diff --git a/examples/sdk-core/README.md b/examples/sdk-core/README.md index d3b1059..aca4517 100644 --- a/examples/sdk-core/README.md +++ b/examples/sdk-core/README.md @@ -57,6 +57,7 @@ See `examples/nextjs-realtime` or `examples/react-vite` for runnable demos. - `realtime/live-avatar.ts` - Live avatar (audio-driven avatar with playAudio or mic input) - `realtime/connection-events.ts` - Handling connection state and errors - `realtime/prompt-update.ts` - Updating prompt dynamically +- `realtime/custom-model.ts` - Using a custom model definition (e.g., preview/experimental models) ## API Reference diff --git a/examples/sdk-core/realtime/custom-model.ts b/examples/sdk-core/realtime/custom-model.ts new file mode 100644 index 0000000..9e20dc6 --- /dev/null +++ b/examples/sdk-core/realtime/custom-model.ts @@ -0,0 +1,62 @@ +/** + * Custom Model Definition Example + * + * Demonstrates how to define and use a custom model that isn't + * built into the SDK. This is useful for preview/experimental models + * or private deployments. + * + * Browser-only example - requires WebRTC APIs + * See examples/nextjs-realtime or examples/react-vite for runnable demos + */ + +import { createDecartClient } from "@decartai/sdk"; +import type { CustomModelDefinition } from "@decartai/sdk"; + +async function main() { + // Define a custom model that isn't in the SDK's built-in registry. + // This works for any model that conforms to the CustomModelDefinition shape. + const lucy2RtPreview: CustomModelDefinition = { + name: "lucy_2_rt_preview", + urlPath: "/v1/stream", + fps: 20, + width: 1280, + height: 720, + }; + + // Get webcam stream using the custom model's settings + const stream = await navigator.mediaDevices.getUserMedia({ + audio: true, + video: { + frameRate: lucy2RtPreview.fps, + width: lucy2RtPreview.width, + height: lucy2RtPreview.height, + }, + }); + + const client = createDecartClient({ + apiKey: process.env.DECART_API_KEY!, + }); + + // Pass the custom model directly to realtime.connect() + const realtimeClient = await client.realtime.connect(stream, { + model: lucy2RtPreview, + onRemoteStream: (transformedStream) => { + const video = document.getElementById("output") as HTMLVideoElement; + video.srcObject = transformedStream; + }, + initialState: { + prompt: { + text: "cinematic lighting, film grain", + enhance: true, + }, + }, + }); + + console.log("Session ID:", realtimeClient.sessionId); + console.log("Connected:", realtimeClient.isConnected()); + + // Update prompt dynamically, same as built-in models + realtimeClient.setPrompt("watercolor painting style"); +} + +main(); diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index ddbbbca..ddcc3c5 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -47,6 +47,7 @@ export type { export type { ConnectionState } from "./realtime/types"; export type { WebRTCStats } from "./realtime/webrtc-stats"; export { + type CustomModelDefinition, type ImageModelDefinition, type ImageModels, isImageModel, diff --git a/packages/sdk/src/realtime/client.ts b/packages/sdk/src/realtime/client.ts index 7a5391a..f1869b0 100644 --- a/packages/sdk/src/realtime/client.ts +++ b/packages/sdk/src/realtime/client.ts @@ -1,5 +1,5 @@ import { z } from "zod"; -import { modelDefinitionSchema, type RealTimeModels } from "../shared/model"; +import { type CustomModelDefinition, type ModelDefinition, modelDefinitionSchema } from "../shared/model"; import { modelStateSchema } from "../shared/types"; import { classifyWebrtcError, type DecartSDKError } from "../utils/errors"; import type { Logger } from "../utils/logger"; @@ -94,7 +94,9 @@ const realTimeClientConnectOptionsSchema = z.object({ initialState: realTimeClientInitialStateSchema.optional(), customizeOffer: createAsyncFunctionSchema(z.function()).optional(), }); -export type RealTimeClientConnectOptions = z.infer; +export type RealTimeClientConnectOptions = Omit, "model"> & { + model: ModelDefinition | CustomModelDefinition; +}; export type Events = { connectionChange: ConnectionState; @@ -189,7 +191,7 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { customizeOffer: options.customizeOffer as ((offer: RTCSessionDescriptionInit) => Promise) | undefined, vp8MinBitrate: 300, vp8StartBitrate: 600, - modelName: options.model.name as RealTimeModels, + modelName: options.model.name, initialImage, initialPrompt, }); diff --git a/packages/sdk/src/realtime/webrtc-connection.ts b/packages/sdk/src/realtime/webrtc-connection.ts index 908b76f..d6f8139 100644 --- a/packages/sdk/src/realtime/webrtc-connection.ts +++ b/packages/sdk/src/realtime/webrtc-connection.ts @@ -1,5 +1,5 @@ import mitt from "mitt"; -import type { RealTimeModels } from "../shared/model"; + import type { Logger } from "../utils/logger"; import { buildUserAgent } from "../utils/user-agent"; import type { DiagnosticEmitter, IceCandidateEvent } from "./diagnostics"; @@ -24,7 +24,7 @@ interface ConnectionCallbacks { customizeOffer?: (offer: RTCSessionDescriptionInit) => Promise; vp8MinBitrate?: number; vp8StartBitrate?: number; - modelName?: RealTimeModels; + modelName?: string; initialImage?: string; initialPrompt?: { text: string; enhance?: boolean }; logger?: Logger; diff --git a/packages/sdk/src/realtime/webrtc-manager.ts b/packages/sdk/src/realtime/webrtc-manager.ts index 5ebdadf..71408fb 100644 --- a/packages/sdk/src/realtime/webrtc-manager.ts +++ b/packages/sdk/src/realtime/webrtc-manager.ts @@ -1,5 +1,5 @@ import pRetry, { AbortError } from "p-retry"; -import type { RealTimeModels } from "../shared/model"; + import type { Logger } from "../utils/logger"; import type { DiagnosticEmitter } from "./diagnostics"; import type { ConnectionState, OutgoingMessage } from "./types"; @@ -16,7 +16,7 @@ export interface WebRTCConfig { customizeOffer?: (offer: RTCSessionDescriptionInit) => Promise; vp8MinBitrate?: number; vp8StartBitrate?: number; - modelName?: RealTimeModels; + modelName?: string; initialImage?: string; initialPrompt?: { text: string; enhance?: boolean }; } diff --git a/packages/sdk/src/shared/model.ts b/packages/sdk/src/shared/model.ts index b3dcef0..d2f6d4b 100644 --- a/packages/sdk/src/shared/model.ts +++ b/packages/sdk/src/shared/model.ts @@ -208,6 +208,15 @@ export type ModelDefinition = { inputSchema: T extends keyof ModelInputSchemas ? ModelInputSchemas[T] : z.ZodTypeAny; }; +/** + * A model definition with an arbitrary (non-registry) model name. + * Use this when providing your own model configuration. + */ +export type CustomModelDefinition = Omit & { + name: string; + inputSchema?: z.ZodTypeAny; +}; + /** * Type alias for model definitions that support synchronous processing. * Only image models support the sync/process API. @@ -221,7 +230,7 @@ export type ImageModelDefinition = ModelDefinition; export type VideoModelDefinition = ModelDefinition; export const modelDefinitionSchema = z.object({ - name: modelSchema, + name: z.string(), urlPath: z.string(), queueUrlPath: z.string().optional(), fps: z.number().min(1), diff --git a/packages/sdk/tests/unit.test.ts b/packages/sdk/tests/unit.test.ts index 4a14ca6..111f89c 100644 --- a/packages/sdk/tests/unit.test.ts +++ b/packages/sdk/tests/unit.test.ts @@ -3117,3 +3117,33 @@ describe("VideoStall Diagnostic", () => { expect(event.data.durationMs).toBe(1500); }); }); + +describe("CustomModelDefinition", () => { + it("allows arbitrary model names in modelDefinitionSchema", async () => { + const { modelDefinitionSchema } = await import("../src/shared/model.js"); + + const customModel = { + name: "lucy_2_rt_preview", + urlPath: "/v1/stream", + fps: 20, + width: 1280, + height: 720, + }; + + const result = modelDefinitionSchema.safeParse(customModel); + expect(result.success).toBe(true); + }); + + it("rejects invalid custom model definitions", async () => { + const { modelDefinitionSchema } = await import("../src/shared/model.js"); + + const invalidModel = { + name: "my_custom_model", + urlPath: "/v1/stream", + // missing fps, width, height + }; + + const result = modelDefinitionSchema.safeParse(invalidModel); + expect(result.success).toBe(false); + }); +});