Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chat completion method #645

Merged
merged 31 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
332648d
implement chat completion
radames May 1, 2024
d63db89
missing import type
radames May 1, 2024
26dd3b1
fix chatCompletion input type
radames May 1, 2024
a90c0e7
🩹
coyotte508 May 2, 2024
a3928f4
✅ Update tests file
coyotte508 May 2, 2024
1d509c3
✅ Update tests
coyotte508 May 2, 2024
4e9ba18
🐛 One more test
coyotte508 May 2, 2024
c535893
✅ More tests
coyotte508 May 2, 2024
63d5151
✅ one more test
coyotte508 May 2, 2024
9d2f737
✅ Fix last test
coyotte508 May 2, 2024
0f881e4
Merge branch 'main' into chatCompletion
coyotte508 May 3, 2024
6a9ad56
remove skips
radames May 4, 2024
23637bf
recorded tapes.json
radames May 4, 2024
ca54d67
add chat chatCompletion hint to change url
radames May 4, 2024
7ff57f2
add chatCompletion test with modelid
radames May 4, 2024
91ec869
tests
radames May 4, 2024
8fd2621
test with error message
radames May 4, 2024
c9b95a5
test
radames May 4, 2024
5ab21ca
better error handling
radames May 5, 2024
074aa76
Merge branch 'main' into chatCompletion
radames May 8, 2024
32ad989
add chat completion example to inference README.md
radames May 8, 2024
3d8bfc6
fix
radames May 8, 2024
5e3a9d6
📝 Update README.md
coyotte508 May 8, 2024
0ca9ad0
return_full_text not compatible here
radames May 8, 2024
72cfa24
remove return_full_text
radames May 8, 2024
779b828
tests
radames May 8, 2024
b66fcf3
Update packages/inference/README.md
radames May 9, 2024
87ee635
fix chat completion example
radames May 9, 2024
b901f5b
♻️ Do not sent `options`
coyotte508 May 11, 2024
5f2b488
record test
radames May 11, 2024
6502858
Merge branch 'main' into chatCompletion
radames May 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,30 @@ const HF_TOKEN = "hf_...";

const inference = new HfInference(HF_TOKEN);

// Chat completion API
const out = await inference.chatCompletion({
model: "mistralai/Mistral-7B-Instruct-v0.2",
messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }],
max_tokens: 100
});

// Chat completion API on OpenAI (also compatible with MistralAI api, etc.)
const openai = new HfInference(OPENAI_TOKEN).endpoint("https://api.openai.com");
const out = await openai.chatCompletion({
model: "gpt-3.5-turbo",
messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }],
max_tokens: 100
});

// Streaming chat completion API
for await (const chunk of openai.chatCompletionStream({
model: "gpt-3.5-turbo",
messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }],
max_tokens: 100
})) {
console.log(chunk.choices[0].delta.content);
}
radames marked this conversation as resolved.
Show resolved Hide resolved

