diff --git a/DnsServerCore/Dns/Trees/ZoneTree.cs b/DnsServerCore/Dns/Trees/ZoneTree.cs index 5dcaa1ce..f21beecd 100644 --- a/DnsServerCore/Dns/Trees/ZoneTree.cs +++ b/DnsServerCore/Dns/Trees/ZoneTree.cs @@ -18,7 +18,6 @@ along with this program. If not, see . */ using DnsServerCore.Dns.Zones; -using System; using System.Collections.Generic; using System.Threading; @@ -90,61 +89,78 @@ namespace DnsServerCore.Dns.Trees return key; } - private static bool KeysMatch(byte[] key1, byte[] key2) + private static bool KeysMatch(byte[] key1, byte[] key2, bool matchWildcard) { - //com.example.*. - //com.example.*.www. - //com.example.abc.www. - - int i = 0; - int j = 0; - - while ((i < key1.Length) && (j < key2.Length)) + if (matchWildcard) { - if (key1[i] == 1) //[*] + //com.example.*. + //com.example.*.www. + //com.example.abc.www. + + int i = 0; + int j = 0; + + while ((i < key1.Length) && (j < key2.Length)) { - if (i == key1.Length - 2) - return true; - - //skip j to next label - while (j < key2.Length) + if (key1[i] == 1) //[*] { - if (key2[j] == 0) //[.] - break; + if (i == key1.Length - 2) + return true; - j++; - } + //skip j to next label + while (j < key2.Length) + { + if (key2[j] == 0) //[.] + break; - i++; - continue; - } - - if (key2[j] == 1) //[*] - { - if (j == key2.Length - 2) - return true; - - //skip i to next label - while (i < key1.Length) - { - if (key1[i] == 0) //[.] - break; + j++; + } i++; + continue; } + if (key2[j] == 1) //[*] + { + if (j == key2.Length - 2) + return true; + + //skip i to next label + while (i < key1.Length) + { + if (key1[i] == 0) //[.] + break; + + i++; + } + + j++; + continue; + } + + if (key1[i] != key2[j]) + return false; + + i++; j++; - continue; } - if (key1[i] != key2[j]) + return (i == key1.Length) && (j == key2.Length); + } + else + { + //exact match + if (key1.Length != key2.Length) return false; - i++; - j++; - } + for (int i = 0; i < key1.Length; i++) + { + if (key1[i] != key2[i]) + return false; + } - return (i == key1.Length) && (j == key2.Length); + return true; + } } #endregion @@ -190,34 +206,45 @@ namespace DnsServerCore.Dns.Trees return (i == mainKey.Length) && (j < testKey.Length); } - protected TNode FindZoneNode(byte[] key, out Node closestNode, out Node closestAuthorityNode, out TSubDomainZone closestSubDomain, out TSubDomainZone closestDelegation, out TApexZone closestAuthority) + protected TNode FindZoneNode(byte[] key, bool matchWildcard, out Node currentNode, out Node closestSubDomainNode, out Node closestAuthorityNode, out TSubDomainZone closestSubDomain, out TSubDomainZone closestDelegation, out TApexZone closestAuthority) { - closestNode = _root; + closestSubDomainNode = null; closestAuthorityNode = null; closestSubDomain = null; closestDelegation = null; closestAuthority = null; + currentNode = _root; Node wildcard = null; int i = 0; while (i <= key.Length) { - //find authority zone - NodeValue value = closestNode.Value; - if (value is not null) + //inspect the current node + NodeValue value = currentNode.Value; + if ((value is not null) && (value.Key.Length <= key.Length)) { - TNode zoneValue = value.Value; - if ((zoneValue is not null) && IsKeySubDomain(value.Key, key)) + TNode zoneNode = value.Value; + if ((zoneNode is not null) && IsKeySubDomain(value.Key, key)) { + //find closest values + TSubDomainZone subDomain = null; TApexZone authority = null; - GetClosestValuesForZone(zoneValue, ref closestSubDomain, ref closestDelegation, ref authority); + GetClosestValuesForZone(zoneNode, ref subDomain, ref closestDelegation, ref authority); + + if (subDomain is not null) + { + closestSubDomain = subDomain; + closestSubDomainNode = currentNode; + } if (authority is not null) { closestAuthority = authority; - closestAuthorityNode = closestNode; + closestAuthorityNode = currentNode; + + wildcard = null; //clear previous wildcard node from the previous authority } } } @@ -225,13 +252,18 @@ namespace DnsServerCore.Dns.Trees if (i == key.Length) break; - Node[] children = closestNode.Children; + Node[] children = currentNode.Children; if (children is null) break; - Node child = Volatile.Read(ref children[1]); //[*] - if (child is not null) - wildcard = child; + Node child; + + if (matchWildcard) + { + child = Volatile.Read(ref children[1]); //[*] + if (child is not null) + wildcard = child; + } child = Volatile.Read(ref children[key[i]]); if (child is null) @@ -248,34 +280,41 @@ namespace DnsServerCore.Dns.Trees break; } - closestNode = wildcard; + currentNode = wildcard; wildcard = null; continue; } - closestNode = child; + currentNode = child; i++; } { - NodeValue value = closestNode.Value; + NodeValue value = currentNode.Value; if (value is not null) { //match exact + wildcard keys - if (KeysMatch(key, value.Key)) + if (KeysMatch(key, value.Key, matchWildcard)) { - //update authority since the matched zone is apex zone - TNode zoneValue = value.Value; - if (zoneValue is not null) + //find closest values since the matched zone may be apex zone + TNode zoneNode = value.Value; + if (zoneNode is not null) { + TSubDomainZone subDomain = null; TApexZone authority = null; - GetClosestValuesForZone(zoneValue, ref closestSubDomain, ref closestDelegation, ref authority); + GetClosestValuesForZone(zoneNode, ref subDomain, ref closestDelegation, ref authority); + + if (subDomain is not null) + { + closestSubDomain = subDomain; + closestSubDomainNode = currentNode; + } if (authority is not null) { closestAuthority = authority; - closestAuthorityNode = closestNode; + closestAuthorityNode = currentNode; } } @@ -286,7 +325,7 @@ namespace DnsServerCore.Dns.Trees if (wildcard is not null) { - //wildcard node found + //inspect wildcard node value NodeValue value = wildcard.Value; if (value is null) { @@ -301,7 +340,7 @@ namespace DnsServerCore.Dns.Trees if (value is not null) { //match wildcard keys - if (KeysMatch(key, value.Key)) + if (KeysMatch(key, value.Key, true)) return value.Value; //found matching wildcard value } } @@ -310,7 +349,7 @@ namespace DnsServerCore.Dns.Trees else { //match wildcard keys - if (KeysMatch(key, value.Key)) + if (KeysMatch(key, value.Key, true)) return value.Value; //found matching wildcard value } } @@ -327,13 +366,10 @@ namespace DnsServerCore.Dns.Trees public void ListSubDomains(string domain, List subDomains) { - if (domain is null) - throw new ArgumentNullException(nameof(domain)); - byte[] bKey = ConvertToByteKey(domain); - _ = _root.FindNodeValue(bKey, out Node closestNode); - Node current = closestNode; + _ = _root.FindNodeValue(bKey, out Node currentNode); + Node current = currentNode; NodeValue value; do @@ -348,7 +384,7 @@ namespace DnsServerCore.Dns.Trees subDomains.Add(label); } } - else if ((current.K == 0) && (current.Depth > closestNode.Depth)) //[.] + else if ((current.K == 0) && (current.Depth > currentNode.Depth)) //[.] { byte[] nodeKey = GetNodeKey(current); if (IsKeySubDomain(bKey, nodeKey)) @@ -359,7 +395,7 @@ namespace DnsServerCore.Dns.Trees } } - current = GetNextChildZoneNode(current, closestNode.Depth); + current = GetNextChildZoneNode(current, currentNode.Depth); } while (current is not null); }