diff --git a/DnsServerCore/DnsServer.cs b/DnsServerCore/DnsServer.cs
index 64577eb0..79b256bd 100644
--- a/DnsServerCore/DnsServer.cs
+++ b/DnsServerCore/DnsServer.cs
@@ -18,6 +18,7 @@ along with this program. If not, see .
*/
using System;
+using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Net;
@@ -59,7 +60,8 @@ namespace DnsServerCore
readonly Zone _authoritativeZoneRoot = new Zone(true);
readonly Zone _cacheZoneRoot = new Zone(false);
- readonly Zone _blockedZoneRoot = new Zone(true);
+ readonly Zone _allowedZoneRoot = new Zone(true);
+ Zone _blockedZoneRoot = new Zone(true);
readonly IDnsCache _dnsCache;
@@ -69,12 +71,15 @@ namespace DnsServerCore
NameServerAddress[] _forwarders;
DnsClientProtocol _forwarderProtocol = DnsClientProtocol.Udp;
bool _preferIPv6 = false;
- int _retries = 2;
+ int _retries = 1;
+ int _timeout = 2000;
int _maxStackCount = 10;
LogManager _log;
LogManager _queryLog;
StatsManager _stats;
+ readonly ConcurrentDictionary _recursiveQueryLocks = new ConcurrentDictionary();
+
volatile ServiceState _state = ServiceState.Stopped;
#endregion
@@ -394,19 +399,57 @@ namespace DnsServerCore
try
{
+ //query authoritative zone
DnsDatagram authoritativeResponse = ProcessAuthoritativeQuery(request, isRecursionAllowed);
if ((authoritativeResponse.Header.RCODE != DnsResponseCode.Refused) || !request.Header.RecursionDesired || !isRecursionAllowed)
return authoritativeResponse;
+ //query blocked zone
DnsDatagram blockedResponse = _blockedZoneRoot.Query(request);
if (blockedResponse.Header.RCODE != DnsResponseCode.Refused)
{
- blockedResponse.Tag = "blocked";
- return blockedResponse;
+ //query allowed zone
+ DnsDatagram allowedResponse = _allowedZoneRoot.Query(request);
+
+ if (allowedResponse.Header.RCODE == DnsResponseCode.Refused)
+ {
+ //request domain not in allowed zone
+
+ if (blockedResponse.Header.RCODE == DnsResponseCode.NameError)
+ {
+ DnsResourceRecord[] answer;
+ DnsResourceRecord[] authority;
+
+ switch (blockedResponse.Question[0].Type)
+ {
+ case DnsResourceRecordType.A:
+ answer = new DnsResourceRecord[] { new DnsResourceRecord(blockedResponse.Question[0].Name, DnsResourceRecordType.A, blockedResponse.Question[0].Class, 60, new DnsARecord(IPAddress.Any)) };
+ authority = new DnsResourceRecord[] { };
+ break;
+
+ case DnsResourceRecordType.AAAA:
+ answer = new DnsResourceRecord[] { new DnsResourceRecord(blockedResponse.Question[0].Name, DnsResourceRecordType.AAAA, blockedResponse.Question[0].Class, 60, new DnsAAAARecord(IPAddress.IPv6Any)) };
+ authority = new DnsResourceRecord[] { };
+ break;
+
+ default:
+ answer = blockedResponse.Answer;
+ authority = blockedResponse.Authority;
+ break;
+ }
+
+ blockedResponse = new DnsDatagram(new DnsHeader(blockedResponse.Header.Identifier, true, blockedResponse.Header.OPCODE, false, false, blockedResponse.Header.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.NoError, blockedResponse.Header.QDCOUNT, (ushort)answer.Length, (ushort)authority.Length, 0), blockedResponse.Question, answer, authority, null);
+ }
+
+ //return blocked response
+ blockedResponse.Tag = "blocked";
+ return blockedResponse;
+ }
}
+ //do recursive query
return ProcessRecursiveQuery(request);
}
catch (Exception ex)
@@ -516,19 +559,7 @@ namespace DnsServerCore
private DnsDatagram ProcessRecursiveQuery(DnsDatagram request, NameServerAddress[] viaNameServers = null)
{
- DnsClientProtocol protocol;
-
- if (_forwarders == null)
- {
- protocol = DnsClient.RecursiveResolveDefaultProtocol;
- }
- else
- {
- viaNameServers = _forwarders; //forwarder has higher weightage
- protocol = _forwarderProtocol;
- }
-
- DnsDatagram response = DnsClient.ResolveViaNameServers(request.Question[0], viaNameServers, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount);
+ DnsDatagram response = RecursiveResolve(request, viaNameServers);
DnsResourceRecord[] authority;
@@ -553,7 +584,7 @@ namespace DnsServerCore
else
question = new DnsQuestionRecord((lastRR.RDATA as DnsCNAMERecord).CNAMEDomainName, questionType, DnsClass.IN);
- lastResponse = DnsClient.ResolveViaNameServers(question, _forwarders, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount);
+ lastResponse = RecursiveResolve(question, _forwarders);
if ((lastResponse.Header.RCODE != DnsResponseCode.NoError) || (lastResponse.Answer.Length == 0))
break;
@@ -586,6 +617,87 @@ namespace DnsServerCore
return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, false, false, response.Header.RCODE, 1, (ushort)response.Answer.Length, (ushort)authority.Length, 0), request.Question, response.Answer, authority, new DnsResourceRecord[] { });
}
+ private DnsDatagram RecursiveResolve(DnsQuestionRecord questionRecord, NameServerAddress[] viaNameServers)
+ {
+ return RecursiveResolve(new DnsDatagram(new DnsHeader(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, 1, 0, 0, 0), new DnsQuestionRecord[] { questionRecord }, null, null, null), viaNameServers);
+ }
+
+ private DnsDatagram RecursiveResolve(DnsDatagram request, NameServerAddress[] viaNameServers)
+ {
+ //query cache zone to see if answer available
+ {
+ DnsDatagram cacheResponse = _cacheZoneRoot.Query(request);
+
+ if (cacheResponse.Header.RCODE != DnsResponseCode.Refused)
+ {
+ if (cacheResponse.Answer.Length > 0)
+ return cacheResponse;
+ else if ((cacheResponse.Authority.Length == 0) || (cacheResponse.Authority[0].Type == DnsResourceRecordType.SOA))
+ return cacheResponse;
+ }
+ }
+
+ //recursion with locking
+ object newLockObj = new object();
+ object actualLockObj = _recursiveQueryLocks.GetOrAdd(request.Question[0], newLockObj);
+
+ if (!actualLockObj.Equals(newLockObj))
+ {
+ //question already being recursively resolved by another thread, wait till timeout or pulse signal
+ bool waitTimeout;
+
+ lock (actualLockObj)
+ {
+ waitTimeout = !Monitor.Wait(actualLockObj, _timeout);
+ }
+
+ if (!waitTimeout)
+ {
+ //query cache zone again to see if answer available
+ DnsDatagram cacheResponse = _cacheZoneRoot.Query(request);
+
+ if (cacheResponse.Header.RCODE != DnsResponseCode.Refused)
+ {
+ if (cacheResponse.Answer.Length > 0)
+ return cacheResponse;
+ else if ((cacheResponse.Authority.Length == 0) || (cacheResponse.Authority[0].Type == DnsResourceRecordType.SOA))
+ return cacheResponse;
+ }
+ }
+
+ //wait timeout or no response available in cache so respond with server failure
+ return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, true, false, false, DnsResponseCode.ServerFailure, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null);
+ }
+
+ DnsClientProtocol protocol;
+
+ if (_forwarders == null)
+ {
+ protocol = DnsClient.RecursiveResolveDefaultProtocol;
+ }
+ else
+ {
+ viaNameServers = _forwarders; //forwarder has higher weightage
+ protocol = _forwarderProtocol;
+ }
+
+ try
+ {
+ return DnsClient.ResolveViaNameServers(request.Question[0], viaNameServers, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount, _timeout);
+ }
+ finally
+ {
+ //remove question lock
+ _recursiveQueryLocks.TryRemove(request.Question[0], out object lockObj);
+
+ //pulse all waiting threads
+ lock (newLockObj)
+ {
+ Monitor.PulseAll(newLockObj);
+ }
+ }
+ }
+
#endregion
#region public
@@ -665,8 +777,23 @@ namespace DnsServerCore
public Zone CacheZoneRoot
{ get { return _cacheZoneRoot; } }
+ public Zone AllowedZoneRoot
+ { get { return _allowedZoneRoot; } }
+
public Zone BlockedZoneRoot
- { get { return _blockedZoneRoot; } }
+ {
+ get { return _blockedZoneRoot; }
+ set
+ {
+ if (value == null)
+ throw new NullReferenceException();
+
+ if (!value.IsAuthoritative)
+ throw new ArgumentException("Blocked zone must be authoritative.");
+
+ _blockedZoneRoot = value;
+ }
+ }
internal IDnsCache Cache
{ get { return _dnsCache; } }
@@ -734,6 +861,12 @@ namespace DnsServerCore
set { _retries = value; }
}
+ public int Timeout
+ {
+ get { return _timeout; }
+ set { _timeout = value; }
+ }
+
public int MaxStackCount
{
get { return _maxStackCount; }