AdvancedBlocking: updated implementation to use new IDnsRequestBlockingHandler interface. Added feature to select group based on DNS server local end point.

This commit is contained in:
Shreyas Zare
2023-10-29 20:27:20 +05:30
parent 4f6ca51638
commit 2d59fe7310

View File

@@ -39,7 +39,7 @@ using TechnitiumLibrary.Net.Http.Client;
namespace AdvancedBlocking
{
public sealed class App : IDnsApplication, IDnsAuthoritativeRequestHandler
public sealed class App : IDnsApplication, IDnsRequestBlockingHandler
{
#region variables
@@ -51,6 +51,7 @@ namespace AdvancedBlocking
bool _enableBlocking;
int _blockListUrlUpdateIntervalHours;
IReadOnlyDictionary<IPEndPoint, string> _localEndPointGroupMap;
IReadOnlyDictionary<NetworkAddress, string> _networkGroupMap;
IReadOnlyDictionary<string, Group> _groups;
@@ -242,11 +243,45 @@ namespace AdvancedBlocking
return false;
}
private string GetGroupName(DnsDatagram request, IPEndPoint remoteEP)
{
if ((request.Metadata is not null) && (request.Metadata.NameServer is not null))
{
IPEndPoint requestLocalEP = request.Metadata.NameServer.IPEndPoint;
if (requestLocalEP is not null)
{
foreach (KeyValuePair<IPEndPoint, string> entry in _localEndPointGroupMap)
{
if ((entry.Key.Port == 0) && entry.Key.Address.Equals(requestLocalEP.Address))
return entry.Value;
if (entry.Key.Equals(requestLocalEP))
return entry.Value;
}
}
}
string groupName = null;
IPAddress remoteIP = remoteEP.Address;
NetworkAddress network = null;
foreach (KeyValuePair<NetworkAddress, string> entry in _networkGroupMap)
{
if (entry.Key.Contains(remoteIP) && ((network is null) || (entry.Key.PrefixLength > network.PrefixLength)))
{
network = entry.Key;
groupName = entry.Value;
}
}
return groupName;
}
#endregion
#region public
public Task InitializeAsync(IDnsServer dnsServer, string config)
public async Task InitializeAsync(IDnsServer dnsServer, string config)
{
_dnsServer = dnsServer;
@@ -261,6 +296,19 @@ namespace AdvancedBlocking
_enableBlocking = jsonConfig.GetProperty("enableBlocking").GetBoolean();
_blockListUrlUpdateIntervalHours = jsonConfig.GetProperty("blockListUrlUpdateIntervalHours").GetInt32();
if (jsonConfig.TryReadObjectAsMap("localEndPointGroupMap",
delegate (string localEP, JsonElement jsonGroup)
{
if (!IPEndPoint.TryParse(localEP, out IPEndPoint ep))
throw new InvalidOperationException("Local end point group map contains an invalid end point: " + localEP);
return new Tuple<IPEndPoint, string>(ep, jsonGroup.GetString());
},
out Dictionary<IPEndPoint, string> localEndPointGroupMap))
{
_localEndPointGroupMap = localEndPointGroupMap;
}
_networkGroupMap = jsonConfig.ReadObjectAsMap("networkGroupMap", delegate (string network, JsonElement jsonGroup)
{
if (!NetworkAddress.TryParse(network, out NetworkAddress networkAddress))
@@ -355,7 +403,7 @@ namespace AdvancedBlocking
_dnsServer.WriteLog("Advanced Blocking app loaded all zones successfully for group: " + group.Key);
}
Task.Run(async delegate ()
await Task.Run(async delegate ()
{
List<Task> loadTasks = new List<Task>();
@@ -417,52 +465,41 @@ namespace AdvancedBlocking
}
});
return Task.CompletedTask;
if (!jsonConfig.TryGetProperty("localEndPointGroupMap", out _))
{
config = config.Replace("\"networkGroupMap\"", "\"localEndPointGroupMap\": {\r\n },\r\n \"networkGroupMap\"");
await File.WriteAllTextAsync(Path.Combine(dnsServer.ApplicationFolder, "dnsApp.config"), config);
}
}
public async Task<DnsDatagram> ProcessRequestAsync(DnsDatagram request, IPEndPoint remoteEP, DnsTransportProtocol protocol, bool isRecursionAllowed)
public Task<bool> IsAllowedAsync(DnsDatagram request, IPEndPoint remoteEP)
{
if (!_enableBlocking)
return null;
IPAddress remoteIP = remoteEP.Address;
NetworkAddress network = null;
string groupName = null;
foreach (KeyValuePair<NetworkAddress, string> entry in _networkGroupMap)
{
if (entry.Key.Contains(remoteIP) && ((network is null) || (entry.Key.PrefixLength > network.PrefixLength)))
{
network = entry.Key;
groupName = entry.Value;
}
}
return Task.FromResult(false);
string groupName = GetGroupName(request, remoteEP);
if ((groupName is null) || !_groups.TryGetValue(groupName, out Group group) || !group.EnableBlocking)
return null;
return Task.FromResult(false);
DnsQuestionRecord question = request.Question[0];
if (!group.IsZoneBlocked(question.Name, out bool allowed, out string blockedDomain, out string blockedRegex, out Uri blockListUrl))
{
if (allowed)
{
try
{
DnsDatagram internalResponse = await _dnsServer.DirectQueryAsync(request);
if (internalResponse.Tag is null)
internalResponse.Tag = DnsServerResponseType.Recursive;
return Task.FromResult(group.IsZoneAllowed(question.Name));
}
return internalResponse;
}
catch (Exception ex)
{
_dnsServer.WriteLog("Failed to resolve the request for allowed domain name with QNAME: " + question.Name + "; QTYPE: " + question.Type + "; QCLASS: " + question.Class + "\r\n" + ex.ToString());
}
}
public Task<DnsDatagram> ProcessRequestAsync(DnsDatagram request, IPEndPoint remoteEP)
{
if (!_enableBlocking)
return Task.FromResult<DnsDatagram>(null);
return null;
}
string groupName = GetGroupName(request, remoteEP);
if ((groupName is null) || !_groups.TryGetValue(groupName, out Group group) || !group.EnableBlocking)
return Task.FromResult<DnsDatagram>(null);
DnsQuestionRecord question = request.Question[0];
if (!group.IsZoneBlocked(question.Name, out string blockedDomain, out string blockedRegex, out Uri blockListUrl))
return Task.FromResult<DnsDatagram>(null);
string GetBlockingReport()
{
@@ -493,7 +530,7 @@ namespace AdvancedBlocking
DnsResourceRecord[] answer = new DnsResourceRecord[] { new DnsResourceRecord(question.Name, DnsResourceRecordType.TXT, question.Class, 60, new DnsTXTRecordData(blockingReport)) };
return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NoError, request.Question, answer) { Tag = DnsServerResponseType.Blocked };
return Task.FromResult(new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NoError, request.Question, answer));
}
else
{
@@ -578,7 +615,7 @@ namespace AdvancedBlocking
}
}
return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, rcode, request.Question, answer, authority, null, request.EDNS is null ? ushort.MinValue : _dnsServer.UdpPayloadSize, EDnsHeaderFlags.None, options) { Tag = DnsServerResponseType.Blocked };
return Task.FromResult(new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, rcode, request.Question, answer, authority, null, request.EDNS is null ? ushort.MinValue : _dnsServer.UdpPayloadSize, EDnsHeaderFlags.None, options));
}
}
@@ -765,26 +802,22 @@ namespace AdvancedBlocking
}
}
public bool IsZoneBlocked(string domain, out bool allowed, out string blockedDomain, out string blockedRegex, out Uri listUrl)
public bool IsZoneAllowed(string domain)
{
domain = domain.ToLower();
//allowed, allow list zone, allowedRegex, regex allow list zone, adblock list zone
if (IsZoneFound(_allowed, domain, out _) || IsZoneFound(_allowListZones, domain, out _, out _) || IsMatchFound(_allowedRegex, domain, out _) || IsMatchFound(_regexAllowListZones, domain, out _, out _) || IsZoneAllowed(_adBlockListZones, domain, out _, out _))
{
//found zone allowed
allowed = true;
blockedDomain = null;
blockedRegex = null;
listUrl = null;
return false;
}
return IsZoneFound(_allowed, domain, out _) || IsZoneFound(_allowListZones, domain, out _, out _) || IsMatchFound(_allowedRegex, domain, out _) || IsMatchFound(_regexAllowListZones, domain, out _, out _) || App.IsZoneAllowed(_adBlockListZones, domain, out _, out _);
}
public bool IsZoneBlocked(string domain, out string blockedDomain, out string blockedRegex, out Uri listUrl)
{
domain = domain.ToLower();
//blocked
if (IsZoneFound(_blocked, domain, out string foundZone1))
{
//found zone blocked
allowed = false;
blockedDomain = foundZone1;
blockedRegex = null;
listUrl = null;
@@ -795,7 +828,6 @@ namespace AdvancedBlocking
if (IsZoneFound(_blockListZones, domain, out string foundZone2, out Uri blockListUrl1))
{
//found zone blocked
allowed = false;
blockedDomain = foundZone2;
blockedRegex = null;
listUrl = blockListUrl1;
@@ -806,7 +838,6 @@ namespace AdvancedBlocking
if (IsMatchFound(_blockedRegex, domain, out string blockedPattern1))
{
//found pattern blocked
allowed = false;
blockedDomain = null;
blockedRegex = blockedPattern1;
listUrl = null;
@@ -817,7 +848,6 @@ namespace AdvancedBlocking
if (IsMatchFound(_regexBlockListZones, domain, out string blockedPattern2, out Uri blockListUrl2))
{
//found pattern blocked
allowed = false;
blockedDomain = null;
blockedRegex = blockedPattern2;
listUrl = blockListUrl2;
@@ -828,14 +858,12 @@ namespace AdvancedBlocking
if (App.IsZoneBlocked(_adBlockListZones, domain, out string foundZone3, out Uri blockListUrl3))
{
//found zone blocked
allowed = false;
blockedDomain = foundZone3;
blockedRegex = null;
listUrl = blockListUrl3;
return true;
}
allowed = false;
blockedDomain = null;
blockedRegex = null;
listUrl = null;