Skip to content

Commit 697e9be

Browse files
authored
[Inference] Factor makeRequestOptions logic (#1107)
- Update the `makeRequestOptions` function to make the logic (hopefully) more readable; especially since #1077 - Stop support of model URLs - Update tapes
1 parent 16f3d68 commit 697e9be

File tree

8 files changed

+1862
-127
lines changed

8 files changed

+1862
-127
lines changed

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 137 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai";
22
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
33
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
44
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
5-
import { INFERENCE_PROVIDERS, type InferenceTask, type Options, type RequestArgs } from "../types";
6-
import { omit } from "../utils/omit";
5+
import type { InferenceProvider } from "../types";
6+
import type { InferenceTask, Options, RequestArgs } from "../types";
77
import { HF_HUB_URL } from "./getDefaultTask";
88
import { isUrl } from "./isUrl";
99

@@ -31,62 +31,49 @@ export async function makeRequestOptions(
3131
chatCompletion?: boolean;
3232
}
3333
): Promise<{ url: string; info: RequestInit }> {
34-
const { accessToken, endpointUrl, provider, ...otherArgs } = args;
35-
let { model } = args;
34+
const { accessToken, endpointUrl, provider: maybeProvider, model: maybeModel, ...otherArgs } = args;
35+
const provider = maybeProvider ?? "hf-inference";
36+
3637
const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion } =
3738
options ?? {};
3839

