diff --git a/DnsServerCore/Dns/Zones/CacheZone.cs b/DnsServerCore/Dns/Zones/CacheZone.cs index 64df13fe..27174b3e 100644 --- a/DnsServerCore/Dns/Zones/CacheZone.cs +++ b/DnsServerCore/Dns/Zones/CacheZone.cs @@ -37,52 +37,30 @@ namespace DnsServerCore.Dns.Zones #region private - private static IReadOnlyList FilterExpiredRecords(DnsResourceRecordType type, IReadOnlyList records, bool serveStale, bool filterSpecialCacheRecords) + private static IReadOnlyList ValidateRRSet(DnsResourceRecordType type, IReadOnlyList records, bool serveStale, bool checkForSpecialCacheRecord) { - if (records.Count == 1) - { - DnsResourceRecord record = records[0]; - - if (record.IsExpired(serveStale)) - return Array.Empty(); //record expired - - if (filterSpecialCacheRecords) - { - if (record.RDATA is DnsCache.DnsSpecialCacheRecord) - return Array.Empty(); //special cache record - } - - return records; - } - - List newRecords = new List(records.Count); - foreach (DnsResourceRecord record in records) { if (record.IsExpired(serveStale)) - continue; //record expired + return Array.Empty(); //RR Set is expired - if (filterSpecialCacheRecords) - { - if (record.RDATA is DnsCache.DnsSpecialCacheRecord) - continue; //special cache record - } - - newRecords.Add(record); + if (checkForSpecialCacheRecord && (record.RDATA is DnsCache.DnsSpecialCacheRecord)) + return Array.Empty(); //RR Set is special cache record } - if (newRecords.Count > 1) + if (records.Count > 1) { switch (type) { case DnsResourceRecordType.A: case DnsResourceRecordType.AAAA: + List newRecords = new List(records); newRecords.Shuffle(); //shuffle records to allow load balancing - break; + return newRecords; } } - return newRecords; + return records; } #endregion @@ -97,7 +75,7 @@ namespace DnsServerCore.Dns.Zones //call trying to cache failure record if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) { - if ((existingRecords.Count > 0) && !(existingRecords[0].RDATA is DnsCache.DnsSpecialCacheRecord existingSplRecord && existingSplRecord.IsFailure) && !existingRecords[0].IsExpired(serveStale)) + if ((existingRecords.Count > 0) && !(existingRecords[0].RDATA is DnsCache.DnsSpecialCacheRecord existingSplRecord && existingSplRecord.IsFailure) && !DnsResourceRecord.IsRRSetExpired(existingRecords, serveStale)) return; //skip to avoid overwriting a useful record with a failure record } } @@ -138,57 +116,21 @@ namespace DnsServerCore.Dns.Zones { foreach (KeyValuePair> entry in _entries) { - bool isExpired = false; - - foreach (DnsResourceRecord record in entry.Value) - { - if (record.IsExpired(serveStale)) - { - //record expired - isExpired = true; - break; - } - } - - if (isExpired) - { - List newRecords = null; - - foreach (DnsResourceRecord record in entry.Value) - { - if (record.IsExpired(serveStale)) - continue; //record expired, skip it - - if (newRecords == null) - newRecords = new List(entry.Value.Count); - - newRecords.Add(record); - } - - if (newRecords == null) - { - //all records expired; remove entry - _entries.TryRemove(entry.Key, out _); - } - else - { - //try update entry with non-expired records - _entries.TryUpdate(entry.Key, newRecords, entry.Value); - } - } + if (DnsResourceRecord.IsRRSetExpired(entry.Value, serveStale)) + _entries.TryRemove(entry.Key, out _); //RR Set is expired; remove entry } } - public IReadOnlyList QueryRecords(DnsResourceRecordType type, bool serveStale, bool filterSpecialCacheRecords) + public IReadOnlyList QueryRecords(DnsResourceRecordType type, bool serveStale, bool checkForSpecialCacheRecord) { //check for CNAME if (_entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList existingCNAMERecords)) { - IReadOnlyList filteredRecords = FilterExpiredRecords(type, existingCNAMERecords, serveStale, filterSpecialCacheRecords); - if (filteredRecords.Count > 0) + IReadOnlyList rrset = ValidateRRSet(type, existingCNAMERecords, serveStale, checkForSpecialCacheRecord); + if (rrset.Count > 0) { - if ((type == DnsResourceRecordType.CNAME) || (filteredRecords[0].RDATA is DnsCNAMERecord)) - return filteredRecords; + if ((type == DnsResourceRecordType.CNAME) || (rrset[0].RDATA is DnsCNAMERecord)) + return rrset; } } @@ -197,13 +139,13 @@ namespace DnsServerCore.Dns.Zones List anyRecords = new List(); foreach (KeyValuePair> entry in _entries) - anyRecords.AddRange(entry.Value); + anyRecords.AddRange(ValidateRRSet(type, entry.Value, serveStale, true)); - return FilterExpiredRecords(type, anyRecords, serveStale, true); + return anyRecords; } if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) - return FilterExpiredRecords(type, existingRecords, serveStale, filterSpecialCacheRecords); + return ValidateRRSet(type, existingRecords, serveStale, checkForSpecialCacheRecord); return Array.Empty(); }