diff --git a/DnsServerCore/DnsServer.cs b/DnsServerCore/DnsServer.cs index 1af81cff..ee884f46 100644 --- a/DnsServerCore/DnsServer.cs +++ b/DnsServerCore/DnsServer.cs @@ -47,8 +47,8 @@ namespace DnsServerCore #region variables const int UDP_LISTENER_THREAD_COUNT = 3; - const int TCP_SOCKET_SEND_TIMEOUT = 10000; - const int TCP_SOCKET_RECV_TIMEOUT = 60000; + const int TCP_SOCKET_SEND_TIMEOUT = 30000; + const int TCP_SOCKET_RECV_TIMEOUT = 120000; readonly IPEndPoint _localEP; @@ -70,8 +70,9 @@ namespace DnsServerCore NetProxy _proxy; NameServerAddress[] _forwarders; DnsClientProtocol _forwarderProtocol = DnsClientProtocol.Udp; + DnsClientProtocol _recursiveResolveProtocol = DnsClientProtocol.Udp; bool _preferIPv6 = false; - int _retries = 1; + int _retries = 3; int _timeout = 2000; int _maxStackCount = 10; LogManager _log; @@ -118,7 +119,7 @@ namespace DnsServerCore #region private - private void ReadUdpQueryPacketsAsync(object parameter) + private void ReadUdpRequestAsync(object parameter) { EndPoint remoteEP; byte[] recvBuffer = new byte[512]; @@ -245,11 +246,10 @@ namespace DnsServerCore { Socket socket = _tcpListener.Accept(); - socket.NoDelay = true; socket.SendTimeout = TCP_SOCKET_SEND_TIMEOUT; socket.ReceiveTimeout = TCP_SOCKET_RECV_TIMEOUT; - ThreadPool.QueueUserWorkItem(ProcessTcpRequestAsync, socket); + ThreadPool.QueueUserWorkItem(ReadTcpRequestAsync, socket); } } catch (ThreadAbortException) @@ -267,70 +267,33 @@ namespace DnsServerCore } } - private void ProcessTcpRequestAsync(object parameter) + private void ReadTcpRequestAsync(object parameter) { Socket tcpSocket = parameter as Socket; DnsDatagram request = null; try { - NetworkStream recvStream = new NetworkStream(tcpSocket); - OffsetStream recvDatagramStream = new OffsetStream(recvStream, 0, 0); - MemoryStream sendBufferStream = null; - byte[] sendBuffer = null; + NetworkStream tcpStream = new NetworkStream(tcpSocket); + OffsetStream recvDatagramStream = new OffsetStream(tcpStream, 0, 0); + MemoryStream sendBufferStream = new MemoryStream(64); ushort length; while (true) { + request = null; + //read dns datagram length - { - byte[] lengthBuffer = recvStream.ReadBytes(2); - Array.Reverse(lengthBuffer, 0, 2); - length = BitConverter.ToUInt16(lengthBuffer, 0); - } + byte[] lengthBuffer = tcpStream.ReadBytes(2); + Array.Reverse(lengthBuffer, 0, 2); + length = BitConverter.ToUInt16(lengthBuffer, 0); //read dns datagram recvDatagramStream.Reset(0, length, 0); request = new DnsDatagram(recvDatagramStream); - DnsDatagram response = ProcessQuery(request, tcpSocket.RemoteEndPoint); - - //send response - if (response != null) - { - if (sendBufferStream == null) - sendBufferStream = new MemoryStream(64); - - //write dns datagram - sendBufferStream.Position = 0; - response.WriteTo(sendBufferStream); - - //prepare final buffer - length = Convert.ToUInt16(sendBufferStream.Position); - - if ((sendBuffer == null) || (sendBuffer.Length < length + 2)) - sendBuffer = new byte[length + 2]; - - //copy datagram length - byte[] lengthBuffer = BitConverter.GetBytes(length); - sendBuffer[0] = lengthBuffer[1]; - sendBuffer[1] = lengthBuffer[0]; - - //copy datagram - sendBufferStream.Position = 0; - sendBufferStream.Read(sendBuffer, 2, length); - - //send dns datagram - tcpSocket.Send(sendBuffer, 0, length + 2, SocketFlags.None); - - LogManager queryLog = _queryLog; - if (queryLog != null) - queryLog.Write(tcpSocket.RemoteEndPoint as IPEndPoint, true, request, response); - - StatsManager stats = _stats; - if (stats != null) - stats.Update(response, (tcpSocket.RemoteEndPoint as IPEndPoint).Address); - } + //process request async + ThreadPool.QueueUserWorkItem(ProcessTcpRequestAsync, new object[] { request, tcpSocket, tcpStream, sendBufferStream }); } } catch (IOException) @@ -354,6 +317,66 @@ namespace DnsServerCore } } + private void ProcessTcpRequestAsync(object parameter) + { + object[] parameters = parameter as object[]; + + DnsDatagram request = parameters[0] as DnsDatagram; + Socket tcpSocket = parameters[1] as Socket; + NetworkStream tcpStream = parameters[2] as NetworkStream; + MemoryStream sendBufferStream = parameters[3] as MemoryStream; + + try + { + DnsDatagram response = ProcessQuery(request, tcpSocket.RemoteEndPoint); + + //send response + if (response != null) + { + lock (tcpSocket) + { + //write dns datagram + sendBufferStream.Position = 0; + response.WriteTo(sendBufferStream); + + //write dns datagram length + ushort length = Convert.ToUInt16(sendBufferStream.Position); + byte[] lengthBuffer = BitConverter.GetBytes(length); + Array.Reverse(lengthBuffer, 0, 2); + tcpStream.Write(lengthBuffer); + + //send dns datagram + sendBufferStream.Position = 0; + sendBufferStream.CopyTo(tcpStream, 512, length); + + tcpStream.Flush(); + } + + LogManager queryLog = _queryLog; + if (queryLog != null) + queryLog.Write(tcpSocket.RemoteEndPoint as IPEndPoint, true, request, response); + + StatsManager stats = _stats; + if (stats != null) + stats.Update(response, (tcpSocket.RemoteEndPoint as IPEndPoint).Address); + } + } + catch (IOException) + { + //ignore IO exceptions + } + catch (Exception ex) + { + LogManager queryLog = _queryLog; + if ((queryLog != null) && (request != null)) + queryLog.Write(tcpSocket.RemoteEndPoint as IPEndPoint, true, request, null); + + LogManager log = _log; + if (log != null) + log.Write(tcpSocket.RemoteEndPoint as IPEndPoint, ex); + } + } + private bool IsRecursionAllowed(EndPoint remoteEP) { if (!_allowRecursion) @@ -469,6 +492,7 @@ namespace DnsServerCore private DnsDatagram ProcessAuthoritativeQuery(DnsDatagram request, bool isRecursionAllowed) { DnsDatagram response = _authoritativeZoneRoot.Query(request); + response.Tag = "cacheHit"; if (response.Header.RCODE == DnsResponseCode.NoError) { @@ -483,6 +507,7 @@ namespace DnsServerCore responseAnswer.AddRange(response.Answer); DnsDatagram lastResponse; + bool cacheHit = (response.Tag == "cacheHit"); while (true) { @@ -496,6 +521,7 @@ namespace DnsServerCore break; lastResponse = ProcessRecursiveQuery(cnameRequest); + cacheHit &= (lastResponse.Tag == "cacheHit"); } if ((lastResponse.Header.RCODE != DnsResponseCode.NoError) || (lastResponse.Answer.Length == 0)) @@ -539,7 +565,7 @@ namespace DnsServerCore } } - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, lastResponse.Header.AuthoritativeAnswer, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, rcode, 1, (ushort)responseAnswer.Count, (ushort)authority.Length, (ushort)additional.Length), request.Question, responseAnswer.ToArray(), authority, additional); + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, lastResponse.Header.AuthoritativeAnswer, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, rcode, 1, (ushort)responseAnswer.Count, (ushort)authority.Length, (ushort)additional.Length), request.Question, responseAnswer.ToArray(), authority, additional) { Tag = (cacheHit ? "cacheHit" : null) }; } } else if ((response.Authority.Length > 0) && (response.Authority[0].Type == DnsResourceRecordType.NS) && isRecursionAllowed) @@ -574,6 +600,7 @@ namespace DnsServerCore responseAnswer.AddRange(response.Answer); DnsDatagram lastResponse; + bool cacheHit = (response.Tag == "cacheHit"); while (true) { @@ -585,6 +612,7 @@ namespace DnsServerCore question = new DnsQuestionRecord((lastRR.RDATA as DnsCNAMERecord).CNAMEDomainName, questionType, DnsClass.IN); lastResponse = RecursiveResolve(question, _forwarders); + cacheHit &= (lastResponse.Tag == "cacheHit"); if ((lastResponse.Header.RCODE != DnsResponseCode.NoError) || (lastResponse.Answer.Length == 0)) break; @@ -597,7 +625,7 @@ namespace DnsServerCore break; if (lastRR.Type != DnsResourceRecordType.CNAME) - throw new DnsServerException("Invalid response received from Dns server."); + throw new DnsServerException("Invalid response received from DNS server."); } if ((lastResponse.Authority.Length > 0) && (lastResponse.Authority[0].Type == DnsResourceRecordType.SOA)) @@ -605,7 +633,7 @@ namespace DnsServerCore else authority = new DnsResourceRecord[] { }; - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, false, false, lastResponse.Header.RCODE, 1, (ushort)responseAnswer.Count, (ushort)authority.Length, 0), request.Question, responseAnswer.ToArray(), authority, new DnsResourceRecord[] { }); + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, false, false, lastResponse.Header.RCODE, 1, (ushort)responseAnswer.Count, (ushort)authority.Length, 0), request.Question, responseAnswer.ToArray(), authority, new DnsResourceRecord[] { }) { Tag = (cacheHit ? "cacheHit" : null) }; } } @@ -614,7 +642,7 @@ namespace DnsServerCore else authority = new DnsResourceRecord[] { }; - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, false, false, response.Header.RCODE, 1, (ushort)response.Answer.Length, (ushort)authority.Length, 0), request.Question, response.Answer, authority, new DnsResourceRecord[] { }); + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, false, false, response.Header.RCODE, 1, (ushort)response.Answer.Length, (ushort)authority.Length, 0), request.Question, response.Answer, authority, new DnsResourceRecord[] { }) { Tag = response.Tag }; } private DnsDatagram RecursiveResolve(DnsQuestionRecord questionRecord, NameServerAddress[] viaNameServers) @@ -630,50 +658,52 @@ namespace DnsServerCore if (cacheResponse.Header.RCODE != DnsResponseCode.Refused) { - if (cacheResponse.Answer.Length > 0) - return cacheResponse; - else if ((cacheResponse.Authority.Length == 0) || (cacheResponse.Authority[0].Type == DnsResourceRecordType.SOA)) + if ((cacheResponse.Answer.Length > 0) || ((cacheResponse.Authority.Length == 0) || (cacheResponse.Authority[0].Type == DnsResourceRecordType.SOA))) + { + cacheResponse.Tag = "cacheHit"; return cacheResponse; + } } } //recursion with locking - object newLockObj = new object(); - object actualLockObj = _recursiveQueryLocks.GetOrAdd(request.Question[0], newLockObj); - - if (!actualLockObj.Equals(newLockObj)) { - //question already being recursively resolved by another thread, wait till timeout or pulse signal - bool waitTimeout; + object newLockObj = new object(); + object actualLockObj = _recursiveQueryLocks.GetOrAdd(request.Question[0], newLockObj); - lock (actualLockObj) + if (!actualLockObj.Equals(newLockObj)) { - waitTimeout = !Monitor.Wait(actualLockObj, _timeout); - } - - if (!waitTimeout) - { - //query cache zone again to see if answer available - DnsDatagram cacheResponse = _cacheZoneRoot.Query(request); - - if (cacheResponse.Header.RCODE != DnsResponseCode.Refused) + //question already being recursively resolved by another thread, wait till timeout or pulse signal + lock (actualLockObj) { - if (cacheResponse.Answer.Length > 0) - return cacheResponse; - else if ((cacheResponse.Authority.Length == 0) || (cacheResponse.Authority[0].Type == DnsResourceRecordType.SOA)) - return cacheResponse; + Monitor.Wait(actualLockObj, _timeout); } - } - //wait timeout or no response available in cache so respond with server failure - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, true, false, false, DnsResponseCode.ServerFailure, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); + //query cache zone again to see if answer available + { + DnsDatagram cacheResponse = _cacheZoneRoot.Query(request); + + if (cacheResponse.Header.RCODE != DnsResponseCode.Refused) + { + if ((cacheResponse.Answer.Length > 0) || ((cacheResponse.Authority.Length == 0) || (cacheResponse.Authority[0].Type == DnsResourceRecordType.SOA))) + { + cacheResponse.Tag = "cacheHit"; + return cacheResponse; + } + } + } + + //no response available in cache so respond with server failure + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, true, false, false, DnsResponseCode.ServerFailure, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); + } } + //select protocol DnsClientProtocol protocol; if (_forwarders == null) { - protocol = DnsClient.RecursiveResolveDefaultProtocol; + protocol = _recursiveResolveProtocol; } else { @@ -683,17 +713,18 @@ namespace DnsServerCore try { - return DnsClient.ResolveViaNameServers(request.Question[0], viaNameServers, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount, _timeout); + return DnsClient.ResolveViaNameServers(request.Question[0], viaNameServers, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount, _timeout, _recursiveResolveProtocol); } finally { //remove question lock - _recursiveQueryLocks.TryRemove(request.Question[0], out object lockObj); - - //pulse all waiting threads - lock (newLockObj) + if (_recursiveQueryLocks.TryRemove(request.Question[0], out object lockObj)) { - Monitor.PulseAll(newLockObj); + //pulse all waiting threads + lock (lockObj) + { + Monitor.PulseAll(lockObj); + } } } } @@ -734,7 +765,7 @@ namespace DnsServerCore for (int i = 0; i < UDP_LISTENER_THREAD_COUNT; i++) { - _udpListenerThreads[i] = new Thread(ReadUdpQueryPacketsAsync); + _udpListenerThreads[i] = new Thread(ReadUdpRequestAsync); _udpListenerThreads[i].IsBackground = true; _udpListenerThreads[i].Start(); } @@ -841,6 +872,12 @@ namespace DnsServerCore set { _forwarderProtocol = value; } } + public DnsClientProtocol RecursiveResolveProtocol + { + get { return _recursiveResolveProtocol; } + set { _recursiveResolveProtocol = value; } + } + public bool PreferIPv6 { get { return _preferIPv6; }