diff --git a/DnsServerCore/Dns/Zones/CacheZone.cs b/DnsServerCore/Dns/Zones/CacheZone.cs index bf896269..ddc4b7d7 100644 --- a/DnsServerCore/Dns/Zones/CacheZone.cs +++ b/DnsServerCore/Dns/Zones/CacheZone.cs @@ -76,36 +76,39 @@ namespace DnsServerCore.Dns.Zones public bool SetRecords(DnsResourceRecordType type, IReadOnlyList records, bool serveStale) { + if (records.Count == 0) + return false; + bool isFailureRecord = false; - if (records.Count > 0) + if (records[0].RDATA is DnsCache.DnsSpecialCacheRecordData splRecord) { - if (records[0].RDATA is DnsCache.DnsSpecialCacheRecordData splRecord) + if (splRecord.IsFailureOrBadCache) { - if (splRecord.IsFailureOrBadCache) - { - //call trying to cache failure record - isFailureRecord = true; + //call trying to cache failure record + isFailureRecord = true; - if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) - { - if ((existingRecords.Count > 0) && !(existingRecords[0].RDATA is DnsCache.DnsSpecialCacheRecordData existingSplRecord && existingSplRecord.IsFailureOrBadCache) && !DnsResourceRecord.IsRRSetExpired(existingRecords, serveStale)) - return false; //skip to avoid overwriting a useful record with a failure record - } + if (_entries.TryGetValue(type, out IReadOnlyList existingRecords) && (existingRecords.Count > 0) && !DnsResourceRecord.IsRRSetExpired(existingRecords, serveStale)) + { + if ((existingRecords[0].RDATA is not DnsCache.DnsSpecialCacheRecordData existingSplRecord) || !existingSplRecord.IsFailureOrBadCache) + return false; //skip to avoid overwriting a useful record with a failure record + + //copy extended errors from existing spl record + splRecord.CopyExtendedDnsErrorsFrom(existingSplRecord); } } - else if ((type == DnsResourceRecordType.NS) && (records[0].RDATA is DnsNSRecordData ns) && !ns.IsParentSideTtlSet) + } + else if ((type == DnsResourceRecordType.NS) && (records[0].RDATA is DnsNSRecordData ns) && !ns.IsParentSideTtlSet) + { + //for ns revalidation + if (_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList existingNSRecords)) { - //for ns revalidation - if (_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList existingNSRecords)) + if ((existingNSRecords.Count > 0) && (existingNSRecords[0].RDATA is DnsNSRecordData existingNS) && existingNS.IsParentSideTtlSet) { - if ((existingNSRecords.Count > 0) && (existingNSRecords[0].RDATA is DnsNSRecordData existingNS) && existingNS.IsParentSideTtlSet) - { - uint parentSideTtl = existingNS.ParentSideTtl; + uint parentSideTtl = existingNS.ParentSideTtl; - foreach (DnsResourceRecord record in records) - (record.RDATA as DnsNSRecordData).ParentSideTtl = parentSideTtl; - } + foreach (DnsResourceRecord record in records) + (record.RDATA as DnsNSRecordData).ParentSideTtl = parentSideTtl; } } }