diff --git a/DnsServerCore/Dns/Zones/AuthZone.cs b/DnsServerCore/Dns/Zones/AuthZone.cs index ad55a9e9..ee35cef0 100644 --- a/DnsServerCore/Dns/Zones/AuthZone.cs +++ b/DnsServerCore/Dns/Zones/AuthZone.cs @@ -125,91 +125,194 @@ namespace DnsServerCore.Dns.Zones return newRecords; } - private static async Task> GetNameServerAddressesAsync(DnsServer dnsServer, DnsResourceRecord record) + private static async Task> ResolveNameServerAddressesAsync(DnsServer dnsServer, string nsDomain) { - string nsDomain; - - switch (record.Type) - { - case DnsResourceRecordType.NS: - nsDomain = (record.RDATA as DnsNSRecord).NameServer; - break; - - case DnsResourceRecordType.SOA: - nsDomain = (record.RDATA as DnsSOARecord).PrimaryNameServer; - break; - - default: - throw new InvalidOperationException(); - } - List nameServers = new List(2); - IReadOnlyList glueRecords = record.GetGlueRecords(); - if (glueRecords.Count > 0) + try { - foreach (DnsResourceRecord glueRecord in glueRecords) + DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.A, DnsClass.IN)).WithTimeout(2000); + if (response.Answer.Count > 0) { - switch (glueRecord.Type) - { - case DnsResourceRecordType.A: - nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsARecord).Address)); - break; - - case DnsResourceRecordType.AAAA: - if (dnsServer.PreferIPv6) - nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsAAAARecord).Address)); - - break; - } + IReadOnlyList addresses = DnsClient.ParseResponseA(response); + foreach (IPAddress address in addresses) + nameServers.Add(new NameServerAddress(nsDomain, address)); } } - else + catch + { } + + if (dnsServer.PreferIPv6) { - //resolve addresses try { - DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.A, DnsClass.IN)).WithTimeout(2000); + DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.AAAA, DnsClass.IN)).WithTimeout(2000); if (response.Answer.Count > 0) { - IReadOnlyList addresses = DnsClient.ParseResponseA(response); + IReadOnlyList addresses = DnsClient.ParseResponseAAAA(response); foreach (IPAddress address in addresses) nameServers.Add(new NameServerAddress(nsDomain, address)); } } catch { } - - if (dnsServer.PreferIPv6) - { - try - { - DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.AAAA, DnsClass.IN)).WithTimeout(2000); - if (response.Answer.Count > 0) - { - IReadOnlyList addresses = DnsClient.ParseResponseAAAA(response); - foreach (IPAddress address in addresses) - nameServers.Add(new NameServerAddress(nsDomain, address)); - } - } - catch - { } - } } return nameServers; } + private static Task> ResolveNameServerAddressesAsync(DnsServer dnsServer, DnsResourceRecord nsRecord) + { + switch (nsRecord.Type) + { + case DnsResourceRecordType.NS: + { + string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer; + + IReadOnlyList glueRecords = nsRecord.GetGlueRecords(); + if (glueRecords.Count > 0) + { + List nameServers = new List(2); + + foreach (DnsResourceRecord glueRecord in glueRecords) + { + switch (glueRecord.Type) + { + case DnsResourceRecordType.A: + nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsARecord).Address)); + break; + + case DnsResourceRecordType.AAAA: + if (dnsServer.PreferIPv6) + nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsAAAARecord).Address)); + + break; + } + } + + return Task.FromResult(nameServers as IReadOnlyList); + } + else + { + return ResolveNameServerAddressesAsync(dnsServer, nsDomain); + } + } + + default: + throw new InvalidOperationException(); + } + } + + #endregion + + #region protected + + protected void CleanupHistory(List history) + { + DnsSOARecord soa = _entries[DnsResourceRecordType.SOA][0].RDATA as DnsSOARecord; + DateTime expiry = DateTime.UtcNow.AddSeconds(-soa.Expire); + int index = 0; + + while (index < history.Count) + { + //check difference sequence + if (history[index].GetDeletedOn() > expiry) + break; //found record to keep + + //skip to next difference sequence + index++; + int soaCount = 1; + + while (index < history.Count) + { + if (history[index].Type == DnsResourceRecordType.SOA) + { + soaCount++; + + if (soaCount == 3) + break; + } + + index++; + } + } + + if (index == history.Count) + { + //delete entire history + history.Clear(); + return; + } + + //remove expired records + history.RemoveRange(0, index); + } + + protected bool SetRecords(DnsResourceRecordType type, IReadOnlyList records, out IReadOnlyList deletedRecords) + { + if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) + { + deletedRecords = existingRecords; + return _entries.TryUpdate(type, records, existingRecords); + } + else + { + deletedRecords = null; + return _entries.TryAdd(type, records); + } + } + + protected bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData rdata, out DnsResourceRecord deletedRecord) + { + if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) + { + if (existingRecords.Count == 1) + { + if (rdata.Equals(existingRecords[0].RDATA)) + { + if (_entries.TryRemove(type, out IReadOnlyList removedRecords)) + { + deletedRecord = removedRecords[0]; + return true; + } + } + } + else + { + deletedRecord = null; + List updatedRecords = new List(existingRecords.Count); + + foreach (DnsResourceRecord existingRecord in existingRecords) + { + if ((deletedRecord is null) && rdata.Equals(existingRecord.RDATA)) + deletedRecord = existingRecord; + else + updatedRecords.Add(existingRecord); + } + + return _entries.TryUpdate(type, updatedRecords, existingRecords); + } + } + + deletedRecord = null; + return false; + } + #endregion #region public public async Task> GetPrimaryNameServerAddressesAsync(DnsServer dnsServer) { + DnsResourceRecord soaRecord = _entries[DnsResourceRecordType.SOA][0]; + + IReadOnlyList primaryNameServers = soaRecord.GetPrimaryNameServers(); + if (primaryNameServers.Count > 0) + return primaryNameServers; + List nameServers = new List(); - DnsResourceRecord soaRecord = _entries[DnsResourceRecordType.SOA][0]; - DnsSOARecord soa = soaRecord.RDATA as DnsSOARecord; + string primaryNameServer = (soaRecord.RDATA as DnsSOARecord).PrimaryNameServer; IReadOnlyList nsRecords = GetRecords(DnsResourceRecordType.NS); //stub zone has no authority so cant use QueryRecords foreach (DnsResourceRecord nsRecord in nsRecords) @@ -217,21 +320,16 @@ namespace DnsServerCore.Dns.Zones if (nsRecord.IsDisabled()) continue; - string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer; - - if (soa.PrimaryNameServer.Equals(nsDomain, StringComparison.OrdinalIgnoreCase)) + if (primaryNameServer.Equals((nsRecord.RDATA as DnsNSRecord).NameServer, StringComparison.OrdinalIgnoreCase)) { //found primary NS - nameServers.AddRange(await GetNameServerAddressesAsync(dnsServer, nsRecord)); + nameServers.AddRange(await ResolveNameServerAddressesAsync(dnsServer, nsRecord)); break; } } - foreach (NameServerAddress nameServer in await GetNameServerAddressesAsync(dnsServer, soaRecord)) - { - if (!nameServers.Contains(nameServer)) - nameServers.Add(nameServer); - } + if (nameServers.Count < 1) + nameServers.AddRange(await ResolveNameServerAddressesAsync(dnsServer, primaryNameServer)); return nameServers; } @@ -240,7 +338,7 @@ namespace DnsServerCore.Dns.Zones { List nameServers = new List(); - DnsSOARecord soa = _entries[DnsResourceRecordType.SOA][0].RDATA as DnsSOARecord; + string primaryNameServer = (_entries[DnsResourceRecordType.SOA][0].RDATA as DnsSOARecord).PrimaryNameServer; IReadOnlyList nsRecords = GetRecords(DnsResourceRecordType.NS); //stub zone has no authority so cant use QueryRecords foreach (DnsResourceRecord nsRecord in nsRecords) @@ -248,27 +346,33 @@ namespace DnsServerCore.Dns.Zones if (nsRecord.IsDisabled()) continue; - string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer; - - if (soa.PrimaryNameServer.Equals(nsDomain, StringComparison.OrdinalIgnoreCase)) + if (primaryNameServer.Equals((nsRecord.RDATA as DnsNSRecord).NameServer, StringComparison.OrdinalIgnoreCase)) continue; //skip primary name server - nameServers.AddRange(await GetNameServerAddressesAsync(dnsServer, nsRecord)); + nameServers.AddRange(await ResolveNameServerAddressesAsync(dnsServer, nsRecord)); } return nameServers; } - public void SyncRecords(Dictionary> newEntries, bool dontRemoveRecords) + public Task> GetNameServerAddressesAsync(DnsServer dnsServer) { - if (!dontRemoveRecords) + DnsResourceRecord soaRecord = _entries[DnsResourceRecordType.SOA][0]; + + IReadOnlyList primaryNameServers = soaRecord.GetPrimaryNameServers(); + if (primaryNameServers.Count > 0) + return Task.FromResult(primaryNameServers); + + return GetSecondaryNameServerAddressesAsync(dnsServer); + } + + public void SyncRecords(Dictionary> newEntries) + { + //remove entires of type that do not exists in new entries + foreach (KeyValuePair> entry in _entries) { - //remove entires of type that do not exists in new entries - foreach (KeyValuePair> entry in _entries) - { - if (!newEntries.ContainsKey(entry.Key)) - _entries.TryRemove(entry.Key, out _); - } + if (!newEntries.ContainsKey(entry.Key)) + _entries.TryRemove(entry.Key, out _); } //set new entries into zone @@ -298,10 +402,14 @@ namespace DnsServerCore.Dns.Zones if (newEntry.Value.Count != 1) continue; //skip invalid SOA record - if ((this is SecondaryZone) || (this is StubZone)) + if (this is SecondaryZone) { - //copy existing SOA record's glue addresses to new SOA record - newEntry.Value[0].SetGlueRecords(_entries[DnsResourceRecordType.SOA][0].GetGlueRecords()); + //copy existing SOA record's glue addresses and comments to new SOA record + DnsResourceRecord existingSoaRecord = _entries[DnsResourceRecordType.SOA][0]; + DnsResourceRecord newSoaRecord = newEntry.Value[0]; + + newSoaRecord.SetPrimaryNameServers(existingSoaRecord.GetPrimaryNameServers()); + newSoaRecord.SetComments(existingSoaRecord.GetComments()); } } @@ -310,6 +418,89 @@ namespace DnsServerCore.Dns.Zones } } + public void SyncRecords(Dictionary> deletedEntries = null, Dictionary> addedEntries = null) + { + if (deletedEntries is not null) + { + foreach (KeyValuePair> deletedEntry in deletedEntries) + { + if (_entries.TryGetValue(deletedEntry.Key, out IReadOnlyList existingRecords)) + { + List updatedRecords = new List(Math.Max(0, existingRecords.Count - deletedEntry.Value.Count)); + + foreach (DnsResourceRecord existingRecord in existingRecords) + { + bool deleted = false; + + foreach (DnsResourceRecord deletedRecord in deletedEntry.Value) + { + if (existingRecord.RDATA.Equals(deletedRecord.RDATA)) + { + deleted = true; + break; + } + } + + if (!deleted) + updatedRecords.Add(existingRecord); + } + + if (existingRecords.Count > updatedRecords.Count) + { + if (updatedRecords.Count > 0) + _entries[deletedEntry.Key] = updatedRecords; + else + _entries.TryRemove(deletedEntry.Key, out _); + } + } + } + } + + if (addedEntries is not null) + { + foreach (KeyValuePair> addedEntry in addedEntries) + { + _entries.AddOrUpdate(addedEntry.Key, addedEntry.Value, delegate (DnsResourceRecordType key, IReadOnlyList existingRecords) + { + List updatedRecords = new List(existingRecords.Count + addedEntry.Value.Count); + + updatedRecords.AddRange(existingRecords); + + foreach (DnsResourceRecord addedRecord in addedEntry.Value) + { + bool exists = false; + + foreach (DnsResourceRecord existingRecord in existingRecords) + { + if (addedRecord.RDATA.Equals(existingRecord.RDATA)) + { + exists = true; + break; + } + } + + if (!exists) + updatedRecords.Add(addedRecord); + } + + if (updatedRecords.Count > existingRecords.Count) + return updatedRecords; + else + return existingRecords; + }); + } + } + } + + public void SyncGlueRecords(IReadOnlyCollection deletedGlueRecords, IReadOnlyCollection addedGlueRecords) + { + if (_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList nsRecords)) + { + foreach (DnsResourceRecord nsRecord in nsRecords) + nsRecord.SyncGlueRecords(deletedGlueRecords, addedGlueRecords); + } + } + public void LoadRecords(DnsResourceRecordType type, IReadOnlyList records) { _entries[type] = records; @@ -339,16 +530,16 @@ namespace DnsServerCore.Dns.Zones { foreach (DnsResourceRecord existingRecord in existingRecords) { - if (record.Equals(existingRecord.RDATA)) + if (record.RDATA.Equals(existingRecord.RDATA)) return existingRecords; } - List updateRecords = new List(existingRecords.Count + 1); + List updatedRecords = new List(existingRecords.Count + 1); - updateRecords.AddRange(existingRecords); - updateRecords.Add(record); + updatedRecords.AddRange(existingRecords); + updatedRecords.Add(record); - return updateRecords; + return updatedRecords; }); } @@ -357,30 +548,21 @@ namespace DnsServerCore.Dns.Zones return _entries.TryRemove(type, out _); } - public virtual bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData record) + public virtual bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData rdata) { - if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) - { - if (existingRecords.Count == 1) - { - if (record.Equals(existingRecords[0].RDATA)) - return _entries.TryRemove(type, out _); - } - else - { - List updateRecords = new List(existingRecords.Count); + return DeleteRecord(type, rdata, out _); + } - for (int i = 0; i < existingRecords.Count; i++) - { - if (!record.Equals(existingRecords[i].RDATA)) - updateRecords.Add(existingRecords[i]); - } + public virtual void UpdateRecord(DnsResourceRecord oldRecord, DnsResourceRecord newRecord) + { + if (oldRecord.Type == DnsResourceRecordType.SOA) + throw new InvalidOperationException("Cannot update record: use SetRecords() for " + oldRecord.Type.ToString() + " record"); - return _entries.TryUpdate(type, updateRecords, existingRecords); - } - } + if (oldRecord.Type != newRecord.Type) + throw new InvalidOperationException("Old and new record types do not match."); - return false; + DeleteRecord(oldRecord.Type, oldRecord.RDATA); + AddRecord(newRecord); } public virtual IReadOnlyList QueryRecords(DnsResourceRecordType type)