DnsServer: updated TCP timeout defaults for keeping alive existing connection for reuse. Updated default retry setting to 3 for limiting ServerFailure responses. Implemented DNS over TCP recommendations as per RFC 7766. Implemented cache hit feature. Fixed issues with RecursiveResolve() to limit ServerFailure responses.

This commit is contained in:
Shreyas Zare
2018-12-09 17:05:02 +05:30
parent 4c9c22aae9
commit f624707e8b

View File

@@ -47,8 +47,8 @@ namespace DnsServerCore
#region variables
const int UDP_LISTENER_THREAD_COUNT = 3;
const int TCP_SOCKET_SEND_TIMEOUT = 10000;
const int TCP_SOCKET_RECV_TIMEOUT = 60000;
const int TCP_SOCKET_SEND_TIMEOUT = 30000;
const int TCP_SOCKET_RECV_TIMEOUT = 120000;
readonly IPEndPoint _localEP;
@@ -70,8 +70,9 @@ namespace DnsServerCore
NetProxy _proxy;
NameServerAddress[] _forwarders;
DnsClientProtocol _forwarderProtocol = DnsClientProtocol.Udp;
DnsClientProtocol _recursiveResolveProtocol = DnsClientProtocol.Udp;
bool _preferIPv6 = false;
int _retries = 1;
int _retries = 3;
int _timeout = 2000;
int _maxStackCount = 10;
LogManager _log;
@@ -118,7 +119,7 @@ namespace DnsServerCore
#region private
private void ReadUdpQueryPacketsAsync(object parameter)
private void ReadUdpRequestAsync(object parameter)
{
EndPoint remoteEP;
byte[] recvBuffer = new byte[512];
@@ -245,11 +246,10 @@ namespace DnsServerCore
{
Socket socket = _tcpListener.Accept();
socket.NoDelay = true;
socket.SendTimeout = TCP_SOCKET_SEND_TIMEOUT;
socket.ReceiveTimeout = TCP_SOCKET_RECV_TIMEOUT;
ThreadPool.QueueUserWorkItem(ProcessTcpRequestAsync, socket);
ThreadPool.QueueUserWorkItem(ReadTcpRequestAsync, socket);
}
}
catch (ThreadAbortException)
@@ -267,70 +267,33 @@ namespace DnsServerCore
}
}
private void ProcessTcpRequestAsync(object parameter)
private void ReadTcpRequestAsync(object parameter)
{
Socket tcpSocket = parameter as Socket;
DnsDatagram request = null;
try
{
NetworkStream recvStream = new NetworkStream(tcpSocket);
OffsetStream recvDatagramStream = new OffsetStream(recvStream, 0, 0);
MemoryStream sendBufferStream = null;
byte[] sendBuffer = null;
NetworkStream tcpStream = new NetworkStream(tcpSocket);
OffsetStream recvDatagramStream = new OffsetStream(tcpStream, 0, 0);
MemoryStream sendBufferStream = new MemoryStream(64);
ushort length;
while (true)
{
request = null;
//read dns datagram length
{
byte[] lengthBuffer = recvStream.ReadBytes(2);
Array.Reverse(lengthBuffer, 0, 2);
length = BitConverter.ToUInt16(lengthBuffer, 0);
}
byte[] lengthBuffer = tcpStream.ReadBytes(2);
Array.Reverse(lengthBuffer, 0, 2);
length = BitConverter.ToUInt16(lengthBuffer, 0);
//read dns datagram
recvDatagramStream.Reset(0, length, 0);
request = new DnsDatagram(recvDatagramStream);
DnsDatagram response = ProcessQuery(request, tcpSocket.RemoteEndPoint);
//send response
if (response != null)
{
if (sendBufferStream == null)
sendBufferStream = new MemoryStream(64);
//write dns datagram
sendBufferStream.Position = 0;
response.WriteTo(sendBufferStream);
//prepare final buffer
length = Convert.ToUInt16(sendBufferStream.Position);
if ((sendBuffer == null) || (sendBuffer.Length < length + 2))
sendBuffer = new byte[length + 2];
//copy datagram length
byte[] lengthBuffer = BitConverter.GetBytes(length);
sendBuffer[0] = lengthBuffer[1];
sendBuffer[1] = lengthBuffer[0];
//copy datagram
sendBufferStream.Position = 0;
sendBufferStream.Read(sendBuffer, 2, length);
//send dns datagram
tcpSocket.Send(sendBuffer, 0, length + 2, SocketFlags.None);
LogManager queryLog = _queryLog;
if (queryLog != null)
queryLog.Write(tcpSocket.RemoteEndPoint as IPEndPoint, true, request, response);
StatsManager stats = _stats;
if (stats != null)
stats.Update(response, (tcpSocket.RemoteEndPoint as IPEndPoint).Address);
}
//process request async
ThreadPool.QueueUserWorkItem(ProcessTcpRequestAsync, new object[] { request, tcpSocket, tcpStream, sendBufferStream });
}
}
catch (IOException)
@@ -354,6 +317,66 @@ namespace DnsServerCore
}
}
private void ProcessTcpRequestAsync(object parameter)
{
object[] parameters = parameter as object[];
DnsDatagram request = parameters[0] as DnsDatagram;
Socket tcpSocket = parameters[1] as Socket;
NetworkStream tcpStream = parameters[2] as NetworkStream;
MemoryStream sendBufferStream = parameters[3] as MemoryStream;
try
{
DnsDatagram response = ProcessQuery(request, tcpSocket.RemoteEndPoint);
//send response
if (response != null)
{
lock (tcpSocket)
{
//write dns datagram
sendBufferStream.Position = 0;
response.WriteTo(sendBufferStream);
//write dns datagram length
ushort length = Convert.ToUInt16(sendBufferStream.Position);
byte[] lengthBuffer = BitConverter.GetBytes(length);
Array.Reverse(lengthBuffer, 0, 2);
tcpStream.Write(lengthBuffer);
//send dns datagram
sendBufferStream.Position = 0;
sendBufferStream.CopyTo(tcpStream, 512, length);
tcpStream.Flush();
}
LogManager queryLog = _queryLog;
if (queryLog != null)
queryLog.Write(tcpSocket.RemoteEndPoint as IPEndPoint, true, request, response);
StatsManager stats = _stats;
if (stats != null)
stats.Update(response, (tcpSocket.RemoteEndPoint as IPEndPoint).Address);
}
}
catch (IOException)
{
//ignore IO exceptions
}
catch (Exception ex)
{
LogManager queryLog = _queryLog;
if ((queryLog != null) && (request != null))
queryLog.Write(tcpSocket.RemoteEndPoint as IPEndPoint, true, request, null);
LogManager log = _log;
if (log != null)
log.Write(tcpSocket.RemoteEndPoint as IPEndPoint, ex);
}
}
private bool IsRecursionAllowed(EndPoint remoteEP)
{
if (!_allowRecursion)
@@ -469,6 +492,7 @@ namespace DnsServerCore
private DnsDatagram ProcessAuthoritativeQuery(DnsDatagram request, bool isRecursionAllowed)
{
DnsDatagram response = _authoritativeZoneRoot.Query(request);
response.Tag = "cacheHit";
if (response.Header.RCODE == DnsResponseCode.NoError)
{
@@ -483,6 +507,7 @@ namespace DnsServerCore
responseAnswer.AddRange(response.Answer);
DnsDatagram lastResponse;
bool cacheHit = (response.Tag == "cacheHit");
while (true)
{
@@ -496,6 +521,7 @@ namespace DnsServerCore
break;
lastResponse = ProcessRecursiveQuery(cnameRequest);
cacheHit &= (lastResponse.Tag == "cacheHit");
}
if ((lastResponse.Header.RCODE != DnsResponseCode.NoError) || (lastResponse.Answer.Length == 0))
@@ -539,7 +565,7 @@ namespace DnsServerCore
}
}
return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, lastResponse.Header.AuthoritativeAnswer, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, rcode, 1, (ushort)responseAnswer.Count, (ushort)authority.Length, (ushort)additional.Length), request.Question, responseAnswer.ToArray(), authority, additional);
return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, lastResponse.Header.AuthoritativeAnswer, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, rcode, 1, (ushort)responseAnswer.Count, (ushort)authority.Length, (ushort)additional.Length), request.Question, responseAnswer.ToArray(), authority, additional) { Tag = (cacheHit ? "cacheHit" : null) };
}
}
else if ((response.Authority.Length > 0) && (response.Authority[0].Type == DnsResourceRecordType.NS) && isRecursionAllowed)
@@ -574,6 +600,7 @@ namespace DnsServerCore
responseAnswer.AddRange(response.Answer);
DnsDatagram lastResponse;
bool cacheHit = (response.Tag == "cacheHit");
while (true)
{
@@ -585,6 +612,7 @@ namespace DnsServerCore
question = new DnsQuestionRecord((lastRR.RDATA as DnsCNAMERecord).CNAMEDomainName, questionType, DnsClass.IN);
lastResponse = RecursiveResolve(question, _forwarders);
cacheHit &= (lastResponse.Tag == "cacheHit");
if ((lastResponse.Header.RCODE != DnsResponseCode.NoError) || (lastResponse.Answer.Length == 0))
break;
@@ -597,7 +625,7 @@ namespace DnsServerCore
break;
if (lastRR.Type != DnsResourceRecordType.CNAME)
throw new DnsServerException("Invalid response received from Dns server.");
throw new DnsServerException("Invalid response received from DNS server.");
}
if ((lastResponse.Authority.Length > 0) && (lastResponse.Authority[0].Type == DnsResourceRecordType.SOA))
@@ -605,7 +633,7 @@ namespace DnsServerCore
else
authority = new DnsResourceRecord[] { };
return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, false, false, lastResponse.Header.RCODE, 1, (ushort)responseAnswer.Count, (ushort)authority.Length, 0), request.Question, responseAnswer.ToArray(), authority, new DnsResourceRecord[] { });
return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, false, false, lastResponse.Header.RCODE, 1, (ushort)responseAnswer.Count, (ushort)authority.Length, 0), request.Question, responseAnswer.ToArray(), authority, new DnsResourceRecord[] { }) { Tag = (cacheHit ? "cacheHit" : null) };
}
}
@@ -614,7 +642,7 @@ namespace DnsServerCore
else
authority = new DnsResourceRecord[] { };
return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, false, false, response.Header.RCODE, 1, (ushort)response.Answer.Length, (ushort)authority.Length, 0), request.Question, response.Answer, authority, new DnsResourceRecord[] { });
return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, false, false, response.Header.RCODE, 1, (ushort)response.Answer.Length, (ushort)authority.Length, 0), request.Question, response.Answer, authority, new DnsResourceRecord[] { }) { Tag = response.Tag };
}
private DnsDatagram RecursiveResolve(DnsQuestionRecord questionRecord, NameServerAddress[] viaNameServers)
@@ -630,50 +658,52 @@ namespace DnsServerCore
if (cacheResponse.Header.RCODE != DnsResponseCode.Refused)
{
if (cacheResponse.Answer.Length > 0)
return cacheResponse;
else if ((cacheResponse.Authority.Length == 0) || (cacheResponse.Authority[0].Type == DnsResourceRecordType.SOA))
if ((cacheResponse.Answer.Length > 0) || ((cacheResponse.Authority.Length == 0) || (cacheResponse.Authority[0].Type == DnsResourceRecordType.SOA)))
{
cacheResponse.Tag = "cacheHit";
return cacheResponse;
}
}
}
//recursion with locking
object newLockObj = new object();
object actualLockObj = _recursiveQueryLocks.GetOrAdd(request.Question[0], newLockObj);
if (!actualLockObj.Equals(newLockObj))
{
//question already being recursively resolved by another thread, wait till timeout or pulse signal
bool waitTimeout;
object newLockObj = new object();
object actualLockObj = _recursiveQueryLocks.GetOrAdd(request.Question[0], newLockObj);
lock (actualLockObj)
if (!actualLockObj.Equals(newLockObj))
{
waitTimeout = !Monitor.Wait(actualLockObj, _timeout);
}
if (!waitTimeout)
{
//query cache zone again to see if answer available
DnsDatagram cacheResponse = _cacheZoneRoot.Query(request);
if (cacheResponse.Header.RCODE != DnsResponseCode.Refused)
//question already being recursively resolved by another thread, wait till timeout or pulse signal
lock (actualLockObj)
{
if (cacheResponse.Answer.Length > 0)
return cacheResponse;
else if ((cacheResponse.Authority.Length == 0) || (cacheResponse.Authority[0].Type == DnsResourceRecordType.SOA))
return cacheResponse;
Monitor.Wait(actualLockObj, _timeout);
}
}
//wait timeout or no response available in cache so respond with server failure
return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, true, false, false, DnsResponseCode.ServerFailure, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null);
//query cache zone again to see if answer available
{
DnsDatagram cacheResponse = _cacheZoneRoot.Query(request);
if (cacheResponse.Header.RCODE != DnsResponseCode.Refused)
{
if ((cacheResponse.Answer.Length > 0) || ((cacheResponse.Authority.Length == 0) || (cacheResponse.Authority[0].Type == DnsResourceRecordType.SOA)))
{
cacheResponse.Tag = "cacheHit";
return cacheResponse;
}
}
}
//no response available in cache so respond with server failure
return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, true, false, false, DnsResponseCode.ServerFailure, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null);
}
}
//select protocol
DnsClientProtocol protocol;
if (_forwarders == null)
{
protocol = DnsClient.RecursiveResolveDefaultProtocol;
protocol = _recursiveResolveProtocol;
}
else
{
@@ -683,17 +713,18 @@ namespace DnsServerCore
try
{
return DnsClient.ResolveViaNameServers(request.Question[0], viaNameServers, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount, _timeout);
return DnsClient.ResolveViaNameServers(request.Question[0], viaNameServers, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount, _timeout, _recursiveResolveProtocol);
}
finally
{
//remove question lock
_recursiveQueryLocks.TryRemove(request.Question[0], out object lockObj);
//pulse all waiting threads
lock (newLockObj)
if (_recursiveQueryLocks.TryRemove(request.Question[0], out object lockObj))
{
Monitor.PulseAll(newLockObj);
//pulse all waiting threads
lock (lockObj)
{
Monitor.PulseAll(lockObj);
}
}
}
}
@@ -734,7 +765,7 @@ namespace DnsServerCore
for (int i = 0; i < UDP_LISTENER_THREAD_COUNT; i++)
{
_udpListenerThreads[i] = new Thread(ReadUdpQueryPacketsAsync);
_udpListenerThreads[i] = new Thread(ReadUdpRequestAsync);
_udpListenerThreads[i].IsBackground = true;
_udpListenerThreads[i].Start();
}
@@ -841,6 +872,12 @@ namespace DnsServerCore
set { _forwarderProtocol = value; }
}
public DnsClientProtocol RecursiveResolveProtocol
{
get { return _recursiveResolveProtocol; }
set { _recursiveResolveProtocol = value; }
}
public bool PreferIPv6
{
get { return _preferIPv6; }