diff --git a/DnsServerCore/Dns/Zones/AuthZone.cs b/DnsServerCore/Dns/Zones/AuthZone.cs new file mode 100644 index 00000000..97857eee --- /dev/null +++ b/DnsServerCore/Dns/Zones/AuthZone.cs @@ -0,0 +1,143 @@ +using System; +using System.Collections.Generic; +using TechnitiumLibrary.IO; +using TechnitiumLibrary.Net.Dns; +using TechnitiumLibrary.Net.Dns.ResourceRecords; + +namespace DnsServerCore.Dns.Zones +{ + public abstract class AuthZone : Zone + { + #region variables + + protected bool _disabled; + + #endregion + + #region constructor + + protected AuthZone(string name) + : base(name) + { } + + protected AuthZone(string name, DnsSOARecord soa) + : base(name) + { + _entries[DnsResourceRecordType.SOA] = new DnsResourceRecord[] { new DnsResourceRecord(_name, DnsResourceRecordType.SOA, DnsClass.IN, soa.Refresh, soa) }; + _entries[DnsResourceRecordType.NS] = new DnsResourceRecord[] { new DnsResourceRecord(_name, DnsResourceRecordType.NS, DnsClass.IN, soa.Refresh, new DnsNSRecord(soa.MasterNameServer)) }; + } + + protected AuthZone(string name, DnsSOARecord soa, DnsNSRecord ns) + : base(name) + { + _entries[DnsResourceRecordType.SOA] = new DnsResourceRecord[] { new DnsResourceRecord(_name, DnsResourceRecordType.SOA, DnsClass.IN, soa.Refresh, soa) }; + _entries[DnsResourceRecordType.NS] = new DnsResourceRecord[] { new DnsResourceRecord(_name, DnsResourceRecordType.NS, DnsClass.IN, soa.Refresh, ns) }; + } + + #endregion + + #region private + + private IReadOnlyList FilterDisabledRecords(DnsResourceRecordType type, IReadOnlyList records) + { + if (_disabled) + return Array.Empty(); + + if (records.Count == 1) + { + if (records[0].IsDisabled()) + return Array.Empty(); //record disabled + + return records; + } + + List newRecords = new List(records.Count); + + foreach (DnsResourceRecord record in records) + { + if (record.IsDisabled()) + continue; //record disabled + + newRecords.Add(record); + } + + if (newRecords.Count > 1) + { + switch (type) + { + case DnsResourceRecordType.A: + case DnsResourceRecordType.AAAA: + newRecords.Shuffle(); //shuffle records to allow load balancing + break; + } + } + + return newRecords; + } + + #endregion + + #region public + + public IReadOnlyList QueryRecords(DnsResourceRecordType type) + { + //check for CNAME + if (_entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList existingCNAMERecords)) + { + IReadOnlyList filteredRecords = FilterDisabledRecords(type, existingCNAMERecords); + if (filteredRecords.Count > 0) + return existingCNAMERecords; + } + + if (type == DnsResourceRecordType.ANY) + { + List records = new List(_entries.Count * 2); + + foreach (KeyValuePair> entry in _entries) + { + if (entry.Key != DnsResourceRecordType.ANY) + records.AddRange(entry.Value); + } + + return FilterDisabledRecords(type, records); + } + + if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) + return FilterDisabledRecords(type, existingRecords); + + return Array.Empty(); + } + + public override bool ContainsNameServerRecords() + { + IReadOnlyList records = QueryRecords(DnsResourceRecordType.NS); + return (records.Count > 0) && (records[0].Type == DnsResourceRecordType.NS); + } + + public bool AreAllRecordsDisabled() + { + foreach (KeyValuePair> entry in _entries) + { + foreach (DnsResourceRecord record in entry.Value) + { + if (!record.IsDisabled()) + return false; + } + } + + return true; + } + + #endregion + + #region properties + + public bool Disabled + { + get { return _disabled; } + set { _disabled = value; } + } + + #endregion + } +} diff --git a/DnsServerCore/Dns/Zones/CacheZone.cs b/DnsServerCore/Dns/Zones/CacheZone.cs new file mode 100644 index 00000000..9f5343a3 --- /dev/null +++ b/DnsServerCore/Dns/Zones/CacheZone.cs @@ -0,0 +1,157 @@ +using System; +using System.Collections.Generic; +using TechnitiumLibrary.IO; +using TechnitiumLibrary.Net.Dns; + +namespace DnsServerCore.Dns.Zones +{ + public sealed class CacheZone : Zone + { + #region constructor + + public CacheZone(string name) + : base(name) + { } + + #endregion + + #region private + + private static IReadOnlyList FilterExpiredRecords(DnsResourceRecordType type, IReadOnlyList records, bool serveStale) + { + if (records.Count == 1) + { + if (!serveStale && records[0].IsStale) + return Array.Empty(); //record is stale + + if (records[0].TtlValue < 1u) + return Array.Empty(); //ttl expired + + return records; + } + + List newRecords = new List(records.Count); + + foreach (DnsResourceRecord record in records) + { + if (!serveStale && record.IsStale) + continue; //record is stale + + if (record.TtlValue < 1u) + continue; //ttl expired + + newRecords.Add(record); + } + + if (newRecords.Count > 1) + { + switch (type) + { + case DnsResourceRecordType.A: + case DnsResourceRecordType.AAAA: + newRecords.Shuffle(); //shuffle records to allow load balancing + break; + } + } + + return newRecords; + } + + #endregion + + #region public + + public override void SetRecords(DnsResourceRecordType type, IReadOnlyList records) + { + if ((records.Count > 0) && (records[0].RDATA is DnsCache.DnsFailureRecord)) + { + //call trying to cache failure record + if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) + { + if ((existingRecords.Count > 0) && !(existingRecords[0].RDATA is DnsCache.DnsFailureRecord)) + return; //skip to avoid overwriting a useful stale record with a failure record to allow serve-stale to work as intended + } + } + + //set records + base.SetRecords(type, records); + + switch (type) + { + case DnsResourceRecordType.CNAME: + case DnsResourceRecordType.SOA: + case DnsResourceRecordType.NS: + //do nothing + break; + + default: + //remove old CNAME entry since current new entry type overlaps any existing CNAME entry in cache + //keeping both entries will create issue with serve stale implementation since stale CNAME entry will be always returned + _entries.TryRemove(DnsResourceRecordType.CNAME, out _); + break; + } + } + + public void RemoveExpiredRecords() + { + foreach (DnsResourceRecordType type in _entries.Keys) + { + IReadOnlyList records = _entries[type]; + + foreach (DnsResourceRecord record in records) + { + if (record.TtlValue < 1u) + { + //record is expired; update entry + List newRecords = new List(records.Count); + + foreach (DnsResourceRecord existingRecord in records) + { + if (existingRecord.TtlValue < 1u) + continue; + + newRecords.Add(existingRecord); + } + + if (newRecords.Count > 0) + { + //try update entry with non-expired records + _entries.TryUpdate(type, newRecords, records); + } + else + { + //all records expired; remove entry + _entries.TryRemove(type, out _); + } + + break; + } + } + } + } + + public IReadOnlyList QueryRecords(DnsResourceRecordType type, bool serveStale) + { + //check for CNAME + if (_entries.TryGetValue(DnsResourceRecordType.CNAME, out IReadOnlyList existingCNAMERecords)) + { + IReadOnlyList filteredRecords = FilterExpiredRecords(type, existingCNAMERecords, serveStale); + if (filteredRecords.Count > 0) + return existingCNAMERecords; + } + + if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) + return FilterExpiredRecords(type, existingRecords, serveStale); + + return Array.Empty(); + } + + public override bool ContainsNameServerRecords() + { + IReadOnlyList records = QueryRecords(DnsResourceRecordType.NS, false); + return (records.Count > 0) && (records[0].Type == DnsResourceRecordType.NS); + } + + #endregion + } +} diff --git a/DnsServerCore/Dns/Zones/PrimaryZone.cs b/DnsServerCore/Dns/Zones/PrimaryZone.cs new file mode 100644 index 00000000..adfbc3fd --- /dev/null +++ b/DnsServerCore/Dns/Zones/PrimaryZone.cs @@ -0,0 +1,57 @@ +using System; +using System.Collections.Generic; +using TechnitiumLibrary.Net.Dns; +using TechnitiumLibrary.Net.Dns.ResourceRecords; + +namespace DnsServerCore.Dns.Zones +{ + public sealed class PrimaryZone : AuthZone + { + #region variables + + readonly bool _internal; + + #endregion + + #region constructor + + public PrimaryZone(string name, DnsSOARecord soa, bool @internal) + : base(name, soa) + { + _internal = @internal; + } + + public PrimaryZone(string name, DnsSOARecord soa, DnsNSRecord ns, bool @internal) + : base(name, soa, ns) + { + _internal = @internal; + } + + public PrimaryZone(string name, bool disabled) + : base(name) + { + _disabled = disabled; + } + + #endregion + + #region public + + public override void SetRecords(DnsResourceRecordType type, IReadOnlyList records) + { + if (type == DnsResourceRecordType.CNAME) + throw new InvalidOperationException("Cannot add CNAME record to zone root."); + + base.SetRecords(type, records); + } + + #endregion + + #region properties + + public bool Internal + { get { return _internal; } } + + #endregion + } +} diff --git a/DnsServerCore/Dns/Zones/SecondaryZone.cs b/DnsServerCore/Dns/Zones/SecondaryZone.cs new file mode 100644 index 00000000..55c7af35 --- /dev/null +++ b/DnsServerCore/Dns/Zones/SecondaryZone.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; +using TechnitiumLibrary.Net.Dns; +using TechnitiumLibrary.Net.Dns.ResourceRecords; + +namespace DnsServerCore.Dns.Zones +{ + public sealed class SecondaryZone : AuthZone + { + #region constructor + + public SecondaryZone(string name, DnsSOARecord soa) + : base(name, soa) + { } + + public SecondaryZone(string name, bool disabled) + : base(name) + { + _disabled = disabled; + } + + #endregion + + #region public + + public override void SetRecords(DnsResourceRecordType type, IReadOnlyList records) + { + throw new InvalidOperationException("Cannot set records for secondary zone."); + } + + public override void AddRecord(DnsResourceRecord record) + { + throw new InvalidOperationException("Cannot add record for secondary zone."); + } + + public override bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData record) + { + throw new InvalidOperationException("Cannot delete record for secondary zone."); + } + + public override bool DeleteRecords(DnsResourceRecordType type) + { + throw new InvalidOperationException("Cannot delete records for secondary zone."); + } + + #endregion + } +} diff --git a/DnsServerCore/Dns/Zones/StubZone.cs b/DnsServerCore/Dns/Zones/StubZone.cs new file mode 100644 index 00000000..fc509130 --- /dev/null +++ b/DnsServerCore/Dns/Zones/StubZone.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; +using TechnitiumLibrary.Net.Dns; +using TechnitiumLibrary.Net.Dns.ResourceRecords; + +namespace DnsServerCore.Dns.Zones +{ + public sealed class StubZone : AuthZone + { + #region constructor + + public StubZone(string name, DnsSOARecord soa) + : base(name, soa) + { } + + public StubZone(string name, bool disabled) + : base(name) + { + _disabled = disabled; + } + + #endregion + + #region public + + public override void SetRecords(DnsResourceRecordType type, IReadOnlyList records) + { + throw new InvalidOperationException("Cannot set records for stub zone."); + } + + public override void AddRecord(DnsResourceRecord record) + { + throw new InvalidOperationException("Cannot add record for stub zone."); + } + + public override bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData record) + { + throw new InvalidOperationException("Cannot delete record for stub zone."); + } + + public override bool DeleteRecords(DnsResourceRecordType type) + { + throw new InvalidOperationException("Cannot delete records for stub zone."); + } + + #endregion + } +} diff --git a/DnsServerCore/Dns/Zones/SubDomainZone.cs b/DnsServerCore/Dns/Zones/SubDomainZone.cs new file mode 100644 index 00000000..49618502 --- /dev/null +++ b/DnsServerCore/Dns/Zones/SubDomainZone.cs @@ -0,0 +1,13 @@ +namespace DnsServerCore.Dns.Zones +{ + public sealed class SubDomainZone : AuthZone + { + #region constructor + + public SubDomainZone(string name) + : base(name) + { } + + #endregion + } +} diff --git a/DnsServerCore/Dns/Zones/Zone.cs b/DnsServerCore/Dns/Zones/Zone.cs new file mode 100644 index 00000000..eedcce1e --- /dev/null +++ b/DnsServerCore/Dns/Zones/Zone.cs @@ -0,0 +1,119 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using TechnitiumLibrary.Net.Dns; + +namespace DnsServerCore.Dns.Zones +{ + public abstract class Zone + { + #region variables + + protected readonly string _name; + protected readonly ConcurrentDictionary> _entries = new ConcurrentDictionary>(); + + #endregion + + #region constructor + + protected Zone(string name) + { + _name = name; + } + + #endregion + + #region public + + public List ListAllRecords() + { + List records = new List(_entries.Count * 2); + + foreach (KeyValuePair> entry in _entries) + records.AddRange(entry.Value); + + return records; + } + + public virtual void SetRecords(DnsResourceRecordType type, IReadOnlyList records) + { + _entries[type] = records; + } + + public virtual void AddRecord(DnsResourceRecord record) + { + switch (record.Type) + { + case DnsResourceRecordType.CNAME: + case DnsResourceRecordType.PTR: + case DnsResourceRecordType.SOA: + throw new InvalidOperationException("Cannot add record: use SetRecords() for " + record.Type.ToString() + " record"); + } + + _entries.AddOrUpdate(record.Type, delegate (DnsResourceRecordType key) + { + return new DnsResourceRecord[] { record }; + }, + delegate (DnsResourceRecordType key, IReadOnlyList existingRecords) + { + foreach (DnsResourceRecord existingRecord in existingRecords) + { + if (record.Equals(existingRecord.RDATA)) + return existingRecords; + } + + List updateRecords = new List(existingRecords.Count + 1); + + updateRecords.AddRange(existingRecords); + updateRecords.Add(record); + + return updateRecords; + }); + } + + public virtual bool DeleteRecords(DnsResourceRecordType type) + { + return _entries.TryRemove(type, out _); + } + + public virtual bool DeleteRecord(DnsResourceRecordType type, DnsResourceRecordData record) + { + if (_entries.TryGetValue(type, out IReadOnlyList existingRecords)) + { + if (existingRecords.Count == 1) + { + if (record.Equals(existingRecords[0].RDATA)) + return _entries.TryRemove(type, out _); + } + else + { + List updateRecords = new List(existingRecords.Count); + + for (int i = 0; i < existingRecords.Count; i++) + { + if (!record.Equals(existingRecords[i].RDATA)) + updateRecords.Add(existingRecords[i]); + } + + return _entries.TryUpdate(type, updateRecords, existingRecords); + } + } + + return false; + } + + public abstract bool ContainsNameServerRecords(); + + #endregion + + #region properties + + public string Name + { get { return _name; } } + + public bool IsEmpty + { get { return _entries.IsEmpty; } } + + #endregion + } +}