DnsServer: implemented DoT and DoH protocol support.

This commit is contained in:
Shreyas Zare
2019-03-02 20:05:05 +05:30
parent ce43a25e43
commit 2f8c1a97c8

View File

@@ -17,12 +17,17 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
using Newtonsoft.Json;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.Specialized;
using System.IO;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
using TechnitiumLibrary.IO;
using TechnitiumLibrary.Net;
@@ -47,14 +52,20 @@ namespace DnsServerCore
#region variables
const int UDP_LISTENER_THREAD_COUNT = 3;
const int LISTENER_THREAD_COUNT = 3;
IPEndPoint[] _localEPs;
IPAddress[] _localIPs;
List<Socket> _udpListeners = new List<Socket>();
List<Socket> _tcpListeners = new List<Socket>();
List<Socket> _tlsListeners = new List<Socket>();
List<Socket> _httpsListeners = new List<Socket>();
List<Thread> _listenerThreads = new List<Thread>();
bool _enableDoT = false;
bool _enableDoH = false;
X509Certificate2 _certificate;
readonly Zone _authoritativeZoneRoot = new Zone(true);
readonly Zone _cacheZoneRoot = new Zone(false);
readonly Zone _allowedZoneRoot = new Zone(true);
@@ -66,8 +77,8 @@ namespace DnsServerCore
bool _allowRecursionOnlyForPrivateNetworks = false;
NetProxy _proxy;
NameServerAddress[] _forwarders;
DnsClientProtocol _forwarderProtocol = DnsClientProtocol.Udp;
DnsClientProtocol _recursiveResolveProtocol = DnsClientProtocol.Udp;
DnsTransportProtocol _forwarderProtocol = DnsTransportProtocol.Udp;
DnsTransportProtocol _recursiveResolveProtocol = DnsTransportProtocol.Udp;
bool _preferIPv6 = false;
int _retries = 3;
int _timeout = 2000;
@@ -102,30 +113,16 @@ namespace DnsServerCore
}
public DnsServer()
: this(new IPEndPoint[] { new IPEndPoint(IPAddress.Any, 53), new IPEndPoint(IPAddress.IPv6Any, 53) })
: this(new IPAddress[] { IPAddress.Any, IPAddress.IPv6Any })
{ }
public DnsServer(IPAddress localIP)
: this(new IPEndPoint(localIP, 53))
{ }
public DnsServer(IPEndPoint localEP)
: this(new IPEndPoint[] { localEP })
: this(new IPAddress[] { localIP })
{ }
public DnsServer(IPAddress[] localIPs)
{
_localEPs = new IPEndPoint[localIPs.Length];
for (int i = 0; i < _localEPs.Length; i++)
_localEPs[i] = new IPEndPoint(localIPs[i], 53);
_dnsCache = new DnsCache(_cacheZoneRoot);
}
public DnsServer(IPEndPoint[] localEPs)
{
_localEPs = localEPs;
_localIPs = localIPs;
_dnsCache = new DnsCache(_cacheZoneRoot);
}
@@ -212,7 +209,7 @@ namespace DnsServerCore
{
LogManager log = _log;
if (log != null)
log.Write(remoteEP as IPEndPoint, false, ex);
log.Write(remoteEP as IPEndPoint, DnsTransportProtocol.Udp, ex);
}
}
}
@@ -224,7 +221,7 @@ namespace DnsServerCore
LogManager log = _log;
if (log != null)
log.Write(remoteEP as IPEndPoint, false, ex);
log.Write(remoteEP as IPEndPoint, DnsTransportProtocol.Udp, ex);
throw;
}
@@ -240,7 +237,7 @@ namespace DnsServerCore
try
{
DnsDatagram response = ProcessQuery(request, remoteEP, false);
DnsDatagram response = ProcessQuery(request, remoteEP, DnsTransportProtocol.Udp);
//send response
if (response != null)
@@ -266,7 +263,7 @@ namespace DnsServerCore
LogManager queryLog = _queryLog;
if (queryLog != null)
queryLog.Write(remoteEP as IPEndPoint, false, request, response);
queryLog.Write(remoteEP as IPEndPoint, DnsTransportProtocol.Udp, request, response);
StatsManager stats = _stats;
if (stats != null)
@@ -280,17 +277,21 @@ namespace DnsServerCore
LogManager queryLog = _queryLog;
if (queryLog != null)
queryLog.Write(remoteEP as IPEndPoint, false, request, null);
queryLog.Write(remoteEP as IPEndPoint, DnsTransportProtocol.Udp, request, null);
LogManager log = _log;
if (log != null)
log.Write(remoteEP as IPEndPoint, false, ex);
log.Write(remoteEP as IPEndPoint, DnsTransportProtocol.Udp, ex);
}
}
private void AcceptTcpConnectionAsync(object parameter)
private void AcceptConnectionAsync(object parameter)
{
Socket tcpListener = parameter as Socket;
object[] parameters = parameter as object[];
Socket tcpListener = parameters[0] as Socket;
DnsTransportProtocol protocol = (DnsTransportProtocol)parameters[1];
IPEndPoint localEP = tcpListener.LocalEndPoint as IPEndPoint;
try
@@ -305,7 +306,51 @@ namespace DnsServerCore
{
Socket socket = tcpListener.Accept();
ThreadPool.QueueUserWorkItem(ReadTcpRequestAsync, socket);
ThreadPool.QueueUserWorkItem(delegate (object state)
{
EndPoint remoteEP = null;
try
{
remoteEP = socket.RemoteEndPoint;
switch (protocol)
{
case DnsTransportProtocol.Tcp:
ReadStreamRequest(new NetworkStream(socket), remoteEP, protocol);
break;
case DnsTransportProtocol.Tls:
SslStream tlsStream = new SslStream(new NetworkStream(socket));
tlsStream.AuthenticateAsServer(_certificate);
ReadStreamRequest(tlsStream, remoteEP, protocol);
break;
case DnsTransportProtocol.Https:
SslStream httpsStream = new SslStream(new NetworkStream(socket));
httpsStream.AuthenticateAsServer(_certificate);
ProcessHttpsRequest(httpsStream, remoteEP);
break;
}
}
catch (IOException)
{
//ignore IO exceptions
}
catch (Exception ex)
{
LogManager log = _log;
if (log != null)
log.Write(remoteEP as IPEndPoint, protocol, ex);
}
finally
{
if (socket != null)
socket.Dispose();
}
});
}
}
catch (Exception ex)
@@ -315,24 +360,21 @@ namespace DnsServerCore
LogManager log = _log;
if (log != null)
log.Write(localEP, true, ex);
log.Write(localEP, protocol, ex);
throw;
}
}
private void ReadTcpRequestAsync(object parameter)
private void ReadStreamRequest(Stream stream, EndPoint remoteEP, DnsTransportProtocol protocol)
{
Socket tcpSocket = parameter as Socket;
DnsDatagram request = null;
EndPoint remoteEP = null;
try
{
remoteEP = tcpSocket.RemoteEndPoint as IPEndPoint;
Stream tcpStream = new WriteBufferedStream(new NetworkStream(tcpSocket), 2048);
OffsetStream recvDatagramStream = new OffsetStream(tcpStream, 0, 0);
MemoryStream sendBufferStream = new MemoryStream(64);
OffsetStream recvDatagramStream = new OffsetStream(stream, 0, 0);
Stream writeBufferedStream = new WriteBufferedStream(stream, 2048);
MemoryStream writeBuffer = new MemoryStream(64);
byte[] lengthBuffer = new byte[2];
ushort length;
@@ -341,7 +383,7 @@ namespace DnsServerCore
request = null;
//read dns datagram length
tcpStream.ReadBytes(lengthBuffer, 0, 2);
stream.ReadBytes(lengthBuffer, 0, 2);
Array.Reverse(lengthBuffer, 0, 2);
length = BitConverter.ToUInt16(lengthBuffer, 0);
@@ -350,7 +392,7 @@ namespace DnsServerCore
request = new DnsDatagram(recvDatagramStream);
//process request async
ThreadPool.QueueUserWorkItem(ProcessTcpRequestAsync, new object[] { remoteEP, tcpStream, request, sendBufferStream });
ThreadPool.QueueUserWorkItem(ProcessStreamRequestAsync, new object[] { writeBufferedStream, writeBuffer, remoteEP, request, protocol });
}
}
catch (IOException)
@@ -361,57 +403,53 @@ namespace DnsServerCore
{
LogManager queryLog = _queryLog;
if ((queryLog != null) && (request != null))
queryLog.Write(remoteEP as IPEndPoint, true, request, null);
queryLog.Write(remoteEP as IPEndPoint, protocol, request, null);
LogManager log = _log;
if (log != null)
log.Write(remoteEP as IPEndPoint, true, ex);
}
finally
{
if (tcpSocket != null)
tcpSocket.Dispose();
log.Write(remoteEP as IPEndPoint, protocol, ex);
}
}
private void ProcessTcpRequestAsync(object parameter)
private void ProcessStreamRequestAsync(object parameter)
{
object[] parameters = parameter as object[];
EndPoint remoteEP = parameters[0] as EndPoint;
Stream tcpStream = parameters[1] as Stream;
DnsDatagram request = parameters[2] as DnsDatagram;
MemoryStream sendBufferStream = parameters[3] as MemoryStream;
Stream stream = parameters[0] as Stream;
MemoryStream writeBuffer = parameters[1] as MemoryStream;
EndPoint remoteEP = parameters[2] as EndPoint;
DnsDatagram request = parameters[3] as DnsDatagram;
DnsTransportProtocol protocol = (DnsTransportProtocol)parameters[4];
try
{
DnsDatagram response = ProcessQuery(request, remoteEP, true);
DnsDatagram response = ProcessQuery(request, remoteEP, protocol);
//send response
if (response != null)
{
lock (tcpStream)
lock (stream)
{
//write dns datagram
sendBufferStream.Position = 0;
response.WriteTo(sendBufferStream);
writeBuffer.Position = 0;
response.WriteTo(writeBuffer);
//write dns datagram length
ushort length = Convert.ToUInt16(sendBufferStream.Position);
ushort length = Convert.ToUInt16(writeBuffer.Position);
byte[] lengthBuffer = BitConverter.GetBytes(length);
Array.Reverse(lengthBuffer, 0, 2);
tcpStream.Write(lengthBuffer);
stream.Write(lengthBuffer);
//send dns datagram
sendBufferStream.Position = 0;
sendBufferStream.CopyTo(tcpStream, 512, length);
writeBuffer.Position = 0;
writeBuffer.CopyTo(stream, 512, length);
tcpStream.Flush();
stream.Flush();
}
LogManager queryLog = _queryLog;
if (queryLog != null)
queryLog.Write(remoteEP as IPEndPoint, true, request, response);
queryLog.Write(remoteEP as IPEndPoint, protocol, request, response);
StatsManager stats = _stats;
if (stats != null)
@@ -426,11 +464,309 @@ namespace DnsServerCore
{
LogManager queryLog = _queryLog;
if ((queryLog != null) && (request != null))
queryLog.Write(remoteEP as IPEndPoint, true, request, null);
queryLog.Write(remoteEP as IPEndPoint, protocol, request, null);
LogManager log = _log;
if (log != null)
log.Write(remoteEP as IPEndPoint, true, ex);
log.Write(remoteEP as IPEndPoint, protocol, ex);
}
}
private void ProcessHttpsRequest(Stream stream, EndPoint remoteEP)
{
DnsDatagram dnsRequest = null;
DnsTransportProtocol dnsProtocol = DnsTransportProtocol.Https;
try
{
while (true)
{
string requestMethod;
string requestPath;
NameValueCollection requestQueryString = new NameValueCollection();
string requestProtocol;
WebHeaderCollection requestHeaders = new WebHeaderCollection();
#region parse http request
using (MemoryStream mS = new MemoryStream())
{
//read http request header into memory stream
int byteRead;
int crlfCount = 0;
while (true)
{
byteRead = stream.ReadByte();
switch (byteRead)
{
case '\r':
case '\n':
crlfCount++;
break;
case -1:
throw new EndOfStreamException();
default:
crlfCount = 0;
break;
}
mS.WriteByte((byte)byteRead);
if (crlfCount == 4)
break; //http request completed
}
mS.Position = 0;
StreamReader sR = new StreamReader(mS);
string[] requestParts = sR.ReadLine().Split(new char[] { ' ' }, 3);
if (requestParts.Length != 3)
throw new InvalidDataException("Invalid HTTP request.");
requestMethod = requestParts[0];
string pathAndQueryString = requestParts[1];
requestProtocol = requestParts[2];
string[] requestPathAndQueryParts = pathAndQueryString.Split(new char[] { '?' }, 2);
requestPath = requestPathAndQueryParts[0];
string queryString = null;
if (requestPathAndQueryParts.Length > 1)
queryString = requestPathAndQueryParts[1];
if (!string.IsNullOrEmpty(queryString))
{
foreach (string item in queryString.Split(new char[] { '&' }, StringSplitOptions.RemoveEmptyEntries))
{
string[] itemParts = item.Split(new char[] { '=' }, 2);
string name = itemParts[0];
string value = null;
if (itemParts.Length > 1)
value = itemParts[1];
requestQueryString.Add(name, value);
}
}
while (true)
{
string line = sR.ReadLine();
if (string.IsNullOrEmpty(line))
break;
string[] parts = line.Split(new char[] { ':' }, 2);
if (parts.Length != 2)
throw new InvalidDataException("Invalid HTTP request.");
requestHeaders.Add(parts[0], parts[1]);
}
}
#endregion
string requestConnection = requestHeaders[HttpRequestHeader.Connection];
if (string.IsNullOrEmpty(requestConnection))
requestConnection = "close";
if (requestPath != "/dns-query")
{
Send404(stream);
return;
}
DnsTransportProtocol protocol = DnsTransportProtocol.Udp;
string strRequestAcceptTypes = requestHeaders[HttpRequestHeader.Accept];
if (!string.IsNullOrEmpty(strRequestAcceptTypes))
{
protocol = DnsTransportProtocol.Udp;
foreach (string acceptType in strRequestAcceptTypes.Split(','))
{
if (acceptType == "application/dns-message")
{
protocol = DnsTransportProtocol.Https;
break;
}
else if (acceptType == "application/dns-json")
{
protocol = DnsTransportProtocol.HttpsJson;
dnsProtocol = DnsTransportProtocol.HttpsJson;
break;
}
}
}
switch (protocol)
{
case DnsTransportProtocol.Https:
#region https wire format
{
switch (requestMethod)
{
case "GET":
string strRequest = requestQueryString["dns"];
if (string.IsNullOrEmpty(strRequest))
throw new ArgumentNullException("dns");
//convert from base64url to base64
strRequest = strRequest.Replace('-', '+');
strRequest = strRequest.Replace('_', '/');
//add padding
int x = strRequest.Length % 4;
if (x > 0)
strRequest = strRequest.PadRight(strRequest.Length - x + 4, '=');
dnsRequest = new DnsDatagram(new MemoryStream(Convert.FromBase64String(strRequest)));
break;
case "POST":
string strContentType = requestHeaders[HttpRequestHeader.ContentType];
if (strContentType != "application/dns-message")
throw new NotSupportedException("DNS request type not supported: " + strContentType);
dnsRequest = new DnsDatagram(stream);
break;
default:
throw new NotSupportedException("DoH request type not supported."); ;
}
DnsDatagram dnsResponse = ProcessQuery(dnsRequest, remoteEP, protocol);
if (dnsResponse != null)
{
using (MemoryStream mS = new MemoryStream())
{
dnsResponse.WriteTo(mS);
byte[] buffer = mS.ToArray();
Send200(stream, "application/dns-message", buffer);
}
LogManager queryLog = _queryLog;
if (queryLog != null)
queryLog.Write(remoteEP as IPEndPoint, protocol, dnsRequest, dnsResponse);
StatsManager stats = _stats;
if (stats != null)
stats.Update(dnsResponse, (remoteEP as IPEndPoint).Address);
}
}
#endregion
break;
case DnsTransportProtocol.HttpsJson:
#region https json format
{
string strName = requestQueryString["name"];
if (string.IsNullOrEmpty(strName))
throw new ArgumentNullException("name");
string strType = requestQueryString["type"];
if (string.IsNullOrEmpty(strType))
strType = "1";
dnsRequest = new DnsDatagram(new DnsHeader(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, 1, 0, 0, 0), new DnsQuestionRecord[] { new DnsQuestionRecord(strName, (DnsResourceRecordType)int.Parse(strType), DnsClass.IN) }, null, null, null);
DnsDatagram dnsResponse = ProcessQuery(dnsRequest, remoteEP, protocol);
if (dnsResponse != null)
{
using (MemoryStream mS = new MemoryStream())
{
JsonTextWriter jsonWriter = new JsonTextWriter(new StreamWriter(mS));
dnsResponse.WriteTo(jsonWriter);
jsonWriter.Flush();
byte[] buffer = mS.ToArray();
Send200(stream, "application/dns-json; charset=utf-8", buffer);
}
LogManager queryLog = _queryLog;
if (queryLog != null)
queryLog.Write(remoteEP as IPEndPoint, protocol, dnsRequest, dnsResponse);
StatsManager stats = _stats;
if (stats != null)
stats.Update(dnsResponse, (remoteEP as IPEndPoint).Address);
}
}
#endregion
break;
default:
Send406(stream, "Only application/dns-message and application/dns-json types are accepted.");
return;
}
if (requestConnection.Equals("close", StringComparison.CurrentCultureIgnoreCase))
break;
}
}
catch (IOException)
{
//ignore IO exceptions
}
catch (Exception ex)
{
LogManager queryLog = _queryLog;
if ((queryLog != null) && (dnsRequest != null))
queryLog.Write(remoteEP as IPEndPoint, dnsProtocol, dnsRequest, null);
LogManager log = _log;
if (log != null)
log.Write(remoteEP as IPEndPoint, dnsProtocol, ex);
}
}
private static void Send404(Stream outputStream)
{
byte[] bufferContent = Encoding.UTF8.GetBytes("<h1>404 Not Found</h1>");
byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 404 Not Found\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\n\r\n");
using (MemoryStream mS = new MemoryStream())
{
mS.Write(bufferHeader, 0, bufferHeader.Length);
mS.Write(bufferContent, 0, bufferContent.Length);
byte[] buffer = mS.ToArray();
outputStream.Write(buffer, 0, buffer.Length);
}
}
private static void Send406(Stream outputStream, string message)
{
byte[] bufferContent = Encoding.UTF8.GetBytes("<h1>406 Not Acceptable</h1><p>" + message + "</p>");
byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 406 Not Acceptable\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\n\r\n");
using (MemoryStream mS = new MemoryStream())
{
mS.Write(bufferHeader, 0, bufferHeader.Length);
mS.Write(bufferContent, 0, bufferContent.Length);
byte[] buffer = mS.ToArray();
outputStream.Write(buffer, 0, buffer.Length);
}
}
private static void Send200(Stream outputStream, string contentType, byte[] bufferContent)
{
byte[] bufferHeader = Encoding.UTF8.GetBytes("HTTP/1.1 200 OK\r\nDate: " + DateTime.UtcNow.ToString("r") + "\r\nContent-Type: " + contentType + "\r\nContent-Length: " + bufferContent.Length + "\r\nX-Robots-Tag: noindex, nofollow\r\n\r\n");
using (MemoryStream mS = new MemoryStream())
{
mS.Write(bufferHeader, 0, bufferHeader.Length);
mS.Write(bufferContent, 0, bufferContent.Length);
byte[] buffer = mS.ToArray();
outputStream.Write(buffer, 0, buffer.Length);
}
}
@@ -455,7 +791,7 @@ namespace DnsServerCore
return true;
}
private DnsDatagram ProcessQuery(DnsDatagram request, EndPoint remoteEP, bool tcp)
internal DnsDatagram ProcessQuery(DnsDatagram request, EndPoint remoteEP, DnsTransportProtocol protocol)
{
if (request.Header.IsResponse)
return null;
@@ -536,7 +872,7 @@ namespace DnsServerCore
{
LogManager log = _log;
if (log != null)
log.Write(remoteEP as IPEndPoint, tcp, ex);
log.Write(remoteEP as IPEndPoint, protocol, ex);
return new DnsDatagram(new DnsHeader(request.Header.Identifier, true, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, isRecursionAllowed, false, false, DnsResponseCode.ServerFailure, request.Header.QDCOUNT, 0, 0, 0), request.Question, null, null, null);
}
@@ -560,6 +896,7 @@ namespace DnsServerCore
if ((lastRR.Type != questionType) && (lastRR.Type == DnsResourceRecordType.CNAME) && (questionType != DnsResourceRecordType.ANY))
{
//resolve cname record
List<DnsResourceRecord> responseAnswer = new List<DnsResourceRecord>();
responseAnswer.AddRange(response.Answer);
@@ -570,26 +907,42 @@ namespace DnsServerCore
{
DnsDatagram cnameRequest = new DnsDatagram(new DnsHeader(0, false, DnsOpcode.StandardQuery, false, false, request.Header.RecursionDesired, false, false, false, DnsResponseCode.NoError, 1, 0, 0, 0), new DnsQuestionRecord[] { new DnsQuestionRecord((lastRR.RDATA as DnsCNAMERecord).CNAMEDomainName, questionType, DnsClass.IN) }, null, null, null);
//query authoritative zone first
lastResponse = _authoritativeZoneRoot.Query(cnameRequest);
if (lastResponse.Header.RCODE == DnsResponseCode.Refused)
{
if (!cnameRequest.Header.RecursionDesired || !isRecursionAllowed)
break;
//not found in auth zone
if (!isRecursionAllowed || !cnameRequest.Header.RecursionDesired)
break; //break since no recursion allowed/desired
//do recursion
lastResponse = ProcessRecursiveQuery(cnameRequest);
cacheHit &= ("cacheHit".Equals(lastResponse.Tag));
}
else if ((lastResponse.Header.RCODE == DnsResponseCode.NoError) && (lastResponse.Answer.Length == 0) && (lastResponse.Authority.Length > 0) && (lastResponse.Authority[0].Type == DnsResourceRecordType.NS))
{
//found delegated zone
if (!isRecursionAllowed || !cnameRequest.Header.RecursionDesired)
break; //break since no recursion allowed/desired
//do recursive resolution using delegated authority name servers
NameServerAddress[] nameServers = NameServerAddress.GetNameServersFromResponse(lastResponse, _preferIPv6);
lastResponse = ProcessRecursiveQuery(cnameRequest, nameServers);
cacheHit &= ("cacheHit".Equals(lastResponse.Tag));
}
//check last response
if ((lastResponse.Header.RCODE != DnsResponseCode.NoError) || (lastResponse.Answer.Length == 0))
break;
break; //cannot proceed to resolve cname further
responseAnswer.AddRange(lastResponse.Answer);
lastRR = lastResponse.Answer[lastResponse.Answer.Length - 1];
if (lastRR.Type != DnsResourceRecordType.CNAME)
break;
break; //cname was resolved
}
DnsResponseCode rcode;
@@ -628,7 +981,7 @@ namespace DnsServerCore
else if ((response.Authority.Length > 0) && (response.Authority[0].Type == DnsResourceRecordType.NS) && isRecursionAllowed)
{
//do recursive resolution using response authority name servers
NameServerAddress[] nameServers = NameServerAddress.GetNameServersFromResponse(response, _preferIPv6, false);
NameServerAddress[] nameServers = NameServerAddress.GetNameServersFromResponse(response, _preferIPv6);
return ProcessRecursiveQuery(request, nameServers);
}
@@ -748,7 +1101,7 @@ namespace DnsServerCore
}
//select protocol
DnsClientProtocol protocol;
DnsTransportProtocol protocol;
if ((viaNameServers == null) && (_forwarders != null))
{
@@ -793,9 +1146,11 @@ namespace DnsServerCore
_state = ServiceState.Starting;
//bind on all local end points
for (int i = 0; i < _localEPs.Length; i++)
for (int i = 0; i < _localIPs.Length; i++)
{
Socket udpListener = new Socket(_localEPs[i].AddressFamily, SocketType.Dgram, ProtocolType.Udp);
IPEndPoint dnsEP = new IPEndPoint(_localIPs[i], 53);
Socket udpListener = new Socket(dnsEP.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
#region this code ignores ICMP port unreachable responses which creates SocketException in ReceiveFrom()
@@ -812,50 +1167,102 @@ namespace DnsServerCore
try
{
udpListener.Bind(_localEPs[i]);
udpListener.Bind(dnsEP);
_udpListeners.Add(udpListener);
LogManager log = _log;
if (log != null)
log.Write(_localEPs[i], false, "DNS Server was bound successfully.");
log.Write(dnsEP, DnsTransportProtocol.Udp, "DNS Server was bound successfully.");
}
catch (Exception ex)
{
LogManager log = _log;
if (log != null)
log.Write(_localEPs[i], false, ex);
log.Write(dnsEP, DnsTransportProtocol.Udp, ex);
udpListener.Dispose();
}
Socket tcpListener = new Socket(_localEPs[i].AddressFamily, SocketType.Stream, ProtocolType.Tcp);
Socket tcpListener = new Socket(dnsEP.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
try
{
tcpListener.Bind(_localEPs[i]);
tcpListener.Bind(dnsEP);
tcpListener.Listen(100);
_tcpListeners.Add(tcpListener);
LogManager log = _log;
if (log != null)
log.Write(_localEPs[i], true, "DNS Server was bound successfully.");
log.Write(dnsEP, DnsTransportProtocol.Tcp, "DNS Server was bound successfully.");
}
catch (Exception ex)
{
LogManager log = _log;
if (log != null)
log.Write(_localEPs[i], true, ex);
log.Write(dnsEP, DnsTransportProtocol.Tcp, ex);
tcpListener.Dispose();
}
if (_enableDoT && (_certificate != null))
{
IPEndPoint tlsEP = new IPEndPoint(_localIPs[i], 853);
Socket tlsListener = new Socket(tlsEP.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
try
{
tlsListener.Bind(tlsEP);
tlsListener.Listen(100);
_tlsListeners.Add(tlsListener);
LogManager log = _log;
if (log != null)
log.Write(tlsEP, DnsTransportProtocol.Tls, "DNS Server was bound successfully.");
}
catch (Exception ex)
{
LogManager log = _log;
if (log != null)
log.Write(tlsEP, DnsTransportProtocol.Tls, ex);
tlsListener.Dispose();
}
}
if (_enableDoH && (_certificate != null))
{
IPEndPoint httpsEP = new IPEndPoint(_localIPs[i], 443);
Socket httpsListener = new Socket(httpsEP.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
try
{
httpsListener.Bind(httpsEP);
httpsListener.Listen(100);
_httpsListeners.Add(httpsListener);
LogManager log = _log;
if (log != null)
log.Write(httpsEP, DnsTransportProtocol.Https, "DNS Server was bound successfully.");
}
catch (Exception ex)
{
LogManager log = _log;
if (log != null)
log.Write(httpsEP, DnsTransportProtocol.Https, ex);
httpsListener.Dispose();
}
}
}
//start reading query packets
foreach (Socket udpListener in _udpListeners)
{
for (int i = 0; i < UDP_LISTENER_THREAD_COUNT; i++)
for (int i = 0; i < LISTENER_THREAD_COUNT; i++)
{
Thread listenerThread = new Thread(ReadUdpRequestAsync);
listenerThread.IsBackground = true;
@@ -867,11 +1274,38 @@ namespace DnsServerCore
foreach (Socket tcpListener in _tcpListeners)
{
Thread listenerThread = new Thread(AcceptTcpConnectionAsync);
listenerThread.IsBackground = true;
listenerThread.Start(tcpListener);
for (int i = 0; i < LISTENER_THREAD_COUNT; i++)
{
Thread listenerThread = new Thread(AcceptConnectionAsync);
listenerThread.IsBackground = true;
listenerThread.Start(new object[] { tcpListener, DnsTransportProtocol.Tcp });
_listenerThreads.Add(listenerThread);
_listenerThreads.Add(listenerThread);
}
}
foreach (Socket tlsListener in _tlsListeners)
{
for (int i = 0; i < LISTENER_THREAD_COUNT; i++)
{
Thread listenerThread = new Thread(AcceptConnectionAsync);
listenerThread.IsBackground = true;
listenerThread.Start(new object[] { tlsListener, DnsTransportProtocol.Tls });
_listenerThreads.Add(listenerThread);
}
}
foreach (Socket httpsListener in _httpsListeners)
{
for (int i = 0; i < LISTENER_THREAD_COUNT; i++)
{
Thread listenerThread = new Thread(AcceptConnectionAsync);
listenerThread.IsBackground = true;
listenerThread.Start(new object[] { httpsListener, DnsTransportProtocol.Https });
_listenerThreads.Add(listenerThread);
}
}
_state = ServiceState.Running;
@@ -890,9 +1324,17 @@ namespace DnsServerCore
foreach (Socket tcpListener in _tcpListeners)
tcpListener.Dispose();
foreach (Socket tlsListener in _tlsListeners)
tlsListener.Dispose();
foreach (Socket httpsListener in _httpsListeners)
httpsListener.Dispose();
_listenerThreads.Clear();
_udpListeners.Clear();
_tcpListeners.Clear();
_tlsListeners.Clear();
_httpsListeners.Clear();
_state = ServiceState.Stopped;
}
@@ -901,16 +1343,10 @@ namespace DnsServerCore
#region properties
public IPEndPoint[] LocalEndPoints
public IPAddress[] LocalAddresses
{
get { return _localEPs; }
set
{
if (_state != ServiceState.Stopped)
throw new InvalidOperationException("DNS Server is already running.");
_localEPs = value;
}
get { return _localIPs; }
set { _localIPs = value; }
}
public string ServerDomain
@@ -924,6 +1360,30 @@ namespace DnsServerCore
}
}
public bool EnableDoT
{
get { return _enableDoT; }
set { _enableDoT = value; }
}
public bool EnableDoH
{
get { return _enableDoH; }
set { _enableDoH = value; }
}
public X509Certificate2 Certificate
{
get { return _certificate; }
set
{
if (!value.HasPrivateKey)
throw new ArgumentException("Tls certificate does not contain private key.");
_certificate = value;
}
}
public Zone AuthoritativeZoneRoot
{ get { return _authoritativeZoneRoot; } }
@@ -976,13 +1436,13 @@ namespace DnsServerCore
set { _forwarders = value; }
}
public DnsClientProtocol ForwarderProtocol
public DnsTransportProtocol ForwarderProtocol
{
get { return _forwarderProtocol; }
set { _forwarderProtocol = value; }
}
public DnsClientProtocol RecursiveResolveProtocol
public DnsTransportProtocol RecursiveResolveProtocol
{
get { return _recursiveResolveProtocol; }
set { _recursiveResolveProtocol = value; }