Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Message API code snippets #700

Merged
merged 4 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion packages/tasks/src/snippets/curl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,24 @@ export const snippetBasic = (model: ModelDataMinimal, accessToken: string): stri
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
`;

export const snippetTextGeneration = (model: ModelDataMinimal, accessToken: string): string => {
if (model.config?.tokenizer_config?.chat_template) {
// Conversational model detected, so we display a code snippet that features the Messages API
return `curl 'https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions' \\
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}" \\
-H 'Content-Type: application/json' \\
-d '{
"model": "${model.id}",
"messages": [{"role": "user", "content": "What is the capital of France?"}],
"max_tokens": 500,
"stream": false
}'
`;
} else {
return snippetBasic(model, accessToken);
}
};

export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): string =>
`curl https://api-inference.huggingface.co/models/${model.id} \\
-X POST \\
Expand All @@ -35,7 +53,7 @@ export const curlSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal
translation: snippetBasic,
summarization: snippetBasic,
"feature-extraction": snippetBasic,
"text-generation": snippetBasic,
"text-generation": snippetTextGeneration,
"text2text-generation": snippetBasic,
"fill-mask": snippetBasic,
"sentence-similarity": snippetBasic,
Expand Down
50 changes: 25 additions & 25 deletions packages/tasks/src/snippets/inputs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,30 @@ const inputsSummarization = () =>

const inputsTableQuestionAnswering = () =>
`{
"query": "How many stars does the transformers repository have?",
"table": {
"Repository": ["Transformers", "Datasets", "Tokenizers"],
"Stars": ["36542", "4512", "3934"],
"Contributors": ["651", "77", "34"],
"Programming language": [
"Python",
"Python",
"Rust, Python and NodeJS"
]
}
}`;
"query": "How many stars does the transformers repository have?",
"table": {
"Repository": ["Transformers", "Datasets", "Tokenizers"],
"Stars": ["36542", "4512", "3934"],
"Contributors": ["651", "77", "34"],
"Programming language": [
"Python",
"Python",
"Rust, Python and NodeJS"
]
}
}`;

const inputsVisualQuestionAnswering = () =>
`{
"image": "cat.png",
"question": "What is in this image?"
}`;
"image": "cat.png",
"question": "What is in this image?"
}`;

const inputsQuestionAnswering = () =>
`{
"question": "What is my name?",
"context": "My name is Clara and I live in Berkeley."
}`;
"question": "What is my name?",
"context": "My name is Clara and I live in Berkeley."
}`;

const inputsTextClassification = () => `"I like you. I love you"`;

Expand All @@ -48,13 +48,13 @@ const inputsFillMask = (model: ModelDataMinimal) => `"The answer to the universe

const inputsSentenceSimilarity = () =>
`{
"source_sentence": "That is a happy person",
"sentences": [
"That is a happy dog",
"That is a very happy person",
"Today is a sunny day"
]
}`;
"source_sentence": "That is a happy person",
"sentences": [
"That is a happy dog",
"That is a very happy person",
"Today is a sunny day"
]
}`;

const inputsFeatureExtraction = () => `"Today is a sunny day and I will get some ice cream."`;

Expand Down
46 changes: 40 additions & 6 deletions packages/tasks/src/snippets/js.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ export const snippetBasic = (model: ModelDataMinimal, accessToken: string): stri
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
headers: {
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}"
"Content-Type": "application/json",
},
method: "POST",
body: JSON.stringify(data),
}
Expand All @@ -20,12 +23,34 @@ query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
console.log(JSON.stringify(response));
});`;

