diff --git a/DnsServerCore/Dns/Zones/CacheZone.cs b/DnsServerCore/Dns/Zones/CacheZone.cs index 90ea3c92..9675f2f3 100644 --- a/DnsServerCore/Dns/Zones/CacheZone.cs +++ b/DnsServerCore/Dns/Zones/CacheZone.cs @@ -76,14 +76,37 @@ namespace DnsServerCore.Dns.Zones public bool SetRecords(DnsResourceRecordType type, IReadOnlyList records, bool serveStale) { - bool isFailureRecord = (records.Count > 0) && records[0].RDATA is DnsCache.DnsSpecialCacheRecord splRecord && splRecord.IsFailureOrBadCache; - if (isFailureRecord) + bool isFailureRecord = false; + + if (records.Count > 0) { - //call trying to cache failure record - if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) + if (records[0].RDATA is DnsCache.DnsSpecialCacheRecord splRecord) { - if ((existingRecords.Count > 0) && !(existingRecords[0].RDATA is DnsCache.DnsSpecialCacheRecord existingSplRecord && existingSplRecord.IsFailureOrBadCache) && !DnsResourceRecord.IsRRSetExpired(existingRecords, serveStale)) - return false; //skip to avoid overwriting a useful record with a failure record + if (splRecord.IsFailureOrBadCache) + { + //call trying to cache failure record + isFailureRecord = true; + + if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) + { + if ((existingRecords.Count > 0) && !(existingRecords[0].RDATA is DnsCache.DnsSpecialCacheRecord existingSplRecord && existingSplRecord.IsFailureOrBadCache) && !DnsResourceRecord.IsRRSetExpired(existingRecords, serveStale)) + return false; //skip to avoid overwriting a useful record with a failure record + } + } + } + else if ((type == DnsResourceRecordType.NS) && (records[0].RDATA is DnsNSRecordData ns) && !ns.IsParentSideTtlSet) + { + //for ns revalidation + if (_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList existingNSRecords)) + { + if ((existingNSRecords.Count > 0) && (existingNSRecords[0].RDATA is DnsNSRecordData existingNS) && existingNS.IsParentSideTtlSet) + { + uint parentSideTtl = existingNS.ParentSideTtl; + + foreach (DnsResourceRecord record in records) + (record.RDATA as DnsNSRecordData).ParentSideTtl = parentSideTtl; + } + } } } @@ -201,7 +224,12 @@ namespace DnsServerCore.Dns.Zones List anyRecords = new List(_entries.Count * 2); foreach (KeyValuePair> entry in _entries) + { + if (entry.Key == DnsResourceRecordType.DS) + continue; + anyRecords.AddRange(ValidateRRSet(type, entry.Value, serveStale, true)); + } return anyRecords;