diff --git a/DnsServerCore/DnsServer.cs b/DnsServerCore/DnsServer.cs index 729e073b..c4c60cef 100644 --- a/DnsServerCore/DnsServer.cs +++ b/DnsServerCore/DnsServer.cs @@ -126,23 +126,22 @@ namespace DnsServerCore #endregion - EndPoint remoteEP = null; - FixMemoryStream recvBufferStream = new FixMemoryStream(128); + EndPoint remoteEP; + byte[] recvBuffer = new byte[512]; int bytesRecv; + if (_udpListener.AddressFamily == AddressFamily.InterNetwork) + remoteEP = new IPEndPoint(IPAddress.Any, 0); + else + remoteEP = new IPEndPoint(IPAddress.IPv6Any, 0); + try { while (true) { - - if (_udpListener.AddressFamily == AddressFamily.InterNetwork) - remoteEP = new IPEndPoint(IPAddress.Any, 0); - else - remoteEP = new IPEndPoint(IPAddress.IPv6Any, 0); - try { - bytesRecv = _udpListener.ReceiveFrom(recvBufferStream.Buffer, ref remoteEP); + bytesRecv = _udpListener.ReceiveFrom(recvBuffer, ref remoteEP); } catch (SocketException ex) { @@ -160,15 +159,16 @@ namespace DnsServerCore if (bytesRecv > 0) { - recvBufferStream.Position = 0; - recvBufferStream.SetLength(bytesRecv); - try { - ThreadPool.QueueUserWorkItem(ProcessUdpRequestAsync, new object[] { remoteEP, new DnsDatagram(recvBufferStream) }); + ThreadPool.QueueUserWorkItem(ProcessUdpRequestAsync, new object[] { remoteEP, new DnsDatagram(new MemoryStream(recvBuffer, 0, bytesRecv, false)) }); + } + catch (Exception ex) + { + LogManager log = _log; + if (log != null) + log.Write(remoteEP as IPEndPoint, ex); } - catch - { } } } } @@ -176,7 +176,7 @@ namespace DnsServerCore { LogManager log = _log; if (log != null) - log.Write((IPEndPoint)remoteEP, ex); + log.Write(remoteEP as IPEndPoint, ex); if (_state == ServiceState.Running) throw; @@ -197,13 +197,14 @@ namespace DnsServerCore //send response if (response != null) { - FixMemoryStream sendBufferStream = new FixMemoryStream(512); + byte[] sendBuffer = new byte[512]; + MemoryStream sendBufferStream = new MemoryStream(sendBuffer); try { response.WriteTo(sendBufferStream); } - catch (EndOfStreamException) + catch (NotSupportedException) { DnsHeader header = response.Header; response = new DnsDatagram(new DnsHeader(header.Identifier, true, header.OPCODE, header.AuthoritativeAnswer, true, header.RecursionDesired, header.RecursionAvailable, header.AuthenticData, header.CheckingDisabled, header.RCODE, header.QDCOUNT, 0, 0, 0), response.Question, null, null, null); @@ -213,22 +214,22 @@ namespace DnsServerCore } //send dns datagram - _udpListener.SendTo(sendBufferStream.Buffer, 0, (int)sendBufferStream.Position, SocketFlags.None, remoteEP); + _udpListener.SendTo(sendBuffer, 0, (int)sendBufferStream.Position, SocketFlags.None, remoteEP); LogManager queryLog = _queryLog; if (queryLog != null) - queryLog.Write((IPEndPoint)remoteEP, false, request, response); + queryLog.Write(remoteEP as IPEndPoint, false, request, response); } } catch (Exception ex) { LogManager queryLog = _queryLog; if (queryLog != null) - queryLog.Write((IPEndPoint)remoteEP, false, request, null); + queryLog.Write(remoteEP as IPEndPoint, false, request, null); LogManager log = _log; if (log != null) - log.Write((IPEndPoint)remoteEP, ex); + log.Write(remoteEP as IPEndPoint, ex); } } @@ -265,70 +266,65 @@ namespace DnsServerCore try { - FixMemoryStream recvBufferStream = new FixMemoryStream(128); - MemoryStream sendBufferStream = new MemoryStream(512); - int bytesRecv; + NetworkStream recvStream = new NetworkStream(tcpSocket); + OffsetStream recvDatagramStream = new OffsetStream(recvStream, 0, 0); + MemoryStream sendBufferStream = null; + byte[] sendBuffer = null; + ushort length; while (true) { //read dns datagram length - bytesRecv = tcpSocket.Receive(recvBufferStream.Buffer, 0, 2, SocketFlags.None); - if (bytesRecv < 1) - return; //do nothing - - Array.Reverse(recvBufferStream.Buffer, 0, 2); - short length = BitConverter.ToInt16(recvBufferStream.Buffer, 0); + { + byte[] lengthBuffer = recvStream.ReadBytes(2); + Array.Reverse(lengthBuffer, 0, 2); + length = BitConverter.ToUInt16(lengthBuffer, 0); + } //read dns datagram - int offset = 0; - while (offset < length) + recvDatagramStream.Reset(0, length, 0); + request = new DnsDatagram(recvDatagramStream); + + DnsDatagram response = ProcessQuery(request, tcpSocket.RemoteEndPoint); + + //send response + if (response != null) { - bytesRecv = tcpSocket.Receive(recvBufferStream.Buffer, offset, length, SocketFlags.None); - if (bytesRecv < 1) - throw new SocketException(); + if (sendBufferStream == null) + sendBufferStream = new MemoryStream(64); - offset += bytesRecv; - } + //write dns datagram + sendBufferStream.Position = 0; + response.WriteTo(sendBufferStream); - bytesRecv = length; + //prepare final buffer + length = Convert.ToUInt16(sendBufferStream.Position); - if (bytesRecv > 0) - { - recvBufferStream.Position = 0; - recvBufferStream.SetLength(bytesRecv); + if ((sendBuffer == null) || (sendBuffer.Length < length + 2)) + sendBuffer = new byte[length + 2]; - request = new DnsDatagram(recvBufferStream); - DnsDatagram response = ProcessQuery(request, tcpSocket.RemoteEndPoint); + //copy datagram length + byte[] lengthBuffer = BitConverter.GetBytes(length); + sendBuffer[0] = lengthBuffer[1]; + sendBuffer[1] = lengthBuffer[0]; - //send response - if (response != null) - { - //write dns datagram - sendBufferStream.Position = 0; - response.WriteTo(sendBufferStream); + //copy datagram + sendBufferStream.Position = 0; + sendBufferStream.Read(sendBuffer, 2, length); - //prepare final buffer - byte[] lengthBytes = BitConverter.GetBytes(Convert.ToInt16(sendBufferStream.Position)); - byte[] buffer = new byte[sendBufferStream.Position + 2]; + //send dns datagram + tcpSocket.Send(sendBuffer, 0, length + 2, SocketFlags.None); - //copy datagram length - buffer[0] = lengthBytes[1]; - buffer[1] = lengthBytes[0]; - - //copy datagram - sendBufferStream.Position = 0; - sendBufferStream.Read(buffer, 2, buffer.Length - 2); - - //send dns datagram - tcpSocket.Send(buffer, 0, buffer.Length, SocketFlags.None); - - LogManager queryLog = _queryLog; - if (queryLog != null) - queryLog.Write((IPEndPoint)tcpSocket.RemoteEndPoint, true, request, response); - } + LogManager queryLog = _queryLog; + if (queryLog != null) + queryLog.Write((IPEndPoint)tcpSocket.RemoteEndPoint, true, request, response); } } } + catch (IOException) + { + //ignore IO exceptions + } catch (Exception ex) { LogManager queryLog = _queryLog; @@ -372,11 +368,13 @@ namespace DnsServerCore if (request.Header.IsResponse) return null; + bool isRecursionAllowed = IsRecursionAllowed(remoteEP); + switch (request.Header.OPCODE) { case DnsOpcode.StandardQuery: if ((request.Question.Length != 1) || (request.Question[0].Class != DnsClass.IN)) - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.Refused, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.Refused, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); switch (request.Question[0].Type) { @@ -384,14 +382,14 @@ namespace DnsServerCore case DnsResourceRecordType.AXFR: case DnsResourceRecordType.MAILB: case DnsResourceRecordType.MAILA: - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.Refused, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.Refused, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); } try { - DnsDatagram authoritativeResponse = ProcessAuthoritativeQuery(request, remoteEP); + DnsDatagram authoritativeResponse = ProcessAuthoritativeQuery(request, isRecursionAllowed); - if ((authoritativeResponse.Header.RCODE != DnsResponseCode.Refused) || !request.Header.RecursionDesired || !IsRecursionAllowed(remoteEP)) + if ((authoritativeResponse.Header.RCODE != DnsResponseCode.Refused) || !request.Header.RecursionDesired || !isRecursionAllowed) return authoritativeResponse; return ProcessRecursiveQuery(request); @@ -400,104 +398,122 @@ namespace DnsServerCore { LogManager log = _log; if (log != null) - log.Write((IPEndPoint)remoteEP, ex); + log.Write(remoteEP as IPEndPoint, ex); - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.ServerFailure, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.ServerFailure, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); } default: - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, request.Header.OPCODE, false, false, request.Header.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.Refused, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, request.Header.OPCODE, false, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.Refused, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); } } - private DnsDatagram ProcessAuthoritativeQuery(DnsDatagram request, EndPoint remoteEP) + private DnsDatagram ProcessAuthoritativeQuery(DnsDatagram request, bool isRecursionAllowed) { DnsDatagram response = _authoritativeZoneRoot.Query(request); - if ((response.Header.RCODE == DnsResponseCode.NoError) && (response.Answer.Length > 0)) + if (response.Header.RCODE == DnsResponseCode.NoError) { - DnsResourceRecordType questionType = request.Question[0].Type; - DnsResourceRecord lastRR = response.Answer[response.Answer.Length - 1]; - - if ((lastRR.Type != questionType) && (lastRR.Type == DnsResourceRecordType.CNAME) && (questionType != DnsResourceRecordType.ANY)) + if (response.Answer.Length > 0) { - List responseAnswer = new List(); - responseAnswer.AddRange(response.Answer); + DnsResourceRecordType questionType = request.Question[0].Type; + DnsResourceRecord lastRR = response.Answer[response.Answer.Length - 1]; - DnsDatagram lastResponse; - - while (true) + if ((lastRR.Type != questionType) && (lastRR.Type == DnsResourceRecordType.CNAME) && (questionType != DnsResourceRecordType.ANY)) { - DnsDatagram cnameRequest = new DnsDatagram(new DnsHeader(0, false, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, false, false, false, DnsResponseCode.NoError, 1, 0, 0, 0), new DnsQuestionRecord[] { new DnsQuestionRecord((lastRR.RDATA as DnsCNAMERecord).CNAMEDomainName, questionType, DnsClass.IN) }, null, null, null); + List responseAnswer = new List(); + responseAnswer.AddRange(response.Answer); - lastResponse = _authoritativeZoneRoot.Query(cnameRequest); + DnsDatagram lastResponse; + + while (true) + { + DnsDatagram cnameRequest = new DnsDatagram(new DnsHeader(0, false, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, false, false, false, DnsResponseCode.NoError, 1, 0, 0, 0), new DnsQuestionRecord[] { new DnsQuestionRecord((lastRR.RDATA as DnsCNAMERecord).CNAMEDomainName, questionType, DnsClass.IN) }, null, null, null); + + lastResponse = _authoritativeZoneRoot.Query(cnameRequest); + + if (lastResponse.Header.RCODE == DnsResponseCode.Refused) + { + if (!cnameRequest.Header.RecursionDesired || !isRecursionAllowed) + break; + + lastResponse = ProcessRecursiveQuery(cnameRequest); + } + + if ((lastResponse.Header.RCODE != DnsResponseCode.NoError) || (lastResponse.Answer.Length == 0)) + break; + + responseAnswer.AddRange(lastResponse.Answer); + + lastRR = lastResponse.Answer[lastResponse.Answer.Length - 1]; + + if (lastRR.Type != DnsResourceRecordType.CNAME) + break; + } + + DnsResponseCode rcode; + DnsResourceRecord[] authority; + DnsResourceRecord[] additional; if (lastResponse.Header.RCODE == DnsResponseCode.Refused) { - if (!cnameRequest.Header.RecursionDesired || !IsRecursionAllowed(remoteEP)) - break; - - lastResponse = ProcessRecursiveQuery(cnameRequest); - } - - if ((lastResponse.Header.RCODE != DnsResponseCode.NoError) || (lastResponse.Answer.Length == 0)) - break; - - responseAnswer.AddRange(lastResponse.Answer); - - lastRR = lastResponse.Answer[lastResponse.Answer.Length - 1]; - - if (lastRR.Type != DnsResourceRecordType.CNAME) - break; - } - - DnsResponseCode rcode; - DnsResourceRecord[] authority; - DnsResourceRecord[] additional; - - if (lastResponse.Header.RCODE == DnsResponseCode.Refused) - { - rcode = DnsResponseCode.NoError; - authority = new DnsResourceRecord[] { }; - additional = new DnsResourceRecord[] { }; - } - else - { - rcode = lastResponse.Header.RCODE; - - if (lastResponse.Header.AuthoritativeAnswer) - { - authority = lastResponse.Authority; - additional = lastResponse.Additional; + rcode = DnsResponseCode.NoError; + authority = new DnsResourceRecord[] { }; + additional = new DnsResourceRecord[] { }; } else { - if ((lastResponse.Authority.Length > 0) && (lastResponse.Authority[0].Type == DnsResourceRecordType.SOA)) + rcode = lastResponse.Header.RCODE; + + if (lastResponse.Header.AuthoritativeAnswer) + { authority = lastResponse.Authority; + additional = lastResponse.Additional; + } else - authority = new DnsResourceRecord[] { }; + { + if ((lastResponse.Authority.Length > 0) && (lastResponse.Authority[0].Type == DnsResourceRecordType.SOA)) + authority = lastResponse.Authority; + else + authority = new DnsResourceRecord[] { }; - additional = new DnsResourceRecord[] { }; + additional = new DnsResourceRecord[] { }; + } } - } - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, lastResponse.Header.AuthoritativeAnswer, false, request.Header.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, rcode, 1, (ushort)responseAnswer.Count, (ushort)authority.Length, (ushort)additional.Length), request.Question, responseAnswer.ToArray(), authority, additional); + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, lastResponse.Header.AuthoritativeAnswer, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, rcode, 1, (ushort)responseAnswer.Count, (ushort)authority.Length, (ushort)additional.Length), request.Question, responseAnswer.ToArray(), authority, additional); + } + } + else if ((response.Authority.Length > 0) && (response.Authority[0].Type == DnsResourceRecordType.NS) && isRecursionAllowed) + { + if (_forwarders != null) + return ProcessRecursiveQuery(request); //do recursive resolution using forwarders + + //do recursive resolution using response authority name servers + NameServerAddress[] nameServers = NameServerAddress.GetNameServersFromResponse(response, _preferIPv6, false); + + return ProcessRecursiveQuery(request, nameServers); } } return response; } - private DnsDatagram ProcessRecursiveQuery(DnsDatagram request) + private DnsDatagram ProcessRecursiveQuery(DnsDatagram request, NameServerAddress[] viaNameServers = null) { DnsClientProtocol protocol; if (_forwarders == null) + { protocol = DnsClient.RecursiveResolveDefaultProtocol; + } else + { + viaNameServers = _forwarders; //forwarder has higher weightage protocol = _forwarderProtocol; + } - DnsDatagram response = DnsClient.ResolveViaNameServers(request.Question[0], _forwarders, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount); + DnsDatagram response = DnsClient.ResolveViaNameServers(request.Question[0], viaNameServers, _dnsCache, _proxy, _preferIPv6, protocol, _retries, _maxStackCount); DnsResourceRecord[] authority; @@ -647,7 +663,7 @@ namespace DnsServerCore { try { - forwarder.RecursiveResolveDomainName(_dnsCache, _proxy, DnsClient.RecursiveResolveDefaultProtocol, _retries); + forwarder.RecursiveResolveDomainName(_dnsCache, _proxy, _preferIPv6, DnsClient.RecursiveResolveDefaultProtocol, _retries); } catch { }