From b9986f15014518668bc3e3f7bf443e26804a6696 Mon Sep 17 00:00:00 2001 From: Shreyas Zare Date: Sat, 16 Mar 2024 13:59:22 +0530 Subject: [PATCH] DnsServer: Implemented EDNS Client Subnet override feature. Updated ProcessUpdateQueryAsync() to check for update permissions for secondary zone. Updated serve stale to wait for max 1800ms. Code refactoring done. --- DnsServerCore/Dns/DnsServer.cs | 98 +++++++++++++++++++++++++++------- 1 file changed, 79 insertions(+), 19 deletions(-) diff --git a/DnsServerCore/Dns/DnsServer.cs b/DnsServerCore/Dns/DnsServer.cs index 673997b8..c4aaa98d 100644 --- a/DnsServerCore/Dns/DnsServer.cs +++ b/DnsServerCore/Dns/DnsServer.cs @@ -90,7 +90,8 @@ namespace DnsServerCore.Dns #region variables internal const int MAX_CNAME_HOPS = 16; - const int SERVE_STALE_TIME_DIFFERENCE = 200; //200ms before client timeout as per RFC 8767 + const int SERVE_STALE_MAX_WAIT_TIME = 1800; //max time to wait before serve stale [RFC 8767] + const int SERVE_STALE_TIME_DIFFERENCE = 200; //200ms before client timeout [RFC 8767] static readonly IPEndPoint IPENDPOINT_ANY_0 = new IPEndPoint(IPAddress.Any, 0); static readonly IReadOnlyCollection _aRecords = new DnsARecordData[] { new DnsARecordData(IPAddress.Any) }; @@ -134,6 +135,8 @@ namespace DnsServerCore.Dns bool _eDnsClientSubnet; byte _eDnsClientSubnetIPv4PrefixLength = 24; byte _eDnsClientSubnetIPv6PrefixLength = 56; + NetworkAddress _eDnsClientSubnetIpv4Override; + NetworkAddress _eDnsClientSubnetIpv6Override; int _qpmLimitRequests = 6000; //100qps int _qpmLimitErrors = 600; //10qps @@ -598,7 +601,7 @@ namespace DnsServerCore.Dns SslStream tlsStream = new SslStream(new NetworkStream(socket)); string serverName = null; - await tlsStream.AuthenticateAsServerAsync(delegate (SslStream stream, SslClientHelloInfo clientHelloInfo, object? state, CancellationToken cancellationToken) + await tlsStream.AuthenticateAsServerAsync(delegate (SslStream stream, SslClientHelloInfo clientHelloInfo, object state, CancellationToken cancellationToken) { serverName = clientHelloInfo.ServerName; return ValueTask.FromResult(_sslServerAuthenticationOptions); @@ -1814,8 +1817,13 @@ namespace DnsServerCore.Dns } case AuthZoneType.Secondary: - //forward to primary + //forward { + //check for permissions + if (!await IsUpdatePermittedAsync()) + return new DnsDatagram(request.Identifier, true, DnsOpcode.Update, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.Refused, request.Question) { Tag = DnsServerResponseType.Authoritative }; + + //forward to primary IReadOnlyList primaryNameServers = await authZoneInfo.GetPrimaryNameServerAddressesAsync(this); DnsResourceRecord soaRecord = authZoneInfo.GetApexRecords(DnsResourceRecordType.SOA)[0]; @@ -2817,14 +2825,26 @@ namespace DnsServerCore.Dns { DnsQuestionRecord question = request.Question[0]; NetworkAddress eDnsClientSubnet = null; - bool conditionalForwardingClientSubnet = false; + bool advancedForwardingClientSubnet = false; //this feature is used by Advanced Forwarding app to cache response per network group if (_eDnsClientSubnet) { EDnsClientSubnetOptionData requestECS = request.GetEDnsClientSubnetOption(); if (requestECS is null) { - if (!NetUtilities.IsPrivateIP(remoteEP.Address)) + if ((_eDnsClientSubnetIpv4Override is not null) && (remoteEP.AddressFamily == AddressFamily.InterNetwork)) + { + //set ipv4 override shadow ECS option + eDnsClientSubnet = _eDnsClientSubnetIpv4Override; + request.SetShadowEDnsClientSubnetOption(eDnsClientSubnet); + } + else if ((_eDnsClientSubnetIpv6Override is not null) && (remoteEP.AddressFamily == AddressFamily.InterNetworkV6)) + { + //set ipv6 override shadow ECS option + eDnsClientSubnet = _eDnsClientSubnetIpv6Override; + request.SetShadowEDnsClientSubnetOption(eDnsClientSubnet); + } + else if (!NetUtilities.IsPrivateIP(remoteEP.Address)) { //set shadow ECS option switch (remoteEP.AddressFamily) @@ -2849,9 +2869,10 @@ namespace DnsServerCore.Dns { return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, request.CheckingDisabled, DnsResponseCode.FormatError, request.Question) { Tag = DnsServerResponseType.Authoritative }; } - else if (requestECS.ConditionalForwardingClientSubnet) + else if (requestECS.AdvancedForwardingClientSubnet) { - conditionalForwardingClientSubnet = true; + //request from Advanced Forwarding app + advancedForwardingClientSubnet = true; eDnsClientSubnet = new NetworkAddress(requestECS.Address, requestECS.SourcePrefixLength); } else if ((requestECS.SourcePrefixLength == 0) || NetUtilities.IsPrivateIP(requestECS.Address)) @@ -2859,6 +2880,18 @@ namespace DnsServerCore.Dns //disable ECS option request.ShadowHideEDnsClientSubnetOption(); } + else if ((_eDnsClientSubnetIpv4Override is not null) && (remoteEP.AddressFamily == AddressFamily.InterNetwork)) + { + //set ipv4 override shadow ECS option + eDnsClientSubnet = _eDnsClientSubnetIpv4Override; + request.SetShadowEDnsClientSubnetOption(eDnsClientSubnet); + } + else if ((_eDnsClientSubnetIpv6Override is not null) && (remoteEP.AddressFamily == AddressFamily.InterNetworkV6)) + { + //set ipv6 override shadow ECS option + eDnsClientSubnet = _eDnsClientSubnetIpv6Override; + request.SetShadowEDnsClientSubnetOption(eDnsClientSubnet); + } else { //use ECS from client request @@ -2878,12 +2911,13 @@ namespace DnsServerCore.Dns } else { + //ECS feature disabled EDnsClientSubnetOptionData requestECS = request.GetEDnsClientSubnetOption(); if (requestECS is not null) { - conditionalForwardingClientSubnet = requestECS.ConditionalForwardingClientSubnet; - if (conditionalForwardingClientSubnet) - eDnsClientSubnet = new NetworkAddress(requestECS.Address, requestECS.SourcePrefixLength); + advancedForwardingClientSubnet = requestECS.AdvancedForwardingClientSubnet; + if (advancedForwardingClientSubnet) + eDnsClientSubnet = new NetworkAddress(requestECS.Address, requestECS.SourcePrefixLength); //request from Advanced Forwarding app else request.ShadowHideEDnsClientSubnetOption(); //hide ECS option } @@ -2922,7 +2956,7 @@ namespace DnsServerCore.Dns //got new resolver task added so question is not being resolved; do recursive resolution in another task on resolver thread pool _ = Task.Factory.StartNew(delegate () { - return RecursiveResolveAsync(question, eDnsClientSubnet, conditionalForwardingClientSubnet, conditionalForwarders, dnssecValidation, cachePrefetchOperation, cacheRefreshOperation, skipDnsAppAuthoritativeRequestHandlers, resolverTaskCompletionSource); + return RecursiveResolveAsync(question, eDnsClientSubnet, advancedForwardingClientSubnet, conditionalForwarders, dnssecValidation, cachePrefetchOperation, cacheRefreshOperation, skipDnsAppAuthoritativeRequestHandlers, resolverTaskCompletionSource); }, CancellationToken.None, TaskCreationOptions.DenyChildAttach, _resolverTaskScheduler); } @@ -2934,9 +2968,10 @@ namespace DnsServerCore.Dns if (_serveStale) { DateTime resolverWaitStartTime = DateTime.UtcNow; + int waitTimeout = Math.Min(SERVE_STALE_MAX_WAIT_TIME, _clientTimeout - SERVE_STALE_TIME_DIFFERENCE); //200ms before client timeout or max 1800ms [RFC 8767] //wait till short timeout for response - if (await Task.WhenAny(resolverTask, Task.Delay(_clientTimeout - SERVE_STALE_TIME_DIFFERENCE)) == resolverTask) //200ms before client timeout as per RFC 8767 + if (await Task.WhenAny(resolverTask, Task.Delay(waitTimeout)) == resolverTask) { //resolver signaled RecursiveResolveResponse response = await resolverTask; @@ -2993,7 +3028,7 @@ namespace DnsServerCore.Dns return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, request.CheckingDisabled, DnsResponseCode.ServerFailure, request.Question, null, null, null, _udpPayloadSize, request.DnssecOk ? EDnsHeaderFlags.DNSSEC_OK : EDnsHeaderFlags.None, options); } - private async Task RecursiveResolveAsync(DnsQuestionRecord question, NetworkAddress eDnsClientSubnet, bool conditionalForwardingClientSubnet, IReadOnlyList conditionalForwarders, bool dnssecValidation, bool cachePrefetchOperation, bool cacheRefreshOperation, bool skipDnsAppAuthoritativeRequestHandlers, TaskCompletionSource taskCompletionSource) + private async Task RecursiveResolveAsync(DnsQuestionRecord question, NetworkAddress eDnsClientSubnet, bool advancedForwardingClientSubnet, IReadOnlyList conditionalForwarders, bool dnssecValidation, bool cachePrefetchOperation, bool cacheRefreshOperation, bool skipDnsAppAuthoritativeRequestHandlers, TaskCompletionSource taskCompletionSource) { try { @@ -3002,7 +3037,7 @@ namespace DnsServerCore.Dns if (cachePrefetchOperation || cacheRefreshOperation) dnsCache = new ResolverPrefetchDnsCache(this, skipDnsAppAuthoritativeRequestHandlers, question); - else if (skipDnsAppAuthoritativeRequestHandlers || conditionalForwardingClientSubnet) + else if (skipDnsAppAuthoritativeRequestHandlers || advancedForwardingClientSubnet) dnsCache = _dnsCacheSkipDnsApps; //to prevent request reaching apps again else dnsCache = _dnsCache; @@ -3042,7 +3077,7 @@ namespace DnsServerCore.Dns if (conditionalForwarders.Count == 1) { DnsResourceRecord conditionalForwarder = conditionalForwarders[0]; - response = await ConditionalForwarderResolveAsync(question, eDnsClientSubnet, conditionalForwardingClientSubnet, dnsCache, conditionalForwarder.RDATA as DnsForwarderRecordData, conditionalForwarder.Name); + response = await ConditionalForwarderResolveAsync(question, eDnsClientSubnet, advancedForwardingClientSubnet, dnsCache, conditionalForwarder.RDATA as DnsForwarderRecordData, conditionalForwarder.Name); } else { @@ -3064,7 +3099,7 @@ namespace DnsServerCore.Dns tasks.Add(Task.Factory.StartNew(delegate () { - return ConditionalForwarderResolveAsync(question, eDnsClientSubnet, conditionalForwardingClientSubnet, dnsCache, forwarder, conditionalForwarder.Name, cancellationToken); + return ConditionalForwarderResolveAsync(question, eDnsClientSubnet, advancedForwardingClientSubnet, dnsCache, forwarder, conditionalForwarder.Name, cancellationToken); }, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Current).Unwrap()); } @@ -3333,7 +3368,7 @@ namespace DnsServerCore.Dns } } - private Task ConditionalForwarderResolveAsync(DnsQuestionRecord question, NetworkAddress eDnsClientSubnet, bool conditionalForwardingClientSubnet, IDnsCache dnsCache, DnsForwarderRecordData forwarder, string conditionalForwardingZoneCut, CancellationToken cancellationToken = default) + private Task ConditionalForwarderResolveAsync(DnsQuestionRecord question, NetworkAddress eDnsClientSubnet, bool advancedForwardingClientSubnet, IDnsCache dnsCache, DnsForwarderRecordData forwarder, string conditionalForwardingZoneCut, CancellationToken cancellationToken = default) { DnsClient dnsClient = new DnsClient(forwarder.NameServer); @@ -3347,7 +3382,7 @@ namespace DnsServerCore.Dns dnsClient.UdpPayloadSize = _udpPayloadSize; dnsClient.DnssecValidation = forwarder.DnssecValidation; dnsClient.EDnsClientSubnet = eDnsClientSubnet; - dnsClient.ConditionalForwardingClientSubnet = conditionalForwardingClientSubnet; + dnsClient.AdvancedForwardingClientSubnet = advancedForwardingClientSubnet; dnsClient.ConditionalForwardingZoneCut = conditionalForwardingZoneCut; return dnsClient.ResolveAsync(question, cancellationToken); @@ -3491,8 +3526,9 @@ namespace DnsServerCore.Dns //additional section checks if (additional.Count > 0) { - if ((request.EDNS is not null) && (response.EDNS is not null) && ((response.EDNS.Options.Count > 0) || (response.DnsClientExtendedErrors.Count > 0))) + if ((response.RCODE != DnsResponseCode.NoError) && (request.EDNS is not null) && (response.EDNS is not null) && ((response.EDNS.Options.Count > 0) || (response.DnsClientExtendedErrors.Count > 0))) { + //only responses with RCODE!=NoError gets cached as a special cache record to preserve EDNS options //copy options as new OPT and keep other records List newAdditional = new List(additional.Count); @@ -4848,6 +4884,30 @@ namespace DnsServerCore.Dns } } + public NetworkAddress EDnsClientSubnetIpv4Override + { + get { return _eDnsClientSubnetIpv4Override; } + set + { + if ((value is not null) && (value.AddressFamily != AddressFamily.InterNetwork)) + throw new ArgumentException(nameof(EDnsClientSubnetIpv4Override), "EDNS Client Subnet IPv4 Override must be an IPv4 network address."); + + _eDnsClientSubnetIpv4Override = value; + } + } + + public NetworkAddress EDnsClientSubnetIpv6Override + { + get { return _eDnsClientSubnetIpv6Override; } + set + { + if ((value is not null) && (value.AddressFamily != AddressFamily.InterNetworkV6)) + throw new ArgumentException(nameof(EDnsClientSubnetIpv6Override), "EDNS Client Subnet IPv6 Override must be an IPv6 network address."); + + _eDnsClientSubnetIpv6Override = value; + } + } + public int QpmLimitRequests { get { return _qpmLimitRequests; }