diff --git a/DnsServerCore/Dns/DnsServer.cs b/DnsServerCore/Dns/DnsServer.cs index 55998446..4e5eb82b 100644 --- a/DnsServerCore/Dns/DnsServer.cs +++ b/DnsServerCore/Dns/DnsServer.cs @@ -175,6 +175,7 @@ namespace DnsServerCore.Dns X509Certificate2Collection _certificateCollection; SslServerAuthenticationOptions _dotSslServerAuthenticationOptions; SslServerAuthenticationOptions _doqSslServerAuthenticationOptions; + SslServerAuthenticationOptions _dohSslServerAuthenticationOptions; string _dnsOverHttpRealIpHeader = "X-Real-IP"; IReadOnlyDictionary _tsigKeys; @@ -239,8 +240,9 @@ namespace DnsServerCore.Dns readonly IndependentTaskScheduler _queryTaskScheduler = new IndependentTaskScheduler(); + TaskPool _resolverTaskPool; readonly IndependentTaskScheduler _resolverTaskScheduler = new IndependentTaskScheduler(ThreadPriority.AboveNormal); - readonly ConcurrentDictionary> _resolverTasks = new ConcurrentDictionary>(); + readonly ConcurrentDictionary> _resolverTasks = new ConcurrentDictionary>(-1, 1000); volatile ServiceState _state = ServiceState.Stopped; @@ -282,6 +284,8 @@ namespace DnsServerCore.Dns _localEndPoints = localEndPoints; _log = log; + ReconfigureResolverTaskPool(100); + _authZoneManager = new AuthZoneManager(this); _allowedZoneManager = new AllowedZoneManager(this); _blockedZoneManager = new BlockedZoneManager(this); @@ -318,6 +322,8 @@ namespace DnsServerCore.Dns _stats?.Dispose(); + _resolverTaskPool?.Dispose(); + _disposed = true; } @@ -786,9 +792,19 @@ namespace DnsServerCore.Dns { while (true) { - QuicConnection quicConnection = await quicListener.AcceptConnectionAsync(); + try + { + QuicConnection quicConnection = await quicListener.AcceptConnectionAsync(); - _ = ProcessQuicConnectionAsync(quicConnection); + _ = ProcessQuicConnectionAsync(quicConnection); + } + catch (QuicException ex) + { + if (ex.InnerException is OperationCanceledException) + continue; + + throw; + } } } catch (ObjectDisposedException) @@ -1553,7 +1569,7 @@ namespace DnsServerCore.Dns { foreach (KeyValuePair> rrsetEntry in zoneEntry.Value) { - IReadOnlyList prRRSet = rrsetEntry.Value; + List prRRSet = rrsetEntry.Value; IReadOnlyList rrset = _authZoneManager.GetRecords(zoneInfo.Name, zoneEntry.Key, rrsetEntry.Key); //check if RRSet exists (value dependent) @@ -2145,6 +2161,9 @@ namespace DnsServerCore.Dns case DnsResourceRecordType.APP: response = await ProcessAPPAsync(request, response, remoteEP, protocol, isRecursionAllowed, skipDnsAppAuthoritativeRequestHandlers); + if (response is null) + return null; //drop request + reprocessResponse = true; break; } @@ -2158,11 +2177,7 @@ namespace DnsServerCore.Dns private async Task AuthoritativeQueryAsync(DnsDatagram request, IPEndPoint remoteEP, DnsTransportProtocol protocol, bool isRecursionAllowed, bool skipDnsAppAuthoritativeRequestHandlers) { - DnsDatagram response = await TechnitiumLibrary.TaskExtensions.TimeoutAsync(delegate (CancellationToken cancellationToken1) - { - return _authZoneManager.QueryAsync(request, remoteEP.Address, isRecursionAllowed, cancellationToken1); - }, _clientTimeout); - + DnsDatagram response = await _authZoneManager.QueryAsync(request, remoteEP.Address, isRecursionAllowed); if (response is not null) { response.Tag = DnsServerResponseType.Authoritative; @@ -2304,12 +2319,16 @@ namespace DnsServerCore.Dns DnsResourceRecord lastRR = response.GetLastAnswerRecord(); EDnsOption[] eDnsClientSubnetOption = null; DnsDatagram newResponse = null; + double responseRtt = 0.0; + + if (response.Metadata is not null) + responseRtt = response.Metadata.RoundTripTime; if (_eDnsClientSubnet) { EDnsClientSubnetOptionData requestECS = request.GetEDnsClientSubnetOption(); if (requestECS is not null) - eDnsClientSubnetOption = new EDnsOption[] { new EDnsOption(EDnsOptionCode.EDNS_CLIENT_SUBNET, requestECS) }; + eDnsClientSubnetOption = [new EDnsOption(EDnsOptionCode.EDNS_CLIENT_SUBNET, requestECS)]; } int queryCount = 0; @@ -2330,6 +2349,9 @@ namespace DnsServerCore.Dns { //do recursion newResponse = await RecursiveResolveAsync(newRequest, remoteEP, null, _dnssecValidation, false, cacheRefreshOperation, skipDnsAppAuthoritativeRequestHandlers); + if (newResponse is null) + return null; //drop request + isAuthoritativeAnswer = false; } else @@ -2341,6 +2363,8 @@ namespace DnsServerCore.Dns else if ((newResponse.Answer.Count > 0) && (newResponse.GetLastAnswerRecord() is DnsResourceRecord lastAnswer) && ((lastAnswer.Type == DnsResourceRecordType.ANAME) || (lastAnswer.Type == DnsResourceRecordType.ALIAS))) { newResponse = await ProcessANAMEAsync(request, newResponse, remoteEP, protocol, isRecursionAllowed, skipDnsAppAuthoritativeRequestHandlers); + if (newResponse is null) + return null; //drop request } else if ((newResponse.Answer.Count == 0) && (newResponse.Authority.Count > 0)) { @@ -2353,6 +2377,9 @@ namespace DnsServerCore.Dns { //do forced recursive resolution using empty conditional forwarders; name servers will be provided via ResolveDnsCache newResponse = await RecursiveResolveAsync(newRequest, remoteEP, [], _dnssecValidation, false, false, skipDnsAppAuthoritativeRequestHandlers); + if (newResponse is null) + return null; //drop request + isAuthoritativeAnswer = false; } @@ -2361,15 +2388,24 @@ namespace DnsServerCore.Dns case DnsResourceRecordType.FWD: //do conditional forwarding newResponse = await RecursiveResolveAsync(newRequest, remoteEP, newResponse.Authority, _dnssecValidation, false, false, skipDnsAppAuthoritativeRequestHandlers); + if (newResponse is null) + return null; //drop request + isAuthoritativeAnswer = false; break; case DnsResourceRecordType.APP: newResponse = await ProcessAPPAsync(newRequest, newResponse, remoteEP, protocol, isRecursionAllowed, skipDnsAppAuthoritativeRequestHandlers); + if (newResponse is null) + return null; //drop request + break; } } + if (newResponse.Metadata is not null) + responseRtt += newResponse.Metadata.RoundTripTime; + //check last response if (newResponse.Answer.Count == 0) break; //cannot proceed to resolve further @@ -2448,7 +2484,10 @@ namespace DnsServerCore.Dns additional = newResponse.Additional; } - return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, isAuthoritativeAnswer, false, request.RecursionDesired, isRecursionAllowed, false, request.CheckingDisabled, rcode, request.Question, newAnswer, authority, additional) { Tag = response.Tag }; + DnsDatagram finalResponse = new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, isAuthoritativeAnswer, false, request.RecursionDesired, isRecursionAllowed, false, request.CheckingDisabled, rcode, request.Question, newAnswer, authority, additional) { Tag = response.Tag }; + finalResponse.SetMetadata(null, responseRtt); + + return finalResponse; } private async Task ProcessANAMEAsync(DnsDatagram request, DnsDatagram response, IPEndPoint remoteEP, DnsTransportProtocol protocol, bool isRecursionAllowed, bool skipDnsAppAuthoritativeRequestHandlers) @@ -2459,7 +2498,7 @@ namespace DnsServerCore.Dns { EDnsClientSubnetOptionData requestECS = request.GetEDnsClientSubnetOption(); if (requestECS is not null) - eDnsClientSubnetOption = new EDnsOption[] { new EDnsOption(EDnsOptionCode.EDNS_CLIENT_SUBNET, requestECS) }; + eDnsClientSubnetOption = [new EDnsOption(EDnsOptionCode.EDNS_CLIENT_SUBNET, requestECS)]; } Queue>> resolveQueue = new Queue>>(); @@ -2480,6 +2519,8 @@ namespace DnsServerCore.Dns { //not found in auth zone; do recursion newResponse = await RecursiveResolveAsync(newRequest, remoteEP, null, _dnssecValidation, false, false, skipDnsAppAuthoritativeRequestHandlers); + if (newResponse is null) + return null; //drop request } else if ((newResponse.Answer.Count == 0) && (newResponse.Authority.Count > 0)) { @@ -2490,15 +2531,24 @@ namespace DnsServerCore.Dns case DnsResourceRecordType.NS: //do forced recursive resolution using empty conditional forwarders; name servers will be provided via ResolverDnsCache newResponse = await RecursiveResolveAsync(newRequest, remoteEP, [], _dnssecValidation, false, false, skipDnsAppAuthoritativeRequestHandlers); + if (newResponse is null) + return null; //drop request + break; case DnsResourceRecordType.FWD: //do conditional forwarding newResponse = await RecursiveResolveAsync(newRequest, remoteEP, newResponse.Authority, _dnssecValidation, false, false, skipDnsAppAuthoritativeRequestHandlers); + if (newResponse is null) + return null; //drop request + break; case DnsResourceRecordType.APP: newResponse = await ProcessAPPAsync(newRequest, newResponse, remoteEP, protocol, isRecursionAllowed, skipDnsAppAuthoritativeRequestHandlers); + if (newResponse is null) + return null; //drop request + break; } } @@ -2852,6 +2902,8 @@ namespace DnsServerCore.Dns } DnsDatagram response = await RecursiveResolveAsync(request, remoteEP, conditionalForwarders, dnssecValidation, false, cacheRefreshOperation, skipDnsAppAuthoritativeRequestHandlers); + if (response is null) + return null; //drop request if (response.Answer.Count > 0) { @@ -2859,7 +2911,11 @@ namespace DnsServerCore.Dns DnsResourceRecord lastRR = response.GetLastAnswerRecord(); if ((lastRR.Type != questionType) && (lastRR.Type == DnsResourceRecordType.CNAME) && (questionType != DnsResourceRecordType.ANY)) + { response = await ProcessCNAMEAsync(request, response, remoteEP, protocol, true, cacheRefreshOperation, skipDnsAppAuthoritativeRequestHandlers); + if (response is null) + return null; //drop request + } if (!isAllowed) { @@ -3049,10 +3105,18 @@ namespace DnsServerCore.Dns if (resolverTask.Equals(resolverTaskCompletionSource.Task)) { //got new resolver task added so question is not being resolved; do recursive resolution in another task on resolver thread pool - _ = Task.Factory.StartNew(delegate () + if (!_resolverTaskPool.TryQueueTask(delegate (object state) + { + return RecursiveResolverBackgroundTaskAsync(question, eDnsClientSubnet, advancedForwardingClientSubnet, conditionalForwarders, dnssecValidation, cachePrefetchOperation, cacheRefreshOperation, skipDnsAppAuthoritativeRequestHandlers, resolverTaskCompletionSource); + }) + ) { - return RecursiveResolverBackgroundTaskAsync(question, eDnsClientSubnet, advancedForwardingClientSubnet, conditionalForwarders, dnssecValidation, cachePrefetchOperation, cacheRefreshOperation, skipDnsAppAuthoritativeRequestHandlers, resolverTaskCompletionSource); - }, CancellationToken.None, TaskCreationOptions.DenyChildAttach, _resolverTaskScheduler); + //resolver queue full + if (!_resolverTasks.TryRemove(GetResolverQueryKey(question, eDnsClientSubnet), out _)) //remove recursion lock entry + throw new InvalidOperationException(); + + return null; //drop request + } } //request is being recursively resolved by another thread @@ -3525,7 +3589,7 @@ namespace DnsServerCore.Dns } } - private async Task ConcurrentConditionalForwarderResolveAsync(DnsQuestionRecord question, NetworkAddress eDnsClientSubnet, bool advancedForwardingClientSubnet, IDnsCache dnsCache, IReadOnlyList conditionalForwarders, bool skipDnsAppAuthoritativeRequestHandlers, CancellationToken cancellationToken = default) + private async Task ConcurrentConditionalForwarderResolveAsync(DnsQuestionRecord question, NetworkAddress eDnsClientSubnet, bool advancedForwardingClientSubnet, IDnsCache dnsCache, List conditionalForwarders, bool skipDnsAppAuthoritativeRequestHandlers, CancellationToken cancellationToken = default) { if (conditionalForwarders.Count == 1) { @@ -3641,15 +3705,16 @@ namespace DnsServerCore.Dns if (dnssecOk && request.CheckingDisabled) { + DnsDatagram cdResponse = resolveResponse.CheckingDisabledResponse; bool authenticData = false; if (dnssecOk) { - if (resolveResponse.CheckingDisabledResponse.Answer.Count > 0) + if (cdResponse.Answer.Count > 0) { authenticData = true; - foreach (DnsResourceRecord record in resolveResponse.CheckingDisabledResponse.Answer) + foreach (DnsResourceRecord record in cdResponse.Answer) { if (record.DnssecStatus != DnssecStatus.Secure) { @@ -3658,11 +3723,11 @@ namespace DnsServerCore.Dns } } } - else if (resolveResponse.CheckingDisabledResponse.Authority.Count > 0) + else if (cdResponse.Authority.Count > 0) { authenticData = true; - foreach (DnsResourceRecord record in resolveResponse.CheckingDisabledResponse.Authority) + foreach (DnsResourceRecord record in cdResponse.Authority) { if (record.DnssecStatus != DnssecStatus.Secure) { @@ -3673,7 +3738,12 @@ namespace DnsServerCore.Dns } } - return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, authenticData, true, resolveResponse.CheckingDisabledResponse.RCODE, request.Question, resolveResponse.CheckingDisabledResponse.Answer, resolveResponse.CheckingDisabledResponse.Authority, RemoveOPTFromAdditional(resolveResponse.CheckingDisabledResponse.Additional, true), _udpPayloadSize, EDnsHeaderFlags.DNSSEC_OK, resolveResponse.CheckingDisabledResponse.EDNS?.Options); + DnsDatagram finalCdResponse = new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, authenticData, true, cdResponse.RCODE, request.Question, cdResponse.Answer, cdResponse.Authority, RemoveOPTFromAdditional(cdResponse.Additional, true), _udpPayloadSize, EDnsHeaderFlags.DNSSEC_OK, cdResponse.EDNS?.Options); + DnsDatagramMetadata metadata = cdResponse.Metadata; + if (metadata is not null) + finalCdResponse.SetMetadata(metadata.NameServer, metadata.RoundTripTime); + + return finalCdResponse; } DnsDatagram response = resolveResponse.Response; @@ -3888,7 +3958,12 @@ namespace DnsServerCore.Dns } } - return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, authenticData, request.CheckingDisabled, response.RCODE, request.Question, answer, authority, additional); + DnsDatagram finalResponse = new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, authenticData, request.CheckingDisabled, response.RCODE, request.Question, answer, authority, additional); + DnsDatagramMetadata metadata = response.Metadata; + if (metadata is not null) + finalResponse.SetMetadata(metadata.NameServer, metadata.RoundTripTime); + + return finalResponse; } } @@ -3951,7 +4026,7 @@ namespace DnsServerCore.Dns { try { - await RecursiveResolveAsync(request, remoteEP, conditionalForwarders, _dnssecValidation, true, false, false); + _ = await RecursiveResolveAsync(request, remoteEP, conditionalForwarders, _dnssecValidation, true, false, false); } catch (Exception ex) { @@ -3964,19 +4039,26 @@ namespace DnsServerCore.Dns try { //refresh cache + bool addBackToSampleList = false; + DnsDatagram request = new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { sample.SampleQuestion }); DnsDatagram response = await ProcessRecursiveQueryAsync(request, IPENDPOINT_ANY_0, DnsTransportProtocol.Udp, sample.ConditionalForwarders, _dnssecValidation, true, false); - - bool addBackToSampleList = false; - DateTime utcNow = DateTime.UtcNow; - - foreach (DnsResourceRecord answer in response.Answer) + if (response is null) { - if ((answer.OriginalTtlValue >= _cachePrefetchEligibility) && (utcNow.AddSeconds(answer.TTL) < _cachePrefetchSamplingTimerTriggersOn)) + addBackToSampleList = true; + } + else + { + DateTime utcNow = DateTime.UtcNow; + + foreach (DnsResourceRecord answer in response.Answer) { - //answer expires before next sampling so add back to the list to allow refreshing it - addBackToSampleList = true; - break; + if ((answer.OriginalTtlValue >= _cachePrefetchEligibility) && (utcNow.AddSeconds(answer.TTL) < _cachePrefetchSamplingTimerTriggersOn)) + { + //answer expires before next sampling so add back to the list to allow refreshing it + addBackToSampleList = true; + break; + } } } @@ -4150,10 +4232,16 @@ namespace DnsServerCore.Dns if (!IsCacheRefreshNeeded(sample.SampleQuestion, _cachePrefetchTrigger + 1)) continue; - cacheRefreshSampleList[i] = null; //remove from sample list to avoid concurrent refresh attempt - - int sampleQuestionIndex = i; - _ = Task.Run(delegate () { return RefreshCacheAsync(cacheRefreshSampleList, sample, sampleQuestionIndex); }); //run task in threadpool since its long running + //run in resolver thread pool + if (_resolverTaskPool.TryQueueTask(delegate (object state) + { + return RefreshCacheAsync(cacheRefreshSampleList, sample, (int)state); + }, i) + ) + { + //refresh cache task was queued + cacheRefreshSampleList[i] = null; //remove from sample list to avoid concurrent refresh attempt + } } } } @@ -4391,6 +4479,26 @@ namespace DnsServerCore.Dns #endregion + #region resolver task pool + + internal bool TryQueueResolverTask(Func task) + { + return _resolverTaskPool.TryQueueTask(task, null); + } + + public void ReconfigureResolverTaskPool(ushort maxConcurrentResolutionsPerCore) + { + TaskPool previousResolverTaskPool = _resolverTaskPool; + + int maxConcurrentResolutions = Environment.ProcessorCount * maxConcurrentResolutionsPerCore; + int resolverQueueSize = maxConcurrentResolutions * 3 * 10; //assuming 3 qps average resolution rate for 10 sec + _resolverTaskPool = new TaskPool(resolverQueueSize, maxConcurrentResolutions, _resolverTaskScheduler); + + previousResolverTaskPool?.Stop(); //stop previous task pool from queuing new tasks + } + + #endregion + #region doh web service private async Task StartDoHAsync() @@ -4425,65 +4533,20 @@ namespace DnsServerCore.Dns //bind to https port if (_enableDnsOverHttps && (_certificateCollection is not null)) { - X509Certificate2 serverCertificate = null; - - foreach (X509Certificate2 certificate in _certificateCollection) - { - if (certificate.HasPrivateKey) - { - serverCertificate = certificate; - break; - } - } - - if (serverCertificate is null) - throw new DnsServerException("DNS Server TLS certificate file must contain a certificate with private key."); - - bool isSupportedHttp2 = _enableDnsOverHttp3; - if (!isSupportedHttp2) - { - switch (Environment.OSVersion.Platform) - { - case PlatformID.Win32NT: - isSupportedHttp2 = Environment.OSVersion.Version.Major >= 10; //http/2 supported on Windows Server 2016/Windows 10 or later - break; - - case PlatformID.Unix: - isSupportedHttp2 = true; //http/2 supported on Linux with OpenSSL 1.0.2 or later (for example, Ubuntu 16.04 or later) - break; - } - } - - List applicationProtocols = new List(); - - if (_enableDnsOverHttp3) - applicationProtocols.Add(new SslApplicationProtocol("h3")); - - if (isSupportedHttp2) - applicationProtocols.Add(new SslApplicationProtocol("h2")); - - applicationProtocols.Add(new SslApplicationProtocol("http/1.1")); - - SslServerAuthenticationOptions sslServerAuthenticationOptions = new SslServerAuthenticationOptions - { - ApplicationProtocols = applicationProtocols, - ServerCertificateContext = SslStreamCertificateContext.Create(serverCertificate, _certificateCollection, false), - }; - foreach (IPAddress localAddress in localAddresses) { serverOptions.Listen(localAddress, _dnsOverHttpsPort, delegate (ListenOptions listenOptions) { if (_enableDnsOverHttp3) listenOptions.Protocols = HttpProtocols.Http1AndHttp2AndHttp3; - else if (isSupportedHttp2) + else if (IsHttp2Supported()) listenOptions.Protocols = HttpProtocols.Http1AndHttp2; else listenOptions.Protocols = HttpProtocols.Http1; listenOptions.UseHttps(delegate (SslStream stream, SslClientHelloInfo clientHelloInfo, object state, CancellationToken cancellationToken) { - return ValueTask.FromResult(sslServerAuthenticationOptions); + return ValueTask.FromResult(_dohSslServerAuthenticationOptions); }, null); }); } @@ -4569,6 +4632,24 @@ namespace DnsServerCore.Dns } } + private bool IsHttp2Supported() + { + if (_enableDnsOverHttp3) + return true; + + switch (Environment.OSVersion.Platform) + { + case PlatformID.Win32NT: + return Environment.OSVersion.Version.Major >= 10; //http/2 supported on Windows Server 2016/Windows 10 or later + + case PlatformID.Unix: + return true; //http/2 supported on Linux with OpenSSL 1.0.2 or later (for example, Ubuntu 16.04 or later) + + default: + return false; + } + } + internal static IReadOnlyList GetValidKestralLocalAddresses(IReadOnlyList localAddresses) { List supportedLocalAddresses = new List(localAddresses.Count); @@ -4875,7 +4956,7 @@ namespace DnsServerCore.Dns } //start reading query packets - int listenerTaskCount = Math.Max(1, Environment.ProcessorCount); + int listenerTaskCount = Environment.ProcessorCount; foreach (Socket udpListener in _udpListeners) { @@ -5522,6 +5603,19 @@ namespace DnsServerCore.Dns set { _listenBacklog = value; } } + public ushort MaxConcurrentResolutionsPerCore + { + get { return Convert.ToUInt16(_resolverTaskPool.MaximumConcurrencyLevel / Environment.ProcessorCount); } + set + { + if (value < 1) + throw new ArgumentOutOfRangeException(nameof(MaxConcurrentResolutionsPerCore), "Value cannot be less than 1."); + + if (MaxConcurrentResolutionsPerCore != value) + ReconfigureResolverTaskPool(value); + } + } + public bool EnableDnsOverUdpProxy { get { return _enableDnsOverUdpProxy; } @@ -5691,6 +5785,22 @@ namespace DnsServerCore.Dns ApplicationProtocols = _doqApplicationProtocols, ServerCertificateContext = certificateContext }; + + List applicationProtocols = new List(); + + if (_enableDnsOverHttp3) + applicationProtocols.Add(new SslApplicationProtocol("h3")); + + if (IsHttp2Supported()) + applicationProtocols.Add(new SslApplicationProtocol("h2")); + + applicationProtocols.Add(new SslApplicationProtocol("http/1.1")); + + _dohSslServerAuthenticationOptions = new SslServerAuthenticationOptions + { + ApplicationProtocols = applicationProtocols, + ServerCertificateContext = SslStreamCertificateContext.Create(serverCertificate, _certificateCollection, false), + }; } } }