From 76ab8c8726f8bb7b69f935b871d7d018a8b634ed Mon Sep 17 00:00:00 2001 From: Erwan d'Orgeville Date: Fri, 24 May 2024 14:41:32 -0400 Subject: [PATCH] feat: add `createContext` support & associated test --- src/adapter/index.ts | 12 +++++++----- test/appRouter.ts | 11 ++++++++++- test/factory.ts | 5 +++-- test/index.test.ts | 9 +++++++++ 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/adapter/index.ts b/src/adapter/index.ts index b73abee..15165c5 100644 --- a/src/adapter/index.ts +++ b/src/adapter/index.ts @@ -24,12 +24,13 @@ export type CreateMQTTHandlerOptions = { router: TRouter; onError?: OnErrorFunction; verbose?: boolean; + createContext?: () => Promise>; }; export const createMQTTHandler = ( opts: CreateMQTTHandlerOptions ) => { - const { client, requestTopic: requestTopic, router, onError, verbose } = opts; + const { client, requestTopic: requestTopic, router, onError, verbose, createContext } = opts; const protocolVersion = client.options.protocolVersion ?? 4; client.subscribe(requestTopic); @@ -43,7 +44,7 @@ export const createMQTTHandler = ( const correlationId = packet.properties?.correlationData?.toString(); const responseTopic = packet.properties?.responseTopic?.toString(); if (!correlationId || !responseTopic) return; - const res = await handleMessage(router, msg, onError); + const res = await handleMessage(router, msg, onError, createContext); if (!res) return; client.publish(responseTopic, Buffer.from(JSON.stringify({ trpc: res })), { properties: { @@ -62,7 +63,7 @@ export const createMQTTHandler = ( return; } if (!correlationId || !responseTopic) return; - const res = await handleMessage(router, msg, onError); + const res = await handleMessage(router, msg, onError, createContext); if (!res) return; client.publish(responseTopic, Buffer.from(JSON.stringify({ trpc: res, correlationId }))); } @@ -72,7 +73,8 @@ export const createMQTTHandler = ( async function handleMessage( router: TRouter, msg: ConsumeMessage, - onError?: OnErrorFunction + onError?: OnErrorFunction, + createContext?: () => Promise> ) { const { transformer } = router._def._config; @@ -85,7 +87,7 @@ async function handleMessage( const { id, params } = trpc; const type = MQTT_METHOD_PROCEDURE_TYPE_MAP[trpc.method] ?? ('query' as const); - const ctx: inferRouterContext | undefined = undefined; + const ctx: inferRouterContext | undefined = await createContext?.(); try { const path = params.path; diff --git a/test/appRouter.ts b/test/appRouter.ts index fe00d48..a22df5e 100644 --- a/test/appRouter.ts +++ b/test/appRouter.ts @@ -2,7 +2,13 @@ import { initTRPC } from '@trpc/server'; export type AppRouter = typeof appRouter; -const t = initTRPC.create(); +export async function createContext() { + return { hello: 'world' }; +} + +export type Context = Awaited>; + +const t = initTRPC.context().create(); const publicProcedure = t.procedure; const router = t.router; @@ -28,5 +34,8 @@ export const appRouter = router({ slow: publicProcedure.query(async () => { await new Promise(resolve => setTimeout(resolve, 10 * 1000)); return 'done'; + }), + getContext: publicProcedure.query(({ ctx }) => { + return ctx; }) }); diff --git a/test/factory.ts b/test/factory.ts index b017a7c..6c6057d 100644 --- a/test/factory.ts +++ b/test/factory.ts @@ -6,7 +6,7 @@ import { createServer } from 'net'; import { createMQTTHandler } from '../src/adapter'; import { mqttLink } from '../src/link'; -import { AppRouter, appRouter } from './appRouter'; +import { type AppRouter, appRouter, createContext } from './appRouter'; export function factory() { const requestTopic = 'rpc/request'; @@ -20,7 +20,8 @@ export function factory() { createMQTTHandler({ client: mqttClient, requestTopic, - router: appRouter + router: appRouter, + createContext }); const client = createTRPCProxyClient({ diff --git a/test/index.test.ts b/test/index.test.ts index 961fda6..b39187f 100644 --- a/test/index.test.ts +++ b/test/index.test.ts @@ -48,3 +48,12 @@ describe('procedures', () => { }); }); }); + +describe('context', () => { + test('getContext query', async () => { + await withFactory(async ({ client }) => { + const ctx = await client.getContext.query(); + expect(ctx).toEqual({ hello: 'world' }); + }); + }); +});