Skip to content

Commit

Permalink
support for replicate
Browse files Browse the repository at this point in the history
  • Loading branch information
julien-c committed Dec 17, 2024
1 parent d96a18e commit aa7d1ca
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 16 deletions.
13 changes: 12 additions & 1 deletion packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
import { INFERENCE_PROVIDERS, type InferenceTask, type Options, type RequestArgs } from "../types";
Expand Down Expand Up @@ -64,6 +65,9 @@ export async function makeRequestOptions(
throw new Error("Specifying an Inference provider requires an accessToken");
}
switch (provider) {
case "replicate":
model = REPLICATE_MODEL_IDS[model];
break;
case "sambanova":
model = SAMBANOVA_MODEL_IDS[model];
break;
Expand All @@ -90,6 +94,9 @@ export async function makeRequestOptions(
if (dont_load_model) {
headers["X-Load-Model"] = "0";
}
if (provider === "replicate") {
headers["Prefer"] = "wait";
}

let url = (() => {
if (endpointUrl && isUrl(model)) {
Expand All @@ -115,6 +122,8 @@ export async function makeRequestOptions(
} else {
/// This is an external key
switch (provider) {
case "replicate":
return `${REPLICATE_API_BASE_URL}/v1/models/${model}/predictions`;
case "sambanova":
return SAMBANOVA_API_BASE_URL;
case "together":
Expand Down Expand Up @@ -151,7 +160,9 @@ export async function makeRequestOptions(
body: binary
? args.data
: JSON.stringify({
...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : { ...otherArgs, model }),
...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate"
? omit(otherArgs, "model")
: { ...otherArgs, model }),
}),
...(credentials ? { credentials } : undefined),
signal: options?.signal,
Expand Down
18 changes: 18 additions & 0 deletions packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import type { ModelId } from "../types";

export const REPLICATE_API_BASE_URL = "https://api.replicate.com";

/**
* Same comment as in sambanova.ts
*/
type ReplicateId = string;

/**
* curl -s \
* -H "Authorization: Bearer $REPLICATE_API_TOKEN" \
* https://api.replicate.com/v1/models
*/
export const REPLICATE_MODEL_IDS: Record<ModelId, ReplicateId> = {
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step",
};
8 changes: 7 additions & 1 deletion packages/inference/src/tasks/custom/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,19 @@ export async function request<T>(
}

if (!response.ok) {
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
if (
["application/json", "application/problem+json"].some(
(contentType) => response.headers.get("Content-Type")?.startsWith(contentType)
)
) {
const output = await response.json();
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
}
if (output.error) {
throw new Error(JSON.stringify(output.error));
} else {
throw new Error(output);
}
}
throw new Error("An error occurred while fetching the blob");
Expand Down
35 changes: 25 additions & 10 deletions packages/inference/src/tasks/cv/textToImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ export type TextToImageArgs = BaseArgs & {
inputs: string;

/**
* Same param but for external providers like Together
* Same param but for external providers like Together, Replicate
*/
prompt?: string;
response_format?: "base64";
input?: {
prompt: string;
};

parameters?: {
/**
Expand All @@ -38,14 +41,16 @@ export type TextToImageArgs = BaseArgs & {
};
};

export type TextToImageOutput = Blob;

interface Base64ImageGeneration {
id: string;
model: string;
data: Array<{
b64_json: string;
}>;
}
export type TextToImageOutput = Blob;
interface OutputUrlImageGeneration {
output: string[];
}

/**
* This task reads some text input and outputs an image.
Expand All @@ -56,16 +61,26 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
args.prompt = args.inputs;
args.inputs = "";
args.response_format = "base64";
} else if (args.provider === "replicate") {
args.input = { prompt: args.inputs };
delete (args as unknown as { inputs: unknown }).inputs;
}
const res = await request<TextToImageOutput | Base64ImageGeneration>(args, {
const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(args, {
...options,
taskHint: "text-to-image",
});
if (res && typeof res === "object" && Array.isArray(res.data) && res.data[0].b64_json) {
const base64Data = res.data[0].b64_json;
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
const blob = await base64Response.blob();
return blob;
if (res && typeof res === "object") {
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
const base64Data = res.data[0].b64_json;
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
const blob = await base64Response.blob();
return blob;
}
if ("output" in res && Array.isArray(res.output)) {
const urlResponse = await fetch(res.output[0]);
const blob = await urlResponse.blob();
return blob;
}
}
const isValidOutput = res && res instanceof Blob;
if (!isValidOutput) {
Expand Down
8 changes: 4 additions & 4 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export interface Options {

export type InferenceTask = Exclude<PipelineType, "other">;

export const INFERENCE_PROVIDERS = ["sambanova", "together", "hf-inference"] as const;
export const INFERENCE_PROVIDERS = ["replicate", "sambanova", "together", "hf-inference"] as const;
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];

export interface BaseArgs {
Expand All @@ -54,19 +54,19 @@ export interface BaseArgs {
*
* Can be created for free in hf.co/settings/token
*
* You can also pass an external Inference provider's key if you intend to call a compatible provider like Sambanova, Together...
* You can also pass an external Inference provider's key if you intend to call a compatible provider like Sambanova, Together, Replicate...
*/
accessToken?: string;

/**
* The model to use.
* The HF model to use.
*
* If not specified, will call huggingface.co/api/tasks to get the default model for the task.
*
* /!\ Legacy behavior allows this to be an URL, but this is deprecated and will be removed in the future.
* Use the `endpointUrl` parameter instead.
*/
model?: string;
model?: ModelId;

/**
* The URL of the endpoint to use. If not specified, will call huggingface.co/api/tasks to get the default endpoint for the task.
Expand Down
10 changes: 10 additions & 0 deletions packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,16 @@ describe.concurrent(
});
expect(res).toBeInstanceOf(Blob);
});

it("textToImage replicate", async () => {
const hf = new HfInference(env.REPLICATE_KEY);
const res = await hf.textToImage({
model: "black-forest-labs/FLUX.1-schnell",
provider: "replicate",
inputs: "black forest gateau cake spelling out the words FLUX SCHNELL, tasty, food photography, dynamic shot",
});
expect(res).toBeInstanceOf(Blob);
});
},
TIMEOUT
);

0 comments on commit aa7d1ca

Please sign in to comment.