DnsServer: updated ProcessUpdateQueryAsync() to add last modified and comments for the added records. Updated PrepareRecursiveResolveResponse() to fix issue with CD flag case when DO flag is unset. Code refactoring changes done.

This commit is contained in:
Shreyas Zare
2025-02-15 12:41:37 +05:30
parent 5094e6a481
commit 100348f0fc

View File

@@ -1698,6 +1698,10 @@ namespace DnsServerCore.Dns
IReadOnlyList<DnsResourceRecord> existingRRSet = _authZoneManager.GetRecords(zoneInfo.Name, uRecord.Name, uRecord.Type); IReadOnlyList<DnsResourceRecord> existingRRSet = _authZoneManager.GetRecords(zoneInfo.Name, uRecord.Name, uRecord.Type);
AddToOriginalRRSets(uRecord.Name, uRecord.Type, existingRRSet); AddToOriginalRRSets(uRecord.Name, uRecord.Type, existingRRSet);
GenericRecordInfo recordInfo = uRecord.GetAuthGenericRecordInfo();
recordInfo.LastModified = DateTime.UtcNow;
recordInfo.Comments = "Via Dynamic Updates (RFC 2136)";
_authZoneManager.SetRecord(zoneInfo.Name, uRecord); _authZoneManager.SetRecord(zoneInfo.Name, uRecord);
} }
else if (uRecord.Type == DnsResourceRecordType.DNAME) else if (uRecord.Type == DnsResourceRecordType.DNAME)
@@ -1705,6 +1709,10 @@ namespace DnsServerCore.Dns
IReadOnlyList<DnsResourceRecord> existingRRSet = _authZoneManager.GetRecords(zoneInfo.Name, uRecord.Name, uRecord.Type); IReadOnlyList<DnsResourceRecord> existingRRSet = _authZoneManager.GetRecords(zoneInfo.Name, uRecord.Name, uRecord.Type);
AddToOriginalRRSets(uRecord.Name, uRecord.Type, existingRRSet); AddToOriginalRRSets(uRecord.Name, uRecord.Type, existingRRSet);
GenericRecordInfo recordInfo = uRecord.GetAuthGenericRecordInfo();
recordInfo.LastModified = DateTime.UtcNow;
recordInfo.Comments = "Via Dynamic Updates (RFC 2136)";
_authZoneManager.SetRecord(zoneInfo.Name, uRecord); _authZoneManager.SetRecord(zoneInfo.Name, uRecord);
} }
else if (uRecord.Type == DnsResourceRecordType.SOA) else if (uRecord.Type == DnsResourceRecordType.SOA)
@@ -1715,6 +1723,10 @@ namespace DnsServerCore.Dns
IReadOnlyList<DnsResourceRecord> existingRRSet = _authZoneManager.GetRecords(zoneInfo.Name, uRecord.Name, uRecord.Type); IReadOnlyList<DnsResourceRecord> existingRRSet = _authZoneManager.GetRecords(zoneInfo.Name, uRecord.Name, uRecord.Type);
AddToOriginalRRSets(uRecord.Name, uRecord.Type, existingRRSet); AddToOriginalRRSets(uRecord.Name, uRecord.Type, existingRRSet);
GenericRecordInfo recordInfo = uRecord.GetAuthGenericRecordInfo();
recordInfo.LastModified = DateTime.UtcNow;
recordInfo.Comments = "Via Dynamic Updates (RFC 2136)";
_authZoneManager.SetRecord(zoneInfo.Name, uRecord); _authZoneManager.SetRecord(zoneInfo.Name, uRecord);
} }
else else
@@ -1728,6 +1740,10 @@ namespace DnsServerCore.Dns
if (uRecord.Type == DnsResourceRecordType.NS) if (uRecord.Type == DnsResourceRecordType.NS)
uRecord.SyncGlueRecords(request.Additional); uRecord.SyncGlueRecords(request.Additional);
GenericRecordInfo recordInfo = uRecord.GetAuthGenericRecordInfo();
recordInfo.LastModified = DateTime.UtcNow;
recordInfo.Comments = "Via Dynamic Updates (RFC 2136)";
_authZoneManager.AddRecord(zoneInfo.Name, uRecord); _authZoneManager.AddRecord(zoneInfo.Name, uRecord);
} }
} }
@@ -3707,10 +3723,14 @@ namespace DnsServerCore.Dns
//get a tailored response for the request //get a tailored response for the request
bool dnssecOk = request.DnssecOk; bool dnssecOk = request.DnssecOk;
if (dnssecOk && request.CheckingDisabled) if (request.CheckingDisabled)
{ {
DnsDatagram cdResponse = resolveResponse.CheckingDisabledResponse; DnsDatagram cdResponse = resolveResponse.CheckingDisabledResponse;
bool authenticData = false; bool authenticData = false;
IReadOnlyList<DnsResourceRecord> cdAnswer;
IReadOnlyList<DnsResourceRecord> cdAuthority;
IReadOnlyList<DnsResourceRecord> cdAdditional = RemoveOPTFromAdditional(cdResponse.Additional, dnssecOk);
EDnsHeaderFlags ednsFlags;
if (dnssecOk) if (dnssecOk)
{ {
@@ -3740,9 +3760,19 @@ namespace DnsServerCore.Dns
} }
} }
} }
cdAnswer = cdResponse.Answer;
cdAuthority = cdResponse.Authority;
ednsFlags = EDnsHeaderFlags.DNSSEC_OK;
}
else
{
cdAnswer = FilterDnssecRecords(cdResponse.Answer);
cdAuthority = FilterDnssecRecords(cdResponse.Authority);
ednsFlags = EDnsHeaderFlags.None;
} }
DnsDatagram finalCdResponse = new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, authenticData, true, cdResponse.RCODE, request.Question, cdResponse.Answer, cdResponse.Authority, RemoveOPTFromAdditional(cdResponse.Additional, true), _udpPayloadSize, EDnsHeaderFlags.DNSSEC_OK, cdResponse.EDNS?.Options); DnsDatagram finalCdResponse = new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, true, true, authenticData, true, cdResponse.RCODE, request.Question, cdAnswer, cdAuthority, cdAdditional, _udpPayloadSize, ednsFlags, cdResponse.EDNS?.Options);
DnsDatagramMetadata metadata = cdResponse.Metadata; DnsDatagramMetadata metadata = cdResponse.Metadata;
if (metadata is not null) if (metadata is not null)
finalCdResponse.SetMetadata(metadata.NameServer, metadata.RoundTripTime); finalCdResponse.SetMetadata(metadata.NameServer, metadata.RoundTripTime);
@@ -3971,6 +4001,39 @@ namespace DnsServerCore.Dns
} }
} }
private static IReadOnlyList<DnsResourceRecord> FilterDnssecRecords(IReadOnlyList<DnsResourceRecord> records)
{
foreach (DnsResourceRecord record1 in records)
{
switch (record1.Type)
{
case DnsResourceRecordType.RRSIG:
case DnsResourceRecordType.NSEC:
case DnsResourceRecordType.NSEC3:
List<DnsResourceRecord> noDnssecRecords = new List<DnsResourceRecord>();
foreach (DnsResourceRecord record2 in records)
{
switch (record2.Type)
{
case DnsResourceRecordType.RRSIG:
case DnsResourceRecordType.NSEC:
case DnsResourceRecordType.NSEC3:
break;
default:
noDnssecRecords.Add(record2);
break;
}
}
return noDnssecRecords;
}
}
return records;
}
private static IReadOnlyList<DnsResourceRecord> RemoveOPTFromAdditional(IReadOnlyList<DnsResourceRecord> additional, bool dnssecOk) private static IReadOnlyList<DnsResourceRecord> RemoveOPTFromAdditional(IReadOnlyList<DnsResourceRecord> additional, bool dnssecOk)
{ {
if (additional.Count == 0) if (additional.Count == 0)
@@ -4521,7 +4584,7 @@ namespace DnsServerCore.Dns
private async Task StartDoHAsync() private async Task StartDoHAsync()
{ {
IReadOnlyList<IPAddress> localAddresses = GetValidKestralLocalAddresses(_localEndPoints.Convert(delegate (IPEndPoint ep) { return ep.Address; })); IReadOnlyList<IPAddress> localAddresses = WebUtilities.GetValidKestralLocalAddresses(_localEndPoints.Convert(delegate (IPEndPoint ep) { return ep.Address; }));
try try
{ {
@@ -4668,89 +4731,6 @@ namespace DnsServerCore.Dns
} }
} }
internal static IReadOnlyList<IPAddress> GetValidKestralLocalAddresses(IReadOnlyList<IPAddress> localAddresses)
{
List<IPAddress> supportedLocalAddresses = new List<IPAddress>(localAddresses.Count);
foreach (IPAddress localAddress in localAddresses)
{
switch (localAddress.AddressFamily)
{
case AddressFamily.InterNetwork:
if (Socket.OSSupportsIPv4)
{
if (!supportedLocalAddresses.Contains(localAddress))
supportedLocalAddresses.Add(localAddress);
}
break;
case AddressFamily.InterNetworkV6:
if (Socket.OSSupportsIPv6)
{
if (!supportedLocalAddresses.Contains(localAddress))
supportedLocalAddresses.Add(localAddress);
}
break;
}
}
bool containsUnicastAddress = false;
foreach (IPAddress localAddress in supportedLocalAddresses)
{
if (!localAddress.Equals(IPAddress.Any) && !localAddress.Equals(IPAddress.IPv6Any))
{
containsUnicastAddress = true;
break;
}
}
List<IPAddress> newLocalAddresses = new List<IPAddress>(supportedLocalAddresses.Count);
if (containsUnicastAddress)
{
//replace any with loopback address
foreach (IPAddress localAddress in supportedLocalAddresses)
{
if (localAddress.Equals(IPAddress.Any))
{
if (!newLocalAddresses.Contains(IPAddress.Loopback))
newLocalAddresses.Add(IPAddress.Loopback);
}
else if (localAddress.Equals(IPAddress.IPv6Any))
{
if (!newLocalAddresses.Contains(IPAddress.IPv6Loopback))
newLocalAddresses.Add(IPAddress.IPv6Loopback);
}
else
{
if (!newLocalAddresses.Contains(localAddress))
newLocalAddresses.Add(localAddress);
}
}
}
else
{
//remove "0.0.0.0" if [::] exists
foreach (IPAddress localAddress in supportedLocalAddresses)
{
if (localAddress.Equals(IPAddress.Any))
{
if (!supportedLocalAddresses.Contains(IPAddress.IPv6Any))
newLocalAddresses.Add(localAddress);
}
else
{
newLocalAddresses.Add(localAddress);
}
}
}
return newLocalAddresses;
}
#endregion #endregion
#region public #region public