From bae6b0483ac11a7dd7c310f355767bfc5bde4e7e Mon Sep 17 00:00:00 2001 From: Shreyas Zare Date: Sat, 29 Aug 2020 14:36:00 +0530 Subject: [PATCH] DnsServer: using separate forwarder and resolver settings. Implemented async methods for request processing and for resolver. Using intependent task scheduler for resolver. Added resolver task stuck check to allow using quick stale response. --- DnsServerCore/Dns/DnsServer.cs | 820 ++++++++++++++++----------------- 1 file changed, 400 insertions(+), 420 deletions(-) diff --git a/DnsServerCore/Dns/DnsServer.cs b/DnsServerCore/Dns/DnsServer.cs index cc014787..3c2b345d 100644 --- a/DnsServerCore/Dns/DnsServer.cs +++ b/DnsServerCore/Dns/DnsServer.cs @@ -22,7 +22,6 @@ using DnsServerCore.Dns.Zones; using Newtonsoft.Json; using System; using System.Collections.Generic; -using System.Collections.Specialized; using System.IO; using System.Net; using System.Net.Security; @@ -30,10 +29,12 @@ using System.Net.Sockets; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading; +using System.Threading.Tasks; using TechnitiumLibrary.IO; using TechnitiumLibrary.Net; using TechnitiumLibrary.Net.Dns; using TechnitiumLibrary.Net.Dns.ResourceRecords; +using TechnitiumLibrary.Net.Http; using TechnitiumLibrary.Net.Proxy; namespace DnsServerCore.Dns @@ -54,7 +55,6 @@ namespace DnsServerCore.Dns #region variables - const int LISTENER_THREAD_COUNT = 4; const int MAX_CNAME_HOPS = 16; string _serverDomain; @@ -69,7 +69,6 @@ namespace DnsServerCore.Dns readonly List _httpListeners = new List(); readonly List _tlsListeners = new List(); readonly List _httpsListeners = new List(); - readonly List _listenerThreads = new List(); bool _enableDnsOverHttp = false; bool _enableDnsOverTls = false; @@ -93,9 +92,13 @@ namespace DnsServerCore.Dns NetProxy _proxy; IReadOnlyList _forwarders; bool _preferIPv6 = false; - int _retries = 2; - int _timeout = 4000; - int _maxStackCount = 10; + int _forwarderRetries = 3; + int _resolverRetries = 5; + int _forwarderTimeout = 4000; + int _resolverTimeout = 4000; + int _clientTimeout = 4000; + int _forwarderConcurrency = 2; + int _resolverMaxStackCount = 10; int _cachePrefetchEligibility = 2; int _cachePrefetchTrigger = 9; int _cachePrefetchSampleIntervalInMinutes = 5; @@ -119,7 +122,8 @@ namespace DnsServerCore.Dns const int CACHE_MAINTENANCE_TIMER_INITIAL_INTEVAL = 60 * 60 * 1000; const int CACHE_MAINTENANCE_TIMER_PERIODIC_INTERVAL = 60 * 60 * 1000; - readonly DomainTree _resolverQueryHandles = new DomainTree(); + readonly IndependentTaskScheduler _resolverTaskScheduler = new IndependentTaskScheduler(ThreadPriority.AboveNormal); + readonly DomainTree _resolverTasks = new DomainTree(); volatile ServiceState _state = ServiceState.Stopped; @@ -257,7 +261,9 @@ namespace DnsServerCore.Dns { try { - ThreadPool.QueueUserWorkItem(ProcessUdpRequestAsync, new object[] { udpListener, remoteEP, new DnsDatagram(new MemoryStream(recvBuffer, 0, bytesRecv, false), false) }); + DnsDatagram request = DnsDatagram.ReadFromUdp(new MemoryStream(recvBuffer, 0, bytesRecv, false)); + + _ = ProcessUdpRequestAsync(udpListener, remoteEP, request); } catch (Exception ex) { @@ -281,21 +287,15 @@ namespace DnsServerCore.Dns } } - private void ProcessUdpRequestAsync(object parameter) + private async Task ProcessUdpRequestAsync(Socket udpListener, EndPoint remoteEP, DnsDatagram request) { - object[] parameters = parameter as object[]; - - Socket udpListener = parameters[0] as Socket; - EndPoint remoteEP = parameters[1] as EndPoint; - DnsDatagram request = parameters[2] as DnsDatagram; - try { DnsDatagram response; if (request.ParsingException == null) { - response = ProcessQuery(request, remoteEP, IsRecursionAllowed(remoteEP), DnsTransportProtocol.Udp); + response = await ProcessQueryAsync(request, remoteEP, IsRecursionAllowed(remoteEP), DnsTransportProtocol.Udp); } else { @@ -316,18 +316,18 @@ namespace DnsServerCore.Dns try { - response.WriteTo(sendBufferStream, false); + response.WriteToUdp(sendBufferStream); } catch (NotSupportedException) { response = new DnsDatagram(response.Identifier, true, response.OPCODE, response.AuthoritativeAnswer, true, response.RecursionDesired, response.RecursionAvailable, response.AuthenticData, response.CheckingDisabled, response.RCODE, response.Question); sendBufferStream.Position = 0; - response.WriteTo(sendBufferStream, false); + response.WriteToUdp(sendBufferStream); } - //send dns datagram - udpListener.SendTo(sendBuffer, 0, (int)sendBufferStream.Position, SocketFlags.None, remoteEP); + //send dns datagram async + await udpListener.SendToAsync(sendBuffer, 0, (int)sendBufferStream.Position, remoteEP); LogManager queryLog = _queryLog; if (queryLog != null) @@ -378,64 +378,7 @@ namespace DnsServerCore.Dns { Socket socket = tcpListener.Accept(); - ThreadPool.QueueUserWorkItem(delegate (object state) - { - EndPoint remoteEP = null; - - try - { - remoteEP = socket.RemoteEndPoint; - - switch (protocol) - { - case DnsTransportProtocol.Tcp: - ReadStreamRequest(new NetworkStream(socket), remoteEP, protocol); - break; - - case DnsTransportProtocol.Tls: - SslStream tlsStream = new SslStream(new NetworkStream(socket)); - tlsStream.AuthenticateAsServer(_certificate); - - ReadStreamRequest(tlsStream, remoteEP, protocol); - break; - - case DnsTransportProtocol.Https: - Stream stream = new NetworkStream(socket); - - if (usingHttps) - { - SslStream httpsStream = new SslStream(stream); - httpsStream.AuthenticateAsServer(_certificate); - - stream = httpsStream; - } - else if (!NetUtilities.IsPrivateIP((remoteEP as IPEndPoint).Address)) - { - //intentionally blocking public IP addresses from using DNS-over-HTTP (without TLS) - //this feature is intended to be used with an SSL terminated reverse proxy like nginx on private network - return; - } - - ProcessDoHRequest(stream, remoteEP, !usingHttps); - break; - } - } - catch (IOException) - { - //ignore IO exceptions - } - catch (Exception ex) - { - LogManager log = _log; - if (log != null) - log.Write(remoteEP as IPEndPoint, protocol, ex); - } - finally - { - if (socket != null) - socket.Dispose(); - } - }); + _ = ProcessConnectionAsync(socket, protocol, usingHttps); } } catch (Exception ex) @@ -451,24 +394,96 @@ namespace DnsServerCore.Dns } } - private void ReadStreamRequest(Stream stream, EndPoint remoteEP, DnsTransportProtocol protocol) + private async Task ProcessConnectionAsync(Socket socket, DnsTransportProtocol protocol, bool usingHttps) { - DnsDatagram request = null; + EndPoint remoteEP = null; + try + { + remoteEP = socket.RemoteEndPoint; + + switch (protocol) + { + case DnsTransportProtocol.Tcp: + await ReadStreamRequestAsync(new NetworkStream(socket), _tcpReceiveTimeout, remoteEP, protocol); + break; + + case DnsTransportProtocol.Tls: + SslStream tlsStream = new SslStream(new NetworkStream(socket)); + await tlsStream.AuthenticateAsServerAsync(_certificate); + + await ReadStreamRequestAsync(tlsStream, _tcpReceiveTimeout, remoteEP, protocol); + break; + + case DnsTransportProtocol.Https: + Stream stream = new NetworkStream(socket); + + if (usingHttps) + { + SslStream httpsStream = new SslStream(stream); + await httpsStream.AuthenticateAsServerAsync(_certificate); + + stream = httpsStream; + } + else if (!NetUtilities.IsPrivateIP((remoteEP as IPEndPoint).Address)) + { + //intentionally blocking public IP addresses from using DNS-over-HTTP (without TLS) + //this feature is intended to be used with an SSL terminated reverse proxy like nginx on private network + return; + } + + await ProcessDoHRequestAsync(stream, _tcpReceiveTimeout, remoteEP, !usingHttps); + break; + } + } + catch (IOException) + { + //ignore IO exceptions + } + catch (Exception ex) + { + LogManager log = _log; + if (log != null) + log.Write(remoteEP as IPEndPoint, protocol, ex); + } + finally + { + if (socket != null) + socket.Dispose(); + } + } + + private async Task ReadStreamRequestAsync(Stream stream, int receiveTimeout, EndPoint remoteEP, DnsTransportProtocol protocol) + { try { MemoryStream readBuffer = new MemoryStream(64); MemoryStream writeBuffer = new MemoryStream(64); + SemaphoreSlim writeSemaphore = new SemaphoreSlim(1, 1); while (true) { - request = null; + DnsDatagram request; - //read dns datagram - request = new DnsDatagram(stream, true, readBuffer); + //read dns datagram with timeout + using (CancellationTokenSource cancellationTokenSource = new CancellationTokenSource()) + { + Task task = DnsDatagram.ReadFromTcpAsync(stream, readBuffer, cancellationTokenSource.Token); + + if (await Task.WhenAny(task, Task.Delay(receiveTimeout, cancellationTokenSource.Token)) != task) + { + //read timed out + stream.Dispose(); + return; + } + + cancellationTokenSource.Cancel(); //cancel delay task + + request = await task; + } //process request async - ThreadPool.QueueUserWorkItem(ProcessStreamRequestAsync, new object[] { stream, writeBuffer, remoteEP, request, protocol }); + _ = ProcessStreamRequestAsync(stream, writeBuffer, writeSemaphore, remoteEP, request, protocol); } } catch (ObjectDisposedException) @@ -481,54 +496,49 @@ namespace DnsServerCore.Dns } catch (Exception ex) { - LogManager queryLog = _queryLog; - if ((queryLog != null) && (request != null)) - queryLog.Write(remoteEP as IPEndPoint, protocol, request, null); - LogManager log = _log; if (log != null) log.Write(remoteEP as IPEndPoint, protocol, ex); } } - private void ProcessStreamRequestAsync(object parameter) + private async Task ProcessStreamRequestAsync(Stream stream, MemoryStream writeBuffer, SemaphoreSlim writeSemaphore, EndPoint remoteEP, DnsDatagram request, DnsTransportProtocol protocol) { - object[] parameters = parameter as object[]; - - Stream stream = parameters[0] as Stream; - MemoryStream writeBuffer = parameters[1] as MemoryStream; - EndPoint remoteEP = parameters[2] as EndPoint; - DnsDatagram request = parameters[3] as DnsDatagram; - DnsTransportProtocol protocol = (DnsTransportProtocol)parameters[4]; - try { DnsDatagram response; if (request.ParsingException == null) { - response = ProcessQuery(request, remoteEP, IsRecursionAllowed(remoteEP), protocol); + response = await ProcessQueryAsync(request, remoteEP, IsRecursionAllowed(remoteEP), protocol); } else { //format error + response = new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, request.Question); + + LogManager queryLog = _queryLog; + if (queryLog != null) + queryLog.Write(remoteEP as IPEndPoint, protocol, request, response); + LogManager log = _log; if (log != null) log.Write(remoteEP as IPEndPoint, protocol, request.ParsingException); - - //format error response - response = new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, request.Question); } //send response if (response != null) { - lock (stream) + await writeSemaphore.WaitAsync(); + try { //send dns datagram - response.WriteTo(stream, true, writeBuffer); - - stream.Flush(); + await response.WriteToTcpAsync(stream, writeBuffer); + await stream.FlushAsync(); + } + finally + { + writeSemaphore.Release(); } LogManager queryLog = _queryLog; @@ -556,7 +566,7 @@ namespace DnsServerCore.Dns } } - private void ProcessDoHRequest(Stream stream, EndPoint remoteEP, bool usingReverseProxy) + private async Task ProcessDoHRequestAsync(Stream stream, int receiveTimeout, EndPoint remoteEP, bool usingReverseProxy) { DnsDatagram dnsRequest = null; DnsTransportProtocol dnsProtocol = DnsTransportProtocol.Https; @@ -565,99 +575,13 @@ namespace DnsServerCore.Dns { while (true) { - string requestMethod; - string requestPath; - NameValueCollection requestQueryString = new NameValueCollection(); - string requestProtocol; - WebHeaderCollection requestHeaders = new WebHeaderCollection(); - - #region parse http request - - using (MemoryStream mS = new MemoryStream()) - { - //read http request header into memory stream - int byteRead; - int crlfCount = 0; - - while (true) - { - byteRead = stream.ReadByte(); - switch (byteRead) - { - case '\r': - case '\n': - crlfCount++; - break; - - case -1: - throw new EndOfStreamException(); - - default: - crlfCount = 0; - break; - } - - mS.WriteByte((byte)byteRead); - - if (crlfCount == 4) - break; //http request completed - } - - mS.Position = 0; - StreamReader sR = new StreamReader(mS); - - string[] requestParts = sR.ReadLine().Split(new char[] { ' ' }, 3); - - if (requestParts.Length != 3) - throw new InvalidDataException("Invalid HTTP request."); - - requestMethod = requestParts[0]; - string pathAndQueryString = requestParts[1]; - requestProtocol = requestParts[2]; - - string[] requestPathAndQueryParts = pathAndQueryString.Split(new char[] { '?' }, 2); - - requestPath = requestPathAndQueryParts[0]; - - string queryString = null; - if (requestPathAndQueryParts.Length > 1) - queryString = requestPathAndQueryParts[1]; - - if (!string.IsNullOrEmpty(queryString)) - { - foreach (string item in queryString.Split(new char[] { '&' }, StringSplitOptions.RemoveEmptyEntries)) - { - string[] itemParts = item.Split(new char[] { '=' }, 2); - - string name = itemParts[0]; - string value = null; - - if (itemParts.Length > 1) - value = itemParts[1]; - - requestQueryString.Add(name, value); - } - } - - while (true) - { - string line = sR.ReadLine(); - if (string.IsNullOrEmpty(line)) - break; - - string[] parts = line.Split(new char[] { ':' }, 2); - if (parts.Length != 2) - throw new InvalidDataException("Invalid HTTP request."); - - requestHeaders.Add(parts[0], parts[1]); - } - } - - #endregion + HttpRequest httpRequest = await HttpRequest.ReadRequestAsync(stream).WithTimeout(receiveTimeout); + if (httpRequest == null) + return; //connection closed gracefully by client if (usingReverseProxy) { - string xRealIp = requestHeaders["X-Real-IP"]; + string xRealIp = httpRequest.Headers["X-Real-IP"]; if (IPAddress.TryParse(xRealIp, out IPAddress address)) { //get the real IP address of the requesting client from X-Real-IP header set in nginx proxy_pass block @@ -665,16 +589,16 @@ namespace DnsServerCore.Dns } } - string requestConnection = requestHeaders[HttpRequestHeader.Connection]; + string requestConnection = httpRequest.Headers[HttpRequestHeader.Connection]; if (string.IsNullOrEmpty(requestConnection)) requestConnection = "close"; - switch (requestPath) + switch (httpRequest.RequestPath) { case "/dns-query": DnsTransportProtocol protocol = DnsTransportProtocol.Udp; - string strRequestAcceptTypes = requestHeaders[HttpRequestHeader.Accept]; + string strRequestAcceptTypes = httpRequest.Headers[HttpRequestHeader.Accept]; if (!string.IsNullOrEmpty(strRequestAcceptTypes)) { foreach (string acceptType in strRequestAcceptTypes.Split(',')) @@ -698,10 +622,10 @@ namespace DnsServerCore.Dns case DnsTransportProtocol.Https: #region https wire format { - switch (requestMethod) + switch (httpRequest.HttpMethod) { case "GET": - string strRequest = requestQueryString["dns"]; + string strRequest = httpRequest.QueryString["dns"]; if (string.IsNullOrEmpty(strRequest)) throw new ArgumentNullException("dns"); @@ -714,29 +638,23 @@ namespace DnsServerCore.Dns if (x > 0) strRequest = strRequest.PadRight(strRequest.Length - x + 4, '='); - dnsRequest = new DnsDatagram(new MemoryStream(Convert.FromBase64String(strRequest)), false); + dnsRequest = DnsDatagram.ReadFromUdp(new MemoryStream(Convert.FromBase64String(strRequest))); break; case "POST": - string strContentType = requestHeaders[HttpRequestHeader.ContentType]; + string strContentType = httpRequest.Headers[HttpRequestHeader.ContentType]; if (string.IsNullOrEmpty(strContentType)) throw new DnsServerException("Missing Content-Type header."); if (strContentType != "application/dns-message") throw new NotSupportedException("DNS request type not supported: " + strContentType); - string strContentLength = requestHeaders[HttpRequestHeader.ContentLength]; - if (string.IsNullOrEmpty(strContentLength)) - throw new DnsServerException("Missing Content-Length header."); - - int contentLength = int.Parse(strContentLength); - using (MemoryStream mS = new MemoryStream()) { - stream.CopyTo(mS, 512, contentLength); + await httpRequest.InputStream.CopyToAsync(mS, 512); mS.Position = 0; - dnsRequest = new DnsDatagram(mS, false); + dnsRequest = DnsDatagram.ReadFromUdp(mS); } break; @@ -749,7 +667,7 @@ namespace DnsServerCore.Dns if (dnsRequest.ParsingException == null) { - dnsResponse = ProcessQuery(dnsRequest, remoteEP, IsRecursionAllowed(remoteEP), protocol); + dnsResponse = await ProcessQueryAsync(dnsRequest, remoteEP, IsRecursionAllowed(remoteEP), protocol); } else { @@ -766,10 +684,10 @@ namespace DnsServerCore.Dns { using (MemoryStream mS = new MemoryStream()) { - dnsResponse.WriteTo(mS, false); + dnsResponse.WriteToUdp(mS); byte[] buffer = mS.ToArray(); - SendContent(stream, "application/dns-message", buffer); + await SendContentAsync(stream, "application/dns-message", buffer); } LogManager queryLog = _queryLog; @@ -787,27 +705,27 @@ namespace DnsServerCore.Dns case DnsTransportProtocol.HttpsJson: #region https json format { - string strName = requestQueryString["name"]; + string strName = httpRequest.QueryString["name"]; if (string.IsNullOrEmpty(strName)) throw new ArgumentNullException("name"); - string strType = requestQueryString["type"]; + string strType = httpRequest.QueryString["type"]; if (string.IsNullOrEmpty(strType)) strType = "1"; dnsRequest = new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { new DnsQuestionRecord(strName, (DnsResourceRecordType)int.Parse(strType), DnsClass.IN) }); - DnsDatagram dnsResponse = ProcessQuery(dnsRequest, remoteEP, IsRecursionAllowed(remoteEP), protocol); + DnsDatagram dnsResponse = await ProcessQueryAsync(dnsRequest, remoteEP, IsRecursionAllowed(remoteEP), protocol); if (dnsResponse != null) { using (MemoryStream mS = new MemoryStream()) { JsonTextWriter jsonWriter = new JsonTextWriter(new StreamWriter(mS)); - dnsResponse.WriteTo(jsonWriter); + dnsResponse.WriteToJson(jsonWriter); jsonWriter.Flush(); byte[] buffer = mS.ToArray(); - SendContent(stream, "application/dns-json; charset=utf-8", buffer); + await SendContentAsync(stream, "application/dns-json; charset=utf-8", buffer); } LogManager queryLog = _queryLog; @@ -823,7 +741,7 @@ namespace DnsServerCore.Dns break; default: - SendError(stream, 406, "Only application/dns-message and application/dns-json types are accepted."); + await SendErrorAsync(stream, 406, "Only application/dns-message and application/dns-json types are accepted."); break; } @@ -833,11 +751,15 @@ namespace DnsServerCore.Dns break; default: - SendError(stream, 404); + await SendErrorAsync(stream, 404); break; } } } + catch (TimeoutException) + { + //ignore timeout exception + } catch (IOException) { //ignore IO exceptions @@ -852,25 +774,25 @@ namespace DnsServerCore.Dns if (log != null) log.Write(remoteEP as IPEndPoint, dnsProtocol, ex); - SendError(stream, ex); + await SendErrorAsync(stream, ex); } } - private static void SendContent(Stream outputStream, string contentType, byte[] bufferContent) + private static async Task SendContentAsync(Stream outputStream, string contentType, byte[] bufferContent) { byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: " + contentType + "\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\n\r\n"); - outputStream.Write(bufferHeader, 0, bufferHeader.Length); - outputStream.Write(bufferContent, 0, bufferContent.Length); - outputStream.Flush(); + await outputStream.WriteAsync(bufferHeader, 0, bufferHeader.Length); + await outputStream.WriteAsync(bufferContent, 0, bufferContent.Length); + await outputStream.FlushAsync(); } - private static void SendError(Stream outputStream, Exception ex) + private static Task SendErrorAsync(Stream outputStream, Exception ex) { - SendError(outputStream, 500, ex.ToString()); + return SendErrorAsync(outputStream, 500, ex.ToString()); } - private static void SendError(Stream outputStream, int statusCode, string message = null) + private static async Task SendErrorAsync(Stream outputStream, int statusCode, string message = null) { try { @@ -878,9 +800,9 @@ namespace DnsServerCore.Dns byte[] bufferContent = Encoding.UTF8.GetBytes("" + statusString + "

