diff --git a/DnsServerCore/Dns/DnsServer.cs b/DnsServerCore/Dns/DnsServer.cs index 17319a90..0d06bdb3 100644 --- a/DnsServerCore/Dns/DnsServer.cs +++ b/DnsServerCore/Dns/DnsServer.cs @@ -24,11 +24,9 @@ using DnsServerCore.Dns.Trees; using DnsServerCore.Dns.ZoneManagers; using DnsServerCore.Dns.Zones; using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Server.Kestrel.Core; -using Microsoft.AspNetCore.Server.Kestrel.Https; using Microsoft.AspNetCore.StaticFiles; using Microsoft.Extensions.FileProviders; using Microsoft.Extensions.Logging; @@ -94,6 +92,7 @@ namespace DnsServerCore.Dns static readonly IPEndPoint IPENDPOINT_ANY_0 = new IPEndPoint(IPAddress.Any, 0); static readonly IReadOnlyCollection _aRecords = new DnsARecordData[] { new DnsARecordData(IPAddress.Any) }; static readonly IReadOnlyCollection _aaaaRecords = new DnsAAAARecordData[] { new DnsAAAARecordData(IPAddress.IPv6Any) }; + static readonly List quicApplicationProtocols = new List() { new SslApplicationProtocol("doq") }; string _serverDomain; readonly string _configFolder; @@ -149,7 +148,9 @@ namespace DnsServerCore.Dns int _dnsOverTlsPort = 853; int _dnsOverHttpsPort = 443; int _dnsOverQuicPort = 853; - X509Certificate2 _certificate; + X509Certificate2Collection _certificateCollection; + SslServerAuthenticationOptions _sslServerAuthenticationOptions; + SslServerAuthenticationOptions _quicSslServerAuthenticationOptions; IReadOnlyDictionary _tsigKeys; @@ -547,7 +548,7 @@ namespace DnsServerCore.Dns case DnsTransportProtocol.Tls: SslStream tlsStream = new SslStream(new NetworkStream(socket)); - await tlsStream.AuthenticateAsServerAsync(_certificate).WithTimeout(_tcpReceiveTimeout); + await tlsStream.AuthenticateAsServerAsync(_sslServerAuthenticationOptions).WithTimeout(_tcpReceiveTimeout); await ReadStreamRequestAsync(tlsStream, remoteEP, protocol); break; @@ -2015,10 +2016,10 @@ namespace DnsServerCore.Dns else { //return NODATA/NXDOMAIN response - if (request.Question[0].Name.Length > appResourceRecord.Name.Length) - rcode = DnsResponseCode.NxDomain; - else + if ((request.Question[0].Name.Length == appResourceRecord.Name.Length) || appResourceRecord.Name.StartsWith('*')) rcode = DnsResponseCode.NoError; + else + rcode = DnsResponseCode.NxDomain; authority = zoneInfo.GetApexRecords(DnsResourceRecordType.SOA); } @@ -3089,7 +3090,7 @@ namespace DnsServerCore.Dns } else { - IReadOnlyList options = new EDnsOption[] { new EDnsOption(EDnsOptionCode.EXTENDED_DNS_ERROR, new EDnsExtendedDnsErrorOptionData(EDnsExtendedDnsErrorCode.Other, "Server exception")) }; + IReadOnlyList options = new EDnsOption[] { new EDnsOption(EDnsOptionCode.EXTENDED_DNS_ERROR, new EDnsExtendedDnsErrorOptionData(EDnsExtendedDnsErrorCode.Other, "Resolver exception")) }; DnsDatagram failureResponse = new DnsDatagram(0, true, DnsOpcode.StandardQuery, false, false, true, true, false, dnssecValidation, DnsResponseCode.ServerFailure, new DnsQuestionRecord[] { question }, null, null, null, _udpPayloadSize, dnssecValidation ? EDnsHeaderFlags.DNSSEC_OK : EDnsHeaderFlags.None, options); taskCompletionSource.SetResult(new RecursiveResolveResponse(failureResponse, failureResponse)); @@ -3787,22 +3788,17 @@ namespace DnsServerCore.Dns } //bind to https port - if (_enableDnsOverHttps && (_certificate is not null)) + if (_enableDnsOverHttps && (_certificateCollection is not null)) { - serverOptions.ConfigureHttpsDefaults(delegate (HttpsConnectionAdapterOptions configureOptions) - { - configureOptions.ServerCertificateSelector = delegate (ConnectionContext context, string dnsName) - { - return _certificate; - }; - }); - foreach (IPAddress localAddress in localAddresses) { serverOptions.Listen(localAddress, _dnsOverHttpsPort, delegate (ListenOptions listenOptions) { listenOptions.Protocols = HttpProtocols.Http1AndHttp2AndHttp3; - listenOptions.UseHttps(); + listenOptions.UseHttps(delegate (SslStream stream, SslClientHelloInfo clientHelloInfo, object state, CancellationToken cancellationToken) + { + return ValueTask.FromResult(_sslServerAuthenticationOptions); + }, null); }); } } @@ -3847,7 +3843,7 @@ namespace DnsServerCore.Dns if (_enableDnsOverHttp) _log?.Write(new IPEndPoint(localAddress, _dnsOverHttpPort), "Http", "DNS Server was bound successfully."); - if (_enableDnsOverHttps && (_certificate is not null)) + if (_enableDnsOverHttps && (_certificateCollection is not null)) _log?.Write(new IPEndPoint(localAddress, _dnsOverHttpsPort), "Https", "DNS Server was bound successfully."); } } @@ -3863,7 +3859,7 @@ namespace DnsServerCore.Dns if (_enableDnsOverHttp) _log?.Write(new IPEndPoint(localAddress, _dnsOverHttpPort), "Http", "DNS Server failed to bind."); - if (_enableDnsOverHttps && (_certificate is not null)) + if (_enableDnsOverHttps && (_certificateCollection is not null)) _log?.Write(new IPEndPoint(localAddress, _dnsOverHttpsPort), "Https", "DNS Server failed to bind."); } @@ -4036,7 +4032,7 @@ namespace DnsServerCore.Dns tcpListener?.Dispose(); } - if (_enableDnsOverTls && (_certificate is not null)) + if (_enableDnsOverTls && (_certificateCollection is not null)) { IPEndPoint tlsEP = new IPEndPoint(localEP.Address, _dnsOverTlsPort); Socket tlsListener = null; @@ -4060,7 +4056,7 @@ namespace DnsServerCore.Dns } } - if (_enableDnsOverQuic && (_certificate is not null)) + if (_enableDnsOverQuic && (_certificateCollection is not null)) { IPEndPoint quicEP = new IPEndPoint(localEP.Address, _dnsOverQuicPort); QuicListener quicListener = null; @@ -4071,7 +4067,7 @@ namespace DnsServerCore.Dns { ListenEndPoint = quicEP, ListenBacklog = _listenBacklog, - ApplicationProtocols = new List() { new SslApplicationProtocol("doq") }, + ApplicationProtocols = quicApplicationProtocols, ConnectionOptionsCallback = delegate (QuicConnection quicConnection, SslClientHelloInfo sslClientHello, CancellationToken cancellationToken) { QuicServerConnectionOptions serverConnectionOptions = new QuicServerConnectionOptions() @@ -4081,11 +4077,7 @@ namespace DnsServerCore.Dns MaxInboundUnidirectionalStreams = 0, MaxInboundBidirectionalStreams = _quicMaxInboundStreams, IdleTimeout = TimeSpan.FromMilliseconds(_quicIdleTimeout), - ServerAuthenticationOptions = new SslServerAuthenticationOptions - { - ApplicationProtocols = new List() { new SslApplicationProtocol("doq") }, - ServerCertificate = _certificate - } + ServerAuthenticationOptions = _quicSslServerAuthenticationOptions }; return ValueTask.FromResult(serverConnectionOptions); @@ -4155,7 +4147,7 @@ namespace DnsServerCore.Dns } } - if (_enableDnsOverHttp || (_enableDnsOverHttps && (_certificate is not null))) + if (_enableDnsOverHttp || (_enableDnsOverHttps && (_certificateCollection is not null))) await StartDoHAsync(); _cachePrefetchSamplingTimer = new Timer(CachePrefetchSamplingTimerCallback, null, Timeout.Infinite, Timeout.Infinite); @@ -4609,15 +4601,48 @@ namespace DnsServerCore.Dns set { _dnsOverQuicPort = value; } } - public X509Certificate2 Certificate + public X509Certificate2Collection CertificateCollection { - get { return _certificate; } + get { return _certificateCollection; } set { - if ((value is not null) && !value.HasPrivateKey) - throw new ArgumentException("Tls certificate does not contain private key."); + if (value is null) + { + _certificateCollection = null; + _sslServerAuthenticationOptions = null; + _quicSslServerAuthenticationOptions = null; + } + else + { + X509Certificate2 serverCertificate = null; - _certificate = value; + foreach (X509Certificate2 certificate in value) + { + if (certificate.HasPrivateKey) + { + serverCertificate = certificate; + break; + } + } + + if (serverCertificate is null) + throw new ArgumentException("DNS Server TLS certificate file must contain a certificate with private key."); + + _certificateCollection = value; + + SslStreamCertificateContext certificateContext = SslStreamCertificateContext.Create(serverCertificate, _certificateCollection, false); + + _sslServerAuthenticationOptions = new SslServerAuthenticationOptions() + { + ServerCertificateContext = certificateContext + }; + + _quicSslServerAuthenticationOptions = new SslServerAuthenticationOptions() + { + ApplicationProtocols = quicApplicationProtocols, + ServerCertificateContext = certificateContext + }; + } } }