diff --git a/DnsServerCore/Dns/DnsServer.cs b/DnsServerCore/Dns/DnsServer.cs index cc014787..3c2b345d 100644 --- a/DnsServerCore/Dns/DnsServer.cs +++ b/DnsServerCore/Dns/DnsServer.cs @@ -22,7 +22,6 @@ using DnsServerCore.Dns.Zones; using Newtonsoft.Json; using System; using System.Collections.Generic; -using System.Collections.Specialized; using System.IO; using System.Net; using System.Net.Security; @@ -30,10 +29,12 @@ using System.Net.Sockets; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading; +using System.Threading.Tasks; using TechnitiumLibrary.IO; using TechnitiumLibrary.Net; using TechnitiumLibrary.Net.Dns; using TechnitiumLibrary.Net.Dns.ResourceRecords; +using TechnitiumLibrary.Net.Http; using TechnitiumLibrary.Net.Proxy; namespace DnsServerCore.Dns @@ -54,7 +55,6 @@ namespace DnsServerCore.Dns #region variables - const int LISTENER_THREAD_COUNT = 4; const int MAX_CNAME_HOPS = 16; string _serverDomain; @@ -69,7 +69,6 @@ namespace DnsServerCore.Dns readonly List _httpListeners = new List(); readonly List _tlsListeners = new List(); readonly List _httpsListeners = new List(); - readonly List _listenerThreads = new List(); bool _enableDnsOverHttp = false; bool _enableDnsOverTls = false; @@ -93,9 +92,13 @@ namespace DnsServerCore.Dns NetProxy _proxy; IReadOnlyList _forwarders; bool _preferIPv6 = false; - int _retries = 2; - int _timeout = 4000; - int _maxStackCount = 10; + int _forwarderRetries = 3; + int _resolverRetries = 5; + int _forwarderTimeout = 4000; + int _resolverTimeout = 4000; + int _clientTimeout = 4000; + int _forwarderConcurrency = 2; + int _resolverMaxStackCount = 10; int _cachePrefetchEligibility = 2; int _cachePrefetchTrigger = 9; int _cachePrefetchSampleIntervalInMinutes = 5; @@ -119,7 +122,8 @@ namespace DnsServerCore.Dns const int CACHE_MAINTENANCE_TIMER_INITIAL_INTEVAL = 60 * 60 * 1000; const int CACHE_MAINTENANCE_TIMER_PERIODIC_INTERVAL = 60 * 60 * 1000; - readonly DomainTree _resolverQueryHandles = new DomainTree(); + readonly IndependentTaskScheduler _resolverTaskScheduler = new IndependentTaskScheduler(ThreadPriority.AboveNormal); + readonly DomainTree _resolverTasks = new DomainTree(); volatile ServiceState _state = ServiceState.Stopped; @@ -257,7 +261,9 @@ namespace DnsServerCore.Dns { try { - ThreadPool.QueueUserWorkItem(ProcessUdpRequestAsync, new object[] { udpListener, remoteEP, new DnsDatagram(new MemoryStream(recvBuffer, 0, bytesRecv, false), false) }); + DnsDatagram request = DnsDatagram.ReadFromUdp(new MemoryStream(recvBuffer, 0, bytesRecv, false)); + + _ = ProcessUdpRequestAsync(udpListener, remoteEP, request); } catch (Exception ex) { @@ -281,21 +287,15 @@ namespace DnsServerCore.Dns } } - private void ProcessUdpRequestAsync(object parameter) + private async Task ProcessUdpRequestAsync(Socket udpListener, EndPoint remoteEP, DnsDatagram request) { - object[] parameters = parameter as object[]; - - Socket udpListener = parameters[0] as Socket; - EndPoint remoteEP = parameters[1] as EndPoint; - DnsDatagram request = parameters[2] as DnsDatagram; - try { DnsDatagram response; if (request.ParsingException == null) { - response = ProcessQuery(request, remoteEP, IsRecursionAllowed(remoteEP), DnsTransportProtocol.Udp); + response = await ProcessQueryAsync(request, remoteEP, IsRecursionAllowed(remoteEP), DnsTransportProtocol.Udp); } else { @@ -316,18 +316,18 @@ namespace DnsServerCore.Dns try { - response.WriteTo(sendBufferStream, false); + response.WriteToUdp(sendBufferStream); } catch (NotSupportedException) { response = new DnsDatagram(response.Identifier, true, response.OPCODE, response.AuthoritativeAnswer, true, response.RecursionDesired, response.RecursionAvailable, response.AuthenticData, response.CheckingDisabled, response.RCODE, response.Question); sendBufferStream.Position = 0; - response.WriteTo(sendBufferStream, false); + response.WriteToUdp(sendBufferStream); } - //send dns datagram - udpListener.SendTo(sendBuffer, 0, (int)sendBufferStream.Position, SocketFlags.None, remoteEP); + //send dns datagram async + await udpListener.SendToAsync(sendBuffer, 0, (int)sendBufferStream.Position, remoteEP); LogManager queryLog = _queryLog; if (queryLog != null) @@ -378,64 +378,7 @@ namespace DnsServerCore.Dns { Socket socket = tcpListener.Accept(); - ThreadPool.QueueUserWorkItem(delegate (object state) - { - EndPoint remoteEP = null; - - try - { - remoteEP = socket.RemoteEndPoint; - - switch (protocol) - { - case DnsTransportProtocol.Tcp: - ReadStreamRequest(new NetworkStream(socket), remoteEP, protocol); - break; - - case DnsTransportProtocol.Tls: - SslStream tlsStream = new SslStream(new NetworkStream(socket)); - tlsStream.AuthenticateAsServer(_certificate); - - ReadStreamRequest(tlsStream, remoteEP, protocol); - break; - - case DnsTransportProtocol.Https: - Stream stream = new NetworkStream(socket); - - if (usingHttps) - { - SslStream httpsStream = new SslStream(stream); - httpsStream.AuthenticateAsServer(_certificate); - - stream = httpsStream; - } - else if (!NetUtilities.IsPrivateIP((remoteEP as IPEndPoint).Address)) - { - //intentionally blocking public IP addresses from using DNS-over-HTTP (without TLS) - //this feature is intended to be used with an SSL terminated reverse proxy like nginx on private network - return; - } - - ProcessDoHRequest(stream, remoteEP, !usingHttps); - break; - } - } - catch (IOException) - { - //ignore IO exceptions - } - catch (Exception ex) - { - LogManager log = _log; - if (log != null) - log.Write(remoteEP as IPEndPoint, protocol, ex); - } - finally - { - if (socket != null) - socket.Dispose(); - } - }); + _ = ProcessConnectionAsync(socket, protocol, usingHttps); } } catch (Exception ex) @@ -451,24 +394,96 @@ namespace DnsServerCore.Dns } } - private void ReadStreamRequest(Stream stream, EndPoint remoteEP, DnsTransportProtocol protocol) + private async Task ProcessConnectionAsync(Socket socket, DnsTransportProtocol protocol, bool usingHttps) { - DnsDatagram request = null; + EndPoint remoteEP = null; + try + { + remoteEP = socket.RemoteEndPoint; + + switch (protocol) + { + case DnsTransportProtocol.Tcp: + await ReadStreamRequestAsync(new NetworkStream(socket), _tcpReceiveTimeout, remoteEP, protocol); + break; + + case DnsTransportProtocol.Tls: + SslStream tlsStream = new SslStream(new NetworkStream(socket)); + await tlsStream.AuthenticateAsServerAsync(_certificate); + + await ReadStreamRequestAsync(tlsStream, _tcpReceiveTimeout, remoteEP, protocol); + break; + + case DnsTransportProtocol.Https: + Stream stream = new NetworkStream(socket); + + if (usingHttps) + { + SslStream httpsStream = new SslStream(stream); + await httpsStream.AuthenticateAsServerAsync(_certificate); + + stream = httpsStream; + } + else if (!NetUtilities.IsPrivateIP((remoteEP as IPEndPoint).Address)) + { + //intentionally blocking public IP addresses from using DNS-over-HTTP (without TLS) + //this feature is intended to be used with an SSL terminated reverse proxy like nginx on private network + return; + } + + await ProcessDoHRequestAsync(stream, _tcpReceiveTimeout, remoteEP, !usingHttps); + break; + } + } + catch (IOException) + { + //ignore IO exceptions + } + catch (Exception ex) + { + LogManager log = _log; + if (log != null) + log.Write(remoteEP as IPEndPoint, protocol, ex); + } + finally + { + if (socket != null) + socket.Dispose(); + } + } + + private async Task ReadStreamRequestAsync(Stream stream, int receiveTimeout, EndPoint remoteEP, DnsTransportProtocol protocol) + { try { MemoryStream readBuffer = new MemoryStream(64); MemoryStream writeBuffer = new MemoryStream(64); + SemaphoreSlim writeSemaphore = new SemaphoreSlim(1, 1); while (true) { - request = null; + DnsDatagram request; - //read dns datagram - request = new DnsDatagram(stream, true, readBuffer); + //read dns datagram with timeout + using (CancellationTokenSource cancellationTokenSource = new CancellationTokenSource()) + { + Task task = DnsDatagram.ReadFromTcpAsync(stream, readBuffer, cancellationTokenSource.Token); + + if (await Task.WhenAny(task, Task.Delay(receiveTimeout, cancellationTokenSource.Token)) != task) + { + //read timed out + stream.Dispose(); + return; + } + + cancellationTokenSource.Cancel(); //cancel delay task + + request = await task; + } //process request async - ThreadPool.QueueUserWorkItem(ProcessStreamRequestAsync, new object[] { stream, writeBuffer, remoteEP, request, protocol }); + _ = ProcessStreamRequestAsync(stream, writeBuffer, writeSemaphore, remoteEP, request, protocol); } } catch (ObjectDisposedException) @@ -481,54 +496,49 @@ namespace DnsServerCore.Dns } catch (Exception ex) { - LogManager queryLog = _queryLog; - if ((queryLog != null) && (request != null)) - queryLog.Write(remoteEP as IPEndPoint, protocol, request, null); - LogManager log = _log; if (log != null) log.Write(remoteEP as IPEndPoint, protocol, ex); } } - private void ProcessStreamRequestAsync(object parameter) + private async Task ProcessStreamRequestAsync(Stream stream, MemoryStream writeBuffer, SemaphoreSlim writeSemaphore, EndPoint remoteEP, DnsDatagram request, DnsTransportProtocol protocol) { - object[] parameters = parameter as object[]; - - Stream stream = parameters[0] as Stream; - MemoryStream writeBuffer = parameters[1] as MemoryStream; - EndPoint remoteEP = parameters[2] as EndPoint; - DnsDatagram request = parameters[3] as DnsDatagram; - DnsTransportProtocol protocol = (DnsTransportProtocol)parameters[4]; - try { DnsDatagram response; if (request.ParsingException == null) { - response = ProcessQuery(request, remoteEP, IsRecursionAllowed(remoteEP), protocol); + response = await ProcessQueryAsync(request, remoteEP, IsRecursionAllowed(remoteEP), protocol); } else { //format error + response = new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, request.Question); + + LogManager queryLog = _queryLog; + if (queryLog != null) + queryLog.Write(remoteEP as IPEndPoint, protocol, request, response); + LogManager log = _log; if (log != null) log.Write(remoteEP as IPEndPoint, protocol, request.ParsingException); - - //format error response - response = new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, request.Question); } //send response if (response != null) { - lock (stream) + await writeSemaphore.WaitAsync(); + try { //send dns datagram - response.WriteTo(stream, true, writeBuffer); - - stream.Flush(); + await response.WriteToTcpAsync(stream, writeBuffer); + await stream.FlushAsync(); + } + finally + { + writeSemaphore.Release(); } LogManager queryLog = _queryLog; @@ -556,7 +566,7 @@ namespace DnsServerCore.Dns } } - private void ProcessDoHRequest(Stream stream, EndPoint remoteEP, bool usingReverseProxy) + private async Task ProcessDoHRequestAsync(Stream stream, int receiveTimeout, EndPoint remoteEP, bool usingReverseProxy) { DnsDatagram dnsRequest = null; DnsTransportProtocol dnsProtocol = DnsTransportProtocol.Https; @@ -565,99 +575,13 @@ namespace DnsServerCore.Dns { while (true) { - string requestMethod; - string requestPath; - NameValueCollection requestQueryString = new NameValueCollection(); - string requestProtocol; - WebHeaderCollection requestHeaders = new WebHeaderCollection(); - - #region parse http request - - using (MemoryStream mS = new MemoryStream()) - { - //read http request header into memory stream - int byteRead; - int crlfCount = 0; - - while (true) - { - byteRead = stream.ReadByte(); - switch (byteRead) - { - case '\r': - case '\n': - crlfCount++; - break; - - case -1: - throw new EndOfStreamException(); - - default: - crlfCount = 0; - break; - } - - mS.WriteByte((byte)byteRead); - - if (crlfCount == 4) - break; //http request completed - } - - mS.Position = 0; - StreamReader sR = new StreamReader(mS); - - string[] requestParts = sR.ReadLine().Split(new char[] { ' ' }, 3); - - if (requestParts.Length != 3) - throw new InvalidDataException("Invalid HTTP request."); - - requestMethod = requestParts[0]; - string pathAndQueryString = requestParts[1]; - requestProtocol = requestParts[2]; - - string[] requestPathAndQueryParts = pathAndQueryString.Split(new char[] { '?' }, 2); - - requestPath = requestPathAndQueryParts[0]; - - string queryString = null; - if (requestPathAndQueryParts.Length > 1) - queryString = requestPathAndQueryParts[1]; - - if (!string.IsNullOrEmpty(queryString)) - { - foreach (string item in queryString.Split(new char[] { '&' }, StringSplitOptions.RemoveEmptyEntries)) - { - string[] itemParts = item.Split(new char[] { '=' }, 2); - - string name = itemParts[0]; - string value = null; - - if (itemParts.Length > 1) - value = itemParts[1]; - - requestQueryString.Add(name, value); - } - } - - while (true) - { - string line = sR.ReadLine(); - if (string.IsNullOrEmpty(line)) - break; - - string[] parts = line.Split(new char[] { ':' }, 2); - if (parts.Length != 2) - throw new InvalidDataException("Invalid HTTP request."); - - requestHeaders.Add(parts[0], parts[1]); - } - } - - #endregion + HttpRequest httpRequest = await HttpRequest.ReadRequestAsync(stream).WithTimeout(receiveTimeout); + if (httpRequest == null) + return; //connection closed gracefully by client if (usingReverseProxy) { - string xRealIp = requestHeaders["X-Real-IP"]; + string xRealIp = httpRequest.Headers["X-Real-IP"]; if (IPAddress.TryParse(xRealIp, out IPAddress address)) { //get the real IP address of the requesting client from X-Real-IP header set in nginx proxy_pass block @@ -665,16 +589,16 @@ namespace DnsServerCore.Dns } } - string requestConnection = requestHeaders[HttpRequestHeader.Connection]; + string requestConnection = httpRequest.Headers[HttpRequestHeader.Connection]; if (string.IsNullOrEmpty(requestConnection)) requestConnection = "close"; - switch (requestPath) + switch (httpRequest.RequestPath) { case "/dns-query": DnsTransportProtocol protocol = DnsTransportProtocol.Udp; - string strRequestAcceptTypes = requestHeaders[HttpRequestHeader.Accept]; + string strRequestAcceptTypes = httpRequest.Headers[HttpRequestHeader.Accept]; if (!string.IsNullOrEmpty(strRequestAcceptTypes)) { foreach (string acceptType in strRequestAcceptTypes.Split(',')) @@ -698,10 +622,10 @@ namespace DnsServerCore.Dns case DnsTransportProtocol.Https: #region https wire format { - switch (requestMethod) + switch (httpRequest.HttpMethod) { case "GET": - string strRequest = requestQueryString["dns"]; + string strRequest = httpRequest.QueryString["dns"]; if (string.IsNullOrEmpty(strRequest)) throw new ArgumentNullException("dns"); @@ -714,29 +638,23 @@ namespace DnsServerCore.Dns if (x > 0) strRequest = strRequest.PadRight(strRequest.Length - x + 4, '='); - dnsRequest = new DnsDatagram(new MemoryStream(Convert.FromBase64String(strRequest)), false); + dnsRequest = DnsDatagram.ReadFromUdp(new MemoryStream(Convert.FromBase64String(strRequest))); break; case "POST": - string strContentType = requestHeaders[HttpRequestHeader.ContentType]; + string strContentType = httpRequest.Headers[HttpRequestHeader.ContentType]; if (string.IsNullOrEmpty(strContentType)) throw new DnsServerException("Missing Content-Type header."); if (strContentType != "application/dns-message") throw new NotSupportedException("DNS request type not supported: " + strContentType); - string strContentLength = requestHeaders[HttpRequestHeader.ContentLength]; - if (string.IsNullOrEmpty(strContentLength)) - throw new DnsServerException("Missing Content-Length header."); - - int contentLength = int.Parse(strContentLength); - using (MemoryStream mS = new MemoryStream()) { - stream.CopyTo(mS, 512, contentLength); + await httpRequest.InputStream.CopyToAsync(mS, 512); mS.Position = 0; - dnsRequest = new DnsDatagram(mS, false); + dnsRequest = DnsDatagram.ReadFromUdp(mS); } break; @@ -749,7 +667,7 @@ namespace DnsServerCore.Dns if (dnsRequest.ParsingException == null) { - dnsResponse = ProcessQuery(dnsRequest, remoteEP, IsRecursionAllowed(remoteEP), protocol); + dnsResponse = await ProcessQueryAsync(dnsRequest, remoteEP, IsRecursionAllowed(remoteEP), protocol); } else { @@ -766,10 +684,10 @@ namespace DnsServerCore.Dns { using (MemoryStream mS = new MemoryStream()) { - dnsResponse.WriteTo(mS, false); + dnsResponse.WriteToUdp(mS); byte[] buffer = mS.ToArray(); - SendContent(stream, "application/dns-message", buffer); + await SendContentAsync(stream, "application/dns-message", buffer); } LogManager queryLog = _queryLog; @@ -787,27 +705,27 @@ namespace DnsServerCore.Dns case DnsTransportProtocol.HttpsJson: #region https json format { - string strName = requestQueryString["name"]; + string strName = httpRequest.QueryString["name"]; if (string.IsNullOrEmpty(strName)) throw new ArgumentNullException("name"); - string strType = requestQueryString["type"]; + string strType = httpRequest.QueryString["type"]; if (string.IsNullOrEmpty(strType)) strType = "1"; dnsRequest = new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { new DnsQuestionRecord(strName, (DnsResourceRecordType)int.Parse(strType), DnsClass.IN) }); - DnsDatagram dnsResponse = ProcessQuery(dnsRequest, remoteEP, IsRecursionAllowed(remoteEP), protocol); + DnsDatagram dnsResponse = await ProcessQueryAsync(dnsRequest, remoteEP, IsRecursionAllowed(remoteEP), protocol); if (dnsResponse != null) { using (MemoryStream mS = new MemoryStream()) { JsonTextWriter jsonWriter = new JsonTextWriter(new StreamWriter(mS)); - dnsResponse.WriteTo(jsonWriter); + dnsResponse.WriteToJson(jsonWriter); jsonWriter.Flush(); byte[] buffer = mS.ToArray(); - SendContent(stream, "application/dns-json; charset=utf-8", buffer); + await SendContentAsync(stream, "application/dns-json; charset=utf-8", buffer); } LogManager queryLog = _queryLog; @@ -823,7 +741,7 @@ namespace DnsServerCore.Dns break; default: - SendError(stream, 406, "Only application/dns-message and application/dns-json types are accepted."); + await SendErrorAsync(stream, 406, "Only application/dns-message and application/dns-json types are accepted."); break; } @@ -833,11 +751,15 @@ namespace DnsServerCore.Dns break; default: - SendError(stream, 404); + await SendErrorAsync(stream, 404); break; } } } + catch (TimeoutException) + { + //ignore timeout exception + } catch (IOException) { //ignore IO exceptions @@ -852,25 +774,25 @@ namespace DnsServerCore.Dns if (log != null) log.Write(remoteEP as IPEndPoint, dnsProtocol, ex); - SendError(stream, ex); + await SendErrorAsync(stream, ex); } } - private static void SendContent(Stream outputStream, string contentType, byte[] bufferContent) + private static async Task SendContentAsync(Stream outputStream, string contentType, byte[] bufferContent) { byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: " + contentType + "\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\n\r\n"); - outputStream.Write(bufferHeader, 0, bufferHeader.Length); - outputStream.Write(bufferContent, 0, bufferContent.Length); - outputStream.Flush(); + await outputStream.WriteAsync(bufferHeader, 0, bufferHeader.Length); + await outputStream.WriteAsync(bufferContent, 0, bufferContent.Length); + await outputStream.FlushAsync(); } - private static void SendError(Stream outputStream, Exception ex) + private static Task SendErrorAsync(Stream outputStream, Exception ex) { - SendError(outputStream, 500, ex.ToString()); + return SendErrorAsync(outputStream, 500, ex.ToString()); } - private static void SendError(Stream outputStream, int statusCode, string message = null) + private static async Task SendErrorAsync(Stream outputStream, int statusCode, string message = null) { try { @@ -878,9 +800,9 @@ namespace DnsServerCore.Dns byte[] bufferContent = Encoding.UTF8.GetBytes("" + statusString + "

