diff --git a/src/__test__/trpcToOpenApi.test.ts b/src/__test__/trpcToOpenApi.test.ts index 3f7cc1a..b9e2aec 100644 --- a/src/__test__/trpcToOpenApi.test.ts +++ b/src/__test__/trpcToOpenApi.test.ts @@ -69,6 +69,7 @@ describe("trpcToOpenApi", () => { }, }, }, + components: {}, }); }); @@ -140,6 +141,7 @@ describe("trpcToOpenApi", () => { }, }, }, + components: {}, }); }); @@ -197,4 +199,56 @@ describe("trpcToOpenApi", () => { ]); }); }); + + describe("globalHeaders", () => { + it("includes headers in every endpoint", () => { + const t = initTRPC.create(); + const router = t.router({ + createThing: t.procedure + .input(z.object({ name: z.string() })) + .mutation(() => undefined), + getThing: t.procedure + .input(z.object({ name: z.string() })) + .query(() => undefined), + }); + + const openApiSpec = trpcToOpenApi({ + apiTitle: "My API", + apiVersion: "1.0", + basePath: "", + router, + globalHeaders: { + MyHeader: { + in: "header", + name: "X-My-Header", + schema: { type: "string" }, + required: false, + }, + }, + }); + + expect(openApiSpec.components).toEqual({ + parameters: { + MyHeader: { + in: "header", + name: "X-My-Header", + schema: { type: "string" }, + required: false, + }, + }, + }); + + const expectedHeaderReferences = [ + { $ref: "#/components/parameters/MyHeader" }, + ]; + + expect(openApiSpec.paths?.["/createThing"]?.post?.parameters).toEqual( + expectedHeaderReferences, + ); + + expect( + openApiSpec.paths?.["/getThing"]?.get?.parameters?.slice(1), + ).toEqual(expectedHeaderReferences); + }); + }); }); diff --git a/src/trpcToOpenApi.ts b/src/trpcToOpenApi.ts index a93657b..1730f77 100644 --- a/src/trpcToOpenApi.ts +++ b/src/trpcToOpenApi.ts @@ -4,7 +4,7 @@ import { type ProcedureType, type RouterRecord, } from "@trpc/server/unstable-core-do-not-import"; -import { type OpenAPIV3, type OpenAPIV3_1 } from "openapi-types"; +import { OpenAPIV3, type OpenAPIV3_1 } from "openapi-types"; import { type ZodSchema } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; import { OpenApiMeta } from "./meta.js"; @@ -16,19 +16,32 @@ export function trpcToOpenApi({ apiVersion, basePath, router, + globalHeaders, }: { apiTitle: string; apiVersion: string; basePath: string; router: AnyTRPCRouter; + globalHeaders?: Record; }): OpenAPIV3_1.Document { + const headerParameters = + globalHeaders != null + ? Object.keys(globalHeaders).map( + (headerKey): OpenAPIV3_1.ReferenceObject => ({ + $ref: `#/components/parameters/${headerKey}`, + }), + ) + : undefined; + return { openapi: "3.1.0", info: { title: apiTitle, version: apiVersion }, - paths: getPathsForRouterRecord( + paths: getPathsForRouterRecord({ basePath, - router._def.procedures as RouterRecord, - ), + routerRecord: router._def.procedures as RouterRecord, + additionalParameters: headerParameters, + }), + components: globalHeaders != null ? { parameters: globalHeaders } : {}, }; } @@ -41,10 +54,17 @@ const PROCEDURE_TYPE_HTTP_METHOD_MAP: Record< subscription: undefined, }; -function getPathsForRouterRecord( - basePath: string, - routerRecord: RouterRecord, -): OpenAPIV3_1.PathsObject { +function getPathsForRouterRecord({ + basePath, + routerRecord, + additionalParameters, +}: { + basePath: string; + routerRecord: RouterRecord; + additionalParameters: + | (OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.ParameterObject)[] + | undefined; +}): OpenAPIV3_1.PathsObject { const paths: OpenAPIV3_1.PathsObject = {}; for (const [procedureName, procedureOrRouterRecord] of entries( @@ -57,8 +77,13 @@ function getPathsForRouterRecord( basePath, procedureName: String(procedureName), procedure: procedureOrRouterRecord, + additionalParameters, }) - : getPathsForRouterRecord(basePath, procedureOrRouterRecord), + : getPathsForRouterRecord({ + basePath, + routerRecord: procedureOrRouterRecord, + additionalParameters, + }), ); } @@ -69,10 +94,14 @@ function getPathsForProcedure({ basePath, procedureName, procedure, + additionalParameters, }: { basePath: string; procedureName: string; procedure: AnyProcedure; + additionalParameters: + | (OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.ParameterObject)[] + | undefined; }): OpenAPIV3_1.PathsObject { const def = procedure._def as unknown as AnyProcedure["_def"] & ProcedureBuilderDef; @@ -108,6 +137,13 @@ function getPathsForProcedure({ content, }; } + + if (additionalParameters != null) { + operation.parameters = [ + ...(operation.parameters ?? []), + ...(additionalParameters ?? []), + ]; + } } return {