diff --git a/DnsServerCore/DnsServer.cs b/DnsServerCore/DnsServer.cs index 1e57c58e..f0f18a4c 100644 --- a/DnsServerCore/DnsServer.cs +++ b/DnsServerCore/DnsServer.cs @@ -1,5 +1,5 @@ /* -Technitium Library +Technitium DNS Server Copyright (C) 2017 Shreyas Zare (shreyas@technitium.com) This program is free software: you can redistribute it and/or modify @@ -24,7 +24,7 @@ using System.Net; using System.Net.Sockets; using System.Threading; using TechnitiumLibrary.IO; -using TechnitiumLibrary.Net; +using TechnitiumLibrary.Net.Dns; namespace DnsServerCore { @@ -35,18 +35,20 @@ namespace DnsServerCore const int TCP_SOCKET_SEND_TIMEOUT = 30000; const int TCP_SOCKET_RECV_TIMEOUT = 60000; - Socket _udpListener; - Thread _udpListenerThread; + readonly Socket _udpListener; + readonly Thread _udpListenerThread; - Socket _tcpListener; - Thread _tcpListenerThread; + readonly Socket _tcpListener; + readonly Thread _tcpListenerThread; - Zone _authoritativeZoneRoot = new Zone(true); - Zone _cacheZoneRoot = new Zone(false); + readonly Zone _authoritativeZoneRoot = new Zone(true); + readonly Zone _cacheZoneRoot = new Zone(false); - bool _allowRecursion; + readonly IDnsCache _dnsCache; + + bool _allowRecursion = false; NameServerAddress[] _forwarders; - bool _enableIPv6 = false; + bool _preferIPv6 = false; int _retries = 2; #endregion @@ -57,12 +59,14 @@ namespace DnsServerCore : this(new IPEndPoint(IPAddress.IPv6Any, 53)) { } - public DnsServer(IPAddress localIP, int port = 53) - : this(new IPEndPoint(localIP, port)) + public DnsServer(IPAddress localIP) + : this(new IPEndPoint(localIP, 53)) { } public DnsServer(IPEndPoint localEP) { + _dnsCache = new DnsCache(_cacheZoneRoot); + _udpListener = new Socket(AddressFamily.InterNetworkV6, SocketType.Dgram, ProtocolType.Udp); _udpListener.SetSocketOption(SocketOptionLevel.IPv6, SocketOptionName.IPv6Only, false); _udpListener.Bind(localEP); @@ -75,11 +79,11 @@ namespace DnsServerCore //start reading query packets _udpListenerThread = new Thread(ReadUdpQueryPacketsAsync); _udpListenerThread.IsBackground = true; - _udpListenerThread.Start(_udpListener); + _udpListenerThread.Start(); _tcpListenerThread = new Thread(AcceptTcpConnectionAsync); _tcpListenerThread.IsBackground = true; - _tcpListenerThread.Start(_tcpListener); + _tcpListenerThread.Start(); } #endregion @@ -88,11 +92,6 @@ namespace DnsServerCore bool _disposed = false; - ~DnsServer() - { - Dispose(false); - } - public void Dispose() { Dispose(true); @@ -122,21 +121,19 @@ namespace DnsServerCore private void ReadUdpQueryPacketsAsync(object parameter) { - Socket udpListener = parameter as Socket; - EndPoint remoteEP; FixMemoryStream recvBufferStream = new FixMemoryStream(128); FixMemoryStream sendBufferStream = new FixMemoryStream(512); int bytesRecv; - if (udpListener.AddressFamily == AddressFamily.InterNetwork) + if (_udpListener.AddressFamily == AddressFamily.InterNetwork) remoteEP = new IPEndPoint(IPAddress.Any, 0); else remoteEP = new IPEndPoint(IPAddress.IPv6Any, 0); while (true) { - bytesRecv = udpListener.ReceiveFrom(recvBufferStream.Buffer, ref remoteEP); + bytesRecv = _udpListener.ReceiveFrom(recvBufferStream.Buffer, ref remoteEP); if (bytesRecv > 0) { @@ -167,7 +164,7 @@ namespace DnsServerCore } //send dns datagram - udpListener.SendTo(sendBufferStream.Buffer, 0, (int)sendBufferStream.Position, SocketFlags.None, remoteEP); + _udpListener.SendTo(sendBufferStream.Buffer, 0, (int)sendBufferStream.Position, SocketFlags.None, remoteEP); } } catch @@ -178,11 +175,9 @@ namespace DnsServerCore private void AcceptTcpConnectionAsync(object parameter) { - Socket tcpListener = parameter as Socket; - while (true) { - Socket socket = tcpListener.Accept(); + Socket socket = _tcpListener.Accept(); socket.NoDelay = true; socket.SendTimeout = TCP_SOCKET_SEND_TIMEOUT; @@ -285,14 +280,17 @@ namespace DnsServerCore switch (request.Header.OPCODE) { case DnsOpcode.StandardQuery: + if (request.Question.Length != 1) + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, request.Header.OPCODE, false, false, request.Header.RecursionDesired, _allowRecursion, false, false, DnsResponseCode.Refused, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null); + try { - DnsDatagram authoritativeResponse = Zone.Query(_authoritativeZoneRoot, request, _enableIPv6); + DnsDatagram authoritativeResponse = _authoritativeZoneRoot.Query(request); - if ((authoritativeResponse.Header.RCODE != DnsResponseCode.Refused) || !request.Header.RecursionDesired || !_allowRecursion) + if ((authoritativeResponse.Header.AuthoritativeAnswer) || !request.Header.RecursionDesired || !_allowRecursion) return authoritativeResponse; - return RecursiveQuery(request); + return ProcessRecursiveQuery(request); } catch { @@ -304,171 +302,42 @@ namespace DnsServerCore } } - public DnsDatagram RecursiveQuery(DnsDatagram request) + public DnsDatagram ProcessRecursiveQuery(DnsDatagram request) { - DnsDatagram originalRequest = request; - List responses = new List(1); + DnsDatagram response = DnsClient.ResolveViaNameServers(_forwarders, request.Question[0], _dnsCache, null, _preferIPv6, false, _retries); - while (true) + if ((response.Header.RCODE == DnsResponseCode.NoError) && (response.Answer.Length > 0)) { - DnsDatagram response = Resolve(request); - responses.Add(response); - - if (response.Header.RCODE != DnsResponseCode.NoError) - break; - - if (response.Answer.Length == 0) - break; - - List newQuestions = new List(); - - foreach (DnsQuestionRecord question in request.Question) + if ((response.Answer[0].Type == DnsResourceRecordType.CNAME) && (request.Question[0].Type != DnsResourceRecordType.CNAME) && (request.Question[0].Type != DnsResourceRecordType.ANY)) { - for (int i = 0; i < response.Answer.Length; i++) - { - DnsResourceRecord answerRecord = response.Answer[i]; + DnsResourceRecord cnameRR = response.Answer[0]; - if ((answerRecord.Type == DnsResourceRecordType.CNAME) && question.Name.Equals(answerRecord.Name, StringComparison.CurrentCultureIgnoreCase)) - { - string cnameDomain = (answerRecord.RDATA as DnsCNAMERecord).CNAMEDomainName; - bool containsAnswer = false; - - for (int j = i + 1; j < response.Answer.Length; j++) - { - DnsResourceRecord answer = response.Answer[j]; - - if ((answer.Type == question.Type) && cnameDomain.Equals(answer.Name, StringComparison.CurrentCultureIgnoreCase)) - { - containsAnswer = true; - break; - } - } - - if (!containsAnswer) - newQuestions.Add(new DnsQuestionRecord((answerRecord.RDATA as DnsCNAMERecord).CNAMEDomainName, question.Type, question.Class)); - - break; - } - } - } - - if (newQuestions.Count == 0) - break; - - request = new DnsDatagram(new DnsHeader(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, Convert.ToUInt16(newQuestions.Count), 0, 0, 0), newQuestions.ToArray(), null, null, null); - } - - return MergeResponseAnswers(originalRequest, responses); - } - - private DnsDatagram Resolve(DnsDatagram request) - { - DnsDatagram cacheResponse = Zone.Query(_cacheZoneRoot, request, _enableIPv6); - - if (cacheResponse.Header.RCODE != DnsResponseCode.Refused) - return cacheResponse; - - List responses = new List(); - - foreach (DnsQuestionRecord questionRecord in request.Question) - { - NameServerAddress[] nameServers; - - if (_forwarders == null) - { - nameServers = NameServerAddress.GetNameServersFromResponse(cacheResponse, _enableIPv6); - - if (nameServers.Length == 0) - { - if (_enableIPv6) - nameServers = DnsClient.ROOT_NAME_SERVERS_IPv6; - else - nameServers = DnsClient.ROOT_NAME_SERVERS_IPv4; - } - } - else - { - nameServers = _forwarders; - } - - int hopCount = 0; - bool working = true; - bool tcp = false; - - while (working && ((hopCount++) < 64)) - { - DnsClient client = new DnsClient(nameServers, _enableIPv6, tcp, _retries); - - DnsDatagram response = client.Resolve(questionRecord); - - if (response.Header.Truncation) - { - tcp = true; - continue; - } - - Zone.CacheResponse(_cacheZoneRoot, response); - - switch (response.Header.RCODE) - { - case DnsResponseCode.NoError: - if ((response.Answer.Length > 0) || (response.Authority.Length == 0)) - { - responses.Add(response); - working = false; - } - else - { - nameServers = NameServerAddress.GetNameServersFromResponse(response, _enableIPv6); - - if (nameServers.Length == 0) - { - responses.Add(response); - working = false; - } - } - break; - - default: - responses.Add(response); - working = false; - break; - } - } - } - - return MergeResponseAnswers(request, responses); - } - - private DnsDatagram MergeResponseAnswers(DnsDatagram request, List responses) - { - switch (responses.Count) - { - case 0: - return null; - - case 1: - DnsDatagram responseReceived = responses[0]; - - if (responseReceived.Answer.Length == 0) - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, request.Header.OPCODE, false, false, true, true, false, false, responseReceived.Header.RCODE, request.Header.QDCOUNT, responseReceived.Header.ANCOUNT, responseReceived.Header.NSCOUNT, 0), request.Question, responseReceived.Answer, responseReceived.Authority, null); - else - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, request.Header.OPCODE, false, false, true, true, false, false, responseReceived.Header.RCODE, request.Header.QDCOUNT, responseReceived.Header.ANCOUNT, 0, 0), request.Question, responseReceived.Answer, null, null); - - default: List responseAnswer = new List(); - List responseAuthority = new List(); + responseAnswer.Add(cnameRR); - foreach (DnsDatagram response in responses) + while (true) { - responseAnswer.AddRange(response.Answer); + DnsDatagram cnameResponse = DnsClient.ResolveViaNameServers(_forwarders, (cnameRR.RDATA as DnsCNAMERecord).CNAMEDomainName, request.Question[0].Type, _dnsCache, null, _preferIPv6, false, _retries); - if ((response.Answer.Length == 0) && (response.Authority != null)) - responseAuthority.AddRange(response.Authority); + if (cnameResponse.Header.RCODE != DnsResponseCode.NoError) + break; + + if (cnameResponse.Answer.Length == 0) + break; + + responseAnswer.AddRange(cnameResponse.Answer); + + if (cnameResponse.Answer[0].Type != DnsResourceRecordType.CNAME) + break; + + cnameRR = cnameResponse.Answer[0]; } - return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, request.Header.OPCODE, false, false, true, true, false, false, responses[0].Header.RCODE, request.Header.QDCOUNT, Convert.ToUInt16(responseAnswer.Count), Convert.ToUInt16(responseAuthority.Count), 0), request.Question, responseAnswer.ToArray(), responseAuthority.ToArray(), null); + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, false, false, DnsResponseCode.NoError, 1, Convert.ToUInt16(responseAnswer.Count), 0, 0), request.Question, responseAnswer.ToArray(), new DnsResourceRecord[] { }, new DnsResourceRecord[] { }); + } } + + return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, false, false, DnsResponseCode.NoError, 1, Convert.ToUInt16(response.Answer.Length), 0, 0), request.Question, response.Answer, new DnsResourceRecord[] { }, new DnsResourceRecord[] { }); } #endregion @@ -493,10 +362,10 @@ namespace DnsServerCore set { _forwarders = value; } } - public bool EnableIPv6 + public bool PreferIPv6 { - get { return _enableIPv6; } - set { _enableIPv6 = value; } + get { return _preferIPv6; } + set { _preferIPv6 = value; } } public int Retries @@ -506,29 +375,37 @@ namespace DnsServerCore } #endregion + + class DnsCache : IDnsCache + { + #region variables + + readonly Zone _cacheZoneRoot; + + #endregion + + #region constructor + + public DnsCache(Zone cacheZoneRoot) + { + _cacheZoneRoot = cacheZoneRoot; + } + + #endregion + + #region public + + public DnsDatagram Query(DnsDatagram request) + { + return _cacheZoneRoot.Query(request); + } + + public void CacheResponse(DnsDatagram response) + { + _cacheZoneRoot.CacheResponse(response); + } + + #endregion + } } - - public class DnsServerException : Exception - { - #region constructors - - public DnsServerException() - : base() - { } - - public DnsServerException(string message) - : base(message) - { } - - public DnsServerException(string message, Exception innerException) - : base(message, innerException) - { } - - protected DnsServerException(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) - : base(info, context) - { } - - #endregion - } - }