" + statusString + "

" + (message == null ? "" : "

" + message + "

") + ""); byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 " + statusString + "\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\n\r\n"); - outputStream.Write(bufferHeader, 0, bufferHeader.Length); - outputStream.Write(bufferContent, 0, bufferContent.Length); - outputStream.Flush(); + await outputStream.WriteAsync(bufferHeader, 0, bufferHeader.Length); + await outputStream.WriteAsync(bufferContent, 0, bufferContent.Length); + await outputStream.FlushAsync(); } catch { } @@ -922,7 +844,7 @@ namespace DnsServerCore.Dns return true; } - private DnsDatagram ProcessQuery(DnsDatagram request, EndPoint remoteEP, bool isRecursionAllowed, DnsTransportProtocol protocol) + private async Task ProcessQueryAsync(DnsDatagram request, EndPoint remoteEP, bool isRecursionAllowed, DnsTransportProtocol protocol) { if (request.IsResponse) return null; @@ -941,10 +863,10 @@ namespace DnsServerCore.Dns if (protocol == DnsTransportProtocol.Udp) return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.FormatError, request.Question); - return ProcessZoneTransferQuery(request, remoteEP); + return await ProcessZoneTransferQueryAsync(request, remoteEP); case DnsResourceRecordType.IXFR: - return ProcessZoneTransferQuery(request, remoteEP); + return await ProcessZoneTransferQueryAsync(request, remoteEP); case DnsResourceRecordType.MAILB: case DnsResourceRecordType.MAILA: @@ -964,13 +886,13 @@ namespace DnsServerCore.Dns } //query authoritative zone - response = ProcessAuthoritativeQuery(request, inAllowedZone, isRecursionAllowed); + response = await ProcessAuthoritativeQueryAsync(request, inAllowedZone, isRecursionAllowed); if ((response.RCODE != DnsResponseCode.Refused) || !request.RecursionDesired || !isRecursionAllowed) return response; //do recursive query - return ProcessRecursiveQuery(request, null, null, !inAllowedZone, false); + return await ProcessRecursiveQueryAsync(request, null, null, !inAllowedZone, false); } } catch (Exception ex) @@ -983,14 +905,14 @@ namespace DnsServerCore.Dns } case DnsOpcode.Notify: - return ProcessNotifyQuery(request, remoteEP); + return await ProcessNotifyQueryAsync(request, remoteEP); default: return new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.NotImplemented, request.Question); } } - private DnsDatagram ProcessNotifyQuery(DnsDatagram request, EndPoint remoteEP) + private async Task ProcessNotifyQueryAsync(DnsDatagram request, EndPoint remoteEP) { AuthZoneInfo authZoneInfo = _authZoneManager.GetAuthZoneInfo(request.Question[0].Name); if ((authZoneInfo == null) || (authZoneInfo.Type != AuthZoneType.Secondary)) @@ -999,7 +921,7 @@ namespace DnsServerCore.Dns IPAddress remoteAddress = (remoteEP as IPEndPoint).Address; bool remoteVerified = false; - IReadOnlyList primaryNameServers = authZoneInfo.GetPrimaryNameServerAddresses(this); + IReadOnlyList primaryNameServers = await authZoneInfo.GetPrimaryNameServerAddressesAsync(this); foreach (NameServerAddress primaryNameServer in primaryNameServers) { @@ -1032,7 +954,7 @@ namespace DnsServerCore.Dns return new DnsDatagram(request.Identifier, true, DnsOpcode.Notify, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NoError, request.Question) { Tag = StatsResponseType.Authoritative }; } - private DnsDatagram ProcessZoneTransferQuery(DnsDatagram request, EndPoint remoteEP) + private async Task ProcessZoneTransferQueryAsync(DnsDatagram request, EndPoint remoteEP) { AuthZoneInfo authZoneInfo = _authZoneManager.GetAuthZoneInfo(request.Question[0].Name); if ((authZoneInfo == null) || (authZoneInfo.Type != AuthZoneType.Primary)) @@ -1043,7 +965,7 @@ namespace DnsServerCore.Dns if (!isAxfrAllowed) { - IReadOnlyList secondaryNameServers = authZoneInfo.GetSecondaryNameServerAddresses(this); + IReadOnlyList secondaryNameServers = await authZoneInfo.GetSecondaryNameServerAddressesAsync(this); foreach (NameServerAddress secondaryNameServer in secondaryNameServers) { @@ -1067,7 +989,7 @@ namespace DnsServerCore.Dns return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, true, false, request.RecursionDesired, false, false, false, DnsResponseCode.NoError, request.Question, axfrRecords) { Tag = StatsResponseType.Authoritative }; } - private DnsDatagram ProcessAuthoritativeQuery(DnsDatagram request, bool inAllowedZone, bool isRecursionAllowed) + private Task ProcessAuthoritativeQueryAsync(DnsDatagram request, bool inAllowedZone, bool isRecursionAllowed) { DnsDatagram response = _authZoneManager.Query(request); response.Tag = StatsResponseType.Authoritative; @@ -1084,10 +1006,10 @@ namespace DnsServerCore.Dns switch (lastRR.Type) { case DnsResourceRecordType.CNAME: - return ProcessCNAME(request, response, isRecursionAllowed, false); + return ProcessCNAMEAsync(request, response, isRecursionAllowed, false); case DnsResourceRecordType.ANAME: - return ProcessANAME(request, response, isRecursionAllowed); + return ProcessANAMEAsync(request, response, isRecursionAllowed); } } } @@ -1101,7 +1023,7 @@ namespace DnsServerCore.Dns //do recursive resolution using response authority name servers List nameServers = NameServerAddress.GetNameServersFromResponse(response, _preferIPv6, false); - return ProcessRecursiveQuery(request, nameServers, null, !inAllowedZone, false); + return ProcessRecursiveQueryAsync(request, nameServers, null, !inAllowedZone, false); } break; @@ -1110,7 +1032,7 @@ namespace DnsServerCore.Dns if ((response.Authority.Count == 1) && (response.Authority[0].Type == DnsResourceRecordType.FWD) && (response.Authority[0].RDATA as DnsForwarderRecord).Forwarder.Equals("this-server", StringComparison.OrdinalIgnoreCase)) { //do conditional forwarding via "this-server" - return ProcessRecursiveQuery(request, null, null, !inAllowedZone, false); + return ProcessRecursiveQueryAsync(request, null, null, !inAllowedZone, false); } else { @@ -1128,16 +1050,16 @@ namespace DnsServerCore.Dns } } - return ProcessRecursiveQuery(request, null, forwarders, !inAllowedZone, false); + return ProcessRecursiveQueryAsync(request, null, forwarders, !inAllowedZone, false); } } } } - return response; + return Task.FromResult(response); } - private DnsDatagram ProcessCNAME(DnsDatagram request, DnsDatagram response, bool isRecursionAllowed, bool cacheRefreshOperation) + private async Task ProcessCNAMEAsync(DnsDatagram request, DnsDatagram response, bool isRecursionAllowed, bool cacheRefreshOperation) { List responseAnswer = new List(); responseAnswer.AddRange(response.Answer); @@ -1160,7 +1082,7 @@ namespace DnsServerCore.Dns if (newRequest.RecursionDesired && isRecursionAllowed) { //do recursion - lastResponse = RecursiveResolve(newRequest, null, null, false, cacheRefreshOperation); + lastResponse = await RecursiveResolveAsync(newRequest, null, null, false, cacheRefreshOperation); isAuthoritativeAnswer = false; } else @@ -1171,7 +1093,7 @@ namespace DnsServerCore.Dns } else if ((lastResponse.Answer.Count > 0) && (lastResponse.Answer[0].Type == DnsResourceRecordType.ANAME)) { - lastResponse = ProcessANAME(request, lastResponse, isRecursionAllowed); + lastResponse = await ProcessANAMEAsync(request, lastResponse, isRecursionAllowed); } else if ((lastResponse.Answer.Count == 0) && (lastResponse.Authority.Count > 0)) { @@ -1184,7 +1106,7 @@ namespace DnsServerCore.Dns //do recursive resolution using last response authority name servers List nameServers = NameServerAddress.GetNameServersFromResponse(lastResponse, _preferIPv6, false); - lastResponse = RecursiveResolve(newRequest, nameServers, null, false, false); + lastResponse = await RecursiveResolveAsync(newRequest, nameServers, null, false, false); isAuthoritativeAnswer = false; } @@ -1194,7 +1116,7 @@ namespace DnsServerCore.Dns if ((lastResponse.Authority.Count == 1) && (lastResponse.Authority[0].RDATA as DnsForwarderRecord).Forwarder.Equals("this-server", StringComparison.OrdinalIgnoreCase)) { //do conditional forwarding via "this-server" - lastResponse = RecursiveResolve(newRequest, null, null, false, false); + lastResponse = await RecursiveResolveAsync(newRequest, null, null, false, false); isAuthoritativeAnswer = false; } else @@ -1213,7 +1135,7 @@ namespace DnsServerCore.Dns } } - lastResponse = RecursiveResolve(newRequest, null, forwarders, false, false); + lastResponse = await RecursiveResolveAsync(newRequest, null, forwarders, false, false); isAuthoritativeAnswer = false; } @@ -1263,7 +1185,7 @@ namespace DnsServerCore.Dns return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, isAuthoritativeAnswer, false, request.RecursionDesired, isRecursionAllowed, false, false, rcode, request.Question, responseAnswer, authority, additional) { Tag = response.Tag }; } - private DnsDatagram ProcessANAME(DnsDatagram request, DnsDatagram response, bool isRecursionAllowed) + private async Task ProcessANAMEAsync(DnsDatagram request, DnsDatagram response, bool isRecursionAllowed) { List responseAnswer = new List(); @@ -1285,7 +1207,7 @@ namespace DnsServerCore.Dns if (lastResponse.RCODE == DnsResponseCode.Refused) { //not found in auth zone; do recursion - lastResponse = RecursiveResolve(newRequest, null, null, false, false); + lastResponse = await RecursiveResolveAsync(newRequest, null, null, false, false); } else if ((lastResponse.Answer.Count == 0) && (lastResponse.Authority.Count > 0)) { @@ -1296,14 +1218,14 @@ namespace DnsServerCore.Dns //do recursive resolution using last response authority name servers List nameServers = NameServerAddress.GetNameServersFromResponse(lastResponse, _preferIPv6, false); - lastResponse = RecursiveResolve(newRequest, nameServers, null, false, false); + lastResponse = await RecursiveResolveAsync(newRequest, nameServers, null, false, false); break; case DnsResourceRecordType.FWD: if ((lastResponse.Authority.Count == 1) && (lastResponse.Authority[0].RDATA as DnsForwarderRecord).Forwarder.Equals("this-server", StringComparison.OrdinalIgnoreCase)) { //do conditional forwarding via "this-server" - lastResponse = RecursiveResolve(newRequest, null, null, false, false); + lastResponse = await RecursiveResolveAsync(newRequest, null, null, false, false); } else { @@ -1321,7 +1243,7 @@ namespace DnsServerCore.Dns } } - lastResponse = RecursiveResolve(newRequest, null, forwarders, false, false); + lastResponse = await RecursiveResolveAsync(newRequest, null, forwarders, false, false); } break; @@ -1410,9 +1332,9 @@ namespace DnsServerCore.Dns return response; } - private DnsDatagram ProcessRecursiveQuery(DnsDatagram request, IReadOnlyList viaNameServers, IReadOnlyList viaForwarders, bool checkForCnameCloaking, bool cacheRefreshOperation) + private async Task ProcessRecursiveQueryAsync(DnsDatagram request, IReadOnlyList viaNameServers, IReadOnlyList viaForwarders, bool checkForCnameCloaking, bool cacheRefreshOperation) { - DnsDatagram response = RecursiveResolve(request, viaNameServers, viaForwarders, false, cacheRefreshOperation); + DnsDatagram response = await RecursiveResolveAsync(request, viaNameServers, viaForwarders, false, cacheRefreshOperation); if (response.Answer.Count > 0) { @@ -1420,7 +1342,7 @@ namespace DnsServerCore.Dns DnsResourceRecord lastRR = response.Answer[response.Answer.Count - 1]; if ((lastRR.Type != questionType) && (lastRR.Type == DnsResourceRecordType.CNAME) && (questionType != DnsResourceRecordType.ANY)) - response = ProcessCNAME(request, response, true, cacheRefreshOperation); + response = await ProcessCNAMEAsync(request, response, true, cacheRefreshOperation); if (checkForCnameCloaking) { @@ -1432,6 +1354,13 @@ namespace DnsServerCore.Dns break; //no further CNAME records exists DnsDatagram newRequest = new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { new DnsQuestionRecord((record.RDATA as DnsCNAMERecord).Domain, request.Question[0].Type, request.Question[0].Class) }); + + //check allowed zone + bool inAllowedZone = _allowedZoneManager.Query(newRequest).RCODE != DnsResponseCode.Refused; + if (inAllowedZone) + break; //CNAME is in allowed zone + + //check blocked zone and block list zone DnsDatagram lastResponse = ProcessBlockedQuery(newRequest); if (lastResponse != null) { @@ -1467,7 +1396,7 @@ namespace DnsServerCore.Dns } } - private DnsDatagram RecursiveResolve(DnsDatagram request, IReadOnlyList viaNameServers, IReadOnlyList viaForwarders, bool cachePrefetchOperation, bool cacheRefreshOperation) + private async Task RecursiveResolveAsync(DnsDatagram request, IReadOnlyList viaNameServers, IReadOnlyList viaForwarders, bool cachePrefetchOperation, bool cacheRefreshOperation) { if (!cachePrefetchOperation && !cacheRefreshOperation) { @@ -1482,21 +1411,8 @@ namespace DnsServerCore.Dns { if ((answer.OriginalTtlValue > _cachePrefetchEligibility) && (answer.TtlValue < _cachePrefetchTrigger)) { - //trigger prefetch in worker thread - ThreadPool.QueueUserWorkItem(delegate (object state) - { - try - { - RecursiveResolve(request, viaNameServers, viaForwarders, true, false); - } - catch (Exception ex) - { - LogManager log = _log; - if (log != null) - log.Write(ex); - } - }); - + //trigger prefetch async + _ = PrefetchCacheAsync(request, viaNameServers, viaForwarders); break; } } @@ -1507,13 +1423,16 @@ namespace DnsServerCore.Dns } //recursion with locking - ResolverQueryHandle newQueryHandle = new ResolverQueryHandle(); - ResolverQueryHandle queryHandle = _resolverQueryHandles.GetOrAdd(GetResolverQueryKey(request.Question[0]), newQueryHandle); + ResolverTask newResolverTask = new ResolverTask(); + ResolverTask resolverTask = _resolverTasks.GetOrAdd(GetResolverQueryKey(request.Question[0]), newResolverTask); - if (queryHandle.Equals(newQueryHandle)) + if (resolverTask.Equals(newResolverTask)) { - //got query handle so question not being resolved; do recursive resolution in worker thread - ThreadPool.QueueUserWorkItem(RecursiveResolveAsync, new object[] { request, viaNameServers, viaForwarders, cachePrefetchOperation, cacheRefreshOperation, queryHandle }); + //got new resolver task added so question is not being resolved; do recursive resolution in another task on resolver thread pool + _ = Task.Factory.StartNew(delegate () + { + return RecursiveResolveAsync(request, viaNameServers, viaForwarders, cachePrefetchOperation, cacheRefreshOperation, newResolverTask.TaskCompletionSource); + }, CancellationToken.None, TaskCreationOptions.DenyChildAttach, _resolverTaskScheduler); } //request is being recursively resolved by another thread @@ -1521,10 +1440,25 @@ namespace DnsServerCore.Dns if (cachePrefetchOperation) return null; //return null as prefetch worker thread does not need valid response and thus does not need to wait + if (resolverTask.IsStuck(_resolverTimeout)) + { + //resolver task is taking a long time to complete + //query cache zone to return stale answer (if available) + DnsDatagram staleResponse = QueryCache(request, true); + if (staleResponse != null) + return staleResponse; + + //no stale response available; wait for resolver task + } + + DateTime resolverWaitStartTime = DateTime.UtcNow; + //wait till short timeout for response - if (queryHandle.WaitForResponse(1800, out DnsDatagram response)) //1.8 sec wait as per draft-ietf-dnsop-serve-stale-04 + if (await Task.WhenAny(resolverTask.TaskCompletionSource.Task, Task.Delay(1800)) == resolverTask.TaskCompletionSource.Task) //1.8 sec wait as per draft-ietf-dnsop-serve-stale-04 { //resolver signaled + DnsDatagram response = await resolverTask.TaskCompletionSource.Task; + if (response != null) return response; @@ -1534,19 +1468,27 @@ namespace DnsServerCore.Dns { //wait timed out //query cache zone to return stale answer (if available) as per draft-ietf-dnsop-serve-stale-04 - DnsDatagram cacheResponse = QueryCache(request, true); - if ((cacheResponse != null) && (cacheResponse.RCODE == DnsResponseCode.NoError)) - return cacheResponse; + DnsDatagram staleResponse = QueryCache(request, true); + if (staleResponse != null) + return staleResponse; - //wait till full timeout before responding as ServerFailure - int timeout = _timeout - 1800; - if (timeout > 0) + if ((DateTime.UtcNow - resolverWaitStartTime).TotalMilliseconds < _clientTimeout) //check if there is any point in waiting further due to execution delay { - queryHandle.WaitForResponse(timeout, out response); - if (response != null) - return response; + //wait till full timeout before responding as ServerFailure + int timeout = _clientTimeout - 1800; + if (timeout > 0) + { + if (await Task.WhenAny(resolverTask.TaskCompletionSource.Task, Task.Delay(timeout)) == resolverTask.TaskCompletionSource.Task) + { + //resolver signaled + DnsDatagram response = await resolverTask.TaskCompletionSource.Task; - //no response available from resolver or resolver had exception and no stale record was found + if (response != null) + return response; + } + + //no response available from resolver or resolver had exception and no stale record was found + } } } @@ -1554,17 +1496,8 @@ namespace DnsServerCore.Dns return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.ServerFailure, request.Question); } - private void RecursiveResolveAsync(object parameter) + private async Task RecursiveResolveAsync(DnsDatagram request, IReadOnlyList viaNameServers, IReadOnlyList viaForwarders, bool cachePrefetchOperation, bool cacheRefreshOperation, TaskCompletionSource taskCompletionSource) { - object[] parameters = parameter as object[]; - - DnsDatagram request = parameters[0] as DnsDatagram; - IReadOnlyList viaNameServers = parameters[1] as IReadOnlyList; - IReadOnlyList viaForwarders = parameters[2] as IReadOnlyList; - bool cachePrefetchOperation = (bool)parameters[3]; - bool cacheRefreshOperation = (bool)parameters[4]; - ResolverQueryHandle queryHandle = parameters[5] as ResolverQueryHandle; - IReadOnlyList forwarders = _forwarders; if (viaForwarders != null) forwarders = viaForwarders; //use provided forwarders @@ -1574,14 +1507,13 @@ namespace DnsServerCore.Dns if ((viaNameServers == null) && (forwarders != null)) { //use forwarders - if (_proxy == null) { //recursive resolve name server when proxy is null else let proxy resolve it foreach (NameServerAddress nameServerAddress in forwarders) { if (nameServerAddress.IsIPEndPointStale) //refresh forwarder IPEndPoint if stale - nameServerAddress.RecursiveResolveIPAddress(_dnsCache, null, _preferIPv6, _retries, _timeout); + await nameServerAddress.RecursiveResolveIPAddressAsync(_dnsCache, null, _preferIPv6, _resolverRetries, _resolverTimeout); } } @@ -1590,14 +1522,15 @@ namespace DnsServerCore.Dns dnsClient.Proxy = _proxy; dnsClient.PreferIPv6 = _preferIPv6; - dnsClient.Retries = _retries; - dnsClient.Timeout = _timeout; + dnsClient.Retries = _forwarderRetries; + dnsClient.Timeout = _forwarderTimeout; + dnsClient.Concurrency = _forwarderConcurrency; - DnsDatagram response = dnsClient.Resolve(request.Question[0]); + DnsDatagram response = await dnsClient.ResolveAsync(request.Question[0]); _cacheZoneManager.CacheResponse(response); - queryHandle.Set(response); + taskCompletionSource.SetResult(response); } else { @@ -1609,8 +1542,8 @@ namespace DnsServerCore.Dns else dnsCache = _dnsCache; - DnsDatagram response = DnsClient.RecursiveResolve(request.Question[0], viaNameServers, dnsCache, _proxy, _preferIPv6, _retries, _timeout, _maxStackCount); - queryHandle.Set(response); + DnsDatagram response = await DnsClient.RecursiveResolveAsync(request.Question[0], viaNameServers, dnsCache, _proxy, _preferIPv6, _resolverRetries, _resolverTimeout, _resolverMaxStackCount); + taskCompletionSource.SetResult(response); } } catch (Exception ex) @@ -1645,28 +1578,28 @@ namespace DnsServerCore.Dns } //fetch stale record - DnsDatagram cacheResponse = QueryCache(request, true); - if (cacheResponse == null) + DnsDatagram staleResponse = QueryCache(request, true); + if (staleResponse == null) { - //no stale record was found; signal null response to release waiting threads - queryHandle.Set(null); + //no stale record was found; signal null response to release waiting tasks + taskCompletionSource.SetResult(null); } else { //reset expiry for stale records - foreach (DnsResourceRecord record in cacheResponse.Answer) + foreach (DnsResourceRecord record in staleResponse.Answer) { if (record.IsStale) record.ResetExpiry(30); //reset expiry by 30 seconds so that resolver tries again only after 30 seconds as per draft-ietf-dnsop-serve-stale-04 } - //signal stale record - queryHandle.Set(cacheResponse); + //signal stale response + taskCompletionSource.SetResult(staleResponse); } } finally { - _resolverQueryHandles.TryRemove(GetResolverQueryKey(request.Question[0]), out _); + _resolverTasks.TryRemove(GetResolverQueryKey(request.Question[0]), out _); } } @@ -1695,6 +1628,52 @@ namespace DnsServerCore.Dns return null; } + private async Task PrefetchCacheAsync(DnsDatagram request, IReadOnlyList viaNameServers, IReadOnlyList viaForwarders) + { + try + { + await RecursiveResolveAsync(request, viaNameServers, viaForwarders, true, false); + } + catch (Exception ex) + { + LogManager log = _log; + if (log != null) + log.Write(ex); + } + } + + private async Task RefreshCacheAsync(IList cacheRefreshSampleList, DnsQuestionRecord sampleQuestion, int sampleQuestionIndex) + { + try + { + //refresh cache + DnsDatagram request = new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { sampleQuestion }); + DnsDatagram response = await ProcessRecursiveQueryAsync(request, null, null, false, true); + + bool removeFromSampleList = true; + DateTime utcNow = DateTime.UtcNow; + + foreach (DnsResourceRecord answer in response.Answer) + { + if ((answer.OriginalTtlValue > _cachePrefetchEligibility) && (utcNow.AddSeconds(answer.TtlValue) < _cachePrefetchSamplingTimerTriggersOn)) + { + //answer expires before next sampling so dont remove from list to allow refreshing it + removeFromSampleList = false; + break; + } + } + + if (removeFromSampleList) + cacheRefreshSampleList[sampleQuestionIndex] = null; + } + catch (Exception ex) + { + LogManager log = _log; + if (log != null) + log.Write(ex); + } + } + private DnsQuestionRecord GetCacheRefreshNeededQuery(DnsQuestionRecord question, int trigger) { int queryCount = 0; @@ -1732,7 +1711,7 @@ namespace DnsServerCore.Dns } } - private bool CacheRefreshNeeded(DnsQuestionRecord question, int trigger) + private bool IsCacheRefreshNeeded(DnsQuestionRecord question, int trigger) { DnsDatagram cacheResponse = QueryCache(new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { question }), false); if (cacheResponse == null) @@ -1751,7 +1730,7 @@ namespace DnsServerCore.Dns return false; //no need to refresh for this query } - private void CachePrefetchSamplingAsync(object state) + private void CachePrefetchSamplingTimerCallback(object state) { try { @@ -1809,7 +1788,7 @@ namespace DnsServerCore.Dns } } - private void CachePrefetchRefreshAsync(object state) + private void CachePrefetchRefreshTimerCallback(object state) { try { @@ -1822,42 +1801,10 @@ namespace DnsServerCore.Dns if (sampleQuestion == null) continue; - if (!CacheRefreshNeeded(sampleQuestion, _cachePrefetchTrigger + 2)) + if (!IsCacheRefreshNeeded(sampleQuestion, _cachePrefetchTrigger + 2)) continue; - int sampleQuestionIndex = i; - - ThreadPool.QueueUserWorkItem(delegate (object state2) - { - try - { - //refresh cache - DnsDatagram request = new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { sampleQuestion }); - DnsDatagram response = ProcessRecursiveQuery(request, null, null, false, true); - - bool removeFromSampleList = true; - DateTime utcNow = DateTime.UtcNow; - - foreach (DnsResourceRecord answer in response.Answer) - { - if ((answer.OriginalTtlValue > _cachePrefetchEligibility) && (utcNow.AddSeconds(answer.TtlValue) < _cachePrefetchSamplingTimerTriggersOn)) - { - //answer expires before next sampling so dont remove from list to allow refreshing it - removeFromSampleList = false; - break; - } - } - - if (removeFromSampleList) - cacheRefreshSampleList[sampleQuestionIndex] = null; - } - catch (Exception ex) - { - LogManager log = _log; - if (log != null) - log.Write(ex); - } - }); + _ = RefreshCacheAsync(cacheRefreshSampleList, sampleQuestion, i); } } } @@ -1872,12 +1819,12 @@ namespace DnsServerCore.Dns lock (_cachePrefetchRefreshTimerLock) { if (_cachePrefetchRefreshTimer != null) - _cachePrefetchRefreshTimer.Change((_cachePrefetchTrigger + 1) * 1000, System.Threading.Timeout.Infinite); + _cachePrefetchRefreshTimer.Change((_cachePrefetchTrigger + 1) * 1000, Timeout.Infinite); } } } - private void CacheMaintenanceAsync(object state) + private void CacheMaintenanceTimerCallback(object state) { try { @@ -2118,69 +2065,66 @@ namespace DnsServerCore.Dns } //start reading query packets + int listenerThreadCount = Math.Max(1, Environment.ProcessorCount); + foreach (Socket udpListener in _udpListeners) { - for (int i = 0; i < LISTENER_THREAD_COUNT; i++) + for (int i = 0; i < listenerThreadCount; i++) { - Thread listenerThread = new Thread(ReadUdpRequestAsync); - listenerThread.IsBackground = true; - listenerThread.Start(udpListener); - - _listenerThreads.Add(listenerThread); + Thread thread = new Thread(ReadUdpRequestAsync); + thread.Name = "DNS UDP Read Request [" + i + "]"; + thread.IsBackground = true; + thread.Start(udpListener); } } foreach (Socket tcpListener in _tcpListeners) { - for (int i = 0; i < LISTENER_THREAD_COUNT; i++) + for (int i = 0; i < listenerThreadCount; i++) { - Thread listenerThread = new Thread(AcceptConnectionAsync); - listenerThread.IsBackground = true; - listenerThread.Start(new object[] { tcpListener, DnsTransportProtocol.Tcp }); - - _listenerThreads.Add(listenerThread); + Thread thread = new Thread(AcceptConnectionAsync); + thread.Name = "DNS TCP Read Request [" + i + "]"; + thread.IsBackground = true; + thread.Start(new object[] { tcpListener, DnsTransportProtocol.Tcp }); } } foreach (Socket httpListener in _httpListeners) { - for (int i = 0; i < LISTENER_THREAD_COUNT; i++) + for (int i = 0; i < listenerThreadCount; i++) { - Thread listenerThread = new Thread(AcceptConnectionAsync); - listenerThread.IsBackground = true; - listenerThread.Start(new object[] { httpListener, DnsTransportProtocol.Https, false }); - - _listenerThreads.Add(listenerThread); + Thread thread = new Thread(AcceptConnectionAsync); + thread.Name = "DNS HTTP Read Request [" + i + "]"; + thread.IsBackground = true; + thread.Start(new object[] { httpListener, DnsTransportProtocol.Https, false }); } } foreach (Socket tlsListener in _tlsListeners) { - for (int i = 0; i < LISTENER_THREAD_COUNT; i++) + for (int i = 0; i < listenerThreadCount; i++) { - Thread listenerThread = new Thread(AcceptConnectionAsync); - listenerThread.IsBackground = true; - listenerThread.Start(new object[] { tlsListener, DnsTransportProtocol.Tls }); - - _listenerThreads.Add(listenerThread); + Thread thread = new Thread(AcceptConnectionAsync); + thread.Name = "DNS TLS Read Request [" + i + "]"; + thread.IsBackground = true; + thread.Start(new object[] { tlsListener, DnsTransportProtocol.Tls }); } } foreach (Socket httpsListener in _httpsListeners) { - for (int i = 0; i < LISTENER_THREAD_COUNT; i++) + for (int i = 0; i < listenerThreadCount; i++) { - Thread listenerThread = new Thread(AcceptConnectionAsync); - listenerThread.IsBackground = true; - listenerThread.Start(new object[] { httpsListener, DnsTransportProtocol.Https }); - - _listenerThreads.Add(listenerThread); + Thread thread = new Thread(AcceptConnectionAsync); + thread.Name = "DNS HTTPS Read Request [" + i + "]"; + thread.IsBackground = true; + thread.Start(new object[] { httpsListener, DnsTransportProtocol.Https }); } } - _cachePrefetchSamplingTimer = new Timer(CachePrefetchSamplingAsync, null, System.Threading.Timeout.Infinite, System.Threading.Timeout.Infinite); - _cachePrefetchRefreshTimer = new Timer(CachePrefetchRefreshAsync, null, System.Threading.Timeout.Infinite, System.Threading.Timeout.Infinite); - _cacheMaintenanceTimer = new Timer(CacheMaintenanceAsync, null, CACHE_MAINTENANCE_TIMER_INITIAL_INTEVAL, CACHE_MAINTENANCE_TIMER_PERIODIC_INTERVAL); + _cachePrefetchSamplingTimer = new Timer(CachePrefetchSamplingTimerCallback, null, Timeout.Infinite, Timeout.Infinite); + _cachePrefetchRefreshTimer = new Timer(CachePrefetchRefreshTimerCallback, null, Timeout.Infinite, Timeout.Infinite); + _cacheMaintenanceTimer = new Timer(CacheMaintenanceTimerCallback, null, CACHE_MAINTENANCE_TIMER_INITIAL_INTEVAL, CACHE_MAINTENANCE_TIMER_PERIODIC_INTERVAL); _state = ServiceState.Running; @@ -2281,7 +2225,6 @@ namespace DnsServerCore.Dns } } - _listenerThreads.Clear(); _udpListeners.Clear(); _tcpListeners.Clear(); _httpListeners.Clear(); @@ -2291,29 +2234,30 @@ namespace DnsServerCore.Dns _state = ServiceState.Stopped; } - public DnsDatagram DirectQuery(DnsQuestionRecord question, int timeout = 2000) + public async Task DirectQueryAsync(DnsQuestionRecord question, int timeout = 2000) { - EventWaitHandle waitHandle = new ManualResetEvent(false); - DnsDatagram response = null; - - ThreadPool.QueueUserWorkItem(delegate (object state) + try { - try + Task task = ProcessQueryAsync(new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { question }), new IPEndPoint(IPAddress.Any, 0), true, DnsTransportProtocol.Tcp); + + using (CancellationTokenSource timeoutCancellationTokenSource = new CancellationTokenSource()) { - response = ProcessQuery(new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { question }), new IPEndPoint(IPAddress.Any, 0), true, DnsTransportProtocol.Tcp); - } - catch (Exception ex) - { - LogManager log = _log; - if (log != null) - log.Write(ex); + if (await Task.WhenAny(task, Task.Delay(timeout, timeoutCancellationTokenSource.Token)) != task) + return null; + + timeoutCancellationTokenSource.Cancel(); //stop delay task } - waitHandle.Set(); - }); + return await task; + } + catch (Exception ex) + { + LogManager log = _log; + if (log != null) + log.Write(ex); - waitHandle.WaitOne(timeout); - return response; + return null; + } } #endregion @@ -2432,30 +2376,66 @@ namespace DnsServerCore.Dns set { _preferIPv6 = value; } } - public int Retries + public int ForwarderRetries { - get { return _retries; } + get { return _forwarderRetries; } set { if (value > 0) - _retries = value; + _forwarderRetries = value; } } - public int Timeout + public int ResolverRetries { - get { return _timeout; } + get { return _resolverRetries; } + set + { + if (value > 0) + _resolverRetries = value; + } + } + + public int ForwarderTimeout + { + get { return _forwarderTimeout; } set { if (value >= 2000) - _timeout = value; + _forwarderTimeout = value; } } - public int MaxStackCount + public int ResolverTimeout { - get { return _maxStackCount; } - set { _maxStackCount = value; } + get { return _resolverTimeout; } + set + { + if (value >= 2000) + _resolverTimeout = value; + } + } + + public int ClientTimeout + { + get { return _clientTimeout; } + set + { + if (value >= 2000) + _clientTimeout = value; + } + } + + public int ForwarderConcurrency + { + get { return _forwarderConcurrency; } + set { _forwarderConcurrency = value; } + } + + public int ResolverMaxStackCount + { + get { return _resolverMaxStackCount; } + set { _resolverMaxStackCount = value; } } public int CachePrefetchEligibility