diff --git a/DnsServerCore/DnsServer.cs b/DnsServerCore/DnsServer.cs index 19888b30..4964abcb 100644 --- a/DnsServerCore/DnsServer.cs +++ b/DnsServerCore/DnsServer.cs @@ -17,12 +17,17 @@ along with this program. If not, see . */ +using Newtonsoft.Json; using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Collections.Specialized; using System.IO; using System.Net; +using System.Net.Security; using System.Net.Sockets; +using System.Security.Cryptography.X509Certificates; +using System.Text; using System.Threading; using TechnitiumLibrary.IO; using TechnitiumLibrary.Net; @@ -47,14 +52,20 @@ namespace DnsServerCore #region variables - const int UDP_LISTENER_THREAD_COUNT = 3; + const int LISTENER_THREAD_COUNT = 3; - IPEndPoint[] _localEPs; + IPAddress[] _localIPs; List _udpListeners = new List(); List _tcpListeners = new List(); + List _tlsListeners = new List(); + List _httpsListeners = new List(); List _listenerThreads = new List(); + bool _enableDoT = false; + bool _enableDoH = false; + X509Certificate2 _certificate; + readonly Zone _authoritativeZoneRoot = new Zone(true); readonly Zone _cacheZoneRoot = new Zone(false); readonly Zone _allowedZoneRoot = new Zone(true); @@ -66,8 +77,8 @@ namespace DnsServerCore bool _allowRecursionOnlyForPrivateNetworks = false; NetProxy _proxy; NameServerAddress[] _forwarders; - DnsClientProtocol _forwarderProtocol = DnsClientProtocol.Udp; - DnsClientProtocol _recursiveResolveProtocol = DnsClientProtocol.Udp; + DnsTransportProtocol _forwarderProtocol = DnsTransportProtocol.Udp; + DnsTransportProtocol _recursiveResolveProtocol = DnsTransportProtocol.Udp; bool _preferIPv6 = false; int _retries = 3; int _timeout = 2000; @@ -102,30 +113,16 @@ namespace DnsServerCore } public DnsServer() - : this(new IPEndPoint[] { new IPEndPoint(IPAddress.Any, 53), new IPEndPoint(IPAddress.IPv6Any, 53) }) + : this(new IPAddress[] { IPAddress.Any, IPAddress.IPv6Any }) { } public DnsServer(IPAddress localIP) - : this(new IPEndPoint(localIP, 53)) - { } - - public DnsServer(IPEndPoint localEP) - : this(new IPEndPoint[] { localEP }) + : this(new IPAddress[] { localIP }) { } public DnsServer(IPAddress[] localIPs) { - _localEPs = new IPEndPoint[localIPs.Length]; - - for (int i = 0; i < _localEPs.Length; i++) - _localEPs[i] = new IPEndPoint(localIPs[i], 53); - - _dnsCache = new DnsCache(_cacheZoneRoot); - } - - public DnsServer(IPEndPoint[] localEPs) - { - _localEPs = localEPs; + _localIPs = localIPs; _dnsCache = new DnsCache(_cacheZoneRoot); } @@ -212,7 +209,7 @@ namespace DnsServerCore { LogManager log = _log; if (log != null) - log.Write(remoteEP as IPEndPoint, false, ex); + log.Write(remoteEP as IPEndPoint, DnsTransportProtocol.Udp, ex); } } } @@ -224,7 +221,7 @@ namespace DnsServerCore LogManager log = _log; if (log != null) - log.Write(remoteEP as IPEndPoint, false, ex); + log.Write(remoteEP as IPEndPoint, DnsTransportProtocol.Udp, ex); throw; } @@ -240,7 +237,7 @@ namespace DnsServerCore try { - DnsDatagram response = ProcessQuery(request, remoteEP, false); + DnsDatagram response = ProcessQuery(request, remoteEP, DnsTransportProtocol.Udp); //send response if (response != null) @@ -266,7 +263,7 @@ namespace DnsServerCore LogManager queryLog = _queryLog; if (queryLog != null) - queryLog.Write(remoteEP as IPEndPoint, false, request, response); + queryLog.Write(remoteEP as IPEndPoint, DnsTransportProtocol.Udp, request, response); StatsManager stats = _stats; if (stats != null) @@ -280,17 +277,21 @@ namespace DnsServerCore LogManager queryLog = _queryLog; if (queryLog != null) - queryLog.Write(remoteEP as IPEndPoint, false, request, null); + queryLog.Write(remoteEP as IPEndPoint, DnsTransportProtocol.Udp, request, null); LogManager log = _log; if (log != null) - log.Write(remoteEP as IPEndPoint, false, ex); + log.Write(remoteEP as IPEndPoint, DnsTransportProtocol.Udp, ex); } } - private void AcceptTcpConnectionAsync(object parameter) + private void AcceptConnectionAsync(object parameter) { - Socket tcpListener = parameter as Socket; + object[] parameters = parameter as object[]; + + Socket tcpListener = parameters[0] as Socket; + DnsTransportProtocol protocol = (DnsTransportProtocol)parameters[1]; + IPEndPoint localEP = tcpListener.LocalEndPoint as IPEndPoint; try @@ -305,7 +306,51 @@ namespace DnsServerCore { Socket socket = tcpListener.Accept(); - ThreadPool.QueueUserWorkItem(ReadTcpRequestAsync, socket); + ThreadPool.QueueUserWorkItem(delegate (object state) + { + EndPoint remoteEP = null; + + try + { + remoteEP = socket.RemoteEndPoint; + + switch (protocol) + { + case DnsTransportProtocol.Tcp: + ReadStreamRequest(new NetworkStream(socket), remoteEP, protocol); + break; + + case DnsTransportProtocol.Tls: + SslStream tlsStream = new SslStream(new NetworkStream(socket)); + tlsStream.AuthenticateAsServer(_certificate); + + ReadStreamRequest(tlsStream, remoteEP, protocol); + break; + + case DnsTransportProtocol.Https: + SslStream httpsStream = new SslStream(new NetworkStream(socket)); + httpsStream.AuthenticateAsServer(_certificate); + + ProcessHttpsRequest(httpsStream, remoteEP); + break; + } + } + catch (IOException) + { + //ignore IO exceptions + } + catch (Exception ex) + { + LogManager log = _log; + if (log != null) + log.Write(remoteEP as IPEndPoint, protocol, ex); + } + finally + { + if (socket != null) + socket.Dispose(); + } + }); } } catch (Exception ex) @@ -315,24 +360,21 @@ namespace DnsServerCore LogManager log = _log; if (log != null) - log.Write(localEP, true, ex); + log.Write(localEP, protocol, ex); throw; } } - private void ReadTcpRequestAsync(object parameter) + private void ReadStreamRequest(Stream stream, EndPoint remoteEP, DnsTransportProtocol protocol) { - Socket tcpSocket = parameter as Socket; DnsDatagram request = null; - EndPoint remoteEP = null; try { - remoteEP = tcpSocket.RemoteEndPoint as IPEndPoint; - Stream tcpStream = new WriteBufferedStream(new NetworkStream(tcpSocket), 2048); - OffsetStream recvDatagramStream = new OffsetStream(tcpStream, 0, 0); - MemoryStream sendBufferStream = new MemoryStream(64); + OffsetStream recvDatagramStream = new OffsetStream(stream, 0, 0); + Stream writeBufferedStream = new WriteBufferedStream(stream, 2048); + MemoryStream writeBuffer = new MemoryStream(64); byte[] lengthBuffer = new byte[2]; ushort length; @@ -341,7 +383,7 @@ namespace DnsServerCore request = null; //read dns datagram length - tcpStream.ReadBytes(lengthBuffer, 0, 2); + stream.ReadBytes(lengthBuffer, 0, 2); Array.Reverse(lengthBuffer, 0, 2); length = BitConverter.ToUInt16(lengthBuffer, 0); @@ -350,7 +392,7 @@ namespace DnsServerCore request = new DnsDatagram(recvDatagramStream); //process request async - ThreadPool.QueueUserWorkItem(ProcessTcpRequestAsync, new object[] { remoteEP, tcpStream, request, sendBufferStream }); + ThreadPool.QueueUserWorkItem(ProcessStreamRequestAsync, new object[] { writeBufferedStream, writeBuffer, remoteEP, request, protocol }); } } catch (IOException) @@ -361,57 +403,53 @@ namespace DnsServerCore { LogManager queryLog = _queryLog; if ((queryLog != null) && (request != null)) - queryLog.Write(remoteEP as IPEndPoint, true, request, null); + queryLog.Write(remoteEP as IPEndPoint, protocol, request, null); LogManager log = _log; if (log != null) - log.Write(remoteEP as IPEndPoint, true, ex); - } - finally - { - if (tcpSocket != null) - tcpSocket.Dispose(); + log.Write(remoteEP as IPEndPoint, protocol, ex); } } - private void ProcessTcpRequestAsync(object parameter) + private void ProcessStreamRequestAsync(object parameter) { object[] parameters = parameter as object[]; - EndPoint remoteEP = parameters[0] as EndPoint; - Stream tcpStream = parameters[1] as Stream; - DnsDatagram request = parameters[2] as DnsDatagram; - MemoryStream sendBufferStream = parameters[3] as MemoryStream; + Stream stream = parameters[0] as Stream; + MemoryStream writeBuffer = parameters[1] as MemoryStream; + EndPoint remoteEP = parameters[2] as EndPoint; + DnsDatagram request = parameters[3] as DnsDatagram; + DnsTransportProtocol protocol = (DnsTransportProtocol)parameters[4]; try { - DnsDatagram response = ProcessQuery(request, remoteEP, true); + DnsDatagram response = ProcessQuery(request, remoteEP, protocol); //send response if (response != null) { - lock (tcpStream) + lock (stream) { //write dns datagram - sendBufferStream.Position = 0; - response.WriteTo(sendBufferStream); + writeBuffer.Position = 0; + response.WriteTo(writeBuffer); //write dns datagram length - ushort length = Convert.ToUInt16(sendBufferStream.Position); + ushort length = Convert.ToUInt16(writeBuffer.Position); byte[] lengthBuffer = BitConverter.GetBytes(length); Array.Reverse(lengthBuffer, 0, 2); - tcpStream.Write(lengthBuffer); + stream.Write(lengthBuffer); //send dns datagram - sendBufferStream.Position = 0; - sendBufferStream.CopyTo(tcpStream, 512, length); + writeBuffer.Position = 0; + writeBuffer.CopyTo(stream, 512, length); - tcpStream.Flush(); + stream.Flush(); } LogManager queryLog = _queryLog; if (queryLog != null) - queryLog.Write(remoteEP as IPEndPoint, true, request, response); + queryLog.Write(remoteEP as IPEndPoint, protocol, request, response); StatsManager stats = _stats; if (stats != null) @@ -426,11 +464,309 @@ namespace DnsServerCore { LogManager queryLog = _queryLog; if ((queryLog != null) && (request != null)) - queryLog.Write(remoteEP as IPEndPoint, true, request, null); + queryLog.Write(remoteEP as IPEndPoint, protocol, request, null); LogManager log = _log; if (log != null) - log.Write(remoteEP as IPEndPoint, true, ex); + log.Write(remoteEP as IPEndPoint, protocol, ex); + } + } + + private void ProcessHttpsRequest(Stream stream, EndPoint remoteEP) + { + DnsDatagram dnsRequest = null; + DnsTransportProtocol dnsProtocol = DnsTransportProtocol.Https; + + try + { + while (true) + { + string requestMethod; + string requestPath; + NameValueCollection requestQueryString = new NameValueCollection(); + string requestProtocol; + WebHeaderCollection requestHeaders = new WebHeaderCollection(); + + #region parse http request + + using (MemoryStream mS = new MemoryStream()) + { + //read http request header into memory stream + int byteRead; + int crlfCount = 0; + + while (true) + { + byteRead = stream.ReadByte(); + switch (byteRead) + { + case '\r': + case '\n': + crlfCount++; + break; + + case -1: + throw new EndOfStreamException(); + + default: + crlfCount = 0; + break; + } + + mS.WriteByte((byte)byteRead); + + if (crlfCount == 4) + break; //http request completed + } + + mS.Position = 0; + StreamReader sR = new StreamReader(mS); + + string[] requestParts = sR.ReadLine().Split(new char[] { ' ' }, 3); + + if (requestParts.Length != 3) + throw new InvalidDataException("Invalid HTTP request."); + + requestMethod = requestParts[0]; + string pathAndQueryString = requestParts[1]; + requestProtocol = requestParts[2]; + + string[] requestPathAndQueryParts = pathAndQueryString.Split(new char[] { '?' }, 2); + + requestPath = requestPathAndQueryParts[0]; + + string queryString = null; + if (requestPathAndQueryParts.Length > 1) + queryString = requestPathAndQueryParts[1]; + + if (!string.IsNullOrEmpty(queryString)) + { + foreach (string item in queryString.Split(new char[] { '&' }, StringSplitOptions.RemoveEmptyEntries)) + { + string[] itemParts = item.Split(new char[] { '=' }, 2); + + string name = itemParts[0]; + string value = null; + + if (itemParts.Length > 1) + value = itemParts[1]; + + requestQueryString.Add(name, value); + } + } + + while (true) + { + string line = sR.ReadLine(); + if (string.IsNullOrEmpty(line)) + break; + + string[] parts = line.Split(new char[] { ':' }, 2); + if (parts.Length != 2) + throw new InvalidDataException("Invalid HTTP request."); + + requestHeaders.Add(parts[0], parts[1]); + } + } + + #endregion + + string requestConnection = requestHeaders[HttpRequestHeader.Connection]; + if (string.IsNullOrEmpty(requestConnection)) + requestConnection = "close"; + + if (requestPath != "/dns-query") + { + Send404(stream); + return; + } + + DnsTransportProtocol protocol = DnsTransportProtocol.Udp; + + string strRequestAcceptTypes = requestHeaders[HttpRequestHeader.Accept]; + if (!string.IsNullOrEmpty(strRequestAcceptTypes)) + { + protocol = DnsTransportProtocol.Udp; + + foreach (string acceptType in strRequestAcceptTypes.Split(',')) + { + if (acceptType == "application/dns-message") + { + protocol = DnsTransportProtocol.Https; + break; + } + else if (acceptType == "application/dns-json") + { + protocol = DnsTransportProtocol.HttpsJson; + dnsProtocol = DnsTransportProtocol.HttpsJson; + break; + } + } + } + + switch (protocol) + { + case DnsTransportProtocol.Https: + #region https wire format + { + switch (requestMethod) + { + case "GET": + string strRequest = requestQueryString["dns"]; + if (string.IsNullOrEmpty(strRequest)) + throw new ArgumentNullException("dns"); + + //convert from base64url to base64 + strRequest = strRequest.Replace('-', '+'); + strRequest = strRequest.Replace('_', '/'); + + //add padding + int x = strRequest.Length % 4; + if (x > 0) + strRequest = strRequest.PadRight(strRequest.Length - x + 4, '='); + + dnsRequest = new DnsDatagram(new MemoryStream(Convert.FromBase64String(strRequest))); + break; + + case "POST": + string strContentType = requestHeaders[HttpRequestHeader.ContentType]; + if (strContentType != "application/dns-message") + throw new NotSupportedException("DNS request type not supported: " + strContentType); + + dnsRequest = new DnsDatagram(stream); + break; + + default: + throw new NotSupportedException("DoH request type not supported."); ; + } + + DnsDatagram dnsResponse = ProcessQuery(dnsRequest, remoteEP, protocol); + if (dnsResponse != null) + { + using (MemoryStream mS = new MemoryStream()) + { + dnsResponse.WriteTo(mS); + + byte[] buffer = mS.ToArray(); + Send200(stream, "application/dns-message", buffer); + } + + LogManager queryLog = _queryLog; + if (queryLog != null) + queryLog.Write(remoteEP as IPEndPoint, protocol, dnsRequest, dnsResponse); + + StatsManager stats = _stats; + if (stats != null) + stats.Update(dnsResponse, (remoteEP as IPEndPoint).Address); + } + } + #endregion + break; + + case DnsTransportProtocol.HttpsJson: + #region https json format + { + string strName = requestQueryString["name"]; + if (string.IsNullOrEmpty(strName)) + throw new ArgumentNullException("name"); + + string strType = requestQueryString["type"]; + if (string.IsNullOrEmpty(strType)) + strType = "1"; + + dnsRequest = new DnsDatagram(new DnsHeader(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, 1, 0, 0, 0), new DnsQuestionRecord[] { new DnsQuestionRecord(strName, (DnsResourceRecordType)int.Parse(strType), DnsClass.IN) }, null, null, null); + + DnsDatagram dnsResponse = ProcessQuery(dnsRequest, remoteEP, protocol); + if (dnsResponse != null) + { + using (MemoryStream mS = new MemoryStream()) + { + JsonTextWriter jsonWriter = new JsonTextWriter(new StreamWriter(mS)); + dnsResponse.WriteTo(jsonWriter); + jsonWriter.Flush(); + + byte[] buffer = mS.ToArray(); + Send200(stream, "application/dns-json; charset=utf-8", buffer); + } + + LogManager queryLog = _queryLog; + if (queryLog != null) + queryLog.Write(remoteEP as IPEndPoint, protocol, dnsRequest, dnsResponse); + + StatsManager stats = _stats; + if (stats != null) + stats.Update(dnsResponse, (remoteEP as IPEndPoint).Address); + } + } + #endregion + break; + + default: + Send406(stream, "Only application/dns-message and application/dns-json types are accepted."); + return; + } + + if (requestConnection.Equals("close", StringComparison.CurrentCultureIgnoreCase)) + break; + } + } + catch (IOException) + { + //ignore IO exceptions + } + catch (Exception ex) + { + LogManager queryLog = _queryLog; + if ((queryLog != null) && (dnsRequest != null)) + queryLog.Write(remoteEP as IPEndPoint, dnsProtocol, dnsRequest, null); + + LogManager log = _log; + if (log != null) + log.Write(remoteEP as IPEndPoint, dnsProtocol, ex); + } + } + + private static void Send404(Stream outputStream) + { + byte[] bufferContent = Encoding.UTF8.GetBytes("

