diff --git a/DnsServerCore/Dns/ZoneManagers/CacheZoneManager.cs b/DnsServerCore/Dns/ZoneManagers/CacheZoneManager.cs index 86a0088d..6fde62ea 100644 --- a/DnsServerCore/Dns/ZoneManagers/CacheZoneManager.cs +++ b/DnsServerCore/Dns/ZoneManagers/CacheZoneManager.cs @@ -80,35 +80,50 @@ namespace DnsServerCore.Dns.ZoneManagers #region private - private List GetAdditionalRecords(IReadOnlyCollection nsRecords, bool serveStale) + private List GetAdditionalRecords(IReadOnlyCollection refRecords, bool serveStale) { List additionalRecords = new List(); - foreach (DnsResourceRecord nsRecord in nsRecords) + foreach (DnsResourceRecord refRecord in refRecords) { - if (nsRecord.Type != DnsResourceRecordType.NS) - continue; - - CacheZone cacheZone = _root.FindZone((nsRecord.RDATA as DnsNSRecord).NameServer, out _, out _, out _); - if (cacheZone != null) + switch (refRecord.Type) { - { - IReadOnlyList records = cacheZone.QueryRecords(DnsResourceRecordType.A, serveStale); - if ((records.Count > 0) && (records[0].RDATA is DnsARecord)) - additionalRecords.AddRange(records); - } + case DnsResourceRecordType.NS: + ResolveAdditionalRecords((refRecord.RDATA as DnsNSRecord).NameServer, serveStale, additionalRecords); + break; - { - IReadOnlyList records = cacheZone.QueryRecords(DnsResourceRecordType.AAAA, serveStale); - if ((records.Count > 0) && (records[0].RDATA is DnsAAAARecord)) - additionalRecords.AddRange(records); - } + case DnsResourceRecordType.MX: + ResolveAdditionalRecords((refRecord.RDATA as DnsMXRecord).Exchange, serveStale, additionalRecords); + break; + + case DnsResourceRecordType.SRV: + ResolveAdditionalRecords((refRecord.RDATA as DnsSRVRecord).Target, serveStale, additionalRecords); + break; } } return additionalRecords; } + private void ResolveAdditionalRecords(string domain, bool serveStale, List additionalRecords) + { + CacheZone cacheZone = _root.FindZone(domain, out _, out _, out _); + if (cacheZone != null) + { + { + IReadOnlyList records = cacheZone.QueryRecords(DnsResourceRecordType.A, serveStale); + if ((records.Count > 0) && (records[0].RDATA is DnsARecord)) + additionalRecords.AddRange(records); + } + + { + IReadOnlyList records = cacheZone.QueryRecords(DnsResourceRecordType.AAAA, serveStale); + if ((records.Count > 0) && (records[0].RDATA is DnsAAAARecord)) + additionalRecords.AddRange(records); + } + } + } + #endregion #region public @@ -218,8 +233,14 @@ namespace DnsServerCore.Dns.ZoneManagers IReadOnlyList additional = null; - if (request.Question[0].Type == DnsResourceRecordType.NS) - additional = GetAdditionalRecords(answers, serveStale); + switch (request.Question[0].Type) + { + case DnsResourceRecordType.NS: + case DnsResourceRecordType.MX: + case DnsResourceRecordType.SRV: + additional = GetAdditionalRecords(answers, serveStale); + break; + } return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, true, false, false, DnsResponseCode.NoError, request.Question, answers, null, additional); }