diff --git a/DnsServerCore/Dns/DnsServer.cs b/DnsServerCore/Dns/DnsServer.cs index 49a901b3..38f952b2 100644 --- a/DnsServerCore/Dns/DnsServer.cs +++ b/DnsServerCore/Dns/DnsServer.cs @@ -378,35 +378,32 @@ namespace DnsServerCore.Dns } //send response - if (response is not null) + byte[] sendBuffer = new byte[512]; + using (MemoryStream sendBufferStream = new MemoryStream(sendBuffer)) { - byte[] sendBuffer = new byte[512]; - using (MemoryStream sendBufferStream = new MemoryStream(sendBuffer)) + try { - try - { - response.WriteToUdp(sendBufferStream); - } - catch (NotSupportedException) - { - response = new DnsDatagram(response.Identifier, true, response.OPCODE, response.AuthoritativeAnswer, true, response.RecursionDesired, response.RecursionAvailable, response.AuthenticData, response.CheckingDisabled, response.RCODE, response.Question) { Tag = StatsResponseType.Authoritative }; + response.WriteToUdp(sendBufferStream); + } + catch (NotSupportedException) + { + response = new DnsDatagram(response.Identifier, true, response.OPCODE, response.AuthoritativeAnswer, true, response.RecursionDesired, response.RecursionAvailable, response.AuthenticData, response.CheckingDisabled, response.RCODE, response.Question) { Tag = StatsResponseType.Authoritative }; - sendBufferStream.Position = 0; - response.WriteToUdp(sendBufferStream); - } - - //send dns datagram async - await udpListener.SendToAsync(new ArraySegment(sendBuffer, 0, (int)sendBufferStream.Position), SocketFlags.None, remoteEP); + sendBufferStream.Position = 0; + response.WriteToUdp(sendBufferStream); } - LogManager queryLog = _queryLog; - if (queryLog is not null) - queryLog.Write(remoteEP, DnsTransportProtocol.Udp, request, response); - - StatsManager stats = _stats; - if (stats is not null) - stats.Update(response, remoteEP.Address); + //send dns datagram async + await udpListener.SendToAsync(new ArraySegment(sendBuffer, 0, (int)sendBufferStream.Position), SocketFlags.None, remoteEP); } + + LogManager queryLog = _queryLog; + if (queryLog is not null) + queryLog.Write(remoteEP, DnsTransportProtocol.Udp, request, response); + + StatsManager stats = _stats; + if (stats is not null) + stats.Update(response, remoteEP.Address); } catch (Exception ex) { @@ -582,43 +579,36 @@ namespace DnsServerCore.Dns else { //format error - response = new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, request.Question) { Tag = StatsResponseType.Authoritative }; - - LogManager queryLog = _queryLog; - if (queryLog is not null) - queryLog.Write(remoteEP, protocol, request, response); - if (!(request.ParsingException is IOException)) { LogManager log = _log; if (log is not null) log.Write(remoteEP, protocol, request.ParsingException); } + + response = new DnsDatagram(request.Identifier, true, request.OPCODE, false, false, request.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, request.Question) { Tag = StatsResponseType.Authoritative }; } //send response - if (response is not null) + await writeSemaphore.WaitAsync(); + try { - await writeSemaphore.WaitAsync(); - try - { - //send dns datagram - await response.WriteToTcpAsync(stream, writeBuffer); - await stream.FlushAsync(); - } - finally - { - writeSemaphore.Release(); - } - - LogManager queryLog = _queryLog; - if (queryLog is not null) - queryLog.Write(remoteEP, protocol, request, response); - - StatsManager stats = _stats; - if (stats is not null) - stats.Update(response, remoteEP.Address); + //send dns datagram + await response.WriteToTcpAsync(stream, writeBuffer); + await stream.FlushAsync(); } + finally + { + writeSemaphore.Release(); + } + + LogManager queryLog = _queryLog; + if (queryLog is not null) + queryLog.Write(remoteEP, protocol, request, response); + + StatsManager stats = _stats; + if (stats is not null) + stats.Update(response, remoteEP.Address); } catch (IOException) { @@ -776,24 +766,21 @@ namespace DnsServerCore.Dns dnsResponse = new DnsDatagram(dnsRequest.Identifier, true, dnsRequest.OPCODE, false, false, dnsRequest.RecursionDesired, IsRecursionAllowed(remoteEP), false, false, DnsResponseCode.FormatError, dnsRequest.Question) { Tag = StatsResponseType.Authoritative }; } - if (dnsResponse is not null) + using (MemoryStream mS = new MemoryStream(512)) { - using (MemoryStream mS = new MemoryStream(512)) - { - dnsResponse.WriteToUdp(mS); + dnsResponse.WriteToUdp(mS); - mS.Position = 0; - await SendContentAsync(stream, requestConnection, "application/dns-message", mS); - } - - LogManager queryLog = _queryLog; - if (queryLog is not null) - queryLog.Write(remoteEP, protocol, dnsRequest, dnsResponse); - - StatsManager stats = _stats; - if (stats is not null) - stats.Update(dnsResponse, remoteEP.Address); + mS.Position = 0; + await SendContentAsync(stream, requestConnection, "application/dns-message", mS); } + + LogManager queryLog = _queryLog; + if (queryLog is not null) + queryLog.Write(remoteEP, protocol, dnsRequest, dnsResponse); + + StatsManager stats = _stats; + if (stats is not null) + stats.Update(dnsResponse, remoteEP.Address); } #endregion break; @@ -812,26 +799,24 @@ namespace DnsServerCore.Dns dnsRequest = new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { new DnsQuestionRecord(strName, (DnsResourceRecordType)int.Parse(strType), DnsClass.IN) }); DnsDatagram dnsResponse = await ProcessQueryAsync(dnsRequest, remoteEP, IsRecursionAllowed(remoteEP), protocol); - if (dnsResponse is not null) + + using (MemoryStream mS = new MemoryStream(512)) { - using (MemoryStream mS = new MemoryStream(512)) - { - JsonTextWriter jsonWriter = new JsonTextWriter(new StreamWriter(mS)); - dnsResponse.WriteToJson(jsonWriter); - jsonWriter.Flush(); + JsonTextWriter jsonWriter = new JsonTextWriter(new StreamWriter(mS)); + dnsResponse.WriteToJson(jsonWriter); + jsonWriter.Flush(); - mS.Position = 0; - await SendContentAsync(stream, requestConnection, "application/dns-json; charset=utf-8", mS); - } - - LogManager queryLog = _queryLog; - if (queryLog is not null) - queryLog.Write(remoteEP, protocol, dnsRequest, dnsResponse); - - StatsManager stats = _stats; - if (stats is not null) - stats.Update(dnsResponse, remoteEP.Address); + mS.Position = 0; + await SendContentAsync(stream, requestConnection, "application/dns-json; charset=utf-8", mS); } + + LogManager queryLog = _queryLog; + if (queryLog is not null) + queryLog.Write(remoteEP, protocol, dnsRequest, dnsResponse); + + StatsManager stats = _stats; + if (stats is not null) + stats.Update(dnsResponse, remoteEP.Address); } #endregion break; @@ -1015,7 +1000,7 @@ namespace DnsServerCore.Dns private async Task ProcessQueryAsync(DnsDatagram request, IPEndPoint remoteEP, bool isRecursionAllowed, DnsTransportProtocol protocol) { if (request.IsResponse) - return null; + return new DnsDatagram(request.Identifier, true, DnsOpcode.StandardQuery, false, false, request.RecursionDesired, false, false, false, DnsResponseCode.FormatError, request.Question) { Tag = StatsResponseType.Authoritative }; switch (request.OPCODE) { @@ -1997,6 +1982,7 @@ namespace DnsServerCore.Dns { case DnsResponseCode.NoError: case DnsResponseCode.NxDomain: + case DnsResponseCode.YXDomain: taskCompletionSource.SetResult(response); break; @@ -2850,30 +2836,9 @@ namespace DnsServerCore.Dns _state = ServiceState.Stopped; } - public async Task DirectQueryAsync(DnsQuestionRecord question, int timeout = 2000) + public Task DirectQueryAsync(DnsQuestionRecord question) { - try - { - Task task = ProcessQueryAsync(new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { question }), new IPEndPoint(IPAddress.Any, 0), true, DnsTransportProtocol.Tcp); - - using (CancellationTokenSource timeoutCancellationTokenSource = new CancellationTokenSource()) - { - if (await Task.WhenAny(task, Task.Delay(timeout, timeoutCancellationTokenSource.Token)) != task) - return null; - - timeoutCancellationTokenSource.Cancel(); //stop delay task - } - - return await task; - } - catch (Exception ex) - { - LogManager log = _log; - if (log is not null) - log.Write(ex); - - return null; - } + return ProcessQueryAsync(new DnsDatagram(0, false, DnsOpcode.StandardQuery, false, false, true, false, false, false, DnsResponseCode.NoError, new DnsQuestionRecord[] { question }), new IPEndPoint(IPAddress.Any, 0), true, DnsTransportProtocol.Tcp); } Task IDnsClient.ResolveAsync(DnsQuestionRecord question)