diff --git a/src/link/index.ts b/src/link/index.ts index a731f92..e3d45d8 100644 --- a/src/link/index.ts +++ b/src/link/index.ts @@ -17,6 +17,7 @@ export type TRPCMQTTLinkOptions = { export const mqttLink = ( opts: TRPCMQTTLinkOptions ): TRPCLink => { + // This runs once, at the initialization of the link return runtime => { const { client, @@ -50,39 +51,11 @@ export const mqttLink = ( } }); - const request = async (message: TRPCMQTTRequest) => - new Promise((resolve, reject) => { - const correlationId = randomUUID(); - const onTimeout = () => { - responseEmitter.off(correlationId, onMessage); - reject(new TRPCClientError('Request timed out after ' + requestTimeoutMs + 'ms')); - }; - const timeout = setTimeout(onTimeout, requestTimeoutMs); - const onMessage = (message: TRPCMQTTResponse) => { - clearTimeout(timeout); - resolve(message); - }; - responseEmitter.once(correlationId, onMessage); - if (protocolVersion >= 5) { - // MQTT 5.0+, use the correlationData & responseTopic field - const opts = { - properties: { - responseTopic, - correlationData: Buffer.from(correlationId) - } - }; - client.publish(requestTopic, JSON.stringify(message), opts); - } else { - // MQTT < 5.0, use the message itself - client.publish( - requestTopic, - JSON.stringify({ ...message, correlationId, responseTopic }) - ); - } - }); - return ({ op }) => { + // This runs every time a procedure is called return observable(observer => { + const abortController = new AbortController(); + const { id, type, path } = op; try { @@ -114,13 +87,55 @@ export const mqttLink = ( observer.complete(); }; - request({ - trpc: { - id, - method: type, - params: { path, input } - } - }) + const request = async (message: TRPCMQTTRequest, signal: AbortSignal) => + new Promise((resolve, reject) => { + const correlationId = randomUUID(); + const onTimeout = () => { + responseEmitter.off(correlationId, onMessage); + signal.onabort = null; + reject(new TRPCClientError('Request timed out after ' + requestTimeoutMs + 'ms')); + }; + const onAbort = () => { + // This runs when the request is aborted externally + clearTimeout(timeout); + responseEmitter.off(correlationId, onMessage); + reject(new TRPCClientError('Request aborted')); + }; + const timeout = setTimeout(onTimeout, requestTimeoutMs); + signal.onabort = onAbort; + const onMessage = (message: TRPCMQTTResponse) => { + clearTimeout(timeout); + resolve(message); + }; + responseEmitter.once(correlationId, onMessage); + if (protocolVersion >= 5) { + // MQTT 5.0+, use the correlationData & responseTopic field + const opts = { + properties: { + responseTopic, + correlationData: Buffer.from(correlationId) + } + }; + client.publish(requestTopic, JSON.stringify(message), opts); + } else { + // MQTT < 5.0, use the message itself + client.publish( + requestTopic, + JSON.stringify({ ...message, correlationId, responseTopic }) + ); + } + }); + + request( + { + trpc: { + id, + method: type, + params: { path, input } + } + }, + abortController.signal + ) .then(onMessage) .catch(cause => { observer.error( @@ -133,8 +148,10 @@ export const mqttLink = ( ); } - // eslint-disable-next-line @typescript-eslint/no-empty-function, prettier/prettier - return () => { }; + return () => { + // This runs after every procedure call, whether it was successful, unsuccessful, or externally aborted + abortController.abort(); + }; }); }; }; diff --git a/test/appRouter.ts b/test/appRouter.ts index cb8c42d..fe00d48 100644 --- a/test/appRouter.ts +++ b/test/appRouter.ts @@ -24,5 +24,9 @@ export const appRouter = router({ .mutation(({ input }) => { state.count += input; return state.count; - }) + }), + slow: publicProcedure.query(async () => { + await new Promise(resolve => setTimeout(resolve, 10 * 1000)); + return 'done'; + }) }); diff --git a/test/index.test.ts b/test/index.test.ts index a05586e..1293c17 100644 --- a/test/index.test.ts +++ b/test/index.test.ts @@ -57,6 +57,16 @@ test('countUp mutation', async () => { expect(addTwo).toBe(3); }); +test('abortSignal is handled', async () => { + const controller = new AbortController(); + const promise = client.slow.query(undefined, { + signal: controller.signal + }); + + controller.abort(); + await expect(promise).rejects.toThrow('aborted'); +}); + afterAll(async () => { mqttClient.end(); broker.close();