Skip to content

Commit

Permalink
[Inference] Factor makeRequestOptions logic (#1107)
Browse files Browse the repository at this point in the history
- Update the `makeRequestOptions` function to make the logic (hopefully)
more readable; especially since #1077
- Stop support of model URLs
- Update tapes
  • Loading branch information
SBrandeis authored Jan 15, 2025
1 parent 16f3d68 commit 697e9be
Show file tree
Hide file tree
Showing 8 changed files with 1,862 additions and 127 deletions.
251 changes: 137 additions & 114 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai";
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
import { INFERENCE_PROVIDERS, type InferenceTask, type Options, type RequestArgs } from "../types";
import { omit } from "../utils/omit";
import type { InferenceProvider } from "../types";
import type { InferenceTask, Options, RequestArgs } from "../types";
import { HF_HUB_URL } from "./getDefaultTask";
import { isUrl } from "./isUrl";

Expand Down Expand Up @@ -31,62 +31,49 @@ export async function makeRequestOptions(
chatCompletion?: boolean;
}
): Promise<{ url: string; info: RequestInit }> {
const { accessToken, endpointUrl, provider, ...otherArgs } = args;
let { model } = args;
const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...otherArgs } = args;
const provider = maybeProvider ?? "hf-inference";

const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion } =
options ?? {};

