From bf47b3b726d130142d3006788a29f5ea09620114 Mon Sep 17 00:00:00 2001 From: SBrandeis Date: Tue, 7 Jan 2025 15:24:26 +0100 Subject: [PATCH] add fal-ai as a provider --- .../inference/src/lib/makeRequestOptions.ts | 20 ++++++++++++------- .../inference/src/tasks/custom/request.ts | 6 +++--- .../inference/src/tasks/cv/textToImage.ts | 6 +++++- packages/inference/src/types.ts | 2 +- 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/packages/inference/src/lib/makeRequestOptions.ts b/packages/inference/src/lib/makeRequestOptions.ts index 62423ebd1..50b63647e 100644 --- a/packages/inference/src/lib/makeRequestOptions.ts +++ b/packages/inference/src/lib/makeRequestOptions.ts @@ -1,3 +1,4 @@ +import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai"; import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate"; import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova"; import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together"; @@ -9,7 +10,8 @@ import { isUrl } from "./isUrl"; const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co"; /** - * Loaded from huggingface.co/api/tasks if needed + * Lazy-loaded from huggingface.co/api/tasks when needed + * Used to determine the default model to use when it's not user defined */ let tasks: Record | null = null; @@ -36,7 +38,7 @@ export async function makeRequestOptions( const headers: Record = {}; if (accessToken) { - headers["Authorization"] = `Bearer ${accessToken}`; + headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`; } if (!model && !tasks && taskHint) { @@ -74,6 +76,9 @@ export async function makeRequestOptions( case "together": model = TOGETHER_MODEL_IDS[model]?.id ?? model; break; + case "fal-ai": + model = FAL_AI_MODEL_IDS[model]; + break; default: break; } @@ -120,8 +125,9 @@ export async function makeRequestOptions( /// TODO we wil proxy the request server-side (using our own keys) and handle billing for it on the user's HF account. throw new Error("Inference proxying is not implemented yet"); } else { - /// This is an external key switch (provider) { + case 'fal-ai': + return `${FAL_AI_API_BASE_URL}/${model}`; case "replicate": return `${REPLICATE_API_BASE_URL}/v1/models/${model}/predictions`; case "sambanova": @@ -160,10 +166,10 @@ export async function makeRequestOptions( body: binary ? args.data : JSON.stringify({ - ...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate" - ? omit(otherArgs, "model") - : { ...otherArgs, model }), - }), + ...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate" || provider === "fal-ai" + ? omit(otherArgs, "model") + : { ...otherArgs, model }), + }), ...(credentials ? { credentials } : undefined), signal: options?.signal, }; diff --git a/packages/inference/src/tasks/custom/request.ts b/packages/inference/src/tasks/custom/request.ts index f23616c18..22067bb3a 100644 --- a/packages/inference/src/tasks/custom/request.ts +++ b/packages/inference/src/tasks/custom/request.ts @@ -2,7 +2,7 @@ import type { InferenceTask, Options, RequestArgs } from "../../types"; import { makeRequestOptions } from "../../lib/makeRequestOptions"; /** - * Primitive to make custom calls to Inference Endpoints + * Primitive to make custom calls to the inference provider */ export async function request( args: RequestArgs, @@ -35,8 +35,8 @@ export async function request( if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) { throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`); } - if (output.error) { - throw new Error(JSON.stringify(output.error)); + if (output.error || output.detail) { + throw new Error(JSON.stringify(output.error ?? output.detail)); } else { throw new Error(output); } diff --git a/packages/inference/src/tasks/cv/textToImage.ts b/packages/inference/src/tasks/cv/textToImage.ts index 12a7a9995..d8527d653 100644 --- a/packages/inference/src/tasks/cv/textToImage.ts +++ b/packages/inference/src/tasks/cv/textToImage.ts @@ -57,7 +57,7 @@ interface OutputUrlImageGeneration { * Recommended model: stabilityai/stable-diffusion-2 */ export async function textToImage(args: TextToImageArgs, options?: Options): Promise { - if (args.provider === "together") { + if (args.provider === "together" || args.provider === "fal-ai") { args.prompt = args.inputs; args.inputs = ""; args.response_format = "base64"; @@ -70,6 +70,10 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro taskHint: "text-to-image", }); if (res && typeof res === "object") { + if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) { + const image = await fetch(res.images[0].url); + return await image.blob(); + } if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) { const base64Data = res.data[0].b64_json; const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`); diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index 33325bbd6..3b70538b3 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -45,7 +45,7 @@ export interface Options { export type InferenceTask = Exclude; -export const INFERENCE_PROVIDERS = ["replicate", "sambanova", "together", "hf-inference"] as const; +export const INFERENCE_PROVIDERS = ["fal-ai", "replicate", "sambanova", "together", "hf-inference"] as const; export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number]; export interface BaseArgs {