From 02c8ae8400276879eef86fd66e6881e809c53369 Mon Sep 17 00:00:00 2001 From: Shreyas Zare Date: Sun, 28 Jun 2020 19:38:34 +0530 Subject: [PATCH] AuthZone: added checks in GetNameServerAddresses() to prevent exceptions. Made GetPrimaryNameServerAddresses() public, used GetRecords() for getting NS records and updated implementation to use addresses from NS record and from glue records. Made GetSecondaryNameServerAddresses() public and used GetRecords() for getting NS records. Updated GetRecords() to return empty array if no records are found. Optimized ContainsNameServerRecords() code. --- DnsServerCore/Dns/Zones/AuthZone.cs | 99 +++++++++++++++++++++-------- 1 file changed, 73 insertions(+), 26 deletions(-) diff --git a/DnsServerCore/Dns/Zones/AuthZone.cs b/DnsServerCore/Dns/Zones/AuthZone.cs index 1865c568..7ca7ca6d 100644 --- a/DnsServerCore/Dns/Zones/AuthZone.cs +++ b/DnsServerCore/Dns/Zones/AuthZone.cs @@ -136,24 +136,34 @@ namespace DnsServerCore.Dns.Zones else { //resolve addresses - DnsDatagram response = dnsServer.DirectQuery(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.A, DnsClass.IN)); - if (response != null) + try { - IReadOnlyList addresses = DnsClient.ParseResponseA(response); - foreach (IPAddress address in addresses) - nameServers.Add(new NameServerAddress(nsDomain, address)); - } - - if (dnsServer.PreferIPv6) - { - response = dnsServer.DirectQuery(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.AAAA, DnsClass.IN)); - if (response != null) + DnsDatagram response = dnsServer.DirectQuery(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.A, DnsClass.IN)); + if ((response != null) && (response.Answer.Count > 0)) { - IReadOnlyList addresses = DnsClient.ParseResponseAAAA(response); + IReadOnlyList addresses = DnsClient.ParseResponseA(response); foreach (IPAddress address in addresses) nameServers.Add(new NameServerAddress(nsDomain, address)); } } + catch + { } + + if (dnsServer.PreferIPv6) + { + try + { + DnsDatagram response = dnsServer.DirectQuery(new DnsQuestionRecord(nsDomain, DnsResourceRecordType.AAAA, DnsClass.IN)); + if ((response != null) && (response.Answer.Count > 0)) + { + IReadOnlyList addresses = DnsClient.ParseResponseAAAA(response); + foreach (IPAddress address in addresses) + nameServers.Add(new NameServerAddress(nsDomain, address)); + } + } + catch + { } + } } return nameServers; @@ -161,24 +171,52 @@ namespace DnsServerCore.Dns.Zones #endregion - #region protected + #region public - protected IReadOnlyList GetPrimaryNameServerAddresses(DnsServer dnsServer) + public IReadOnlyList GetPrimaryNameServerAddresses(DnsServer dnsServer) { - DnsResourceRecord soaRecord = _entries[DnsResourceRecordType.SOA][0]; + List nameServers = new List(); - return GetNameServerAddresses(dnsServer, soaRecord); + DnsResourceRecord soaRecord = _entries[DnsResourceRecordType.SOA][0]; + DnsSOARecord soa = soaRecord.RDATA as DnsSOARecord; + IReadOnlyList nsRecords = GetRecords(DnsResourceRecordType.NS); //stub zone has no authority so cant use QueryRecords + + foreach (DnsResourceRecord nsRecord in nsRecords) + { + if (nsRecord.IsDisabled()) + continue; + + string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer; + + if (soa.PrimaryNameServer.Equals(nsDomain, StringComparison.OrdinalIgnoreCase)) + { + //found primary NS + nameServers.AddRange(GetNameServerAddresses(dnsServer, nsRecord)); + break; + } + } + + foreach (NameServerAddress nameServer in GetNameServerAddresses(dnsServer, soaRecord)) + { + if (!nameServers.Contains(nameServer)) + nameServers.Add(nameServer); + } + + return nameServers; } - protected IReadOnlyList GetSecondaryNameServerAddresses(DnsServer dnsServer) + public IReadOnlyList GetSecondaryNameServerAddresses(DnsServer dnsServer) { List nameServers = new List(); DnsSOARecord soa = _entries[DnsResourceRecordType.SOA][0].RDATA as DnsSOARecord; - IReadOnlyList nsRecords = QueryRecords(DnsResourceRecordType.NS); + IReadOnlyList nsRecords = GetRecords(DnsResourceRecordType.NS); //stub zone has no authority so cant use QueryRecords foreach (DnsResourceRecord nsRecord in nsRecords) { + if (nsRecord.IsDisabled()) + continue; + string nsDomain = (nsRecord.RDATA as DnsNSRecord).NameServer; if (soa.PrimaryNameServer.Equals(nsDomain, StringComparison.OrdinalIgnoreCase)) @@ -190,10 +228,6 @@ namespace DnsServerCore.Dns.Zones return nameServers; } - #endregion - - #region public - public void SyncRecords(Dictionary> newEntries, bool dontRemoveRecords) { if (!dontRemoveRecords) @@ -236,7 +270,7 @@ namespace DnsServerCore.Dns.Zones if ((this is SecondaryZone) || (this is StubZone)) { //copy existing SOA record's glue addresses to new SOA record - newEntry.Value[0].SetGlueRecords(_entries[DnsResourceRecordType.SOA][0].GetGlueRecords()); + newEntry.Value[0].SyncGlueRecords(_entries[DnsResourceRecordType.SOA][0].GetGlueRecords()); } } @@ -363,13 +397,26 @@ namespace DnsServerCore.Dns.Zones public IReadOnlyList GetRecords(DnsResourceRecordType type) { - return _entries[type]; + if (_entries.TryGetValue(type, out IReadOnlyList records)) + return records; + + return Array.Empty(); } public override bool ContainsNameServerRecords() { - IReadOnlyList records = QueryRecords(DnsResourceRecordType.NS); - return (records.Count > 0) && (records[0].Type == DnsResourceRecordType.NS); + if (!_entries.TryGetValue(DnsResourceRecordType.NS, out IReadOnlyList records)) + return false; + + foreach (DnsResourceRecord record in records) + { + if (record.IsDisabled()) + continue; + + return true; + } + + return false; } #endregion