diff --git a/DnsServerCore/Dns/Zones/CacheZoneManager.cs b/DnsServerCore/Dns/Zones/CacheZoneManager.cs new file mode 100644 index 00000000..bd9ead33 --- /dev/null +++ b/DnsServerCore/Dns/Zones/CacheZoneManager.cs @@ -0,0 +1,212 @@ +using System; +using System.Collections.Generic; +using TechnitiumLibrary.Net.Dns; +using TechnitiumLibrary.Net.Dns.ResourceRecords; + +namespace DnsServerCore.Dns.Zones +{ + public class CacheZoneManager : DnsCache + { + #region variables + + const uint NEGATIVE_RECORD_TTL = 300u; + const uint MINIMUM_RECORD_TTL = 10u; + const uint SERVE_STALE_TTL = 7 * 24 * 60 * 60; //7 days serve stale ttl as per draft-ietf-dnsop-serve-stale-04 + + readonly protected ZoneTree _root = new ZoneTree(); + + #endregion + + #region constructor + + public CacheZoneManager() + : base(NEGATIVE_RECORD_TTL, MINIMUM_RECORD_TTL, SERVE_STALE_TTL) + { } + + #endregion + + #region protected + + protected override void CacheRecords(IReadOnlyList resourceRecords) + { + if (resourceRecords.Count == 1) + { + CacheZone zone = _root.GetOrAdd(resourceRecords[0].Name, delegate (string key) + { + return new CacheZone(resourceRecords[0].Name); + }); + + zone.SetRecords(resourceRecords[0].Type, resourceRecords); + } + else + { + Dictionary>> groupedByDomainRecords = DnsResourceRecord.GroupRecords(resourceRecords); + + //add grouped records + foreach (KeyValuePair>> groupedByTypeRecords in groupedByDomainRecords) + { + CacheZone zone = _root.GetOrAdd(groupedByTypeRecords.Key, delegate (string key) + { + return new CacheZone(groupedByTypeRecords.Key); + }); + + foreach (KeyValuePair> groupedRecords in groupedByTypeRecords.Value) + zone.SetRecords(groupedRecords.Key, groupedRecords.Value); + } + } + } + + #endregion + + #region private + + private List GetAdditionalRecords(IReadOnlyCollection nsRecords, bool serveStale) + { + List additionalRecords = new List(); + + foreach (DnsResourceRecord nsRecord in nsRecords) + { + if (nsRecord.Type != DnsResourceRecordType.NS) + continue; + + CacheZone cacheZone = _root.FindZone((nsRecord.RDATA as DnsNSRecord).NSDomainName, out _, out _, out _); + if (cacheZone != null) + { + { + IReadOnlyList records = cacheZone.QueryRecords(DnsResourceRecordType.A, serveStale); + if ((records.Count > 0) && (records[0].RDATA is DnsARecord)) + additionalRecords.AddRange(records); + } + + { + IReadOnlyList records = cacheZone.QueryRecords(DnsResourceRecordType.AAAA, serveStale); + if ((records.Count > 0) && (records[0].RDATA is DnsAAAARecord)) + additionalRecords.AddRange(records); + } + } + } + + return additionalRecords; + } + + #endregion + + #region public + + public void DoMaintenance() + { + foreach (CacheZone zone in _root) + { + zone.RemoveExpiredRecords(); + + if (zone.IsEmpty) + _root.TryRemove(zone.Name, out _); //remove empty zone + } + } + + public void Flush() + { + _root.Clear(); + } + + public bool DeleteZone(string domain) + { + return _root.TryRemove(domain, out _); + } + + public List ListSubDomains(string domain) + { + return _root.ListSubDomains(domain); + } + + public List ListAllRecords(string domain) + { + if (_root.TryGet(domain, out CacheZone zone)) + return zone.ListAllRecords(); + + return new List(0); + } + + public DnsDatagram QueryClosestDelegation(DnsDatagram request) + { + _ = _root.FindZone(request.Question[0].Name, out CacheZone delegation, out _, out _); + if (delegation == null) + { + //no cached delegation found + return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.Refused, request.Question); + } + + //return closest name servers in delegation + IReadOnlyList authority = delegation.QueryRecords(DnsResourceRecordType.NS, false); + List additional = GetAdditionalRecords(authority, false); + + return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.NoError, request.Question, null, authority, additional); + } + + public override DnsDatagram Query(DnsDatagram request, bool serveStale = false) + { + CacheZone zone = _root.FindZone(request.Question[0].Name, out CacheZone delegation, out _, out _); + if (zone == null) + { + //zone not found + if (delegation == null) + { + //no cached delegation found + return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.Refused, request.Question); + } + + //return closest name servers in delegation + IReadOnlyList authority = delegation.QueryRecords(DnsResourceRecordType.NS, serveStale); + List additional = GetAdditionalRecords(authority, serveStale); + + return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.NoError, request.Question, null, authority, additional); + } + + //zone found + IReadOnlyList answers = zone.QueryRecords(request.Question[0].Type, serveStale); + if (answers.Count > 0) + { + if (answers[0].RDATA is DnsEmptyRecord) + { + DnsResourceRecord[] authority = null; + DnsResourceRecord soaRecord = (answers[0].RDATA as DnsEmptyRecord).Authority; + if (soaRecord != null) + authority = new DnsResourceRecord[] { soaRecord }; + + return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.NoError, request.Question, null, authority); + } + + if (answers[0].RDATA is DnsNXRecord) + { + DnsResourceRecord[] authority = null; + DnsResourceRecord soaRecord = (answers[0].RDATA as DnsNXRecord).Authority; + if (soaRecord != null) + authority = new DnsResourceRecord[] { soaRecord }; + + return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.NameError, request.Question, null, authority); + } + + if (answers[0].RDATA is DnsANYRecord) + { + DnsANYRecord anyRR = answers[0].RDATA as DnsANYRecord; + return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.NoError, request.Question, anyRR.Records); + } + + if (answers[0].RDATA is DnsFailureRecord) + return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, false, (answers[0].RDATA as DnsFailureRecord).RCODE, request.Question); + + IReadOnlyList additional = null; + + if (request.Question[0].Type == DnsResourceRecordType.NS) + additional = GetAdditionalRecords(answers, serveStale); + + return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.NoError, request.Question, answers, null, additional); + } + + //found nothing in cache + return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.Refused, request.Question); + } + + #endregion + } +}