Skip to content

Adapt generate-snippets-fixtures script for providers #1137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/tasks-gen/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"type-fest": "^3.13.1"
},
"dependencies": {
"@huggingface/tasks": "workspace:^"
"@huggingface/tasks": "workspace:^",
"@huggingface/inference": "workspace:^"
}
}
3 changes: 3 additions & 0 deletions packages/tasks-gen/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 36 additions & 16 deletions packages/tasks-gen/scripts/generate-snippets-fixtures.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import { existsSync as pathExists } from "node:fs";
import * as fs from "node:fs/promises";
import * as path from "node:path/posix";

import type { InferenceSnippet } from "@huggingface/tasks";
import type { InferenceProvider, InferenceSnippet } from "@huggingface/tasks";
import { snippets } from "@huggingface/tasks";

type LANGUAGE = "sh" | "js" | "py";
Expand All @@ -28,6 +28,7 @@ const TEST_CASES: {
testName: string;
model: snippets.ModelDataMinimal;
languages: LANGUAGE[];
providers: InferenceProvider[];
opts?: Record<string, unknown>;
}[] = [
{
Expand All @@ -39,6 +40,7 @@ const TEST_CASES: {
inference: "",
},
languages: ["sh", "js", "py"],
providers: ["hf-inference", "together"],
opts: { streaming: false },
},
{
Expand All @@ -50,6 +52,7 @@ const TEST_CASES: {
inference: "",
},
languages: ["sh", "js", "py"],
providers: ["hf-inference"],
opts: { streaming: true },
},
{
Expand All @@ -61,6 +64,7 @@ const TEST_CASES: {
inference: "",
},
languages: ["sh", "js", "py"],
providers: ["hf-inference"],
opts: { streaming: false },
},
{
Expand All @@ -72,6 +76,7 @@ const TEST_CASES: {
inference: "",
},
languages: ["sh", "js", "py"],
providers: ["hf-inference"],
opts: { streaming: true },
},
{
Expand All @@ -82,6 +87,7 @@ const TEST_CASES: {
tags: [],
inference: "",
},
providers: ["hf-inference"],
languages: ["sh", "js", "py"],
},
] as const;
Expand Down Expand Up @@ -113,31 +119,41 @@ function getFixtureFolder(testName: string): string {
function generateInferenceSnippet(
model: snippets.ModelDataMinimal,
language: LANGUAGE,
provider: InferenceProvider,
opts?: Record<string, unknown>
): InferenceSnippet[] {
const generatedSnippets = GET_SNIPPET_FN[language](model, "api_token", opts);
const generatedSnippets = GET_SNIPPET_FN[language](model, "api_token", provider, opts);
return Array.isArray(generatedSnippets) ? generatedSnippets : [generatedSnippets];
}

async function getExpectedInferenceSnippet(testName: string, language: LANGUAGE): Promise<InferenceSnippet[]> {
async function getExpectedInferenceSnippet(
testName: string,
language: LANGUAGE,
provider: InferenceProvider
): Promise<InferenceSnippet[]> {
const fixtureFolder = getFixtureFolder(testName);
const files = await fs.readdir(fixtureFolder);

const expectedSnippets: InferenceSnippet[] = [];
for (const file of files.filter((file) => file.endsWith("." + language)).sort()) {
const client = path.basename(file).split(".").slice(1, -1).join("."); // e.g. '0.huggingface.js.js' => "huggingface.js"
for (const file of files.filter((file) => file.endsWith("." + language) && file.includes(`.${provider}.`)).sort()) {
const client = path.basename(file).split(".").slice(1, -2).join("."); // e.g. '0.huggingface.js.replicate.js' => "huggingface.js"
const content = await fs.readFile(path.join(fixtureFolder, file), { encoding: "utf-8" });
expectedSnippets.push(client === "default" ? { content } : { client, content });
}
return expectedSnippets;
}

async function saveExpectedInferenceSnippet(testName: string, language: LANGUAGE, snippets: InferenceSnippet[]) {
async function saveExpectedInferenceSnippet(
testName: string,
language: LANGUAGE,
provider: InferenceProvider,
snippets: InferenceSnippet[]
) {
const fixtureFolder = getFixtureFolder(testName);
await fs.mkdir(fixtureFolder, { recursive: true });

for (const [index, snippet] of snippets.entries()) {
const file = path.join(fixtureFolder, `${index}.${snippet.client ?? "default"}.${language}`);
const file = path.join(fixtureFolder, `${index}.${snippet.client ?? "default"}.${provider}.${language}`);
await fs.writeFile(file, snippet.content);
}
}
Expand All @@ -147,13 +163,15 @@ if (import.meta.vitest) {
const { describe, expect, it } = import.meta.vitest;

describe("inference API snippets", () => {
TEST_CASES.forEach(({ testName, model, languages, opts }) => {
TEST_CASES.forEach(({ testName, model, languages, providers, opts }) => {
describe(testName, () => {
languages.forEach((language) => {
it(language, async () => {
const generatedSnippets = generateInferenceSnippet(model, language, opts);
const expectedSnippets = await getExpectedInferenceSnippet(testName, language);
expect(generatedSnippets).toEqual(expectedSnippets);
providers.forEach((provider) => {
it(language, async () => {
const generatedSnippets = generateInferenceSnippet(model, language, provider, opts);
const expectedSnippets = await getExpectedInferenceSnippet(testName, language, provider);
expect(generatedSnippets).toEqual(expectedSnippets);
});
});
});
});
Expand All @@ -166,11 +184,13 @@ if (import.meta.vitest) {
await fs.rm(path.join(rootDirFinder(), "snippets-fixtures"), { recursive: true, force: true });

console.debug(" 🏭 Generating new fixtures...");
TEST_CASES.forEach(({ testName, model, languages, opts }) => {
console.debug(` ${testName} (${languages.join(", ")})`);
TEST_CASES.forEach(({ testName, model, languages, providers, opts }) => {
console.debug(` ${testName} (${languages.join(", ")}) (${providers.join(", ")})`);
languages.forEach(async (language) => {
const generatedSnippets = generateInferenceSnippet(model, language, opts);
await saveExpectedInferenceSnippet(testName, language, generatedSnippets);
providers.forEach(async (provider) => {
const generatedSnippets = generateInferenceSnippet(model, language, provider, opts);
await saveExpectedInferenceSnippet(testName, language, provider, generatedSnippets);
});
});
});
console.log("✅ All done!");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
curl 'https://huggingface.co/api/inference-proxy/together/v1/chat/completions' \
-H 'Authorization: Bearer api_token' \
-H 'Content-Type: application/json' \
--data '{
"model": "meta-llama/Llama-3.1-8B-Instruct",
"messages": [
{
"role": "user",
"content": "What is the capital of France?"
}
],
"max_tokens": 500,
"stream": false
}'
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { HfInference } from "@huggingface/inference";

const client = new HfInference("api_token");

const chatCompletion = await client.chatCompletion({
model: "meta-llama/Llama-3.1-8B-Instruct",
messages: [
{
role: "user",
content: "What is the capital of France?"
}
],
provider: "hf-inference",
max_tokens: 500
});

console.log(chatCompletion.choices[0].message);
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const chatCompletion = await client.chatCompletion({
content: "What is the capital of France?"
}
],
provider: "together",
max_tokens: 500
});

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from huggingface_hub import InferenceClient

client = InferenceClient(
provider="hf-inference",
api_key="api_token"
)

messages = [
{
"role": "user",
"content": "What is the capital of France?"
}
]

completion = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=messages,
max_tokens=500
)

print(completion.choices[0].message)
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from huggingface_hub import InferenceClient

client = InferenceClient(api_key="api_token")
client = InferenceClient(
provider="together",
api_key="api_token"
)

messages = [
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://api-inference.huggingface.co/v1/",
apiKey: "api_token"
baseURL: "https://api-inference.huggingface.co/v1/",
apiKey: "api_token"
});

const chatCompletion = await client.chat.completions.create({
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://huggingface.co/api/inference-proxy/together",
apiKey: "api_token"
});

const chatCompletion = await client.chat.completions.create({
model: "meta-llama/Llama-3.1-8B-Instruct",
messages: [
{
role: "user",
content: "What is the capital of France?"
}
],
max_tokens: 500
});

console.log(chatCompletion.choices[0].message);
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from openai import OpenAI

client = OpenAI(
base_url="https://huggingface.co/api/inference-proxy/together",
api_key="api_token"
)

messages = [
{
"role": "user",
"content": "What is the capital of France?"
}
]

completion = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=messages,
max_tokens=500
)

print(completion.choices[0].message)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const stream = client.chatCompletionStream({
content: "What is the capital of France?"
}
],
provider: "hf-inference",
max_tokens: 500
});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from huggingface_hub import InferenceClient

client = InferenceClient(api_key="api_token")
client = InferenceClient(
provider="hf-inference",
api_key="api_token"
)

messages = [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://api-inference.huggingface.co/v1/",
apiKey: "api_token"
apiKey: "api_token"
});

let out = "";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const chatCompletion = await client.chatCompletion({
]
}
],
provider: "hf-inference",
max_tokens: 500
});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from huggingface_hub import InferenceClient

client = InferenceClient(api_key="api_token")
client = InferenceClient(
provider="hf-inference",
api_key="api_token"
)

messages = [
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://api-inference.huggingface.co/v1/",
apiKey: "api_token"
baseURL: "https://api-inference.huggingface.co/v1/",
apiKey: "api_token"
});

const chatCompletion = await client.chat.completions.create({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const stream = client.chatCompletionStream({
]
}
],
provider: "hf-inference",
max_tokens: 500
});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from huggingface_hub import InferenceClient

client = InferenceClient(api_key="api_token")
client = InferenceClient(
provider="hf-inference",
api_key="api_token"
)

messages = [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://api-inference.huggingface.co/v1/",
apiKey: "api_token"
apiKey: "api_token"
});

let out = "";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { HfInference } from "@huggingface/inference";

const client = new HfInference("api_token");

const image = await client.textToImage({
model: "black-forest-labs/FLUX.1-schnell",
inputs: "Astronaut riding a horse",
parameters: { num_inference_steps: 5 },
provider: "hf-inference",
});
/// Use the generated image (it's a Blob)
Loading
Loading