diff --git a/DnsServerCore/Dns/Zones/AuthZone.cs b/DnsServerCore/Dns/Zones/AuthZone.cs index d77d15b5..2cb15b3a 100644 --- a/DnsServerCore/Dns/Zones/AuthZone.cs +++ b/DnsServerCore/Dns/Zones/AuthZone.cs @@ -99,7 +99,7 @@ namespace DnsServerCore.Dns.Zones return newRecords; } - private IReadOnlyList AddRRSIGs(IReadOnlyList records) + private IReadOnlyList AppendRRSigTo(IReadOnlyList records) { IReadOnlyList rrsigRecords = GetRecords(DnsResourceRecordType.RRSIG); if (rrsigRecords.Count == 0) @@ -120,9 +120,9 @@ namespace DnsServerCore.Dns.Zones #endregion - #region protected + #region versioning - protected bool SetRecords(DnsResourceRecordType type, IReadOnlyList records, out IReadOnlyList deletedRecords) + internal bool TrySetRecords(DnsResourceRecordType type, IReadOnlyList records, out IReadOnlyList deletedRecords) { if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) { @@ -131,12 +131,12 @@ namespace DnsServerCore.Dns.Zones } else { - deletedRecords = null; + deletedRecords = Array.Empty(); return _entries.TryAdd(type, records); } } - protected bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData rdata, out DnsResourceRecord deletedRecord) + internal bool TryDeleteRecord(DnsResourceRecordType type, DnsResourceRecordData rdata, out DnsResourceRecord deletedRecord) { if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) { @@ -164,6 +164,9 @@ namespace DnsServerCore.Dns.Zones updatedRecords.Add(existingRecord); } + if (deletedRecord is null) + return false; //not found + return _entries.TryUpdate(type, updatedRecords, existingRecords); } } @@ -172,6 +175,344 @@ namespace DnsServerCore.Dns.Zones return false; } + internal bool TryDeleteRecords(DnsResourceRecordType type, IReadOnlyList records, out IReadOnlyList deletedRecords) + { + if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) + { + if (existingRecords.Count == 1) + { + DnsResourceRecord existingRecord = existingRecords[0]; + + foreach (DnsResourceRecord record in records) + { + if (record.RDATA.Equals(existingRecord.RDATA)) + { + if (_entries.TryRemove(type, out IReadOnlyList removedRecords)) + { + deletedRecords = removedRecords; + return true; + } + } + } + } + else + { + List deleted = new List(records.Count); + List updatedRecords = new List(existingRecords.Count); + + foreach (DnsResourceRecord existingRecord in existingRecords) + { + bool found = false; + + foreach (DnsResourceRecord record in records) + { + if (record.RDATA.Equals(existingRecord.RDATA)) + { + found = true; + break; + } + } + + if (found) + deleted.Add(existingRecord); + else + updatedRecords.Add(existingRecord); + } + + if (deleted.Count > 0) + { + deletedRecords = deleted; + + if (updatedRecords.Count > 0) + return _entries.TryUpdate(type, updatedRecords, existingRecords); + + return _entries.TryRemove(type, out _); + } + } + } + + deletedRecords = null; + return false; + } + + internal void AddOrUpdateRRSigRecords(IReadOnlyList newRRSigRecords, out IReadOnlyList deletedRRSigRecords) + { + IReadOnlyList deleted = null; + + _entries.AddOrUpdate(DnsResourceRecordType.RRSIG, delegate (DnsResourceRecordType key) + { + deleted = Array.Empty(); + return newRRSigRecords; + }, + delegate (DnsResourceRecordType key, IReadOnlyList existingRecords) + { + List updatedRecords = new List(existingRecords.Count + newRRSigRecords.Count); + List deletedRecords = new List(); + + foreach (DnsResourceRecord existingRecord in existingRecords) + { + bool found = false; + DnsRRSIGRecord existingRRSig = existingRecord.RDATA as DnsRRSIGRecord; + + foreach (DnsResourceRecord newRRSigRecord in newRRSigRecords) + { + DnsRRSIGRecord newRRSig = newRRSigRecord.RDATA as DnsRRSIGRecord; + + if ((newRRSig.TypeCovered == existingRRSig.TypeCovered) && (newRRSig.KeyTag == existingRRSig.KeyTag)) + { + deletedRecords.Add(existingRecord); + found = true; + break; + } + } + + if (!found) + updatedRecords.Add(existingRecord); + } + + updatedRecords.AddRange(newRRSigRecords); + + deleted = deletedRecords; + return updatedRecords; + }); + + deletedRRSigRecords = deleted; + } + + #endregion + + #region DNSSEC + + internal IReadOnlyList SignAllRRSets() + { + List rrsigRecords = new List(_entries.Count); + + foreach (KeyValuePair> entry in _entries) + { + if (entry.Key == DnsResourceRecordType.RRSIG) + continue; + + rrsigRecords.AddRange(SignRRSet(entry.Value)); + } + + return rrsigRecords; + } + + internal IReadOnlyList RemoveAllDnssecRecords() + { + List allRemovedRecords = new List(); + + foreach (KeyValuePair> entry in _entries) + { + switch (entry.Key) + { + case DnsResourceRecordType.DNSKEY: + case DnsResourceRecordType.RRSIG: + case DnsResourceRecordType.NSEC: + case DnsResourceRecordType.NSEC3PARAM: + case DnsResourceRecordType.NSEC3: + if (_entries.TryRemove(entry.Key, out IReadOnlyList removedRecords)) + allRemovedRecords.AddRange(removedRecords); + + break; + } + } + + return allRemovedRecords; + } + + internal IReadOnlyList RemoveNSecRecordsWithRRSig() + { + List allRemovedRecords = new List(2); + + foreach (KeyValuePair> entry in _entries) + { + switch (entry.Key) + { + case DnsResourceRecordType.NSEC: + if (_entries.TryRemove(entry.Key, out IReadOnlyList removedRecords)) + allRemovedRecords.AddRange(removedRecords); + + break; + + case DnsResourceRecordType.RRSIG: + List recordsToRemove = new List(1); + + foreach (DnsResourceRecord rrsigRecord in entry.Value) + { + DnsRRSIGRecord rrsig = rrsigRecord.RDATA as DnsRRSIGRecord; + if (rrsig.TypeCovered == DnsResourceRecordType.NSEC) + recordsToRemove.Add(rrsigRecord); + } + + if (recordsToRemove.Count > 0) + { + if (TryDeleteRecords(DnsResourceRecordType.RRSIG, recordsToRemove, out IReadOnlyList deletedRecords)) + allRemovedRecords.AddRange(deletedRecords); + } + + break; + } + } + + return allRemovedRecords; + } + + internal IReadOnlyList RemoveNSec3RecordsWithRRSig() + { + List allRemovedRecords = new List(2); + + foreach (KeyValuePair> entry in _entries) + { + switch (entry.Key) + { + case DnsResourceRecordType.NSEC3: + case DnsResourceRecordType.NSEC3PARAM: + if (_entries.TryRemove(entry.Key, out IReadOnlyList removedRecords)) + allRemovedRecords.AddRange(removedRecords); + + break; + + case DnsResourceRecordType.RRSIG: + List recordsToRemove = new List(1); + + foreach (DnsResourceRecord rrsigRecord in entry.Value) + { + DnsRRSIGRecord rrsig = rrsigRecord.RDATA as DnsRRSIGRecord; + switch (rrsig.TypeCovered) + { + case DnsResourceRecordType.NSEC3: + case DnsResourceRecordType.NSEC3PARAM: + recordsToRemove.Add(rrsigRecord); + break; + } + } + + if (recordsToRemove.Count > 0) + { + if (TryDeleteRecords(DnsResourceRecordType.RRSIG, recordsToRemove, out IReadOnlyList deletedRecords)) + allRemovedRecords.AddRange(deletedRecords); + } + + break; + } + } + + return allRemovedRecords; + } + + internal bool HasOnlyNSec3Records() + { + if (!_entries.ContainsKey(DnsResourceRecordType.NSEC3)) + return false; + + foreach (KeyValuePair> entry in _entries) + { + switch (entry.Key) + { + case DnsResourceRecordType.NSEC3: + case DnsResourceRecordType.RRSIG: + break; + + default: + //found non NSEC3 records + return false; + } + } + + return true; + } + + internal IReadOnlyList RefreshSignatures() + { + if (!_entries.TryGetValue(DnsResourceRecordType.RRSIG, out IReadOnlyList rrsigRecords)) + throw new InvalidOperationException(); + + List typesToRefresh = new List(); + DateTime utcNow = DateTime.UtcNow; + + foreach (DnsResourceRecord rrsigRecord in rrsigRecords) + { + DnsRRSIGRecord rrsig = rrsigRecord.RDATA as DnsRRSIGRecord; + + uint signatureValidityPeriod = rrsig.SignatureExpirationValue - rrsig.SignatureInceptionValue; + uint refreshPeriod = signatureValidityPeriod / 3; + + if (utcNow > DateTime.UnixEpoch.AddSeconds(rrsig.SignatureExpirationValue - refreshPeriod)) + typesToRefresh.Add(rrsig.TypeCovered); + } + + List newRRSigRecords = new List(typesToRefresh.Count); + + foreach (DnsResourceRecordType type in typesToRefresh) + { + if (_entries.TryGetValue(type, out IReadOnlyList records)) + newRRSigRecords.AddRange(SignRRSet(records)); + } + + return newRRSigRecords; + } + + internal virtual IReadOnlyList SignRRSet(IReadOnlyList records) + { + throw new InvalidOperationException(); + } + + internal IReadOnlyList GetUpdatedNSecRRSet(string nextDomainName, uint ttl) + { + List types = new List(_entries.Count); + + foreach (KeyValuePair> entry in _entries) + types.Add(entry.Key); + + if (!_entries.ContainsKey(DnsResourceRecordType.NSEC)) + types.Add(DnsResourceRecordType.NSEC); + + if (!_entries.ContainsKey(DnsResourceRecordType.RRSIG)) + types.Add(DnsResourceRecordType.RRSIG); + + types.Sort(); + + DnsNSECRecord newNSecRecord = new DnsNSECRecord(nextDomainName, types); + + if (!_entries.TryGetValue(DnsResourceRecordType.NSEC, out IReadOnlyList existingRecords) || !existingRecords[0].RDATA.Equals(newNSecRecord)) + return new DnsResourceRecord[] { new DnsResourceRecord(_name, DnsResourceRecordType.NSEC, DnsClass.IN, ttl, newNSecRecord) }; + + return Array.Empty(); + } + + internal IReadOnlyList CreateNSec3RRSet(string hashedOwnerName, byte[] nextHashedOwnerName, uint ttl, ushort iterations, byte[] salt) + { + List types = new List(_entries.Count); + + foreach (KeyValuePair> entry in _entries) + types.Add(entry.Key); + + if (!_entries.ContainsKey(DnsResourceRecordType.RRSIG)) + types.Add(DnsResourceRecordType.RRSIG); + + types.Sort(); + + DnsNSEC3Record newNSec3 = new DnsNSEC3Record(DnssecNSEC3HashAlgorithm.SHA1, DnssecNSEC3Flags.None, iterations, salt, nextHashedOwnerName, types); + return new DnsResourceRecord[] { new DnsResourceRecord(hashedOwnerName, DnsResourceRecordType.NSEC3, DnsClass.IN, ttl, newNSec3) }; + } + + internal DnsResourceRecord GetPartialNSec3Record(string zoneName, uint ttl, ushort iterations, byte[] salt) + { + List types = new List(_entries.Count); + + foreach (KeyValuePair> entry in _entries) + types.Add(entry.Key); + + if (!_entries.ContainsKey(DnsResourceRecordType.RRSIG)) + types.Add(DnsResourceRecordType.RRSIG); + + types.Sort(); + + DnsNSEC3Record newNSec3Record = new DnsNSEC3Record(DnssecNSEC3HashAlgorithm.SHA1, DnssecNSEC3Flags.None, iterations, salt, Array.Empty(), types); + return new DnsResourceRecord(newNSec3Record.ComputeHashedOwnerName(_name) + "." + zoneName, DnsResourceRecordType.NSEC3, DnsClass.IN, ttl, newNSec3Record); + } + #endregion #region public @@ -359,7 +700,7 @@ namespace DnsServerCore.Dns.Zones public virtual bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData rdata) { - return DeleteRecord(type, rdata, out _); + return TryDeleteRecord(type, rdata, out _); } public virtual void UpdateRecord(DnsResourceRecord oldRecord, DnsResourceRecord newRecord) @@ -405,7 +746,7 @@ namespace DnsServerCore.Dns.Zones if (filteredRecords.Count > 0) { if (dnssecOk) - return AddRRSIGs(filteredRecords); + return AppendRRSigTo(filteredRecords); return filteredRecords; } @@ -417,7 +758,7 @@ namespace DnsServerCore.Dns.Zones if (filteredRecords.Count > 0) { if (dnssecOk) - return AddRRSIGs(filteredRecords); + return AppendRRSigTo(filteredRecords); return filteredRecords; }