From 1d3a1c5c3b441cdce7dbab15b8135f331fd12980 Mon Sep 17 00:00:00 2001 From: Shreyas Zare Date: Sun, 23 Apr 2023 16:12:58 +0530 Subject: [PATCH] CacheZoneManager: added support for conditional forwarding client subnet. --- .../Dns/ZoneManagers/CacheZoneManager.cs | 75 +++++++++++-------- 1 file changed, 45 insertions(+), 30 deletions(-) diff --git a/DnsServerCore/Dns/ZoneManagers/CacheZoneManager.cs b/DnsServerCore/Dns/ZoneManagers/CacheZoneManager.cs index 5d4391e5..4cc41245 100644 --- a/DnsServerCore/Dns/ZoneManagers/CacheZoneManager.cs +++ b/DnsServerCore/Dns/ZoneManagers/CacheZoneManager.cs @@ -73,11 +73,10 @@ namespace DnsServerCore.Dns.ZoneManagers IReadOnlyList glueRecords = GetGlueRecordsFrom(resourceRecord); IReadOnlyList rrsigRecords = GetRRSIGRecordsFrom(resourceRecord); IReadOnlyList nsecRecords = GetNSECRecordsFrom(resourceRecord); - NetworkAddress eDnsClientSubnet; + NetworkAddress eDnsClientSubnet = GetEDnsClientSubnetFrom(resourceRecord); + bool conditionalForwardingClientSubnet = GetConditionalForwardingClientSubnetFrom(resourceRecord); - if (CacheZone.IsTypeSupportedForEDnsClientSubnet(resourceRecord.Type)) - eDnsClientSubnet = GetEDnsClientSubnetFrom(resourceRecord); - else + if (!conditionalForwardingClientSubnet && !CacheZone.IsTypeSupportedForEDnsClientSubnet(resourceRecord.Type)) eDnsClientSubnet = null; if ((glueRecords is not null) || (rrsigRecords is not null) || (nsecRecords is not null) || (eDnsClientSubnet is not null)) @@ -88,6 +87,7 @@ namespace DnsServerCore.Dns.ZoneManagers rrInfo.RRSIGRecords = rrsigRecords; rrInfo.NSECRecords = nsecRecords; rrInfo.EDnsClientSubnet = eDnsClientSubnet; + rrInfo.ConditionalForwardingClientSubnet = conditionalForwardingClientSubnet; if (glueRecords is not null) { @@ -179,9 +179,9 @@ namespace DnsServerCore.Dns.ZoneManagers #region private - private static IReadOnlyList AddDSRecordsTo(CacheZone delegation, bool serveStale, IReadOnlyList nsRecords) + private static IReadOnlyList AddDSRecordsTo(CacheZone delegation, bool serveStale, IReadOnlyList nsRecords, NetworkAddress eDnsClientSubnet, bool conditionalForwardingClientSubnet) { - IReadOnlyList records = delegation.QueryRecords(DnsResourceRecordType.DS, serveStale, true, null); + IReadOnlyList records = delegation.QueryRecords(DnsResourceRecordType.DS, serveStale, true, eDnsClientSubnet, conditionalForwardingClientSubnet); if ((records.Count > 0) && (records[0].Type == DnsResourceRecordType.DS)) { List newNSRecords = new List(nsRecords.Count + records.Count); @@ -253,7 +253,7 @@ namespace DnsServerCore.Dns.ZoneManagers newAuthority = newAuthorityList; } - private void ResolveCNAME(DnsQuestionRecord question, DnsResourceRecord lastCNAME, bool serveStale, NetworkAddress eDnsClientSubnet, List answerRecords) + private void ResolveCNAME(DnsQuestionRecord question, DnsResourceRecord lastCNAME, bool serveStale, NetworkAddress eDnsClientSubnet, bool conditionalForwardingClientSubnet, List answerRecords) { int queryCount = 0; @@ -266,7 +266,7 @@ namespace DnsServerCore.Dns.ZoneManagers if (!_root.TryGet(cnameDomain, out CacheZone cacheZone)) break; - IReadOnlyList records = cacheZone.QueryRecords(question.Type, serveStale, true, eDnsClientSubnet); + IReadOnlyList records = cacheZone.QueryRecords(question.Type, serveStale, true, eDnsClientSubnet, conditionalForwardingClientSubnet); if (records.Count < 1) break; @@ -293,7 +293,7 @@ namespace DnsServerCore.Dns.ZoneManagers while (++queryCount < DnsServer.MAX_CNAME_HOPS); } - private bool DoDNAMESubstitution(DnsQuestionRecord question, IReadOnlyList answer, bool serveStale, out IReadOnlyList newAnswer) + private bool DoDNAMESubstitution(DnsQuestionRecord question, IReadOnlyList answer, bool serveStale, NetworkAddress eDnsClientSubnet, bool conditionalForwardingClientSubnet, out IReadOnlyList newAnswer) { DnsResourceRecord dnameRR = answer[0]; @@ -309,7 +309,7 @@ namespace DnsServerCore.Dns.ZoneManagers cnameRR }; - ResolveCNAME(question, cnameRR, serveStale, null, list); + ResolveCNAME(question, cnameRR, serveStale, eDnsClientSubnet, conditionalForwardingClientSubnet, list); newAnswer = list; return true; @@ -321,7 +321,7 @@ namespace DnsServerCore.Dns.ZoneManagers } } - private IReadOnlyList GetAdditionalRecords(IReadOnlyList refRecords, bool serveStale, bool dnssecOk) + private IReadOnlyList GetAdditionalRecords(IReadOnlyList refRecords, bool serveStale, bool dnssecOk, NetworkAddress eDnsClientSubnet, bool conditionalForwardingClientSubnet) { List additionalRecords = new List(); @@ -332,21 +332,21 @@ namespace DnsServerCore.Dns.ZoneManagers case DnsResourceRecordType.NS: DnsNSRecordData nsRecord = refRecord.RDATA as DnsNSRecordData; if (nsRecord is not null) - ResolveAdditionalRecords(refRecord, nsRecord.NameServer, serveStale, dnssecOk, additionalRecords); + ResolveAdditionalRecords(refRecord, nsRecord.NameServer, serveStale, dnssecOk, eDnsClientSubnet, conditionalForwardingClientSubnet, additionalRecords); break; case DnsResourceRecordType.MX: DnsMXRecordData mxRecord = refRecord.RDATA as DnsMXRecordData; if (mxRecord is not null) - ResolveAdditionalRecords(refRecord, mxRecord.Exchange, serveStale, dnssecOk, additionalRecords); + ResolveAdditionalRecords(refRecord, mxRecord.Exchange, serveStale, dnssecOk, eDnsClientSubnet, conditionalForwardingClientSubnet, additionalRecords); break; case DnsResourceRecordType.SRV: DnsSRVRecordData srvRecord = refRecord.RDATA as DnsSRVRecordData; if (srvRecord is not null) - ResolveAdditionalRecords(refRecord, srvRecord.Target, serveStale, dnssecOk, additionalRecords); + ResolveAdditionalRecords(refRecord, srvRecord.Target, serveStale, dnssecOk, eDnsClientSubnet, conditionalForwardingClientSubnet, additionalRecords); break; } @@ -355,7 +355,7 @@ namespace DnsServerCore.Dns.ZoneManagers return additionalRecords; } - private void ResolveAdditionalRecords(DnsResourceRecord refRecord, string domain, bool serveStale, bool dnssecOk, List additionalRecords) + private void ResolveAdditionalRecords(DnsResourceRecord refRecord, string domain, bool serveStale, bool dnssecOk, NetworkAddress eDnsClientSubnet, bool conditionalForwardingClientSubnet, List additionalRecords) { IReadOnlyList glueRecords = refRecord.GetCacheRecordInfo().GlueRecords; if (glueRecords is not null) @@ -385,13 +385,13 @@ namespace DnsServerCore.Dns.ZoneManagers if (_root.TryGet(domain, out CacheZone cacheZone)) { { - IReadOnlyList records = cacheZone.QueryRecords(DnsResourceRecordType.A, serveStale, true, null); + IReadOnlyList records = cacheZone.QueryRecords(DnsResourceRecordType.A, serveStale, true, eDnsClientSubnet, conditionalForwardingClientSubnet); if ((records.Count > 0) && (records[0].Type == DnsResourceRecordType.A)) additionalRecords.AddRange(records); } { - IReadOnlyList records = cacheZone.QueryRecords(DnsResourceRecordType.AAAA, serveStale, true, null); + IReadOnlyList records = cacheZone.QueryRecords(DnsResourceRecordType.AAAA, serveStale, true, eDnsClientSubnet, conditionalForwardingClientSubnet); if ((records.Count > 0) && (records[0].Type == DnsResourceRecordType.AAAA)) additionalRecords.AddRange(records); } @@ -548,6 +548,17 @@ namespace DnsServerCore.Dns.ZoneManagers { string domain = request.Question[0].Name; + NetworkAddress eDnsClientSubnet = null; + bool conditionalForwardingClientSubnet = false; + { + EDnsClientSubnetOptionData requestECS = request.GetEDnsClientSubnetOption(); + if (requestECS is not null) + { + eDnsClientSubnet = new NetworkAddress(requestECS.Address, requestECS.SourcePrefixLength); + conditionalForwardingClientSubnet = requestECS.ConditionalForwardingClientSubnet; + } + } + do { _ = _root.FindZone(domain, out _, out CacheZone delegation); @@ -555,23 +566,23 @@ namespace DnsServerCore.Dns.ZoneManagers return null; //return closest name servers in delegation - IReadOnlyList closestAuthority = delegation.QueryRecords(DnsResourceRecordType.NS, false, true, null); + IReadOnlyList closestAuthority = delegation.QueryRecords(DnsResourceRecordType.NS, false, true, eDnsClientSubnet, conditionalForwardingClientSubnet); if ((closestAuthority.Count > 0) && (closestAuthority[0].Type == DnsResourceRecordType.NS) && (closestAuthority[0].Name.Length > 0)) //dont trust root name servers from cache! { if (request.DnssecOk) { if (closestAuthority[0].DnssecStatus != DnssecStatus.Disabled) //dont return records with disabled status { - closestAuthority = AddDSRecordsTo(delegation, false, closestAuthority); + closestAuthority = AddDSRecordsTo(delegation, false, closestAuthority, eDnsClientSubnet, conditionalForwardingClientSubnet); - IReadOnlyList additional = GetAdditionalRecords(closestAuthority, false, request.DnssecOk); + IReadOnlyList additional = GetAdditionalRecords(closestAuthority, false, request.DnssecOk, eDnsClientSubnet, conditionalForwardingClientSubnet); return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.NoError, request.Question, null, closestAuthority, additional); } } else { - IReadOnlyList additional = GetAdditionalRecords(closestAuthority, false, request.DnssecOk); + IReadOnlyList additional = GetAdditionalRecords(closestAuthority, false, request.DnssecOk, eDnsClientSubnet, conditionalForwardingClientSubnet); return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.NoError, request.Question, null, closestAuthority, additional); } @@ -590,10 +601,14 @@ namespace DnsServerCore.Dns.ZoneManagers DnsQuestionRecord question = request.Question[0]; NetworkAddress eDnsClientSubnet = null; + bool conditionalForwardingClientSubnet = false; { EDnsClientSubnetOptionData requestECS = request.GetEDnsClientSubnetOption(); if (requestECS is not null) + { eDnsClientSubnet = new NetworkAddress(requestECS.Address, requestECS.SourcePrefixLength); + conditionalForwardingClientSubnet = requestECS.ConditionalForwardingClientSubnet; + } } CacheZone zone; @@ -613,7 +628,7 @@ namespace DnsServerCore.Dns.ZoneManagers if (zone is not null) { //zone found - IReadOnlyList answer = zone.QueryRecords(question.Type, serveStaleAndResetExpiry, false, eDnsClientSubnet); + IReadOnlyList answer = zone.QueryRecords(question.Type, serveStaleAndResetExpiry, false, eDnsClientSubnet, conditionalForwardingClientSubnet); if (answer.Count > 0) { //answer found in cache @@ -724,7 +739,7 @@ namespace DnsServerCore.Dns.ZoneManagers List newAnswers = new List(answer.Count + 3); newAnswers.AddRange(answer); - ResolveCNAME(question, lastRR, serveStaleAndResetExpiry, eDnsClientSubnet, newAnswers); + ResolveCNAME(question, lastRR, serveStaleAndResetExpiry, eDnsClientSubnet, conditionalForwardingClientSubnet, newAnswers); answer = newAnswers; } @@ -754,7 +769,7 @@ namespace DnsServerCore.Dns.ZoneManagers case DnsResourceRecordType.NS: case DnsResourceRecordType.MX: case DnsResourceRecordType.SRV: - additional = GetAdditionalRecords(answer, serveStaleAndResetExpiry, request.DnssecOk); + additional = GetAdditionalRecords(answer, serveStaleAndResetExpiry, request.DnssecOk, eDnsClientSubnet, conditionalForwardingClientSubnet); break; } @@ -838,12 +853,12 @@ namespace DnsServerCore.Dns.ZoneManagers //check for DNAME in closest zone if (closest is not null) { - IReadOnlyList answer = closest.QueryRecords(DnsResourceRecordType.DNAME, serveStaleAndResetExpiry, true, null); + IReadOnlyList answer = closest.QueryRecords(DnsResourceRecordType.DNAME, serveStaleAndResetExpiry, true, eDnsClientSubnet, conditionalForwardingClientSubnet); if ((answer.Count > 0) && (answer[0].Type == DnsResourceRecordType.DNAME)) { DnsResponseCode rCode; - if (DoDNAMESubstitution(question, answer, serveStaleAndResetExpiry, out answer)) + if (DoDNAMESubstitution(question, answer, serveStaleAndResetExpiry, eDnsClientSubnet, conditionalForwardingClientSubnet, out answer)) rCode = DnsResponseCode.NoError; else rCode = DnsResponseCode.YXDomain; @@ -916,23 +931,23 @@ namespace DnsServerCore.Dns.ZoneManagers while (true) { - IReadOnlyList closestAuthority = delegation.QueryRecords(DnsResourceRecordType.NS, serveStaleAndResetExpiry, true, null); + IReadOnlyList closestAuthority = delegation.QueryRecords(DnsResourceRecordType.NS, serveStaleAndResetExpiry, true, eDnsClientSubnet, conditionalForwardingClientSubnet); if ((closestAuthority.Count > 0) && (closestAuthority[0].Type == DnsResourceRecordType.NS) && (closestAuthority[0].Name.Length > 0)) //dont trust root name servers from cache! { if (request.DnssecOk) { if (closestAuthority[0].DnssecStatus != DnssecStatus.Disabled) //dont return records with disabled status { - closestAuthority = AddDSRecordsTo(delegation, serveStaleAndResetExpiry, closestAuthority); + closestAuthority = AddDSRecordsTo(delegation, serveStaleAndResetExpiry, closestAuthority, eDnsClientSubnet, conditionalForwardingClientSubnet); - IReadOnlyList additional = GetAdditionalRecords(closestAuthority, serveStaleAndResetExpiry, request.DnssecOk); + IReadOnlyList additional = GetAdditionalRecords(closestAuthority, serveStaleAndResetExpiry, request.DnssecOk, eDnsClientSubnet, conditionalForwardingClientSubnet); return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, closestAuthority[0].DnssecStatus == DnssecStatus.Secure, request.CheckingDisabled, DnsResponseCode.NoError, request.Question, null, closestAuthority, additional); } } else { - IReadOnlyList additional = GetAdditionalRecords(closestAuthority, serveStaleAndResetExpiry, request.DnssecOk); + IReadOnlyList additional = GetAdditionalRecords(closestAuthority, serveStaleAndResetExpiry, request.DnssecOk, eDnsClientSubnet, conditionalForwardingClientSubnet); return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, closestAuthority[0].DnssecStatus == DnssecStatus.Secure, request.CheckingDisabled, DnsResponseCode.NoError, request.Question, null, closestAuthority, additional); }