diff --git a/DnsServerCore/Zone.cs b/DnsServerCore/Zone.cs index 2b8118f7..110aa3a8 100644 --- a/DnsServerCore/Zone.cs +++ b/DnsServerCore/Zone.cs @@ -34,6 +34,7 @@ namespace DnsServerCore readonly bool _authoritativeZone; readonly Zone _parentZone; + readonly string _zoneLabel; readonly string _zoneName; readonly ConcurrentDictionary _zones = new ConcurrentDictionary(); @@ -56,6 +57,7 @@ namespace DnsServerCore { _authoritativeZone = parentZone._authoritativeZone; _parentZone = parentZone; + _zoneLabel = zoneLabel; string zoneName = zoneLabel; @@ -183,8 +185,8 @@ namespace DnsServerCore { foreach (DnsResourceRecord existingRecord in existingRecords) { - if (record.Equals(existingRecord)) - return existingRecords; + if (record.RDATA.Equals(existingRecord.RDATA)) + throw new DnsServerException("Resource record already exists."); } DnsResourceRecord[] newValue = new DnsResourceRecord[existingRecords.Length + 1]; @@ -204,7 +206,7 @@ namespace DnsServerCore for (int i = 0; i < existingRecords.Length; i++) { - if (record.Equals(existingRecords[i])) + if (record.RDATA.Equals(existingRecords[i].RDATA)) { existingRecords[i] = null; recordFound = true; @@ -213,32 +215,59 @@ namespace DnsServerCore } if (!recordFound) - return; + throw new DnsServerException("Resource record does not exists."); - DnsResourceRecord[] newRecords = new DnsResourceRecord[existingRecords.Length - 1]; - - for (int i = 0, j = 0; i < existingRecords.Length; i++) + if (existingRecords.Length == 1) { - if (existingRecords[i] != null) - newRecords[j++] = existingRecords[i]; + DeleteRecords(record.Type); } - - _entries.AddOrUpdate(record.Type, newRecords, delegate (DnsResourceRecordType key, DnsResourceRecord[] oldValue) + else { - return newRecords; - }); + DnsResourceRecord[] newRecords = new DnsResourceRecord[existingRecords.Length - 1]; + + for (int i = 0, j = 0; i < existingRecords.Length; i++) + { + if (existingRecords[i] != null) + newRecords[j++] = existingRecords[i]; + } + + _entries.AddOrUpdate(record.Type, newRecords, delegate (DnsResourceRecordType key, DnsResourceRecord[] oldValue) + { + return newRecords; + }); + } } } private void DeleteRecords(DnsResourceRecordType type) { _entries.TryRemove(type, out DnsResourceRecord[] existingValues); + + DeleteEmptyZones(this); } - private DnsResourceRecord[] GetRecords(DnsResourceRecordType type) + private static void DeleteEmptyZones(Zone currentZone) + { + while (true) + { + if ((currentZone._entries.Count > 0) || (currentZone._zones.Count > 0)) + break; + + currentZone._parentZone._zones.TryRemove(currentZone._zoneLabel, out Zone deletedZone); + + currentZone = currentZone._parentZone; + } + } + + private DnsResourceRecord[] GetRecords(DnsResourceRecordType type, bool includeSubDomainsForANY = false) { if (_entries.TryGetValue(DnsResourceRecordType.CNAME, out DnsResourceRecord[] existingCNAMERecords)) - return existingCNAMERecords; + { + if (_authoritativeZone) + return existingCNAMERecords; + + return FilterExpiredRecords(existingCNAMERecords); + } if ((type == DnsResourceRecordType.ANY) && _authoritativeZone) { @@ -247,11 +276,46 @@ namespace DnsServerCore foreach (KeyValuePair entry in _entries) allRecords.AddRange(entry.Value); + if (includeSubDomainsForANY) + { + foreach (KeyValuePair zone in _zones) + allRecords.AddRange(zone.Value.GetRecords(DnsResourceRecordType.ANY, true)); + } + return allRecords.ToArray(); } if (_entries.TryGetValue(type, out DnsResourceRecord[] existingRecords)) - return existingRecords; + { + if (_authoritativeZone) + return existingRecords; + + return FilterExpiredRecords(existingRecords); + } + + return null; + } + + private DnsResourceRecord[] FilterExpiredRecords(DnsResourceRecord[] records) + { + if (records.Length == 1) + { + if (records[0].TTLValue < 1) + return null; + + return records; + } + + List newRecords = new List(records.Length); + + foreach (DnsResourceRecord record in records) + { + if (record.TTLValue > 0) + newRecords.Add(record); + } + + if (newRecords.Count > 0) + return newRecords.ToArray(); return null; } @@ -292,7 +356,8 @@ namespace DnsServerCore private void GetAuthoritativeZones(List zones) { - if (GetRecords(DnsResourceRecordType.SOA) != null) + DnsResourceRecord[] soa = GetRecords(DnsResourceRecordType.SOA); + if ((soa != null) && (soa[0].Type == DnsResourceRecordType.SOA)) zones.Add(this); foreach (KeyValuePair entry in _zones) @@ -325,10 +390,6 @@ namespace DnsServerCore else { //record type found - - if ((records.Length > 0) && (records[0].Type == DnsResourceRecordType.CNAME)) - records = ResolveCNAME(rootZone, records[0], question.Type); - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, true, false, request.Header.RecursionDesired, false, false, false, DnsResponseCode.NoError, 1, (ushort)records.Length, 0, 0), request.Question, records, new DnsResourceRecord[] { }, new DnsResourceRecord[] { }); } } @@ -362,9 +423,6 @@ namespace DnsServerCore if (records[0].RDATA is DnsNXRecord) return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, true, false, false, DnsResponseCode.NameError, 1, 0, 1, 0), request.Question, new DnsResourceRecord[] { }, new DnsResourceRecord[] { (records[0].RDATA as DnsNXRecord).Authority }, new DnsResourceRecord[] { }); - if ((records.Length > 0) && (records[0].Type == DnsResourceRecordType.CNAME)) - records = ResolveCNAME(rootZone, records[0], question.Type); - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, true, false, false, DnsResponseCode.NoError, 1, (ushort)records.Length, 0, 0), request.Question, records, new DnsResourceRecord[] { }, new DnsResourceRecord[] { }); } } @@ -395,32 +453,6 @@ namespace DnsServerCore return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, true, false, false, DnsResponseCode.Refused, 1, 0, 0, 0), request.Question, new DnsResourceRecord[] { }, new DnsResourceRecord[] { }, new DnsResourceRecord[] { }); } - private static DnsResourceRecord[] ResolveCNAME(Zone rootZone, DnsResourceRecord cnameRR, DnsResourceRecordType type) - { - if ((type == DnsResourceRecordType.CNAME) || (type == DnsResourceRecordType.ANY)) - return new DnsResourceRecord[] { cnameRR }; - - List recordsList = new List(); - recordsList.Add(cnameRR); - - while (true) - { - DnsResourceRecord[] records = GetRecords(rootZone, (cnameRR.RDATA as DnsCNAMERecord).CNAMEDomainName, type); - - if ((records == null) || (records.Length == 0)) - break; - - recordsList.AddRange(records); - - if (records[0].Type != DnsResourceRecordType.CNAME) - break; - - cnameRR = records[0]; - } - - return recordsList.ToArray(); - } - #endregion #region internal @@ -572,14 +604,54 @@ namespace DnsServerCore public void AddRecord(string domain, DnsResourceRecordType type, uint ttl, DnsResourceRecordData record) { - CreateZone(this, domain).AddRecord(new DnsResourceRecord(domain, type, DnsClass.IN, ttl, record)); + DnsResourceRecord rr = new DnsResourceRecord(domain, type, DnsClass.IN, ttl, record); + CreateZone(this, domain).AddRecord(rr); } - public void DeleteRecord(string domain, DnsResourceRecordType type, uint ttl, DnsResourceRecordData record) + public void UpdateRecord(DnsResourceRecord oldRecord, DnsResourceRecord newRecord) + { + if (oldRecord.Type != newRecord.Type) + throw new DnsServerException("Cannot update record: new record must be of same type."); + + if (oldRecord.Type == DnsResourceRecordType.SOA) + throw new DnsServerException("Cannot update record: use SetRecords() for updating SOA record."); + + Zone zone = FindClosestZone(this, oldRecord.Name); + if (!zone._zoneName.Equals(oldRecord.Name, StringComparison.CurrentCultureIgnoreCase)) + throw new DnsServerException("Cannot update record: old record does not exists."); + + switch (oldRecord.Type) + { + case DnsResourceRecordType.CNAME: + case DnsResourceRecordType.PTR: + if (oldRecord.Name.Equals(newRecord.Name, StringComparison.CurrentCultureIgnoreCase)) + { + zone.SetRecords(newRecord.Type, new DnsResourceRecord[] { newRecord }); + } + else + { + zone.DeleteRecords(oldRecord.Type); + CreateZone(this, newRecord.Name).SetRecords(newRecord.Type, new DnsResourceRecord[] { newRecord }); + } + break; + + default: + zone.DeleteRecord(oldRecord); + + if (oldRecord.Name.Equals(newRecord.Name, StringComparison.CurrentCultureIgnoreCase)) + zone.AddRecord(newRecord); + else + CreateZone(this, newRecord.Name).AddRecord(newRecord); + + break; + } + } + + public void DeleteRecord(string domain, DnsResourceRecordType type, DnsResourceRecordData record) { Zone zone = FindClosestZone(this, domain); if (zone._zoneName.Equals(domain, StringComparison.CurrentCultureIgnoreCase)) - zone.DeleteRecord(new DnsResourceRecord(domain, type, DnsClass.IN, ttl, record)); + zone.DeleteRecord(new DnsResourceRecord(domain, type, DnsClass.IN, 0, record)); } public void DeleteRecords(string domain, DnsResourceRecordType type) @@ -589,7 +661,7 @@ namespace DnsServerCore zone.DeleteRecords(type); } - public DnsResourceRecord[] GetRecords(string domain = "") + public DnsResourceRecord[] GetRecords(string domain = "", bool includeSubDomains = true) { Zone currentZone = this; @@ -605,7 +677,7 @@ namespace DnsServerCore return null; //no zone for given domain } - DnsResourceRecord[] records = currentZone.GetRecords(DnsResourceRecordType.ANY); + DnsResourceRecord[] records = currentZone.GetRecords(DnsResourceRecordType.ANY, includeSubDomains); if (records != null) return records; @@ -666,6 +738,12 @@ namespace DnsServerCore return (DeleteZone(this, domain) != null); } + public void Flush() + { + _zones.Clear(); + _entries.Clear(); + } + #endregion class DnsNXRecord : DnsResourceRecordData