diff --git a/src/Microsoft.AspNetCore.NodeServices/HostingModels/HttpNodeInstance.cs b/src/Microsoft.AspNetCore.NodeServices/HostingModels/HttpNodeInstance.cs index 0ab2782..0c11371 100644 --- a/src/Microsoft.AspNetCore.NodeServices/HostingModels/HttpNodeInstance.cs +++ b/src/Microsoft.AspNetCore.NodeServices/HostingModels/HttpNodeInstance.cs @@ -4,6 +4,7 @@ using System.IO; using System.Net.Http; using System.Text; using System.Text.RegularExpressions; +using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Newtonsoft.Json; @@ -57,15 +58,17 @@ namespace Microsoft.AspNetCore.NodeServices.HostingModels return $"--port {port}"; } - protected override async Task InvokeExportAsync(NodeInvocationInfo invocationInfo) + protected override async Task InvokeExportAsync( + NodeInvocationInfo invocationInfo, CancellationToken cancellationToken) { var payloadJson = JsonConvert.SerializeObject(invocationInfo, jsonSerializerSettings); var payload = new StringContent(payloadJson, Encoding.UTF8, "application/json"); - var response = await _client.PostAsync("http://localhost:" + _portNumber, payload); + var response = await _client.PostAsync("http://localhost:" + _portNumber, payload, cancellationToken); if (!response.IsSuccessStatusCode) { - var responseErrorString = await response.Content.ReadAsStringAsync(); + // Unfortunately there's no true way to cancel ReadAsStringAsync calls, hence AbandonIfCancelled + var responseErrorString = await response.Content.ReadAsStringAsync().OrThrowOnCancellation(cancellationToken); throw new Exception("Call to Node module failed with error: " + responseErrorString); } @@ -81,11 +84,11 @@ namespace Microsoft.AspNetCore.NodeServices.HostingModels typeof(T).FullName); } - var responseString = await response.Content.ReadAsStringAsync(); + var responseString = await response.Content.ReadAsStringAsync().OrThrowOnCancellation(cancellationToken); return (T)(object)responseString; case "application/json": - var responseJson = await response.Content.ReadAsStringAsync(); + var responseJson = await response.Content.ReadAsStringAsync().OrThrowOnCancellation(cancellationToken); return JsonConvert.DeserializeObject(responseJson, jsonSerializerSettings); case "application/octet-stream": @@ -97,7 +100,7 @@ namespace Microsoft.AspNetCore.NodeServices.HostingModels typeof(T).FullName + ". Instead you must use the generic type System.IO.Stream."); } - return (T)(object)(await response.Content.ReadAsStreamAsync()); + return (T)(object)(await response.Content.ReadAsStreamAsync().OrThrowOnCancellation(cancellationToken)); default: throw new InvalidOperationException("Unexpected response content type: " + responseContentType.MediaType); diff --git a/src/Microsoft.AspNetCore.NodeServices/HostingModels/INodeInstance.cs b/src/Microsoft.AspNetCore.NodeServices/HostingModels/INodeInstance.cs index cac69f2..68a4319 100644 --- a/src/Microsoft.AspNetCore.NodeServices/HostingModels/INodeInstance.cs +++ b/src/Microsoft.AspNetCore.NodeServices/HostingModels/INodeInstance.cs @@ -1,10 +1,11 @@ using System; +using System.Threading; using System.Threading.Tasks; namespace Microsoft.AspNetCore.NodeServices.HostingModels { public interface INodeInstance : IDisposable { - Task InvokeExportAsync(string moduleName, string exportNameOrNull, params object[] args); + Task InvokeExportAsync(CancellationToken cancellationToken, string moduleName, string exportNameOrNull, params object[] args); } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.NodeServices/HostingModels/OutOfProcessNodeInstance.cs b/src/Microsoft.AspNetCore.NodeServices/HostingModels/OutOfProcessNodeInstance.cs index 641c098..9a583d3 100644 --- a/src/Microsoft.AspNetCore.NodeServices/HostingModels/OutOfProcessNodeInstance.cs +++ b/src/Microsoft.AspNetCore.NodeServices/HostingModels/OutOfProcessNodeInstance.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; @@ -67,7 +68,8 @@ If you haven't yet installed node-inspector, you can do so as follows: ConnectToInputOutputStreams(); } - public async Task InvokeExportAsync(string moduleName, string exportNameOrNull, params object[] args) + public async Task InvokeExportAsync( + CancellationToken cancellationToken, string moduleName, string exportNameOrNull, params object[] args) { if (_nodeProcess.HasExited || _nodeProcessNeedsRestart) { @@ -79,15 +81,17 @@ If you haven't yet installed node-inspector, you can do so as follows: throw new NodeInvocationException(message, null, nodeInstanceUnavailable: true); } - // Wait until the connection is established. This will throw if the connection fails to initialize. - await _connectionIsReadySource.Task; + // Wait until the connection is established. This will throw if the connection fails to initialize, + // or if cancellation is requested first. Note that we can't really cancel the "establishing connection" + // task because that's shared with all callers, but we can stop waiting for it if this call is cancelled. + await _connectionIsReadySource.Task.OrThrowOnCancellation(cancellationToken); return await InvokeExportAsync(new NodeInvocationInfo { ModuleName = moduleName, ExportedFunctionName = exportNameOrNull, Args = args - }); + }, cancellationToken); } public void Dispose() @@ -96,7 +100,9 @@ If you haven't yet installed node-inspector, you can do so as follows: GC.SuppressFinalize(this); } - protected abstract Task InvokeExportAsync(NodeInvocationInfo invocationInfo); + protected abstract Task InvokeExportAsync( + NodeInvocationInfo invocationInfo, + CancellationToken cancellationToken); // This method is virtual, as it provides a way to override the NODE_PATH or the path to node.exe protected virtual ProcessStartInfo PrepareNodeProcessStartInfo( diff --git a/src/Microsoft.AspNetCore.NodeServices/HostingModels/SocketNodeInstance.cs b/src/Microsoft.AspNetCore.NodeServices/HostingModels/SocketNodeInstance.cs index 2acfb78..2acfc0e 100644 --- a/src/Microsoft.AspNetCore.NodeServices/HostingModels/SocketNodeInstance.cs +++ b/src/Microsoft.AspNetCore.NodeServices/HostingModels/SocketNodeInstance.cs @@ -57,7 +57,7 @@ namespace Microsoft.AspNetCore.NodeServices.HostingModels _socketAddress = socketAddress; } - protected override async Task InvokeExportAsync(NodeInvocationInfo invocationInfo) + protected override async Task InvokeExportAsync(NodeInvocationInfo invocationInfo, CancellationToken cancellationToken) { if (_connectionHasFailed) { @@ -70,7 +70,12 @@ namespace Microsoft.AspNetCore.NodeServices.HostingModels if (_virtualConnectionClient == null) { - await EnsureVirtualConnectionClientCreated(); + // Although we could pass the cancellationToken into EnsureVirtualConnectionClientCreated and + // have it signal cancellations upstream, that would be a bad thing to do, because all callers + // wait for the same connection task. There's no reason why the first caller should have the + // special ability to cancel the connection process in a way that would affect subsequent + // callers. So, each caller just independently stops awaiting connection if that call is cancelled. + await EnsureVirtualConnectionClientCreated().OrThrowOnCancellation(cancellationToken); } // For each invocation, we open a new virtual connection. This gives an API equivalent to opening a new @@ -83,7 +88,7 @@ namespace Microsoft.AspNetCore.NodeServices.HostingModels virtualConnection = _virtualConnectionClient.OpenVirtualConnection(); // Send request - await WriteJsonLineAsync(virtualConnection, invocationInfo); + await WriteJsonLineAsync(virtualConnection, invocationInfo, cancellationToken); // Determine what kind of response format is expected if (typeof(T) == typeof(Stream)) @@ -96,7 +101,7 @@ namespace Microsoft.AspNetCore.NodeServices.HostingModels else { // Parse and return non-streamed JSON response - var response = await ReadJsonAsync>(virtualConnection); + var response = await ReadJsonAsync>(virtualConnection, cancellationToken); if (response.ErrorMessage != null) { throw new NodeInvocationException(response.ErrorMessage, response.ErrorDetails); @@ -163,27 +168,27 @@ namespace Microsoft.AspNetCore.NodeServices.HostingModels base.Dispose(disposing); } - private static async Task WriteJsonLineAsync(Stream stream, object serializableObject) + private static async Task WriteJsonLineAsync(Stream stream, object serializableObject, CancellationToken cancellationToken) { var json = JsonConvert.SerializeObject(serializableObject, jsonSerializerSettings); var bytes = Encoding.UTF8.GetBytes(json + '\n'); - await stream.WriteAsync(bytes, 0, bytes.Length); + await stream.WriteAsync(bytes, 0, bytes.Length, cancellationToken); } - private static async Task ReadJsonAsync(Stream stream) + private static async Task ReadJsonAsync(Stream stream, CancellationToken cancellationToken) { - var json = Encoding.UTF8.GetString(await ReadAllBytesAsync(stream)); + var json = Encoding.UTF8.GetString(await ReadAllBytesAsync(stream, cancellationToken)); return JsonConvert.DeserializeObject(json, jsonSerializerSettings); } - private static async Task ReadAllBytesAsync(Stream input) + private static async Task ReadAllBytesAsync(Stream input, CancellationToken cancellationToken) { byte[] buffer = new byte[16 * 1024]; using (var ms = new MemoryStream()) { int read; - while ((read = await input.ReadAsync(buffer, 0, buffer.Length)) > 0) + while ((read = await input.ReadAsync(buffer, 0, buffer.Length, cancellationToken)) > 0) { ms.Write(buffer, 0, read); } diff --git a/src/Microsoft.AspNetCore.NodeServices/INodeServices.cs b/src/Microsoft.AspNetCore.NodeServices/INodeServices.cs index 3aa09e3..fbc32d0 100644 --- a/src/Microsoft.AspNetCore.NodeServices/INodeServices.cs +++ b/src/Microsoft.AspNetCore.NodeServices/INodeServices.cs @@ -1,4 +1,5 @@ using System; +using System.Threading; using System.Threading.Tasks; namespace Microsoft.AspNetCore.NodeServices @@ -6,8 +7,11 @@ namespace Microsoft.AspNetCore.NodeServices public interface INodeServices : IDisposable { Task InvokeAsync(string moduleName, params object[] args); + Task InvokeAsync(CancellationToken cancellationToken, string moduleName, params object[] args); Task InvokeExportAsync(string moduleName, string exportedFunctionName, params object[] args); + Task InvokeExportAsync(CancellationToken cancellationToken, string moduleName, string exportedFunctionName, params object[] args); + [Obsolete("Use InvokeAsync instead")] Task Invoke(string moduleName, params object[] args); diff --git a/src/Microsoft.AspNetCore.NodeServices/NodeServicesImpl.cs b/src/Microsoft.AspNetCore.NodeServices/NodeServicesImpl.cs index 037cc96..38d3055 100644 --- a/src/Microsoft.AspNetCore.NodeServices/NodeServicesImpl.cs +++ b/src/Microsoft.AspNetCore.NodeServices/NodeServicesImpl.cs @@ -1,4 +1,5 @@ using System; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.NodeServices.HostingModels; @@ -34,19 +35,29 @@ namespace Microsoft.AspNetCore.NodeServices return InvokeExportAsync(moduleName, null, args); } - public Task InvokeExportAsync(string moduleName, string exportedFunctionName, params object[] args) + public Task InvokeAsync(CancellationToken cancellationToken, string moduleName, params object[] args) { - return InvokeExportWithPossibleRetryAsync(moduleName, exportedFunctionName, args, allowRetry: true); + return InvokeExportAsync(cancellationToken, moduleName, null, args); } - public async Task InvokeExportWithPossibleRetryAsync(string moduleName, string exportedFunctionName, object[] args, bool allowRetry) + public Task InvokeExportAsync(string moduleName, string exportedFunctionName, params object[] args) + { + return InvokeExportWithPossibleRetryAsync(moduleName, exportedFunctionName, args, /* allowRetry */ true, CancellationToken.None); + } + + public Task InvokeExportAsync(CancellationToken cancellationToken, string moduleName, string exportedFunctionName, params object[] args) + { + return InvokeExportWithPossibleRetryAsync(moduleName, exportedFunctionName, args, /* allowRetry */ true, cancellationToken); + } + + public async Task InvokeExportWithPossibleRetryAsync(string moduleName, string exportedFunctionName, object[] args, bool allowRetry, CancellationToken cancellationToken) { ThrowAnyOutstandingDelayedDisposalException(); var nodeInstance = GetOrCreateCurrentNodeInstance(); try { - return await nodeInstance.InvokeExportAsync(moduleName, exportedFunctionName, args); + return await nodeInstance.InvokeExportAsync(cancellationToken, moduleName, exportedFunctionName, args); } catch (NodeInvocationException ex) { @@ -69,7 +80,7 @@ namespace Microsoft.AspNetCore.NodeServices // One the next call, don't allow retries, because we could get into an infinite retry loop, or a long retry // loop that masks an underlying problem. A newly-created Node instance should be able to accept invocations, // or something more serious must be wrong. - return await InvokeExportWithPossibleRetryAsync(moduleName, exportedFunctionName, args, allowRetry: false); + return await InvokeExportWithPossibleRetryAsync(moduleName, exportedFunctionName, args, /* allowRetry */ false, cancellationToken); } else { diff --git a/src/Microsoft.AspNetCore.NodeServices/Util/TaskExtensions.cs b/src/Microsoft.AspNetCore.NodeServices/Util/TaskExtensions.cs new file mode 100644 index 0000000..75cfdb1 --- /dev/null +++ b/src/Microsoft.AspNetCore.NodeServices/Util/TaskExtensions.cs @@ -0,0 +1,30 @@ +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.NodeServices +{ + internal static class TaskExtensions + { + public static Task OrThrowOnCancellation(this Task task, CancellationToken cancellationToken) + { + return task.IsCompleted + ? task // If the task is already completed, no need to wrap it in a further layer of task + : task.ContinueWith( + _ => {}, // If the task completes, allow execution to continue + cancellationToken, + TaskContinuationOptions.ExecuteSynchronously, + TaskScheduler.Default); + } + + public static Task OrThrowOnCancellation(this Task task, CancellationToken cancellationToken) + { + return task.IsCompleted + ? task // If the task is already completed, no need to wrap it in a further layer of task + : task.ContinueWith( + t => t.Result, // If the task completes, pass through its result + cancellationToken, + TaskContinuationOptions.ExecuteSynchronously, + TaskScheduler.Default); + } + } +} \ No newline at end of file