export const snippetTextGeneration = (model: ModelDataMinimal, accessToken: string): string => {
if (model.config?.tokenizer_config?.chat_template) {
// Conversational model detected, so we display a code snippet that features the Messages API
return `import { HfInference } from "@huggingface/inference";

const inference = new HfInference("${accessToken || `{API_TOKEN}`}");

for await (const chunk of inference.chatCompletionStream({
model: "${model.id}",
messages: [{ role: "user", content: "What is the capital of France?" }],
max_tokens: 500,
})) {
process.stdout.write(chunk.choices[0]?.delta?.content || "");
}
`;
} else {
return snippetBasic(model, accessToken);
}
};
export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): string =>
`async function query(data) {
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
headers: {
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}"
"Content-Type": "application/json",
},
method: "POST",
body: JSON.stringify(data),
}
Expand All @@ -45,7 +70,10 @@ export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string)
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
headers: {
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}"
"Content-Type": "application/json",
},
method: "POST",
body: JSON.stringify(data),
}
Expand All @@ -62,7 +90,10 @@ export const snippetTextToAudio = (model: ModelDataMinimal, accessToken: string)
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
headers: {
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}"
"Content-Type": "application/json",
},
method: "POST",
body: JSON.stringify(data),
}
Expand Down Expand Up @@ -99,7 +130,10 @@ export const snippetFile = (model: ModelDataMinimal, accessToken: string): strin
const response = await fetch(
"https://api-inference.huggingface.co/models/${model.id}",
{
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
headers: {
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}"
"Content-Type": "application/json",
},
method: "POST",
body: data,
}
Expand All @@ -122,7 +156,7 @@ export const jsSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal,
translation: snippetBasic,
summarization: snippetBasic,
"feature-extraction": snippetBasic,
"text-generation": snippetBasic,
"text-generation": snippetTextGeneration,
"text2text-generation": snippetBasic,
"fill-mask": snippetBasic,
"sentence-similarity": snippetBasic,
Expand Down
31 changes: 27 additions & 4 deletions packages/tasks/src/snippets/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@ import type { PipelineType } from "../pipelines.js";
import { getModelInputSnippet } from "./inputs.js";
import type { ModelDataMinimal } from "./types.js";

export const snippetConversational = (model: ModelDataMinimal, accessToken: string): string =>
`from huggingface_hub import InferenceClient

client = InferenceClient(
"${model.id}",
token="${accessToken || "{API_TOKEN}"}",
)

for message in client.chat_completion(
messages=[{"role": "user", "content": "What is the capital of France?"}],
max_tokens=500,
stream=True,
):
print(message.choices[0].delta.content, end="")
`;

export const snippetZeroShotClassification = (model: ModelDataMinimal): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
Expand Down Expand Up @@ -107,7 +123,7 @@ output = query({
"inputs": ${getModelInputSnippet(model)},
})`;

export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal) => string>> = {
export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal, accessToken: string) => string>> = {
// Same order as in tasks/src/pipelines.ts
"text-classification": snippetBasic,
"token-classification": snippetBasic,
Expand Down Expand Up @@ -138,15 +154,22 @@ export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinim
};

export function getPythonInferenceSnippet(model: ModelDataMinimal, accessToken: string): string {
const body =
model.pipeline_tag && model.pipeline_tag in pythonSnippets ? pythonSnippets[model.pipeline_tag]?.(model) ?? "" : "";
if (model.pipeline_tag === "text-generation" && model.config?.tokenizer_config?.chat_template) {
// Conversational model detected, so we display a code snippet that features the Messages API
return snippetConversational(model, accessToken);
} else {
const body =
model.pipeline_tag && model.pipeline_tag in pythonSnippets
? pythonSnippets[model.pipeline_tag]?.(model, accessToken) ?? ""
: "";

return `import requests
return `import requests

API_URL = "https://api-inference.huggingface.co/models/${model.id}"
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}

${body}`;
}
}

export function hasPythonInferenceSnippet(model: ModelDataMinimal): boolean {
Expand Down
2 changes: 1 addition & 1 deletion packages/tasks/src/snippets/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ import type { ModelData } from "../model-data";
*
* Add more fields as needed.
*/
export type ModelDataMinimal = Pick<ModelData, "id" | "pipeline_tag" | "mask_token" | "library_name">;
export type ModelDataMinimal = Pick<ModelData, "id" | "pipeline_tag" | "mask_token" | "library_name" | "config">;
Loading