diff --git a/DnsServerCore/Dns/DnsServer.cs b/DnsServerCore/Dns/DnsServer.cs index 688a7bc9..f80df940 100644 --- a/DnsServerCore/Dns/DnsServer.cs +++ b/DnsServerCore/Dns/DnsServer.cs @@ -128,6 +128,7 @@ namespace DnsServerCore.Dns IList _cacheRefreshSampleList; Timer _cacheMaintenanceTimer; + readonly object _cacheMaintenanceTimerLock = new object(); const int CACHE_MAINTENANCE_TIMER_INITIAL_INTEVAL = 15 * 60 * 1000; const int CACHE_MAINTENANCE_TIMER_PERIODIC_INTERVAL = 15 * 60 * 1000; @@ -229,7 +230,9 @@ namespace DnsServerCore.Dns private async Task ReadUdpRequestAsync(Socket udpListener) { - byte[] recvBuffer = new byte[1500]; + const int BUFFER_SIZE = 512; + byte[] recvBuffer = new byte[BUFFER_SIZE]; + MemoryStream recvBufferStream = new MemoryStream(recvBuffer); try { @@ -253,6 +256,8 @@ namespace DnsServerCore.Dns while (true) { + recvBufferStream.SetLength(BUFFER_SIZE); //resetting length before using buffer + try { result = await udpListener.ReceiveFromAsync(recvBuffer, SocketFlags.None, epAny); @@ -277,12 +282,12 @@ namespace DnsServerCore.Dns { try { - using (MemoryStream mS = new MemoryStream(recvBuffer, 0, result.ReceivedBytes, false)) - { - DnsDatagram request = DnsDatagram.ReadFromUdp(mS); + recvBufferStream.Position = 0; + recvBufferStream.SetLength(result.ReceivedBytes); - _ = ProcessUdpRequestAsync(udpListener, result.RemoteEndPoint as IPEndPoint, request); - } + DnsDatagram request = DnsDatagram.ReadFromUdp(recvBufferStream); + + _ = ProcessUdpRequestAsync(udpListener, result.RemoteEndPoint as IPEndPoint, request); } catch (EndOfStreamException) { @@ -325,9 +330,12 @@ namespace DnsServerCore.Dns else { //format error - LogManager log = _log; - if (log != null) - log.Write(remoteEP, DnsTransportProtocol.Udp, request.ParsingException); + if (!(request.ParsingException is IOException)) + { + LogManager log = _log; + if (log != null) + log.Write(remoteEP, DnsTransportProtocol.Udp, request.ParsingException); + } //format error response response = new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, request.Question); @@ -540,9 +548,12 @@ namespace DnsServerCore.Dns if (queryLog != null) queryLog.Write(remoteEP, protocol, request, response); - LogManager log = _log; - if (log != null) - log.Write(remoteEP, protocol, request.ParsingException); + if (!(request.ParsingException is IOException)) + { + LogManager log = _log; + if (log != null) + log.Write(remoteEP, protocol, request.ParsingException); + } } //send response @@ -684,9 +695,9 @@ namespace DnsServerCore.Dns if (strContentType != "application/dns-message") throw new NotSupportedException("DNS request type not supported: " + strContentType); - using (MemoryStream mS = new MemoryStream()) + using (MemoryStream mS = new MemoryStream(32)) { - await httpRequest.InputStream.CopyToAsync(mS, 512); + await httpRequest.InputStream.CopyToAsync(mS, 32); mS.Position = 0; dnsRequest = DnsDatagram.ReadFromUdp(mS); @@ -707,9 +718,12 @@ namespace DnsServerCore.Dns else { //format error - LogManager log = _log; - if (log != null) - log.Write(remoteEP, protocol, dnsRequest.ParsingException); + if (!(dnsRequest.ParsingException is IOException)) + { + LogManager log = _log; + if (log != null) + log.Write(remoteEP, protocol, dnsRequest.ParsingException); + } //format error response dnsResponse = new DnsDatagram(dnsRequest.Identifier, true, dnsRequest.OPCODE, false, false, dnsRequest.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, dnsRequest.Question); @@ -717,12 +731,12 @@ namespace DnsServerCore.Dns if (dnsResponse != null) { - using (MemoryStream mS = new MemoryStream()) + using (MemoryStream mS = new MemoryStream(512)) { dnsResponse.WriteToUdp(mS); - byte[] buffer = mS.ToArray(); - await SendContentAsync(stream, requestConnection, "application/dns-message", buffer); + mS.Position = 0; + await SendContentAsync(stream, requestConnection, "application/dns-message", mS); } LogManager queryLog = _queryLog; @@ -753,14 +767,14 @@ namespace DnsServerCore.Dns DnsDatagram dnsResponse = await ProcessQueryAsync(dnsRequest, remoteEP, IsRecursionAllowed(remoteEP), protocol); if (dnsResponse != null) { - using (MemoryStream mS = new MemoryStream()) + using (MemoryStream mS = new MemoryStream(512)) { JsonTextWriter jsonWriter = new JsonTextWriter(new StreamWriter(mS)); dnsResponse.WriteToJson(jsonWriter); jsonWriter.Flush(); - byte[] buffer = mS.ToArray(); - await SendContentAsync(stream, requestConnection, "application/dns-json; charset=utf-8", buffer); + mS.Position = 0; + await SendContentAsync(stream, requestConnection, "application/dns-json; charset=utf-8", mS); } LogManager queryLog = _queryLog; @@ -832,12 +846,12 @@ namespace DnsServerCore.Dns } } - private static async Task SendContentAsync(Stream outputStream, string connection, string contentType, byte[] bufferContent) + private static async Task SendContentAsync(Stream outputStream, string connection, string contentType, Stream content) { - byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: " + contentType + "\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\nConnection: " + connection + "\r\n\r\n"); + byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: " + contentType + "\r\nContent-Length: " + content.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\nConnection: " + connection + "\r\n\r\n"); await outputStream.WriteAsync(bufferHeader); - await outputStream.WriteAsync(bufferContent); + await content.CopyToAsync(outputStream); await outputStream.FlushAsync(); } @@ -964,7 +978,7 @@ namespace DnsServerCore.Dns catch (InvalidDomainNameException) { //format error response - return new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, request.Question); + return new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.FormatError, request.Question); } catch (Exception ex) { @@ -1153,7 +1167,14 @@ namespace DnsServerCore.Dns AuthZoneInfo zoneInfo = _authZoneManager.GetAuthZoneInfo(appResourceRecord.Name); DnsDatagram appResponse = await appRequestHandler.ProcessRequestAsync(request, remoteEP, zoneInfo.Name, appResourceRecord.TtlValue, appRecord.Data, isRecursionAllowed, application.DnsServer); - if (appResponse != null) + if (appResponse == null) + { + //return no error response with SOA + IReadOnlyList authority = zoneInfo.GetRecords(DnsResourceRecordType.SOA); + + return new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.NoError, request.Question, null, authority); + } + else { if (appResponse.AuthoritativeAnswer) appResponse.Tag = StatsResponseType.Authoritative; @@ -1175,12 +1196,12 @@ namespace DnsServerCore.Dns log.Write(remoteEP, protocol, "DNS application '" + appRecord.AppName + "' was not found: " + appResourceRecord.Name); } - //return no error response with SOA + //return server failure response with SOA { AuthZoneInfo zoneInfo = _authZoneManager.GetAuthZoneInfo(request.Question[0].Name); IReadOnlyList authority = zoneInfo.GetRecords(DnsResourceRecordType.SOA); - return new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.NoError, request.Question, null, authority); + return new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.ServerFailure, request.Question, null, authority); } } @@ -1883,27 +1904,29 @@ namespace DnsServerCore.Dns DnsDatagram request = new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { sample.SampleQuestion }); DnsDatagram response = await ProcessRecursiveQueryAsync(request, new IPEndPoint(IPAddress.Any, 0), DnsTransportProtocol.Udp, sample.ViaForwarders, true); - bool removeFromSampleList = true; + bool addBackToSampleList = false; DateTime utcNow = DateTime.UtcNow; foreach (DnsResourceRecord answer in response.Answer) { if ((answer.OriginalTtlValue > _cachePrefetchEligibility) && (utcNow.AddSeconds(answer.TtlValue) < _cachePrefetchSamplingTimerTriggersOn)) { - //answer expires before next sampling so dont remove from list to allow refreshing it - removeFromSampleList = false; + //answer expires before next sampling so add back to the list to allow refreshing it + addBackToSampleList = true; break; } } - if (removeFromSampleList) - cacheRefreshSampleList[sampleQuestionIndex] = null; + if (addBackToSampleList) + cacheRefreshSampleList[sampleQuestionIndex] = sample; //put back into sample list to allow refreshing it again } catch (Exception ex) { LogManager log = _log; if (log != null) log.Write(ex); + + cacheRefreshSampleList[sampleQuestionIndex] = sample; //put back into sample list to allow refreshing it again } } @@ -2094,7 +2117,10 @@ namespace DnsServerCore.Dns if (!IsCacheRefreshNeeded(sample.SampleQuestion, _cachePrefetchTrigger + 2)) continue; - _ = RefreshCacheAsync(cacheRefreshSampleList, sample, i); + cacheRefreshSampleList[i] = null; //remove from sample list to avoid concurrent refresh attempt + + int sampleQuestionIndex = i; + _ = Task.Run(delegate () { return RefreshCacheAsync(cacheRefreshSampleList, sample, sampleQuestionIndex); }); //run task in threadpool since its long running } } } @@ -2126,6 +2152,14 @@ namespace DnsServerCore.Dns if (log != null) log.Write(ex); } + finally + { + lock (_cacheMaintenanceTimerLock) + { + if (_cacheMaintenanceTimer != null) + _cacheMaintenanceTimer.Change(CACHE_MAINTENANCE_TIMER_PERIODIC_INTERVAL, Timeout.Infinite); + } + } } private void ResetPrefetchTimers() @@ -2165,20 +2199,28 @@ namespace DnsServerCore.Dns private void UpdateThisServer() { - if (_thisServer == null) + if ((_localEndPoints == null) || (_localEndPoints.Count == 0)) { - if ((_localEndPoints == null) || (_localEndPoints.Count == 0)) - _thisServer = new NameServerAddress(_serverDomain, IPAddress.Loopback); - else if (_localEndPoints[0].Address.Equals(IPAddress.Any)) - _thisServer = new NameServerAddress(_serverDomain, new IPEndPoint(IPAddress.Loopback, _localEndPoints[0].Port)); - else if (_localEndPoints[0].Equals(IPAddress.IPv6Any)) - _thisServer = new NameServerAddress(_serverDomain, new IPEndPoint(IPAddress.IPv6Loopback, _localEndPoints[0].Port)); - else - _thisServer = new NameServerAddress(_serverDomain, _localEndPoints[0]); + _thisServer = new NameServerAddress(_serverDomain, IPAddress.Loopback); } else { - _thisServer = new NameServerAddress(_serverDomain, _thisServer.IPEndPoint); + foreach (IPEndPoint localEndPoint in _localEndPoints) + { + if (localEndPoint.Address.Equals(IPAddress.Any)) + { + _thisServer = new NameServerAddress(_serverDomain, new IPEndPoint(IPAddress.Loopback, localEndPoint.Port)); + return; + } + + if (localEndPoint.Address.Equals(IPAddress.IPv6Any)) + { + _thisServer = new NameServerAddress(_serverDomain, new IPEndPoint(IPAddress.IPv6Loopback, localEndPoint.Port)); + return; + } + } + + _thisServer = new NameServerAddress(_serverDomain, _localEndPoints[0]); } } @@ -2449,7 +2491,7 @@ namespace DnsServerCore.Dns _cachePrefetchSamplingTimer = new Timer(CachePrefetchSamplingTimerCallback, null, Timeout.Infinite, Timeout.Infinite); _cachePrefetchRefreshTimer = new Timer(CachePrefetchRefreshTimerCallback, null, Timeout.Infinite, Timeout.Infinite); - _cacheMaintenanceTimer = new Timer(CacheMaintenanceTimerCallback, null, CACHE_MAINTENANCE_TIMER_INITIAL_INTEVAL, CACHE_MAINTENANCE_TIMER_PERIODIC_INTERVAL); + _cacheMaintenanceTimer = new Timer(CacheMaintenanceTimerCallback, null, CACHE_MAINTENANCE_TIMER_INITIAL_INTEVAL, Timeout.Infinite); _state = ServiceState.Running; @@ -2482,10 +2524,13 @@ namespace DnsServerCore.Dns } } - if (_cacheMaintenanceTimer != null) + lock (_cacheMaintenanceTimerLock) { - _cacheMaintenanceTimer.Dispose(); - _cacheMaintenanceTimer = null; + if (_cacheMaintenanceTimer != null) + { + _cacheMaintenanceTimer.Dispose(); + _cacheMaintenanceTimer = null; + } } foreach (Socket udpListener in _udpListeners)