diff --git a/Apps/AdvancedBlockingApp/App.cs b/Apps/AdvancedBlockingApp/App.cs index f49ef603..2015d9bc 100644 --- a/Apps/AdvancedBlockingApp/App.cs +++ b/Apps/AdvancedBlockingApp/App.cs @@ -39,7 +39,7 @@ using TechnitiumLibrary.Net.Http.Client; namespace AdvancedBlocking { - public sealed class App : IDnsApplication, IDnsAuthoritativeRequestHandler + public sealed class App : IDnsApplication, IDnsRequestBlockingHandler { #region variables @@ -51,6 +51,7 @@ namespace AdvancedBlocking bool _enableBlocking; int _blockListUrlUpdateIntervalHours; + IReadOnlyDictionary _localEndPointGroupMap; IReadOnlyDictionary _networkGroupMap; IReadOnlyDictionary _groups; @@ -242,11 +243,45 @@ namespace AdvancedBlocking return false; } + private string GetGroupName(DnsDatagram request, IPEndPoint remoteEP) + { + if ((request.Metadata is not null) && (request.Metadata.NameServer is not null)) + { + IPEndPoint requestLocalEP = request.Metadata.NameServer.IPEndPoint; + if (requestLocalEP is not null) + { + foreach (KeyValuePair entry in _localEndPointGroupMap) + { + if ((entry.Key.Port == 0) && entry.Key.Address.Equals(requestLocalEP.Address)) + return entry.Value; + + if (entry.Key.Equals(requestLocalEP)) + return entry.Value; + } + } + } + + string groupName = null; + IPAddress remoteIP = remoteEP.Address; + NetworkAddress network = null; + + foreach (KeyValuePair entry in _networkGroupMap) + { + if (entry.Key.Contains(remoteIP) && ((network is null) || (entry.Key.PrefixLength > network.PrefixLength))) + { + network = entry.Key; + groupName = entry.Value; + } + } + + return groupName; + } + #endregion #region public - public Task InitializeAsync(IDnsServer dnsServer, string config) + public async Task InitializeAsync(IDnsServer dnsServer, string config) { _dnsServer = dnsServer; @@ -261,6 +296,19 @@ namespace AdvancedBlocking _enableBlocking = jsonConfig.GetProperty("enableBlocking").GetBoolean(); _blockListUrlUpdateIntervalHours = jsonConfig.GetProperty("blockListUrlUpdateIntervalHours").GetInt32(); + if (jsonConfig.TryReadObjectAsMap("localEndPointGroupMap", + delegate (string localEP, JsonElement jsonGroup) + { + if (!IPEndPoint.TryParse(localEP, out IPEndPoint ep)) + throw new InvalidOperationException("Local end point group map contains an invalid end point: " + localEP); + + return new Tuple(ep, jsonGroup.GetString()); + }, + out Dictionary localEndPointGroupMap)) + { + _localEndPointGroupMap = localEndPointGroupMap; + } + _networkGroupMap = jsonConfig.ReadObjectAsMap("networkGroupMap", delegate (string network, JsonElement jsonGroup) { if (!NetworkAddress.TryParse(network, out NetworkAddress networkAddress)) @@ -355,7 +403,7 @@ namespace AdvancedBlocking _dnsServer.WriteLog("Advanced Blocking app loaded all zones successfully for group: " + group.Key); } - Task.Run(async delegate () + await Task.Run(async delegate () { List loadTasks = new List(); @@ -417,52 +465,41 @@ namespace AdvancedBlocking } }); - return Task.CompletedTask; + if (!jsonConfig.TryGetProperty("localEndPointGroupMap", out _)) + { + config = config.Replace("\"networkGroupMap\"", "\"localEndPointGroupMap\": {\r\n },\r\n \"networkGroupMap\""); + + await File.WriteAllTextAsync(Path.Combine(dnsServer.ApplicationFolder, "dnsApp.config"), config); + } } - public async Task ProcessRequestAsync(DnsDatagram request, IPEndPoint remoteEP, DnsTransportProtocol protocol, bool isRecursionAllowed) + public Task IsAllowedAsync(DnsDatagram request, IPEndPoint remoteEP) { if (!_enableBlocking) - return null; - - IPAddress remoteIP = remoteEP.Address; - NetworkAddress network = null; - string groupName = null; - - foreach (KeyValuePair entry in _networkGroupMap) - { - if (entry.Key.Contains(remoteIP) && ((network is null) || (entry.Key.PrefixLength > network.PrefixLength))) - { - network = entry.Key; - groupName = entry.Value; - } - } + return Task.FromResult(false); + string groupName = GetGroupName(request, remoteEP); if ((groupName is null) || !_groups.TryGetValue(groupName, out Group group) || !group.EnableBlocking) - return null; + return Task.FromResult(false); DnsQuestionRecord question = request.Question[0]; - if (!group.IsZoneBlocked(question.Name, out bool allowed, out string blockedDomain, out string blockedRegex, out Uri blockListUrl)) - { - if (allowed) - { - try - { - DnsDatagram internalResponse = await _dnsServer.DirectQueryAsync(request); - if (internalResponse.Tag is null) - internalResponse.Tag = DnsServerResponseType.Recursive; + return Task.FromResult(group.IsZoneAllowed(question.Name)); + } - return internalResponse; - } - catch (Exception ex) - { - _dnsServer.WriteLog("Failed to resolve the request for allowed domain name with QNAME: " + question.Name + "; QTYPE: " + question.Type + "; QCLASS: " + question.Class + "\r\n" + ex.ToString()); - } - } + public Task ProcessRequestAsync(DnsDatagram request, IPEndPoint remoteEP) + { + if (!_enableBlocking) + return Task.FromResult(null); - return null; - } + string groupName = GetGroupName(request, remoteEP); + if ((groupName is null) || !_groups.TryGetValue(groupName, out Group group) || !group.EnableBlocking) + return Task.FromResult(null); + + DnsQuestionRecord question = request.Question[0]; + + if (!group.IsZoneBlocked(question.Name, out string blockedDomain, out string blockedRegex, out Uri blockListUrl)) + return Task.FromResult(null); string GetBlockingReport() { @@ -493,7 +530,7 @@ namespace AdvancedBlocking DnsResourceRecord[] answer = new DnsResourceRecord[] { new DnsResourceRecord(question.Name, DnsResourceRecordType.TXT, question.Class, 60, new DnsTXTRecordData(blockingReport)) }; - return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NoError, request.Question, answer) { Tag = DnsServerResponseType.Blocked }; + return Task.FromResult(new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.NoError, request.Question, answer)); } else { @@ -578,7 +615,7 @@ namespace AdvancedBlocking } } - return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, rcode, request.Question, answer, authority, null, request.EDNS is null ? ushort.MinValue : _dnsServer.UdpPayloadSize, EDnsHeaderFlags.None, options) { Tag = DnsServerResponseType.Blocked }; + return Task.FromResult(new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, rcode, request.Question, answer, authority, null, request.EDNS is null ? ushort.MinValue : _dnsServer.UdpPayloadSize, EDnsHeaderFlags.None, options)); } } @@ -765,26 +802,22 @@ namespace AdvancedBlocking } } - public bool IsZoneBlocked(string domain, out bool allowed, out string blockedDomain, out string blockedRegex, out Uri listUrl) + public bool IsZoneAllowed(string domain) { domain = domain.ToLower(); //allowed, allow list zone, allowedRegex, regex allow list zone, adblock list zone - if (IsZoneFound(_allowed, domain, out _) || IsZoneFound(_allowListZones, domain, out _, out _) || IsMatchFound(_allowedRegex, domain, out _) || IsMatchFound(_regexAllowListZones, domain, out _, out _) || IsZoneAllowed(_adBlockListZones, domain, out _, out _)) - { - //found zone allowed - allowed = true; - blockedDomain = null; - blockedRegex = null; - listUrl = null; - return false; - } + return IsZoneFound(_allowed, domain, out _) || IsZoneFound(_allowListZones, domain, out _, out _) || IsMatchFound(_allowedRegex, domain, out _) || IsMatchFound(_regexAllowListZones, domain, out _, out _) || App.IsZoneAllowed(_adBlockListZones, domain, out _, out _); + } + + public bool IsZoneBlocked(string domain, out string blockedDomain, out string blockedRegex, out Uri listUrl) + { + domain = domain.ToLower(); //blocked if (IsZoneFound(_blocked, domain, out string foundZone1)) { //found zone blocked - allowed = false; blockedDomain = foundZone1; blockedRegex = null; listUrl = null; @@ -795,7 +828,6 @@ namespace AdvancedBlocking if (IsZoneFound(_blockListZones, domain, out string foundZone2, out Uri blockListUrl1)) { //found zone blocked - allowed = false; blockedDomain = foundZone2; blockedRegex = null; listUrl = blockListUrl1; @@ -806,7 +838,6 @@ namespace AdvancedBlocking if (IsMatchFound(_blockedRegex, domain, out string blockedPattern1)) { //found pattern blocked - allowed = false; blockedDomain = null; blockedRegex = blockedPattern1; listUrl = null; @@ -817,7 +848,6 @@ namespace AdvancedBlocking if (IsMatchFound(_regexBlockListZones, domain, out string blockedPattern2, out Uri blockListUrl2)) { //found pattern blocked - allowed = false; blockedDomain = null; blockedRegex = blockedPattern2; listUrl = blockListUrl2; @@ -828,14 +858,12 @@ namespace AdvancedBlocking if (App.IsZoneBlocked(_adBlockListZones, domain, out string foundZone3, out Uri blockListUrl3)) { //found zone blocked - allowed = false; blockedDomain = foundZone3; blockedRegex = null; listUrl = blockListUrl3; return true; } - allowed = false; blockedDomain = null; blockedRegex = null; listUrl = null;