Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 8b model #39

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/commands/local.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { cli } from "../cli.js";
import { llama } from "../plugins/local-llm-rename/llama.js";
import { DEFAULT_MODEL, getEnsuredModelPath } from "../local-models.js";
import { DEFAULT_MODEL } from "../local-models.js";
import { unminify } from "../unminify.js";
import prettier from "../plugins/prettier.js";
import babel from "../plugins/babel/babel.js";
Expand All @@ -26,7 +26,7 @@ export const local = cli()
}

const prompt = await llama({
modelPath: getEnsuredModelPath(opts.model),
model: opts.model,
disableGPU: opts.disableGPU,
seed: opts.seed ? parseInt(opts.seed) : undefined
});
Expand Down
24 changes: 20 additions & 4 deletions src/local-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,36 @@ import { showProgress } from "./progress.js";
import { err } from "./cli-error.js";
import { homedir } from "os";
import { join } from "path";
import { ChatWrapper, Llama3_1ChatWrapper } from "node-llama-cpp";

const MODEL_DIRECTORY = join(homedir(), ".humanifyjs", "models");

export const MODELS: { [modelName: string]: URL } = {
"2gb": url`https://huggingface.co/bartowski/Phi-3.1-mini-4k-instruct-GGUF/resolve/main/Phi-3.1-mini-4k-instruct-Q4_K_M.gguf?download=true`
type ModelDefinition = { url: URL; wrapper?: ChatWrapper };

export const MODELS: { [modelName: string]: ModelDefinition } = {
"2gb": {
url: url`https://huggingface.co/bartowski/Phi-3.1-mini-4k-instruct-GGUF/resolve/main/Phi-3.1-mini-4k-instruct-Q4_K_M.gguf?download=true`
},
"8b": {
url: url`https://huggingface.co/lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf?download=true`,
wrapper: new Llama3_1ChatWrapper()
}
};

async function ensureModelDirectory() {
await fs.mkdir(MODEL_DIRECTORY, { recursive: true });
}

export function getModelWrapper(model: string) {
if (!(model in MODELS)) {
err(`Model ${model} not found`);
}
return MODELS[model].wrapper;
}

export async function downloadModel(model: string) {
await ensureModelDirectory();
const url = MODELS[model];
const url = MODELS[model].url;
if (url === undefined) {
err(`Model ${model} not found`);
}
Expand Down Expand Up @@ -54,7 +70,7 @@ export function getModelPath(model: string) {
if (!(model in MODELS)) {
err(`Model ${model} not found`);
}
const filename = basename(MODELS[model].pathname);
const filename = basename(MODELS[model].url.pathname);
return `${MODEL_DIRECTORY}/${filename}`;
}

Expand Down
8 changes: 2 additions & 6 deletions src/plugins/local-llm-rename/define-filename.llmtest.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import test from "node:test";
import { llama } from "./llama.js";
import { assertMatches } from "../../test-utils.js";
import { DEFAULT_MODEL, getEnsuredModelPath } from "../../local-models.js";
import { defineFilename } from "./define-filename.js";
import { testPrompt } from "../../test/test-prompt.js";

const prompt = await llama({
seed: 1,
modelPath: getEnsuredModelPath(process.env["MODEL"] ?? DEFAULT_MODEL)
});
const prompt = await testPrompt();

test("Defines a good name for a file with a function", async () => {
const result = await defineFilename(prompt, "const a = b => b + 1;");
Expand Down
8 changes: 5 additions & 3 deletions src/plugins/local-llm-rename/llama.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { getLlama, LlamaChatSession, LlamaGrammar } from "node-llama-cpp";
import { Gbnf } from "./gbnf.js";
import { getModelPath, getModelWrapper } from "../../local-models.js";

export type Prompt = (
systemPrompt: string,
Expand All @@ -11,12 +12,12 @@ const IS_CI = process.env["CI"] === "true";

export async function llama(opts: {
seed?: number;
modelPath: string;
model: string;
disableGPU?: boolean;
}): Promise<Prompt> {
const llama = await getLlama();
const model = await llama.loadModel({
modelPath: opts?.modelPath,
modelPath: getModelPath(opts?.model),
gpuLayers: (opts?.disableGPU ?? IS_CI) ? 0 : undefined
});

Expand All @@ -26,7 +27,8 @@ export async function llama(opts: {
const session = new LlamaChatSession({
contextSequence: context.getSequence(),
autoDisposeSequence: true,
systemPrompt
systemPrompt,
chatWrapper: getModelWrapper(opts.model)
});
const response = await session.promptWithMeta(userPrompt, {
temperature: 0.8,
Expand Down
4 changes: 2 additions & 2 deletions src/test/test-prompt.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { DEFAULT_MODEL, getEnsuredModelPath } from "../local-models.js";
import { DEFAULT_MODEL } from "../local-models.js";
import { llama } from "../plugins/local-llm-rename/llama.js";

export const testPrompt = async () =>
await llama({
seed: 1,
modelPath: getEnsuredModelPath(process.env["MODEL"] ?? DEFAULT_MODEL)
model: process.env["MODEL"] ?? DEFAULT_MODEL
});