From 352a2a199971ac5f707d4ab0abc7e811ed7cfa86 Mon Sep 17 00:00:00 2001 From: Shreyas Zare Date: Sat, 11 Aug 2018 12:09:44 +0530 Subject: [PATCH] DnsServer: removed obsolete code usage. Code refactoring done. Fixed tcp response reading bug caused by small recv buffer size than the packet length. ProcessAuthoritativeQuery() updated to do recursion if auth zone has delegated subdomain. --- DnsServerCore/DnsServer.cs | 280 ++++++++++++++++++++----------------- 1 file changed, 148 insertions(+), 132 deletions(-) diff --git a/DnsServerCore/DnsServer.cs b/DnsServerCore/DnsServer.cs index 729e073b..c4c60cef 100644 --- a/DnsServerCore/DnsServer.cs +++ b/DnsServerCore/DnsServer.cs @@ -126,23 +126,22 @@ namespace DnsServerCore #endregion - EndPoint remoteEP = null; - FixMemoryStream recvBufferStream = new FixMemoryStream(128); + EndPoint remoteEP; + byte[] recvBuffer = new byte[512]; int bytesRecv; + if (_udpListener.AddressFamily == AddressFamily.InterNetwork) + remoteEP = new IPEndPoint(IPAddress.Any, 0); + else + remoteEP = new IPEndPoint(IPAddress.IPv6Any, 0); + try { while (true) { - - if (_udpListener.AddressFamily == AddressFamily.InterNetwork) - remoteEP = new IPEndPoint(IPAddress.Any, 0); - else - remoteEP = new IPEndPoint(IPAddress.IPv6Any, 0); - try { - bytesRecv = _udpListener.ReceiveFrom(recvBufferStream.Buffer, ref remoteEP); + bytesRecv = _udpListener.ReceiveFrom(recvBuffer, ref remoteEP); } catch (SocketException ex) { @@ -160,15 +159,16 @@ namespace DnsServerCore if (bytesRecv > 0) { - recvBufferStream.Position = 0; - recvBufferStream.SetLength(bytesRecv); - try { - ThreadPool.QueueUserWorkItem(ProcessUdpRequestAsync, new object[] { remoteEP, new DnsDatagram(recvBufferStream) }); + ThreadPool.QueueUserWorkItem(ProcessUdpRequestAsync, new object[] { remoteEP, new DnsDatagram(new MemoryStream(recvBuffer, 0, bytesRecv, false)) }); + } + catch (Exception ex) + { + LogManager log = _log; + if (log != null) + log.Write(remoteEP as IPEndPoint, ex); } - catch - { } } } } @@ -176,7 +176,7 @@ namespace DnsServerCore { LogManager log = _log; if (log != null) - log.Write((IPEndPoint)remoteEP, ex); + log.Write(remoteEP as IPEndPoint, ex); if (_state == ServiceState.Running) throw; @@ -197,13 +197,14 @@ namespace DnsServerCore //send response if (response != null) { - FixMemoryStream sendBufferStream = new FixMemoryStream(512); + byte[] sendBuffer = new byte[512]; + MemoryStream sendBufferStream = new MemoryStream(sendBuffer); try { response.WriteTo(sendBufferStream); } - catch (EndOfStreamException) + catch (NotSupportedException) { DnsHeader header = response.Header; response = new DnsDatagram(new DnsHeader(header.Identifier, true, header.OPCODE, header.AuthoritativeAnswer, true, header.RecursionDesired, header.RecursionAvailable, header.AuthenticData, header.CheckingDisabled, header.RCODE, header.QDCOUNT, 0, 0, 0), response.Question, null, null, null); @@ -213,22 +214,22 @@ namespace DnsServerCore } //send dns datagram - _udpListener.SendTo(sendBufferStream.Buffer, 0, (int)sendBufferStream.Position, SocketFlags.None, remoteEP); + _udpListener.SendTo(sendBuffer, 0, (int)sendBufferStream.Position, SocketFlags.None, remoteEP); LogManager queryLog = _queryLog; if (queryLog != null) - queryLog.Write((IPEndPoint)remoteEP, false, request, response); + queryLog.Write(remoteEP as IPEndPoint, false, request, response); } } catch (Exception ex) { LogManager queryLog = _queryLog; if (queryLog != null) - queryLog.Write((IPEndPoint)remoteEP, false, request, null); + queryLog.Write(remoteEP as IPEndPoint, false, request, null); LogManager log = _log; if (log != null) - log.Write((IPEndPoint)remoteEP, ex); + log.Write(remoteEP as IPEndPoint, ex); } } @@ -265,70 +266,65 @@ namespace DnsServerCore try { - FixMemoryStream recvBufferStream = new FixMemoryStream(128); - MemoryStream sendBufferStream = new MemoryStream(512); - int bytesRecv; + NetworkStream recvStream = new NetworkStream(tcpSocket); + OffsetStream recvDatagramStream = new OffsetStream(recvStream, 0, 0); + MemoryStream sendBufferStream = null; + byte[] sendBuffer = null; + ushort length; while (true) { //read dns datagram length - bytesRecv = tcpSocket.Receive(recvBufferStream.Buffer, 0, 2, SocketFlags.None); - if (bytesRecv < 1) - return; //do nothing - - Array.Reverse(recvBufferStream.Buffer, 0, 2); - short length = BitConverter.ToInt16(recvBufferStream.Buffer, 0); + { + byte[] lengthBuffer = recvStream.ReadBytes(2); + Array.Reverse(lengthBuffer, 0, 2); + length = BitConverter.ToUInt16(lengthBuffer, 0); + } //read dns datagram - int offset = 0; - while (offset < length) + recvDatagramStream.Reset(0, length, 0); + request = new DnsDatagram(recvDatagramStream); + + DnsDatagram response = ProcessQuery(request, tcpSocket.RemoteEndPoint); + + //send response + if (response != null) { - bytesRecv = tcpSocket.Receive(recvBufferStream.Buffer, offset, length, SocketFlags.None); - if (bytesRecv < 1) - throw new SocketException(); + if (sendBufferStream == null) + sendBufferStream = new MemoryStream(64); - offset += bytesRecv; - } + //write dns datagram + sendBufferStream.Position = 0; + response.WriteTo(sendBufferStream); - bytesRecv = length; + //prepare final buffer + length = Convert.ToUInt16(sendBufferStream.Position); - if (bytesRecv > 0) - { - recvBufferStream.Position = 0; - recvBufferStream.SetLength(bytesRecv); + if ((sendBuffer == null) || (sendBuffer.Length < length + 2)) + sendBuffer = new byte[length + 2]; - request = new DnsDatagram(recvBufferStream); - DnsDatagram response = ProcessQuery(request, tcpSocket.RemoteEndPoint); + //copy datagram length + byte[] lengthBuffer = BitConverter.GetBytes(length); + sendBuffer[0] = lengthBuffer[1]; + sendBuffer[1] = lengthBuffer[0]; - //send response - if (response != null) - { - //write dns datagram - sendBufferStream.Position = 0; - response.WriteTo(sendBufferStream); + //copy datagram + sendBufferStream.Position = 0; + sendBufferStream.Read(sendBuffer, 2, length); - //prepare final buffer - byte[] lengthBytes = BitConverter.GetBytes(Convert.ToInt16(sendBufferStream.Position)); - byte[] buffer = new byte[sendBufferStream.Position + 2]; + //send dns datagram + tcpSocket.Send(sendBuffer, 0, length + 2, SocketFlags.None); - //copy datagram length - buffer[0] = lengthBytes[1]; - buffer[1] = lengthBytes[0]; - - //copy datagram - sendBufferStream.Position = 0; - sendBufferStream.Read(buffer, 2, buffer.Length - 2); - - //send dns datagram - tcpSocket.Send(buffer, 0, buffer.Length, SocketFlags.None); - - LogManager queryLog = _queryLog; - if (queryLog != null) - queryLog.Write((IPEndPoint)tcpSocket.RemoteEndPoint, true, request, response); - } + LogManager queryLog = _queryLog; + if (queryLog != null) + queryLog.Write((IPEndPoint)tcpSocket.RemoteEndPoint, true, request, response); } } } + catch (IOException) + { + //ignore IO exceptions + } catch (Exception ex) { LogManager queryLog = _queryLog; @@ -372,11 +368,13 @@ namespace DnsServerCore if (request.Header.IsResponse) return null; + bool isRecursionAllowed = IsRecursionAllowed(remoteEP); + switch (request.Header.OPCODE) { case DnsOpcode.StandardQuery: if ((request.Question.Length != 1) || (request.Question[0].Class != DnsClass.IN)) - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.Refused, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.Refused, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); switch (request.Question[0].Type) { @@ -384,14 +382,14 @@ namespace DnsServerCore case DnsResourceRecordType.AXFR: case DnsResourceRecordType.MAILB: case DnsResourceRecordType.MAILA: - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.Refused, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.Refused, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); } try { - DnsDatagram authoritativeResponse = ProcessAuthoritativeQuery(request, remoteEP); + DnsDatagram authoritativeResponse = ProcessAuthoritativeQuery(request, isRecursionAllowed); - if ((authoritativeResponse.Header.RCODE != DnsResponseCode.Refused) || !request.Header.RecursionDesired || !IsRecursionAllowed(remoteEP)) + if ((authoritativeResponse.Header.RCODE != DnsResponseCode.Refused) || !request.Header.RecursionDesired || !isRecursionAllowed) return authoritativeResponse; return ProcessRecursiveQuery(request); @@ -400,104 +398,122 @@ namespace DnsServerCore { LogManager log = _log; if (log != null) - log.Write((IPEndPoint)remoteEP, ex); + log.Write(remoteEP as IPEndPoint, ex); - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.ServerFailure, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.ServerFailure, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); } default: - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, request.Header.OPCODE, false, false, request.Header.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.Refused, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, request.Header.OPCODE, false, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.Refused, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); } } - private DnsDatagram ProcessAuthoritativeQuery(DnsDatagram request, EndPoint remoteEP) + private DnsDatagram ProcessAuthoritativeQuery(DnsDatagram request, bool isRecursionAllowed) { DnsDatagram response = _authoritativeZoneRoot.Query(request); - if ((response.Header.RCODE == DnsResponseCode.NoError) && (response.Answer.Length > 0)) + if (response.Header.RCODE == DnsResponseCode.NoError) { - DnsResourceRecordType questionType = request.Question[0].Type; - DnsResourceRecord lastRR = response.Answer[response.Answer.Length - 1]; - - if ((lastRR.Type != questionType) && (lastRR.Type == DnsResourceRecordType.CNAME) && (questionType != DnsResourceRecordType.ANY)) + if (response.Answer.Length > 0) { - List responseAnswer = new List(); - responseAnswer.AddRange(response.Answer); + DnsResourceRecordType questionType = request.Question[0].Type; + DnsResourceRecord lastRR = response.Answer[response.Answer.Length - 1]; - DnsDatagram lastResponse; - - while (true) + if ((lastRR.Type != questionType) && (lastRR.Type == DnsResourceRecordType.CNAME) && (questionType != DnsResourceRecordType.ANY)) { - DnsDatagram cnameRequest = new DnsDatagram(new DnsHeader(0, false, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, false, false, false, DnsResponseCode.NoError, 1, 0, 0, 0), new DnsQuestionRecord[] { new DnsQuestionRecord((lastRR.RDATA as DnsCNAMERecord).CNAMEDomainName, questionType, DnsClass.IN) }, null, null, null); + List responseAnswer = new List(); + responseAnswer.AddRange(response.Answer); - lastResponse = _authoritativeZoneRoot.Query(cnameRequest); + DnsDatagram lastResponse; + + while (true) + { + DnsDatagram cnameRequest = new DnsDatagram(new DnsHeader(0, false, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, false, false, false, DnsResponseCode.NoError, 1, 0, 0, 0), new DnsQuestionRecord[] { new DnsQuestionRecord((lastRR.RDATA as DnsCNAMERecord).CNAMEDomainName, questionType, DnsClass.IN) }, null, null, null); + + lastResponse = _authoritativeZoneRoot.Query(cnameRequest); + + if (lastResponse.Header.RCODE == DnsResponseCode.Refused) + { + if (!cnameRequest.Header.RecursionDesired || !isRecursionAllowed) + break; + + lastResponse = ProcessRecursiveQuery(cnameRequest); + } + + if ((lastResponse.Header.RCODE != DnsResponseCode.NoError) || (lastResponse.Answer.Length == 0)) + break; + + responseAnswer.AddRange(lastResponse.Answer); + + lastRR = lastResponse.Answer[lastResponse.Answer.Length - 1]; + + if (lastRR.Type != DnsResourceRecordType.CNAME) + break; + } + + DnsResponseCode rcode; + DnsResourceRecord[] authority; + DnsResourceRecord[] additional; if (lastResponse.Header.RCODE == DnsResponseCode.Refused) { - if (!cnameRequest.Header.RecursionDesired || !IsRecursionAllowed(remoteEP)) - break; - - lastResponse = ProcessRecursiveQuery(cnameRequest); - } - - if ((lastResponse.Header.RCODE != DnsResponseCode.NoError) || (lastResponse.Answer.Length == 0)) - break; - - responseAnswer.AddRange(lastResponse.Answer); - - lastRR = lastResponse.Answer[lastResponse.Answer.Length - 1]; - - if (lastRR.Type != DnsResourceRecordType.CNAME) - break; - } - - DnsResponseCode rcode; - DnsResourceRecord[] authority; - DnsResourceRecord[] additional; - - if (lastResponse.Header.RCODE == DnsResponseCode.Refused) - { - rcode = DnsResponseCode.NoError; - authority = new DnsResourceRecord[] { }; - additional = new DnsResourceRecord[] { }; - } - else - { - rcode = lastResponse.Header.RCODE; - - if (lastResponse.Header.AuthoritativeAnswer) - { - authority = lastResponse.Authority; - additional = lastResponse.Additional; + rcode = DnsResponseCode.NoError; + authority = new DnsResourceRecord[] { }; + additional = new DnsResourceRecord[] { }; } else { - if ((lastResponse.Authority.Length > 0) && (lastResponse.Authority[0].Type == DnsResourceRecordType.SOA)) + rcode = lastResponse.Header.RCODE; + + if (lastResponse.Header.AuthoritativeAnswer) + { authority = lastResponse.Authority; + additional = lastResponse.Additional; + } else - authority = new DnsResourceRecord[] { }; + { + if ((lastResponse.Authority.Length > 0) && (lastResponse.Authority[0].Type == DnsResourceRecordType.SOA)) + authority = lastResponse.Authority; + else + authority = new DnsResourceRecord[] { }; - additional = new DnsResourceRecord[] { }; + additional = new DnsResourceRecord[] { }; + } } - } - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, lastResponse.Header.AuthoritativeAnswer, false, request.Header.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, rcode, 1, (ushort)responseAnswer.Count, (ushort)authority.Length, (ushort)additional.Length), request.Question, responseAnswer.ToArray(), authority, additional); + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, lastResponse.Header.AuthoritativeAnswer, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, rcode, 1, (ushort)responseAnswer.Count, (ushort)authority.Length, (ushort)additional.Length), request.Question, responseAnswer.ToArray(), authority, additional); + } + } + else if ((response.Authority.Length > 0) && (response.Authority[0].Type == DnsResourceRecordType.NS) && isRecursionAllowed) + { + if (_forwarders != null) + return ProcessRecursiveQuery(request); //do recursive resolution using forwarders + + //do recursive resolution using response authority name servers + NameServerAddress[] nameServers = NameServerAddress.GetNameServersFromResponse(response, _preferIPv6, false); + + return ProcessRecursiveQuery(request, nameServers); } } return response; } - private DnsDatagram ProcessRecursiveQuery(DnsDatagram request) + private DnsDatagram ProcessRecursiveQuery(DnsDatagram request, NameServerAddress[] viaNameServers = null) { DnsClientProtocol protocol; if (_forwarders == null) + { protocol = DnsClient.RecursiveResolveDefaultProtocol; + } else + { + viaNameServers = _forwarders; //forwarder has higher weightage protocol = _forwarderProtocol; + } - DnsDatagram response = DnsClient.ResolveViaNameServers(request.Question[0], _forwarders, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount); + DnsDatagram response = DnsClient.ResolveViaNameServers(request.Question[0], viaNameServers, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount); DnsResourceRecord[] authority; @@ -647,7 +663,7 @@ namespace DnsServerCore { try { - forwarder.RecursiveResolveDomainName(_dnsCache, _proxy, DnsClient.RecursiveResolveDefaultProtocol, _retries); + forwarder.RecursiveResolveDomainName(_dnsCache, _proxy, _preferIPv6, DnsClient.RecursiveResolveDefaultProtocol, _retries); } catch { }