DnsServer: implemented support for custom real ip header for DoH. Implemented blocking answer ttl feature. Fixed minor issue with rate limiting detection logging feature. Fixed minor issue with DoH start process. Added missing validation checks for optional protocol port properties. Code refactoring changes done.

This commit is contained in:
Shreyas Zare
2024-10-19 16:58:04 +05:30
parent b6b4877c91
commit c7ce7077c5

View File

@@ -174,6 +174,7 @@ namespace DnsServerCore.Dns
X509Certificate2Collection _certificateCollection; X509Certificate2Collection _certificateCollection;
SslServerAuthenticationOptions _sslServerAuthenticationOptions; SslServerAuthenticationOptions _sslServerAuthenticationOptions;
SslServerAuthenticationOptions _quicSslServerAuthenticationOptions; SslServerAuthenticationOptions _quicSslServerAuthenticationOptions;
string _dnsOverHttpRealIpHeader = "X-Real-IP";
IReadOnlyDictionary<string, TsigKey> _tsigKeys; IReadOnlyDictionary<string, TsigKey> _tsigKeys;
@@ -200,6 +201,7 @@ namespace DnsServerCore.Dns
bool _allowTxtBlockingReport = true; bool _allowTxtBlockingReport = true;
IReadOnlyCollection<NetworkAddress> _blockingBypassList; IReadOnlyCollection<NetworkAddress> _blockingBypassList;
DnsServerBlockingType _blockingType = DnsServerBlockingType.NxDomain; DnsServerBlockingType _blockingType = DnsServerBlockingType.NxDomain;
uint _blockingAnswerTtl = 30;
IReadOnlyCollection<DnsARecordData> _customBlockingARecords = Array.Empty<DnsARecordData>(); IReadOnlyCollection<DnsARecordData> _customBlockingARecords = Array.Empty<DnsARecordData>();
IReadOnlyCollection<DnsAAAARecordData> _customBlockingAAAARecords = Array.Empty<DnsAAAARecordData>(); IReadOnlyCollection<DnsAAAARecordData> _customBlockingAAAARecords = Array.Empty<DnsAAAARecordData>();
@@ -744,17 +746,20 @@ namespace DnsServerCore.Dns
} }
//send response //send response
await writeSemaphore.WaitAsync(); await TechnitiumLibrary.TaskExtensions.TimeoutAsync(async delegate (CancellationToken cancellationToken1)
try
{ {
//send dns datagram await writeSemaphore.WaitAsync(cancellationToken1);
await response.WriteToTcpAsync(stream, writeBuffer); try
await stream.FlushAsync(); {
} //send dns datagram
finally await response.WriteToTcpAsync(stream, writeBuffer, cancellationToken1);
{ await stream.FlushAsync(cancellationToken1);
writeSemaphore.Release(); }
} finally
{
writeSemaphore.Release();
}
}, _tcpSendTimeout);
_queryLog?.Write(remoteEP, protocol, request, response); _queryLog?.Write(remoteEP, protocol, request, response);
_stats.QueueUpdate(request, remoteEP, protocol, response, false); _stats.QueueUpdate(request, remoteEP, protocol, response, false);
@@ -907,7 +912,7 @@ namespace DnsServerCore.Dns
private async Task ProcessDoHRequestAsync(HttpContext context) private async Task ProcessDoHRequestAsync(HttpContext context)
{ {
IPEndPoint remoteEP = context.GetRemoteEndPoint(); IPEndPoint remoteEP = context.GetRemoteEndPoint(_dnsOverHttpRealIpHeader);
DnsDatagram dnsRequest = null; DnsDatagram dnsRequest = null;
try try
@@ -927,7 +932,7 @@ namespace DnsServerCore.Dns
if (!request.IsHttps) if (!request.IsHttps)
{ {
//get the actual connection remote EP //get the actual connection remote EP
IPEndPoint connectionEp = context.GetRemoteEndPoint(true); IPEndPoint connectionEp = context.GetRemoteEndPoint(null);
if (!NetUtilities.IsPrivateIP(connectionEp.Address)) if (!NetUtilities.IsPrivateIP(connectionEp.Address))
{ {
@@ -1033,10 +1038,13 @@ namespace DnsServerCore.Dns
response.ContentType = "application/dns-message"; response.ContentType = "application/dns-message";
response.ContentLength = mS.Length; response.ContentLength = mS.Length;
using (Stream s = response.Body) await TechnitiumLibrary.TaskExtensions.TimeoutAsync(async delegate (CancellationToken cancellationToken1)
{ {
await mS.CopyToAsync(s, 512); await using (Stream s = response.Body)
} {
await mS.CopyToAsync(s, 512, cancellationToken1);
}
}, _tcpSendTimeout);
} }
_queryLog?.Write(remoteEP, DnsTransportProtocol.Https, dnsRequest, dnsResponse); _queryLog?.Write(remoteEP, DnsTransportProtocol.Https, dnsRequest, dnsResponse);
@@ -2640,7 +2648,7 @@ namespace DnsServerCore.Dns
//return meta data //return meta data
string blockedDomain = GetBlockedDomain(); string blockedDomain = GetBlockedDomain();
IReadOnlyList<DnsResourceRecord> answer = new DnsResourceRecord[] { new DnsResourceRecord(question.Name, DnsResourceRecordType.TXT, question.Class, 60, new DnsTXTRecordData("source=blocked-zone; domain=" + blockedDomain)) }; IReadOnlyList<DnsResourceRecord> answer = [new DnsResourceRecord(question.Name, DnsResourceRecordType.TXT, question.Class, _blockingAnswerTtl, new DnsTXTRecordData("source=blocked-zone; domain=" + blockedDomain))];
return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NoError, request.Question, answer) { Tag = DnsServerResponseType.Blocked }; return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NoError, request.Question, answer) { Tag = DnsServerResponseType.Blocked };
} }
@@ -2652,7 +2660,7 @@ namespace DnsServerCore.Dns
if (_allowTxtBlockingReport && (request.EDNS is not null)) if (_allowTxtBlockingReport && (request.EDNS is not null))
{ {
blockedDomain = GetBlockedDomain(); blockedDomain = GetBlockedDomain();
options = new EDnsOption[] { new EDnsOption(EDnsOptionCode.EXTENDED_DNS_ERROR, new EDnsExtendedDnsErrorOptionData(EDnsExtendedDnsErrorCode.Blocked, "source=blocked-zone; domain=" + blockedDomain)) }; options = [new EDnsOption(EDnsOptionCode.EXTENDED_DNS_ERROR, new EDnsExtendedDnsErrorOptionData(EDnsExtendedDnsErrorCode.Blocked, "source=blocked-zone; domain=" + blockedDomain))];
} }
IReadOnlyCollection<DnsARecordData> aRecords; IReadOnlyCollection<DnsARecordData> aRecords;
@@ -2678,7 +2686,7 @@ namespace DnsServerCore.Dns
if (parentDomain is null) if (parentDomain is null)
parentDomain = string.Empty; parentDomain = string.Empty;
return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NxDomain, request.Question, null, new DnsResourceRecord[] { new DnsResourceRecord(parentDomain, DnsResourceRecordType.SOA, question.Class, 60, _blockedZoneManager.DnsSOARecord) }, null, request.EDNS is null ? ushort.MinValue : _udpPayloadSize, EDnsHeaderFlags.None, options) { Tag = DnsServerResponseType.Blocked }; return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NxDomain, request.Question, null, [new DnsResourceRecord(parentDomain, DnsResourceRecordType.SOA, question.Class, _blockingAnswerTtl, _blockedZoneManager.DnsSOARecord)], null, request.EDNS is null ? ushort.MinValue : _udpPayloadSize, EDnsHeaderFlags.None, options) { Tag = DnsServerResponseType.Blocked };
default: default:
throw new InvalidOperationException(); throw new InvalidOperationException();
@@ -2691,23 +2699,41 @@ namespace DnsServerCore.Dns
{ {
case DnsResourceRecordType.A: case DnsResourceRecordType.A:
{ {
List<DnsResourceRecord> rrList = new List<DnsResourceRecord>(aRecords.Count); if (aRecords.Count > 0)
{
DnsResourceRecord[] rrList = new DnsResourceRecord[aRecords.Count];
int i = 0;
foreach (DnsARecordData record in aRecords) foreach (DnsARecordData record in aRecords)
rrList.Add(new DnsResourceRecord(question.Name, DnsResourceRecordType.A, question.Class, 60, record)); rrList[i++] = new DnsResourceRecord(question.Name, DnsResourceRecordType.A, question.Class, _blockingAnswerTtl, record);
answer = rrList; answer = rrList;
}
else
{
answer = null;
authority = response.Authority;
}
} }
break; break;
case DnsResourceRecordType.AAAA: case DnsResourceRecordType.AAAA:
{ {
List<DnsResourceRecord> rrList = new List<DnsResourceRecord>(aaaaRecords.Count); if (aaaaRecords.Count > 0)
{
DnsResourceRecord[] rrList = new DnsResourceRecord[aaaaRecords.Count];
int i = 0;
foreach (DnsAAAARecordData record in aaaaRecords) foreach (DnsAAAARecordData record in aaaaRecords)
rrList.Add(new DnsResourceRecord(question.Name, DnsResourceRecordType.AAAA, question.Class, 60, record)); rrList[i++] = new DnsResourceRecord(question.Name, DnsResourceRecordType.AAAA, question.Class, _blockingAnswerTtl, record);
answer = rrList; answer = rrList;
}
else
{
answer = null;
authority = response.Authority;
}
} }
break; break;
@@ -3085,7 +3111,7 @@ namespace DnsServerCore.Dns
} }
//no response available; respond with ServerFailure //no response available; respond with ServerFailure
EDnsOption[] options = [new EDnsOption(EDnsOptionCode.EXTENDED_DNS_ERROR, new EDnsExtendedDnsErrorOptionData(EDnsExtendedDnsErrorCode.Other, "Waiting for resolver"))]; EDnsOption[] options = [new EDnsOption(EDnsOptionCode.EXTENDED_DNS_ERROR, new EDnsExtendedDnsErrorOptionData(EDnsExtendedDnsErrorCode.Other, "Waiting for resolver. Please try again."))];
return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, request.CheckingDisabled, DnsResponseCode.ServerFailure, request.Question, null, null, null, _udpPayloadSize, request.DnssecOk ? EDnsHeaderFlags.DNSSEC_OK : EDnsHeaderFlags.None, options); return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, request.CheckingDisabled, DnsResponseCode.ServerFailure, request.Question, null, null, null, _udpPayloadSize, request.DnssecOk ? EDnsHeaderFlags.DNSSEC_OK : EDnsHeaderFlags.None, options);
} }
@@ -4283,8 +4309,11 @@ namespace DnsServerCore.Dns
{ {
_stats.GetLatestClientSubnetStats(_qpmLimitSampleMinutes, _qpmLimitIPv4PrefixLength, _qpmLimitIPv6PrefixLength, out IReadOnlyDictionary<IPAddress, long> qpmLimitClientSubnetStats, out IReadOnlyDictionary<IPAddress, long> qpmLimitErrorClientSubnetStats); _stats.GetLatestClientSubnetStats(_qpmLimitSampleMinutes, _qpmLimitIPv4PrefixLength, _qpmLimitIPv6PrefixLength, out IReadOnlyDictionary<IPAddress, long> qpmLimitClientSubnetStats, out IReadOnlyDictionary<IPAddress, long> qpmLimitErrorClientSubnetStats);
WriteClientSubnetRateLimitLog(_qpmLimitErrorClientSubnetStats, qpmLimitErrorClientSubnetStats, _qpmLimitErrors, "errors"); if (_qpmLimitErrors > 0)
WriteClientSubnetRateLimitLog(_qpmLimitClientSubnetStats, qpmLimitClientSubnetStats, _qpmLimitRequests, "requests"); WriteClientSubnetRateLimitLog(_qpmLimitErrorClientSubnetStats, qpmLimitErrorClientSubnetStats, _qpmLimitErrors, "errors");
if (_qpmLimitRequests > 0)
WriteClientSubnetRateLimitLog(_qpmLimitClientSubnetStats, qpmLimitClientSubnetStats, _qpmLimitRequests, "requests");
_qpmLimitClientSubnetStats = qpmLimitClientSubnetStats; _qpmLimitClientSubnetStats = qpmLimitClientSubnetStats;
_qpmLimitErrorClientSubnetStats = qpmLimitErrorClientSubnetStats; _qpmLimitErrorClientSubnetStats = qpmLimitErrorClientSubnetStats;
@@ -4395,78 +4424,78 @@ namespace DnsServerCore.Dns
private async Task StartDoHAsync() private async Task StartDoHAsync()
{ {
WebApplicationBuilder builder = WebApplication.CreateBuilder();
builder.Environment.ContentRootFileProvider = new PhysicalFileProvider(Path.GetDirectoryName(_dohwwwFolder))
{
UseActivePolling = true,
UsePollingFileWatcher = true
};
builder.Environment.WebRootFileProvider = new PhysicalFileProvider(_dohwwwFolder)
{
UseActivePolling = true,
UsePollingFileWatcher = true
};
IReadOnlyList<IPAddress> localAddresses = GetValidKestralLocalAddresses(_localEndPoints.Convert(delegate (IPEndPoint ep) { return ep.Address; })); IReadOnlyList<IPAddress> localAddresses = GetValidKestralLocalAddresses(_localEndPoints.Convert(delegate (IPEndPoint ep) { return ep.Address; }));
builder.WebHost.ConfigureKestrel(delegate (WebHostBuilderContext context, KestrelServerOptions serverOptions)
{
//bind to http port
if (_enableDnsOverHttp)
{
foreach (IPAddress localAddress in localAddresses)
serverOptions.Listen(localAddress, _dnsOverHttpPort);
}
//bind to https port
if (_enableDnsOverHttps && (_certificateCollection is not null))
{
foreach (IPAddress localAddress in localAddresses)
{
serverOptions.Listen(localAddress, _dnsOverHttpsPort, delegate (ListenOptions listenOptions)
{
listenOptions.Protocols = _enableDnsOverHttp3 ? HttpProtocols.Http1AndHttp2AndHttp3 : HttpProtocols.Http1AndHttp2;
listenOptions.UseHttps(delegate (SslStream stream, SslClientHelloInfo clientHelloInfo, object state, CancellationToken cancellationToken)
{
return ValueTask.FromResult(_sslServerAuthenticationOptions);
}, null);
});
}
}
serverOptions.AddServerHeader = false;
serverOptions.Limits.RequestHeadersTimeout = TimeSpan.FromMilliseconds(_tcpReceiveTimeout);
serverOptions.Limits.KeepAliveTimeout = TimeSpan.FromMilliseconds(_tcpReceiveTimeout);
serverOptions.Limits.MaxRequestHeadersTotalSize = 4096;
serverOptions.Limits.MaxRequestLineSize = serverOptions.Limits.MaxRequestHeadersTotalSize;
serverOptions.Limits.MaxRequestBufferSize = serverOptions.Limits.MaxRequestLineSize;
serverOptions.Limits.MaxRequestBodySize = 64 * 1024;
serverOptions.Limits.MaxResponseBufferSize = 4096;
});
builder.Logging.ClearProviders();
_dohWebService = builder.Build();
_dohWebService.UseDefaultFiles();
_dohWebService.UseStaticFiles(new StaticFileOptions()
{
OnPrepareResponse = delegate (StaticFileResponseContext ctx)
{
ctx.Context.Response.Headers["X-Robots-Tag"] = "noindex, nofollow";
ctx.Context.Response.Headers.CacheControl = "private, max-age=300";
},
ServeUnknownFileTypes = true
});
_dohWebService.UseRouting();
_dohWebService.MapGet("/dns-query", ProcessDoHRequestAsync);
_dohWebService.MapPost("/dns-query", ProcessDoHRequestAsync);
try try
{ {
WebApplicationBuilder builder = WebApplication.CreateBuilder();
builder.Environment.ContentRootFileProvider = new PhysicalFileProvider(Path.GetDirectoryName(_dohwwwFolder))
{
UseActivePolling = true,
UsePollingFileWatcher = true
};
builder.Environment.WebRootFileProvider = new PhysicalFileProvider(_dohwwwFolder)
{
UseActivePolling = true,
UsePollingFileWatcher = true
};
builder.WebHost.ConfigureKestrel(delegate (WebHostBuilderContext context, KestrelServerOptions serverOptions)
{
//bind to http port
if (_enableDnsOverHttp)
{
foreach (IPAddress localAddress in localAddresses)
serverOptions.Listen(localAddress, _dnsOverHttpPort);
}
//bind to https port
if (_enableDnsOverHttps && (_certificateCollection is not null))
{
foreach (IPAddress localAddress in localAddresses)
{
serverOptions.Listen(localAddress, _dnsOverHttpsPort, delegate (ListenOptions listenOptions)
{
listenOptions.Protocols = _enableDnsOverHttp3 ? HttpProtocols.Http1AndHttp2AndHttp3 : HttpProtocols.Http1AndHttp2;
listenOptions.UseHttps(delegate (SslStream stream, SslClientHelloInfo clientHelloInfo, object state, CancellationToken cancellationToken)
{
return ValueTask.FromResult(_sslServerAuthenticationOptions);
}, null);
});
}
}
serverOptions.AddServerHeader = false;
serverOptions.Limits.RequestHeadersTimeout = TimeSpan.FromMilliseconds(_tcpReceiveTimeout);
serverOptions.Limits.KeepAliveTimeout = TimeSpan.FromMilliseconds(_tcpReceiveTimeout);
serverOptions.Limits.MaxRequestHeadersTotalSize = 4096;
serverOptions.Limits.MaxRequestLineSize = serverOptions.Limits.MaxRequestHeadersTotalSize;
serverOptions.Limits.MaxRequestBufferSize = serverOptions.Limits.MaxRequestLineSize;
serverOptions.Limits.MaxRequestBodySize = 64 * 1024;
serverOptions.Limits.MaxResponseBufferSize = 4096;
});
builder.Logging.ClearProviders();
_dohWebService = builder.Build();
_dohWebService.UseDefaultFiles();
_dohWebService.UseStaticFiles(new StaticFileOptions()
{
OnPrepareResponse = delegate (StaticFileResponseContext ctx)
{
ctx.Context.Response.Headers["X-Robots-Tag"] = "noindex, nofollow";
ctx.Context.Response.Headers.CacheControl = "private, max-age=300";
},
ServeUnknownFileTypes = true
});
_dohWebService.UseRouting();
_dohWebService.MapGet("/dns-query", ProcessDoHRequestAsync);
_dohWebService.MapPost("/dns-query", ProcessDoHRequestAsync);
await _dohWebService.StartAsync(); await _dohWebService.StartAsync();
if (_log is not null) if (_log is not null)
@@ -5516,37 +5545,73 @@ namespace DnsServerCore.Dns
public int DnsOverUdpProxyPort public int DnsOverUdpProxyPort
{ {
get { return _dnsOverUdpProxyPort; } get { return _dnsOverUdpProxyPort; }
set { _dnsOverUdpProxyPort = value; } set
{
if ((value < ushort.MinValue) || (value > ushort.MaxValue))
throw new ArgumentOutOfRangeException(nameof(DnsOverUdpProxyPort), "Port number valid range is from 0 to 65535.");
_dnsOverUdpProxyPort = value;
}
} }
public int DnsOverTcpProxyPort public int DnsOverTcpProxyPort
{ {
get { return _dnsOverTcpProxyPort; } get { return _dnsOverTcpProxyPort; }
set { _dnsOverTcpProxyPort = value; } set
{
if ((value < ushort.MinValue) || (value > ushort.MaxValue))
throw new ArgumentOutOfRangeException(nameof(DnsOverTcpProxyPort), "Port number valid range is from 0 to 65535.");
_dnsOverTcpProxyPort = value;
}
} }
public int DnsOverHttpPort public int DnsOverHttpPort
{ {
get { return _dnsOverHttpPort; } get { return _dnsOverHttpPort; }
set { _dnsOverHttpPort = value; } set
{
if ((value < ushort.MinValue) || (value > ushort.MaxValue))
throw new ArgumentOutOfRangeException(nameof(DnsOverHttpPort), "Port number valid range is from 0 to 65535.");
_dnsOverHttpPort = value;
}
} }
public int DnsOverTlsPort public int DnsOverTlsPort
{ {
get { return _dnsOverTlsPort; } get { return _dnsOverTlsPort; }
set { _dnsOverTlsPort = value; } set
{
if ((value < ushort.MinValue) || (value > ushort.MaxValue))
throw new ArgumentOutOfRangeException(nameof(DnsOverTlsPort), "Port number valid range is from 0 to 65535.");
_dnsOverTlsPort = value;
}
} }
public int DnsOverHttpsPort public int DnsOverHttpsPort
{ {
get { return _dnsOverHttpsPort; } get { return _dnsOverHttpsPort; }
set { _dnsOverHttpsPort = value; } set
{
if ((value < ushort.MinValue) || (value > ushort.MaxValue))
throw new ArgumentOutOfRangeException(nameof(DnsOverHttpsPort), "Port number valid range is from 0 to 65535.");
_dnsOverHttpsPort = value;
}
} }
public int DnsOverQuicPort public int DnsOverQuicPort
{ {
get { return _dnsOverQuicPort; } get { return _dnsOverQuicPort; }
set { _dnsOverQuicPort = value; } set
{
if ((value < ushort.MinValue) || (value > ushort.MaxValue))
throw new ArgumentOutOfRangeException(nameof(DnsOverQuicPort), "Port number valid range is from 0 to 65535.");
_dnsOverQuicPort = value;
}
} }
public X509Certificate2Collection CertificateCollection public X509Certificate2Collection CertificateCollection
@@ -5594,6 +5659,22 @@ namespace DnsServerCore.Dns
} }
} }
public string DnsOverHttpRealIpHeader
{
get { return _dnsOverHttpRealIpHeader; }
set
{
if (string.IsNullOrEmpty(value))
_dnsOverHttpRealIpHeader = "X-Real-IP";
else if (value.Length > 255)
throw new ArgumentException("DNS-over-HTTP Real IP header name cannot exceed 255 characters.", nameof(DnsOverHttpRealIpHeader));
else if (value.Contains(' '))
throw new ArgumentException("DNS-over-HTTP Real IP header name cannot contain invalid characters.", nameof(DnsOverHttpRealIpHeader));
else
_dnsOverHttpRealIpHeader = value;
}
}
public IReadOnlyDictionary<string, TsigKey> TsigKeys public IReadOnlyDictionary<string, TsigKey> TsigKeys
{ {
get { return _tsigKeys; } get { return _tsigKeys; }
@@ -5817,6 +5898,22 @@ namespace DnsServerCore.Dns
set { _blockingType = value; } set { _blockingType = value; }
} }
public uint BlockingAnswerTtl
{
get { return _blockingAnswerTtl; }
set
{
if (_blockingAnswerTtl != value)
{
_blockingAnswerTtl = value;
//update SOA MINIMUM values
_blockedZoneManager.UpdateServerDomain();
_blockListZoneManager.UpdateServerDomain();
}
}
}
public IReadOnlyCollection<DnsARecordData> CustomBlockingARecords public IReadOnlyCollection<DnsARecordData> CustomBlockingARecords
{ {
get { return _customBlockingARecords; } get { return _customBlockingARecords; }