39-
const headers: Record<string, string> = {};
40-
if (accessToken) {
41-
headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
40+
if (endpointUrl && provider !== "hf-inference") {
41+
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
4242
}
43-
44-
if (!model && !tasks && taskHint) {
45-
const res = await fetch(`${HF_HUB_URL}/api/tasks`);
46-
47-
if (res.ok) {
48-
tasks = await res.json();
49-
}
43+
if (forceTask && provider !== "hf-inference") {
44+
throw new Error(`Cannot use forceTask with a third-party provider.`);
5045
}
51-
52-
if (!model && tasks && taskHint) {
53-
const taskInfo = tasks[taskHint];
54-
if (taskInfo) {
55-
model = taskInfo.models[0].id;
56-
}
46+
if (maybeModel && isUrl(maybeModel)) {
47+
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
5748
}
5849

59-
if (!model) {
60-
throw new Error("No model provided, and no default model found for this task");
61-
}
62-
if (provider) {
63-
if (!INFERENCE_PROVIDERS.includes(provider)) {
64-
throw new Error("Unknown Inference provider");
65-
}
66-
if (!accessToken) {
67-
throw new Error("Specifying an Inference provider requires an accessToken");
50+
let model: string;
51+
if (!maybeModel) {
52+
if (taskHint) {
53+
model = mapModel({ model: await loadDefaultModel(taskHint), provider });
54+
} else {
55+
throw new Error("No model provided, and no default model found for this task");
56+
/// TODO : change error message ^
6857
}
58+
} else {
59+
model = mapModel({ model: maybeModel, provider });
60+
}
6961

70-
const modelId = (() => {
71-
switch (provider) {
72-
case "replicate":
73-
return REPLICATE_MODEL_IDS[model];
74-
case "sambanova":
75-
return SAMBANOVA_MODEL_IDS[model];
76-
case "together":
77-
return TOGETHER_MODEL_IDS[model]?.id;
78-
case "fal-ai":
79-
return FAL_AI_MODEL_IDS[model];
80-
default:
81-
return model;
82-
}
83-
})();
84-
85-
if (!modelId) {
86-
throw new Error(`Model ${model} is not supported for provider ${provider}`);
87-
}
62+
const url = endpointUrl
63+
? chatCompletion
64+
? endpointUrl + `/v1/chat/completions`
65+
: endpointUrl
66+
: makeUrl({
67+
model,
68+
provider: provider ?? "hf-inference",
69+
taskHint,
70+
chatCompletion: chatCompletion ?? false,
71+
forceTask,
72+
});
8873

89-
model = modelId;
74+
const headers: Record<string, string> = {};
75+
if (accessToken) {
76+
headers["Authorization"] = provider === "fal-ai" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
9077
}
9178

9279
const binary = "data" in args && !!args.data;
@@ -95,73 +82,20 @@ export async function makeRequestOptions(
9582
headers["Content-Type"] = "application/json";
9683
}
9784

98-
if (wait_for_model) {
99-
headers["X-Wait-For-Model"] = "true";
100-
}
101-
if (use_cache === false) {
102-
headers["X-Use-Cache"] = "false";
103-
}
104-
if (dont_load_model) {
105-
headers["X-Load-Model"] = "0";
106-
}
107-
if (provider === "replicate") {
108-
headers["Prefer"] = "wait";
109-
}
110-
111-
let url = (() => {
112-
if (endpointUrl && isUrl(model)) {
113-
throw new TypeError("Both model and endpointUrl cannot be URLs");
114-
}
115-
if (isUrl(model)) {
116-
console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
117-
return model;
118-
}
119-
if (endpointUrl) {
120-
return endpointUrl;
85+
if (provider === "hf-inference") {
86+
if (wait_for_model) {
87+
headers["X-Wait-For-Model"] = "true";
12188
}
122-
if (forceTask) {
123-
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${forceTask}/${model}`;
89+
if (use_cache === false) {
90+
headers["X-Use-Cache"] = "false";
12491
}
125-
if (provider) {
126-
if (!accessToken) {
127-
throw new Error("Specifying an Inference provider requires an accessToken");
128-
}
129-
if (accessToken.startsWith("hf_")) {
130-
/// TODO we wil proxy the request server-side (using our own keys) and handle billing for it on the user's HF account.
131-
throw new Error("Inference proxying is not implemented yet");
132-
} else {
133-
switch (provider) {
134-
case "fal-ai":
135-
return `${FAL_AI_API_BASE_URL}/${model}`;
136-
case "replicate":
137-
if (model.includes(":")) {
138-
// Versioned models are in the form of `owner/model:version`
139-
return `${REPLICATE_API_BASE_URL}/v1/predictions`;
140-
} else {
141-
// Unversioned models are in the form of `owner/model`
142-
return `${REPLICATE_API_BASE_URL}/v1/models/${model}/predictions`;
143-
}
144-
case "sambanova":
145-
return SAMBANOVA_API_BASE_URL;
146-
case "together":
147-
if (taskHint === "text-to-image") {
148-
return `${TOGETHER_API_BASE_URL}/v1/images/generations`;
149-
}
150-
return TOGETHER_API_BASE_URL;
151-
default:
152-
break;
153-
}
154-
}
92+
if (dont_load_model) {
93+
headers["X-Load-Model"] = "0";
15594
}
156-
157-
return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
158-
})();
159-
160-
if (chatCompletion && !url.endsWith("/chat/completions")) {
161-
url += "/v1/chat/completions";
16295
}
163-
if (provider === "together" && taskHint === "text-generation" && !chatCompletion) {
164-
url += "/v1/completions";
96+
97+
if (provider === "replicate") {
98+
headers["Prefer"] = "wait";
16599
}
166100

167101
/**
@@ -188,13 +122,102 @@ export async function makeRequestOptions(
188122
body: binary
189123
? args.data
190124
: JSON.stringify({
191-
...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate" || provider === "fal-ai"
192-
? omit(otherArgs, "model")
193-
: { ...otherArgs, model }),
125+
...otherArgs,
126+
...(chatCompletion || provider === "together" ? { model } : undefined),
194127
}),
195128
...(credentials ? { credentials } : undefined),
196129
signal: options?.signal,
197130
};
198131

199132
return { url, info };
200133
}
134+
135+
function mapModel(params: { model: string; provider: InferenceProvider }): string {
136+
const model = (() => {
137+
switch (params.provider) {
138+
case "fal-ai":
139+
return FAL_AI_MODEL_IDS[params.model];
140+
case "replicate":
141+
return REPLICATE_MODEL_IDS[params.model];
142+
case "sambanova":
143+
return SAMBANOVA_MODEL_IDS[params.model];
144+
case "together":
145+
return TOGETHER_MODEL_IDS[params.model]?.id;
146+
case "hf-inference":
147+
return params.model;
148+
}
149+
})();
150+
151+
if (!model) {
152+
throw new Error(`Model ${params.model} is not supported for provider ${params.provider}`);
153+
}
154+
return model;
155+
}
156+
157+
function makeUrl(params: {
158+
model: string;
159+
provider: InferenceProvider;
160+
taskHint: InferenceTask | undefined;
161+
chatCompletion: boolean;
162+
forceTask?: string | InferenceTask;
163+
}): string {
164+
switch (params.provider) {
165+
case "fal-ai":
166+
return `${FAL_AI_API_BASE_URL}/${params.model}`;
167+
case "replicate": {
168+
if (params.model.includes(":")) {
169+
/// Versioned model
170+
return `${REPLICATE_API_BASE_URL}/v1/predictions`;
171+
}
172+
/// Evergreen / Canonical model
173+
return `${REPLICATE_API_BASE_URL}/v1/models/${params.model}/predictions`;
174+
}
175+
case "sambanova":
176+
/// Sambanova API matches OpenAI-like APIs: model is defined in the request body
177+
if (params.taskHint === "text-generation" && params.chatCompletion) {
178+
return `${SAMBANOVA_API_BASE_URL}/v1/chat/completions`;
179+
}
180+
return SAMBANOVA_API_BASE_URL;
181+
case "together": {
182+
/// Together API matches OpenAI-like APIs: model is defined in the request body
183+
if (params.taskHint === "text-to-image") {
184+
return `${TOGETHER_API_BASE_URL}/v1/images/generations`;
185+
}
186+
if (params.taskHint === "text-generation") {
187+
if (params.chatCompletion) {
188+
return `${TOGETHER_API_BASE_URL}/v1/chat/completions`;
189+
}
190+
return `${TOGETHER_API_BASE_URL}/v1/completions`;
191+
}
192+
return TOGETHER_API_BASE_URL;
193+
}
194+
default: {
195+
const url = params.forceTask
196+
? `${HF_INFERENCE_API_BASE_URL}/pipeline/${params.forceTask}/${params.model}`
197+
: `${HF_INFERENCE_API_BASE_URL}/models/${params.model}`;
198+
if (params.taskHint === "text-generation" && params.chatCompletion) {
199+
return url + `/v1/chat/completions`;
200+
}
201+
return url;
202+
}
203+
}
204+
}
205+
async function loadDefaultModel(task: InferenceTask): Promise<string> {
206+
if (!tasks) {
207+
tasks = await loadTaskInfo();
208+
}
209+
const taskInfo = tasks[task];
210+
if ((taskInfo?.models.length ?? 0) <= 0) {
211+
throw new Error(`No default model defined for task ${task}, please define the model explicitly.`);
212+
}
213+
return taskInfo.models[0].id;
214+
}
215+
216+
async function loadTaskInfo(): Promise<Record<string, { models: { id: string }[] }>> {
217+
const res = await fetch(`${HF_HUB_URL}/api/tasks`);
218+
219+
if (!res.ok) {
220+
throw new Error("Failed to load tasks definitions from Hugging Face Hub.");
221+
}
222+
return await res.json();
223+
}

packages/inference/src/providers/fal-ai.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ type FalAiId = string;
77
/**
88
* Mapping from HF model ID -> fal.ai app id
99
*/
10-
export const FAL_AI_MODEL_IDS: Record<ModelId, FalAiId> = {
10+
export const FAL_AI_MODEL_IDS: Partial<Record<ModelId, FalAiId>> = {
1111
/** text-to-image */
1212
"black-forest-labs/FLUX.1-schnell": "fal-ai/flux/schnell",
1313
"black-forest-labs/FLUX.1-dev": "fal-ai/flux/dev",

packages/inference/src/providers/replicate.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ type ReplicateId = string;
1414
* 'https://api.replicate.com/v1/models'
1515
* ```
1616
*/
17-
export const REPLICATE_MODEL_IDS: Record<ModelId, ReplicateId> = {
17+
export const REPLICATE_MODEL_IDS: Partial<Record<ModelId, ReplicateId>> = {
1818
/** text-to-image */
1919
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
2020
"ByteDance/SDXL-Lightning":

packages/inference/src/providers/sambanova.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ type SambanovaId = string;
1515
/**
1616
* https://community.sambanova.ai/t/supported-models/193
1717
*/
18-
export const SAMBANOVA_MODEL_IDS: Record<ModelId, SambanovaId> = {
18+
export const SAMBANOVA_MODEL_IDS: Partial<Record<ModelId, SambanovaId>> = {
1919
/** Chat completion / conversational */
2020
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
2121
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",

packages/inference/src/providers/together.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@ type TogetherId = string;
1010
/**
1111
* https://docs.together.ai/reference/models-1
1212
*/
13-
export const TOGETHER_MODEL_IDS: Record<
14-
ModelId,
15-
{ id: TogetherId; type: "chat" | "embedding" | "image" | "language" | "moderation" }
13+
export const TOGETHER_MODEL_IDS: Partial<
14+
Record<ModelId, { id: TogetherId; type: "chat" | "embedding" | "image" | "language" | "moderation" }>
1615
> = {
1716
/** text-to-image */
1817
"black-forest-labs/FLUX.1-Canny-dev": { id: "black-forest-labs/FLUX.1-canny", type: "image" },

packages/inference/src/tasks/custom/streamingRequest.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,13 @@ export async function* streamingRequest<T>(
3232
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
3333
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
3434
}
35-
if (output.error) {
35+
if (typeof output.error === "string") {
3636
throw new Error(output.error);
3737
}
38+
if (output.error && "message" in output.error && typeof output.error.message === "string") {
39+
/// OpenAI errors
40+
throw new Error(output.error.message);
41+
}
3842
}
3943

4044
throw new Error(`Server response contains error: ${response.status}`);

0 commit comments

Comments
 (0)