From 1e79190c414cb709d3f8a55e9dcb9c0012accfcd Mon Sep 17 00:00:00 2001 From: Shreyas Zare Date: Sat, 4 Nov 2017 14:33:15 +0530 Subject: [PATCH] Dns Zone: added disable zone feature. Bug fixes and code refactoring done. --- DnsServerCore/Zone.cs | 359 +++++++++++++++++++++++++++--------------- 1 file changed, 231 insertions(+), 128 deletions(-) diff --git a/DnsServerCore/Zone.cs b/DnsServerCore/Zone.cs index 5205517f..bc549386 100644 --- a/DnsServerCore/Zone.cs +++ b/DnsServerCore/Zone.cs @@ -37,6 +37,8 @@ namespace DnsServerCore readonly string _zoneLabel; readonly string _zoneName; + bool _disabled; + readonly ConcurrentDictionary _zones = new ConcurrentDictionary(); readonly ConcurrentDictionary _entries = new ConcurrentDictionary(); @@ -134,18 +136,20 @@ namespace DnsServerCore currentZone = nextZone; else return currentZone; + + if (currentZone._disabled) + return currentZone; } return currentZone; } - private static Zone DeleteZone(Zone rootZone, string domain) + private static Zone GetZone(Zone rootZone, string domain) { Zone currentZone = rootZone; string[] path = ConvertDomainToPath(domain); - //find parent zone - for (int i = 0; i < path.Length - 1; i++) + for (int i = 0; i < path.Length; i++) { string nextZoneName = path[i]; @@ -155,24 +159,130 @@ namespace DnsServerCore return null; } - if (currentZone._zones.TryRemove(path[path.Length - 1], out Zone deletedZone)) - return deletedZone; + return currentZone; + } + + private static Zone[] DeleteZone(Zone rootZone, string domain) + { + Zone currentZone = GetZone(rootZone, domain); + if (currentZone == null) + return null; + + if (!currentZone._authoritativeZone && (currentZone._zoneName.Equals("root-servers.net", StringComparison.CurrentCultureIgnoreCase))) + return null; //cannot delete root-servers.net + + currentZone._entries.Clear(); + + List deletedSubDomains = new List(); + + DeleteSubDomains(currentZone, deletedSubDomains); + + DeleteEmptyParentZones(currentZone); + + return deletedSubDomains.ToArray(); + } + + private static bool DeleteSubDomains(Zone currentZone, List deletedSubDomains) + { + if (currentZone._authoritativeZone) + { + if (currentZone._entries.ContainsKey(DnsResourceRecordType.SOA)) + return false; //this is a zone so return false + } + else + { + //cache zone + if (currentZone._zoneName.Equals("root-servers.net", StringComparison.CurrentCultureIgnoreCase)) + return false; //cannot delete root-servers.net + } + + currentZone._entries.Clear(); + deletedSubDomains.Add(currentZone); + + List subDomainsToDelete = new List(); + + foreach (KeyValuePair zone in currentZone._zones) + { + if (DeleteSubDomains(zone.Value, deletedSubDomains)) + subDomainsToDelete.Add(zone.Value); + } + + foreach (Zone subDomain in subDomainsToDelete) + currentZone._zones.TryRemove(subDomain._zoneLabel, out Zone deletedValue); + + return (currentZone._zones.Count == 0); + } + + private static void DeleteEmptyParentZones(Zone currentZone) + { + while (true) + { + if ((currentZone._entries.Count > 0) || (currentZone._zones.Count > 0)) + break; + + currentZone._parentZone._zones.TryRemove(currentZone._zoneLabel, out Zone deletedZone); + + currentZone = currentZone._parentZone; + } + } + + private DnsResourceRecord[] QueryRecords(DnsResourceRecordType type, bool bypassCNAME = false) + { + if (!bypassCNAME && _entries.TryGetValue(DnsResourceRecordType.CNAME, out DnsResourceRecord[] existingCNAMERecords)) + { + if (_authoritativeZone) + return existingCNAMERecords; + + return FilterExpiredRecords(existingCNAMERecords); + } + + if (_entries.TryGetValue(type, out DnsResourceRecord[] existingRecords)) + { + if (_authoritativeZone) + return existingRecords; + + return FilterExpiredRecords(existingRecords); + } return null; } - private static DnsResourceRecord[] GetRecords(Zone rootZone, string domain, DnsResourceRecordType type) + private DnsResourceRecord[] GetAllRecords(bool includeSubDomains) { - Zone closestZone = FindClosestZone(rootZone, domain); + List allRecords = new List(); - if (closestZone._zoneName.Equals(domain, StringComparison.CurrentCultureIgnoreCase)) - return closestZone.GetRecords(type); + foreach (KeyValuePair entry in _entries) + { + if (entry.Key != DnsResourceRecordType.ANY) + allRecords.AddRange(entry.Value); + } - return null; + if (includeSubDomains) + { + foreach (KeyValuePair zone in _zones) + { + if (!zone.Value._entries.ContainsKey(DnsResourceRecordType.SOA)) + allRecords.AddRange(zone.Value.GetAllRecords(true)); + } + } + + return allRecords.ToArray(); } private void SetRecords(DnsResourceRecordType type, DnsResourceRecord[] records) { + if (type == DnsResourceRecordType.CNAME) + { + //delete all sub zones and entries except SOA + _zones.Clear(); + + foreach (DnsResourceRecordType key in _entries.Keys) + { + if (key != DnsResourceRecordType.SOA) + _entries.TryRemove(key, out DnsResourceRecord[] removedValues); + } + } + _entries.AddOrUpdate(type, records, delegate (DnsResourceRecordType key, DnsResourceRecord[] existingRecords) { return records; @@ -243,57 +353,7 @@ namespace DnsServerCore { _entries.TryRemove(type, out DnsResourceRecord[] existingValues); - DeleteEmptyZones(this); - } - - private static void DeleteEmptyZones(Zone currentZone) - { - while (true) - { - if ((currentZone._entries.Count > 0) || (currentZone._zones.Count > 0)) - break; - - currentZone._parentZone._zones.TryRemove(currentZone._zoneLabel, out Zone deletedZone); - - currentZone = currentZone._parentZone; - } - } - - private DnsResourceRecord[] GetRecords(DnsResourceRecordType type, bool includeSubDomainsForANY = false) - { - if (_entries.TryGetValue(DnsResourceRecordType.CNAME, out DnsResourceRecord[] existingCNAMERecords)) - { - if (_authoritativeZone) - return existingCNAMERecords; - - return FilterExpiredRecords(existingCNAMERecords); - } - - if ((type == DnsResourceRecordType.ANY) && _authoritativeZone) - { - List allRecords = new List(); - - foreach (KeyValuePair entry in _entries) - allRecords.AddRange(entry.Value); - - if (includeSubDomainsForANY) - { - foreach (KeyValuePair zone in _zones) - allRecords.AddRange(zone.Value.GetRecords(DnsResourceRecordType.ANY, true)); - } - - return allRecords.ToArray(); - } - - if (_entries.TryGetValue(type, out DnsResourceRecord[] existingRecords)) - { - if (_authoritativeZone) - return existingRecords; - - return FilterExpiredRecords(existingRecords); - } - - return null; + DeleteEmptyParentZones(this); } private DnsResourceRecord[] FilterExpiredRecords(DnsResourceRecord[] records) @@ -327,7 +387,7 @@ namespace DnsServerCore while (currentZone != null) { - nsRecords = currentZone.GetRecords(DnsResourceRecordType.NS); + nsRecords = currentZone.QueryRecords(DnsResourceRecordType.NS); if ((nsRecords != null) && (nsRecords.Length > 0) && (nsRecords[0].Type == DnsResourceRecordType.NS)) return nsRecords; @@ -344,7 +404,7 @@ namespace DnsServerCore while (currentZone != null) { - nsRecords = currentZone.GetRecords(DnsResourceRecordType.SOA); + nsRecords = currentZone.QueryRecords(DnsResourceRecordType.SOA); if ((nsRecords != null) && (nsRecords.Length > 0) && (nsRecords[0].Type == DnsResourceRecordType.SOA)) return nsRecords; @@ -354,16 +414,27 @@ namespace DnsServerCore return null; } + private static DnsResourceRecord[] GetGlueRecords(Zone rootZone, string domain, DnsResourceRecordType type) + { + Zone currentZone = GetZone(rootZone, domain); + if (currentZone != null) + { + DnsResourceRecord[] records = currentZone.QueryRecords(type); + if ((records != null) && (records.Length > 0) && (records[0].Type == type)) + return records; + } + + return null; + } + private void GetAuthoritativeZones(List zones) { - DnsResourceRecord[] soa = GetRecords(DnsResourceRecordType.SOA); + DnsResourceRecord[] soa = QueryRecords(DnsResourceRecordType.SOA, true); if ((soa != null) && (soa[0].Type == DnsResourceRecordType.SOA)) zones.Add(this); foreach (KeyValuePair entry in _zones) - { entry.Value.GetAuthoritativeZones(zones); - } } private static DnsDatagram QueryAuthoritative(Zone rootZone, DnsDatagram request) @@ -373,10 +444,13 @@ namespace DnsServerCore Zone closestZone = FindClosestZone(rootZone, domain); + if (closestZone._disabled) + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, false, false, false, DnsResponseCode.Refused, 1, 0, 0, 0), request.Question, new DnsResourceRecord[] { }, new DnsResourceRecord[] { }, new DnsResourceRecord[] { }); + if (closestZone._zoneName.Equals(domain)) { //zone found - DnsResourceRecord[] records = closestZone.GetRecords(question.Type); + DnsResourceRecord[] records = closestZone.QueryRecords(question.Type); if (records == null) { //record type not found @@ -414,7 +488,7 @@ namespace DnsServerCore if (closestZone._zoneName.Equals(domain)) { - DnsResourceRecord[] records = closestZone.GetRecords(question.Type); + DnsResourceRecord[] records = closestZone.QueryRecords(question.Type); if (records != null) { if (records[0].RDATA is DnsEmptyRecord) @@ -442,11 +516,11 @@ namespace DnsServerCore { string nsDomain = (nameServer.RDATA as DnsNSRecord).NSDomainName; - DnsResourceRecord[] glueAs = GetRecords(rootZone, nsDomain, DnsResourceRecordType.A); + DnsResourceRecord[] glueAs = GetGlueRecords(rootZone, nsDomain, DnsResourceRecordType.A); if (glueAs != null) glueRecords.AddRange(glueAs); - DnsResourceRecord[] glueAAAAs = GetRecords(rootZone, nsDomain, DnsResourceRecordType.AAAA); + DnsResourceRecord[] glueAAAAs = GetGlueRecords(rootZone, nsDomain, DnsResourceRecordType.AAAA); if (glueAAAAs != null) glueRecords.AddRange(glueAAAAs); } @@ -627,8 +701,8 @@ namespace DnsServerCore if (oldRecord.Type == DnsResourceRecordType.SOA) throw new DnsServerException("Cannot update record: use SetRecords() for updating SOA record."); - Zone zone = FindClosestZone(this, oldRecord.Name); - if (!zone._zoneName.Equals(oldRecord.Name, StringComparison.CurrentCultureIgnoreCase)) + Zone currentZone = GetZone(this, oldRecord.Name); + if (currentZone == null) throw new DnsServerException("Cannot update record: old record does not exists."); switch (oldRecord.Type) @@ -637,20 +711,20 @@ namespace DnsServerCore case DnsResourceRecordType.PTR: if (oldRecord.Name.Equals(newRecord.Name, StringComparison.CurrentCultureIgnoreCase)) { - zone.SetRecords(newRecord.Type, new DnsResourceRecord[] { newRecord }); + currentZone.SetRecords(newRecord.Type, new DnsResourceRecord[] { newRecord }); } else { - zone.DeleteRecords(oldRecord.Type); + currentZone.DeleteRecords(oldRecord.Type); CreateZone(this, newRecord.Name).SetRecords(newRecord.Type, new DnsResourceRecord[] { newRecord }); } break; default: - zone.DeleteRecord(oldRecord); + currentZone.DeleteRecord(oldRecord); if (oldRecord.Name.Equals(newRecord.Name, StringComparison.CurrentCultureIgnoreCase)) - zone.AddRecord(newRecord); + currentZone.AddRecord(newRecord); else CreateZone(this, newRecord.Name).AddRecord(newRecord); @@ -660,35 +734,25 @@ namespace DnsServerCore public void DeleteRecord(string domain, DnsResourceRecordType type, DnsResourceRecordData record) { - Zone zone = FindClosestZone(this, domain); - if (zone._zoneName.Equals(domain, StringComparison.CurrentCultureIgnoreCase)) - zone.DeleteRecord(new DnsResourceRecord(domain, type, DnsClass.IN, 0, record)); + Zone currentZone = GetZone(this, domain); + if (currentZone != null) + currentZone.DeleteRecord(new DnsResourceRecord(domain, type, DnsClass.IN, 0, record)); } public void DeleteRecords(string domain, DnsResourceRecordType type) { - Zone zone = FindClosestZone(this, domain); - if (zone._zoneName.Equals(domain, StringComparison.CurrentCultureIgnoreCase)) - zone.DeleteRecords(type); + Zone currentZone = GetZone(this, domain); + if (currentZone != null) + currentZone.DeleteRecords(type); } - public DnsResourceRecord[] GetRecords(string domain = "", bool includeSubDomains = true) + public DnsResourceRecord[] GetAllRecords(string domain = "", bool includeSubDomains = true) { - Zone currentZone = this; + Zone currentZone = GetZone(this, domain); + if (currentZone == null) + return null; - string[] path = ConvertDomainToPath(domain); - - for (int i = 0; i < path.Length; i++) - { - string nextZoneName = path[i]; - - if (currentZone._zones.TryGetValue(nextZoneName, out Zone nextZone)) - currentZone = nextZone; - else - return null; //no zone for given domain - } - - DnsResourceRecord[] records = currentZone.GetRecords(DnsResourceRecordType.ANY, includeSubDomains); + DnsResourceRecord[] records = currentZone.GetAllRecords(includeSubDomains); if (records != null) return records; @@ -697,19 +761,9 @@ namespace DnsServerCore public string[] ListSubZones(string domain = "") { - Zone currentZone = this; - - string[] path = ConvertDomainToPath(domain); - - for (int i = 0; i < path.Length; i++) - { - string nextZoneName = path[i]; - - if (currentZone._zones.TryGetValue(nextZoneName, out Zone nextZone)) - currentZone = nextZone; - else - return new string[] { }; //no zone for given domain - } + Zone currentZone = GetZone(this, domain); + if (currentZone == null) + return new string[] { }; //no zone for given domain string[] subZoneNames = new string[currentZone._zones.Keys.Count]; currentZone._zones.Keys.CopyTo(subZoneNames, 0); @@ -717,36 +771,49 @@ namespace DnsServerCore return subZoneNames; } - public string[] ListAuthoritativeZones(string domain = "") + public ZoneInfo[] ListAuthoritativeZones(string domain = "") { - Zone currentZone = this; - - string[] path = ConvertDomainToPath(domain); - - for (int i = 0; i < path.Length; i++) - { - string nextZoneName = path[i]; - - if (currentZone._zones.TryGetValue(nextZoneName, out Zone nextZone)) - currentZone = nextZone; - else - return new string[] { }; //no zone for given domain - } + Zone currentZone = GetZone(this, domain); + if (currentZone == null) + return new ZoneInfo[] { }; //no zone for given domain List zones = new List(); currentZone.GetAuthoritativeZones(zones); - List zoneNames = new List(); + List zoneNames = new List(); foreach (Zone zone in zones) - zoneNames.Add(zone._zoneName); + zoneNames.Add(new ZoneInfo(zone)); return zoneNames.ToArray(); } - public bool DeleteZone(string domain) + public string[] DeleteZone(string domain) { - return (DeleteZone(this, domain) != null); + Zone[] deletedZones = DeleteZone(this, domain); + if (deletedZones == null) + return new string[] { }; + + List deletedZoneNames = new List(); + + foreach (Zone deletedZone in deletedZones) + deletedZoneNames.Add(deletedZone._zoneName); + + return deletedZoneNames.ToArray(); + } + + public void DisableZone(string domain) + { + Zone currentZone = GetZone(this, domain); + if (currentZone != null) + currentZone._disabled = true; + } + + public void EnableZone(string domain) + { + Zone currentZone = GetZone(this, domain); + if (currentZone != null) + currentZone._disabled = false; } public void Flush() @@ -757,6 +824,42 @@ namespace DnsServerCore #endregion + public class ZoneInfo + { + #region variables + + readonly string _zoneName; + readonly bool _disabled; + + #endregion + + #region constructor + + public ZoneInfo(string zoneName, bool disabled) + { + _zoneName = zoneName; + _disabled = disabled; + } + + public ZoneInfo(Zone zone) + { + _zoneName = zone._zoneName; + _disabled = zone._disabled; + } + + #endregion + + #region properties + + public string ZoneName + { get { return _zoneName; } } + + public bool Disabled + { get { return _disabled; } } + + #endregion + } + class DnsNXRecord : DnsResourceRecordData { #region variables