// You can also omit "model" to use the recommended model for the task
await inference.translation({
model: 't5-base',
Expand Down
149 changes: 122 additions & 27 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ It works with both [Inference API (serverless)](https://huggingface.co/docs/api-

Check out the [full documentation](https://huggingface.co/docs/huggingface.js/inference/README).

You can also try out a live [interactive notebook](https://observablehq.com/@huggingface/hello-huggingface-js-inference), see some demos on [hf.co/huggingfacejs](https://huggingface.co/huggingfacejs), or watch a [Scrimba tutorial that explains how Inference Endpoints works](https://scrimba.com/scrim/cod8248f5adfd6e129582c523).
You can also try out a live [interactive notebook](https://observablehq.com/@huggingface/hello-huggingface-js-inference), see some demos on [hf.co/huggingfacejs](https://huggingface.co/huggingfacejs), or watch a [Scrimba tutorial that explains how Inference Endpoints works](https://scrimba.com/scrim/cod8248f5adfd6e129582c523).

## Getting Started

Expand All @@ -30,7 +30,6 @@ import { HfInference } from "https://esm.sh/@huggingface/inference"
import { HfInference } from "npm:@huggingface/inference"
```


### Initialize

```typescript
Expand All @@ -43,7 +42,6 @@ const hf = new HfInference('your access token')

Your access token should be kept private. If you need to protect it in front-end applications, we suggest setting up a proxy server that stores the access token.


#### Tree-shaking

You can import the functions you need directly from the module instead of using the `HfInference` class.
Expand All @@ -63,6 +61,85 @@ This will enable tree-shaking by your bundler.

## Natural Language Processing

### Text Generation

Generates text from an input prompt.

[Demo](https://huggingface.co/spaces/huggingfacejs/streaming-text-generation)

```typescript
await hf.textGeneration({
model: 'gpt2',
inputs: 'The answer to the universe is'
})

for await (const output of hf.textGenerationStream({
model: "google/flan-t5-xxl",
inputs: 'repeat "one two three four"',
parameters: { max_new_tokens: 250 }
})) {
console.log(output.token.text, output.generated_text);
}
```

### Text Generation (Chat Completion API Compatible)

Using the `chatCompletion` method, you can generate text with models compatible with the OpenAI Chat Completion API. All models served by [TGI](https://api-inference.huggingface.co/framework/text-generation-inference) on Hugging Face support Messages API.

[Demo](https://huggingface.co/spaces/huggingfacejs/streaming-chat-completion)

```typescript
// Non-streaming API
const out = await hf.chatCompletion({
model: "mistralai/Mistral-7B-Instruct-v0.2",
messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }],
max_tokens: 500,
temperature: 0.1,
seed: 0,
});

// Streaming API
let out = "";
for await (const chunk of hf.chatCompletionStream({
model: "mistralai/Mistral-7B-Instruct-v0.2",
messages: [
{ role: "user", content: "Complete the equation 1+1= ,just the answer" },
],
max_tokens: 500,
temperature: 0.1,
seed: 0,
})) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
}
}
```

It's also possible to call Mistral or OpenAI endpoints directly:

```typescript
const openai = new HfInference(OPENAI_TOKEN).endpoint("https://api.openai.com");

let out = "";
for await (const chunk of openai.chatCompletionStream({
model: "gpt-3.5-turbo",
messages: [
{ role: "user", content: "Complete the equation 1+1= ,just the answer" },
],
max_tokens: 500,
temperature: 0.1,
seed: 0,
})) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
}
}
radames marked this conversation as resolved.
Show resolved Hide resolved

// For mistral AI:
// endpointUrl: "https://api.mistral.ai"
// model: "mistral-tiny"
```

### Fill Mask

Tries to fill in a hole with a missing word (token to be precise).
Expand Down Expand Up @@ -131,27 +208,6 @@ await hf.textClassification({
})
```

### Text Generation

Generates text from an input prompt.

[Demo](https://huggingface.co/spaces/huggingfacejs/streaming-text-generation)

```typescript
await hf.textGeneration({
model: 'gpt2',
inputs: 'The answer to the universe is'
})

for await (const output of hf.textGenerationStream({
model: "google/flan-t5-xxl",
inputs: 'repeat "one two three four"',
parameters: { max_new_tokens: 250 }
})) {
console.log(output.token.text, output.generated_text);
}
```

### Token Classification

Used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
Expand All @@ -177,9 +233,9 @@ await hf.translation({
model: 'facebook/mbart-large-50-many-to-many-mmt',
inputs: textToTranslate,
parameters: {
"src_lang": "en_XX",
"tgt_lang": "fr_XX"
}
"src_lang": "en_XX",
"tgt_lang": "fr_XX"
}
})
```

Expand Down Expand Up @@ -497,13 +553,52 @@ for await (const output of hf.streamingRequest({
}
```

You can use any OpenAI Chat Completion API-compatible provider with the `chatCompletion` method.
radames marked this conversation as resolved.
Show resolved Hide resolved

```typescript
// Chat Completion Example
const MISTRAL_KEY = process.env.MISTRAL_KEY;
const hf = new HfInference(MISTRAL_KEY);
const ep = hf.endpoint("https://api.mistral.ai/v1/chat/completions");
const stream = ep.streamingRequest({
model: "mistral-tiny",
messages: [{ role: "user", content: "Complete the equation one + one = , just the answer" }],
});
let out = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
console.log(out);
}
}
```

## Custom Inference Endpoints

Learn more about using your own inference endpoints [here](https://hf.co/docs/inference-endpoints/)

```typescript
const gpt2 = hf.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the universe is'});

// Chat Completion Example
const ep = hf.endpoint(
"https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2/v1/chat/completions"
);
const stream = ep.chatCompletionStream({
model: "tgi",
messages: [{ role: "user", content: "Complete the equation 1+1= ,just the answer" }],
max_tokens: 500,
temperature: 0.1,
seed: 0,
});
let out = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
console.log(out);
}
}
```

By default, all calls to the inference endpoint will wait until the model is
Expand Down
8 changes: 4 additions & 4 deletions packages/inference/src/HfInference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ type TaskWithNoAccessToken = {
) => ReturnType<Task[key]>;
};

type TaskWithNoAccessTokenNoModel = {
type TaskWithNoAccessTokenNoEndpointUrl = {
[key in keyof Task]: (
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "model">,
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "endpointUrl">,
options?: Parameters<Task[key]>[1]
) => ReturnType<Task[key]>;
};
Expand Down Expand Up @@ -57,12 +57,12 @@ export class HfInferenceEndpoint {
enumerable: false,
value: (params: RequestArgs, options: Options) =>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
fn({ ...params, accessToken, model: endpointUrl } as any, { ...defaultOptions, ...options }),
fn({ ...params, accessToken, endpointUrl } as any, { ...defaultOptions, ...options }),
});
}
}
}

export interface HfInference extends TaskWithNoAccessToken {}

export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoModel {}
export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoEndpointUrl {}
8 changes: 8 additions & 0 deletions packages/inference/src/lib/isEmpty.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export function isObjectEmpty(object: object): boolean {
for (const prop in object) {
if (Object.prototype.hasOwnProperty.call(object, prop)) {
return false;
}
}
return true;
}
17 changes: 12 additions & 5 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import type { InferenceTask, Options, RequestArgs } from "../types";
import { isObjectEmpty } from "../lib/isEmpty";
import { omit } from "../utils/omit";
import { HF_HUB_URL } from "./getDefaultTask";
import { isUrl } from "./isUrl";

Expand All @@ -24,8 +26,7 @@ export async function makeRequestOptions(
taskHint?: InferenceTask;
}
): Promise<{ url: string; info: RequestInit }> {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { accessToken, model: _model, ...otherArgs } = args;
const { accessToken, endpointUrl, ...otherArgs } = args;
let { model } = args;
const {
forceTask: task,
Expand Down Expand Up @@ -78,10 +79,16 @@ export async function makeRequestOptions(
}

const url = (() => {
if (endpointUrl && isUrl(model)) {
throw new TypeError("Both model and endpointUrl cannot be URLs");
}
if (isUrl(model)) {
console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
return model;
}

if (endpointUrl) {
return endpointUrl;
}
if (task) {
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
}
Expand All @@ -105,8 +112,8 @@ export async function makeRequestOptions(
body: binary
? args.data
: JSON.stringify({
...otherArgs,
options: options && otherOptions,
...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs),
...(otherOptions && !isObjectEmpty(otherOptions) && { options: otherOptions }),
}),
...(credentials && { credentials }),
signal: options?.signal,
Expand Down
13 changes: 12 additions & 1 deletion packages/inference/src/tasks/custom/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@ export async function request<T>(
task?: string | InferenceTask;
/** To load default model if needed */
taskHint?: InferenceTask;
/** Is chat completion compatible */
chatCompletion?: boolean;
radames marked this conversation as resolved.
Show resolved Hide resolved
}
): Promise<T> {
const { url, info } = await makeRequestOptions(args, options);
const { url: _url, info } = await makeRequestOptions(args, options);
let url = _url;
if (options?.chatCompletion) {
if (!url.endsWith("/chat/completions")) {
url += "/v1/chat/completions";
}
}
const response = await (options?.fetch ?? fetch)(url, info);

if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
Expand All @@ -26,6 +34,9 @@ export async function request<T>(
if (!response.ok) {
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
const output = await response.json();
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
}
if (output.error) {
throw new Error(output.error);
}
Expand Down
16 changes: 15 additions & 1 deletion packages/inference/src/tasks/custom/streamingRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,17 @@ export async function* streamingRequest<T>(
task?: string | InferenceTask;
/** To load default model if needed */
taskHint?: InferenceTask;
/** Is chat completion compatible */
chatCompletion?: boolean;
}
): AsyncGenerator<T> {
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
const { url: _url, info } = await makeRequestOptions({ ...args, stream: true }, options);
let url = _url;
if (options?.chatCompletion) {
if (!url.endsWith("/chat/completions")) {
url += "/v1/chat/completions";
}
}
const response = await (options?.fetch ?? fetch)(url, info);

if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
Expand All @@ -27,6 +35,9 @@ export async function* streamingRequest<T>(
if (!response.ok) {
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
const output = await response.json();
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
}
if (output.error) {
throw new Error(output.error);
}
Expand Down Expand Up @@ -67,6 +78,9 @@ export async function* streamingRequest<T>(
onChunk(value);
for (const event of events) {
if (event.data.length > 0) {
if (event.data === "[DONE]") {
return;
}
const data = JSON.parse(event.data);
if (typeof data === "object" && data !== null && "error" in data) {
throw new Error(data.error);
Expand Down
2 changes: 2 additions & 0 deletions packages/inference/src/tasks/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ export * from "./nlp/textGenerationStream";
export * from "./nlp/tokenClassification";
export * from "./nlp/translation";
export * from "./nlp/zeroShotClassification";
export * from "./nlp/chatCompletion";
export * from "./nlp/chatCompletionStream";

// Multimodal tasks
export * from "./multimodal/documentQuestionAnswering";
Expand Down
Loading
Loading