From 3cfbec888314c5e4862e8960ffcf39664ac8d59d Mon Sep 17 00:00:00 2001 From: Shreyas Zare Date: Sat, 29 Aug 2020 14:23:58 +0530 Subject: [PATCH] AuthZoneManager: Zone creation, loading and sync checks added. --- .../Dns/ZoneManagers/AuthZoneManager.cs | 106 +++++++++++++----- 1 file changed, 76 insertions(+), 30 deletions(-) diff --git a/DnsServerCore/Dns/ZoneManagers/AuthZoneManager.cs b/DnsServerCore/Dns/ZoneManagers/AuthZoneManager.cs index ac522803..ee62148a 100644 --- a/DnsServerCore/Dns/ZoneManagers/AuthZoneManager.cs +++ b/DnsServerCore/Dns/ZoneManagers/AuthZoneManager.cs @@ -25,6 +25,7 @@ using System.IO; using System.Net; using System.Text; using System.Threading; +using System.Threading.Tasks; using TechnitiumLibrary.Net.Dns; using TechnitiumLibrary.Net.Dns.ResourceRecords; @@ -185,22 +186,37 @@ namespace DnsServerCore.Dns.ZoneManagers if (_root.TryAdd(zone)) return zone; + if (_root.TryGet(zoneInfo.Name, out AuthZone existingZone) && (existingZone is SubDomainZone)) + { + _root[zoneInfo.Name] = zone; + return zone; + } + throw new DnsServerException("Zone already exists: " + zoneInfo.Name); } - private void LoadRecords(IReadOnlyList records) + private void LoadRecords(AuthZone authZone, IReadOnlyList records) { Dictionary>> groupedByDomainRecords = DnsResourceRecord.GroupRecords(records); foreach (KeyValuePair>> groupedByTypeRecords in groupedByDomainRecords) { - AuthZone zone = GetOrAddZone(groupedByTypeRecords.Key); + if (authZone.Name.Equals(groupedByTypeRecords.Key, StringComparison.OrdinalIgnoreCase)) + { + foreach (KeyValuePair> groupedRecords in groupedByTypeRecords.Value) + authZone.LoadRecords(groupedRecords.Key, groupedRecords.Value); + } + else + { + AuthZone zone = GetOrAddZone(groupedByTypeRecords.Key); + if (zone is SubDomainZone) + { + foreach (KeyValuePair> groupedRecords in groupedByTypeRecords.Value) + zone.LoadRecords(groupedRecords.Key, groupedRecords.Value); - foreach (KeyValuePair> groupedRecords in groupedByTypeRecords.Value) - zone.LoadRecords(groupedRecords.Key, groupedRecords.Value); - - if (zone is SubDomainZone) - (zone as SubDomainZone).AutoUpdateState(); + (zone as SubDomainZone).AutoUpdateState(); + } + } } } @@ -383,17 +399,7 @@ namespace DnsServerCore.Dns.ZoneManagers } } - public AuthZoneInfo CreatePrimaryZone(string domain, string primaryNameServer, bool @internal) - { - AuthZone authZone = new PrimaryZone(_dnsServer, domain, primaryNameServer, @internal); - - if (_root.TryAdd(authZone)) - return new AuthZoneInfo(authZone); - - return null; - } - - internal AuthZoneInfo CreatePrimaryZone(string domain, DnsSOARecord soaRecord, DnsNSRecord ns) + internal AuthZoneInfo InternalCreatePrimaryZone(string domain, DnsSOARecord soaRecord, DnsNSRecord ns) { AuthZone authZone = new PrimaryZone(_dnsServer, domain, soaRecord, ns); @@ -403,9 +409,25 @@ namespace DnsServerCore.Dns.ZoneManagers return null; } - public AuthZoneInfo CreateSecondaryZone(string domain, string primaryNameServerAddresses = null) + public AuthZoneInfo CreatePrimaryZone(string domain, string primaryNameServer, bool @internal) { - AuthZone authZone = new SecondaryZone(_dnsServer, domain, primaryNameServerAddresses); + AuthZone authZone = new PrimaryZone(_dnsServer, domain, primaryNameServer, @internal); + + if (_root.TryAdd(authZone)) + return new AuthZoneInfo(authZone); + + if (_root.TryGet(domain, out AuthZone existingZone) && (existingZone is SubDomainZone)) + { + _root[domain] = authZone; + return new AuthZoneInfo(authZone); + } + + return null; + } + + public async Task CreateSecondaryZoneAsync(string domain, string primaryNameServerAddresses = null) + { + AuthZone authZone = await SecondaryZone.CreateAsync(_dnsServer, domain, primaryNameServerAddresses); if (_root.TryAdd(authZone)) { @@ -413,12 +435,19 @@ namespace DnsServerCore.Dns.ZoneManagers return new AuthZoneInfo(authZone); } + if (_root.TryGet(domain, out AuthZone existingZone) && (existingZone is SubDomainZone)) + { + _root[domain] = authZone; + (authZone as SecondaryZone).RefreshZone(); + return new AuthZoneInfo(authZone); + } + return null; } - public AuthZoneInfo CreateStubZone(string domain, string primaryNameServerAddresses = null) + public async Task CreateStubZoneAsync(string domain, string primaryNameServerAddresses = null) { - AuthZone authZone = new StubZone(_dnsServer, domain, primaryNameServerAddresses); + AuthZone authZone = await StubZone.CreateAsync(_dnsServer, domain, primaryNameServerAddresses); if (_root.TryAdd(authZone)) { @@ -426,6 +455,13 @@ namespace DnsServerCore.Dns.ZoneManagers return new AuthZoneInfo(authZone); } + if (_root.TryGet(domain, out AuthZone existingZone) && (existingZone is SubDomainZone)) + { + _root[domain] = authZone; + (authZone as StubZone).RefreshZone(); + return new AuthZoneInfo(authZone); + } + return null; } @@ -436,6 +472,12 @@ namespace DnsServerCore.Dns.ZoneManagers if (_root.TryAdd(authZone)) return new AuthZoneInfo(authZone); + if (_root.TryGet(domain, out AuthZone existingZone) && (existingZone is SubDomainZone)) + { + _root[domain] = authZone; + return new AuthZoneInfo(authZone); + } + return null; } @@ -541,7 +583,7 @@ namespace DnsServerCore.Dns.ZoneManagers int i = 0; - if (syncRecords[0].Type == DnsResourceRecordType.SOA) + if ((syncRecords.Count > 1) && (syncRecords[0].Type == DnsResourceRecordType.SOA) && (syncRecords[syncRecords.Count - 1].Type == DnsResourceRecordType.SOA)) i = 1; //skip first SOA in AXFR if (domain.Length == 0) @@ -611,7 +653,11 @@ namespace DnsServerCore.Dns.ZoneManagers foreach (KeyValuePair>> newEntries in newRecordsGroupedByDomain) { AuthZone zone = GetOrAddZone(newEntries.Key); - zone.SyncRecords(newEntries.Value, dontRemoveRecords); + + if (zone.Name.Equals(domain, StringComparison.OrdinalIgnoreCase)) + zone.SyncRecords(newEntries.Value, dontRemoveRecords); + else if (zone is SubDomainZone) + zone.SyncRecords(newEntries.Value, dontRemoveRecords); } } @@ -909,12 +955,12 @@ namespace DnsServerCore.Dns.ZoneManagers AuthZoneInfo zoneInfo = new AuthZoneInfo(records[0].Name, zoneType, false); //create zone - Zone authZone = CreateEmptyZone(zoneInfo); + AuthZone authZone = CreateEmptyZone(zoneInfo); try { //load records - LoadRecords(records); + LoadRecords(authZone, records); } catch { @@ -963,12 +1009,12 @@ namespace DnsServerCore.Dns.ZoneManagers AuthZoneInfo zoneInfo = new AuthZoneInfo(records[0].Name, zoneType, zoneDisabled); //create zone - Zone authZone = CreateEmptyZone(zoneInfo); + AuthZone authZone = CreateEmptyZone(zoneInfo); try { //load records - LoadRecords(records); + LoadRecords(authZone, records); } catch { @@ -993,7 +1039,7 @@ namespace DnsServerCore.Dns.ZoneManagers AuthZoneInfo zoneInfo = new AuthZoneInfo(bR); //create zone - Zone authZone = CreateEmptyZone(zoneInfo); + AuthZone authZone = CreateEmptyZone(zoneInfo); //read all zone records DnsResourceRecord[] records = new DnsResourceRecord[bR.ReadInt32()]; @@ -1008,7 +1054,7 @@ namespace DnsServerCore.Dns.ZoneManagers try { //load records - LoadRecords(records); + LoadRecords(authZone, records); } catch {