DnsServer: allowed zone added to override blocked zone. retries value set to 1 & timeout parameter added. Recursive query lock implemented to limit only one thread perform recursive resolution for a given domain & type query. Blocked zone change added to never return NameError. Block zone object property set method implemented to allow swapping new zone.

This commit is contained in:
Shreyas Zare
2018-10-05 23:52:43 +05:30
parent 2863e76a99
commit deebb64a2c

View File

@@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Net;
@@ -59,7 +60,8 @@ namespace DnsServerCore
readonly Zone _authoritativeZoneRoot = new Zone(true);
readonly Zone _cacheZoneRoot = new Zone(false);
readonly Zone _blockedZoneRoot = new Zone(true);
readonly Zone _allowedZoneRoot = new Zone(true);
Zone _blockedZoneRoot = new Zone(true);
readonly IDnsCache _dnsCache;
@@ -69,12 +71,15 @@ namespace DnsServerCore
NameServerAddress[] _forwarders;
DnsClientProtocol _forwarderProtocol = DnsClientProtocol.Udp;
bool _preferIPv6 = false;
int _retries = 2;
int _retries = 1;
int _timeout = 2000;
int _maxStackCount = 10;
LogManager _log;
LogManager _queryLog;
StatsManager _stats;
readonly ConcurrentDictionary<DnsQuestionRecord, object> _recursiveQueryLocks = new ConcurrentDictionary<DnsQuestionRecord, object>();
volatile ServiceState _state = ServiceState.Stopped;
#endregion
@@ -394,19 +399,57 @@ namespace DnsServerCore
try
{
//query authoritative zone
DnsDatagram authoritativeResponse = ProcessAuthoritativeQuery(request, isRecursionAllowed);
if ((authoritativeResponse.Header.RCODE != DnsResponseCode.Refused) || !request.Header.RecursionDesired || !isRecursionAllowed)
return authoritativeResponse;
//query blocked zone
DnsDatagram blockedResponse = _blockedZoneRoot.Query(request);
if (blockedResponse.Header.RCODE != DnsResponseCode.Refused)
{
blockedResponse.Tag = "blocked";
return blockedResponse;
//query allowed zone
DnsDatagram allowedResponse = _allowedZoneRoot.Query(request);
if (allowedResponse.Header.RCODE == DnsResponseCode.Refused)
{
//request domain not in allowed zone
if (blockedResponse.Header.RCODE == DnsResponseCode.NameError)
{
DnsResourceRecord[] answer;
DnsResourceRecord[] authority;
switch (blockedResponse.Question[0].Type)
{
case DnsResourceRecordType.A:
answer = new DnsResourceRecord[] { new DnsResourceRecord(blockedResponse.Question[0].Name, DnsResourceRecordType.A, blockedResponse.Question[0].Class, 60, new DnsARecord(IPAddress.Any)) };
authority = new DnsResourceRecord[] { };
break;
case DnsResourceRecordType.AAAA:
answer = new DnsResourceRecord[] { new DnsResourceRecord(blockedResponse.Question[0].Name, DnsResourceRecordType.AAAA, blockedResponse.Question[0].Class, 60, new DnsAAAARecord(IPAddress.IPv6Any)) };
authority = new DnsResourceRecord[] { };
break;
default:
answer = blockedResponse.Answer;
authority = blockedResponse.Authority;
break;
}
blockedResponse = new DnsDatagram(new DnsHeader(blockedResponse.Header.Identifier, true, blockedResponse.Header.OPCODE, false, false, blockedResponse.Header.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.NoError, blockedResponse.Header.QDCOUNT, (ushort)answer.Length, (ushort)authority.Length, 0), blockedResponse.Question, answer, authority, null);
}
//return blocked response
blockedResponse.Tag = "blocked";
return blockedResponse;
}
}
//do recursive query
return ProcessRecursiveQuery(request);
}
catch (Exception ex)
@@ -516,19 +559,7 @@ namespace DnsServerCore
private DnsDatagram ProcessRecursiveQuery(DnsDatagram request, NameServerAddress[] viaNameServers = null)
{
DnsClientProtocol protocol;
if (_forwarders == null)
{
protocol = DnsClient.RecursiveResolveDefaultProtocol;
}
else
{
viaNameServers = _forwarders; //forwarder has higher weightage
protocol = _forwarderProtocol;
}
DnsDatagram response = DnsClient.ResolveViaNameServers(request.Question[0], viaNameServers, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount);
DnsDatagram response = RecursiveResolve(request, viaNameServers);
DnsResourceRecord[] authority;
@@ -553,7 +584,7 @@ namespace DnsServerCore
else
question = new DnsQuestionRecord((lastRR.RDATA as DnsCNAMERecord).CNAMEDomainName, questionType, DnsClass.IN);
lastResponse = DnsClient.ResolveViaNameServers(question, _forwarders, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount);
lastResponse = RecursiveResolve(question, _forwarders);
if ((lastResponse.Header.RCODE != DnsResponseCode.NoError) || (lastResponse.Answer.Length == 0))
break;
@@ -586,6 +617,87 @@ namespace DnsServerCore
return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, false, false, response.Header.RCODE, 1, (ushort)response.Answer.Length, (ushort)authority.Length, 0), request.Question, response.Answer, authority, new DnsResourceRecord[] { });
}
private DnsDatagram RecursiveResolve(DnsQuestionRecord questionRecord, NameServerAddress[] viaNameServers)
{
return RecursiveResolve(new DnsDatagram(new DnsHeader(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, 1, 0, 0, 0), new DnsQuestionRecord[] { questionRecord }, null, null, null), viaNameServers);
}
private DnsDatagram RecursiveResolve(DnsDatagram request, NameServerAddress[] viaNameServers)
{
//query cache zone to see if answer available
{
DnsDatagram cacheResponse = _cacheZoneRoot.Query(request);
if (cacheResponse.Header.RCODE != DnsResponseCode.Refused)
{
if (cacheResponse.Answer.Length > 0)
return cacheResponse;
else if ((cacheResponse.Authority.Length == 0) || (cacheResponse.Authority[0].Type == DnsResourceRecordType.SOA))
return cacheResponse;
}
}
//recursion with locking
object newLockObj = new object();
object actualLockObj = _recursiveQueryLocks.GetOrAdd(request.Question[0], newLockObj);
if (!actualLockObj.Equals(newLockObj))
{
//question already being recursively resolved by another thread, wait till timeout or pulse signal
bool waitTimeout;
lock (actualLockObj)
{
waitTimeout = !Monitor.Wait(actualLockObj, _timeout);
}
if (!waitTimeout)
{
//query cache zone again to see if answer available
DnsDatagram cacheResponse = _cacheZoneRoot.Query(request);
if (cacheResponse.Header.RCODE != DnsResponseCode.Refused)
{
if (cacheResponse.Answer.Length > 0)
return cacheResponse;
else if ((cacheResponse.Authority.Length == 0) || (cacheResponse.Authority[0].Type == DnsResourceRecordType.SOA))
return cacheResponse;
}
}
//wait timeout or no response available in cache so respond with server failure
return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, true, false, false, DnsResponseCode.ServerFailure, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null);
}
DnsClientProtocol protocol;
if (_forwarders == null)
{
protocol = DnsClient.RecursiveResolveDefaultProtocol;
}
else
{
viaNameServers = _forwarders; //forwarder has higher weightage
protocol = _forwarderProtocol;
}
try
{
return DnsClient.ResolveViaNameServers(request.Question[0], viaNameServers, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount, _timeout);
}
finally
{
//remove question lock
_recursiveQueryLocks.TryRemove(request.Question[0], out object lockObj);
//pulse all waiting threads
lock (newLockObj)
{
Monitor.PulseAll(newLockObj);
}
}
}
#endregion
#region public
@@ -665,8 +777,23 @@ namespace DnsServerCore
public Zone CacheZoneRoot
{ get { return _cacheZoneRoot; } }
public Zone AllowedZoneRoot
{ get { return _allowedZoneRoot; } }
public Zone BlockedZoneRoot
{ get { return _blockedZoneRoot; } }
{
get { return _blockedZoneRoot; }
set
{
if (value == null)
throw new NullReferenceException();
if (!value.IsAuthoritative)
throw new ArgumentException("Blocked zone must be authoritative.");
_blockedZoneRoot = value;
}
}
internal IDnsCache Cache
{ get { return _dnsCache; } }
@@ -734,6 +861,12 @@ namespace DnsServerCore
set { _retries = value; }
}
public int Timeout
{
get { return _timeout; }
set { _timeout = value; }
}
public int MaxStackCount
{
get { return _maxStackCount; }