diff --git a/DnsServerCore/WebService.cs b/DnsServerCore/WebService.cs index 0d58d020..3f54b543 100644 --- a/DnsServerCore/WebService.cs +++ b/DnsServerCore/WebService.cs @@ -31,6 +31,8 @@ using System.IO; using System.IO.Compression; using System.Net; using System.Net.Http; +using System.Net.Security; +using System.Net.Sockets; using System.Reflection; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; @@ -71,16 +73,26 @@ namespace DnsServerCore DnsServer _dnsServer; DhcpServer _dhcpServer; - int _webServicePort; + IReadOnlyList _webServiceLocalAddresses = new IPAddress[] { IPAddress.Any, IPAddress.IPv6Any }; + int _webServiceHttpPort = 5380; + int _webServiceTlsPort = 53443; + bool _webServiceEnableTls; + bool _webServiceHttpToTlsRedirect; + string _webServiceTlsCertificatePath; + string _webServiceTlsCertificatePassword; + DateTime _webServiceTlsCertificateLastModifiedOn; + HttpListener _webService; - Thread _webServiceThread; + IReadOnlyList _webServiceTlsListeners; + X509Certificate2 _webServiceTlsCertificate; readonly IndependentTaskScheduler _webServiceTaskScheduler = new IndependentTaskScheduler(ThreadPriority.AboveNormal); string _webServiceHostname; - string _tlsCertificatePath; - string _tlsCertificatePassword; + string _dnsTlsCertificatePath; + string _dnsTlsCertificatePassword; + DateTime _dnsTlsCertificateLastModifiedOn; + Timer _tlsCertificateUpdateTimer; - DateTime _tlsCertificateLastModifiedOn; const int TLS_CERTIFICATE_UPDATE_TIMER_INITIAL_INTERVAL = 60000; const int TLS_CERTIFICATE_UPDATE_TIMER_INTERVAL = 60000; @@ -135,7 +147,7 @@ namespace DnsServerCore #region IDisposable - private bool _disposed = false; + private bool _disposed; protected virtual void Dispose(bool disposing) { @@ -171,27 +183,122 @@ namespace DnsServerCore #region private - private void AcceptWebRequestAsync(object state) + private async Task AcceptWebRequestAsync() { try { while (true) { - HttpListenerContext context = _webService.GetContext(); + HttpListenerContext context = await _webService.GetContextAsync(); - _ = Task.Factory.StartNew(delegate () + if ((_webServiceTlsListeners != null) && (_webServiceTlsListeners.Count > 0) && _webServiceHttpToTlsRedirect) { - return ProcessRequestAsync(context.Request, context.Response); - }, CancellationToken.None, TaskCreationOptions.DenyChildAttach, _webServiceTaskScheduler); + IPEndPoint remoteEP = context.Request.RemoteEndPoint; + + if ((remoteEP != null) && !IPAddress.IsLoopback(remoteEP.Address)) + { + string domain = _webServiceTlsCertificate.GetNameInfo(X509NameType.DnsName, false); + string redirectUri = "https://" + domain + ":" + _webServiceTlsPort + context.Request.Url.PathAndQuery; + + context.Response.Redirect(redirectUri); + context.Response.Close(); + + continue; + } + } + + _ = ProcessRequestAsync(context.Request, context.Response); } } + catch (HttpListenerException ex) + { + if (ex.ErrorCode == 995) + return; //web service stopping + + _log.Write(ex); + } catch (Exception ex) { if ((_state == ServiceState.Stopping) || (_state == ServiceState.Stopped)) return; //web service stopping _log.Write(ex); - throw; + } + } + + private async Task AcceptTlsWebRequestAsync(Socket tlsListener) + { + try + { + while (true) + { + Socket socket = await tlsListener.AcceptAsync(); + + _ = TlsToHttpTunnelAsync(socket); + } + } + catch (SocketException ex) + { + if (ex.SocketErrorCode == SocketError.OperationAborted) + return; //web service stopping + + _log.Write(ex); + } + catch (Exception ex) + { + if ((_state == ServiceState.Stopping) || (_state == ServiceState.Stopped)) + return; //web service stopping + + _log.Write(ex); + } + } + + private async Task TlsToHttpTunnelAsync(Socket socket) + { + bool dispose = true; + Socket tunnel = null; + + try + { + if (_webServiceLocalAddresses.Count < 1) + return; + + SslStream sslStream = new SslStream(new NetworkStream(socket, true)); + + await sslStream.AuthenticateAsServerAsync(_webServiceTlsCertificate); + + IPEndPoint httpEP; + + if (_webServiceLocalAddresses[0].Equals(IPAddress.Any)) + httpEP = new IPEndPoint(IPAddress.Loopback, _webServiceHttpPort); + else if (_webServiceLocalAddresses[0].Equals(IPAddress.IPv6Any)) + httpEP = new IPEndPoint(IPAddress.IPv6Loopback, _webServiceHttpPort); + else + httpEP = new IPEndPoint(_webServiceLocalAddresses[0], _webServiceHttpPort); + + tunnel = new Socket(httpEP.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + tunnel.Connect(httpEP); + + NetworkStream tunnelStream = new NetworkStream(tunnel, true); + + _ = sslStream.CopyToAsync(tunnelStream).ContinueWith(delegate (Task prevTask) { sslStream.Dispose(); tunnelStream.Dispose(); }); + _ = tunnelStream.CopyToAsync(sslStream).ContinueWith(delegate (Task prevTask) { sslStream.Dispose(); tunnelStream.Dispose(); }); + + dispose = false; + } + catch (Exception ex) + { + _log.Write(ex); + } + finally + { + if (dispose) + { + socket.Dispose(); + + if (tunnel != null) + tunnel.Dispose(); + } } } @@ -266,13 +373,17 @@ namespace DnsServerCore return; case "/api/restoreSettings": - await RestoreSettingsAsync(request); + await RestoreSettingsAsync(request, jsonWriter); break; case "/api/getStats": await GetStats(request, jsonWriter); break; + case "/api/getTopStats": + await GetTopStats(request, jsonWriter); + break; + case "/api/flushDnsCache": FlushCache(request); break; @@ -909,12 +1020,9 @@ namespace DnsServerCore jsonWriter.WritePropertyName("version"); jsonWriter.WriteValue(GetCleanVersion(_currentVersion)); - jsonWriter.WritePropertyName("serverDomain"); + jsonWriter.WritePropertyName("dnsServerDomain"); jsonWriter.WriteValue(_dnsServer.ServerDomain); - jsonWriter.WritePropertyName("webServicePort"); - jsonWriter.WriteValue(_webServicePort); - jsonWriter.WritePropertyName("dnsServerLocalEndPoints"); jsonWriter.WriteStartArray(); @@ -923,6 +1031,32 @@ namespace DnsServerCore jsonWriter.WriteEndArray(); + jsonWriter.WritePropertyName("webServiceLocalAddresses"); + jsonWriter.WriteStartArray(); + + foreach (IPAddress localAddress in _webServiceLocalAddresses) + jsonWriter.WriteValue(localAddress.ToString()); + + jsonWriter.WriteEndArray(); + + jsonWriter.WritePropertyName("webServiceHttpPort"); + jsonWriter.WriteValue(_webServiceHttpPort); + + jsonWriter.WritePropertyName("webServiceEnableTls"); + jsonWriter.WriteValue(_webServiceEnableTls); + + jsonWriter.WritePropertyName("webServiceHttpToTlsRedirect"); + jsonWriter.WriteValue(_webServiceHttpToTlsRedirect); + + jsonWriter.WritePropertyName("webServiceTlsPort"); + jsonWriter.WriteValue(_webServiceTlsPort); + + jsonWriter.WritePropertyName("webServiceTlsCertificatePath"); + jsonWriter.WriteValue(_webServiceTlsCertificatePath); + + jsonWriter.WritePropertyName("webServiceTlsCertificatePassword"); + jsonWriter.WriteValue("************"); + jsonWriter.WritePropertyName("enableDnsOverHttp"); jsonWriter.WriteValue(_dnsServer.EnableDnsOverHttp); @@ -932,10 +1066,10 @@ namespace DnsServerCore jsonWriter.WritePropertyName("enableDnsOverHttps"); jsonWriter.WriteValue(_dnsServer.EnableDnsOverHttps); - jsonWriter.WritePropertyName("tlsCertificatePath"); - jsonWriter.WriteValue(_tlsCertificatePath); + jsonWriter.WritePropertyName("dnsTlsCertificatePath"); + jsonWriter.WriteValue(_dnsTlsCertificatePath); - jsonWriter.WritePropertyName("tlsCertificatePassword"); + jsonWriter.WritePropertyName("dnsTlsCertificatePassword"); jsonWriter.WriteValue("************"); jsonWriter.WritePropertyName("preferIPv6"); @@ -1074,9 +1208,12 @@ namespace DnsServerCore private void SetDnsSettings(HttpListenerRequest request, JsonTextWriter jsonWriter) { - string strServerDomain = request.QueryString["serverDomain"]; - if (!string.IsNullOrEmpty(strServerDomain)) - _dnsServer.ServerDomain = strServerDomain; + bool restartDnsService = false; + bool restartWebService = false; + + string strDnsServerDomain = request.QueryString["dnsServerDomain"]; + if (!string.IsNullOrEmpty(strDnsServerDomain)) + _dnsServer.ServerDomain = strDnsServerDomain; string strDnsServerLocalEndPoints = request.QueryString["dnsServerLocalEndPoints"]; if (strDnsServerLocalEndPoints != null) @@ -1094,46 +1231,177 @@ namespace DnsServerCore localEndPoints.Add(nameServer.IPEndPoint); } - _dnsServer.LocalEndPoints = localEndPoints; + if (localEndPoints.Count > 0) + { + if (_dnsServer.LocalEndPoints.Count != localEndPoints.Count) + { + restartDnsService = true; + } + else + { + foreach (IPEndPoint currentLocalEP in _dnsServer.LocalEndPoints) + { + if (!localEndPoints.Contains(currentLocalEP)) + { + restartDnsService = true; + break; + } + } + } + + _dnsServer.LocalEndPoints = localEndPoints; + } } - int oldWebServicePort = _webServicePort; - - string strWebServicePort = request.QueryString["webServicePort"]; - if (!string.IsNullOrEmpty(strWebServicePort)) - _webServicePort = int.Parse(strWebServicePort); - - string enableDnsOverHttp = request.QueryString["enableDnsOverHttp"]; - if (!string.IsNullOrEmpty(enableDnsOverHttp)) - _dnsServer.EnableDnsOverHttp = bool.Parse(enableDnsOverHttp); - - string strEnableDnsOverTls = request.QueryString["enableDnsOverTls"]; - if (!string.IsNullOrEmpty(strEnableDnsOverTls)) - _dnsServer.EnableDnsOverTls = bool.Parse(strEnableDnsOverTls); - - string strEnableDnsOverHttps = request.QueryString["enableDnsOverHttps"]; - if (!string.IsNullOrEmpty(strEnableDnsOverHttps)) - _dnsServer.EnableDnsOverHttps = bool.Parse(strEnableDnsOverHttps); - - string strTlsCertificatePath = request.QueryString["tlsCertificatePath"]; - string strTlsCertificatePassword = request.QueryString["tlsCertificatePassword"]; - if (string.IsNullOrEmpty(strTlsCertificatePath)) + string strWebServiceLocalAddresses = request.QueryString["webServiceLocalAddresses"]; + if (strWebServiceLocalAddresses != null) { - StopTlsCertificateUpdateTimer(); - _tlsCertificatePath = null; - _tlsCertificatePassword = ""; + if (string.IsNullOrEmpty(strWebServiceLocalAddresses)) + strWebServiceLocalAddresses = "0.0.0.0,[::]"; + + string[] strLocalAddresses = strWebServiceLocalAddresses.Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries); + List localAddresses = new List(strLocalAddresses.Length); + + for (int i = 0; i < strLocalAddresses.Length; i++) + { + if (IPAddress.TryParse(strLocalAddresses[i], out IPAddress localAddress)) + localAddresses.Add(localAddress); + } + + if (localAddresses.Count > 0) + { + if (_webServiceLocalAddresses.Count != localAddresses.Count) + { + restartWebService = true; + } + else + { + foreach (IPAddress currentlocalAddress in _webServiceLocalAddresses) + { + if (!localAddresses.Contains(currentlocalAddress)) + { + restartWebService = true; + break; + } + } + } + + _webServiceLocalAddresses = localAddresses; + } + } + + int oldWebServiceHttpPort = _webServiceHttpPort; + + string strWebServiceHttpPort = request.QueryString["webServiceHttpPort"]; + if (!string.IsNullOrEmpty(strWebServiceHttpPort)) + { + _webServiceHttpPort = int.Parse(strWebServiceHttpPort); + + if (oldWebServiceHttpPort != _webServiceHttpPort) + restartWebService = true; + } + + string strWebServiceEnableTls = request.QueryString["webServiceEnableTls"]; + if (!string.IsNullOrEmpty(strWebServiceEnableTls)) + { + bool oldWebServiceEnableTls = _webServiceEnableTls; + + _webServiceEnableTls = bool.Parse(strWebServiceEnableTls); + + if (oldWebServiceEnableTls != _webServiceEnableTls) + restartWebService = true; + } + + string strWebServiceHttpToTlsRedirect = request.QueryString["webServiceHttpToTlsRedirect"]; + if (!string.IsNullOrEmpty(strWebServiceHttpToTlsRedirect)) + _webServiceHttpToTlsRedirect = bool.Parse(strWebServiceHttpToTlsRedirect); + + + string strWebServiceTlsPort = request.QueryString["webServiceTlsPort"]; + if (!string.IsNullOrEmpty(strWebServiceTlsPort)) + { + int oldWebServiceTlsPort = _webServiceTlsPort; + + _webServiceTlsPort = int.Parse(strWebServiceTlsPort); + + if (oldWebServiceTlsPort != _webServiceTlsPort) + restartWebService = true; + } + + string strWebServiceTlsCertificatePath = request.QueryString["webServiceTlsCertificatePath"]; + string strWebServiceTlsCertificatePassword = request.QueryString["webServiceTlsCertificatePassword"]; + if (string.IsNullOrEmpty(strWebServiceTlsCertificatePath)) + { + _webServiceTlsCertificatePath = null; + _webServiceTlsCertificatePassword = ""; } else { - if (strTlsCertificatePassword == "************") - strTlsCertificatePassword = _tlsCertificatePassword; + if (strWebServiceTlsCertificatePassword == "************") + strWebServiceTlsCertificatePassword = _webServiceTlsCertificatePassword; - if ((strTlsCertificatePath != _tlsCertificatePath) || (strTlsCertificatePassword != _tlsCertificatePassword)) + if ((strWebServiceTlsCertificatePath != _webServiceTlsCertificatePath) || (strWebServiceTlsCertificatePassword != _webServiceTlsCertificatePassword)) { - LoadTlsCertificate(strTlsCertificatePath, strTlsCertificatePassword); + LoadWebServiceTlsCertificate(strWebServiceTlsCertificatePath, strWebServiceTlsCertificatePassword); - _tlsCertificatePath = strTlsCertificatePath; - _tlsCertificatePassword = strTlsCertificatePassword; + _webServiceTlsCertificatePath = strWebServiceTlsCertificatePath; + _webServiceTlsCertificatePassword = strWebServiceTlsCertificatePassword; + + StartTlsCertificateUpdateTimer(); + } + } + + string enableDnsOverHttp = request.QueryString["enableDnsOverHttp"]; + if (!string.IsNullOrEmpty(enableDnsOverHttp)) + { + bool oldEnableDnsOverHttp = _dnsServer.EnableDnsOverHttp; + + _dnsServer.EnableDnsOverHttp = bool.Parse(enableDnsOverHttp); + + if (oldEnableDnsOverHttp != _dnsServer.EnableDnsOverHttp) + restartDnsService = true; + } + + string strEnableDnsOverTls = request.QueryString["enableDnsOverTls"]; + if (!string.IsNullOrEmpty(strEnableDnsOverTls)) + { + bool oldEnableDnsOverTls = _dnsServer.EnableDnsOverTls; + + _dnsServer.EnableDnsOverTls = bool.Parse(strEnableDnsOverTls); + + if (oldEnableDnsOverTls != _dnsServer.EnableDnsOverTls) + restartDnsService = true; + } + + string strEnableDnsOverHttps = request.QueryString["enableDnsOverHttps"]; + if (!string.IsNullOrEmpty(strEnableDnsOverHttps)) + { + bool oldEnableDnsOverHttps = _dnsServer.EnableDnsOverHttps; + + _dnsServer.EnableDnsOverHttps = bool.Parse(strEnableDnsOverHttps); + + if (oldEnableDnsOverHttps != _dnsServer.EnableDnsOverHttps) + restartDnsService = true; + } + + string strDnsTlsCertificatePath = request.QueryString["dnsTlsCertificatePath"]; + string strDnsTlsCertificatePassword = request.QueryString["dnsTlsCertificatePassword"]; + if (string.IsNullOrEmpty(strDnsTlsCertificatePath)) + { + _dnsTlsCertificatePath = null; + _dnsTlsCertificatePassword = ""; + } + else + { + if (strDnsTlsCertificatePassword == "************") + strDnsTlsCertificatePassword = _dnsTlsCertificatePassword; + + if ((strDnsTlsCertificatePath != _dnsTlsCertificatePath) || (strDnsTlsCertificatePassword != _dnsTlsCertificatePassword)) + { + LoadDnsTlsCertificate(strDnsTlsCertificatePath, strDnsTlsCertificatePassword); + + _dnsTlsCertificatePath = strDnsTlsCertificatePath; + _dnsTlsCertificatePassword = strDnsTlsCertificatePassword; StartTlsCertificateUpdateTimer(); } @@ -1166,13 +1434,7 @@ namespace DnsServerCore string strMaxLogFileDays = request.QueryString["maxLogFileDays"]; if (!string.IsNullOrEmpty(strMaxLogFileDays)) - { - int maxLogFileDays = int.Parse(strMaxLogFileDays); - if (maxLogFileDays < 1) - throw new ArgumentOutOfRangeException("Parameter 'maxLogFileDays' must be greater than 1."); - - _log.MaxLogFileDays = maxLogFileDays; - } + _log.MaxLogFileDays = int.Parse(strMaxLogFileDays); string strAllowRecursion = request.QueryString["allowRecursion"]; if (!string.IsNullOrEmpty(strAllowRecursion)) @@ -1285,13 +1547,13 @@ namespace DnsServerCore string[] strBlockListUrlList = strBlockListUrls.Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries); - if (oldWebServicePort != _webServicePort) + if (oldWebServiceHttpPort != _webServiceHttpPort) { for (int i = 0; i < strBlockListUrlList.Length; i++) { - if (strBlockListUrlList[i].Contains("http://localhost:" + oldWebServicePort + "/blocklist.txt")) + if (strBlockListUrlList[i].Contains("http://localhost:" + oldWebServiceHttpPort + "/blocklist.txt")) { - strBlockListUrlList[i] = "http://localhost:" + _webServicePort + "/blocklist.txt"; + strBlockListUrlList[i] = "http://localhost:" + _webServiceHttpPort + "/blocklist.txt"; updated = true; break; } @@ -1345,12 +1607,64 @@ namespace DnsServerCore _blockListUpdateIntervalHours = blockListUpdateIntervalHours; } - _log.Write(GetRequestRemoteEndPoint(request), "[" + GetSession(request).Username + "] DNS Settings were updated {serverDomain: " + _dnsServer.ServerDomain + "; dnsServerLocalEndPoints: " + strDnsServerLocalEndPoints + "; webServicePort: " + _webServicePort + "; enableDnsOverHttp: " + _dnsServer.EnableDnsOverHttp + "; enableDnsOverTls: " + _dnsServer.EnableDnsOverTls + "; enableDnsOverHttps: " + _dnsServer.EnableDnsOverHttps + "; tlsCertificatePath: " + _tlsCertificatePath + "; preferIPv6: " + _dnsServer.PreferIPv6 + "; logQueries: " + (_dnsServer.QueryLogManager != null) + "; allowRecursion: " + _dnsServer.AllowRecursion + "; allowRecursionOnlyForPrivateNetworks: " + _dnsServer.AllowRecursionOnlyForPrivateNetworks + "; proxyType: " + strProxyType + "; forwarders: " + strForwarders + "; forwarderProtocol: " + strForwarderProtocol + "; blockListUrl: " + strBlockListUrls + ";}"); - SaveConfigFile(); _log.Save(); + _log.Write(GetRequestRemoteEndPoint(request), "[" + GetSession(request).Username + "] DNS Settings were updated {dnsServerDomain: " + _dnsServer.ServerDomain + "; dnsServerLocalEndPoints: " + strDnsServerLocalEndPoints + "; webServiceLocalAddresses: " + strWebServiceLocalAddresses + "; webServiceHttpPort: " + _webServiceHttpPort + "; webServiceEnableTls: " + strWebServiceEnableTls + "; webServiceHttpToTlsRedirect: " + strWebServiceHttpToTlsRedirect + "; webServiceTlsPort: " + strWebServiceTlsPort + "; webServiceTlsCertificatePath: " + strWebServiceTlsCertificatePath + "; enableDnsOverHttp: " + _dnsServer.EnableDnsOverHttp + "; enableDnsOverTls: " + _dnsServer.EnableDnsOverTls + "; enableDnsOverHttps: " + _dnsServer.EnableDnsOverHttps + "; dnsTlsCertificatePath: " + _dnsTlsCertificatePath + "; preferIPv6: " + _dnsServer.PreferIPv6 + "; enableLogging: " + strEnableLogging + "; logQueries: " + (_dnsServer.QueryLogManager != null) + "; useLocalTime: " + strUseLocalTime + "; logFolder: " + strLogFolder + "; maxLogFileDays: " + strMaxLogFileDays + "; allowRecursion: " + _dnsServer.AllowRecursion + "; allowRecursionOnlyForPrivateNetworks: " + _dnsServer.AllowRecursionOnlyForPrivateNetworks + "; randomizeName: " + strRandomizeName + "; serveStale: " + strServeStale + "; serveStaleTtl: " + strServeStaleTtl + "; cachePrefetchEligibility: " + strCachePrefetchEligibility + "; cachePrefetchTrigger: " + strCachePrefetchTrigger + "; cachePrefetchSampleIntervalInMinutes: " + strCachePrefetchSampleIntervalInMinutes + "; cachePrefetchSampleEligibilityHitsPerHour: " + strCachePrefetchSampleEligibilityHitsPerHour + "; proxyType: " + strProxyType + "; forwarders: " + strForwarders + "; forwarderProtocol: " + strForwarderProtocol + "; blockListUrl: " + strBlockListUrls + "; blockListUpdateIntervalHours: " + strBlockListUpdateIntervalHours + ";}"); + + if ((_webServiceTlsCertificatePath == null) && (_dnsTlsCertificatePath == null)) + StopTlsCertificateUpdateTimer(); + GetDnsSettings(jsonWriter); + + RestartService(restartDnsService, restartWebService); + } + + private void RestartService(bool restartDnsService, bool restartWebService) + { + if (restartDnsService) + { + _ = Task.Run(delegate () + { + _log.Write("Attempting to restart DNS service."); + + try + { + _dnsServer.Stop(); + _dnsServer.Start(); + + _log.Write("DNS service was restarted successfully."); + } + catch (Exception ex) + { + _log.Write("Failed to restart DNS service."); + _log.Write(ex); + } + }); + } + + if (restartWebService) + { + _ = Task.Run(async delegate () + { + await Task.Delay(2000); //wait for this HTTP response to be delivered before stopping web server + + _log.Write("Attempting to restart web service."); + + try + { + StopWebService(); + StartWebService(); + + _log.Write("Web service was restarted successfully."); + } + catch (Exception ex) + { + _log.Write("Failed to restart web service."); + _log.Write(ex); + } + }); + } } private async Task BackupSettingsAsync(HttpListenerRequest request, HttpListenerResponse response) @@ -1534,14 +1848,16 @@ namespace DnsServerCore { File.Delete(tmpFile); } - catch - { } + catch (Exception ex) + { + _log.Write(ex); + } } _log.Write(GetRequestRemoteEndPoint(request), "[" + GetSession(request).Username + "] Settings backup zip file was exported."); } - private async Task RestoreSettingsAsync(HttpListenerRequest request) + private async Task RestoreSettingsAsync(HttpListenerRequest request, JsonTextWriter jsonWriter) { bool blockLists = false; bool logs = false; @@ -1646,6 +1962,13 @@ namespace DnsServerCore if (logs) { + //delete existing log files + string[] logFiles = Directory.GetFiles(_log.LogFolderAbsolutePath, "*.log", SearchOption.TopDirectoryOnly); + foreach (string logFile in logFiles) + { + File.Delete(logFile); + } + //extract log files from backup foreach (ZipArchiveEntry entry in backupZip.Entries) { @@ -1666,6 +1989,13 @@ namespace DnsServerCore if (blockLists) { + //delete existing block list files + string[] blockListFiles = Directory.GetFiles(Path.Combine(_configFolder, "blocklists"), "*", SearchOption.TopDirectoryOnly); + foreach (string blockListFile in blockListFiles) + { + File.Delete(blockListFile); + } + //extract block list files from backup foreach (ZipArchiveEntry entry in backupZip.Entries) { @@ -1789,9 +2119,16 @@ namespace DnsServerCore { File.Delete(tmpFile); } - catch - { } + catch (Exception ex) + { + _log.Write(ex); + } } + + if (dnsSettings) + RestartService(true, true); + + GetDnsSettings(jsonWriter); } private void ForceUpdateBlockLists(HttpListenerRequest request) @@ -1848,11 +2185,17 @@ namespace DnsServerCore jsonWriter.WriteValue(item.Value); } + jsonWriter.WritePropertyName("zones"); + jsonWriter.WriteValue(_dnsServer.AuthZoneManager.TotalZones); + jsonWriter.WritePropertyName("allowedZones"); jsonWriter.WriteValue(_dnsServer.AllowedZoneManager.TotalZonesAllowed); jsonWriter.WritePropertyName("blockedZones"); - jsonWriter.WriteValue(_dnsServer.BlockedZoneManager.TotalZonesBlocked + _dnsServer.BlockListZoneManager.TotalZonesBlocked); + jsonWriter.WriteValue(_dnsServer.BlockedZoneManager.TotalZonesBlocked); + + jsonWriter.WritePropertyName("blockListZones"); + jsonWriter.WriteValue(_dnsServer.BlockListZoneManager.TotalZonesBlocked); jsonWriter.WriteEndObject(); } @@ -2159,6 +2502,160 @@ namespace DnsServerCore jsonWriter.WriteEndObject(); } + private async Task GetTopStats(HttpListenerRequest request, JsonTextWriter jsonWriter) + { + string strType = request.QueryString["type"]; + if (string.IsNullOrEmpty(strType)) + strType = "lastHour"; + + string strStatsType = request.QueryString["statsType"]; + if (string.IsNullOrEmpty(strStatsType)) + throw new WebServiceException("Parameter 'statsType' missing."); + + string strLimit = request.QueryString["limit"]; + if (string.IsNullOrEmpty(strLimit)) + strLimit = "1000"; + + TopStatsType statsType = (TopStatsType)Enum.Parse(typeof(TopStatsType), strStatsType, true); + int limit = int.Parse(strLimit); + + List> topStatsData; + + switch (strType) + { + case "lastHour": + topStatsData = _dnsServer.StatsManager.GetLastHourTopStats(statsType, limit); + break; + + case "lastDay": + topStatsData = _dnsServer.StatsManager.GetLastDayTopStats(statsType, limit); + break; + + case "lastWeek": + topStatsData = _dnsServer.StatsManager.GetLastWeekTopStats(statsType, limit); + break; + + case "lastMonth": + topStatsData = _dnsServer.StatsManager.GetLastMonthTopStats(statsType, limit); + break; + + case "lastYear": + topStatsData = _dnsServer.StatsManager.GetLastYearTopStats(statsType, limit); + break; + + default: + throw new WebServiceException("Unknown stats type requested: " + strType); + } + + switch (statsType) + { + case TopStatsType.TopClients: + { + IDictionary clientIpMap = _dhcpServer.GetAddressClientMap(); + + jsonWriter.WritePropertyName("topClients"); + jsonWriter.WriteStartArray(); + + foreach (KeyValuePair item in topStatsData) + { + jsonWriter.WriteStartObject(); + + jsonWriter.WritePropertyName("name"); + jsonWriter.WriteValue(item.Key); + + if (clientIpMap.TryGetValue(item.Key, out string clientDomain)) + { + jsonWriter.WritePropertyName("domain"); + jsonWriter.WriteValue(clientDomain); + } + else + { + IPAddress address = IPAddress.Parse(item.Key); + + if (IPAddress.IsLoopback(address)) + { + jsonWriter.WritePropertyName("domain"); + jsonWriter.WriteValue("localhost"); + } + else + { + try + { + DnsDatagram ptrResponse = await _dnsServer.DirectQueryAsync(new DnsQuestionRecord(address, DnsClass.IN), 200); + if ((ptrResponse != null) && (ptrResponse.Answer.Count > 0)) + { + IReadOnlyList ptrDomains = DnsClient.ParseResponsePTR(ptrResponse); + if (ptrDomains != null) + { + jsonWriter.WritePropertyName("domain"); + jsonWriter.WriteValue(ptrDomains[0]); + } + } + } + catch + { } + } + } + + jsonWriter.WritePropertyName("hits"); + jsonWriter.WriteValue(item.Value); + + jsonWriter.WriteEndObject(); + } + + jsonWriter.WriteEndArray(); + } + break; + + case TopStatsType.TopDomains: + { + jsonWriter.WritePropertyName("topDomains"); + jsonWriter.WriteStartArray(); + + foreach (KeyValuePair item in topStatsData) + { + jsonWriter.WriteStartObject(); + + jsonWriter.WritePropertyName("name"); + jsonWriter.WriteValue(item.Key); + + jsonWriter.WritePropertyName("hits"); + jsonWriter.WriteValue(item.Value); + + jsonWriter.WriteEndObject(); + } + + jsonWriter.WriteEndArray(); + } + break; + + case TopStatsType.TopBlockedDomains: + { + jsonWriter.WritePropertyName("topBlockedDomains"); + jsonWriter.WriteStartArray(); + + foreach (KeyValuePair item in topStatsData) + { + jsonWriter.WriteStartObject(); + + jsonWriter.WritePropertyName("name"); + jsonWriter.WriteValue(item.Key); + + jsonWriter.WritePropertyName("hits"); + jsonWriter.WriteValue(item.Value); + + jsonWriter.WriteEndObject(); + } + + jsonWriter.WriteEndArray(); + } + break; + + default: + throw new NotSupportedException(); + } + } + private void FlushCache(HttpListenerRequest request) { _dnsServer.CacheZoneManager.Flush(); @@ -2689,7 +3186,7 @@ namespace DnsServerCore _log.Write(GetRequestRemoteEndPoint(request), "[" + GetSession(request).Username + "] " + zoneInfo.Type.ToString() + " zone was deleted: " + domain); - _dnsServer.AuthZoneManager.DeleteZoneFile(domain); + _dnsServer.AuthZoneManager.DeleteZoneFile(zoneInfo.Name); } private void EnableZone(HttpListenerRequest request) @@ -2710,9 +3207,12 @@ namespace DnsServerCore zoneInfo.Disabled = false; - _log.Write(GetRequestRemoteEndPoint(request), "[" + GetSession(request).Username + "] " + zoneInfo.Type.ToString() + " zone was enabled: " + domain); + _log.Write(GetRequestRemoteEndPoint(request), "[" + GetSession(request).Username + "] " + zoneInfo.Type.ToString() + " zone was enabled: " + zoneInfo.Name); _dnsServer.AuthZoneManager.SaveZoneFile(zoneInfo.Name); + + //delete cache for this zone to allow rebuilding cache data as needed by stub or forwarder zones + _dnsServer.CacheZoneManager.DeleteZone(zoneInfo.Name); } private void DisableZone(HttpListenerRequest request) @@ -2733,7 +3233,7 @@ namespace DnsServerCore zoneInfo.Disabled = true; - _log.Write(GetRequestRemoteEndPoint(request), "[" + GetSession(request).Username + "] " + zoneInfo.Type.ToString() + " zone was disabled: " + domain); + _log.Write(GetRequestRemoteEndPoint(request), "[" + GetSession(request).Username + "] " + zoneInfo.Type.ToString() + " zone was disabled: " + zoneInfo.Name); _dnsServer.AuthZoneManager.SaveZoneFile(zoneInfo.Name); } @@ -4590,16 +5090,34 @@ namespace DnsServerCore { _tlsCertificateUpdateTimer = new Timer(delegate (object state) { - try + if (!string.IsNullOrEmpty(_webServiceTlsCertificatePath)) { - FileInfo fileInfo = new FileInfo(_tlsCertificatePath); + try + { + FileInfo fileInfo = new FileInfo(_webServiceTlsCertificatePath); - if (fileInfo.Exists && (fileInfo.LastWriteTimeUtc != _tlsCertificateLastModifiedOn)) - LoadTlsCertificate(_tlsCertificatePath, _tlsCertificatePassword); + if (fileInfo.Exists && (fileInfo.LastWriteTimeUtc != _webServiceTlsCertificateLastModifiedOn)) + LoadWebServiceTlsCertificate(_webServiceTlsCertificatePath, _webServiceTlsCertificatePassword); + } + catch (Exception ex) + { + _log.Write("DNS Server encountered an error while updating Web Service TLS Certificate: " + _webServiceTlsCertificatePath + "\r\n" + ex.ToString()); + } } - catch (Exception ex) + + if (!string.IsNullOrEmpty(_dnsTlsCertificatePath)) { - _log.Write("DNS Server encountered an error while updating TLS Certificate: " + _tlsCertificatePath + "\r\n" + ex.ToString()); + try + { + FileInfo fileInfo = new FileInfo(_dnsTlsCertificatePath); + + if (fileInfo.Exists && (fileInfo.LastWriteTimeUtc != _dnsTlsCertificateLastModifiedOn)) + LoadDnsTlsCertificate(_dnsTlsCertificatePath, _dnsTlsCertificatePassword); + } + catch (Exception ex) + { + _log.Write("DNS Server encountered an error while updating DNS Server TLS Certificate: " + _dnsTlsCertificatePath + "\r\n" + ex.ToString()); + } } }, null, TLS_CERTIFICATE_UPDATE_TIMER_INITIAL_INTERVAL, TLS_CERTIFICATE_UPDATE_TIMER_INTERVAL); @@ -4615,23 +5133,44 @@ namespace DnsServerCore } } - private void LoadTlsCertificate(string tlsCertificatePath, string tlsCertificatePassword) + private void LoadWebServiceTlsCertificate(string tlsCertificatePath, string tlsCertificatePassword) { FileInfo fileInfo = new FileInfo(tlsCertificatePath); if (!fileInfo.Exists) - throw new ArgumentException("Tls certificate file does not exists: " + tlsCertificatePath); + throw new ArgumentException("Web Service TLS certificate file does not exists: " + tlsCertificatePath); if (Path.GetExtension(tlsCertificatePath) != ".pfx") - throw new ArgumentException("Tls certificate file must be PKCS #12 formatted with .pfx extension: " + tlsCertificatePath); + throw new ArgumentException("Web Service TLS certificate file must be PKCS #12 formatted with .pfx extension: " + tlsCertificatePath); X509Certificate2 certificate = new X509Certificate2(tlsCertificatePath, tlsCertificatePassword); if (!certificate.Verify()) - throw new ArgumentException("Tls certificate is invalid."); + throw new ArgumentException("Web Service TLS certificate is invalid."); + + _webServiceTlsCertificate = certificate; + _webServiceTlsCertificateLastModifiedOn = fileInfo.LastWriteTimeUtc; + + _log.Write("Web Service TLS certificate was loaded: " + tlsCertificatePath); + } + + private void LoadDnsTlsCertificate(string tlsCertificatePath, string tlsCertificatePassword) + { + FileInfo fileInfo = new FileInfo(tlsCertificatePath); + + if (!fileInfo.Exists) + throw new ArgumentException("DNS Server TLS certificate file does not exists: " + tlsCertificatePath); + + if (Path.GetExtension(tlsCertificatePath) != ".pfx") + throw new ArgumentException("DNS Server TLS certificate file must be PKCS #12 formatted with .pfx extension: " + tlsCertificatePath); + + X509Certificate2 certificate = new X509Certificate2(tlsCertificatePath, tlsCertificatePassword); + + if (!certificate.Verify()) + throw new ArgumentException("DNS Server TLS certificate is invalid."); _dnsServer.Certificate = certificate; - _tlsCertificateLastModifiedOn = fileInfo.LastWriteTimeUtc; + _dnsTlsCertificateLastModifiedOn = fileInfo.LastWriteTimeUtc; _log.Write("DNS Server TLS certificate was loaded: " + tlsCertificatePath); } @@ -4680,7 +5219,46 @@ namespace DnsServerCore case 12: case 13: _dnsServer.ServerDomain = bR.ReadShortString(); - _webServicePort = bR.ReadInt32(); + _webServiceHttpPort = bR.ReadInt32(); + + if (version >= 13) + { + { + int count = bR.ReadByte(); + if (count > 0) + { + IPAddress[] localAddresses = new IPAddress[count]; + + for (int i = 0; i < count; i++) + localAddresses[i] = IPAddressExtension.Parse(bR); + + _webServiceLocalAddresses = localAddresses; + } + } + + _webServiceTlsPort = bR.ReadInt32(); + _webServiceEnableTls = bR.ReadBoolean(); + _webServiceHttpToTlsRedirect = bR.ReadBoolean(); + _webServiceTlsCertificatePath = bR.ReadShortString(); + _webServiceTlsCertificatePassword = bR.ReadShortString(); + + if (_webServiceTlsCertificatePath.Length == 0) + _webServiceTlsCertificatePath = null; + + if (_webServiceTlsCertificatePath != null) + { + try + { + LoadWebServiceTlsCertificate(_webServiceTlsCertificatePath, _webServiceTlsCertificatePassword); + } + catch (Exception ex) + { + _log.Write("DNS Server encountered an error while loading Web Service TLS certificate: " + _webServiceTlsCertificatePath + "\r\n" + ex.ToString()); + } + + StartTlsCertificateUpdateTimer(); + } + } _dnsServer.PreferIPv6 = bR.ReadBoolean(); @@ -4847,21 +5425,21 @@ namespace DnsServerCore _dnsServer.EnableDnsOverHttp = bR.ReadBoolean(); _dnsServer.EnableDnsOverTls = bR.ReadBoolean(); _dnsServer.EnableDnsOverHttps = bR.ReadBoolean(); - _tlsCertificatePath = bR.ReadShortString(); - _tlsCertificatePassword = bR.ReadShortString(); + _dnsTlsCertificatePath = bR.ReadShortString(); + _dnsTlsCertificatePassword = bR.ReadShortString(); - if (_tlsCertificatePath.Length == 0) - _tlsCertificatePath = null; + if (_dnsTlsCertificatePath.Length == 0) + _dnsTlsCertificatePath = null; - if (_tlsCertificatePath != null) + if (_dnsTlsCertificatePath != null) { try { - LoadTlsCertificate(_tlsCertificatePath, _tlsCertificatePassword); + LoadDnsTlsCertificate(_dnsTlsCertificatePath, _dnsTlsCertificatePassword); } catch (Exception ex) { - _log.Write("DNS Server encountered an error while loading TLS certificate: " + _tlsCertificatePath + "\r\n" + ex.ToString()); + _log.Write("DNS Server encountered an error while loading DNS Server TLS certificate: " + _dnsTlsCertificatePath + "\r\n" + ex.ToString()); } StartTlsCertificateUpdateTimer(); @@ -4871,7 +5449,7 @@ namespace DnsServerCore break; default: - throw new InvalidDataException("DnsServer config version not supported."); + throw new InvalidDataException("DNS Server config version not supported."); } } @@ -4900,8 +5478,6 @@ namespace DnsServerCore _log.Write("DNS Server config file was not found: " + configFile); _log.Write("DNS Server is restoring default config file."); - _webServicePort = 5380; - SetCredentials("admin", "admin"); _dnsServer.AllowRecursion = true; @@ -4931,7 +5507,28 @@ namespace DnsServerCore bW.Write((byte)13); //version bW.WriteShortString(_dnsServer.ServerDomain); - bW.Write(_webServicePort); + bW.Write(_webServiceHttpPort); + + { + bW.Write(Convert.ToByte(_webServiceLocalAddresses.Count)); + + foreach (IPAddress localAddress in _webServiceLocalAddresses) + localAddress.WriteTo(bW); + } + + bW.Write(_webServiceTlsPort); + bW.Write(_webServiceEnableTls); + bW.Write(_webServiceHttpToTlsRedirect); + + if (_webServiceTlsCertificatePath == null) + bW.WriteShortString(string.Empty); + else + bW.WriteShortString(_webServiceTlsCertificatePath); + + if (_webServiceTlsCertificatePassword == null) + bW.WriteShortString(string.Empty); + else + bW.WriteShortString(_webServiceTlsCertificatePassword); bW.Write(_dnsServer.PreferIPv6); @@ -5026,15 +5623,15 @@ namespace DnsServerCore bW.Write(_dnsServer.EnableDnsOverTls); bW.Write(_dnsServer.EnableDnsOverHttps); - if (_tlsCertificatePath == null) + if (_dnsTlsCertificatePath == null) bW.WriteShortString(string.Empty); else - bW.WriteShortString(_tlsCertificatePath); + bW.WriteShortString(_dnsTlsCertificatePath); - if (_tlsCertificatePassword == null) + if (_dnsTlsCertificatePassword == null) bW.WriteShortString(string.Empty); else - bW.WriteShortString(_tlsCertificatePassword); + bW.WriteShortString(_dnsTlsCertificatePassword); //write config mS.Position = 0; @@ -5048,6 +5645,113 @@ namespace DnsServerCore _log.Write("DNS Server config file was saved: " + configFile); } + private void StartWebService() + { + //HTTP service + try + { + string webServiceHostname = null; + + _webService = new HttpListener(); + + foreach (IPAddress webServiceLocalAddress in _webServiceLocalAddresses) + { + string host; + + if (webServiceLocalAddress.Equals(IPAddress.Any) || webServiceLocalAddress.Equals(IPAddress.IPv6Any)) + { + host = "+"; + } + else + { + if (webServiceLocalAddress.AddressFamily == AddressFamily.InterNetworkV6) + host = "[" + webServiceLocalAddress.ToString() + "]"; + else + host = webServiceLocalAddress.ToString(); + + if (webServiceHostname == null) + webServiceHostname = host; + } + + _webService.Prefixes.Add("http://" + host + ":" + _webServiceHttpPort + "/"); + } + + _webService.Start(); + + _webServiceHostname = webServiceHostname ?? Environment.MachineName.ToLower(); + } + catch (Exception ex) + { + _log.Write("Web Service failed to bind using default hostname. Attempting to bind again using 'localhost' hostname.\r\n" + ex.ToString()); + + _webService = new HttpListener(); + _webService.Prefixes.Add("http://localhost:" + _webServiceHttpPort + "/"); + _webService.Start(); + + _webServiceHostname = "localhost"; + } + + _webService.IgnoreWriteExceptions = true; + + _ = Task.Factory.StartNew(delegate () + { + return AcceptWebRequestAsync(); + }, CancellationToken.None, TaskCreationOptions.DenyChildAttach, _webServiceTaskScheduler); + + _log.Write(new IPEndPoint(IPAddress.Any, _webServiceHttpPort), "HTTP Web Service was started successfully."); + + //TLS service + if (_webServiceEnableTls && (_webServiceTlsCertificate != null)) + { + List webServiceTlsListeners = new List(); + + try + { + foreach (IPAddress webServiceLocalAddress in _webServiceLocalAddresses) + { + Socket tlsListener = new Socket(webServiceLocalAddress.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + + tlsListener.Bind(new IPEndPoint(webServiceLocalAddress, _webServiceTlsPort)); + tlsListener.Listen(10); + + webServiceTlsListeners.Add(tlsListener); + } + + foreach (Socket tlsListener in webServiceTlsListeners) + { + _ = Task.Factory.StartNew(delegate () + { + return AcceptTlsWebRequestAsync(tlsListener); + }, CancellationToken.None, TaskCreationOptions.DenyChildAttach, _webServiceTaskScheduler); + } + + _webServiceTlsListeners = webServiceTlsListeners; + + _log.Write(new IPEndPoint(IPAddress.Any, _webServiceHttpPort), "TLS Web Service was started successfully."); + } + catch (Exception ex) + { + _log.Write("TLS Web Service failed to start.\r\n" + ex.ToString()); + + foreach (Socket tlsListener in webServiceTlsListeners) + tlsListener.Dispose(); + } + } + } + + private void StopWebService() + { + _webService.Stop(); + + if (_webServiceTlsListeners != null) + { + foreach (Socket tlsListener in _webServiceTlsListeners) + tlsListener.Dispose(); + + _webServiceTlsListeners = null; + } + } + #endregion #region public @@ -5065,12 +5769,12 @@ namespace DnsServerCore try { //get initial server domain - string serverDomain = Environment.MachineName.ToLower(); - if (!DnsClient.IsDomainNameValid(serverDomain)) - serverDomain = "dns-server-1"; //use this name instead since machine name is not a valid domain name + string dnsServerDomain = Environment.MachineName.ToLower(); + if (!DnsClient.IsDomainNameValid(dnsServerDomain)) + dnsServerDomain = "dns-server-1"; //use this name instead since machine name is not a valid domain name //init dns server - _dnsServer = new DnsServer(serverDomain, _configFolder, Path.Combine(_appFolder, "dohwww"), _log); + _dnsServer = new DnsServer(dnsServerDomain, _configFolder, Path.Combine(_appFolder, "dohwww"), _log); //init dhcp server _dhcpServer = new DhcpServer(Path.Combine(_configFolder, "scopes"), _log); @@ -5122,39 +5826,15 @@ namespace DnsServerCore _dhcpServer.Start(); //start web service - try - { - _webService = new HttpListener(); - _webService.Prefixes.Add("http://+:" + _webServicePort + "/"); - _webService.Start(); - - _webServiceHostname = Environment.MachineName.ToLower(); - } - catch (Exception ex) - { - _log.Write("Web Service failed to bind using default hostname. Attempting to bind again using 'localhost' hostname.\r\n" + ex.ToString()); - - _webService = new HttpListener(); - _webService.Prefixes.Add("http://localhost:" + _webServicePort + "/"); - _webService.Start(); - - _webServiceHostname = "localhost"; - } - - _webService.IgnoreWriteExceptions = true; - - _webServiceThread = new Thread(AcceptWebRequestAsync); - _webServiceThread.Name = "WebService"; - _webServiceThread.IsBackground = true; - _webServiceThread.Start(); + StartWebService(); _state = ServiceState.Running; - _log.Write(new IPEndPoint(IPAddress.Any, _webServicePort), "Web Service (v" + _currentVersion.ToString() + ") was started successfully."); + _log.Write("DNS Server (v" + _currentVersion.ToString() + ") was started successfully."); } catch (Exception ex) { - _log.Write("Failed to start Web Service (v" + _currentVersion.ToString() + ")\r\n" + ex.ToString()); + _log.Write("Failed to start DNS Server (v" + _currentVersion.ToString() + ")\r\n" + ex.ToString()); throw; } } @@ -5168,7 +5848,7 @@ namespace DnsServerCore try { - _webService.Stop(); + StopWebService(); _dnsServer.Stop(); _dhcpServer.Stop(); @@ -5177,11 +5857,11 @@ namespace DnsServerCore _state = ServiceState.Stopped; - _log.Write(new IPEndPoint(IPAddress.Loopback, _webServicePort), "Web Service (v" + _currentVersion.ToString() + ") was stopped successfully."); + _log.Write("DNS Server (v" + _currentVersion.ToString() + ") was stopped successfully."); } catch (Exception ex) { - _log.Write("Failed to stop Web Service (v" + _currentVersion.ToString() + ")\r\n" + ex.ToString()); + _log.Write("Failed to stop DNS Server (v" + _currentVersion.ToString() + ")\r\n" + ex.ToString()); throw; } } @@ -5193,8 +5873,8 @@ namespace DnsServerCore public string ConfigFolder { get { return _configFolder; } } - public int WebServicePort - { get { return _webServicePort; } } + public int WebServiceHttpPort + { get { return _webServiceHttpPort; } } public string WebServiceHostname { get { return _webServiceHostname; } }