diff --git a/DnsServerCore/Dns/Zones/CacheZone.cs b/DnsServerCore/Dns/Zones/CacheZone.cs index 27174b3e..820ce6dc 100644 --- a/DnsServerCore/Dns/Zones/CacheZone.cs +++ b/DnsServerCore/Dns/Zones/CacheZone.cs @@ -1,6 +1,6 @@ /* Technitium DNS Server -Copyright (C) 2021 Shreyas Zare (shreyas@technitium.com) +Copyright (C) 2022 Shreyas Zare (shreyas@technitium.com) This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -19,7 +19,7 @@ along with this program. If not, see . using System; using System.Collections.Generic; -using TechnitiumLibrary.IO; +using TechnitiumLibrary; using TechnitiumLibrary.Net.Dns; using TechnitiumLibrary.Net.Dns.ResourceRecords; @@ -37,14 +37,14 @@ namespace DnsServerCore.Dns.Zones #region private - private static IReadOnlyList ValidateRRSet(DnsResourceRecordType type, IReadOnlyList records, bool serveStale, bool checkForSpecialCacheRecord) + private static IReadOnlyList ValidateRRSet(DnsResourceRecordType type, IReadOnlyList records, bool serveStale, bool skipSpecialCacheRecord) { foreach (DnsResourceRecord record in records) { if (record.IsExpired(serveStale)) return Array.Empty(); //RR Set is expired - if (checkForSpecialCacheRecord && (record.RDATA is DnsCache.DnsSpecialCacheRecord)) + if (skipSpecialCacheRecord && (record.RDATA is DnsCache.DnsSpecialCacheRecord)) return Array.Empty(); //RR Set is special cache record } @@ -69,13 +69,13 @@ namespace DnsServerCore.Dns.Zones public void SetRecords(DnsResourceRecordType type, IReadOnlyList records, bool serveStale) { - bool isFailureRecord = (records.Count > 0) && (records[0].RDATA is DnsCache.DnsSpecialCacheRecord splRecord) && splRecord.IsFailure; + bool isFailureRecord = (records.Count > 0) && (records[0].RDATA is DnsCache.DnsSpecialCacheRecord splRecord) && (splRecord.Type == DnsCache.DnsSpecialCacheRecordType.FailureCache); if (isFailureRecord) { //call trying to cache failure record if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) { - if ((existingRecords.Count > 0) && !(existingRecords[0].RDATA is DnsCache.DnsSpecialCacheRecord existingSplRecord && existingSplRecord.IsFailure) && !DnsResourceRecord.IsRRSetExpired(existingRecords, serveStale)) + if ((existingRecords.Count > 0) && !(existingRecords[0].RDATA is DnsCache.DnsSpecialCacheRecord existingSplRecord && (existingSplRecord.Type == DnsCache.DnsSpecialCacheRecordType.FailureCache)) && !DnsResourceRecord.IsRRSetExpired(existingRecords, serveStale)) return; //skip to avoid overwriting a useful record with a failure record } } @@ -92,6 +92,7 @@ namespace DnsServerCore.Dns.Zones case DnsResourceRecordType.CNAME: case DnsResourceRecordType.SOA: case DnsResourceRecordType.NS: + case DnsResourceRecordType.DS: //do nothing break; @@ -121,32 +122,63 @@ namespace DnsServerCore.Dns.Zones } } - public IReadOnlyList QueryRecords(DnsResourceRecordType type, bool serveStale, bool checkForSpecialCacheRecord) + public IReadOnlyList QueryRecords(DnsResourceRecordType type, bool serveStale, bool skipSpecialCacheRecord) { - //check for CNAME - if (_entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList existingCNAMERecords)) + switch (type) { - IReadOnlyList rrset = ValidateRRSet(type, existingCNAMERecords, serveStale, checkForSpecialCacheRecord); - if (rrset.Count > 0) - { - if ((type == DnsResourceRecordType.CNAME) || (rrset[0].RDATA is DnsCNAMERecord)) - return rrset; - } + case DnsResourceRecordType.DS: + { + //since some zones have CNAME at apex so no CNAME lookup for DS queries! + if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) + return ValidateRRSet(type, existingRecords, serveStale, skipSpecialCacheRecord); + } + break; + + case DnsResourceRecordType.SOA: + case DnsResourceRecordType.DNSKEY: + { + //since some zones have CNAME at apex! + if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) + return ValidateRRSet(type, existingRecords, serveStale, skipSpecialCacheRecord); + + if (_entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList existingCNAMERecords)) + { + IReadOnlyList rrset = ValidateRRSet(type, existingCNAMERecords, serveStale, skipSpecialCacheRecord); + if (rrset.Count > 0) + { + if ((type == DnsResourceRecordType.CNAME) || (rrset[0].RDATA is DnsCNAMERecord)) + return rrset; + } + } + } + break; + + case DnsResourceRecordType.ANY: + List anyRecords = new List(); + + foreach (KeyValuePair> entry in _entries) + anyRecords.AddRange(ValidateRRSet(type, entry.Value, serveStale, true)); + + return anyRecords; + + default: + { + if (_entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList existingCNAMERecords)) + { + IReadOnlyList rrset = ValidateRRSet(type, existingCNAMERecords, serveStale, skipSpecialCacheRecord); + if (rrset.Count > 0) + { + if ((type == DnsResourceRecordType.CNAME) || (rrset[0].RDATA is DnsCNAMERecord)) + return rrset; + } + } + + if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) + return ValidateRRSet(type, existingRecords, serveStale, skipSpecialCacheRecord); + } + break; } - if (type == DnsResourceRecordType.ANY) - { - List anyRecords = new List(); - - foreach (KeyValuePair> entry in _entries) - anyRecords.AddRange(ValidateRRSet(type, entry.Value, serveStale, true)); - - return anyRecords; - } - - if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) - return ValidateRRSet(type, existingRecords, serveStale, checkForSpecialCacheRecord); - return Array.Empty(); }