AuthZone: removed code for apex zone into a new ApexZone class. Added DNSSEC related implementation.

This commit is contained in:
Shreyas Zare
2022-01-16 19:31:33 +05:30
parent a7542598c8
commit a11269ece2

View File

@@ -20,38 +20,16 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
using DnsServerCore.Dns.ResourceRecords;
using System;
using System.Collections.Generic;
using System.Net;
using System.Threading.Tasks;
using TechnitiumLibrary.IO;
using TechnitiumLibrary.Net.Dns;
using TechnitiumLibrary;
using TechnitiumLibrary.Net.Dns.ResourceRecords;
namespace DnsServerCore.Dns.Zones
{
public enum AuthZoneTransfer : byte
{
Deny = 0,
Allow = 1,
AllowOnlyZoneNameServers = 2,
AllowOnlySpecifiedNameServers = 3
}
public enum AuthZoneNotify : byte
{
None = 0,
ZoneNameServers = 1,
SpecifiedNameServers = 2
}
abstract class AuthZone : Zone, IDisposable
{
#region variables
protected bool _disabled;
protected AuthZoneTransfer _zoneTransfer;
protected IReadOnlyCollection<IPAddress> _zoneTransferNameServers;
protected AuthZoneNotify _notify;
protected IReadOnlyCollection<IPAddress> _notifyNameServers;
#endregion
@@ -61,10 +39,6 @@ namespace DnsServerCore.Dns.Zones
: base(zoneInfo.Name)
{
_disabled = zoneInfo.Disabled;
_zoneTransfer = zoneInfo.ZoneTransfer;
_zoneTransferNameServers = zoneInfo.ZoneTransferNameServers;
_notify = zoneInfo.Notify;
_notifyNameServers = zoneInfo.NotifyNameServers;
}
protected AuthZone(string name)
@@ -125,123 +99,29 @@ namespace DnsServerCore.Dns.Zones
return newRecords;
}
private static async Task ResolveNameServerAddressesAsync(DnsServer dnsServer, string nsDomain, int port, DnsTransportProtocol protocol, List<NameServerAddress> outNameServers)
private IReadOnlyList<DnsResourceRecord> AddRRSIGs(IReadOnlyList<DnsResourceRecord> records)
{
try
IReadOnlyList<DnsResourceRecord> rrsigRecords = GetRecords(DnsResourceRecordType.RRSIG);
if (rrsigRecords.Count == 0)
return records;
DnsResourceRecordType type = records[0].Type;
List<DnsResourceRecord> newRecords = new List<DnsResourceRecord>(records.Count + 2);
newRecords.AddRange(records);
foreach (DnsResourceRecord rrsigRecord in rrsigRecords)
{
DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.A, DnsClass.IN)).WithTimeout(2000);
if (response.Answer.Count > 0)
{
IReadOnlyList<IPAddress> addresses = DnsClient.ParseResponseA(response);
foreach (IPAddress address in addresses)
outNameServers.Add(new NameServerAddress(nsDomain, new IPEndPoint(address, port), protocol));
}
if ((rrsigRecord.RDATA as DnsRRSIGRecord).TypeCovered == type)
newRecords.Add(rrsigRecord);
}
catch
{ }
if (dnsServer.PreferIPv6)
{
try
{
DnsDatagram response = await dnsServer.DirectQueryAsync(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.AAAA, DnsClass.IN)).WithTimeout(2000);
if (response.Answer.Count > 0)
{
IReadOnlyList<IPAddress> addresses = DnsClient.ParseResponseAAAA(response);
foreach (IPAddress address in addresses)
outNameServers.Add(new NameServerAddress(nsDomain, new IPEndPoint(address, port), protocol));
}
}
catch
{ }
}
}
private static Task ResolveNameServerAddressesAsync(DnsServer dnsServer, DnsResourceRecord nsRecord, List<NameServerAddress> outNameServers)
{
switch (nsRecord.Type)
{
case DnsResourceRecordType.NS:
{
string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer;
IReadOnlyList<DnsResourceRecord> glueRecords = nsRecord.GetGlueRecords();
if (glueRecords.Count > 0)
{
foreach (DnsResourceRecord glueRecord in glueRecords)
{
switch (glueRecord.Type)
{
case DnsResourceRecordType.A:
outNameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsARecord).Address));
break;
case DnsResourceRecordType.AAAA:
if (dnsServer.PreferIPv6)
outNameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsAAAARecord).Address));
break;
}
}
return Task.CompletedTask;
}
else
{
return ResolveNameServerAddressesAsync(dnsServer, nsDomain, 53, DnsTransportProtocol.Udp, outNameServers);
}
}
default:
throw new InvalidOperationException();
}
return newRecords;
}
#endregion
#region protected
protected void CleanupHistory(List<DnsResourceRecord> history)
{
DnsSOARecord soa = _entries[DnsResourceRecordType.SOA][0].RDATA as DnsSOARecord;
DateTime expiry = DateTime.UtcNow.AddSeconds(-soa.Expire);
int index = 0;
while (index < history.Count)
{
//check difference sequence
if (history[index].GetDeletedOn() > expiry)
break; //found record to keep
//skip to next difference sequence
index++;
int soaCount = 1;
while (index < history.Count)
{
if (history[index].Type == DnsResourceRecordType.SOA)
{
soaCount++;
if (soaCount == 3)
break;
}
index++;
}
}
if (index == history.Count)
{
//delete entire history
history.Clear();
return;
}
//remove expired records
history.RemoveRange(0, index);
}
protected bool SetRecords(DnsResourceRecordType type, IReadOnlyList<DnsResourceRecord> records, out IReadOnlyList<DnsResourceRecord> deletedRecords)
{
if (_entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
@@ -296,91 +176,6 @@ namespace DnsServerCore.Dns.Zones
#region public
public async Task<IReadOnlyList<NameServerAddress>> GetPrimaryNameServerAddressesAsync(DnsServer dnsServer)
{
DnsResourceRecord soaRecord = _entries[DnsResourceRecordType.SOA][0];
IReadOnlyList<NameServerAddress> primaryNameServers = soaRecord.GetPrimaryNameServers();
if (primaryNameServers.Count > 0)
{
List<NameServerAddress> resolvedNameServers = new List<NameServerAddress>(primaryNameServers.Count * 2);
foreach (NameServerAddress nameServer in primaryNameServers)
{
if (nameServer.IPEndPoint is null)
{
await ResolveNameServerAddressesAsync(dnsServer, nameServer.Host, nameServer.Port, nameServer.Protocol, resolvedNameServers);
}
else
{
resolvedNameServers.Add(nameServer);
}
}
return resolvedNameServers;
}
string primaryNameServer = (soaRecord.RDATA as DnsSOARecord).PrimaryNameServer;
IReadOnlyList<DnsResourceRecord> nsRecords = GetRecords(DnsResourceRecordType.NS); //stub zone has no authority so cant use QueryRecords
List<NameServerAddress> nameServers = new List<NameServerAddress>(nsRecords.Count * 2);
foreach (DnsResourceRecord nsRecord in nsRecords)
{
if (nsRecord.IsDisabled())
continue;
if (primaryNameServer.Equals((nsRecord.RDATA as DnsNSRecord).NameServer, StringComparison.OrdinalIgnoreCase))
{
//found primary NS
await ResolveNameServerAddressesAsync(dnsServer, nsRecord, nameServers);
break;
}
}
if (nameServers.Count < 1)
await ResolveNameServerAddressesAsync(dnsServer, primaryNameServer, 53, DnsTransportProtocol.Udp, nameServers);
return nameServers;
}
public async Task<IReadOnlyList<NameServerAddress>> GetSecondaryNameServerAddressesAsync(DnsServer dnsServer)
{
string primaryNameServer = (_entries[DnsResourceRecordType.SOA][0].RDATA as DnsSOARecord).PrimaryNameServer;
IReadOnlyList<DnsResourceRecord> nsRecords = GetRecords(DnsResourceRecordType.NS); //stub zone has no authority so cant use QueryRecords
List<NameServerAddress> nameServers = new List<NameServerAddress>(nsRecords.Count * 2);
foreach (DnsResourceRecord nsRecord in nsRecords)
{
if (nsRecord.IsDisabled())
continue;
if (primaryNameServer.Equals((nsRecord.RDATA as DnsNSRecord).NameServer, StringComparison.OrdinalIgnoreCase))
continue; //skip primary name server
await ResolveNameServerAddressesAsync(dnsServer, nsRecord, nameServers);
}
return nameServers;
}
public async Task<IReadOnlyList<NameServerAddress>> GetAllNameServerAddressesAsync(DnsServer dnsServer)
{
IReadOnlyList<NameServerAddress> primaryNameServers = await GetPrimaryNameServerAddressesAsync(dnsServer);
IReadOnlyList<NameServerAddress> secondaryNameServers = await GetSecondaryNameServerAddressesAsync(dnsServer);
if (secondaryNameServers.Count < 1)
return primaryNameServers;
List<NameServerAddress> allNameServers = new List<NameServerAddress>(primaryNameServers.Count + secondaryNameServers.Count);
allNameServers.AddRange(primaryNameServers);
allNameServers.AddRange(secondaryNameServers);
return allNameServers;
}
public void SyncRecords(Dictionary<DnsResourceRecordType, List<DnsResourceRecord>> newEntries)
{
//remove entires of type that do not exists in new entries
@@ -579,16 +374,8 @@ namespace DnsServerCore.Dns.Zones
AddRecord(newRecord);
}
public virtual IReadOnlyList<DnsResourceRecord> QueryRecords(DnsResourceRecordType type)
public virtual IReadOnlyList<DnsResourceRecord> QueryRecords(DnsResourceRecordType type, bool dnssecOk)
{
//check for CNAME
if (_entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList<DnsResourceRecord> existingCNAMERecords))
{
IReadOnlyList<DnsResourceRecord> filteredRecords = FilterDisabledRecords(type, existingCNAMERecords);
if (filteredRecords.Count > 0)
return filteredRecords;
}
if (type == DnsResourceRecordType.ANY)
{
List<DnsResourceRecord> records = new List<DnsResourceRecord>(_entries.Count * 2);
@@ -611,11 +398,29 @@ namespace DnsServerCore.Dns.Zones
return FilterDisabledRecords(type, records);
}
//check for CNAME
if (_entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList<DnsResourceRecord> existingCNAMERecords))
{
IReadOnlyList<DnsResourceRecord> filteredRecords = FilterDisabledRecords(type, existingCNAMERecords);
if (filteredRecords.Count > 0)
{
if (dnssecOk)
return AddRRSIGs(filteredRecords);
return filteredRecords;
}
}
if (_entries.TryGetValue(type, out IReadOnlyList<DnsResourceRecord> existingRecords))
{
IReadOnlyList<DnsResourceRecord> filteredRecords = FilterDisabledRecords(type, existingRecords);
if (filteredRecords.Count > 0)
{
if (dnssecOk)
return AddRRSIGs(filteredRecords);
return filteredRecords;
}
}
switch (type)
@@ -665,42 +470,6 @@ namespace DnsServerCore.Dns.Zones
set { _disabled = value; }
}
public virtual AuthZoneTransfer ZoneTransfer
{
get { return _zoneTransfer; }
set { _zoneTransfer = value; }
}
public IReadOnlyCollection<IPAddress> ZoneTransferNameServers
{
get { return _zoneTransferNameServers; }
set
{
if ((value is not null) && (value.Count > byte.MaxValue))
throw new ArgumentOutOfRangeException(nameof(ZoneTransferNameServers), "Name server addresses cannot be more than 255.");
_zoneTransferNameServers = value;
}
}
public virtual AuthZoneNotify Notify
{
get { return _notify; }
set { _notify = value; }
}
public IReadOnlyCollection<IPAddress> NotifyNameServers
{
get { return _notifyNameServers; }
set
{
if ((value is not null) && (value.Count > byte.MaxValue))
throw new ArgumentOutOfRangeException(nameof(NotifyNameServers), "Name server addresses cannot be more than 255.");
_notifyNameServers = value;
}
}
public virtual bool IsActive
{
get { return !_disabled; }