diff --git a/DnsServerCore/Dns/DnsServer.cs b/DnsServerCore/Dns/DnsServer.cs index 600cbf9c..fe117bc6 100644 --- a/DnsServerCore/Dns/DnsServer.cs +++ b/DnsServerCore/Dns/DnsServer.cs @@ -232,13 +232,29 @@ namespace DnsServerCore.Dns try { - UdpReceiveFromResult result; + EndPoint epAny; + + switch (udpListener.AddressFamily) + { + case AddressFamily.InterNetwork: + epAny = new IPEndPoint(IPAddress.Any, 0); + break; + + case AddressFamily.InterNetworkV6: + epAny = new IPEndPoint(IPAddress.IPv6Any, 0); + break; + + default: + throw new NotSupportedException("AddressFamily not supported."); + } + + SocketReceiveFromResult result; while (true) { try { - result = await udpListener.ReceiveFromAsync(recvBuffer); + result = await udpListener.ReceiveFromAsync(recvBuffer, SocketFlags.None, epAny); } catch (SocketException ex) { @@ -248,7 +264,7 @@ namespace DnsServerCore.Dns case SocketError.HostUnreachable: case SocketError.MessageSize: case SocketError.NetworkReset: - result = null; + result = default; break; default: @@ -256,11 +272,11 @@ namespace DnsServerCore.Dns } } - if ((result != null) && (result.BytesReceived > 0)) + if (result.ReceivedBytes > 0) { try { - DnsDatagram request = DnsDatagram.ReadFromUdp(new MemoryStream(recvBuffer, 0, result.BytesReceived, false)); + DnsDatagram request = DnsDatagram.ReadFromUdp(new MemoryStream(recvBuffer, 0, result.ReceivedBytes, false)); _ = ProcessUdpRequestAsync(udpListener, result.RemoteEndPoint as IPEndPoint, request); } @@ -328,7 +344,7 @@ namespace DnsServerCore.Dns } //send dns datagram async - await udpListener.SendToAsync(sendBuffer, 0, (int)sendBufferStream.Position, remoteEP); + await udpListener.SendToAsync(new ArraySegment(sendBuffer, 0, (int)sendBufferStream.Position), SocketFlags.None, remoteEP); LogManager queryLog = _queryLog; if (queryLog != null) @@ -597,14 +613,20 @@ namespace DnsServerCore.Dns { //intentionally blocking public IP addresses from using DNS-over-HTTP (without TLS) //this feature is intended to be used with an SSL terminated reverse proxy like nginx on private network - await SendErrorAsync(stream, 403, "DNS-over-HTTPS (DoH) queries are supported only on HTTPS."); + await SendErrorAsync(stream, "close", 403, "DNS-over-HTTPS (DoH) queries are supported only on HTTPS."); return; } DnsTransportProtocol protocol = DnsTransportProtocol.Udp; string strRequestAcceptTypes = httpRequest.Headers[HttpRequestHeader.Accept]; - if (!string.IsNullOrEmpty(strRequestAcceptTypes)) + if (string.IsNullOrEmpty(strRequestAcceptTypes)) + { + string strContentType = httpRequest.Headers[HttpRequestHeader.ContentType]; + if (strContentType == "application/dns-message") + protocol = DnsTransportProtocol.Https; + } + else { foreach (string acceptType in strRequestAcceptTypes.Split(',')) { @@ -692,7 +714,7 @@ namespace DnsServerCore.Dns dnsResponse.WriteToUdp(mS); byte[] buffer = mS.ToArray(); - await SendContentAsync(stream, "application/dns-message", buffer); + await SendContentAsync(stream, requestConnection, "application/dns-message", buffer); } LogManager queryLog = _queryLog; @@ -730,7 +752,7 @@ namespace DnsServerCore.Dns jsonWriter.Flush(); byte[] buffer = mS.ToArray(); - await SendContentAsync(stream, "application/dns-json; charset=utf-8", buffer); + await SendContentAsync(stream, requestConnection, "application/dns-json; charset=utf-8", buffer); } LogManager queryLog = _queryLog; @@ -746,7 +768,7 @@ namespace DnsServerCore.Dns break; default: - await SendErrorAsync(stream, 406, "Only application/dns-message and application/dns-json types are accepted."); + await SendErrorAsync(stream, requestConnection, 406, "Only application/dns-message and application/dns-json types are accepted."); break; } @@ -760,7 +782,7 @@ namespace DnsServerCore.Dns if (!path.StartsWith("/") || path.Contains("/../") || path.Contains("/.../")) { - await SendErrorAsync(stream, 404); + await SendErrorAsync(stream, requestConnection, 404); break; } @@ -771,11 +793,11 @@ namespace DnsServerCore.Dns if (!path.StartsWith(_dohwwwFolder) || !File.Exists(path)) { - await SendErrorAsync(stream, 404); + await SendErrorAsync(stream, requestConnection, 404); break; } - await SendFileAsync(stream, path); + await SendFileAsync(stream, requestConnection, path); break; } } @@ -798,31 +820,31 @@ namespace DnsServerCore.Dns if (log != null) log.Write(remoteEP, dnsProtocol, ex); - await SendErrorAsync(stream, ex); + await SendErrorAsync(stream, "close", ex); } } - private static async Task SendContentAsync(Stream outputStream, string contentType, byte[] bufferContent) + private static async Task SendContentAsync(Stream outputStream, string connection, string contentType, byte[] bufferContent) { - byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: " + contentType + "\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\n\r\n"); + byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: " + contentType + "\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\nConnection: " + connection + "\r\n\r\n"); await outputStream.WriteAsync(bufferHeader, 0, bufferHeader.Length); await outputStream.WriteAsync(bufferContent, 0, bufferContent.Length); await outputStream.FlushAsync(); } - private static Task SendErrorAsync(Stream outputStream, Exception ex) + private static Task SendErrorAsync(Stream outputStream, string connection, Exception ex) { - return SendErrorAsync(outputStream, 500, ex.ToString()); + return SendErrorAsync(outputStream, connection, 500, ex.ToString()); } - private static async Task SendErrorAsync(Stream outputStream, int statusCode, string message = null) + private static async Task SendErrorAsync(Stream outputStream, string connection, int statusCode, string message = null) { try { string statusString = statusCode + " " + GetHttpStatusString((HttpStatusCode)statusCode); byte[] bufferContent = Encoding.UTF8.GetBytes("" + statusString + "

" + statusString + "

" + (message == null ? "" : "

" + message + "

") + ""); - byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 " + statusString + "\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\n\r\n"); + byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 " + statusString + "\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\nConnection: " + connection + "\r\n\r\n"); await outputStream.WriteAsync(bufferHeader, 0, bufferHeader.Length); await outputStream.WriteAsync(bufferContent, 0, bufferContent.Length); @@ -832,11 +854,11 @@ namespace DnsServerCore.Dns { } } - private static async Task SendFileAsync(Stream outputStream, string filePath) + private static async Task SendFileAsync(Stream outputStream, string connection, string filePath) { using (FileStream fS = new FileStream(filePath, FileMode.Open, FileAccess.Read, FileShare.ReadWrite)) { - byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: " + WebUtilities.GetContentType(filePath).MediaType + "\r\nContent-Length: " + fS.Length + "\r\nCache-Control: private, max-age=300\r\nX-Robots-Tag: noindex, nofollow\r\n\r\n"); + byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: " + WebUtilities.GetContentType(filePath).MediaType + "\r\nContent-Length: " + fS.Length + "\r\nCache-Control: private, max-age=300\r\nX-Robots-Tag: noindex, nofollow\r\nConnection: " + connection + "\r\n\r\n"); await outputStream.WriteAsync(bufferHeader, 0, bufferHeader.Length); await fS.CopyToAsync(outputStream); @@ -904,6 +926,12 @@ namespace DnsServerCore.Dns case DnsResourceRecordType.IXFR: return await ProcessZoneTransferQueryAsync(request, remoteEP); + case DnsResourceRecordType.ANY: + if (protocol == DnsTransportProtocol.Udp) //force TCP for ANY request + return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, true, request.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.NoError, request.Question); + + break; + case DnsResourceRecordType.MAILB: case DnsResourceRecordType.MAILA: return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.NotImplemented, request.Question); @@ -911,29 +939,28 @@ namespace DnsServerCore.Dns case DnsResourceRecordType.FWD: case DnsResourceRecordType.APP: return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.Refused, request.Question); - - default: - DnsDatagram response; - - //check in allowed zone - bool inAllowedZone = _allowedZoneManager.Query(request).RCODE != DnsResponseCode.Refused; - if (!inAllowedZone) - { - //check in blocked zone and block list zone - response = ProcessBlockedQuery(request); - if (response != null) - return response; - } - - //query authoritative zone - response = await ProcessAuthoritativeQueryAsync(request, remoteEP, inAllowedZone, isRecursionAllowed, protocol); - - if ((response.RCODE != DnsResponseCode.Refused) || !request.RecursionDesired || !isRecursionAllowed) - return response; - - //do recursive query - return await ProcessRecursiveQueryAsync(request, remoteEP, protocol, null, !inAllowedZone, false); } + + DnsDatagram response; + + //check in allowed zone + bool inAllowedZone = _allowedZoneManager.Query(request).RCODE != DnsResponseCode.Refused; + if (!inAllowedZone) + { + //check in blocked zone and block list zone + response = ProcessBlockedQuery(request); + if (response != null) + return response; + } + + //query authoritative zone + response = await ProcessAuthoritativeQueryAsync(request, remoteEP, inAllowedZone, isRecursionAllowed, protocol); + + if ((response.RCODE != DnsResponseCode.Refused) || !request.RecursionDesired || !isRecursionAllowed) + return response; + + //do recursive query + return await ProcessRecursiveQueryAsync(request, remoteEP, protocol, null, !inAllowedZone, false); } catch (InvalidDomainNameException) { @@ -1036,7 +1063,7 @@ namespace DnsServerCore.Dns private async Task ProcessAuthoritativeQueryAsync(DnsDatagram request, IPEndPoint remoteEP, bool inAllowedZone, bool isRecursionAllowed, DnsTransportProtocol protocol) { - DnsDatagram response = _authZoneManager.Query(request); + DnsDatagram response = _authZoneManager.Query(request, isRecursionAllowed); response.Tag = StatsResponseType.Authoritative; bool reprocessResponse; @@ -1112,9 +1139,6 @@ namespace DnsServerCore.Dns } while (reprocessResponse); - if (response.RecursionAvailable != isRecursionAllowed) - response = new DnsDatagram(response.Identifier, response.IsResponse, response.OPCODE, response.AuthoritativeAnswer, response.Truncation, response.RecursionDesired, isRecursionAllowed, response.AuthenticData, response.CheckingDisabled, response.RCODE, response.Question, response.Answer, response.Authority, response.Additional); - return response; } @@ -1176,7 +1200,7 @@ namespace DnsServerCore.Dns DnsDatagram newRequest = new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { new DnsQuestionRecord(lastDomain, request.Question[0].Type, request.Question[0].Class) }); //query authoritative zone first - lastResponse = _authZoneManager.Query(newRequest); + lastResponse = _authZoneManager.Query(newRequest, isRecursionAllowed); if (lastResponse.RCODE == DnsResponseCode.Refused) { @@ -1319,7 +1343,7 @@ namespace DnsServerCore.Dns DnsDatagram newRequest = new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { new DnsQuestionRecord(lastDomain, request.Question[0].Type, request.Question[0].Class) }); //query authoritative zone first - lastResponse = _authZoneManager.Query(newRequest); + lastResponse = _authZoneManager.Query(newRequest, isRecursionAllowed); if (lastResponse.RCODE == DnsResponseCode.Refused) { @@ -1894,7 +1918,7 @@ namespace DnsServerCore.Dns { reQueryAuthZone = false; - DnsDatagram response = _authZoneManager.Query(new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, false, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { eligibleQuerySample })); + DnsDatagram response = _authZoneManager.Query(new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, false, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { eligibleQuerySample }), true); switch (response.RCODE) { case DnsResponseCode.Refused: //zone not hosted; do refresh