diff --git a/DnsServerCore/Dns/DnsServer.cs b/DnsServerCore/Dns/DnsServer.cs index de7b62ea..7faf5d7c 100644 --- a/DnsServerCore/Dns/DnsServer.cs +++ b/DnsServerCore/Dns/DnsServer.cs @@ -1,6 +1,6 @@ /* Technitium DNS Server -Copyright (C) 2020 Shreyas Zare (shreyas@technitium.com) +Copyright (C) 2021 Shreyas Zare (shreyas@technitium.com) This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -17,7 +17,6 @@ along with this program. If not, see . */ -using DnsServerCore.Dns.ResourceRecords; using DnsServerCore.Dns.ZoneManagers; using DnsServerCore.Dns.Zones; using Newtonsoft.Json; @@ -116,10 +115,11 @@ namespace DnsServerCore.Dns Timer _cachePrefetchSamplingTimer; readonly object _cachePrefetchSamplingTimerLock = new object(); + const int CACHE_PREFETCH_SAMPLING_TIMER_INITIAL_INTEVAL = 5000; Timer _cachePrefetchRefreshTimer; readonly object _cachePrefetchRefreshTimerLock = new object(); - const int CACHE_PREFETCH_REFRESH_TIMER_INITIAL_INTEVAL = 60000; + const int CACHE_PREFETCH_REFRESH_TIMER_INITIAL_INTEVAL = 10000; DateTime _cachePrefetchSamplingTimerTriggersOn; IList _cacheRefreshSampleList; @@ -1062,7 +1062,7 @@ namespace DnsServerCore.Dns break; case DnsResourceRecordType.FWD: - if ((response.Authority.Count == 1) && (response.Authority[0].Type == DnsResourceRecordType.FWD) && (response.Authority[0].RDATA as DnsForwarderRecord).Forwarder.Equals("this-server", StringComparison.OrdinalIgnoreCase)) + if ((response.Authority.Count == 1) && (response.Authority[0].RDATA as DnsForwarderRecord).Forwarder.Equals("this-server", StringComparison.OrdinalIgnoreCase)) { //do conditional forwarding via "this-server" return ProcessRecursiveQueryAsync(request, null, null, !inAllowedZone, false); @@ -1817,93 +1817,94 @@ namespace DnsServerCore.Dns List cacheRefreshSampleList = new List(eligibleQueries.Count); int cacheRefreshTrigger = (_cachePrefetchSampleIntervalInMinutes + 1) * 60; - foreach (KeyValuePair query in eligibleQueries) + foreach (KeyValuePair eligibleQuery in eligibleQueries) { + if (eligibleQuery.Key.Type == DnsResourceRecordType.ANY) + continue; //dont refresh type ANY queries + + DnsQuestionRecord refreshQuery; + IReadOnlyList viaNameServers = null; IReadOnlyList viaForwarders = null; - AuthZoneInfo zoneInfo = _authZoneManager.GetAuthZoneInfo(query.Key.Name); - if ((zoneInfo != null) && !zoneInfo.Disabled) + //query auth zone for refresh query + DnsDatagram response = _authZoneManager.Query(new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, false, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { eligibleQuery.Key })); + + if (response.RCODE == DnsResponseCode.NoError) { - switch (zoneInfo.Type) + if (response.Answer.Count > 0) { - case AuthZoneType.Primary: - case AuthZoneType.Secondary: - //zone is hosted - continue; //no cache refresh for hosted zones + DnsResourceRecord lastRecord = response.Answer[response.Answer.Count - 1]; - case AuthZoneType.Stub: //stub zone refresh via its name servers - { - IReadOnlyList nsRecords = zoneInfo.GetRecords(DnsResourceRecordType.NS); + if ((lastRecord.Type == DnsResourceRecordType.CNAME) && (eligibleQuery.Key.Type != DnsResourceRecordType.CNAME)) + { + //refresh CNAME + refreshQuery = GetCacheRefreshNeededQuery(new DnsQuestionRecord((lastRecord.RDATA as DnsCNAMERecord).Domain, eligibleQuery.Key.Type, eligibleQuery.Key.Class), cacheRefreshTrigger); + } + else + { + //dont refresh; zone is hosted + continue; + } + } + else if (response.Authority.Count > 0) + { + switch (response.Authority[0].Type) + { + case DnsResourceRecordType.NS: + refreshQuery = GetCacheRefreshNeededQuery(eligibleQuery.Key, cacheRefreshTrigger); - List nameServers = new List(); + if ((refreshQuery != null) && refreshQuery.Equals(eligibleQuery.Key)) + viaNameServers = NameServerAddress.GetNameServersFromResponse(response, _preferIPv6, false); - foreach (DnsResourceRecord nsRecord in nsRecords) + break; + + case DnsResourceRecordType.FWD: + refreshQuery = GetCacheRefreshNeededQuery(eligibleQuery.Key, cacheRefreshTrigger); + + if ((response.Authority.Count == 1) && (response.Authority[0].RDATA as DnsForwarderRecord).Forwarder.Equals("this-server", StringComparison.OrdinalIgnoreCase)) { - string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer; + //do conditional forwarding via "this-server" + } + else + { + //do conditional forwarding + List forwarders = new List(response.Authority.Count); - IReadOnlyList glueRecords = nsRecord.GetGlueRecords(); - if (glueRecords.Count > 0) + foreach (DnsResourceRecord rr in response.Authority) { - foreach (DnsResourceRecord glueRecord in glueRecords) + if (rr.Type == DnsResourceRecordType.FWD) { - switch (glueRecord.Type) - { - case DnsResourceRecordType.A: - nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsARecord).Address)); - break; + DnsForwarderRecord fwd = rr.RDATA as DnsForwarderRecord; - case DnsResourceRecordType.AAAA: - if (_preferIPv6) - nameServers.Add(new NameServerAddress(nsDomain, (glueRecord.RDATA as DnsAAAARecord).Address)); - - break; - } + if (!fwd.Forwarder.Equals("this-server", StringComparison.OrdinalIgnoreCase)) + forwarders.Add(fwd.NameServer); } } - else - { - nameServers.Add(new NameServerAddress(nsDomain)); - } + + if (forwarders.Count > 0) + viaForwarders = forwarders; } + break; - viaNameServers = nameServers; - } - break; - - case AuthZoneType.Forwarder: //forwarder zone refresh via its forwarders - { - IReadOnlyList fwdRecords = zoneInfo.GetRecords(DnsResourceRecordType.FWD); - - List forwarders = new List(); - - foreach (DnsResourceRecord fwdRecord in fwdRecords) - { - if (!fwdRecord.IsDisabled()) - forwarders.Add((fwdRecord.RDATA as DnsForwarderRecord).NameServer); - } - - viaForwarders = forwarders; - } - break; + default: + //dont refresh; invalid response + continue; + } } - } - - if (query.Key.Type == DnsResourceRecordType.ANY) - continue; //dont refresh ANY queries - - DnsQuestionRecord refreshQuery = GetCacheRefreshNeededQuery(query.Key, cacheRefreshTrigger); - if (refreshQuery != null) - { - if ((viaNameServers != null) && !refreshQuery.Name.Equals(query.Key.Name, StringComparison.OrdinalIgnoreCase)) + else { - //stub zone case where refresh query is a CNAME of the original query - if (!refreshQuery.Name.Equals(zoneInfo.Name, StringComparison.OrdinalIgnoreCase) && !refreshQuery.Name.EndsWith("." + zoneInfo.Name, StringComparison.OrdinalIgnoreCase)) - viaNameServers = null; //refresh query is a CNAME that is outside of the stub zone so do usual recursive resolution + //dont refresh; invalid response + continue; } - - cacheRefreshSampleList.Add(new CacheRefreshSample(refreshQuery, viaNameServers, viaForwarders)); } + else + { + refreshQuery = GetCacheRefreshNeededQuery(eligibleQuery.Key, cacheRefreshTrigger); + } + + if (refreshQuery != null) + cacheRefreshSampleList.Add(new CacheRefreshSample(refreshQuery, viaNameServers, viaForwarders)); } _cacheRefreshSampleList = cacheRefreshSampleList; @@ -1980,7 +1981,7 @@ namespace DnsServerCore.Dns private void ResetPrefetchTimers() { - if (_cachePrefetchTrigger == 0) + if ((_cachePrefetchTrigger == 0) || !_allowRecursion) { lock (_cachePrefetchSamplingTimerLock) { @@ -2000,8 +2001,8 @@ namespace DnsServerCore.Dns { if (_cachePrefetchSamplingTimer != null) { - _cachePrefetchSamplingTimer.Change(_cachePrefetchSampleIntervalInMinutes * 60 * 1000, Timeout.Infinite); - _cachePrefetchSamplingTimerTriggersOn = DateTime.UtcNow.AddMinutes(_cachePrefetchSampleIntervalInMinutes); + _cachePrefetchSamplingTimer.Change(CACHE_PREFETCH_SAMPLING_TIMER_INITIAL_INTEVAL, Timeout.Infinite); + _cachePrefetchSamplingTimerTriggersOn = DateTime.UtcNow.AddMilliseconds(CACHE_PREFETCH_SAMPLING_TIMER_INITIAL_INTEVAL); } } @@ -2477,7 +2478,11 @@ namespace DnsServerCore.Dns public bool AllowRecursion { get { return _allowRecursion; } - set { _allowRecursion = value; } + set + { + _allowRecursion = value; + ResetPrefetchTimers(); + } } public bool AllowRecursionOnlyForPrivateNetworks