DnsServer: refactored the DirectQueryAsync() method to return the ProcessQueryAsync() task directly. Minor refactoring done.

This commit is contained in:
Shreyas Zare
2021-06-19 14:15:52 +05:30
parent c45bf39435
commit dd09a9f477

View File

@@ -378,35 +378,32 @@ namespace DnsServerCore.Dns
}
//send response
if (response is not null)
byte[] sendBuffer = new byte[512];
using (MemoryStream sendBufferStream = new MemoryStream(sendBuffer))
{
byte[] sendBuffer = new byte[512];
using (MemoryStream sendBufferStream = new MemoryStream(sendBuffer))
try
{
try
{
response.WriteToUdp(sendBufferStream);
}
catch (NotSupportedException)
{
response = new DnsDatagram(response.Identifier, true, response.OPCODE, response.AuthoritativeAnswer, true, response.RecursionDesired, response.RecursionAvailable, response.AuthenticData, response.CheckingDisabled, response.RCODE, response.Question) { Tag = StatsResponseType.Authoritative };
response.WriteToUdp(sendBufferStream);
}
catch (NotSupportedException)
{
response = new DnsDatagram(response.Identifier, true, response.OPCODE, response.AuthoritativeAnswer, true, response.RecursionDesired, response.RecursionAvailable, response.AuthenticData, response.CheckingDisabled, response.RCODE, response.Question) { Tag = StatsResponseType.Authoritative };
sendBufferStream.Position = 0;
response.WriteToUdp(sendBufferStream);
}
//send dns datagram async
await udpListener.SendToAsync(new ArraySegment<byte>(sendBuffer, 0, (int)sendBufferStream.Position), SocketFlags.None, remoteEP);
sendBufferStream.Position = 0;
response.WriteToUdp(sendBufferStream);
}
LogManager queryLog = _queryLog;
if (queryLog is not null)
queryLog.Write(remoteEP, DnsTransportProtocol.Udp, request, response);
StatsManager stats = _stats;
if (stats is not null)
stats.Update(response, remoteEP.Address);
//send dns datagram async
await udpListener.SendToAsync(new ArraySegment<byte>(sendBuffer, 0, (int)sendBufferStream.Position), SocketFlags.None, remoteEP);
}
LogManager queryLog = _queryLog;
if (queryLog is not null)
queryLog.Write(remoteEP, DnsTransportProtocol.Udp, request, response);
StatsManager stats = _stats;
if (stats is not null)
stats.Update(response, remoteEP.Address);
}
catch (Exception ex)
{
@@ -582,43 +579,36 @@ namespace DnsServerCore.Dns
else
{
//format error
response = new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, request.Question) { Tag = StatsResponseType.Authoritative };
LogManager queryLog = _queryLog;
if (queryLog is not null)
queryLog.Write(remoteEP, protocol, request, response);
if (!(request.ParsingException is IOException))
{
LogManager log = _log;
if (log is not null)
log.Write(remoteEP, protocol, request.ParsingException);
}
response = new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, request.Question) { Tag = StatsResponseType.Authoritative };
}
//send response
if (response is not null)
await writeSemaphore.WaitAsync();
try
{
await writeSemaphore.WaitAsync();
try
{
//send dns datagram
await response.WriteToTcpAsync(stream, writeBuffer);
await stream.FlushAsync();
}
finally
{
writeSemaphore.Release();
}
LogManager queryLog = _queryLog;
if (queryLog is not null)
queryLog.Write(remoteEP, protocol, request, response);
StatsManager stats = _stats;
if (stats is not null)
stats.Update(response, remoteEP.Address);
//send dns datagram
await response.WriteToTcpAsync(stream, writeBuffer);
await stream.FlushAsync();
}
finally
{
writeSemaphore.Release();
}
LogManager queryLog = _queryLog;
if (queryLog is not null)
queryLog.Write(remoteEP, protocol, request, response);
StatsManager stats = _stats;
if (stats is not null)
stats.Update(response, remoteEP.Address);
}
catch (IOException)
{
@@ -776,24 +766,21 @@ namespace DnsServerCore.Dns
dnsResponse = new DnsDatagram(dnsRequest.Identifier, true, dnsRequest.OPCODE, false, false, dnsRequest.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, dnsRequest.Question) { Tag = StatsResponseType.Authoritative };
}
if (dnsResponse is not null)
using (MemoryStream mS = new MemoryStream(512))
{
using (MemoryStream mS = new MemoryStream(512))
{
dnsResponse.WriteToUdp(mS);
dnsResponse.WriteToUdp(mS);
mS.Position = 0;
await SendContentAsync(stream, requestConnection, "application/dns-message", mS);
}
LogManager queryLog = _queryLog;
if (queryLog is not null)
queryLog.Write(remoteEP, protocol, dnsRequest, dnsResponse);
StatsManager stats = _stats;
if (stats is not null)
stats.Update(dnsResponse, remoteEP.Address);
mS.Position = 0;
await SendContentAsync(stream, requestConnection, "application/dns-message", mS);
}
LogManager queryLog = _queryLog;
if (queryLog is not null)
queryLog.Write(remoteEP, protocol, dnsRequest, dnsResponse);
StatsManager stats = _stats;
if (stats is not null)
stats.Update(dnsResponse, remoteEP.Address);
}
#endregion
break;
@@ -812,26 +799,24 @@ namespace DnsServerCore.Dns
dnsRequest = new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { new DnsQuestionRecord(strName, (DnsResourceRecordType)int.Parse(strType), DnsClass.IN) });
DnsDatagram dnsResponse = await ProcessQueryAsync(dnsRequest, remoteEP, IsRecursionAllowed(remoteEP), protocol);
if (dnsResponse is not null)
using (MemoryStream mS = new MemoryStream(512))
{
using (MemoryStream mS = new MemoryStream(512))
{
JsonTextWriter jsonWriter = new JsonTextWriter(new StreamWriter(mS));
dnsResponse.WriteToJson(jsonWriter);
jsonWriter.Flush();
JsonTextWriter jsonWriter = new JsonTextWriter(new StreamWriter(mS));
dnsResponse.WriteToJson(jsonWriter);
jsonWriter.Flush();
mS.Position = 0;
await SendContentAsync(stream, requestConnection, "application/dns-json; charset=utf-8", mS);
}
LogManager queryLog = _queryLog;
if (queryLog is not null)
queryLog.Write(remoteEP, protocol, dnsRequest, dnsResponse);
StatsManager stats = _stats;
if (stats is not null)
stats.Update(dnsResponse, remoteEP.Address);
mS.Position = 0;
await SendContentAsync(stream, requestConnection, "application/dns-json; charset=utf-8", mS);
}
LogManager queryLog = _queryLog;
if (queryLog is not null)
queryLog.Write(remoteEP, protocol, dnsRequest, dnsResponse);
StatsManager stats = _stats;
if (stats is not null)
stats.Update(dnsResponse, remoteEP.Address);
}
#endregion
break;
@@ -1015,7 +1000,7 @@ namespace DnsServerCore.Dns
private async Task<DnsDatagram> ProcessQueryAsync(DnsDatagram request, IPEndPoint remoteEP, bool isRecursionAllowed, DnsTransportProtocol protocol)
{
if (request.IsResponse)
return null;
return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.FormatError, request.Question) { Tag = StatsResponseType.Authoritative };
switch (request.OPCODE)
{
@@ -1997,6 +1982,7 @@ namespace DnsServerCore.Dns
{
case DnsResponseCode.NoError:
case DnsResponseCode.NxDomain:
case DnsResponseCode.YXDomain:
taskCompletionSource.SetResult(response);
break;
@@ -2850,30 +2836,9 @@ namespace DnsServerCore.Dns
_state = ServiceState.Stopped;
}
public async Task<DnsDatagram> DirectQueryAsync(DnsQuestionRecord question, int timeout = 2000)
public Task<DnsDatagram> DirectQueryAsync(DnsQuestionRecord question)
{
try
{
Task<DnsDatagram> task = ProcessQueryAsync(new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { question }), new IPEndPoint(IPAddress.Any, 0), true, DnsTransportProtocol.Tcp);
using (CancellationTokenSource timeoutCancellationTokenSource = new CancellationTokenSource())
{
if (await Task.WhenAny(task, Task.Delay(timeout, timeoutCancellationTokenSource.Token)) != task)
return null;
timeoutCancellationTokenSource.Cancel(); //stop delay task
}
return await task;
}
catch (Exception ex)
{
LogManager log = _log;
if (log is not null)
log.Write(ex);
return null;
}
return ProcessQueryAsync(new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { question }), new IPEndPoint(IPAddress.Any, 0), true, DnsTransportProtocol.Tcp);
}
Task<DnsDatagram> IDnsClient.ResolveAsync(DnsQuestionRecord question)