" + statusString + "

" + (message == null ? "" : "

" + message + "

") + ""); byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 " + statusString + "\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\n\r\n"); - outputStream.Write(bufferHeader, 0, bufferHeader.Length); - outputStream.Write(bufferContent, 0, bufferContent.Length); - outputStream.Flush(); + await outputStream.WriteAsync(bufferHeader, 0, bufferHeader.Length); + await outputStream.WriteAsync(bufferContent, 0, bufferContent.Length); + await outputStream.FlushAsync(); } catch { } @@ -922,7 +844,7 @@ namespace DnsServerCore.Dns return true; } - private DnsDatagram ProcessQuery(DnsDatagram request, EndPoint remoteEP, bool isRecursionAllowed, DnsTransportProtocol protocol) + private async Task ProcessQueryAsync(DnsDatagram request, EndPoint remoteEP, bool isRecursionAllowed, DnsTransportProtocol protocol) { if (request.IsResponse) return null; @@ -941,10 +863,10 @@ namespace DnsServerCore.Dns if (protocol == DnsTransportProtocol.Udp) return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.FormatError, request.Question); - return ProcessZoneTransferQuery(request, remoteEP); + return await ProcessZoneTransferQueryAsync(request, remoteEP); case DnsResourceRecordType.IXFR: - return ProcessZoneTransferQuery(request, remoteEP); + return await ProcessZoneTransferQueryAsync(request, remoteEP); case DnsResourceRecordType.MAILB: case DnsResourceRecordType.MAILA: @@ -964,13 +886,13 @@ namespace DnsServerCore.Dns } //query authoritative zone - response = ProcessAuthoritativeQuery(request, inAllowedZone, isRecursionAllowed); + response = await ProcessAuthoritativeQueryAsync(request, inAllowedZone, isRecursionAllowed); if ((response.RCODE != DnsResponseCode.Refused) || !request.RecursionDesired || !isRecursionAllowed) return response; //do recursive query - return ProcessRecursiveQuery(request, null, null, !inAllowedZone, false); + return await ProcessRecursiveQueryAsync(request, null, null, !inAllowedZone, false); } } catch (Exception ex) @@ -983,14 +905,14 @@ namespace DnsServerCore.Dns } case DnsOpcode.Notify: - return ProcessNotifyQuery(request, remoteEP); + return await ProcessNotifyQueryAsync(request, remoteEP); default: return new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.NotImplemented, request.Question); } } - private DnsDatagram ProcessNotifyQuery(DnsDatagram request, EndPoint remoteEP) + private async Task ProcessNotifyQueryAsync(DnsDatagram request, EndPoint remoteEP) { AuthZoneInfo authZoneInfo = _authZoneManager.GetAuthZoneInfo(request.Question[0].Name); if ((authZoneInfo == null) || (authZoneInfo.Type != AuthZoneType.Secondary)) @@ -999,7 +921,7 @@ namespace DnsServerCore.Dns IPAddress remoteAddress = (remoteEP as IPEndPoint).Address; bool remoteVerified = false; - IReadOnlyList primaryNameServers = authZoneInfo.GetPrimaryNameServerAddresses(this); + IReadOnlyList primaryNameServers = await authZoneInfo.GetPrimaryNameServerAddressesAsync(this); foreach (NameServerAddress primaryNameServer in primaryNameServers) { @@ -1032,7 +954,7 @@ namespace DnsServerCore.Dns return new DnsDatagram(request.Identifier, true, DnsOpcode.Notify, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NoError, request.Question) { Tag = StatsResponseType.Authoritative }; } - private DnsDatagram ProcessZoneTransferQuery(DnsDatagram request, EndPoint remoteEP) + private async Task ProcessZoneTransferQueryAsync(DnsDatagram request, EndPoint remoteEP) { AuthZoneInfo authZoneInfo = _authZoneManager.GetAuthZoneInfo(request.Question[0].Name); if ((authZoneInfo == null) || (authZoneInfo.Type != AuthZoneType.Primary)) @@ -1043,7 +965,7 @@ namespace DnsServerCore.Dns if (!isAxfrAllowed) { - IReadOnlyList secondaryNameServers = authZoneInfo.GetSecondaryNameServerAddresses(this); + IReadOnlyList secondaryNameServers = await authZoneInfo.GetSecondaryNameServerAddressesAsync(this); foreach (NameServerAddress secondaryNameServer in secondaryNameServers) { @@ -1067,7 +989,7 @@ namespace DnsServerCore.Dns return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, true, false, request.RecursionDesired, false, false, false, DnsResponseCode.NoError, request.Question, axfrRecords) { Tag = StatsResponseType.Authoritative }; } - private DnsDatagram ProcessAuthoritativeQuery(DnsDatagram request, bool inAllowedZone, bool isRecursionAllowed) + private Task ProcessAuthoritativeQueryAsync(DnsDatagram request, bool inAllowedZone, bool isRecursionAllowed) { DnsDatagram response = _authZoneManager.Query(request); response.Tag = StatsResponseType.Authoritative; @@ -1084,10 +1006,10 @@ namespace DnsServerCore.Dns switch (lastRR.Type) { case DnsResourceRecordType.CNAME: - return ProcessCNAME(request, response, isRecursionAllowed, false); + return ProcessCNAMEAsync(request, response, isRecursionAllowed, false); case DnsResourceRecordType.ANAME: - return ProcessANAME(request, response, isRecursionAllowed); + return ProcessANAMEAsync(request, response, isRecursionAllowed); } } } @@ -1101,7 +1023,7 @@ namespace DnsServerCore.Dns //do recursive resolution using response authority name servers List nameServers = NameServerAddress.GetNameServersFromResponse(response, _preferIPv6, false); - return ProcessRecursiveQuery(request, nameServers, null, !inAllowedZone, false); + return ProcessRecursiveQueryAsync(request, nameServers, null, !inAllowedZone, false); } break; @@ -1110,7 +1032,7 @@ namespace DnsServerCore.Dns if ((response.Authority.Count == 1) && (response.Authority[0].Type == DnsResourceRecordType.FWD) && (response.Authority[0].RDATA as DnsForwarderRecord).Forwarder.Equals("this-server", StringComparison.OrdinalIgnoreCase)) { //do conditional forwarding via "this-server" - return ProcessRecursiveQuery(request, null, null, !inAllowedZone, false); + return ProcessRecursiveQueryAsync(request, null, null, !inAllowedZone, false); } else { @@ -1128,16 +1050,16 @@ namespace DnsServerCore.Dns } } - return ProcessRecursiveQuery(request, null, forwarders, !inAllowedZone, false); + return ProcessRecursiveQueryAsync(request, null, forwarders, !inAllowedZone, false); } } } } - return response; + return Task.FromResult(response); } - private DnsDatagram ProcessCNAME(DnsDatagram request, DnsDatagram response, bool isRecursionAllowed, bool cacheRefreshOperation) + private async Task ProcessCNAMEAsync(DnsDatagram request, DnsDatagram response, bool isRecursionAllowed, bool cacheRefreshOperation) { List responseAnswer = new List(); responseAnswer.AddRange(response.Answer); @@ -1160,7 +1082,7 @@ namespace DnsServerCore.Dns if (newRequest.RecursionDesired && isRecursionAllowed) { //do recursion - lastResponse = RecursiveResolve(newRequest, null, null, false, cacheRefreshOperation); + lastResponse = await RecursiveResolveAsync(newRequest, null, null, false, cacheRefreshOperation); isAuthoritativeAnswer = false; } else @@ -1171,7 +1093,7 @@ namespace DnsServerCore.Dns } else if ((lastResponse.Answer.Count > 0) && (lastResponse.Answer[0].Type == DnsResourceRecordType.ANAME)) { - lastResponse = ProcessANAME(request, lastResponse, isRecursionAllowed); + lastResponse = await ProcessANAMEAsync(request, lastResponse, isRecursionAllowed); } else if ((lastResponse.Answer.Count == 0) && (lastResponse.Authority.Count > 0)) { @@ -1184,7 +1106,7 @@ namespace DnsServerCore.Dns //do recursive resolution using last response authority name servers List nameServers = NameServerAddress.GetNameServersFromResponse(lastResponse, _preferIPv6, false); - lastResponse = RecursiveResolve(newRequest, nameServers, null, false, false); + lastResponse = await RecursiveResolveAsync(newRequest, nameServers, null, false, false); isAuthoritativeAnswer = false; } @@ -1194,7 +1116,7 @@ namespace DnsServerCore.Dns if ((lastResponse.Authority.Count == 1) && (lastResponse.Authority[0].RDATA as DnsForwarderRecord).Forwarder.Equals("this-server", StringComparison.OrdinalIgnoreCase)) { //do conditional forwarding via "this-server" - lastResponse = RecursiveResolve(newRequest, null, null, false, false); + lastResponse = await RecursiveResolveAsync(newRequest, null, null, false, false); isAuthoritativeAnswer = false; } else @@ -1213,7 +1135,7 @@ namespace DnsServerCore.Dns } } - lastResponse = RecursiveResolve(newRequest, null, forwarders, false, false); + lastResponse = await RecursiveResolveAsync(newRequest, null, forwarders, false, false); isAuthoritativeAnswer = false; } @@ -1263,7 +1185,7 @@ namespace DnsServerCore.Dns return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, isAuthoritativeAnswer, false, request.RecursionDesired, isRecursionAllowed, false, false, rcode, request.Question, responseAnswer, authority, additional) { Tag = response.Tag }; } - private DnsDatagram ProcessANAME(DnsDatagram request, DnsDatagram response, bool isRecursionAllowed) + private async Task ProcessANAMEAsync(DnsDatagram request, DnsDatagram response, bool isRecursionAllowed) { List responseAnswer = new List(); @@ -1285,7 +1207,7 @@ namespace DnsServerCore.Dns if (lastResponse.RCODE == DnsResponseCode.Refused) { //not found in auth zone; do recursion - lastResponse = RecursiveResolve(newRequest, null, null, false, false); + lastResponse = await RecursiveResolveAsync(newRequest, null, null, false, false); } else if ((lastResponse.Answer.Count == 0) && (lastResponse.Authority.Count > 0)) { @@ -1296,14 +1218,14 @@ namespace DnsServerCore.Dns //do recursive resolution using last response authority name servers List nameServers = NameServerAddress.GetNameServersFromResponse(lastResponse, _preferIPv6, false); - lastResponse = RecursiveResolve(newRequest, nameServers, null, false, false); + lastResponse = await RecursiveResolveAsync(newRequest, nameServers, null, false, false); break; case DnsResourceRecordType.FWD: if ((lastResponse.Authority.Count == 1) && (lastResponse.Authority[0].RDATA as DnsForwarderRecord).Forwarder.Equals("this-server", StringComparison.OrdinalIgnoreCase)) { //do conditional forwarding via "this-server" - lastResponse = RecursiveResolve(newRequest, null, null, false, false); + lastResponse = await RecursiveResolveAsync(newRequest, null, null, false, false); } else { @@ -1321,7 +1243,7 @@ namespace DnsServerCore.Dns } } - lastResponse = RecursiveResolve(newRequest, null, forwarders, false, false); + lastResponse = await RecursiveResolveAsync(newRequest, null, forwarders, false, false); } break; @@ -1410,9 +1332,9 @@ namespace DnsServerCore.Dns return response; } - private DnsDatagram ProcessRecursiveQuery(DnsDatagram request, IReadOnlyList viaNameServers, IReadOnlyList viaForwarders, bool checkForCnameCloaking, bool cacheRefreshOperation) + private async Task ProcessRecursiveQueryAsync(DnsDatagram request, IReadOnlyList viaNameServers, IReadOnlyList viaForwarders, bool checkForCnameCloaking, bool cacheRefreshOperation) { - DnsDatagram response = RecursiveResolve(request, viaNameServers, viaForwarders, false, cacheRefreshOperation); + DnsDatagram response = await RecursiveResolveAsync(request, viaNameServers, viaForwarders, false, cacheRefreshOperation); if (response.Answer.Count > 0) { @@ -1420,7 +1342,7 @@ namespace DnsServerCore.Dns DnsResourceRecord lastRR = response.Answer[response.Answer.Count - 1]; if ((lastRR.Type != questionType) && (lastRR.Type == DnsResourceRecordType.CNAME) && (questionType != DnsResourceRecordType.ANY)) - response = ProcessCNAME(request, response, true, cacheRefreshOperation); + response = await ProcessCNAMEAsync(request, response, true, cacheRefreshOperation); if (checkForCnameCloaking) { @@ -1432,6 +1354,13 @@ namespace DnsServerCore.Dns break; //no further CNAME records exists DnsDatagram newRequest = new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { new DnsQuestionRecord((record.RDATA as DnsCNAMERecord).Domain, request.Question[0].Type, request.Question[0].Class) }); + + //check allowed zone + bool inAllowedZone = _allowedZoneManager.Query(newRequest).RCODE != DnsResponseCode.Refused; + if (inAllowedZone) + break; //CNAME is in allowed zone + + //check blocked zone and block list zone DnsDatagram lastResponse = ProcessBlockedQuery(newRequest); if (lastResponse != null) { @@ -1467,7 +1396,7 @@ namespace DnsServerCore.Dns } } - private DnsDatagram RecursiveResolve(DnsDatagram request, IReadOnlyList viaNameServers, IReadOnlyList viaForwarders, bool cachePrefetchOperation, bool cacheRefreshOperation) + private async Task RecursiveResolveAsync(DnsDatagram request, IReadOnlyList viaNameServers, IReadOnlyList viaForwarders, bool cachePrefetchOperation, bool cacheRefreshOperation) { if (!cachePrefetchOperation && !cacheRefreshOperation) { @@ -1482,21 +1411,8 @@ namespace DnsServerCore.Dns { if ((answer.OriginalTtlValue > _cachePrefetchEligibility) && (answer.TtlValue < _cachePrefetchTrigger)) { - //trigger prefetch in worker thread - ThreadPool.QueueUserWorkItem(delegate (object state) - { - try - { - RecursiveResolve(request, viaNameServers, viaForwarders, true, false); - } - catch (Exception ex) - { - LogManager log = _log; - if (log != null) - log.Write(ex); - } - }); - + //trigger prefetch async + _ = PrefetchCacheAsync(request, viaNameServers, viaForwarders); break; } } @@ -1507,13 +1423,16 @@ namespace DnsServerCore.Dns } //recursion with locking - ResolverQueryHandle newQueryHandle = new ResolverQueryHandle(); - ResolverQueryHandle queryHandle = _resolverQueryHandles.GetOrAdd(GetResolverQueryKey(request.Question[0]), newQueryHandle); + ResolverTask newResolverTask = new ResolverTask(); + ResolverTask resolverTask = _resolverTasks.GetOrAdd(GetResolverQueryKey(request.Question[0]), newResolverTask); - if (queryHandle.Equals(newQueryHandle)) + if (resolverTask.Equals(newResolverTask)) { - //got query handle so question not being resolved; do recursive resolution in worker thread - ThreadPool.QueueUserWorkItem(RecursiveResolveAsync, new object[] { request, viaNameServers, viaForwarders, cachePrefetchOperation, cacheRefreshOperation, queryHandle }); + //got new resolver task added so question is not being resolved; do recursive resolution in another task on resolver thread pool + _ = Task.Factory.StartNew(delegate () + { + return RecursiveResolveAsync(request, viaNameServers, viaForwarders, cachePrefetchOperation, cacheRefreshOperation, newResolverTask.TaskCompletionSource); + }, CancellationToken.None, TaskCreationOptions.DenyChildAttach, _resolverTaskScheduler); } //request is being recursively resolved by another thread @@ -1521,10 +1440,25 @@ namespace DnsServerCore.Dns if (cachePrefetchOperation) return null; //return null as prefetch worker thread does not need valid response and thus does not need to wait + if (resolverTask.IsStuck(_resolverTimeout)) + { + //resolver task is taking a long time to complete + //query cache zone to return stale answer (if available) + DnsDatagram staleResponse = QueryCache(request, true); + if (staleResponse != null) + return staleResponse; + + //no stale response available; wait for resolver task + } + + DateTime resolverWaitStartTime = DateTime.UtcNow; + //wait till short timeout for response - if (queryHandle.WaitForResponse(1800, out DnsDatagram response)) //1.8 sec wait as per draft-ietf-dnsop-serve-stale-04 + if (await Task.WhenAny(resolverTask.TaskCompletionSource.Task, Task.Delay(1800)) == resolverTask.TaskCompletionSource.Task) //1.8 sec wait as per draft-ietf-dnsop-serve-stale-04 { //resolver signaled + DnsDatagram response = await resolverTask.TaskCompletionSource.Task; + if (response != null) return response; @@ -1534,19 +1468,27 @@ namespace DnsServerCore.Dns { //wait timed out //query cache zone to return stale answer (if available) as per draft-ietf-dnsop-serve-stale-04 - DnsDatagram cacheResponse = QueryCache(request, true); - if ((cacheResponse != null) && (cacheResponse.RCODE == DnsResponseCode.NoError)) - return cacheResponse; + DnsDatagram staleResponse = QueryCache(request, true); + if (staleResponse != null) + return staleResponse; - //wait till full timeout before responding as ServerFailure - int timeout = _timeout - 1800; - if (timeout > 0) + if ((DateTime.UtcNow - resolverWaitStartTime).TotalMilliseconds < _clientTimeout) //check if there is any point in waiting further due to execution delay { - queryHandle.WaitForResponse(timeout, out response); - if (response != null) - return response; + //wait till full timeout before responding as ServerFailure + int timeout = _clientTimeout - 1800; + if (timeout > 0) + { + if (await Task.WhenAny(resolverTask.TaskCompletionSource.Task, Task.Delay(timeout)) == resolverTask.TaskCompletionSource.Task) + { + //resolver signaled + DnsDatagram response = await resolverTask.TaskCompletionSource.Task; - //no response available from resolver or resolver had exception and no stale record was found + if (response != null) + return response; + } + + //no response available from resolver or resolver had exception and no stale record was found + } } } @@ -1554,17 +1496,8 @@ namespace DnsServerCore.Dns return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.ServerFailure, request.Question); } - private void RecursiveResolveAsync(object parameter) + private async Task RecursiveResolveAsync(DnsDatagram request, IReadOnlyList viaNameServers, IReadOnlyList viaForwarders, bool cachePrefetchOperation, bool cacheRefreshOperation, TaskCompletionSource taskCompletionSource) { - object[] parameters = parameter as object[]; - - DnsDatagram request = parameters[0] as DnsDatagram; - IReadOnlyList viaNameServers = parameters[1] as IReadOnlyList; - IReadOnlyList viaForwarders = parameters[2] as IReadOnlyList; - bool cachePrefetchOperation = (bool)parameters[3]; - bool cacheRefreshOperation = (bool)parameters[4]; - ResolverQueryHandle queryHandle = parameters[5] as ResolverQueryHandle; - IReadOnlyList forwarders = _forwarders; if (viaForwarders != null) forwarders = viaForwarders; //use provided forwarders @@ -1574,14 +1507,13 @@ namespace DnsServerCore.Dns if ((viaNameServers == null) && (forwarders != null)) { //use forwarders - if (_proxy == null) { //recursive resolve name server when proxy is null else let proxy resolve it foreach (NameServerAddress nameServerAddress in forwarders) { if (nameServerAddress.IsIPEndPointStale) //refresh forwarder IPEndPoint if stale - nameServerAddress.RecursiveResolveIPAddress(_dnsCache, null, _preferIPv6, _retries, _timeout); + await nameServerAddress.RecursiveResolveIPAddressAsync(_dnsCache, null, _preferIPv6, _resolverRetries, _resolverTimeout); } } @@ -1590,14 +1522,15 @@ namespace DnsServerCore.Dns dnsClient.Proxy = _proxy; dnsClient.PreferIPv6 = _preferIPv6; - dnsClient.Retries = _retries; - dnsClient.Timeout = _timeout; + dnsClient.Retries = _forwarderRetries; + dnsClient.Timeout = _forwarderTimeout; + dnsClient.Concurrency = _forwarderConcurrency; - DnsDatagram response = dnsClient.Resolve(request.Question[0]); + DnsDatagram response = await dnsClient.ResolveAsync(request.Question[0]); _cacheZoneManager.CacheResponse(response); - queryHandle.Set(response); + taskCompletionSource.SetResult(response); } else { @@ -1609,8 +1542,8 @@ namespace DnsServerCore.Dns else dnsCache = _dnsCache; - DnsDatagram response = DnsClient.RecursiveResolve(request.Question[0], viaNameServers, dnsCache, _proxy, _preferIPv6, _retries, _timeout, _maxStackCount); - queryHandle.Set(response); + DnsDatagram response = await DnsClient.RecursiveResolveAsync(request.Question[0], viaNameServers, dnsCache, _proxy, _preferIPv6, _resolverRetries, _resolverTimeout, _resolverMaxStackCount); + taskCompletionSource.SetResult(response); } } catch (Exception ex) @@ -1645,28 +1578,28 @@ namespace DnsServerCore.Dns } //fetch stale record - DnsDatagram cacheResponse = QueryCache(request, true); - if (cacheResponse == null) + DnsDatagram staleResponse = QueryCache(request, true); + if (staleResponse == null) { - //no stale record was found; signal null response to release waiting threads - queryHandle.Set(null); + //no stale record was found; signal null response to release waiting tasks + taskCompletionSource.SetResult(null); } else { //reset expiry for stale records - foreach (DnsResourceRecord record in cacheResponse.Answer) + foreach (DnsResourceRecord record in staleResponse.Answer) { if (record.IsStale) record.ResetExpiry(30); //reset expiry by 30 seconds so that resolver tries again only after 30 seconds as per draft-ietf-dnsop-serve-stale-04 } - //signal stale record - queryHandle.Set(cacheResponse); + //signal stale response + taskCompletionSource.SetResult(staleResponse); } } finally { - _resolverQueryHandles.TryRemove(GetResolverQueryKey(request.Question[0]), out _); + _resolverTasks.TryRemove(GetResolverQueryKey(request.Question[0]), out _); } } @@ -1695,6 +1628,52 @@ namespace DnsServerCore.Dns return null; } + private async Task PrefetchCacheAsync(DnsDatagram request, IReadOnlyList viaNameServers, IReadOnlyList viaForwarders) + { + try + { + await RecursiveResolveAsync(request, viaNameServers, viaForwarders, true, false); + } + catch (Exception ex) + { + LogManager log = _log; + if (log != null) + log.Write(ex); + } + } + + private async Task RefreshCacheAsync(IList cacheRefreshSampleList, DnsQuestionRecord sampleQuestion, int sampleQuestionIndex) + { + try + { + //refresh cache + DnsDatagram request = new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { sampleQuestion }); + DnsDatagram response = await ProcessRecursiveQueryAsync(request, null, null, false, true); + + bool removeFromSampleList = true; + DateTime utcNow = DateTime.UtcNow; + + foreach (DnsResourceRecord answer in response.Answer) + { + if ((answer.OriginalTtlValue > _cachePrefetchEligibility) && (utcNow.AddSeconds(answer.TtlValue) < _cachePrefetchSamplingTimerTriggersOn)) + { + //answer expires before next sampling so dont remove from list to allow refreshing it + removeFromSampleList = false; + break; + } + } + + if (removeFromSampleList) + cacheRefreshSampleList[sampleQuestionIndex] = null; + } + catch (Exception ex) + { + LogManager log = _log; + if (log != null) + log.Write(ex); + } + } + private DnsQuestionRecord GetCacheRefreshNeededQuery(DnsQuestionRecord question, int trigger) { int queryCount = 0; @@ -1732,7 +1711,7 @@ namespace DnsServerCore.Dns } } - private bool CacheRefreshNeeded(DnsQuestionRecord question, int trigger) + private bool IsCacheRefreshNeeded(DnsQuestionRecord question, int trigger) { DnsDatagram cacheResponse = QueryCache(new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { question }), false); if (cacheResponse == null) @@ -1751,7 +1730,7 @@ namespace DnsServerCore.Dns return false; //no need to refresh for this query } - private void CachePrefetchSamplingAsync(object state) + private void CachePrefetchSamplingTimerCallback(object state) { try { @@ -1809,7 +1788,7 @@ namespace DnsServerCore.Dns } } - private void CachePrefetchRefreshAsync(object state) + private void CachePrefetchRefreshTimerCallback(object state) { try { @@ -1822,42 +1801,10 @@ namespace DnsServerCore.Dns if (sampleQuestion == null) continue; - if (!CacheRefreshNeeded(sampleQuestion, _cachePrefetchTrigger + 2)) + if (!IsCacheRefreshNeeded(sampleQuestion, _cachePrefetchTrigger + 2)) continue; - int sampleQuestionIndex = i; - - ThreadPool.QueueUserWorkItem(delegate (object state2) - { - try - { - //refresh cache - DnsDatagram request = new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { sampleQuestion }); - DnsDatagram response = ProcessRecursiveQuery(request, null, null, false, true); - - bool removeFromSampleList = true; - DateTime utcNow = DateTime.UtcNow; - - foreach (DnsResourceRecord answer in response.Answer) - { - if ((answer.OriginalTtlValue > _cachePrefetchEligibility) && (utcNow.AddSeconds(answer.TtlValue) < _cachePrefetchSamplingTimerTriggersOn)) - { - //answer expires before next sampling so dont remove from list to allow refreshing it - removeFromSampleList = false; - break; - } - } - - if (removeFromSampleList) - cacheRefreshSampleList[sampleQuestionIndex] = null; - } - catch (Exception ex) - { - LogManager log = _log; - if (log != null) - log.Write(ex); - } - }); + _ = RefreshCacheAsync(cacheRefreshSampleList, sampleQuestion, i); } } } @@ -1872,12 +1819,12 @@ namespace DnsServerCore.Dns lock (_cachePrefetchRefreshTimerLock) { if (_cachePrefetchRefreshTimer != null) - _cachePrefetchRefreshTimer.Change((_cachePrefetchTrigger + 1) * 1000, System.Threading.Timeout.Infinite); + _cachePrefetchRefreshTimer.Change((_cachePrefetchTrigger + 1) * 1000, Timeout.Infinite); } } } - private void CacheMaintenanceAsync(object state) + private void CacheMaintenanceTimerCallback(object state) { try { @@ -2118,69 +2065,66 @@ namespace DnsServerCore.Dns } //start reading query packets + int listenerThreadCount = Math.Max(1, Environment.ProcessorCount); + foreach (Socket udpListener in _udpListeners) { - for (int i = 0; i < LISTENER_THREAD_COUNT; i++) + for (int i = 0; i < listenerThreadCount; i++) { - Thread listenerThread = new Thread(ReadUdpRequestAsync); - listenerThread.IsBackground = true; - listenerThread.Start(udpListener); - - _listenerThreads.Add(listenerThread); + Thread thread = new Thread(ReadUdpRequestAsync); + thread.Name = "DNS UDP Read Request [" + i + "]"; + thread.IsBackground = true; + thread.Start(udpListener); } } foreach (Socket tcpListener in _tcpListeners) { - for (int i = 0; i < LISTENER_THREAD_COUNT; i++) + for (int i = 0; i < listenerThreadCount; i++) { - Thread listenerThread = new Thread(AcceptConnectionAsync); - listenerThread.IsBackground = true; - listenerThread.Start(new object[] { tcpListener, DnsTransportProtocol.Tcp }); - - _listenerThreads.Add(listenerThread); + Thread thread = new Thread(AcceptConnectionAsync); + thread.Name = "DNS TCP Read Request [" + i + "]"; + thread.IsBackground = true; + thread.Start(new object[] { tcpListener, DnsTransportProtocol.Tcp }); } } foreach (Socket httpListener in _httpListeners) { - for (int i = 0; i < LISTENER_THREAD_COUNT; i++) + for (int i = 0; i < listenerThreadCount; i++) { - Thread listenerThread = new Thread(AcceptConnectionAsync); - listenerThread.IsBackground = true; - listenerThread.Start(new object[] { httpListener, DnsTransportProtocol.Https, false }); - - _listenerThreads.Add(listenerThread); + Thread thread = new Thread(AcceptConnectionAsync); + thread.Name = "DNS HTTP Read Request [" + i + "]"; + thread.IsBackground = true; + thread.Start(new object[] { httpListener, DnsTransportProtocol.Https, false }); } } foreach (Socket tlsListener in _tlsListeners) { - for (int i = 0; i < LISTENER_THREAD_COUNT; i++) + for (int i = 0; i < listenerThreadCount; i++) { - Thread listenerThread = new Thread(AcceptConnectionAsync); - listenerThread.IsBackground = true; - listenerThread.Start(new object[] { tlsListener, DnsTransportProtocol.Tls }); - - _listenerThreads.Add(listenerThread); + Thread thread = new Thread(AcceptConnectionAsync); + thread.Name = "DNS TLS Read Request [" + i + "]"; + thread.IsBackground = true; + thread.Start(new object[] { tlsListener, DnsTransportProtocol.Tls }); } } foreach (Socket httpsListener in _httpsListeners) { - for (int i = 0; i < LISTENER_THREAD_COUNT; i++) + for (int i = 0; i < listenerThreadCount; i++) { - Thread listenerThread = new Thread(AcceptConnectionAsync); - listenerThread.IsBackground = true; - listenerThread.Start(new object[] { httpsListener, DnsTransportProtocol.Https }); - - _listenerThreads.Add(listenerThread); + Thread thread = new Thread(AcceptConnectionAsync); + thread.Name = "DNS HTTPS Read Request [" + i + "]"; + thread.IsBackground = true; + thread.Start(new object[] { httpsListener, DnsTransportProtocol.Https }); } } - _cachePrefetchSamplingTimer = new Timer(CachePrefetchSamplingAsync, null, System.Threading.Timeout.Infinite, System.Threading.Timeout.Infinite); - _cachePrefetchRefreshTimer = new Timer(CachePrefetchRefreshAsync, null, System.Threading.Timeout.Infinite, System.Threading.Timeout.Infinite); - _cacheMaintenanceTimer = new Timer(CacheMaintenanceAsync, null, CACHE_MAINTENANCE_TIMER_INITIAL_INTEVAL, CACHE_MAINTENANCE_TIMER_PERIODIC_INTERVAL); + _cachePrefetchSamplingTimer = new Timer(CachePrefetchSamplingTimerCallback, null, Timeout.Infinite, Timeout.Infinite); + _cachePrefetchRefreshTimer = new Timer(CachePrefetchRefreshTimerCallback, null, Timeout.Infinite, Timeout.Infinite); + _cacheMaintenanceTimer = new Timer(CacheMaintenanceTimerCallback, null, CACHE_MAINTENANCE_TIMER_INITIAL_INTEVAL, CACHE_MAINTENANCE_TIMER_PERIODIC_INTERVAL); _state = ServiceState.Running; @@ -2281,7 +2225,6 @@ namespace DnsServerCore.Dns } } - _listenerThreads.Clear(); _udpListeners.Clear(); _tcpListeners.Clear(); _httpListeners.Clear(); @@ -2291,29 +2234,30 @@ namespace DnsServerCore.Dns _state = ServiceState.Stopped; } - public DnsDatagram DirectQuery(DnsQuestionRecord question, int timeout = 2000) + public async Task DirectQueryAsync(DnsQuestionRecord question, int timeout = 2000) { - EventWaitHandle waitHandle = new ManualResetEvent(false); - DnsDatagram response = null; - - ThreadPool.QueueUserWorkItem(delegate (object state) + try { - try + Task task = ProcessQueryAsync(new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { question }), new IPEndPoint(IPAddress.Any, 0), true, DnsTransportProtocol.Tcp); + + using (CancellationTokenSource timeoutCancellationTokenSource = new CancellationTokenSource()) { - response = ProcessQuery(new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { question }), new IPEndPoint(IPAddress.Any, 0), true, DnsTransportProtocol.Tcp); - } - catch (Exception ex) - { - LogManager log = _log; - if (log != null) - log.Write(ex); + if (await Task.WhenAny(task, Task.Delay(timeout, timeoutCancellationTokenSource.Token)) != task) + return null; + + timeoutCancellationTokenSource.Cancel(); //stop delay task } - waitHandle.Set(); - }); + return await task; + } + catch (Exception ex) + { + LogManager log = _log; + if (log != null) + log.Write(ex); - waitHandle.WaitOne(timeout); - return response; + return null; + } } #endregion @@ -2432,30 +2376,66 @@ namespace DnsServerCore.Dns set { _preferIPv6 = value; } } - public int Retries + public int ForwarderRetries { - get { return _retries; } + get { return _forwarderRetries; } set { if (value > 0) - _retries = value; + _forwarderRetries = value; } } - public int Timeout + public int ResolverRetries { - get { return _timeout; } + get { return _resolverRetries; } + set + { + if (value > 0) + _resolverRetries = value; + } + } + + public int ForwarderTimeout + { + get { return _forwarderTimeout; } set { if (value >= 2000) - _timeout = value; + _forwarderTimeout = value; } } - public int MaxStackCount + public int ResolverTimeout { - get { return _maxStackCount; } - set { _maxStackCount = value; } + get { return _resolverTimeout; } + set + { + if (value >= 2000) + _resolverTimeout = value; + } + } + + public int ClientTimeout + { + get { return _clientTimeout; } + set + { + if (value >= 2000) + _clientTimeout = value; + } + } + + public int ForwarderConcurrency + { + get { return _forwarderConcurrency; } + set { _forwarderConcurrency = value; } + } + + public int ResolverMaxStackCount + { + get { return _resolverMaxStackCount; } + set { _resolverMaxStackCount = value; } } public int CachePrefetchEligibility