DnsServer: implemented support for custom real ip header for DoH. Implemented blocking answer ttl feature. Fixed minor issue with rate limiting detection logging feature. Fixed minor issue with DoH start process. Added missing validation checks for optional protocol port properties. Code refactoring changes done.

This commit is contained in:
Shreyas Zare
2024-10-19 16:58:04 +05:30
parent b6b4877c91
commit c7ce7077c5

View File

@@ -174,6 +174,7 @@ namespace DnsServerCore.Dns
X509Certificate2Collection _certificateCollection;
SslServerAuthenticationOptions _sslServerAuthenticationOptions;
SslServerAuthenticationOptions _quicSslServerAuthenticationOptions;
string _dnsOverHttpRealIpHeader = "X-Real-IP";
IReadOnlyDictionary<string, TsigKey> _tsigKeys;
@@ -200,6 +201,7 @@ namespace DnsServerCore.Dns
bool _allowTxtBlockingReport = true;
IReadOnlyCollection<NetworkAddress> _blockingBypassList;
DnsServerBlockingType _blockingType = DnsServerBlockingType.NxDomain;
uint _blockingAnswerTtl = 30;
IReadOnlyCollection<DnsARecordData> _customBlockingARecords = Array.Empty<DnsARecordData>();
IReadOnlyCollection<DnsAAAARecordData> _customBlockingAAAARecords = Array.Empty<DnsAAAARecordData>();
@@ -744,17 +746,20 @@ namespace DnsServerCore.Dns
}
//send response
await writeSemaphore.WaitAsync();
await TechnitiumLibrary.TaskExtensions.TimeoutAsync(async delegate (CancellationToken cancellationToken1)
{
await writeSemaphore.WaitAsync(cancellationToken1);
try
{
//send dns datagram
await response.WriteToTcpAsync(stream, writeBuffer);
await stream.FlushAsync();
await response.WriteToTcpAsync(stream, writeBuffer, cancellationToken1);
await stream.FlushAsync(cancellationToken1);
}
finally
{
writeSemaphore.Release();
}
}, _tcpSendTimeout);
_queryLog?.Write(remoteEP, protocol, request, response);
_stats.QueueUpdate(request, remoteEP, protocol, response, false);
@@ -907,7 +912,7 @@ namespace DnsServerCore.Dns
private async Task ProcessDoHRequestAsync(HttpContext context)
{
IPEndPoint remoteEP = context.GetRemoteEndPoint();
IPEndPoint remoteEP = context.GetRemoteEndPoint(_dnsOverHttpRealIpHeader);
DnsDatagram dnsRequest = null;
try
@@ -927,7 +932,7 @@ namespace DnsServerCore.Dns
if (!request.IsHttps)
{
//get the actual connection remote EP
IPEndPoint connectionEp = context.GetRemoteEndPoint(true);
IPEndPoint connectionEp = context.GetRemoteEndPoint(null);
if (!NetUtilities.IsPrivateIP(connectionEp.Address))
{
@@ -1033,10 +1038,13 @@ namespace DnsServerCore.Dns
response.ContentType = "application/dns-message";
response.ContentLength = mS.Length;
using (Stream s = response.Body)
await TechnitiumLibrary.TaskExtensions.TimeoutAsync(async delegate (CancellationToken cancellationToken1)
{
await mS.CopyToAsync(s, 512);
await using (Stream s = response.Body)
{
await mS.CopyToAsync(s, 512, cancellationToken1);
}
}, _tcpSendTimeout);
}
_queryLog?.Write(remoteEP, DnsTransportProtocol.Https, dnsRequest, dnsResponse);
@@ -2640,7 +2648,7 @@ namespace DnsServerCore.Dns
//return meta data
string blockedDomain = GetBlockedDomain();
IReadOnlyList<DnsResourceRecord> answer = new DnsResourceRecord[] { new DnsResourceRecord(question.Name, DnsResourceRecordType.TXT, question.Class, 60, new DnsTXTRecordData("source=blocked-zone; domain=" + blockedDomain)) };
IReadOnlyList<DnsResourceRecord> answer = [new DnsResourceRecord(question.Name, DnsResourceRecordType.TXT, question.Class, _blockingAnswerTtl, new DnsTXTRecordData("source=blocked-zone; domain=" + blockedDomain))];
return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NoError, request.Question, answer) { Tag = DnsServerResponseType.Blocked };
}
@@ -2652,7 +2660,7 @@ namespace DnsServerCore.Dns
if (_allowTxtBlockingReport && (request.EDNS is not null))
{
blockedDomain = GetBlockedDomain();
options = new EDnsOption[] { new EDnsOption(EDnsOptionCode.EXTENDED_DNS_ERROR, new EDnsExtendedDnsErrorOptionData(EDnsExtendedDnsErrorCode.Blocked, "source=blocked-zone; domain=" + blockedDomain)) };
options = [new EDnsOption(EDnsOptionCode.EXTENDED_DNS_ERROR, new EDnsExtendedDnsErrorOptionData(EDnsExtendedDnsErrorCode.Blocked, "source=blocked-zone; domain=" + blockedDomain))];
}
IReadOnlyCollection<DnsARecordData> aRecords;
@@ -2678,7 +2686,7 @@ namespace DnsServerCore.Dns
if (parentDomain is null)
parentDomain = string.Empty;
return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NxDomain, request.Question, null, new DnsResourceRecord[] { new DnsResourceRecord(parentDomain, DnsResourceRecordType.SOA, question.Class, 60, _blockedZoneManager.DnsSOARecord) }, null, request.EDNS is null ? ushort.MinValue : _udpPayloadSize, EDnsHeaderFlags.None, options) { Tag = DnsServerResponseType.Blocked };
return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NxDomain, request.Question, null, [new DnsResourceRecord(parentDomain, DnsResourceRecordType.SOA, question.Class, _blockingAnswerTtl, _blockedZoneManager.DnsSOARecord)], null, request.EDNS is null ? ushort.MinValue : _udpPayloadSize, EDnsHeaderFlags.None, options) { Tag = DnsServerResponseType.Blocked };
default:
throw new InvalidOperationException();
@@ -2691,24 +2699,42 @@ namespace DnsServerCore.Dns
{
case DnsResourceRecordType.A:
{
List<DnsResourceRecord> rrList = new List<DnsResourceRecord>(aRecords.Count);
if (aRecords.Count > 0)
{
DnsResourceRecord[] rrList = new DnsResourceRecord[aRecords.Count];
int i = 0;
foreach (DnsARecordData record in aRecords)
rrList.Add(new DnsResourceRecord(question.Name, DnsResourceRecordType.A, question.Class, 60, record));
rrList[i++] = new DnsResourceRecord(question.Name, DnsResourceRecordType.A, question.Class, _blockingAnswerTtl, record);
answer = rrList;
}
else
{
answer = null;
authority = response.Authority;
}
}
break;
case DnsResourceRecordType.AAAA:
{
List<DnsResourceRecord> rrList = new List<DnsResourceRecord>(aaaaRecords.Count);
if (aaaaRecords.Count > 0)
{
DnsResourceRecord[] rrList = new DnsResourceRecord[aaaaRecords.Count];
int i = 0;
foreach (DnsAAAARecordData record in aaaaRecords)
rrList.Add(new DnsResourceRecord(question.Name, DnsResourceRecordType.AAAA, question.Class, 60, record));
rrList[i++] = new DnsResourceRecord(question.Name, DnsResourceRecordType.AAAA, question.Class, _blockingAnswerTtl, record);
answer = rrList;
}
else
{
answer = null;
authority = response.Authority;
}
}
break;
default:
@@ -3085,7 +3111,7 @@ namespace DnsServerCore.Dns
}
//no response available; respond with ServerFailure
EDnsOption[] options = [new EDnsOption(EDnsOptionCode.EXTENDED_DNS_ERROR, new EDnsExtendedDnsErrorOptionData(EDnsExtendedDnsErrorCode.Other, "Waiting for resolver"))];
EDnsOption[] options = [new EDnsOption(EDnsOptionCode.EXTENDED_DNS_ERROR, new EDnsExtendedDnsErrorOptionData(EDnsExtendedDnsErrorCode.Other, "Waiting for resolver. Please try again."))];
return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, request.CheckingDisabled, DnsResponseCode.ServerFailure, request.Question, null, null, null, _udpPayloadSize, request.DnssecOk ? EDnsHeaderFlags.DNSSEC_OK : EDnsHeaderFlags.None, options);
}
@@ -4283,7 +4309,10 @@ namespace DnsServerCore.Dns
{
_stats.GetLatestClientSubnetStats(_qpmLimitSampleMinutes, _qpmLimitIPv4PrefixLength, _qpmLimitIPv6PrefixLength, out IReadOnlyDictionary<IPAddress, long> qpmLimitClientSubnetStats, out IReadOnlyDictionary<IPAddress, long> qpmLimitErrorClientSubnetStats);
if (_qpmLimitErrors > 0)
WriteClientSubnetRateLimitLog(_qpmLimitErrorClientSubnetStats, qpmLimitErrorClientSubnetStats, _qpmLimitErrors, "errors");
if (_qpmLimitRequests > 0)
WriteClientSubnetRateLimitLog(_qpmLimitClientSubnetStats, qpmLimitClientSubnetStats, _qpmLimitRequests, "requests");
_qpmLimitClientSubnetStats = qpmLimitClientSubnetStats;
@@ -4394,6 +4423,10 @@ namespace DnsServerCore.Dns
#region doh web service
private async Task StartDoHAsync()
{
IReadOnlyList<IPAddress> localAddresses = GetValidKestralLocalAddresses(_localEndPoints.Convert(delegate (IPEndPoint ep) { return ep.Address; }));
try
{
WebApplicationBuilder builder = WebApplication.CreateBuilder();
@@ -4409,8 +4442,6 @@ namespace DnsServerCore.Dns
UsePollingFileWatcher = true
};
IReadOnlyList<IPAddress> localAddresses = GetValidKestralLocalAddresses(_localEndPoints.Convert(delegate (IPEndPoint ep) { return ep.Address; }));
builder.WebHost.ConfigureKestrel(delegate (WebHostBuilderContext context, KestrelServerOptions serverOptions)
{
//bind to http port
@@ -4465,8 +4496,6 @@ namespace DnsServerCore.Dns
_dohWebService.MapGet("/dns-query", ProcessDoHRequestAsync);
_dohWebService.MapPost("/dns-query", ProcessDoHRequestAsync);
try
{
await _dohWebService.StartAsync();
if (_log is not null)
@@ -5516,37 +5545,73 @@ namespace DnsServerCore.Dns
public int DnsOverUdpProxyPort
{
get { return _dnsOverUdpProxyPort; }
set { _dnsOverUdpProxyPort = value; }
set
{
if ((value < ushort.MinValue) || (value > ushort.MaxValue))
throw new ArgumentOutOfRangeException(nameof(DnsOverUdpProxyPort), "Port number valid range is from 0 to 65535.");
_dnsOverUdpProxyPort = value;
}
}
public int DnsOverTcpProxyPort
{
get { return _dnsOverTcpProxyPort; }
set { _dnsOverTcpProxyPort = value; }
set
{
if ((value < ushort.MinValue) || (value > ushort.MaxValue))
throw new ArgumentOutOfRangeException(nameof(DnsOverTcpProxyPort), "Port number valid range is from 0 to 65535.");
_dnsOverTcpProxyPort = value;
}
}
public int DnsOverHttpPort
{
get { return _dnsOverHttpPort; }
set { _dnsOverHttpPort = value; }
set
{
if ((value < ushort.MinValue) || (value > ushort.MaxValue))
throw new ArgumentOutOfRangeException(nameof(DnsOverHttpPort), "Port number valid range is from 0 to 65535.");
_dnsOverHttpPort = value;
}
}
public int DnsOverTlsPort
{
get { return _dnsOverTlsPort; }
set { _dnsOverTlsPort = value; }
set
{
if ((value < ushort.MinValue) || (value > ushort.MaxValue))
throw new ArgumentOutOfRangeException(nameof(DnsOverTlsPort), "Port number valid range is from 0 to 65535.");
_dnsOverTlsPort = value;
}
}
public int DnsOverHttpsPort
{
get { return _dnsOverHttpsPort; }
set { _dnsOverHttpsPort = value; }
set
{
if ((value < ushort.MinValue) || (value > ushort.MaxValue))
throw new ArgumentOutOfRangeException(nameof(DnsOverHttpsPort), "Port number valid range is from 0 to 65535.");
_dnsOverHttpsPort = value;
}
}
public int DnsOverQuicPort
{
get { return _dnsOverQuicPort; }
set { _dnsOverQuicPort = value; }
set
{
if ((value < ushort.MinValue) || (value > ushort.MaxValue))
throw new ArgumentOutOfRangeException(nameof(DnsOverQuicPort), "Port number valid range is from 0 to 65535.");
_dnsOverQuicPort = value;
}
}
public X509Certificate2Collection CertificateCollection
@@ -5594,6 +5659,22 @@ namespace DnsServerCore.Dns
}
}
public string DnsOverHttpRealIpHeader
{
get { return _dnsOverHttpRealIpHeader; }
set
{
if (string.IsNullOrEmpty(value))
_dnsOverHttpRealIpHeader = "X-Real-IP";
else if (value.Length > 255)
throw new ArgumentException("DNS-over-HTTP Real IP header name cannot exceed 255 characters.", nameof(DnsOverHttpRealIpHeader));
else if (value.Contains(' '))
throw new ArgumentException("DNS-over-HTTP Real IP header name cannot contain invalid characters.", nameof(DnsOverHttpRealIpHeader));
else
_dnsOverHttpRealIpHeader = value;
}
}
public IReadOnlyDictionary<string, TsigKey> TsigKeys
{
get { return _tsigKeys; }
@@ -5817,6 +5898,22 @@ namespace DnsServerCore.Dns
set { _blockingType = value; }
}
public uint BlockingAnswerTtl
{
get { return _blockingAnswerTtl; }
set
{
if (_blockingAnswerTtl != value)
{
_blockingAnswerTtl = value;
//update SOA MINIMUM values
_blockedZoneManager.UpdateServerDomain();
_blockListZoneManager.UpdateServerDomain();
}
}
}
public IReadOnlyCollection<DnsARecordData> CustomBlockingARecords
{
get { return _customBlockingARecords; }