404 Not Found

"); + byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 404 Not Found\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\n\r\n"); + + using (MemoryStream mS = new MemoryStream()) + { + mS.Write(bufferHeader, 0, bufferHeader.Length); + mS.Write(bufferContent, 0, bufferContent.Length); + + byte[] buffer = mS.ToArray(); + outputStream.Write(buffer, 0, buffer.Length); + } + } + + private static void Send406(Stream outputStream, string message) + { + byte[] bufferContent = Encoding.UTF8.GetBytes("

406 Not Acceptable

" + message + "

"); + byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 406 Not Acceptable\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\n\r\n"); + + using (MemoryStream mS = new MemoryStream()) + { + mS.Write(bufferHeader, 0, bufferHeader.Length); + mS.Write(bufferContent, 0, bufferContent.Length); + + byte[] buffer = mS.ToArray(); + outputStream.Write(buffer, 0, buffer.Length); + } + } + + private static void Send200(Stream outputStream, string contentType, byte[] bufferContent) + { + byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: " + contentType + "\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\n\r\n"); + + using (MemoryStream mS = new MemoryStream()) + { + mS.Write(bufferHeader, 0, bufferHeader.Length); + mS.Write(bufferContent, 0, bufferContent.Length); + + byte[] buffer = mS.ToArray(); + outputStream.Write(buffer, 0, buffer.Length); } } @@ -455,7 +791,7 @@ namespace DnsServerCore return true; } - private DnsDatagram ProcessQuery(DnsDatagram request, EndPoint remoteEP, bool tcp) + internal DnsDatagram ProcessQuery(DnsDatagram request, EndPoint remoteEP, DnsTransportProtocol protocol) { if (request.Header.IsResponse) return null; @@ -536,7 +872,7 @@ namespace DnsServerCore { LogManager log = _log; if (log != null) - log.Write(remoteEP as IPEndPoint, tcp, ex); + log.Write(remoteEP as IPEndPoint, protocol, ex); return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.ServerFailure, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); } @@ -560,6 +896,7 @@ namespace DnsServerCore if ((lastRR.Type != questionType) && (lastRR.Type == DnsResourceRecordType.CNAME) && (questionType != DnsResourceRecordType.ANY)) { + //resolve cname record List responseAnswer = new List(); responseAnswer.AddRange(response.Answer); @@ -570,26 +907,42 @@ namespace DnsServerCore { DnsDatagram cnameRequest = new DnsDatagram(new DnsHeader(0, false, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, false, false, false, DnsResponseCode.NoError, 1, 0, 0, 0), new DnsQuestionRecord[] { new DnsQuestionRecord((lastRR.RDATA as DnsCNAMERecord).CNAMEDomainName, questionType, DnsClass.IN) }, null, null, null); + //query authoritative zone first lastResponse = _authoritativeZoneRoot.Query(cnameRequest); if (lastResponse.Header.RCODE == DnsResponseCode.Refused) { - if (!cnameRequest.Header.RecursionDesired || !isRecursionAllowed) - break; + //not found in auth zone + if (!isRecursionAllowed || !cnameRequest.Header.RecursionDesired) + break; //break since no recursion allowed/desired + //do recursion lastResponse = ProcessRecursiveQuery(cnameRequest); cacheHit &= ("cacheHit".Equals(lastResponse.Tag)); } + else if ((lastResponse.Header.RCODE == DnsResponseCode.NoError) && (lastResponse.Answer.Length == 0) && (lastResponse.Authority.Length > 0) && (lastResponse.Authority[0].Type == DnsResourceRecordType.NS)) + { + //found delegated zone + if (!isRecursionAllowed || !cnameRequest.Header.RecursionDesired) + break; //break since no recursion allowed/desired + //do recursive resolution using delegated authority name servers + NameServerAddress[] nameServers = NameServerAddress.GetNameServersFromResponse(lastResponse, _preferIPv6); + + lastResponse = ProcessRecursiveQuery(cnameRequest, nameServers); + cacheHit &= ("cacheHit".Equals(lastResponse.Tag)); + } + + //check last response if ((lastResponse.Header.RCODE != DnsResponseCode.NoError) || (lastResponse.Answer.Length == 0)) - break; + break; //cannot proceed to resolve cname further responseAnswer.AddRange(lastResponse.Answer); lastRR = lastResponse.Answer[lastResponse.Answer.Length - 1]; if (lastRR.Type != DnsResourceRecordType.CNAME) - break; + break; //cname was resolved } DnsResponseCode rcode; @@ -628,7 +981,7 @@ namespace DnsServerCore else if ((response.Authority.Length > 0) && (response.Authority[0].Type == DnsResourceRecordType.NS) && isRecursionAllowed) { //do recursive resolution using response authority name servers - NameServerAddress[] nameServers = NameServerAddress.GetNameServersFromResponse(response, _preferIPv6, false); + NameServerAddress[] nameServers = NameServerAddress.GetNameServersFromResponse(response, _preferIPv6); return ProcessRecursiveQuery(request, nameServers); } @@ -748,7 +1101,7 @@ namespace DnsServerCore } //select protocol - DnsClientProtocol protocol; + DnsTransportProtocol protocol; if ((viaNameServers == null) && (_forwarders != null)) { @@ -793,9 +1146,11 @@ namespace DnsServerCore _state = ServiceState.Starting; //bind on all local end points - for (int i = 0; i < _localEPs.Length; i++) + for (int i = 0; i < _localIPs.Length; i++) { - Socket udpListener = new Socket(_localEPs[i].AddressFamily, SocketType.Dgram, ProtocolType.Udp); + IPEndPoint dnsEP = new IPEndPoint(_localIPs[i], 53); + + Socket udpListener = new Socket(dnsEP.AddressFamily, SocketType.Dgram, ProtocolType.Udp); #region this code ignores ICMP port unreachable responses which creates SocketException in ReceiveFrom() @@ -812,50 +1167,102 @@ namespace DnsServerCore try { - udpListener.Bind(_localEPs[i]); + udpListener.Bind(dnsEP); _udpListeners.Add(udpListener); LogManager log = _log; if (log != null) - log.Write(_localEPs[i], false, "DNS Server was bound successfully."); + log.Write(dnsEP, DnsTransportProtocol.Udp, "DNS Server was bound successfully."); } catch (Exception ex) { LogManager log = _log; if (log != null) - log.Write(_localEPs[i], false, ex); + log.Write(dnsEP, DnsTransportProtocol.Udp, ex); udpListener.Dispose(); } - Socket tcpListener = new Socket(_localEPs[i].AddressFamily, SocketType.Stream, ProtocolType.Tcp); + Socket tcpListener = new Socket(dnsEP.AddressFamily, SocketType.Stream, ProtocolType.Tcp); try { - tcpListener.Bind(_localEPs[i]); + tcpListener.Bind(dnsEP); tcpListener.Listen(100); _tcpListeners.Add(tcpListener); LogManager log = _log; if (log != null) - log.Write(_localEPs[i], true, "DNS Server was bound successfully."); + log.Write(dnsEP, DnsTransportProtocol.Tcp, "DNS Server was bound successfully."); } catch (Exception ex) { LogManager log = _log; if (log != null) - log.Write(_localEPs[i], true, ex); + log.Write(dnsEP, DnsTransportProtocol.Tcp, ex); tcpListener.Dispose(); } + + if (_enableDoT && (_certificate != null)) + { + IPEndPoint tlsEP = new IPEndPoint(_localIPs[i], 853); + Socket tlsListener = new Socket(tlsEP.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + + try + { + tlsListener.Bind(tlsEP); + tlsListener.Listen(100); + + _tlsListeners.Add(tlsListener); + + LogManager log = _log; + if (log != null) + log.Write(tlsEP, DnsTransportProtocol.Tls, "DNS Server was bound successfully."); + } + catch (Exception ex) + { + LogManager log = _log; + if (log != null) + log.Write(tlsEP, DnsTransportProtocol.Tls, ex); + + tlsListener.Dispose(); + } + } + + if (_enableDoH && (_certificate != null)) + { + IPEndPoint httpsEP = new IPEndPoint(_localIPs[i], 443); + Socket httpsListener = new Socket(httpsEP.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + + try + { + httpsListener.Bind(httpsEP); + httpsListener.Listen(100); + + _httpsListeners.Add(httpsListener); + + LogManager log = _log; + if (log != null) + log.Write(httpsEP, DnsTransportProtocol.Https, "DNS Server was bound successfully."); + } + catch (Exception ex) + { + LogManager log = _log; + if (log != null) + log.Write(httpsEP, DnsTransportProtocol.Https, ex); + + httpsListener.Dispose(); + } + } } //start reading query packets foreach (Socket udpListener in _udpListeners) { - for (int i = 0; i < UDP_LISTENER_THREAD_COUNT; i++) + for (int i = 0; i < LISTENER_THREAD_COUNT; i++) { Thread listenerThread = new Thread(ReadUdpRequestAsync); listenerThread.IsBackground = true; @@ -867,11 +1274,38 @@ namespace DnsServerCore foreach (Socket tcpListener in _tcpListeners) { - Thread listenerThread = new Thread(AcceptTcpConnectionAsync); - listenerThread.IsBackground = true; - listenerThread.Start(tcpListener); + for (int i = 0; i < LISTENER_THREAD_COUNT; i++) + { + Thread listenerThread = new Thread(AcceptConnectionAsync); + listenerThread.IsBackground = true; + listenerThread.Start(new object[] { tcpListener, DnsTransportProtocol.Tcp }); - _listenerThreads.Add(listenerThread); + _listenerThreads.Add(listenerThread); + } + } + + foreach (Socket tlsListener in _tlsListeners) + { + for (int i = 0; i < LISTENER_THREAD_COUNT; i++) + { + Thread listenerThread = new Thread(AcceptConnectionAsync); + listenerThread.IsBackground = true; + listenerThread.Start(new object[] { tlsListener, DnsTransportProtocol.Tls }); + + _listenerThreads.Add(listenerThread); + } + } + + foreach (Socket httpsListener in _httpsListeners) + { + for (int i = 0; i < LISTENER_THREAD_COUNT; i++) + { + Thread listenerThread = new Thread(AcceptConnectionAsync); + listenerThread.IsBackground = true; + listenerThread.Start(new object[] { httpsListener, DnsTransportProtocol.Https }); + + _listenerThreads.Add(listenerThread); + } } _state = ServiceState.Running; @@ -890,9 +1324,17 @@ namespace DnsServerCore foreach (Socket tcpListener in _tcpListeners) tcpListener.Dispose(); + foreach (Socket tlsListener in _tlsListeners) + tlsListener.Dispose(); + + foreach (Socket httpsListener in _httpsListeners) + httpsListener.Dispose(); + _listenerThreads.Clear(); _udpListeners.Clear(); _tcpListeners.Clear(); + _tlsListeners.Clear(); + _httpsListeners.Clear(); _state = ServiceState.Stopped; } @@ -901,16 +1343,10 @@ namespace DnsServerCore #region properties - public IPEndPoint[] LocalEndPoints + public IPAddress[] LocalAddresses { - get { return _localEPs; } - set - { - if (_state != ServiceState.Stopped) - throw new InvalidOperationException("DNS Server is already running."); - - _localEPs = value; - } + get { return _localIPs; } + set { _localIPs = value; } } public string ServerDomain @@ -924,6 +1360,30 @@ namespace DnsServerCore } } + public bool EnableDoT + { + get { return _enableDoT; } + set { _enableDoT = value; } + } + + public bool EnableDoH + { + get { return _enableDoH; } + set { _enableDoH = value; } + } + + public X509Certificate2 Certificate + { + get { return _certificate; } + set + { + if (!value.HasPrivateKey) + throw new ArgumentException("Tls certificate does not contain private key."); + + _certificate = value; + } + } + public Zone AuthoritativeZoneRoot { get { return _authoritativeZoneRoot; } } @@ -976,13 +1436,13 @@ namespace DnsServerCore set { _forwarders = value; } } - public DnsClientProtocol ForwarderProtocol + public DnsTransportProtocol ForwarderProtocol { get { return _forwarderProtocol; } set { _forwarderProtocol = value; } } - public DnsClientProtocol RecursiveResolveProtocol + public DnsTransportProtocol RecursiveResolveProtocol { get { return _recursiveResolveProtocol; } set { _recursiveResolveProtocol = value; }