const headers: Record<string, string> = {};
if (accessToken) {
headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
if (endpointUrl && provider !== "hf-inference") {
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
}

if (!model && !tasks && taskHint) {
const res = await fetch(`${HF_HUB_URL}/api/tasks`);

if (res.ok) {
tasks = await res.json();
}
if (forceTask && provider !== "hf-inference") {
throw new Error(`Cannot use forceTask with a third-party provider.`);
}

if (!model && tasks && taskHint) {
const taskInfo = tasks[taskHint];
if (taskInfo) {
model = taskInfo.models[0].id;
}
if (maybeModel && isUrl(maybeModel)) {
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
}

if (!model) {
throw new Error("No model provided, and no default model found for this task");
}
if (provider) {
if (!INFERENCE_PROVIDERS.includes(provider)) {
throw new Error("Unknown Inference provider");
}
if (!accessToken) {
throw new Error("Specifying an Inference provider requires an accessToken");
let model: string;
if (!maybeModel) {
if (taskHint) {
model = mapModel({ model: await loadDefaultModel(taskHint), provider });
} else {
throw new Error("No model provided, and no default model found for this task");
/// TODO : change error message ^
}
} else {
model = mapModel({ model: maybeModel, provider });
}

const modelId = (() => {
switch (provider) {
case "replicate":
return REPLICATE_MODEL_IDS[model];
case "sambanova":
return SAMBANOVA_MODEL_IDS[model];
case "together":
return TOGETHER_MODEL_IDS[model]?.id;
case "fal-ai":
return FAL_AI_MODEL_IDS[model];
default:
return model;
}
})();

if (!modelId) {
throw new Error(`Model ${model} is not supported for provider ${provider}`);
}
const url = endpointUrl
? chatCompletion
? endpointUrl + `/v1/chat/completions`
: endpointUrl
: makeUrl({
model,
provider: provider ?? "hf-inference",
taskHint,
chatCompletion: chatCompletion ?? false,
forceTask,
});

model = modelId;
const headers: Record<string, string> = {};
if (accessToken) {
headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
}

const binary = "data" in args && !!args.data;
Expand All @@ -95,73 +82,20 @@ export async function makeRequestOptions(
headers["Content-Type"] = "application/json";
}

if (wait_for_model) {
headers["X-Wait-For-Model"] = "true";
}
if (use_cache === false) {
headers["X-Use-Cache"] = "false";
}
if (dont_load_model) {
headers["X-Load-Model"] = "0";
}
if (provider === "replicate") {
headers["Prefer"] = "wait";
}

let url = (() => {
if (endpointUrl && isUrl(model)) {
throw new TypeError("Both model and endpointUrl cannot be URLs");
}
if (isUrl(model)) {
console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
return model;
}
if (endpointUrl) {
return endpointUrl;
if (provider === "hf-inference") {
if (wait_for_model) {
headers["X-Wait-For-Model"] = "true";
}
if (forceTask) {
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${forceTask}/${model}`;
if (use_cache === false) {
headers["X-Use-Cache"] = "false";
}
if (provider) {
if (!accessToken) {
throw new Error("Specifying an Inference provider requires an accessToken");
}
if (accessToken.startsWith("hf_")) {
/// TODO we wil proxy the request server-side (using our own keys) and handle billing for it on the user's HF account.
throw new Error("Inference proxying is not implemented yet");
} else {
switch (provider) {
case "fal-ai":
return `${FAL_AI_API_BASE_URL}/${model}`;
case "replicate":
if (model.includes(":")) {
// Versioned models are in the form of `owner/model:version`
return `${REPLICATE_API_BASE_URL}/v1/predictions`;
} else {
// Unversioned models are in the form of `owner/model`
return `${REPLICATE_API_BASE_URL}/v1/models/${model}/predictions`;
}
case "sambanova":
return SAMBANOVA_API_BASE_URL;
case "together":
if (taskHint === "text-to-image") {
return `${TOGETHER_API_BASE_URL}/v1/images/generations`;
}
return TOGETHER_API_BASE_URL;
default:
break;
}
}
if (dont_load_model) {
headers["X-Load-Model"] = "0";
}

return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
})();

if (chatCompletion && !url.endsWith("/chat/completions")) {
url += "/v1/chat/completions";
}
if (provider === "together" && taskHint === "text-generation" && !chatCompletion) {
url += "/v1/completions";

if (provider === "replicate") {
headers["Prefer"] = "wait";
}

/**
Expand All @@ -188,13 +122,102 @@ export async function makeRequestOptions(
body: binary
? args.data
: JSON.stringify({
...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate" || provider === "fal-ai"
? omit(otherArgs, "model")
: { ...otherArgs, model }),
...otherArgs,
...(chatCompletion || provider === "together" ? { model } : undefined),
}),
...(credentials ? { credentials } : undefined),
signal: options?.signal,
};

return { url, info };
}

function mapModel(params: { model: string; provider: InferenceProvider }): string {
const model = (() => {
switch (params.provider) {
case "fal-ai":
return FAL_AI_MODEL_IDS[params.model];
case "replicate":
return REPLICATE_MODEL_IDS[params.model];
case "sambanova":
return SAMBANOVA_MODEL_IDS[params.model];
case "together":
return TOGETHER_MODEL_IDS[params.model]?.id;
case "hf-inference":
return params.model;
}
})();

if (!model) {
throw new Error(`Model ${params.model} is not supported for provider ${params.provider}`);
}
return model;
}

function makeUrl(params: {
model: string;
provider: InferenceProvider;
taskHint: InferenceTask | undefined;
chatCompletion: boolean;
forceTask?: string | InferenceTask;
}): string {
switch (params.provider) {
case "fal-ai":
return `${FAL_AI_API_BASE_URL}/${params.model}`;
case "replicate": {
if (params.model.includes(":")) {
/// Versioned model
return `${REPLICATE_API_BASE_URL}/v1/predictions`;
}
/// Evergreen / Canonical model
return `${REPLICATE_API_BASE_URL}/v1/models/${params.model}/predictions`;
}
case "sambanova":
/// Sambanova API matches OpenAI-like APIs: model is defined in the request body
if (params.taskHint === "text-generation" && params.chatCompletion) {
return `${SAMBANOVA_API_BASE_URL}/v1/chat/completions`;
}
return SAMBANOVA_API_BASE_URL;
case "together": {
/// Together API matches OpenAI-like APIs: model is defined in the request body
if (params.taskHint === "text-to-image") {
return `${TOGETHER_API_BASE_URL}/v1/images/generations`;
}
if (params.taskHint === "text-generation") {
if (params.chatCompletion) {
return `${TOGETHER_API_BASE_URL}/v1/chat/completions`;
}
return `${TOGETHER_API_BASE_URL}/v1/completions`;
}
return TOGETHER_API_BASE_URL;
}
default: {
const url = params.forceTask
? `${HF_INFERENCE_API_BASE_URL}/pipeline/${params.forceTask}/${params.model}`
: `${HF_INFERENCE_API_BASE_URL}/models/${params.model}`;
if (params.taskHint === "text-generation" && params.chatCompletion) {
return url + `/v1/chat/completions`;
}
return url;
}
}
}
async function loadDefaultModel(task: InferenceTask): Promise<string> {
if (!tasks) {
tasks = await loadTaskInfo();
}
const taskInfo = tasks[task];
if ((taskInfo?.models.length ?? 0) <= 0) {
throw new Error(`No default model defined for task ${task}, please define the model explicitly.`);
}
return taskInfo.models[0].id;
}

async function loadTaskInfo(): Promise<Record<string, { models: { id: string }[] }>> {
const res = await fetch(`${HF_HUB_URL}/api/tasks`);

if (!res.ok) {
throw new Error("Failed to load tasks definitions from Hugging Face Hub.");
}
return await res.json();
}
2 changes: 1 addition & 1 deletion packages/inference/src/providers/fal-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ type FalAiId = string;
/**
* Mapping from HF model ID -> fal.ai app id
*/
export const FAL_AI_MODEL_IDS: Record<ModelId, FalAiId> = {
export const FAL_AI_MODEL_IDS: Partial<Record<ModelId, FalAiId>> = {
/** text-to-image */
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type ReplicateId = string;
* 'https://api.replicate.com/v1/models'
* ```
*/
export const REPLICATE_MODEL_IDS: Record<ModelId, ReplicateId> = {
export const REPLICATE_MODEL_IDS: Partial<Record<ModelId, ReplicateId>> = {
/** text-to-image */
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
"ByteDance/SDXL-Lightning":
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/providers/sambanova.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type SambanovaId = string;
/**
* https://community.sambanova.ai/t/supported-models/193
*/
export const SAMBANOVA_MODEL_IDS: Record<ModelId, SambanovaId> = {
export const SAMBANOVA_MODEL_IDS: Partial<Record<ModelId, SambanovaId>> = {
/** Chat completion / conversational */
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
Expand Down
5 changes: 2 additions & 3 deletions packages/inference/src/providers/together.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ type TogetherId = string;
/**
* https://docs.together.ai/reference/models-1
*/
export const TOGETHER_MODEL_IDS: Record<
ModelId,
{ id: TogetherId; type: "chat" | "embedding" | "image" | "language" | "moderation" }
export const TOGETHER_MODEL_IDS: Partial<
Record<ModelId, { id: TogetherId; type: "chat" | "embedding" | "image" | "language" | "moderation" }>
> = {
/** text-to-image */
"black-forest-labs/FLUX.1-Canny-dev": { id: "black-forest-labs/FLUX.1-canny", type: "image" },
Expand Down
6 changes: 5 additions & 1 deletion packages/inference/src/tasks/custom/streamingRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ export async function* streamingRequest<T>(
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
}
if (output.error) {
if (typeof output.error === "string") {
throw new Error(output.error);
}
if (output.error && "message" in output.error && typeof output.error.message === "string") {
/// OpenAI errors
throw new Error(output.error.message);
}
}

throw new Error(`Server response contains error: ${response.status}`);
Expand Down
Loading

0 comments on commit 697e9be

Please sign in to comment.