@@ -2,8 +2,8 @@ import { FAL_AI_API_BASE_URL, FAL_AI_MODEL_IDS } from "../providers/fal-ai";
2
2
import { REPLICATE_API_BASE_URL , REPLICATE_MODEL_IDS } from "../providers/replicate" ;
3
3
import { SAMBANOVA_API_BASE_URL , SAMBANOVA_MODEL_IDS } from "../providers/sambanova" ;
4
4
import { TOGETHER_API_BASE_URL , TOGETHER_MODEL_IDS } from "../providers/together" ;
5
- import { INFERENCE_PROVIDERS , type InferenceTask , type Options , type RequestArgs } from "../types" ;
6
- import { omit } from "../utils/omit " ;
5
+ import type { InferenceProvider } from "../types" ;
6
+ import type { InferenceTask , Options , RequestArgs } from "../types " ;
7
7
import { HF_HUB_URL } from "./getDefaultTask" ;
8
8
import { isUrl } from "./isUrl" ;
9
9
@@ -31,62 +31,49 @@ export async function makeRequestOptions(
31
31
chatCompletion ?: boolean ;
32
32
}
33
33
) : Promise < { url : string ; info : RequestInit } > {
34
- const { accessToken, endpointUrl, provider, ...otherArgs } = args ;
35
- let { model } = args ;
34
+ const { accessToken, endpointUrl, provider : maybeProvider , model : maybeModel , ...otherArgs } = args ;
35
+ const provider = maybeProvider ?? "hf-inference" ;
36
+
36
37
const { forceTask, includeCredentials, taskHint, wait_for_model, use_cache, dont_load_model, chatCompletion } =
37
38
options ?? { } ;
38
39
39
- const headers : Record < string , string > = { } ;
40
- if ( accessToken ) {
41
- headers [ "Authorization" ] = provider === "fal-ai" ? `Key ${ accessToken } ` : `Bearer ${ accessToken } ` ;
40
+ if ( endpointUrl && provider !== "hf-inference" ) {
41
+ throw new Error ( `Cannot use endpointUrl with a third-party provider.` ) ;
42
42
}
43
-
44
- if ( ! model && ! tasks && taskHint ) {
45
- const res = await fetch ( `${ HF_HUB_URL } /api/tasks` ) ;
46
-
47
- if ( res . ok ) {
48
- tasks = await res . json ( ) ;
49
- }
43
+ if ( forceTask && provider !== "hf-inference" ) {
44
+ throw new Error ( `Cannot use forceTask with a third-party provider.` ) ;
50
45
}
51
-
52
- if ( ! model && tasks && taskHint ) {
53
- const taskInfo = tasks [ taskHint ] ;
54
- if ( taskInfo ) {
55
- model = taskInfo . models [ 0 ] . id ;
56
- }
46
+ if ( maybeModel && isUrl ( maybeModel ) ) {
47
+ throw new Error ( `Model URLs are no longer supported. Use endpointUrl instead.` ) ;
57
48
}
58
49
59
- if ( ! model ) {
60
- throw new Error ( "No model provided, and no default model found for this task" ) ;
61
- }
62
- if ( provider ) {
63
- if ( ! INFERENCE_PROVIDERS . includes ( provider ) ) {
64
- throw new Error ( "Unknown Inference provider" ) ;
65
- }
66
- if ( ! accessToken ) {
67
- throw new Error ( "Specifying an Inference provider requires an accessToken" ) ;
50
+ let model : string ;
51
+ if ( ! maybeModel ) {
52
+ if ( taskHint ) {
53
+ model = mapModel ( { model : await loadDefaultModel ( taskHint ) , provider } ) ;
54
+ } else {
55
+ throw new Error ( "No model provided, and no default model found for this task" ) ;
56
+ /// TODO : change error message ^
68
57
}
58
+ } else {
59
+ model = mapModel ( { model : maybeModel , provider } ) ;
60
+ }
69
61
70
- const modelId = ( ( ) => {
71
- switch ( provider ) {
72
- case "replicate" :
73
- return REPLICATE_MODEL_IDS [ model ] ;
74
- case "sambanova" :
75
- return SAMBANOVA_MODEL_IDS [ model ] ;
76
- case "together" :
77
- return TOGETHER_MODEL_IDS [ model ] ?. id ;
78
- case "fal-ai" :
79
- return FAL_AI_MODEL_IDS [ model ] ;
80
- default :
81
- return model ;
82
- }
83
- } ) ( ) ;
84
-
85
- if ( ! modelId ) {
86
- throw new Error ( `Model ${ model } is not supported for provider ${ provider } ` ) ;
87
- }
62
+ const url = endpointUrl
63
+ ? chatCompletion
64
+ ? endpointUrl + `/v1/chat/completions`
65
+ : endpointUrl
66
+ : makeUrl ( {
67
+ model,
68
+ provider : provider ?? "hf-inference" ,
69
+ taskHint,
70
+ chatCompletion : chatCompletion ?? false ,
71
+ forceTask,
72
+ } ) ;
88
73
89
- model = modelId ;
74
+ const headers : Record < string , string > = { } ;
75
+ if ( accessToken ) {
76
+ headers [ "Authorization" ] = provider === "fal-ai" ? `Key ${ accessToken } ` : `Bearer ${ accessToken } ` ;
90
77
}
91
78
92
79
const binary = "data" in args && ! ! args . data ;
@@ -95,73 +82,20 @@ export async function makeRequestOptions(
95
82
headers [ "Content-Type" ] = "application/json" ;
96
83
}
97
84
98
- if ( wait_for_model ) {
99
- headers [ "X-Wait-For-Model" ] = "true" ;
100
- }
101
- if ( use_cache === false ) {
102
- headers [ "X-Use-Cache" ] = "false" ;
103
- }
104
- if ( dont_load_model ) {
105
- headers [ "X-Load-Model" ] = "0" ;
106
- }
107
- if ( provider === "replicate" ) {
108
- headers [ "Prefer" ] = "wait" ;
109
- }
110
-
111
- let url = ( ( ) => {
112
- if ( endpointUrl && isUrl ( model ) ) {
113
- throw new TypeError ( "Both model and endpointUrl cannot be URLs" ) ;
114
- }
115
- if ( isUrl ( model ) ) {
116
- console . warn ( "Using a model URL is deprecated, please use the `endpointUrl` parameter instead" ) ;
117
- return model ;
118
- }
119
- if ( endpointUrl ) {
120
- return endpointUrl ;
85
+ if ( provider === "hf-inference" ) {
86
+ if ( wait_for_model ) {
87
+ headers [ "X-Wait-For-Model" ] = "true" ;
121
88
}
122
- if ( forceTask ) {
123
- return ` ${ HF_INFERENCE_API_BASE_URL } /pipeline/ ${ forceTask } / ${ model } ` ;
89
+ if ( use_cache === false ) {
90
+ headers [ "X-Use-Cache" ] = "false" ;
124
91
}
125
- if ( provider ) {
126
- if ( ! accessToken ) {
127
- throw new Error ( "Specifying an Inference provider requires an accessToken" ) ;
128
- }
129
- if ( accessToken . startsWith ( "hf_" ) ) {
130
- /// TODO we wil proxy the request server-side (using our own keys) and handle billing for it on the user's HF account.
131
- throw new Error ( "Inference proxying is not implemented yet" ) ;
132
- } else {
133
- switch ( provider ) {
134
- case "fal-ai" :
135
- return `${ FAL_AI_API_BASE_URL } /${ model } ` ;
136
- case "replicate" :
137
- if ( model . includes ( ":" ) ) {
138
- // Versioned models are in the form of `owner/model:version`
139
- return `${ REPLICATE_API_BASE_URL } /v1/predictions` ;
140
- } else {
141
- // Unversioned models are in the form of `owner/model`
142
- return `${ REPLICATE_API_BASE_URL } /v1/models/${ model } /predictions` ;
143
- }
144
- case "sambanova" :
145
- return SAMBANOVA_API_BASE_URL ;
146
- case "together" :
147
- if ( taskHint === "text-to-image" ) {
148
- return `${ TOGETHER_API_BASE_URL } /v1/images/generations` ;
149
- }
150
- return TOGETHER_API_BASE_URL ;
151
- default :
152
- break ;
153
- }
154
- }
92
+ if ( dont_load_model ) {
93
+ headers [ "X-Load-Model" ] = "0" ;
155
94
}
156
-
157
- return `${ HF_INFERENCE_API_BASE_URL } /models/${ model } ` ;
158
- } ) ( ) ;
159
-
160
- if ( chatCompletion && ! url . endsWith ( "/chat/completions" ) ) {
161
- url += "/v1/chat/completions" ;
162
95
}
163
- if ( provider === "together" && taskHint === "text-generation" && ! chatCompletion ) {
164
- url += "/v1/completions" ;
96
+
97
+ if ( provider === "replicate" ) {
98
+ headers [ "Prefer" ] = "wait" ;
165
99
}
166
100
167
101
/**
@@ -188,13 +122,102 @@ export async function makeRequestOptions(
188
122
body : binary
189
123
? args . data
190
124
: JSON . stringify ( {
191
- ...( ( otherArgs . model && isUrl ( otherArgs . model ) ) || provider === "replicate" || provider === "fal-ai"
192
- ? omit ( otherArgs , "model" )
193
- : { ...otherArgs , model } ) ,
125
+ ...otherArgs ,
126
+ ...( chatCompletion || provider === "together" ? { model } : undefined ) ,
194
127
} ) ,
195
128
...( credentials ? { credentials } : undefined ) ,
196
129
signal : options ?. signal ,
197
130
} ;
198
131
199
132
return { url, info } ;
200
133
}
134
+
135
+ function mapModel ( params : { model : string ; provider : InferenceProvider } ) : string {
136
+ const model = ( ( ) => {
137
+ switch ( params . provider ) {
138
+ case "fal-ai" :
139
+ return FAL_AI_MODEL_IDS [ params . model ] ;
140
+ case "replicate" :
141
+ return REPLICATE_MODEL_IDS [ params . model ] ;
142
+ case "sambanova" :
143
+ return SAMBANOVA_MODEL_IDS [ params . model ] ;
144
+ case "together" :
145
+ return TOGETHER_MODEL_IDS [ params . model ] ?. id ;
146
+ case "hf-inference" :
147
+ return params . model ;
148
+ }
149
+ } ) ( ) ;
150
+
151
+ if ( ! model ) {
152
+ throw new Error ( `Model ${ params . model } is not supported for provider ${ params . provider } ` ) ;
153
+ }
154
+ return model ;
155
+ }
156
+
157
+ function makeUrl ( params : {
158
+ model : string ;
159
+ provider : InferenceProvider ;
160
+ taskHint : InferenceTask | undefined ;
161
+ chatCompletion : boolean ;
162
+ forceTask ?: string | InferenceTask ;
163
+ } ) : string {
164
+ switch ( params . provider ) {
165
+ case "fal-ai" :
166
+ return `${ FAL_AI_API_BASE_URL } /${ params . model } ` ;
167
+ case "replicate" : {
168
+ if ( params . model . includes ( ":" ) ) {
169
+ /// Versioned model
170
+ return `${ REPLICATE_API_BASE_URL } /v1/predictions` ;
171
+ }
172
+ /// Evergreen / Canonical model
173
+ return `${ REPLICATE_API_BASE_URL } /v1/models/${ params . model } /predictions` ;
174
+ }
175
+ case "sambanova" :
176
+ /// Sambanova API matches OpenAI-like APIs: model is defined in the request body
177
+ if ( params . taskHint === "text-generation" && params . chatCompletion ) {
178
+ return `${ SAMBANOVA_API_BASE_URL } /v1/chat/completions` ;
179
+ }
180
+ return SAMBANOVA_API_BASE_URL ;
181
+ case "together" : {
182
+ /// Together API matches OpenAI-like APIs: model is defined in the request body
183
+ if ( params . taskHint === "text-to-image" ) {
184
+ return `${ TOGETHER_API_BASE_URL } /v1/images/generations` ;
185
+ }
186
+ if ( params . taskHint === "text-generation" ) {
187
+ if ( params . chatCompletion ) {
188
+ return `${ TOGETHER_API_BASE_URL } /v1/chat/completions` ;
189
+ }
190
+ return `${ TOGETHER_API_BASE_URL } /v1/completions` ;
191
+ }
192
+ return TOGETHER_API_BASE_URL ;
193
+ }
194
+ default : {
195
+ const url = params . forceTask
196
+ ? `${ HF_INFERENCE_API_BASE_URL } /pipeline/${ params . forceTask } /${ params . model } `
197
+ : `${ HF_INFERENCE_API_BASE_URL } /models/${ params . model } ` ;
198
+ if ( params . taskHint === "text-generation" && params . chatCompletion ) {
199
+ return url + `/v1/chat/completions` ;
200
+ }
201
+ return url ;
202
+ }
203
+ }
204
+ }
205
+ async function loadDefaultModel ( task : InferenceTask ) : Promise < string > {
206
+ if ( ! tasks ) {
207
+ tasks = await loadTaskInfo ( ) ;
208
+ }
209
+ const taskInfo = tasks [ task ] ;
210
+ if ( ( taskInfo ?. models . length ?? 0 ) <= 0 ) {
211
+ throw new Error ( `No default model defined for task ${ task } , please define the model explicitly.` ) ;
212
+ }
213
+ return taskInfo . models [ 0 ] . id ;
214
+ }
215
+
216
+ async function loadTaskInfo ( ) : Promise < Record < string , { models : { id : string } [ ] } > > {
217
+ const res = await fetch ( `${ HF_HUB_URL } /api/tasks` ) ;
218
+
219
+ if ( ! res . ok ) {
220
+ throw new Error ( "Failed to load tasks definitions from Hugging Face Hub." ) ;
221
+ }
222
+ return await res . json ( ) ;
223
+ }
0 commit comments