diff --git a/DnsServerCore/DnsServer.cs b/DnsServerCore/DnsServer.cs index 64577eb0..79b256bd 100644 --- a/DnsServerCore/DnsServer.cs +++ b/DnsServerCore/DnsServer.cs @@ -18,6 +18,7 @@ along with this program. If not, see . */ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Net; @@ -59,7 +60,8 @@ namespace DnsServerCore readonly Zone _authoritativeZoneRoot = new Zone(true); readonly Zone _cacheZoneRoot = new Zone(false); - readonly Zone _blockedZoneRoot = new Zone(true); + readonly Zone _allowedZoneRoot = new Zone(true); + Zone _blockedZoneRoot = new Zone(true); readonly IDnsCache _dnsCache; @@ -69,12 +71,15 @@ namespace DnsServerCore NameServerAddress[] _forwarders; DnsClientProtocol _forwarderProtocol = DnsClientProtocol.Udp; bool _preferIPv6 = false; - int _retries = 2; + int _retries = 1; + int _timeout = 2000; int _maxStackCount = 10; LogManager _log; LogManager _queryLog; StatsManager _stats; + readonly ConcurrentDictionary _recursiveQueryLocks = new ConcurrentDictionary(); + volatile ServiceState _state = ServiceState.Stopped; #endregion @@ -394,19 +399,57 @@ namespace DnsServerCore try { + //query authoritative zone DnsDatagram authoritativeResponse = ProcessAuthoritativeQuery(request, isRecursionAllowed); if ((authoritativeResponse.Header.RCODE != DnsResponseCode.Refused) || !request.Header.RecursionDesired || !isRecursionAllowed) return authoritativeResponse; + //query blocked zone DnsDatagram blockedResponse = _blockedZoneRoot.Query(request); if (blockedResponse.Header.RCODE != DnsResponseCode.Refused) { - blockedResponse.Tag = "blocked"; - return blockedResponse; + //query allowed zone + DnsDatagram allowedResponse = _allowedZoneRoot.Query(request); + + if (allowedResponse.Header.RCODE == DnsResponseCode.Refused) + { + //request domain not in allowed zone + + if (blockedResponse.Header.RCODE == DnsResponseCode.NameError) + { + DnsResourceRecord[] answer; + DnsResourceRecord[] authority; + + switch (blockedResponse.Question[0].Type) + { + case DnsResourceRecordType.A: + answer = new DnsResourceRecord[] { new DnsResourceRecord(blockedResponse.Question[0].Name, DnsResourceRecordType.A, blockedResponse.Question[0].Class, 60, new DnsARecord(IPAddress.Any)) }; + authority = new DnsResourceRecord[] { }; + break; + + case DnsResourceRecordType.AAAA: + answer = new DnsResourceRecord[] { new DnsResourceRecord(blockedResponse.Question[0].Name, DnsResourceRecordType.AAAA, blockedResponse.Question[0].Class, 60, new DnsAAAARecord(IPAddress.IPv6Any)) }; + authority = new DnsResourceRecord[] { }; + break; + + default: + answer = blockedResponse.Answer; + authority = blockedResponse.Authority; + break; + } + + blockedResponse = new DnsDatagram(new DnsHeader(blockedResponse.Header.Identifier, true, blockedResponse.Header.OPCODE, false, false, blockedResponse.Header.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.NoError, blockedResponse.Header.QDCOUNT, (ushort)answer.Length, (ushort)authority.Length, 0), blockedResponse.Question, answer, authority, null); + } + + //return blocked response + blockedResponse.Tag = "blocked"; + return blockedResponse; + } } + //do recursive query return ProcessRecursiveQuery(request); } catch (Exception ex) @@ -516,19 +559,7 @@ namespace DnsServerCore private DnsDatagram ProcessRecursiveQuery(DnsDatagram request, NameServerAddress[] viaNameServers = null) { - DnsClientProtocol protocol; - - if (_forwarders == null) - { - protocol = DnsClient.RecursiveResolveDefaultProtocol; - } - else - { - viaNameServers = _forwarders; //forwarder has higher weightage - protocol = _forwarderProtocol; - } - - DnsDatagram response = DnsClient.ResolveViaNameServers(request.Question[0], viaNameServers, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount); + DnsDatagram response = RecursiveResolve(request, viaNameServers); DnsResourceRecord[] authority; @@ -553,7 +584,7 @@ namespace DnsServerCore else question = new DnsQuestionRecord((lastRR.RDATA as DnsCNAMERecord).CNAMEDomainName, questionType, DnsClass.IN); - lastResponse = DnsClient.ResolveViaNameServers(question, _forwarders, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount); + lastResponse = RecursiveResolve(question, _forwarders); if ((lastResponse.Header.RCODE != DnsResponseCode.NoError) || (lastResponse.Answer.Length == 0)) break; @@ -586,6 +617,87 @@ namespace DnsServerCore return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, false, false, response.Header.RCODE, 1, (ushort)response.Answer.Length, (ushort)authority.Length, 0), request.Question, response.Answer, authority, new DnsResourceRecord[] { }); } + private DnsDatagram RecursiveResolve(DnsQuestionRecord questionRecord, NameServerAddress[] viaNameServers) + { + return RecursiveResolve(new DnsDatagram(new DnsHeader(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, 1, 0, 0, 0), new DnsQuestionRecord[] { questionRecord }, null, null, null), viaNameServers); + } + + private DnsDatagram RecursiveResolve(DnsDatagram request, NameServerAddress[] viaNameServers) + { + //query cache zone to see if answer available + { + DnsDatagram cacheResponse = _cacheZoneRoot.Query(request); + + if (cacheResponse.Header.RCODE != DnsResponseCode.Refused) + { + if (cacheResponse.Answer.Length > 0) + return cacheResponse; + else if ((cacheResponse.Authority.Length == 0) || (cacheResponse.Authority[0].Type == DnsResourceRecordType.SOA)) + return cacheResponse; + } + } + + //recursion with locking + object newLockObj = new object(); + object actualLockObj = _recursiveQueryLocks.GetOrAdd(request.Question[0], newLockObj); + + if (!actualLockObj.Equals(newLockObj)) + { + //question already being recursively resolved by another thread, wait till timeout or pulse signal + bool waitTimeout; + + lock (actualLockObj) + { + waitTimeout = !Monitor.Wait(actualLockObj, _timeout); + } + + if (!waitTimeout) + { + //query cache zone again to see if answer available + DnsDatagram cacheResponse = _cacheZoneRoot.Query(request); + + if (cacheResponse.Header.RCODE != DnsResponseCode.Refused) + { + if (cacheResponse.Answer.Length > 0) + return cacheResponse; + else if ((cacheResponse.Authority.Length == 0) || (cacheResponse.Authority[0].Type == DnsResourceRecordType.SOA)) + return cacheResponse; + } + } + + //wait timeout or no response available in cache so respond with server failure + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, true, false, false, DnsResponseCode.ServerFailure, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); + } + + DnsClientProtocol protocol; + + if (_forwarders == null) + { + protocol = DnsClient.RecursiveResolveDefaultProtocol; + } + else + { + viaNameServers = _forwarders; //forwarder has higher weightage + protocol = _forwarderProtocol; + } + + try + { + return DnsClient.ResolveViaNameServers(request.Question[0], viaNameServers, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount, _timeout); + } + finally + { + //remove question lock + _recursiveQueryLocks.TryRemove(request.Question[0], out object lockObj); + + //pulse all waiting threads + lock (newLockObj) + { + Monitor.PulseAll(newLockObj); + } + } + } + #endregion #region public @@ -665,8 +777,23 @@ namespace DnsServerCore public Zone CacheZoneRoot { get { return _cacheZoneRoot; } } + public Zone AllowedZoneRoot + { get { return _allowedZoneRoot; } } + public Zone BlockedZoneRoot - { get { return _blockedZoneRoot; } } + { + get { return _blockedZoneRoot; } + set + { + if (value == null) + throw new NullReferenceException(); + + if (!value.IsAuthoritative) + throw new ArgumentException("Blocked zone must be authoritative."); + + _blockedZoneRoot = value; + } + } internal IDnsCache Cache { get { return _dnsCache; } } @@ -734,6 +861,12 @@ namespace DnsServerCore set { _retries = value; } } + public int Timeout + { + get { return _timeout; } + set { _timeout = value; } + } + public int MaxStackCount { get { return _maxStackCount; }