Files
DnsServer/Apps/DnsRebindBlockingApp/App.cs
2023-12-10 10:07:20 +08:00

107 lines
3.8 KiB
C#

using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Text.Json;
using System.Threading.Tasks;
using DnsServerCore.ApplicationCommon;
using TechnitiumLibrary.Net;
using TechnitiumLibrary.Net.Dns;
using TechnitiumLibrary.Net.Dns.ResourceRecords;
namespace DnsRebindBlocking
{
public class App: IDnsApplication, IDnsPostProcessor
{
private AppConfig Config = null!;
private HashSet<NetworkAddress> PrivateNetworks = new();
private IDnsServer DnsServer = null!;
public void Dispose()
{
// Nothing to dispose of.
}
public Task InitializeAsync(IDnsServer dnsServer, string config)
{
DnsServer = dnsServer;
Config = JsonSerializer.Deserialize<AppConfig>(config, new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase
})!;
DnsServer.WriteLog($"Initializing. Enabled: {Config.Enabled}");
PrivateNetworks.Clear();
foreach (var privateNetwork in Config.PrivateNetworks)
{
var success = NetworkAddress.TryParse(privateNetwork, out NetworkAddress networkAddress);
PrivateNetworks.Add(networkAddress);
}
// Add the ServerDomain to the PrivateDomains list so it doesn't block it's own.
Config.PrivateDomains.Add(DnsServer.ServerDomain);
return Task.CompletedTask;
}
public string Description => "Block DNS responses with protected IP ranges to prevent DNS rebinding attacks.";
public Task<DnsDatagram> PostProcessAsync(DnsDatagram request, IPEndPoint remoteEP, DnsTransportProtocol protocol, DnsDatagram response)
{
// Do not filter authoritative responses. Because in this case any rebinding is intentional.
if (!Config.Enabled || response.AuthoritativeAnswer)
return Task.FromResult(response);
var answers = response.Answer.Where(res => !IsFilteredRebind(res)).ToList();
var additional = response.Additional.Where(res => !IsFilteredRebind(res)).ToList();
return Task.FromResult(response.Clone(answers, response.Authority, additional));
}
private bool IsFilteredRebind(DnsResourceRecord record)
{
if (record.Type != DnsResourceRecordType.A && record.Type != DnsResourceRecordType.AAAA)
return false;
IPAddress address;
switch (record.RDATA)
{
case DnsARecordData data:
address = data.Address;
break;
case DnsAAAARecordData data:
address = data.Address;
break;
default:
return false;
}
var isPrivateNetwork = PrivateNetworks.Any(net => net.Contains(address));
var isPrivateDomain = IsZoneFound(Config.PrivateDomains, record.Name, out _);
return isPrivateNetwork && !isPrivateDomain;
}
private static string? GetParentZone(string domain)
{
var i = domain.IndexOf('.');
//dont return root zone
return i > -1 ? domain[(i + 1)..] : null;
}
private static bool IsZoneFound(IReadOnlySet<string> domains, string domain, out string? foundZone)
{
var currentDomain = domain.ToLower();
do
{
if (domains.Contains(currentDomain))
{
foundZone = currentDomain;
return true;
}
currentDomain = GetParentZone(currentDomain);
}
while (currentDomain is not null);
foundZone = null;
return false;
}
}
}