From 4049b8a486e19dfb190d409bc52e415b44a81a8f Mon Sep 17 00:00:00 2001 From: Shreyas Zare Date: Sat, 30 Mar 2019 17:01:44 +0530 Subject: [PATCH] Zone: serve stale querying support implemented. Query response record shuffeling implemented to allow load balancing across all IP addresses. Added check in non authoritative zone type to remove CNAME entry if different type of entry is added in the zone to prevent issue with serve stale. --- DnsServerCore/Zone.cs | 103 +++++++++++++++++++++++++++++------------- 1 file changed, 71 insertions(+), 32 deletions(-) diff --git a/DnsServerCore/Zone.cs b/DnsServerCore/Zone.cs index 35b15372..474f5a9f 100644 --- a/DnsServerCore/Zone.cs +++ b/DnsServerCore/Zone.cs @@ -43,6 +43,7 @@ namespace DnsServerCore readonly ConcurrentDictionary _entries = new ConcurrentDictionary(); string _serverDomain; + uint _serveStaleTtl; #endregion @@ -96,7 +97,7 @@ namespace DnsServerCore private static string[] ConvertDomainToPath(string domainName) { - DnsDatagram.IsDomainNameValid(domainName, true); + DnsClient.IsDomainNameValid(domainName, true); if (string.IsNullOrEmpty(domainName)) return new string[] { }; @@ -257,7 +258,7 @@ namespace DnsServerCore } } - private DnsResourceRecord[] QueryRecords(DnsResourceRecordType type, bool bypassCNAME) + private DnsResourceRecord[] QueryRecords(DnsResourceRecordType type, bool bypassCNAME, bool serveStale) { if (_authoritativeZone && (type == DnsResourceRecordType.ANY)) { @@ -266,18 +267,25 @@ namespace DnsServerCore foreach (KeyValuePair entry in _entries) allRecords.AddRange(entry.Value); - return FilterExpiredDisabledRecords(allRecords.ToArray()); + return FilterExpiredDisabledRecords(allRecords.ToArray(), serveStale); } if (!bypassCNAME && _entries.TryGetValue(DnsResourceRecordType.CNAME, out DnsResourceRecord[] existingCNAMERecords)) { - DnsResourceRecord[] records = FilterExpiredDisabledRecords(existingCNAMERecords); + DnsResourceRecord[] records = FilterExpiredDisabledRecords(existingCNAMERecords, serveStale); if (records != null) return records; } if (_entries.TryGetValue(type, out DnsResourceRecord[] existingRecords)) - return FilterExpiredDisabledRecords(existingRecords); + { + DnsResourceRecord[] records = FilterExpiredDisabledRecords(existingRecords, serveStale); + + if (records != null) + DnsClient.ShuffleArray(records); //shuffle records to allow load balancing + + return records; + } return null; } @@ -309,7 +317,7 @@ namespace DnsServerCore private void ListAuthoritativeZones(List zones) { - DnsResourceRecord[] soa = QueryRecords(DnsResourceRecordType.SOA, true); + DnsResourceRecord[] soa = QueryRecords(DnsResourceRecordType.SOA, true, false); if (soa != null) zones.Add(this); @@ -323,6 +331,25 @@ namespace DnsServerCore { return records; }); + + if (!_authoritativeZone) + { + //this is only applicable for cache zone + switch (type) + { + case DnsResourceRecordType.CNAME: + case DnsResourceRecordType.SOA: + case DnsResourceRecordType.NS: + //do nothing + break; + + default: + //remove old CNAME entry since current new entry type overlaps any existing CNAME entry in cache + //keeping both entries will create issue with serve stale implementation since stale CNAME entry will be always returned + _entries.TryRemove(DnsResourceRecordType.CNAME, out DnsResourceRecord[] existingValues); + break; + } + } } private void AddRecord(DnsResourceRecord record) @@ -400,13 +427,16 @@ namespace DnsServerCore DeleteEmptyParentZones(this); } - private DnsResourceRecord[] FilterExpiredDisabledRecords(DnsResourceRecord[] records) + private DnsResourceRecord[] FilterExpiredDisabledRecords(DnsResourceRecord[] records, bool serveStale) { if (records.Length == 1) { - if (records[0].TTLValue < 1) + if (!serveStale && records[0].IsStale) return null; + if (records[0].TTLValue < 1u) + return null; //ttl expired + DnsResourceRecordInfo rrInfo = records[0].Tag as DnsResourceRecordInfo; if ((rrInfo != null) && rrInfo.Disabled) return null; @@ -418,9 +448,12 @@ namespace DnsServerCore foreach (DnsResourceRecord record in records) { - if (record.TTLValue < 1) + if (!serveStale && record.IsStale) continue; + if (record.TTLValue < 1u) + continue; //ttl expired + DnsResourceRecordInfo rrInfo = record.Tag as DnsResourceRecordInfo; if ((rrInfo != null) && rrInfo.Disabled) continue; @@ -461,14 +494,14 @@ namespace DnsServerCore return currentZone; } - private DnsResourceRecord[] QueryClosestCachedNameServers() + private DnsResourceRecord[] QueryClosestCachedNameServers(bool serveStale) { Zone currentZone = this; DnsResourceRecord[] nsRecords = null; while (currentZone != null) { - nsRecords = currentZone.QueryRecords(DnsResourceRecordType.NS, true); + nsRecords = currentZone.QueryRecords(DnsResourceRecordType.NS, true, serveStale); if ((nsRecords != null) && (nsRecords.Length > 0) && (nsRecords[0].RDATA is DnsNSRecord)) return nsRecords; @@ -485,11 +518,11 @@ namespace DnsServerCore while (currentZone != null) { - nsRecords = currentZone.QueryRecords(DnsResourceRecordType.SOA, true); + nsRecords = currentZone.QueryRecords(DnsResourceRecordType.SOA, true, false); if ((nsRecords != null) && (nsRecords.Length > 0) && (nsRecords[0].RDATA as DnsSOARecord).MasterNameServer.Equals(rootZoneServerDomain, StringComparison.CurrentCultureIgnoreCase)) return nsRecords; - nsRecords = currentZone.QueryRecords(DnsResourceRecordType.NS, true); + nsRecords = currentZone.QueryRecords(DnsResourceRecordType.NS, true, false); if ((nsRecords != null) && (nsRecords.Length > 0)) return nsRecords; @@ -508,7 +541,7 @@ namespace DnsServerCore { if (currentZone._entries.ContainsKey(DnsResourceRecordType.SOA)) { - nsRecords = currentZone.QueryRecords(DnsResourceRecordType.NS, true); + nsRecords = currentZone.QueryRecords(DnsResourceRecordType.NS, true, false); if ((nsRecords != null) && (nsRecords.Length > 0)) return nsRecords; @@ -521,7 +554,7 @@ namespace DnsServerCore return null; } - private static DnsResourceRecord[] QueryGlueRecords(Zone rootZone, DnsResourceRecord[] nsRecords) + private static DnsResourceRecord[] QueryGlueRecords(Zone rootZone, DnsResourceRecord[] nsRecords, bool serveStale) { List glueRecords = new List(); @@ -535,13 +568,13 @@ namespace DnsServerCore if ((zone != null) && !zone._disabled) { { - DnsResourceRecord[] records = zone.QueryRecords(DnsResourceRecordType.A, true); + DnsResourceRecord[] records = zone.QueryRecords(DnsResourceRecordType.A, true, serveStale); if ((records != null) && (records.Length > 0)) glueRecords.AddRange(records); } { - DnsResourceRecord[] records = zone.QueryRecords(DnsResourceRecordType.AAAA, true); + DnsResourceRecord[] records = zone.QueryRecords(DnsResourceRecordType.AAAA, true, serveStale); if ((records != null) && (records.Length > 0)) glueRecords.AddRange(records); } @@ -573,7 +606,7 @@ namespace DnsServerCore if (DomainEquals(closestZone._zoneName, domain)) { //zone found - DnsResourceRecord[] records = closestZone.QueryRecords(question.Type, false); + DnsResourceRecord[] records = closestZone.QueryRecords(question.Type, false, false); if (records == null) { //record type not found @@ -608,7 +641,7 @@ namespace DnsServerCore } else { - additional = QueryGlueRecords(rootZone, closestAuthoritativeNameServers); + additional = QueryGlueRecords(rootZone, closestAuthoritativeNameServers, false); } return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, true, false, request.Header.RecursionDesired, false, false, false, DnsResponseCode.NoError, 1, (ushort)records.Length, (ushort)closestAuthoritativeNameServers.Length, (ushort)additional.Length), request.Question, records, closestAuthoritativeNameServers, additional); @@ -623,13 +656,13 @@ namespace DnsServerCore else { //zone is delegated - DnsResourceRecord[] additional = QueryGlueRecords(rootZone, closestAuthority); + DnsResourceRecord[] additional = QueryGlueRecords(rootZone, closestAuthority, false); return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, false, false, false, DnsResponseCode.NoError, 1, 0, (ushort)closestAuthority.Length, (ushort)additional.Length), request.Question, new DnsResourceRecord[] { }, closestAuthority, additional); } } - private static DnsDatagram QueryCache(Zone rootZone, DnsDatagram request) + private static DnsDatagram QueryCache(Zone rootZone, DnsDatagram request, bool serveStale) { DnsQuestionRecord question = request.Question[0]; string domain = question.Name.ToLower(); @@ -638,7 +671,7 @@ namespace DnsServerCore if (closestZone._zoneName.Equals(domain)) { - DnsResourceRecord[] records = closestZone.QueryRecords(question.Type, false); + DnsResourceRecord[] records = closestZone.QueryRecords(question.Type, false, serveStale); if (records != null) { if (records[0].RDATA is DnsEmptyRecord) @@ -667,10 +700,10 @@ namespace DnsServerCore } } - DnsResourceRecord[] nameServers = closestZone.QueryClosestCachedNameServers(); + DnsResourceRecord[] nameServers = closestZone.QueryClosestCachedNameServers(serveStale); if (nameServers != null) { - DnsResourceRecord[] additional = QueryGlueRecords(rootZone, nameServers); + DnsResourceRecord[] additional = QueryGlueRecords(rootZone, nameServers, serveStale); return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, true, false, false, DnsResponseCode.NoError, 1, 0, (ushort)nameServers.Length, (ushort)additional.Length), request.Question, new DnsResourceRecord[] { }, nameServers, additional); } @@ -718,12 +751,12 @@ namespace DnsServerCore return groupedByDomainRecords; } - internal DnsDatagram Query(DnsDatagram request) + internal DnsDatagram Query(DnsDatagram request, bool serveStale = false) { if (_authoritativeZone) return QueryAuthoritative(this, request); - return QueryCache(this, request); + return QueryCache(this, request, serveStale); } internal void CacheResponse(DnsDatagram response) @@ -750,7 +783,7 @@ namespace DnsServerCore ttl = authority.TTLValue; DnsResourceRecord record = new DnsResourceRecord(question.Name, question.Type, DnsClass.IN, ttl, new DnsNXRecord(authority)); - record.SetExpiry(); + record.SetExpiry(_serveStaleTtl); CreateZone(this, question.Name).SetRecords(question.Type, new DnsResourceRecord[] { record }); } @@ -777,7 +810,7 @@ namespace DnsServerCore ttl = authority.TTLValue; DnsResourceRecord record = new DnsResourceRecord(question.Name, question.Type, DnsClass.IN, ttl, new DnsEmptyRecord(authority)); - record.SetExpiry(); + record.SetExpiry(_serveStaleTtl); CreateZone(this, question.Name).SetRecords(question.Type, new DnsResourceRecord[] { record }); } @@ -792,7 +825,7 @@ namespace DnsServerCore { //empty response from authority name server DnsResourceRecord record = new DnsResourceRecord(question.Name, question.Type, DnsClass.IN, DEFAULT_RECORD_TTL, new DnsEmptyRecord(null)); - record.SetExpiry(); + record.SetExpiry(_serveStaleTtl); CreateZone(this, question.Name).SetRecords(question.Type, new DnsResourceRecord[] { record }); break; @@ -807,7 +840,7 @@ namespace DnsServerCore foreach (DnsQuestionRecord question in response.Question) { DnsResourceRecord record = new DnsResourceRecord(question.Name, question.Type, DnsClass.IN, DEFAULT_RECORD_TTL, new DnsEmptyRecord(null)); - record.SetExpiry(); + record.SetExpiry(_serveStaleTtl); CreateZone(this, question.Name).SetRecords(question.Type, new DnsResourceRecord[] { record }); } @@ -826,7 +859,7 @@ namespace DnsServerCore //set expiry for cached records foreach (DnsResourceRecord record in allRecords) - record.SetExpiry(); + record.SetExpiry(_serveStaleTtl); SetRecords(allRecords); @@ -842,7 +875,7 @@ namespace DnsServerCore } DnsResourceRecord anyRR = new DnsResourceRecord(response.Question[0].Name, DnsResourceRecordType.ANY, DnsClass.IN, ttl, new DnsANYRecord(response.Answer)); - anyRR.SetExpiry(); + anyRR.SetExpiry(_serveStaleTtl); CreateZone(this, response.Question[0].Name).SetRecords(DnsResourceRecordType.ANY, new DnsResourceRecord[] { anyRR }); } @@ -1034,6 +1067,12 @@ namespace DnsServerCore set { _serverDomain = value; } } + public uint ServeStaleTtl + { + get { return _serveStaleTtl; } + set { _serveStaleTtl = value; } + } + #endregion public class ZoneInfo : IComparable