Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Prepare release #1112

Merged
merged 3 commits into from
Jan 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ await uploadFile({
}
});

// Use Inference API
// Use HF Inference API

await inference.chatCompletion({
model: "meta-llama/Llama-3.1-8B-Instruct",
@@ -53,7 +53,7 @@ await inference.textToImage({

This is a collection of JS libraries to interact with the Hugging Face API, with TS types included.

- [@huggingface/inference](packages/inference/README.md): Use Inference API (serverless) and Inference Endpoints (dedicated) to make calls to 100,000+ Machine Learning models
- [@huggingface/inference](packages/inference/README.md): Use Inference API (serverless), Inference Endpoints (dedicated) and third-party Inference providers to make calls to 100,000+ Machine Learning models
- [@huggingface/hub](packages/hub/README.md): Interact with huggingface.co to create or delete repos and commit / download files
- [@huggingface/agents](packages/agents/README.md): Interact with HF models through a natural language interface
- [@huggingface/gguf](packages/gguf/README.md): A GGUF parser that works on remotely hosted files.
@@ -144,6 +144,22 @@ for await (const chunk of inference.chatCompletionStream({
console.log(chunk.choices[0].delta.content);
}

/// Using a third-party provider:
await inference.chatCompletion({
model: "meta-llama/Llama-3.1-8B-Instruct",
messages: [{ role: "user", content: "Hello, nice to meet you!" }],
max_tokens: 512,
provider: "sambanova"
})

await inference.textToImage({
model: "black-forest-labs/FLUX.1-dev",
inputs: "a picture of a green bird",
provider: "together"
})



// You can also omit "model" to use the recommended model for the task
await inference.translation({
inputs: "My name is Wolfgang and I live in Amsterdam",
28 changes: 28 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
@@ -42,6 +42,34 @@ const hf = new HfInference('your access token')

Your access token should be kept private. If you need to protect it in front-end applications, we suggest setting up a proxy server that stores the access token.

### Requesting third-party inference providers

You can request inference from third-party providers with the inference client.

Currently, we support the following providers: [Fal.ai](https://fal.ai), [Replicate](https://replicate.com), [Together](https://together.xyz) and [Sambanova](https://sambanova.ai).

To make request to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
```ts
const accessToken = "hf_..."; // Either a HF access token, or an API key from the 3rd party provider (Replicate in this example)

const client = new HfInference(accessToken);
await client.textToImage({
provider: "replicate",
model:"black-forest-labs/Flux.1-dev",
inputs: "A black forest cake"
})
```

When authenticated with a Hugging Face access token, the request is routed through https://huggingface.co.
When authenticated with a third-party provider key, the request is made directly against that provider's inference API.

Only a subset of models are supported when requesting 3rd party providers. You can check the list of supported models per pipeline tasks here:
- [Fal.ai supported models](./src/providers/fal-ai.ts)
- [Replicate supported models](./src/providers/replicate.ts)
- [Sambanova supported models](./src/providers/sambanova.ts)
- [Together supported models](./src/providers/together.ts)
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)

#### Tree-shaking

You can import the functions you need directly from the module instead of using the `HfInference` class.
1 change: 1 addition & 0 deletions packages/inference/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export type { ProviderMapping } from "./providers/types"
export { HfInference, HfInferenceEndpoint } from "./HfInference";
export { InferenceOutputError } from "./lib/InferenceOutputError";
export { FAL_AI_SUPPORTED_MODEL_IDS } from "./providers/fal-ai";
27 changes: 14 additions & 13 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { WidgetType } from "@huggingface/tasks";
import { HF_HUB_URL, HF_INFERENCE_API_URL } from "../config";
import { FAL_AI_API_BASE_URL, FAL_AI_SUPPORTED_MODEL_IDS } from "../providers/fal-ai";
import { REPLICATE_API_BASE_URL, REPLICATE_SUPPORTED_MODEL_IDS } from "../providers/replicate";
@@ -65,21 +66,21 @@ export async function makeRequestOptions(
? "hf-token"
: "provider-key"
: includeCredentials === "include"
? "credentials-include"
: "none";
? "credentials-include"
: "none";

const url = endpointUrl
? chatCompletion
? endpointUrl + `/v1/chat/completions`
: endpointUrl
: makeUrl({
authMethod,
chatCompletion: chatCompletion ?? false,
forceTask,
model,
provider: provider ?? "hf-inference",
taskHint,
});
authMethod,
chatCompletion: chatCompletion ?? false,
forceTask,
model,
provider: provider ?? "hf-inference",
taskHint,
});

const headers: Record<string, string> = {};
if (accessToken) {
@@ -133,9 +134,9 @@ export async function makeRequestOptions(
body: binary
? args.data
: JSON.stringify({
...otherArgs,
...(chatCompletion || provider === "together" ? { model } : undefined),
}),
...otherArgs,
...(chatCompletion || provider === "together" ? { model } : undefined),
}),
...(credentials ? { credentials } : undefined),
signal: options?.signal,
};
@@ -155,7 +156,7 @@ function mapModel(params: {
if (!params.taskHint) {
throw new Error("taskHint must be specified when using a third-party provider");
}
const task = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
const task: WidgetType = params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
const model = (() => {
switch (params.provider) {
case "fal-ai":
5 changes: 3 additions & 2 deletions packages/inference/src/providers/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type { InferenceTask, ModelId } from "../types";
import type { WidgetType } from "@huggingface/tasks";
import type { ModelId } from "../types";

export type ProviderMapping<ProviderId extends string> = Partial<
Record<InferenceTask | "conversational", Partial<Record<ModelId, ProviderId>>>
Record<WidgetType, Partial<Record<ModelId, ProviderId>>>
>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Loading