AuthZone: updated dns record versioning method implementation to support DNSSEC. Added methods for DNSSEC related operations. Minor code refactoring done.

This commit is contained in:
Shreyas Zare
2022-02-19 13:05:26 +05:30
parent 9d9c0e24db
commit d12a45cff4

View File

@@ -99,7 +99,7 @@ namespace DnsServerCore.Dns.Zones
return newRecords;
}
private IReadOnlyList<DnsResourceRecord> AddRRSIGs(IReadOnlyList<DnsResourceRecord> records)
private IReadOnlyList<DnsResourceRecord> AppendRRSigTo(IReadOnlyList<DnsResourceRecord> records)
{
IReadOnlyList<DnsResourceRecord> rrsigRecords = GetRecords(DnsResourceRecordType.RRSIG);
if (rrsigRecords.Count == 0)
@@ -120,9 +120,9 @@ namespace DnsServerCore.Dns.Zones
#endregion
#region protected
#region versioning
protected bool SetRecords(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records, out IReadOnlyList<DnsResourceRecord> deletedRecords)
internal bool TrySetRecords(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records, out IReadOnlyList<DnsResourceRecord> deletedRecords)
{
if (_entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
{
@@ -131,12 +131,12 @@ namespace DnsServerCore.Dns.Zones
}
else
{
deletedRecords = null;
deletedRecords = Array.Empty<DnsResourceRecord>();
return _entries.TryAdd(type, records);
}
}
protected bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData rdata, out DnsResourceRecord deletedRecord)
internal bool TryDeleteRecord(DnsResourceRecordType type, DnsResourceRecordData rdata, out DnsResourceRecord deletedRecord)
{
if (_entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
{
@@ -164,6 +164,9 @@ namespace DnsServerCore.Dns.Zones
updatedRecords.Add(existingRecord);
}
if (deletedRecord is null)
return false; //not found
return _entries.TryUpdate(type, updatedRecords, existingRecords);
}
}
@@ -172,6 +175,344 @@ namespace DnsServerCore.Dns.Zones
return false;
}
internal bool TryDeleteRecords(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records, out IReadOnlyList<DnsResourceRecord> deletedRecords)
{
if (_entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
{
if (existingRecords.Count == 1)
{
DnsResourceRecord existingRecord = existingRecords[0];
foreach (DnsResourceRecord record in records)
{
if (record.RDATA.Equals(existingRecord.RDATA))
{
if (_entries.TryRemove(type, out IReadOnlyList<DnsResourceRecord> removedRecords))
{
deletedRecords = removedRecords;
return true;
}
}
}
}
else
{
List<DnsResourceRecord> deleted = new List<DnsResourceRecord>(records.Count);
List<DnsResourceRecord> updatedRecords = new List<DnsResourceRecord>(existingRecords.Count);
foreach (DnsResourceRecord existingRecord in existingRecords)
{
bool found = false;
foreach (DnsResourceRecord record in records)
{
if (record.RDATA.Equals(existingRecord.RDATA))
{
found = true;
break;
}
}
if (found)
deleted.Add(existingRecord);
else
updatedRecords.Add(existingRecord);
}
if (deleted.Count > 0)
{
deletedRecords = deleted;
if (updatedRecords.Count > 0)
return _entries.TryUpdate(type, updatedRecords, existingRecords);
return _entries.TryRemove(type, out _);
}
}
}
deletedRecords = null;
return false;
}
internal void AddOrUpdateRRSigRecords(IReadOnlyList<DnsResourceRecord> newRRSigRecords, out IReadOnlyList<DnsResourceRecord> deletedRRSigRecords)
{
IReadOnlyList<DnsResourceRecord> deleted = null;
_entries.AddOrUpdate(DnsResourceRecordType.RRSIG, delegate (DnsResourceRecordType key)
{
deleted = Array.Empty<DnsResourceRecord>();
return newRRSigRecords;
},
delegate (DnsResourceRecordType key, IReadOnlyList<DnsResourceRecord> existingRecords)
{
List<DnsResourceRecord> updatedRecords = new List<DnsResourceRecord>(existingRecords.Count + newRRSigRecords.Count);
List<DnsResourceRecord> deletedRecords = new List<DnsResourceRecord>();
foreach (DnsResourceRecord existingRecord in existingRecords)
{
bool found = false;
DnsRRSIGRecord existingRRSig = existingRecord.RDATA as DnsRRSIGRecord;
foreach (DnsResourceRecord newRRSigRecord in newRRSigRecords)
{
DnsRRSIGRecord newRRSig = newRRSigRecord.RDATA as DnsRRSIGRecord;
if ((newRRSig.TypeCovered == existingRRSig.TypeCovered) && (newRRSig.KeyTag == existingRRSig.KeyTag))
{
deletedRecords.Add(existingRecord);
found = true;
break;
}
}
if (!found)
updatedRecords.Add(existingRecord);
}
updatedRecords.AddRange(newRRSigRecords);
deleted = deletedRecords;
return updatedRecords;
});
deletedRRSigRecords = deleted;
}
#endregion
#region DNSSEC
internal IReadOnlyList<DnsResourceRecord> SignAllRRSets()
{
List<DnsResourceRecord> rrsigRecords = new List<DnsResourceRecord>(_entries.Count);
foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
{
if (entry.Key == DnsResourceRecordType.RRSIG)
continue;
rrsigRecords.AddRange(SignRRSet(entry.Value));
}
return rrsigRecords;
}
internal IReadOnlyList<DnsResourceRecord> RemoveAllDnssecRecords()
{
List<DnsResourceRecord> allRemovedRecords = new List<DnsResourceRecord>();
foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
{
switch (entry.Key)
{
case DnsResourceRecordType.DNSKEY:
case DnsResourceRecordType.RRSIG:
case DnsResourceRecordType.NSEC:
case DnsResourceRecordType.NSEC3PARAM:
case DnsResourceRecordType.NSEC3:
if (_entries.TryRemove(entry.Key, out IReadOnlyList<DnsResourceRecord> removedRecords))
allRemovedRecords.AddRange(removedRecords);
break;
}
}
return allRemovedRecords;
}
internal IReadOnlyList<DnsResourceRecord> RemoveNSecRecordsWithRRSig()
{
List<DnsResourceRecord> allRemovedRecords = new List<DnsResourceRecord>(2);
foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
{
switch (entry.Key)
{
case DnsResourceRecordType.NSEC:
if (_entries.TryRemove(entry.Key, out IReadOnlyList<DnsResourceRecord> removedRecords))
allRemovedRecords.AddRange(removedRecords);
break;
case DnsResourceRecordType.RRSIG:
List<DnsResourceRecord> recordsToRemove = new List<DnsResourceRecord>(1);
foreach (DnsResourceRecord rrsigRecord in entry.Value)
{
DnsRRSIGRecord rrsig = rrsigRecord.RDATA as DnsRRSIGRecord;
if (rrsig.TypeCovered == DnsResourceRecordType.NSEC)
recordsToRemove.Add(rrsigRecord);
}
if (recordsToRemove.Count > 0)
{
if (TryDeleteRecords(DnsResourceRecordType.RRSIG, recordsToRemove, out IReadOnlyList<DnsResourceRecord> deletedRecords))
allRemovedRecords.AddRange(deletedRecords);
}
break;
}
}
return allRemovedRecords;
}
internal IReadOnlyList<DnsResourceRecord> RemoveNSec3RecordsWithRRSig()
{
List<DnsResourceRecord> allRemovedRecords = new List<DnsResourceRecord>(2);
foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
{
switch (entry.Key)
{
case DnsResourceRecordType.NSEC3:
case DnsResourceRecordType.NSEC3PARAM:
if (_entries.TryRemove(entry.Key, out IReadOnlyList<DnsResourceRecord> removedRecords))
allRemovedRecords.AddRange(removedRecords);
break;
case DnsResourceRecordType.RRSIG:
List<DnsResourceRecord> recordsToRemove = new List<DnsResourceRecord>(1);
foreach (DnsResourceRecord rrsigRecord in entry.Value)
{
DnsRRSIGRecord rrsig = rrsigRecord.RDATA as DnsRRSIGRecord;
switch (rrsig.TypeCovered)
{
case DnsResourceRecordType.NSEC3:
case DnsResourceRecordType.NSEC3PARAM:
recordsToRemove.Add(rrsigRecord);
break;
}
}
if (recordsToRemove.Count > 0)
{
if (TryDeleteRecords(DnsResourceRecordType.RRSIG, recordsToRemove, out IReadOnlyList<DnsResourceRecord> deletedRecords))
allRemovedRecords.AddRange(deletedRecords);
}
break;
}
}
return allRemovedRecords;
}
internal bool HasOnlyNSec3Records()
{
if (!_entries.ContainsKey(DnsResourceRecordType.NSEC3))
return false;
foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
{
switch (entry.Key)
{
case DnsResourceRecordType.NSEC3:
case DnsResourceRecordType.RRSIG:
break;
default:
//found non NSEC3 records
return false;
}
}
return true;
}
internal IReadOnlyList<DnsResourceRecord> RefreshSignatures()
{
if (!_entries.TryGetValue(DnsResourceRecordType.RRSIG, out IReadOnlyList<DnsResourceRecord> rrsigRecords))
throw new InvalidOperationException();
List<DnsResourceRecordType> typesToRefresh = new List<DnsResourceRecordType>();
DateTime utcNow = DateTime.UtcNow;
foreach (DnsResourceRecord rrsigRecord in rrsigRecords)
{
DnsRRSIGRecord rrsig = rrsigRecord.RDATA as DnsRRSIGRecord;
uint signatureValidityPeriod = rrsig.SignatureExpirationValue - rrsig.SignatureInceptionValue;
uint refreshPeriod = signatureValidityPeriod / 3;
if (utcNow > DateTime.UnixEpoch.AddSeconds(rrsig.SignatureExpirationValue - refreshPeriod))
typesToRefresh.Add(rrsig.TypeCovered);
}
List<DnsResourceRecord> newRRSigRecords = new List<DnsResourceRecord>(typesToRefresh.Count);
foreach (DnsResourceRecordType type in typesToRefresh)
{
if (_entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> records))
newRRSigRecords.AddRange(SignRRSet(records));
}
return newRRSigRecords;
}
internal virtual IReadOnlyList<DnsResourceRecord> SignRRSet(IReadOnlyList<DnsResourceRecord> records)
{
throw new InvalidOperationException();
}
internal IReadOnlyList<DnsResourceRecord> GetUpdatedNSecRRSet(string nextDomainName, uint ttl)
{
List<DnsResourceRecordType> types = new List<DnsResourceRecordType>(_entries.Count);
foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
types.Add(entry.Key);
if (!_entries.ContainsKey(DnsResourceRecordType.NSEC))
types.Add(DnsResourceRecordType.NSEC);
if (!_entries.ContainsKey(DnsResourceRecordType.RRSIG))
types.Add(DnsResourceRecordType.RRSIG);
types.Sort();
DnsNSECRecord newNSecRecord = new DnsNSECRecord(nextDomainName, types);
if (!_entries.TryGetValue(DnsResourceRecordType.NSEC, out IReadOnlyList<DnsResourceRecord> existingRecords) || !existingRecords[0].RDATA.Equals(newNSecRecord))
return new DnsResourceRecord[] { new DnsResourceRecord(_name, DnsResourceRecordType.NSEC, DnsClass.IN, ttl, newNSecRecord) };
return Array.Empty<DnsResourceRecord>();
}
internal IReadOnlyList<DnsResourceRecord> CreateNSec3RRSet(string hashedOwnerName, byte[] nextHashedOwnerName, uint ttl, ushort iterations, byte[] salt)
{
List<DnsResourceRecordType> types = new List<DnsResourceRecordType>(_entries.Count);
foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
types.Add(entry.Key);
if (!_entries.ContainsKey(DnsResourceRecordType.RRSIG))
types.Add(DnsResourceRecordType.RRSIG);
types.Sort();
DnsNSEC3Record newNSec3 = new DnsNSEC3Record(DnssecNSEC3HashAlgorithm.SHA1, DnssecNSEC3Flags.None, iterations, salt, nextHashedOwnerName, types);
return new DnsResourceRecord[] { new DnsResourceRecord(hashedOwnerName, DnsResourceRecordType.NSEC3, DnsClass.IN, ttl, newNSec3) };
}
internal DnsResourceRecord GetPartialNSec3Record(string zoneName, uint ttl, ushort iterations, byte[] salt)
{
List<DnsResourceRecordType> types = new List<DnsResourceRecordType>(_entries.Count);
foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
types.Add(entry.Key);
if (!_entries.ContainsKey(DnsResourceRecordType.RRSIG))
types.Add(DnsResourceRecordType.RRSIG);
types.Sort();
DnsNSEC3Record newNSec3Record = new DnsNSEC3Record(DnssecNSEC3HashAlgorithm.SHA1, DnssecNSEC3Flags.None, iterations, salt, Array.Empty<byte>(), types);
return new DnsResourceRecord(newNSec3Record.ComputeHashedOwnerName(_name) + "." + zoneName, DnsResourceRecordType.NSEC3, DnsClass.IN, ttl, newNSec3Record);
}
#endregion
#region public
@@ -359,7 +700,7 @@ namespace DnsServerCore.Dns.Zones
public virtual bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData rdata)
{
return DeleteRecord(type, rdata, out _);
return TryDeleteRecord(type, rdata, out _);
}
public virtual void UpdateRecord(DnsResourceRecord oldRecord, DnsResourceRecord newRecord)
@@ -405,7 +746,7 @@ namespace DnsServerCore.Dns.Zones
if (filteredRecords.Count > 0)
{
if (dnssecOk)
return AddRRSIGs(filteredRecords);
return AppendRRSigTo(filteredRecords);
return filteredRecords;
}
@@ -417,7 +758,7 @@ namespace DnsServerCore.Dns.Zones
if (filteredRecords.Count > 0)
{
if (dnssecOk)
return AddRRSIGs(filteredRecords);
return AppendRRSigTo(filteredRecords);
return filteredRecords;
}