diff --git a/DnsServerCore/Dns/ZoneManagers/BlockListZoneManager.cs b/DnsServerCore/Dns/ZoneManagers/BlockListZoneManager.cs index ea9d172c..c36f43ff 100644 --- a/DnsServerCore/Dns/ZoneManagers/BlockListZoneManager.cs +++ b/DnsServerCore/Dns/ZoneManagers/BlockListZoneManager.cs @@ -38,6 +38,7 @@ namespace DnsServerCore.Dns.ZoneManagers readonly DnsServer _dnsServer; readonly string _localCacheFolder; + readonly List _allowListUrls = new List(); readonly List _blockListUrls = new List(); IReadOnlyDictionary> _blockListZone = new Dictionary>(); @@ -105,7 +106,7 @@ namespace DnsServerCore.Dns.ZoneManagers return word; } - private Queue ReadBlockListFile(Uri blockListUrl) + private Queue ReadListFile(Uri listUrl, bool isAllow) { Queue domains = new Queue(); @@ -113,9 +114,9 @@ namespace DnsServerCore.Dns.ZoneManagers { LogManager log = _dnsServer.LogManager; if (log != null) - log.Write("DNS Server is reading block list from: " + blockListUrl.AbsoluteUri); + log.Write("DNS Server is reading " + (isAllow ? "allow" : "block") + " list from: " + listUrl.AbsoluteUri); - using (FileStream fS = new FileStream(GetBlockListFilePath(blockListUrl), FileMode.Open, FileAccess.Read)) + using (FileStream fS = new FileStream(GetBlockListFilePath(listUrl), FileMode.Open, FileAccess.Read)) { //parse hosts file and populate block zone StreamReader sR = new StreamReader(fS, true); @@ -185,13 +186,13 @@ namespace DnsServerCore.Dns.ZoneManagers } if (log != null) - log.Write("DNS Server block list file was read (" + domains.Count + " domains) from: " + blockListUrl.AbsoluteUri); + log.Write("DNS Server " + (isAllow ? "allow" : "block") + " list file was read (" + domains.Count + " domains) from: " + listUrl.AbsoluteUri); } catch (Exception ex) { LogManager log = _dnsServer.LogManager; if (log != null) - log.Write("DNS Server failed to read block list from: " + blockListUrl.AbsoluteUri + "\r\n" + ex.ToString()); + log.Write("DNS Server failed to read " + (isAllow ? "allow" : "block") + " list from: " + listUrl.AbsoluteUri + "\r\n" + ex.ToString()); } return domains; @@ -225,6 +226,21 @@ namespace DnsServerCore.Dns.ZoneManagers public void LoadBlockLists() { + //read all allowed domains in dictionary + Dictionary allowedDomains = new Dictionary(); + + foreach (Uri allowListUri in _allowListUrls) + { + Queue queue = ReadListFile(allowListUri, true); + + while (queue.Count > 0) + { + string domain = queue.Dequeue(); + + allowedDomains.TryAdd(domain, null); + } + } + //read all block lists in a queue Dictionary> blockListQueues = new Dictionary>(_blockListUrls.Count); int totalDomains = 0; @@ -233,7 +249,7 @@ namespace DnsServerCore.Dns.ZoneManagers { if (!blockListQueues.ContainsKey(blockListUrl)) { - Queue blockListQueue = ReadBlockListFile(blockListUrl); + Queue blockListQueue = ReadListFile(blockListUrl, false); totalDomains += blockListQueue.Count; blockListQueues.Add(blockListUrl, blockListQueue); } @@ -250,6 +266,9 @@ namespace DnsServerCore.Dns.ZoneManagers { string domain = queue.Dequeue(); + if (allowedDomains.TryGetValue(domain, out _)) + continue; //domain is in allowed list so skip adding it to block list zone + if (!blockListZone.TryGetValue(domain, out List blockLists)) { blockLists = new List(2); @@ -278,15 +297,15 @@ namespace DnsServerCore.Dns.ZoneManagers bool downloaded = false; bool notmodified = false; - foreach (Uri blockListUrl in _blockListUrls) + async Task DownloadListUrlAsync(Uri listUrl, bool isAllowList) { - string blockListFilePath = GetBlockListFilePath(blockListUrl); - string blockListDownloadFilePath = blockListFilePath + ".downloading"; + string listFilePath = GetBlockListFilePath(listUrl); + string listDownloadFilePath = listFilePath + ".downloading"; try { - if (File.Exists(blockListDownloadFilePath)) - File.Delete(blockListDownloadFilePath); + if (File.Exists(listDownloadFilePath)) + File.Delete(listDownloadFilePath); HttpClientHandler handler = new HttpClientHandler(); handler.Proxy = _dnsServer.Proxy; @@ -294,33 +313,33 @@ namespace DnsServerCore.Dns.ZoneManagers using (HttpClient http = new HttpClient(handler)) { - if (File.Exists(blockListFilePath)) - http.DefaultRequestHeaders.IfModifiedSince = File.GetLastWriteTimeUtc(blockListFilePath); + if (File.Exists(listFilePath)) + http.DefaultRequestHeaders.IfModifiedSince = File.GetLastWriteTimeUtc(listFilePath); - HttpResponseMessage httpResponse = await http.GetAsync(blockListUrl); + HttpResponseMessage httpResponse = await http.GetAsync(listUrl); switch (httpResponse.StatusCode) { case HttpStatusCode.OK: { - using (FileStream fS = new FileStream(blockListDownloadFilePath, FileMode.Create, FileAccess.Write)) + using (FileStream fS = new FileStream(listDownloadFilePath, FileMode.Create, FileAccess.Write)) { Stream httpStream = await httpResponse.Content.ReadAsStreamAsync(); await httpStream.CopyToAsync(fS); } - if (File.Exists(blockListFilePath)) - File.Delete(blockListFilePath); + if (File.Exists(listFilePath)) + File.Delete(listFilePath); - File.Move(blockListDownloadFilePath, blockListFilePath); + File.Move(listDownloadFilePath, listFilePath); if (httpResponse.Content.Headers.LastModified != null) - File.SetLastWriteTimeUtc(blockListFilePath, httpResponse.Content.Headers.LastModified.Value.UtcDateTime); + File.SetLastWriteTimeUtc(listFilePath, httpResponse.Content.Headers.LastModified.Value.UtcDateTime); downloaded = true; LogManager log = _dnsServer.LogManager; if (log != null) - log.Write("DNS Server successfully downloaded block list (" + WebUtilities.GetFormattedSize(new FileInfo(blockListFilePath).Length) + "): " + blockListUrl.AbsoluteUri); + log.Write("DNS Server successfully downloaded " + (isAllowList ? "allow" : "block") + " list (" + WebUtilities.GetFormattedSize(new FileInfo(listFilePath).Length) + "): " + listUrl.AbsoluteUri); } break; @@ -330,7 +349,7 @@ namespace DnsServerCore.Dns.ZoneManagers LogManager log = _dnsServer.LogManager; if (log != null) - log.Write("DNS Server successfully checked for a new update of the block list: " + blockListUrl.AbsoluteUri); + log.Write("DNS Server successfully checked for a new update of the " + (isAllowList ? "allow" : "block") + " list: " + listUrl.AbsoluteUri); } break; @@ -343,10 +362,20 @@ namespace DnsServerCore.Dns.ZoneManagers { LogManager log = _dnsServer.LogManager; if (log != null) - log.Write("DNS Server failed to download block list and will use previously downloaded file (if available): " + blockListUrl.AbsoluteUri + "\r\n" + ex.ToString()); + log.Write("DNS Server failed to download " + (isAllowList ? "allow" : "block") + " list and will use previously downloaded file (if available): " + listUrl.AbsoluteUri + "\r\n" + ex.ToString()); } } + List tasks = new List(); + + foreach (Uri allowListUrl in _allowListUrls) + tasks.Add(DownloadListUrlAsync(allowListUrl, true)); + + foreach (Uri blockListUrl in _blockListUrls) + tasks.Add(DownloadListUrlAsync(blockListUrl, false)); + + await Task.WhenAll(tasks); + if (downloaded) { LoadBlockLists(); @@ -412,6 +441,9 @@ namespace DnsServerCore.Dns.ZoneManagers set { UpdateServerDomain(value); } } + public List AllowListUrls + { get { return _allowListUrls; } } + public List BlockListUrls { get { return _blockListUrls; } }