AuthZone: added GetNameServerAddressesAsync() for stub refresh. Updated implementation for IXFR support. Implemented UpdateRecord() method.

This commit is contained in:
Shreyas Zare
2021-07-10 13:19:25 +05:30
parent 31a3bcb92e
commit d8a91ba3ac

View File

@@ -125,91 +125,194 @@ namespace DnsServerCore.Dns.Zones
return newRecords;
}
private static async Task<IReadOnlyList<NameServerAddress>> GetNameServerAddressesAsync(DnsServer dnsServer, DnsResourceRecord record)
private static async Task<IReadOnlyList<NameServerAddress>> ResolveNameServerAddressesAsync(DnsServer dnsServer, string nsDomain)
{
string nsDomain;
switch (record.Type)
{
case DnsResourceRecordType.NS:
nsDomain = (record.RDATA as DnsNSRecord).NameServer;
break;
case DnsResourceRecordType.SOA:
nsDomain = (record.RDATA as DnsSOARecord).PrimaryNameServer;
break;
default:
throw new InvalidOperationException();
}
List<NameServerAddress> nameServers = new List<NameServerAddress>(2);
IReadOnlyList<DnsResourceRecord> glueRecords = record.GetGlueRecords();
if (glueRecords.Count > 0)
try
{
foreach (DnsResourceRecord glueRecord in glueRecords)
DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.A, DnsClass.IN)).WithTimeout(2000);
if (response.Answer.Count > 0)
{
switch (glueRecord.Type)
{
case DnsResourceRecordType.A:
nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsARecord).Address));
break;
case DnsResourceRecordType.AAAA:
if (dnsServer.PreferIPv6)
nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsAAAARecord).Address));
break;
}
IReadOnlyList<IPAddress> addresses = DnsClient.ParseResponseA(response);
foreach (IPAddress address in addresses)
nameServers.Add(new NameServerAddress(nsDomain, address));
}
}
else
catch
{ }
if (dnsServer.PreferIPv6)
{
//resolve addresses
try
{
DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.A, DnsClass.IN)).WithTimeout(2000);
DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.AAAA, DnsClass.IN)).WithTimeout(2000);
if (response.Answer.Count > 0)
{
IReadOnlyList<IPAddress> addresses = DnsClient.ParseResponseA(response);
IReadOnlyList<IPAddress> addresses = DnsClient.ParseResponseAAAA(response);
foreach (IPAddress address in addresses)
nameServers.Add(new NameServerAddress(nsDomain, address));
}
}
catch
{ }
if (dnsServer.PreferIPv6)
{
try
{
DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.AAAA, DnsClass.IN)).WithTimeout(2000);
if (response.Answer.Count > 0)
{
IReadOnlyList<IPAddress> addresses = DnsClient.ParseResponseAAAA(response);
foreach (IPAddress address in addresses)
nameServers.Add(new NameServerAddress(nsDomain, address));
}
}
catch
{ }
}
}
return nameServers;
}
private static Task<IReadOnlyList<NameServerAddress>> ResolveNameServerAddressesAsync(DnsServer dnsServer, DnsResourceRecord nsRecord)
{
switch (nsRecord.Type)
{
case DnsResourceRecordType.NS:
{
string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer;
IReadOnlyList<DnsResourceRecord> glueRecords = nsRecord.GetGlueRecords();
if (glueRecords.Count > 0)
{
List<NameServerAddress> nameServers = new List<NameServerAddress>(2);
foreach (DnsResourceRecord glueRecord in glueRecords)
{
switch (glueRecord.Type)
{
case DnsResourceRecordType.A:
nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsARecord).Address));
break;
case DnsResourceRecordType.AAAA:
if (dnsServer.PreferIPv6)
nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsAAAARecord).Address));
break;
}
}
return Task.FromResult(nameServers as IReadOnlyList<NameServerAddress>);
}
else
{
return ResolveNameServerAddressesAsync(dnsServer, nsDomain);
}
}
default:
throw new InvalidOperationException();
}
}
#endregion
#region protected
protected void CleanupHistory(List<DnsResourceRecord> history)
{
DnsSOARecord soa = _entries[DnsResourceRecordType.SOA][0].RDATA as DnsSOARecord;
DateTime expiry = DateTime.UtcNow.AddSeconds(-soa.Expire);
int index = 0;
while (index < history.Count)
{
//check difference sequence
if (history[index].GetDeletedOn() > expiry)
break; //found record to keep
//skip to next difference sequence
index++;
int soaCount = 1;
while (index < history.Count)
{
if (history[index].Type == DnsResourceRecordType.SOA)
{
soaCount++;
if (soaCount == 3)
break;
}
index++;
}
}
if (index == history.Count)
{
//delete entire history
history.Clear();
return;
}
//remove expired records
history.RemoveRange(0, index);
}
protected bool SetRecords(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records, out IReadOnlyList<DnsResourceRecord> deletedRecords)
{
if (_entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
{
deletedRecords = existingRecords;
return _entries.TryUpdate(type, records, existingRecords);
}
else
{
deletedRecords = null;
return _entries.TryAdd(type, records);
}
}
protected bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData rdata, out DnsResourceRecord deletedRecord)
{
if (_entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
{
if (existingRecords.Count == 1)
{
if (rdata.Equals(existingRecords[0].RDATA))
{
if (_entries.TryRemove(type, out IReadOnlyList<DnsResourceRecord> removedRecords))
{
deletedRecord = removedRecords[0];
return true;
}
}
}
else
{
deletedRecord = null;
List<DnsResourceRecord> updatedRecords = new List<DnsResourceRecord>(existingRecords.Count);
foreach (DnsResourceRecord existingRecord in existingRecords)
{
if ((deletedRecord is null) && rdata.Equals(existingRecord.RDATA))
deletedRecord = existingRecord;
else
updatedRecords.Add(existingRecord);
}
return _entries.TryUpdate(type, updatedRecords, existingRecords);
}
}
deletedRecord = null;
return false;
}
#endregion
#region public
public async Task<IReadOnlyList<NameServerAddress>> GetPrimaryNameServerAddressesAsync(DnsServer dnsServer)
{
DnsResourceRecord soaRecord = _entries[DnsResourceRecordType.SOA][0];
IReadOnlyList<NameServerAddress> primaryNameServers = soaRecord.GetPrimaryNameServers();
if (primaryNameServers.Count > 0)
return primaryNameServers;
List<NameServerAddress> nameServers = new List<NameServerAddress>();
DnsResourceRecord soaRecord = _entries[DnsResourceRecordType.SOA][0];
DnsSOARecord soa = soaRecord.RDATA as DnsSOARecord;
string primaryNameServer = (soaRecord.RDATA as DnsSOARecord).PrimaryNameServer;
IReadOnlyList<DnsResourceRecord> nsRecords = GetRecords(DnsResourceRecordType.NS); //stub zone has no authority so cant use QueryRecords
foreach (DnsResourceRecord nsRecord in nsRecords)
@@ -217,21 +320,16 @@ namespace DnsServerCore.Dns.Zones
if (nsRecord.IsDisabled())
continue;
string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer;
if (soa.PrimaryNameServer.Equals(nsDomain, StringComparison.OrdinalIgnoreCase))
if (primaryNameServer.Equals((nsRecord.RDATA as DnsNSRecord).NameServer, StringComparison.OrdinalIgnoreCase))
{
//found primary NS
nameServers.AddRange(await GetNameServerAddressesAsync(dnsServer, nsRecord));
nameServers.AddRange(await ResolveNameServerAddressesAsync(dnsServer, nsRecord));
break;
}
}
foreach (NameServerAddress nameServer in await GetNameServerAddressesAsync(dnsServer, soaRecord))
{
if (!nameServers.Contains(nameServer))
nameServers.Add(nameServer);
}
if (nameServers.Count < 1)
nameServers.AddRange(await ResolveNameServerAddressesAsync(dnsServer, primaryNameServer));
return nameServers;
}
@@ -240,7 +338,7 @@ namespace DnsServerCore.Dns.Zones
{
List<NameServerAddress> nameServers = new List<NameServerAddress>();
DnsSOARecord soa = _entries[DnsResourceRecordType.SOA][0].RDATA as DnsSOARecord;
string primaryNameServer = (_entries[DnsResourceRecordType.SOA][0].RDATA as DnsSOARecord).PrimaryNameServer;
IReadOnlyList<DnsResourceRecord> nsRecords = GetRecords(DnsResourceRecordType.NS); //stub zone has no authority so cant use QueryRecords
foreach (DnsResourceRecord nsRecord in nsRecords)
@@ -248,27 +346,33 @@ namespace DnsServerCore.Dns.Zones
if (nsRecord.IsDisabled())
continue;
string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer;
if (soa.PrimaryNameServer.Equals(nsDomain, StringComparison.OrdinalIgnoreCase))
if (primaryNameServer.Equals((nsRecord.RDATA as DnsNSRecord).NameServer, StringComparison.OrdinalIgnoreCase))
continue; //skip primary name server
nameServers.AddRange(await GetNameServerAddressesAsync(dnsServer, nsRecord));
nameServers.AddRange(await ResolveNameServerAddressesAsync(dnsServer, nsRecord));
}
return nameServers;
}
public void SyncRecords(Dictionary<DnsResourceRecordType, List<DnsResourceRecord>> newEntries, bool dontRemoveRecords)
public Task<IReadOnlyList<NameServerAddress>> GetNameServerAddressesAsync(DnsServer dnsServer)
{
if (!dontRemoveRecords)
DnsResourceRecord soaRecord = _entries[DnsResourceRecordType.SOA][0];
IReadOnlyList<NameServerAddress> primaryNameServers = soaRecord.GetPrimaryNameServers();
if (primaryNameServers.Count > 0)
return Task.FromResult(primaryNameServers);
return GetSecondaryNameServerAddressesAsync(dnsServer);
}
public void SyncRecords(Dictionary<DnsResourceRecordType, List<DnsResourceRecord>> newEntries)
{
//remove entires of type that do not exists in new entries
foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
{
//remove entires of type that do not exists in new entries
foreach (KeyValuePair<DnsResourceRecordType, IReadOnlyList<DnsResourceRecord>> entry in _entries)
{
if (!newEntries.ContainsKey(entry.Key))
_entries.TryRemove(entry.Key, out _);
}
if (!newEntries.ContainsKey(entry.Key))
_entries.TryRemove(entry.Key, out _);
}
//set new entries into zone
@@ -298,10 +402,14 @@ namespace DnsServerCore.Dns.Zones
if (newEntry.Value.Count != 1)
continue; //skip invalid SOA record
if ((this is SecondaryZone) || (this is StubZone))
if (this is SecondaryZone)
{
//copy existing SOA record's glue addresses to new SOA record
newEntry.Value[0].SetGlueRecords(_entries[DnsResourceRecordType.SOA][0].GetGlueRecords());
//copy existing SOA record's glue addresses and comments to new SOA record
DnsResourceRecord existingSoaRecord = _entries[DnsResourceRecordType.SOA][0];
DnsResourceRecord newSoaRecord = newEntry.Value[0];
newSoaRecord.SetPrimaryNameServers(existingSoaRecord.GetPrimaryNameServers());
newSoaRecord.SetComments(existingSoaRecord.GetComments());
}
}
@@ -310,6 +418,89 @@ namespace DnsServerCore.Dns.Zones
}
}
public void SyncRecords(Dictionary<DnsResourceRecordType, List<DnsResourceRecord>> deletedEntries = null, Dictionary<DnsResourceRecordType, List<DnsResourceRecord>> addedEntries = null)
{
if (deletedEntries is not null)
{
foreach (KeyValuePair<DnsResourceRecordType, List<DnsResourceRecord>> deletedEntry in deletedEntries)
{
if (_entries.TryGetValue(deletedEntry.Key, out IReadOnlyList<DnsResourceRecord> existingRecords))
{
List<DnsResourceRecord> updatedRecords = new List<DnsResourceRecord>(Math.Max(0, existingRecords.Count - deletedEntry.Value.Count));
foreach (DnsResourceRecord existingRecord in existingRecords)
{
bool deleted = false;
foreach (DnsResourceRecord deletedRecord in deletedEntry.Value)
{
if (existingRecord.RDATA.Equals(deletedRecord.RDATA))
{
deleted = true;
break;
}
}
if (!deleted)
updatedRecords.Add(existingRecord);
}
if (existingRecords.Count > updatedRecords.Count)
{
if (updatedRecords.Count > 0)
_entries[deletedEntry.Key] = updatedRecords;
else
_entries.TryRemove(deletedEntry.Key, out _);
}
}
}
}
if (addedEntries is not null)
{
foreach (KeyValuePair<DnsResourceRecordType, List<DnsResourceRecord>> addedEntry in addedEntries)
{
_entries.AddOrUpdate(addedEntry.Key, addedEntry.Value, delegate (DnsResourceRecordType key, IReadOnlyList<DnsResourceRecord> existingRecords)
{
List<DnsResourceRecord> updatedRecords = new List<DnsResourceRecord>(existingRecords.Count + addedEntry.Value.Count);
updatedRecords.AddRange(existingRecords);
foreach (DnsResourceRecord addedRecord in addedEntry.Value)
{
bool exists = false;
foreach (DnsResourceRecord existingRecord in existingRecords)
{
if (addedRecord.RDATA.Equals(existingRecord.RDATA))
{
exists = true;
break;
}
}
if (!exists)
updatedRecords.Add(addedRecord);
}
if (updatedRecords.Count > existingRecords.Count)
return updatedRecords;
else
return existingRecords;
});
}
}
}
public void SyncGlueRecords(IReadOnlyCollection<DnsResourceRecord> deletedGlueRecords, IReadOnlyCollection<DnsResourceRecord> addedGlueRecords)
{
if (_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList<DnsResourceRecord> nsRecords))
{
foreach (DnsResourceRecord nsRecord in nsRecords)
nsRecord.SyncGlueRecords(deletedGlueRecords, addedGlueRecords);
}
}
public void LoadRecords(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records)
{
_entries[type] = records;
@@ -339,16 +530,16 @@ namespace DnsServerCore.Dns.Zones
{
foreach (DnsResourceRecord existingRecord in existingRecords)
{
if (record.Equals(existingRecord.RDATA))
if (record.RDATA.Equals(existingRecord.RDATA))
return existingRecords;
}
List<DnsResourceRecord> updateRecords = new List<DnsResourceRecord>(existingRecords.Count + 1);
List<DnsResourceRecord> updatedRecords = new List<DnsResourceRecord>(existingRecords.Count + 1);
updateRecords.AddRange(existingRecords);
updateRecords.Add(record);
updatedRecords.AddRange(existingRecords);
updatedRecords.Add(record);
return updateRecords;
return updatedRecords;
});
}
@@ -357,30 +548,21 @@ namespace DnsServerCore.Dns.Zones
return _entries.TryRemove(type, out _);
}
public virtual bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData record)
public virtual bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData rdata)
{
if (_entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
{
if (existingRecords.Count == 1)
{
if (record.Equals(existingRecords[0].RDATA))
return _entries.TryRemove(type, out _);
}
else
{
List<DnsResourceRecord> updateRecords = new List<DnsResourceRecord>(existingRecords.Count);
return DeleteRecord(type, rdata, out _);
}
for (int i = 0; i < existingRecords.Count; i++)
{
if (!record.Equals(existingRecords[i].RDATA))
updateRecords.Add(existingRecords[i]);
}
public virtual void UpdateRecord(DnsResourceRecord oldRecord, DnsResourceRecord newRecord)
{
if (oldRecord.Type == DnsResourceRecordType.SOA)
throw new InvalidOperationException("Cannot update record: use SetRecords() for " + oldRecord.Type.ToString() + " record");
return _entries.TryUpdate(type, updateRecords, existingRecords);
}
}
if (oldRecord.Type != newRecord.Type)
throw new InvalidOperationException("Old and new record types do not match.");
return false;
DeleteRecord(oldRecord.Type, oldRecord.RDATA);
AddRecord(newRecord);
}
public virtual IReadOnlyList<DnsResourceRecord> QueryRecords(DnsResourceRecordType type)