BlockListZoneManager: updated implementation of allowed list by having a separate allowed zone. Updated parser to detect inline comment. Added IsAllowed() method. Updated Query() to check count before query.

This commit is contained in:
Shreyas Zare
2023-02-25 13:06:04 +05:30
parent d62483bd51
commit a53d5df19c

View File

@@ -43,6 +43,8 @@ namespace DnsServerCore.Dns.ZoneManagers
readonly List<Uri> _allowListUrls = new List<Uri>();
readonly List<Uri> _blockListUrls = new List<Uri>();
IReadOnlyDictionary<string, object> _allowListZone = new Dictionary<string, object>();
IReadOnlyDictionary<string, List<Uri>> _blockListZone = new Dictionary<string, List<Uri>>();
DnsSOARecordData _soaRecord;
@@ -109,9 +111,10 @@ namespace DnsServerCore.Dns.ZoneManagers
return word;
}
private Queue<string> ReadListFile(Uri listUrl, bool isAllowList, Dictionary<string, object> allowedDomains)
private Queue<string> ReadListFile(Uri listUrl, bool isAllowList, out Queue<string> exceptionDomains)
{
Queue<string> domains = new Queue<string>();
exceptionDomains = new Queue<string>();
try
{
@@ -141,7 +144,7 @@ namespace DnsServerCore.Dns.ZoneManagers
if (line.Length == 0)
continue; //skip empty line
if (line.StartsWith("#") || line.StartsWith("!"))
if (line.StartsWith('#') || line.StartsWith("!"))
continue; //skip comment line
if (line.StartsWith("||"))
@@ -153,38 +156,35 @@ namespace DnsServerCore.Dns.ZoneManagers
domain = line.Substring(2, i - 2);
options = line.Substring(i + 1);
if (((options.Length == 0) || (options.StartsWith("$") && (options.Contains("doc") || options.Contains("all")))) && DnsClient.IsDomainNameValid(domain))
domains.Enqueue(domain);
if (((options.Length == 0) || (options.StartsWith('$') && (options.Contains("doc") || options.Contains("all")))) && DnsClient.IsDomainNameValid(domain))
domains.Enqueue(domain.ToLower());
}
else
{
domain = line.Substring(2);
if (DnsClient.IsDomainNameValid(domain))
domains.Enqueue(domain);
domains.Enqueue(domain.ToLower());
}
}
else if (line.StartsWith("@@||"))
{
//adblock format
if (!isAllowList)
//adblock format - exception syntax
i = line.IndexOf('^');
if (i > -1)
{
i = line.IndexOf('^');
if (i > -1)
{
domain = line.Substring(4, i - 4);
options = line.Substring(i + 1);
domain = line.Substring(4, i - 4);
options = line.Substring(i + 1);
if (((options.Length == 0) || (options.StartsWith("$") && (options.Contains("doc") || options.Contains("all")))) && DnsClient.IsDomainNameValid(domain))
allowedDomains.TryAdd(domain, null);
}
else
{
domain = line.Substring(4);
if (((options.Length == 0) || (options.StartsWith('$') && (options.Contains("doc") || options.Contains("all")))) && DnsClient.IsDomainNameValid(domain))
exceptionDomains.Enqueue(domain.ToLower());
}
else
{
domain = line.Substring(4);
if (DnsClient.IsDomainNameValid(domain))
allowedDomains.TryAdd(domain, null);
}
if (DnsClient.IsDomainNameValid(domain))
exceptionDomains.Enqueue(domain.ToLower());
}
}
else
@@ -200,7 +200,7 @@ namespace DnsServerCore.Dns.ZoneManagers
{
secondWord = PopWord(ref line);
if (secondWord.Length == 0)
if ((secondWord.Length == 0) || secondWord.StartsWith('#'))
hostname = firstWord;
else
hostname = secondWord;
@@ -267,11 +267,13 @@ namespace DnsServerCore.Dns.ZoneManagers
return null;
}
private static bool IsZoneAllowed(Dictionary<string, object> allowedDomains, string domain)
private bool IsZoneAllowed(string domain)
{
domain = domain.ToLower();
do
{
if (allowedDomains.TryGetValue(domain, out _))
if (_allowListZone.TryGetValue(domain, out _))
return true;
domain = AuthZoneManager.GetParentZone(domain);
@@ -287,37 +289,57 @@ namespace DnsServerCore.Dns.ZoneManagers
public void LoadBlockLists()
{
//read all allowed domains in dictionary
Dictionary<string, object> allowedDomains = new Dictionary<string, object>();
Dictionary<Uri, Queue<string>> allowListQueues = new Dictionary<Uri, Queue<string>>(_allowListUrls.Count);
Dictionary<Uri, Queue<string>> blockListQueues = new Dictionary<Uri, Queue<string>>(_blockListUrls.Count);
int totalAllowedDomains = 0;
int totalBlockedDomains = 0;
//read all allow lists in a queue
foreach (Uri allowListUrl in _allowListUrls)
{
Queue<string> queue = ReadListFile(allowListUrl, true, null);
if (!allowListQueues.ContainsKey(allowListUrl))
{
Queue<string> allowListQueue = ReadListFile(allowListUrl, true, out Queue<string> blockListQueue);
totalAllowedDomains += allowListQueue.Count;
allowListQueues.Add(allowListUrl, allowListQueue);
totalBlockedDomains += blockListQueue.Count;
blockListQueues.Add(allowListUrl, blockListQueue);
}
}
//read all block lists in a queue
foreach (Uri blockListUrl in _blockListUrls)
{
if (!blockListQueues.ContainsKey(blockListUrl))
{
Queue<string> blockListQueue = ReadListFile(blockListUrl, false, out Queue<string> allowListQueue);
totalBlockedDomains += blockListQueue.Count;
blockListQueues.Add(blockListUrl, blockListQueue);
totalAllowedDomains += allowListQueue.Count;
allowListQueues.Add(blockListUrl, allowListQueue);
}
}
//load block list zone
Dictionary<string, object> allowListZone = new Dictionary<string, object>(totalAllowedDomains);
foreach (KeyValuePair<Uri, Queue<string>> allowListQueue in allowListQueues)
{
Queue<string> queue = allowListQueue.Value;
while (queue.Count > 0)
{
string domain = queue.Dequeue();
allowedDomains.TryAdd(domain, null);
allowListZone.TryAdd(domain, null);
}
}
//read all block lists in a queue
Dictionary<Uri, Queue<string>> blockListQueues = new Dictionary<Uri, Queue<string>>(_blockListUrls.Count);
int totalDomains = 0;
foreach (Uri blockListUrl in _blockListUrls)
{
if (!blockListQueues.ContainsKey(blockListUrl))
{
Queue<string> blockListQueue = ReadListFile(blockListUrl, false, allowedDomains);
totalDomains += blockListQueue.Count;
blockListQueues.Add(blockListUrl, blockListQueue);
}
}
//load block list zone
Dictionary<string, List<Uri>> blockListZone = new Dictionary<string, List<Uri>>(totalDomains);
Dictionary<string, List<Uri>> blockListZone = new Dictionary<string, List<Uri>>(totalBlockedDomains);
foreach (KeyValuePair<Uri, Queue<string>> blockListQueue in blockListQueues)
{
@@ -327,9 +349,6 @@ namespace DnsServerCore.Dns.ZoneManagers
{
string domain = queue.Dequeue();
if (IsZoneAllowed(allowedDomains, domain))
continue; //domain is in allowed list so skip adding it to block list zone
if (!blockListZone.TryGetValue(domain, out List<Uri> blockLists))
{
blockLists = new List<Uri>(2);
@@ -340,16 +359,16 @@ namespace DnsServerCore.Dns.ZoneManagers
}
}
//set new blocked zone
//set new allowed and blocked zones
_allowListZone = allowListZone;
_blockListZone = blockListZone;
LogManager log = _dnsServer.LogManager;
if (log != null)
log.Write("DNS Server block list zone was loaded successfully.");
_dnsServer.LogManager?.Write("DNS Server block list zone was loaded successfully.");
}
public void Flush()
{
_allowListZone = new Dictionary<string, object>();
_blockListZone = new Dictionary<string, List<Uri>>();
}
@@ -451,8 +470,19 @@ namespace DnsServerCore.Dns.ZoneManagers
return downloaded || notModified;
}
public bool IsAllowed(DnsDatagram request)
{
if (_allowListZone.Count < 1)
return false;
return IsZoneAllowed(request.Question[0].Name);
}
public DnsDatagram Query(DnsDatagram request)
{
if (_blockListZone.Count < 1)
return null;
DnsQuestionRecord question = request.Question[0];
List<Uri> blockLists = IsZoneBlocked(question.Name, out string blockedDomain);
@@ -572,6 +602,9 @@ namespace DnsServerCore.Dns.ZoneManagers
public List<Uri> BlockListUrls
{ get { return _blockListUrls; } }
public int TotalZonesAllowed
{ get { return _allowListZone.Count; } }
public int TotalZonesBlocked
{ get { return _blockListZone.Count; } }