From f5da78c5791830776a0000fbc6d3de9b97c7daee Mon Sep 17 00:00:00 2001 From: Shreyas Zare Date: Sat, 10 Apr 2021 13:58:48 +0530 Subject: [PATCH] DnsServer: reusing memory stream in ReadUdpRequestAsync(). Updated ProcessAPPAsync() to return server failure when app or class path is not found. Updated CachePrefetchRefreshTimerCallback() and RefreshCacheAsync() to prevent double refresh attempts. Updated CachePrefetchRefreshTimerCallback() to use threadpool tasks for better concurrency. Updated CacheMaintenanceTimerCallback() to reset due time in finally for next interval callback and using sync lock. Fixed bug in UpdateThisServer() logic. --- DnsServerCore/Dns/DnsServer.cs | 145 +++++++++++++++++++++------------ 1 file changed, 95 insertions(+), 50 deletions(-) diff --git a/DnsServerCore/Dns/DnsServer.cs b/DnsServerCore/Dns/DnsServer.cs index 688a7bc9..f80df940 100644 --- a/DnsServerCore/Dns/DnsServer.cs +++ b/DnsServerCore/Dns/DnsServer.cs @@ -128,6 +128,7 @@ namespace DnsServerCore.Dns IList _cacheRefreshSampleList; Timer _cacheMaintenanceTimer; + readonly object _cacheMaintenanceTimerLock = new object(); const int CACHE_MAINTENANCE_TIMER_INITIAL_INTEVAL = 15 * 60 * 1000; const int CACHE_MAINTENANCE_TIMER_PERIODIC_INTERVAL = 15 * 60 * 1000; @@ -229,7 +230,9 @@ namespace DnsServerCore.Dns private async Task ReadUdpRequestAsync(Socket udpListener) { - byte[] recvBuffer = new byte[1500]; + const int BUFFER_SIZE = 512; + byte[] recvBuffer = new byte[BUFFER_SIZE]; + MemoryStream recvBufferStream = new MemoryStream(recvBuffer); try { @@ -253,6 +256,8 @@ namespace DnsServerCore.Dns while (true) { + recvBufferStream.SetLength(BUFFER_SIZE); //resetting length before using buffer + try { result = await udpListener.ReceiveFromAsync(recvBuffer, SocketFlags.None, epAny); @@ -277,12 +282,12 @@ namespace DnsServerCore.Dns { try { - using (MemoryStream mS = new MemoryStream(recvBuffer, 0, result.ReceivedBytes, false)) - { - DnsDatagram request = DnsDatagram.ReadFromUdp(mS); + recvBufferStream.Position = 0; + recvBufferStream.SetLength(result.ReceivedBytes); - _ = ProcessUdpRequestAsync(udpListener, result.RemoteEndPoint as IPEndPoint, request); - } + DnsDatagram request = DnsDatagram.ReadFromUdp(recvBufferStream); + + _ = ProcessUdpRequestAsync(udpListener, result.RemoteEndPoint as IPEndPoint, request); } catch (EndOfStreamException) { @@ -325,9 +330,12 @@ namespace DnsServerCore.Dns else { //format error - LogManager log = _log; - if (log != null) - log.Write(remoteEP, DnsTransportProtocol.Udp, request.ParsingException); + if (!(request.ParsingException is IOException)) + { + LogManager log = _log; + if (log != null) + log.Write(remoteEP, DnsTransportProtocol.Udp, request.ParsingException); + } //format error response response = new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, request.Question); @@ -540,9 +548,12 @@ namespace DnsServerCore.Dns if (queryLog != null) queryLog.Write(remoteEP, protocol, request, response); - LogManager log = _log; - if (log != null) - log.Write(remoteEP, protocol, request.ParsingException); + if (!(request.ParsingException is IOException)) + { + LogManager log = _log; + if (log != null) + log.Write(remoteEP, protocol, request.ParsingException); + } } //send response @@ -684,9 +695,9 @@ namespace DnsServerCore.Dns if (strContentType != "application/dns-message") throw new NotSupportedException("DNS request type not supported: " + strContentType); - using (MemoryStream mS = new MemoryStream()) + using (MemoryStream mS = new MemoryStream(32)) { - await httpRequest.InputStream.CopyToAsync(mS, 512); + await httpRequest.InputStream.CopyToAsync(mS, 32); mS.Position = 0; dnsRequest = DnsDatagram.ReadFromUdp(mS); @@ -707,9 +718,12 @@ namespace DnsServerCore.Dns else { //format error - LogManager log = _log; - if (log != null) - log.Write(remoteEP, protocol, dnsRequest.ParsingException); + if (!(dnsRequest.ParsingException is IOException)) + { + LogManager log = _log; + if (log != null) + log.Write(remoteEP, protocol, dnsRequest.ParsingException); + } //format error response dnsResponse = new DnsDatagram(dnsRequest.Identifier, true, dnsRequest.OPCODE, false, false, dnsRequest.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, dnsRequest.Question); @@ -717,12 +731,12 @@ namespace DnsServerCore.Dns if (dnsResponse != null) { - using (MemoryStream mS = new MemoryStream()) + using (MemoryStream mS = new MemoryStream(512)) { dnsResponse.WriteToUdp(mS); - byte[] buffer = mS.ToArray(); - await SendContentAsync(stream, requestConnection, "application/dns-message", buffer); + mS.Position = 0; + await SendContentAsync(stream, requestConnection, "application/dns-message", mS); } LogManager queryLog = _queryLog; @@ -753,14 +767,14 @@ namespace DnsServerCore.Dns DnsDatagram dnsResponse = await ProcessQueryAsync(dnsRequest, remoteEP, IsRecursionAllowed(remoteEP), protocol); if (dnsResponse != null) { - using (MemoryStream mS = new MemoryStream()) + using (MemoryStream mS = new MemoryStream(512)) { JsonTextWriter jsonWriter = new JsonTextWriter(new StreamWriter(mS)); dnsResponse.WriteToJson(jsonWriter); jsonWriter.Flush(); - byte[] buffer = mS.ToArray(); - await SendContentAsync(stream, requestConnection, "application/dns-json; charset=utf-8", buffer); + mS.Position = 0; + await SendContentAsync(stream, requestConnection, "application/dns-json; charset=utf-8", mS); } LogManager queryLog = _queryLog; @@ -832,12 +846,12 @@ namespace DnsServerCore.Dns } } - private static async Task SendContentAsync(Stream outputStream, string connection, string contentType, byte[] bufferContent) + private static async Task SendContentAsync(Stream outputStream, string connection, string contentType, Stream content) { - byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: " + contentType + "\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\nConnection: " + connection + "\r\n\r\n"); + byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: " + contentType + "\r\nContent-Length: " + content.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\nConnection: " + connection + "\r\n\r\n"); await outputStream.WriteAsync(bufferHeader); - await outputStream.WriteAsync(bufferContent); + await content.CopyToAsync(outputStream); await outputStream.FlushAsync(); } @@ -964,7 +978,7 @@ namespace DnsServerCore.Dns catch (InvalidDomainNameException) { //format error response - return new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, request.Question); + return new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.FormatError, request.Question); } catch (Exception ex) { @@ -1153,7 +1167,14 @@ namespace DnsServerCore.Dns AuthZoneInfo zoneInfo = _authZoneManager.GetAuthZoneInfo(appResourceRecord.Name); DnsDatagram appResponse = await appRequestHandler.ProcessRequestAsync(request, remoteEP, zoneInfo.Name, appResourceRecord.TtlValue, appRecord.Data, isRecursionAllowed, application.DnsServer); - if (appResponse != null) + if (appResponse == null) + { + //return no error response with SOA + IReadOnlyList authority = zoneInfo.GetRecords(DnsResourceRecordType.SOA); + + return new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.NoError, request.Question, null, authority); + } + else { if (appResponse.AuthoritativeAnswer) appResponse.Tag = StatsResponseType.Authoritative; @@ -1175,12 +1196,12 @@ namespace DnsServerCore.Dns log.Write(remoteEP, protocol, "DNS application '" + appRecord.AppName + "' was not found: " + appResourceRecord.Name); } - //return no error response with SOA + //return server failure response with SOA { AuthZoneInfo zoneInfo = _authZoneManager.GetAuthZoneInfo(request.Question[0].Name); IReadOnlyList authority = zoneInfo.GetRecords(DnsResourceRecordType.SOA); - return new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.NoError, request.Question, null, authority); + return new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.ServerFailure, request.Question, null, authority); } } @@ -1883,27 +1904,29 @@ namespace DnsServerCore.Dns DnsDatagram request = new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { sample.SampleQuestion }); DnsDatagram response = await ProcessRecursiveQueryAsync(request, new IPEndPoint(IPAddress.Any, 0), DnsTransportProtocol.Udp, sample.ViaForwarders, true); - bool removeFromSampleList = true; + bool addBackToSampleList = false; DateTime utcNow = DateTime.UtcNow; foreach (DnsResourceRecord answer in response.Answer) { if ((answer.OriginalTtlValue > _cachePrefetchEligibility) && (utcNow.AddSeconds(answer.TtlValue) < _cachePrefetchSamplingTimerTriggersOn)) { - //answer expires before next sampling so dont remove from list to allow refreshing it - removeFromSampleList = false; + //answer expires before next sampling so add back to the list to allow refreshing it + addBackToSampleList = true; break; } } - if (removeFromSampleList) - cacheRefreshSampleList[sampleQuestionIndex] = null; + if (addBackToSampleList) + cacheRefreshSampleList[sampleQuestionIndex] = sample; //put back into sample list to allow refreshing it again } catch (Exception ex) { LogManager log = _log; if (log != null) log.Write(ex); + + cacheRefreshSampleList[sampleQuestionIndex] = sample; //put back into sample list to allow refreshing it again } } @@ -2094,7 +2117,10 @@ namespace DnsServerCore.Dns if (!IsCacheRefreshNeeded(sample.SampleQuestion, _cachePrefetchTrigger + 2)) continue; - _ = RefreshCacheAsync(cacheRefreshSampleList, sample, i); + cacheRefreshSampleList[i] = null; //remove from sample list to avoid concurrent refresh attempt + + int sampleQuestionIndex = i; + _ = Task.Run(delegate () { return RefreshCacheAsync(cacheRefreshSampleList, sample, sampleQuestionIndex); }); //run task in threadpool since its long running } } } @@ -2126,6 +2152,14 @@ namespace DnsServerCore.Dns if (log != null) log.Write(ex); } + finally + { + lock (_cacheMaintenanceTimerLock) + { + if (_cacheMaintenanceTimer != null) + _cacheMaintenanceTimer.Change(CACHE_MAINTENANCE_TIMER_PERIODIC_INTERVAL, Timeout.Infinite); + } + } } private void ResetPrefetchTimers() @@ -2165,20 +2199,28 @@ namespace DnsServerCore.Dns private void UpdateThisServer() { - if (_thisServer == null) + if ((_localEndPoints == null) || (_localEndPoints.Count == 0)) { - if ((_localEndPoints == null) || (_localEndPoints.Count == 0)) - _thisServer = new NameServerAddress(_serverDomain, IPAddress.Loopback); - else if (_localEndPoints[0].Address.Equals(IPAddress.Any)) - _thisServer = new NameServerAddress(_serverDomain, new IPEndPoint(IPAddress.Loopback, _localEndPoints[0].Port)); - else if (_localEndPoints[0].Equals(IPAddress.IPv6Any)) - _thisServer = new NameServerAddress(_serverDomain, new IPEndPoint(IPAddress.IPv6Loopback, _localEndPoints[0].Port)); - else - _thisServer = new NameServerAddress(_serverDomain, _localEndPoints[0]); + _thisServer = new NameServerAddress(_serverDomain, IPAddress.Loopback); } else { - _thisServer = new NameServerAddress(_serverDomain, _thisServer.IPEndPoint); + foreach (IPEndPoint localEndPoint in _localEndPoints) + { + if (localEndPoint.Address.Equals(IPAddress.Any)) + { + _thisServer = new NameServerAddress(_serverDomain, new IPEndPoint(IPAddress.Loopback, localEndPoint.Port)); + return; + } + + if (localEndPoint.Address.Equals(IPAddress.IPv6Any)) + { + _thisServer = new NameServerAddress(_serverDomain, new IPEndPoint(IPAddress.IPv6Loopback, localEndPoint.Port)); + return; + } + } + + _thisServer = new NameServerAddress(_serverDomain, _localEndPoints[0]); } } @@ -2449,7 +2491,7 @@ namespace DnsServerCore.Dns _cachePrefetchSamplingTimer = new Timer(CachePrefetchSamplingTimerCallback, null, Timeout.Infinite, Timeout.Infinite); _cachePrefetchRefreshTimer = new Timer(CachePrefetchRefreshTimerCallback, null, Timeout.Infinite, Timeout.Infinite); - _cacheMaintenanceTimer = new Timer(CacheMaintenanceTimerCallback, null, CACHE_MAINTENANCE_TIMER_INITIAL_INTEVAL, CACHE_MAINTENANCE_TIMER_PERIODIC_INTERVAL); + _cacheMaintenanceTimer = new Timer(CacheMaintenanceTimerCallback, null, CACHE_MAINTENANCE_TIMER_INITIAL_INTEVAL, Timeout.Infinite); _state = ServiceState.Running; @@ -2482,10 +2524,13 @@ namespace DnsServerCore.Dns } } - if (_cacheMaintenanceTimer != null) + lock (_cacheMaintenanceTimerLock) { - _cacheMaintenanceTimer.Dispose(); - _cacheMaintenanceTimer = null; + if (_cacheMaintenanceTimer != null) + { + _cacheMaintenanceTimer.Dispose(); + _cacheMaintenanceTimer = null; + } } foreach (Socket udpListener in _udpListeners)