diff --git a/DnsServerCore/Zone.cs b/DnsServerCore/Zone.cs index 35b15372..474f5a9f 100644 --- a/DnsServerCore/Zone.cs +++ b/DnsServerCore/Zone.cs @@ -43,6 +43,7 @@ namespace DnsServerCore readonly ConcurrentDictionary _entries = new ConcurrentDictionary(); string _serverDomain; + uint _serveStaleTtl; #endregion @@ -96,7 +97,7 @@ namespace DnsServerCore private static string[] ConvertDomainToPath(string domainName) { - DnsDatagram.IsDomainNameValid(domainName, true); + DnsClient.IsDomainNameValid(domainName, true); if (string.IsNullOrEmpty(domainName)) return new string[] { }; @@ -257,7 +258,7 @@ namespace DnsServerCore } } - private DnsResourceRecord[] QueryRecords(DnsResourceRecordType type, bool bypassCNAME) + private DnsResourceRecord[] QueryRecords(DnsResourceRecordType type, bool bypassCNAME, bool serveStale) { if (_authoritativeZone && (type == DnsResourceRecordType.ANY)) { @@ -266,18 +267,25 @@ namespace DnsServerCore foreach (KeyValuePair entry in _entries) allRecords.AddRange(entry.Value); - return FilterExpiredDisabledRecords(allRecords.ToArray()); + return FilterExpiredDisabledRecords(allRecords.ToArray(), serveStale); } if (!bypassCNAME && _entries.TryGetValue(DnsResourceRecordType.CNAME, out DnsResourceRecord[] existingCNAMERecords)) { - DnsResourceRecord[] records = FilterExpiredDisabledRecords(existingCNAMERecords); + DnsResourceRecord[] records = FilterExpiredDisabledRecords(existingCNAMERecords, serveStale); if (records != null) return records; } if (_entries.TryGetValue(type, out DnsResourceRecord[] existingRecords)) - return FilterExpiredDisabledRecords(existingRecords); + { + DnsResourceRecord[] records = FilterExpiredDisabledRecords(existingRecords, serveStale); + + if (records != null) + DnsClient.ShuffleArray(records); //shuffle records to allow load balancing + + return records; + } return null; } @@ -309,7 +317,7 @@ namespace DnsServerCore private void ListAuthoritativeZones(List zones) { - DnsResourceRecord[] soa = QueryRecords(DnsResourceRecordType.SOA, true); + DnsResourceRecord[] soa = QueryRecords(DnsResourceRecordType.SOA, true, false); if (soa != null) zones.Add(this); @@ -323,6 +331,25 @@ namespace DnsServerCore { return records; }); + + if (!_authoritativeZone) + { + //this is only applicable for cache zone + switch (type) + { + case DnsResourceRecordType.CNAME: + case DnsResourceRecordType.SOA: + case DnsResourceRecordType.NS: + //do nothing + break; + + default: + //remove old CNAME entry since current new entry type overlaps any existing CNAME entry in cache + //keeping both entries will create issue with serve stale implementation since stale CNAME entry will be always returned + _entries.TryRemove(DnsResourceRecordType.CNAME, out DnsResourceRecord[] existingValues); + break; + } + } } private void AddRecord(DnsResourceRecord record) @@ -400,13 +427,16 @@ namespace DnsServerCore DeleteEmptyParentZones(this); } - private DnsResourceRecord[] FilterExpiredDisabledRecords(DnsResourceRecord[] records) + private DnsResourceRecord[] FilterExpiredDisabledRecords(DnsResourceRecord[] records, bool serveStale) { if (records.Length == 1) { - if (records[0].TTLValue < 1) + if (!serveStale && records[0].IsStale) return null; + if (records[0].TTLValue < 1u) + return null; //ttl expired + DnsResourceRecordInfo rrInfo = records[0].Tag as DnsResourceRecordInfo; if ((rrInfo != null) && rrInfo.Disabled) return null; @@ -418,9 +448,12 @@ namespace DnsServerCore foreach (DnsResourceRecord record in records) { - if (record.TTLValue < 1) + if (!serveStale && record.IsStale) continue; + if (record.TTLValue < 1u) + continue; //ttl expired + DnsResourceRecordInfo rrInfo = record.Tag as DnsResourceRecordInfo; if ((rrInfo != null) && rrInfo.Disabled) continue; @@ -461,14 +494,14 @@ namespace DnsServerCore return currentZone; } - private DnsResourceRecord[] QueryClosestCachedNameServers() + private DnsResourceRecord[] QueryClosestCachedNameServers(bool serveStale) { Zone currentZone = this; DnsResourceRecord[] nsRecords = null; while (currentZone != null) { - nsRecords = currentZone.QueryRecords(DnsResourceRecordType.NS, true); + nsRecords = currentZone.QueryRecords(DnsResourceRecordType.NS, true, serveStale); if ((nsRecords != null) && (nsRecords.Length > 0) && (nsRecords[0].RDATA is DnsNSRecord)) return nsRecords; @@ -485,11 +518,11 @@ namespace DnsServerCore while (currentZone != null) { - nsRecords = currentZone.QueryRecords(DnsResourceRecordType.SOA, true); + nsRecords = currentZone.QueryRecords(DnsResourceRecordType.SOA, true, false); if ((nsRecords != null) && (nsRecords.Length > 0) && (nsRecords[0].RDATA as DnsSOARecord).MasterNameServer.Equals(rootZoneServerDomain, StringComparison.CurrentCultureIgnoreCase)) return nsRecords; - nsRecords = currentZone.QueryRecords(DnsResourceRecordType.NS, true); + nsRecords = currentZone.QueryRecords(DnsResourceRecordType.NS, true, false); if ((nsRecords != null) && (nsRecords.Length > 0)) return nsRecords; @@ -508,7 +541,7 @@ namespace DnsServerCore { if (currentZone._entries.ContainsKey(DnsResourceRecordType.SOA)) { - nsRecords = currentZone.QueryRecords(DnsResourceRecordType.NS, true); + nsRecords = currentZone.QueryRecords(DnsResourceRecordType.NS, true, false); if ((nsRecords != null) && (nsRecords.Length > 0)) return nsRecords; @@ -521,7 +554,7 @@ namespace DnsServerCore return null; } - private static DnsResourceRecord[] QueryGlueRecords(Zone rootZone, DnsResourceRecord[] nsRecords) + private static DnsResourceRecord[] QueryGlueRecords(Zone rootZone, DnsResourceRecord[] nsRecords, bool serveStale) { List glueRecords = new List(); @@ -535,13 +568,13 @@ namespace DnsServerCore if ((zone != null) && !zone._disabled) { { - DnsResourceRecord[] records = zone.QueryRecords(DnsResourceRecordType.A, true); + DnsResourceRecord[] records = zone.QueryRecords(DnsResourceRecordType.A, true, serveStale); if ((records != null) && (records.Length > 0)) glueRecords.AddRange(records); } { - DnsResourceRecord[] records = zone.QueryRecords(DnsResourceRecordType.AAAA, true); + DnsResourceRecord[] records = zone.QueryRecords(DnsResourceRecordType.AAAA, true, serveStale); if ((records != null) && (records.Length > 0)) glueRecords.AddRange(records); } @@ -573,7 +606,7 @@ namespace DnsServerCore if (DomainEquals(closestZone._zoneName, domain)) { //zone found - DnsResourceRecord[] records = closestZone.QueryRecords(question.Type, false); + DnsResourceRecord[] records = closestZone.QueryRecords(question.Type, false, false); if (records == null) { //record type not found @@ -608,7 +641,7 @@ namespace DnsServerCore } else { - additional = QueryGlueRecords(rootZone, closestAuthoritativeNameServers); + additional = QueryGlueRecords(rootZone, closestAuthoritativeNameServers, false); } return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, true, false, request.Header.RecursionDesired, false, false, false, DnsResponseCode.NoError, 1, (ushort)records.Length, (ushort)closestAuthoritativeNameServers.Length, (ushort)additional.Length), request.Question, records, closestAuthoritativeNameServers, additional); @@ -623,13 +656,13 @@ namespace DnsServerCore else { //zone is delegated - DnsResourceRecord[] additional = QueryGlueRecords(rootZone, closestAuthority); + DnsResourceRecord[] additional = QueryGlueRecords(rootZone, closestAuthority, false); return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, false, false, false, DnsResponseCode.NoError, 1, 0, (ushort)closestAuthority.Length, (ushort)additional.Length), request.Question, new DnsResourceRecord[] { }, closestAuthority, additional); } } - private static DnsDatagram QueryCache(Zone rootZone, DnsDatagram request) + private static DnsDatagram QueryCache(Zone rootZone, DnsDatagram request, bool serveStale) { DnsQuestionRecord question = request.Question[0]; string domain = question.Name.ToLower(); @@ -638,7 +671,7 @@ namespace DnsServerCore if (closestZone._zoneName.Equals(domain)) { - DnsResourceRecord[] records = closestZone.QueryRecords(question.Type, false); + DnsResourceRecord[] records = closestZone.QueryRecords(question.Type, false, serveStale); if (records != null) { if (records[0].RDATA is DnsEmptyRecord) @@ -667,10 +700,10 @@ namespace DnsServerCore } } - DnsResourceRecord[] nameServers = closestZone.QueryClosestCachedNameServers(); + DnsResourceRecord[] nameServers = closestZone.QueryClosestCachedNameServers(serveStale); if (nameServers != null) { - DnsResourceRecord[] additional = QueryGlueRecords(rootZone, nameServers); + DnsResourceRecord[] additional = QueryGlueRecords(rootZone, nameServers, serveStale); return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, true, false, false, DnsResponseCode.NoError, 1, 0, (ushort)nameServers.Length, (ushort)additional.Length), request.Question, new DnsResourceRecord[] { }, nameServers, additional); } @@ -718,12 +751,12 @@ namespace DnsServerCore return groupedByDomainRecords; } - internal DnsDatagram Query(DnsDatagram request) + internal DnsDatagram Query(DnsDatagram request, bool serveStale = false) { if (_authoritativeZone) return QueryAuthoritative(this, request); - return QueryCache(this, request); + return QueryCache(this, request, serveStale); } internal void CacheResponse(DnsDatagram response) @@ -750,7 +783,7 @@ namespace DnsServerCore ttl = authority.TTLValue; DnsResourceRecord record = new DnsResourceRecord(question.Name, question.Type, DnsClass.IN, ttl, new DnsNXRecord(authority)); - record.SetExpiry(); + record.SetExpiry(_serveStaleTtl); CreateZone(this, question.Name).SetRecords(question.Type, new DnsResourceRecord[] { record }); } @@ -777,7 +810,7 @@ namespace DnsServerCore ttl = authority.TTLValue; DnsResourceRecord record = new DnsResourceRecord(question.Name, question.Type, DnsClass.IN, ttl, new DnsEmptyRecord(authority)); - record.SetExpiry(); + record.SetExpiry(_serveStaleTtl); CreateZone(this, question.Name).SetRecords(question.Type, new DnsResourceRecord[] { record }); } @@ -792,7 +825,7 @@ namespace DnsServerCore { //empty response from authority name server DnsResourceRecord record = new DnsResourceRecord(question.Name, question.Type, DnsClass.IN, DEFAULT_RECORD_TTL, new DnsEmptyRecord(null)); - record.SetExpiry(); + record.SetExpiry(_serveStaleTtl); CreateZone(this, question.Name).SetRecords(question.Type, new DnsResourceRecord[] { record }); break; @@ -807,7 +840,7 @@ namespace DnsServerCore foreach (DnsQuestionRecord question in response.Question) { DnsResourceRecord record = new DnsResourceRecord(question.Name, question.Type, DnsClass.IN, DEFAULT_RECORD_TTL, new DnsEmptyRecord(null)); - record.SetExpiry(); + record.SetExpiry(_serveStaleTtl); CreateZone(this, question.Name).SetRecords(question.Type, new DnsResourceRecord[] { record }); } @@ -826,7 +859,7 @@ namespace DnsServerCore //set expiry for cached records foreach (DnsResourceRecord record in allRecords) - record.SetExpiry(); + record.SetExpiry(_serveStaleTtl); SetRecords(allRecords); @@ -842,7 +875,7 @@ namespace DnsServerCore } DnsResourceRecord anyRR = new DnsResourceRecord(response.Question[0].Name, DnsResourceRecordType.ANY, DnsClass.IN, ttl, new DnsANYRecord(response.Answer)); - anyRR.SetExpiry(); + anyRR.SetExpiry(_serveStaleTtl); CreateZone(this, response.Question[0].Name).SetRecords(DnsResourceRecordType.ANY, new DnsResourceRecord[] { anyRR }); } @@ -1034,6 +1067,12 @@ namespace DnsServerCore set { _serverDomain = value; } } + public uint ServeStaleTtl + { + get { return _serveStaleTtl; } + set { _serveStaleTtl = value; } + } + #endregion public class ZoneInfo : IComparable