diff --git a/DnsServerCore/Dns/ZoneManagers/BlockListZoneManager.cs b/DnsServerCore/Dns/ZoneManagers/BlockListZoneManager.cs index 0dc0bcfb..9862cca2 100644 --- a/DnsServerCore/Dns/ZoneManagers/BlockListZoneManager.cs +++ b/DnsServerCore/Dns/ZoneManagers/BlockListZoneManager.cs @@ -43,6 +43,8 @@ namespace DnsServerCore.Dns.ZoneManagers readonly List _allowListUrls = new List(); readonly List _blockListUrls = new List(); + + IReadOnlyDictionary _allowListZone = new Dictionary(); IReadOnlyDictionary> _blockListZone = new Dictionary>(); DnsSOARecordData _soaRecord; @@ -109,9 +111,10 @@ namespace DnsServerCore.Dns.ZoneManagers return word; } - private Queue ReadListFile(Uri listUrl, bool isAllowList, Dictionary allowedDomains) + private Queue ReadListFile(Uri listUrl, bool isAllowList, out Queue exceptionDomains) { Queue domains = new Queue(); + exceptionDomains = new Queue(); try { @@ -141,7 +144,7 @@ namespace DnsServerCore.Dns.ZoneManagers if (line.Length == 0) continue; //skip empty line - if (line.StartsWith("#") || line.StartsWith("!")) + if (line.StartsWith('#') || line.StartsWith("!")) continue; //skip comment line if (line.StartsWith("||")) @@ -153,38 +156,35 @@ namespace DnsServerCore.Dns.ZoneManagers domain = line.Substring(2, i - 2); options = line.Substring(i + 1); - if (((options.Length == 0) || (options.StartsWith("$") && (options.Contains("doc") || options.Contains("all")))) && DnsClient.IsDomainNameValid(domain)) - domains.Enqueue(domain); + if (((options.Length == 0) || (options.StartsWith('$') && (options.Contains("doc") || options.Contains("all")))) && DnsClient.IsDomainNameValid(domain)) + domains.Enqueue(domain.ToLower()); } else { domain = line.Substring(2); if (DnsClient.IsDomainNameValid(domain)) - domains.Enqueue(domain); + domains.Enqueue(domain.ToLower()); } } else if (line.StartsWith("@@||")) { - //adblock format - if (!isAllowList) + //adblock format - exception syntax + i = line.IndexOf('^'); + if (i > -1) { - i = line.IndexOf('^'); - if (i > -1) - { - domain = line.Substring(4, i - 4); - options = line.Substring(i + 1); + domain = line.Substring(4, i - 4); + options = line.Substring(i + 1); - if (((options.Length == 0) || (options.StartsWith("$") && (options.Contains("doc") || options.Contains("all")))) && DnsClient.IsDomainNameValid(domain)) - allowedDomains.TryAdd(domain, null); - } - else - { - domain = line.Substring(4); + if (((options.Length == 0) || (options.StartsWith('$') && (options.Contains("doc") || options.Contains("all")))) && DnsClient.IsDomainNameValid(domain)) + exceptionDomains.Enqueue(domain.ToLower()); + } + else + { + domain = line.Substring(4); - if (DnsClient.IsDomainNameValid(domain)) - allowedDomains.TryAdd(domain, null); - } + if (DnsClient.IsDomainNameValid(domain)) + exceptionDomains.Enqueue(domain.ToLower()); } } else @@ -200,7 +200,7 @@ namespace DnsServerCore.Dns.ZoneManagers { secondWord = PopWord(ref line); - if (secondWord.Length == 0) + if ((secondWord.Length == 0) || secondWord.StartsWith('#')) hostname = firstWord; else hostname = secondWord; @@ -267,11 +267,13 @@ namespace DnsServerCore.Dns.ZoneManagers return null; } - private static bool IsZoneAllowed(Dictionary allowedDomains, string domain) + private bool IsZoneAllowed(string domain) { + domain = domain.ToLower(); + do { - if (allowedDomains.TryGetValue(domain, out _)) + if (_allowListZone.TryGetValue(domain, out _)) return true; domain = AuthZoneManager.GetParentZone(domain); @@ -287,37 +289,57 @@ namespace DnsServerCore.Dns.ZoneManagers public void LoadBlockLists() { - //read all allowed domains in dictionary - Dictionary allowedDomains = new Dictionary(); + Dictionary> allowListQueues = new Dictionary>(_allowListUrls.Count); + Dictionary> blockListQueues = new Dictionary>(_blockListUrls.Count); + int totalAllowedDomains = 0; + int totalBlockedDomains = 0; + //read all allow lists in a queue foreach (Uri allowListUrl in _allowListUrls) { - Queue queue = ReadListFile(allowListUrl, true, null); + if (!allowListQueues.ContainsKey(allowListUrl)) + { + Queue allowListQueue = ReadListFile(allowListUrl, true, out Queue blockListQueue); + + totalAllowedDomains += allowListQueue.Count; + allowListQueues.Add(allowListUrl, allowListQueue); + + totalBlockedDomains += blockListQueue.Count; + blockListQueues.Add(allowListUrl, blockListQueue); + } + } + + //read all block lists in a queue + foreach (Uri blockListUrl in _blockListUrls) + { + if (!blockListQueues.ContainsKey(blockListUrl)) + { + Queue blockListQueue = ReadListFile(blockListUrl, false, out Queue allowListQueue); + + totalBlockedDomains += blockListQueue.Count; + blockListQueues.Add(blockListUrl, blockListQueue); + + totalAllowedDomains += allowListQueue.Count; + allowListQueues.Add(blockListUrl, allowListQueue); + } + } + + //load block list zone + Dictionary allowListZone = new Dictionary(totalAllowedDomains); + + foreach (KeyValuePair> allowListQueue in allowListQueues) + { + Queue queue = allowListQueue.Value; while (queue.Count > 0) { string domain = queue.Dequeue(); - allowedDomains.TryAdd(domain, null); + allowListZone.TryAdd(domain, null); } } - //read all block lists in a queue - Dictionary> blockListQueues = new Dictionary>(_blockListUrls.Count); - int totalDomains = 0; - - foreach (Uri blockListUrl in _blockListUrls) - { - if (!blockListQueues.ContainsKey(blockListUrl)) - { - Queue blockListQueue = ReadListFile(blockListUrl, false, allowedDomains); - totalDomains += blockListQueue.Count; - blockListQueues.Add(blockListUrl, blockListQueue); - } - } - - //load block list zone - Dictionary> blockListZone = new Dictionary>(totalDomains); + Dictionary> blockListZone = new Dictionary>(totalBlockedDomains); foreach (KeyValuePair> blockListQueue in blockListQueues) { @@ -327,9 +349,6 @@ namespace DnsServerCore.Dns.ZoneManagers { string domain = queue.Dequeue(); - if (IsZoneAllowed(allowedDomains, domain)) - continue; //domain is in allowed list so skip adding it to block list zone - if (!blockListZone.TryGetValue(domain, out List blockLists)) { blockLists = new List(2); @@ -340,16 +359,16 @@ namespace DnsServerCore.Dns.ZoneManagers } } - //set new blocked zone + //set new allowed and blocked zones + _allowListZone = allowListZone; _blockListZone = blockListZone; - LogManager log = _dnsServer.LogManager; - if (log != null) - log.Write("DNS Server block list zone was loaded successfully."); + _dnsServer.LogManager?.Write("DNS Server block list zone was loaded successfully."); } public void Flush() { + _allowListZone = new Dictionary(); _blockListZone = new Dictionary>(); } @@ -451,8 +470,19 @@ namespace DnsServerCore.Dns.ZoneManagers return downloaded || notModified; } + public bool IsAllowed(DnsDatagram request) + { + if (_allowListZone.Count < 1) + return false; + + return IsZoneAllowed(request.Question[0].Name); + } + public DnsDatagram Query(DnsDatagram request) { + if (_blockListZone.Count < 1) + return null; + DnsQuestionRecord question = request.Question[0]; List blockLists = IsZoneBlocked(question.Name, out string blockedDomain); @@ -572,6 +602,9 @@ namespace DnsServerCore.Dns.ZoneManagers public List BlockListUrls { get { return _blockListUrls; } } + public int TotalZonesAllowed + { get { return _allowListZone.Count; } } + public int TotalZonesBlocked { get { return _blockListZone.Count; } }