diff --git a/DnsServerCore/Dns/DnsServer.cs b/DnsServerCore/Dns/DnsServer.cs index 0c10f1f4..d17a9ffc 100644 --- a/DnsServerCore/Dns/DnsServer.cs +++ b/DnsServerCore/Dns/DnsServer.cs @@ -217,25 +217,19 @@ namespace DnsServerCore.Dns #region private - private void ReadUdpRequestAsync(object parameter) + private async Task ReadUdpRequestAsync(Socket udpListener) { - Socket udpListener = parameter as Socket; - EndPoint remoteEP; byte[] recvBuffer = new byte[512]; - int bytesRecv; - - if (udpListener.AddressFamily == AddressFamily.InterNetwork) - remoteEP = new IPEndPoint(IPAddress.Any, 0); - else - remoteEP = new IPEndPoint(IPAddress.IPv6Any, 0); try { + UdpReceiveFromResult result; + while (true) { try { - bytesRecv = udpListener.ReceiveFrom(recvBuffer, ref remoteEP); + result = await udpListener.ReceiveFromAsync(recvBuffer); } catch (SocketException ex) { @@ -245,7 +239,7 @@ namespace DnsServerCore.Dns case SocketError.HostUnreachable: case SocketError.MessageSize: case SocketError.NetworkReset: - bytesRecv = 0; + result = null; break; default: @@ -253,19 +247,19 @@ namespace DnsServerCore.Dns } } - if (bytesRecv > 0) + if ((result != null) && (result.BytesReceived > 0)) { try { - DnsDatagram request = DnsDatagram.ReadFromUdp(new MemoryStream(recvBuffer, 0, bytesRecv, false)); + DnsDatagram request = DnsDatagram.ReadFromUdp(new MemoryStream(recvBuffer, 0, result.BytesReceived, false)); - _ = ProcessUdpRequestAsync(udpListener, remoteEP as IPEndPoint, request); + _ = ProcessUdpRequestAsync(udpListener, result.RemoteEndPoint as IPEndPoint, request); } catch (Exception ex) { LogManager log = _log; if (log != null) - log.Write(remoteEP as IPEndPoint, DnsTransportProtocol.Udp, ex); + log.Write(result.RemoteEndPoint as IPEndPoint, DnsTransportProtocol.Udp, ex); } } } @@ -277,7 +271,7 @@ namespace DnsServerCore.Dns LogManager log = _log; if (log != null) - log.Write(remoteEP as IPEndPoint, DnsTransportProtocol.Udp, ex); + log.Write(ex); throw; } @@ -349,17 +343,8 @@ namespace DnsServerCore.Dns } } - private void AcceptConnectionAsync(object parameter) + private async Task AcceptConnectionAsync(Socket tcpListener, DnsTransportProtocol protocol, bool usingHttps) { - object[] parameters = parameter as object[]; - - Socket tcpListener = parameters[0] as Socket; - DnsTransportProtocol protocol = (DnsTransportProtocol)parameters[1]; - - bool usingHttps = true; - if (parameters.Length > 2) - usingHttps = (bool)parameters[2]; - IPEndPoint localEP = tcpListener.LocalEndPoint as IPEndPoint; try @@ -372,7 +357,7 @@ namespace DnsServerCore.Dns while (true) { - Socket socket = tcpListener.Accept(); + Socket socket = await tcpListener.AcceptAsync(); _ = ProcessConnectionAsync(socket, protocol, usingHttps); } @@ -2231,61 +2216,36 @@ namespace DnsServerCore.Dns } //start reading query packets - int listenerThreadCount = Math.Max(1, Environment.ProcessorCount); + int listenerTaskCount = Math.Max(1, Environment.ProcessorCount); foreach (Socket udpListener in _udpListeners) { - for (int i = 0; i < listenerThreadCount; i++) - { - Thread thread = new Thread(ReadUdpRequestAsync); - thread.Name = "DNS UDP Read Request [" + i + "]"; - thread.IsBackground = true; - thread.Start(udpListener); - } + for (int i = 0; i < listenerTaskCount; i++) + _ = ReadUdpRequestAsync(udpListener); } foreach (Socket tcpListener in _tcpListeners) { - for (int i = 0; i < listenerThreadCount; i++) - { - Thread thread = new Thread(AcceptConnectionAsync); - thread.Name = "DNS TCP Read Request [" + i + "]"; - thread.IsBackground = true; - thread.Start(new object[] { tcpListener, DnsTransportProtocol.Tcp }); - } + for (int i = 0; i < listenerTaskCount; i++) + _ = AcceptConnectionAsync(tcpListener, DnsTransportProtocol.Tcp, false); } foreach (Socket httpListener in _httpListeners) { - for (int i = 0; i < listenerThreadCount; i++) - { - Thread thread = new Thread(AcceptConnectionAsync); - thread.Name = "DNS HTTP Read Request [" + i + "]"; - thread.IsBackground = true; - thread.Start(new object[] { httpListener, DnsTransportProtocol.Https, false }); - } + for (int i = 0; i < listenerTaskCount; i++) + _ = AcceptConnectionAsync(httpListener, DnsTransportProtocol.Https, false); } foreach (Socket tlsListener in _tlsListeners) { - for (int i = 0; i < listenerThreadCount; i++) - { - Thread thread = new Thread(AcceptConnectionAsync); - thread.Name = "DNS TLS Read Request [" + i + "]"; - thread.IsBackground = true; - thread.Start(new object[] { tlsListener, DnsTransportProtocol.Tls }); - } + for (int i = 0; i < listenerTaskCount; i++) + _ = AcceptConnectionAsync(tlsListener, DnsTransportProtocol.Tls, false); } foreach (Socket httpsListener in _httpsListeners) { - for (int i = 0; i < listenerThreadCount; i++) - { - Thread thread = new Thread(AcceptConnectionAsync); - thread.Name = "DNS HTTPS Read Request [" + i + "]"; - thread.IsBackground = true; - thread.Start(new object[] { httpsListener, DnsTransportProtocol.Https }); - } + for (int i = 0; i < listenerTaskCount; i++) + _ = AcceptConnectionAsync(httpsListener, DnsTransportProtocol.Https, true); } _cachePrefetchSamplingTimer = new Timer(CachePrefetchSamplingTimerCallback, null, Timeout.Infinite, Timeout.Infinite);