diff --git a/DnsServerCore/Zone.cs b/DnsServerCore/Zone.cs index 31935787..f1a69876 100644 --- a/DnsServerCore/Zone.cs +++ b/DnsServerCore/Zone.cs @@ -77,15 +77,15 @@ namespace DnsServerCore { nsRecords.Add(new DnsResourceRecord("", DnsResourceRecordType.NS, DnsClass.IN, 172800, new DnsNSRecord(rootNameServer.Domain))); - CreateZone(this, rootNameServer.Domain).SetRecord(DnsResourceRecordType.A, new DnsResourceRecord[] { new DnsResourceRecord(rootNameServer.Domain, DnsResourceRecordType.A, DnsClass.IN, 172800, new DnsARecord(rootNameServer.EndPoint.Address)) }); + CreateZone(this, rootNameServer.Domain).SetRecords(DnsResourceRecordType.A, new DnsResourceRecord[] { new DnsResourceRecord(rootNameServer.Domain, DnsResourceRecordType.A, DnsClass.IN, 172800, new DnsARecord(rootNameServer.EndPoint.Address)) }); } foreach (NameServerAddress rootNameServer in DnsClient.ROOT_NAME_SERVERS_IPv6) { - CreateZone(this, rootNameServer.Domain).SetRecord(DnsResourceRecordType.AAAA, new DnsResourceRecord[] { new DnsResourceRecord(rootNameServer.Domain, DnsResourceRecordType.AAAA, DnsClass.IN, 172800, new DnsARecord(rootNameServer.EndPoint.Address)) }); + CreateZone(this, rootNameServer.Domain).SetRecords(DnsResourceRecordType.AAAA, new DnsResourceRecord[] { new DnsResourceRecord(rootNameServer.Domain, DnsResourceRecordType.AAAA, DnsClass.IN, 172800, new DnsAAAARecord(rootNameServer.EndPoint.Address)) }); } - SetRecord(DnsResourceRecordType.NS, nsRecords.ToArray()); + SetRecords(DnsResourceRecordType.NS, nsRecords.ToArray()); } private static string[] ConvertDomainToPath(string domainName) @@ -169,9 +169,9 @@ namespace DnsServerCore return null; } - private void SetRecord(DnsResourceRecordType type, DnsResourceRecord[] records) + private void SetRecords(DnsResourceRecordType type, DnsResourceRecord[] records) { - DnsResourceRecord[] existingRecords = _entries.GetOrAdd(type, delegate (DnsResourceRecordType key) + DnsResourceRecord[] existingRecords = _entries.AddOrUpdate(type, records, delegate (DnsResourceRecordType key, DnsResourceRecord[] oldValue) { return records; }); @@ -367,6 +367,42 @@ namespace DnsServerCore #region internal + internal static Dictionary>> GroupRecords(ICollection records) + { + Dictionary>> groupedByDomainRecords = new Dictionary>>(); + + foreach (DnsResourceRecord record in records) + { + Dictionary> groupedByTypeRecords; + + if (groupedByDomainRecords.ContainsKey(record.Name)) + { + groupedByTypeRecords = groupedByDomainRecords[record.Name]; + } + else + { + groupedByTypeRecords = new Dictionary>(); + groupedByDomainRecords.Add(record.Name, groupedByTypeRecords); + } + + List groupedRecords; + + if (groupedByTypeRecords.ContainsKey(record.Type)) + { + groupedRecords = groupedByTypeRecords[record.Type]; + } + else + { + groupedRecords = new List(); + groupedByTypeRecords.Add(record.Type, groupedRecords); + } + + groupedRecords.Add(record); + } + + return groupedByDomainRecords; + } + internal DnsDatagram Query(DnsDatagram request) { if (_authoritativeZone) @@ -396,7 +432,7 @@ namespace DnsServerCore DnsResourceRecord record = new DnsResourceRecord(question.Name, question.Type, DnsClass.IN, DEFAULT_RECORD_TTL, new DnsNXRecord(authority)); record.SetExpiry(); - CreateZone(this, question.Name).SetRecord(question.Type, new DnsResourceRecord[] { record }); + CreateZone(this, question.Name).SetRecords(question.Type, new DnsResourceRecord[] { record }); } } } @@ -413,7 +449,7 @@ namespace DnsServerCore DnsResourceRecord record = new DnsResourceRecord(question.Name, question.Type, DnsClass.IN, DEFAULT_RECORD_TTL, new DnsEmptyRecord(authority)); record.SetExpiry(); - CreateZone(this, question.Name).SetRecord(question.Type, new DnsResourceRecord[] { record }); + CreateZone(this, question.Name).SetRecords(question.Type, new DnsResourceRecord[] { record }); } } } @@ -431,78 +467,51 @@ namespace DnsServerCore allRecords.AddRange(response.Authority); allRecords.AddRange(response.Additional); - #region group all records by domain and type - - Dictionary>> cacheEntries = new Dictionary>>(); - + //set expiry for cached records foreach (DnsResourceRecord record in allRecords) - { - Dictionary> cacheTypeEntries; + record.SetExpiry(); - if (cacheEntries.ContainsKey(record.Name)) - { - cacheTypeEntries = cacheEntries[record.Name]; - } - else - { - cacheTypeEntries = new Dictionary>(); - cacheEntries.Add(record.Name, cacheTypeEntries); - } - - List cacheRREntries; - - if (cacheTypeEntries.ContainsKey(record.Type)) - { - cacheRREntries = cacheTypeEntries[record.Type]; - } - else - { - cacheRREntries = new List(); - cacheTypeEntries.Add(record.Type, cacheRREntries); - } - - cacheRREntries.Add(record); - } - - #endregion - - //add grouped entries into cache - foreach (KeyValuePair>> cacheEntry in cacheEntries) - { - string domain = cacheEntry.Key; - - foreach (KeyValuePair> cacheTypeEntry in cacheEntry.Value) - { - DnsResourceRecordType type = cacheTypeEntry.Key; - DnsResourceRecord[] records = cacheTypeEntry.Value.ToArray(); - - foreach (DnsResourceRecord record in records) - record.SetExpiry(); - - CreateZone(this, domain).SetRecord(type, records); - } - } + SetRecords(allRecords); //cache for ANY request if (response.Question[0].Type == DnsResourceRecordType.ANY) - CreateZone(this, response.Question[0].Name).SetRecord(DnsResourceRecordType.ANY, response.Answer); + CreateZone(this, response.Question[0].Name).SetRecords(DnsResourceRecordType.ANY, response.Answer); } #endregion #region public - public void SetRecord(string domain, DnsResourceRecordType type, uint ttl, DnsResourceRecordData[] records) + public void SetRecords(string domain, DnsResourceRecordType type, uint ttl, DnsResourceRecordData[] records) { DnsResourceRecord[] resourceRecords = new DnsResourceRecord[records.Length]; for (int i = 0; i < records.Length; i++) resourceRecords[i] = new DnsResourceRecord(domain, type, DnsClass.IN, ttl, records[i]); - CreateZone(this, domain).SetRecord(type, resourceRecords); + CreateZone(this, domain).SetRecords(type, resourceRecords); } - public DnsResourceRecord[] GetAllRecords(string domain = "") + public void SetRecords(ICollection records) + { + Dictionary>> groupedByDomainRecords = GroupRecords(records); + + //add grouped records + foreach (KeyValuePair>> groupedByTypeRecords in groupedByDomainRecords) + { + string domain = groupedByTypeRecords.Key; + + foreach (KeyValuePair> groupedRecords in groupedByTypeRecords.Value) + { + DnsResourceRecordType type = groupedRecords.Key; + DnsResourceRecord[] resourceRecords = groupedRecords.Value.ToArray(); + + CreateZone(this, domain).SetRecords(type, resourceRecords); + } + } + } + + public DnsResourceRecord[] GetRecords(string domain = "") { Zone currentZone = this; @@ -518,7 +527,11 @@ namespace DnsServerCore return new DnsResourceRecord[] { }; //no zone for given domain } - return currentZone.GetRecords(DnsResourceRecordType.ANY); + DnsResourceRecord[] records = currentZone.GetRecords(DnsResourceRecordType.ANY); + if (records != null) + return records; + + return new DnsResourceRecord[] { }; } public string[] ListSubZones(string domain = "") @@ -543,7 +556,7 @@ namespace DnsServerCore return subZoneNames; } - public string[] ListAllAuthoritativeZones(string domain = "") + public string[] ListAuthoritativeZones(string domain = "") { Zone currentZone = this; @@ -570,6 +583,11 @@ namespace DnsServerCore return zoneNames.ToArray(); } + public bool DeleteZone(string domain) + { + return (DeleteZone(this, domain) != null); + } + #endregion class DnsNXRecord : DnsResourceRecordData