Skip to content

Commit

Permalink
Initial SWE-bench runner and various fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
danielcampagnolitg committed Aug 28, 2024
1 parent 414738d commit 34ad474
Show file tree
Hide file tree
Showing 12 changed files with 1,384 additions and 107 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"py": " node --env-file=variables/local.env -r ts-node/register -r tsconfig-paths/register src/cli/py.ts",
"code": " node --env-file=variables/local.env -r ts-node/register -r tsconfig-paths/register src/cli/code.ts",
"swe": " node --env-file=variables/local.env -r ts-node/register -r tsconfig-paths/register src/cli/swe.ts",
"swebench": "node --env-file=variables/local.env -r ts-node/register -r tsconfig-paths/register src/cli/swebench.ts",
"scrape": " node --env-file=variables/local.env -r ts-node/register -r tsconfig-paths/register src/cli/scrape.ts",
"research": "node --env-file=variables/local.env -r ts-node/register -r tsconfig-paths/register src/cli/research.ts",
"util": " node --env-file=variables/local.env -r ts-node/register -r tsconfig-paths/register src/cli/util.ts",
Expand Down
2 changes: 1 addition & 1 deletion src/cli/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export function parseProcessArgs(): CliOptions {
const scriptPath = process.argv[1];
let scriptName = scriptPath.split(path.sep).at(-1);
scriptName = scriptName.substring(0, scriptName.length - 3);
return parseUserCliArgs(scriptName, process.argv.slice(2));
return parseUserCliArgs(scriptName, process.argv.toSpliced(2));
}

export function parseUserCliArgs(scriptName: string, args: string[]): CliOptions {
Expand Down
2 changes: 1 addition & 1 deletion src/cli/gaia.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ async function main() {
llms = ClaudeLLMs();
}

const args = process.argv.splice(2);
const args = process.argv.toSpliced(2);
const questions = JSON.parse(readFileSync(tasksFile).toString()) as GaiaQuestion[];
if (args.length === 0) {
logger.info('Running entire Gaia benchmark...');
Expand Down
72 changes: 72 additions & 0 deletions src/cli/swebench.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import '#fastify/trace-init/trace-init'; // leave an empty line next so this doesn't get sorted from the first line

import { promises as fs, readFileSync } from 'fs';
import { AgentLLMs } from '#agent/agentContext';
import { AGENT_COMPLETED_PARAM_NAME } from '#agent/agentFunctions';
import { RunAgentConfig, startAgent, startAgentAndWait } from '#agent/agentRunner';
import { runAgentWorkflow } from '#agent/agentWorkflowRunner';
import { shutdownTrace } from '#fastify/trace-init/trace-init';
import { GitLab } from '#functions/scm/gitlab';
import { FileSystem } from '#functions/storage/filesystem';
import { UtilFunctions } from '#functions/util';
import { Perplexity } from '#functions/web/perplexity';
import { PublicWeb } from '#functions/web/web';
import { LlmCall } from '#llm/llmCallService/llmCall';
import { ClaudeLLMs } from '#llm/models/anthropic';
import { Claude3_5_Sonnet_Vertex, ClaudeVertexLLMs } from '#llm/models/anthropic-vertex';
import { groqLlama3_70B } from '#llm/models/groq';
import { Gemini_1_5_Flash } from '#llm/models/vertexai';
import { logger } from '#o11y/logger';
import { SWEBenchAgent, SWEInstance } from '#swe/SWEBenchAgent';
import { CodeEditingAgent } from '#swe/codeEditingAgent';
import { sleep } from '#utils/async-utils';
import { appContext, initFirestoreApplicationContext } from '../app';
import { parseProcessArgs, saveAgentId } from './cli';

async function main() {
const instance = JSON.parse(readFileSync('instance.json').toString()) as SWEInstance;

await new SWEBenchAgent().runInference(instance);

if (!process.env.ASDF) return;
// let args = process.argv.toSpliced(2);
//
// args = args.filter(arg => !arg.startsWith('-'))
// if(!args.length) throw new Error('instanceId is required')

let agentLlms: AgentLLMs = ClaudeLLMs();
if (process.env.GCLOUD_PROJECT) {
await initFirestoreApplicationContext();
agentLlms = ClaudeVertexLLMs();
}

const { initialPrompt, resumeAgentId } = parseProcessArgs();

console.log(`Prompt: ${initialPrompt}`);

const config: RunAgentConfig = {
agentName: `SWE-Bench ${instance.instance_id}`,
llms: agentLlms,
functions: [], //FileSystem,
initialPrompt,
resumeAgentId,
humanInLoop: {
budget: 4,
},
};

const agentId = await runAgentWorkflow(config, async () => {
await new CodeEditingAgent().runCodeEditWorkflow(config.initialPrompt);
});

if (agentId) {
saveAgentId('swebench', agentId);
}

await shutdownTrace();
}

main().then(
() => console.log('done'),
(e) => console.error(e),
);
31 changes: 24 additions & 7 deletions src/functions/scm/github.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { SourceControlManagement } from '#functions/scm/sourceControlManagement'
import { logger } from '#o11y/logger';
import { functionConfig } from '#user/userService/userContext';
import { envVar } from '#utils/env-var';
import { checkExecResult, execCmd, execCommand, failOnError } from '#utils/exec';
import { checkExecResult, execCmd, execCommand, failOnError, runShellCommand, spawnCommand } from '#utils/exec';
import { GitProject } from './gitProject';

type RequestType = typeof request;
Expand Down Expand Up @@ -86,23 +86,40 @@ export class GitHub implements SourceControlManagement {
// If we're resuming an agent which has already created the branch but not pushed
// then it won't exist remotely, so this will return a non-zero code
if (branchOrCommit) {
// TODO
}
// Fetch all branches and commits
await execCommand(`git -C ${path} fetch --all`, { workingDirectory: path });

// Checkout to the branch or commit
const result = await execCommand(`git -C ${path} checkout ${branchOrCommit}`, { workingDirectory: path });
failOnError(`Failed to checkout ${branchOrCommit} in ${path}`, result);

const result = await execCommand(`git -C ${path} pull`, { workingDirectory: path });
// checkExecResult(result, `Failed to pull ${path}`);
// if (this.checkIfBranch(branchOrCommit)) {
// const pullResult = await execCommand(`git pull`);
// failOnError(`Failed to pull ${path} after checking out ${branchOrCommit}`, pullResult);
// }
}
} else {
logger.info(`Cloning project: ${org}/${project} to ${path}`);
const command = `git clone https://oauth2:${this.config().token}@github.com/${projectPathWithOrg}.git ${path}`;
const result = await execCommand(command, { workingDirectory: path });
const command = `git clone 'https://oauth2:${this.config().token}@github.com/${projectPathWithOrg}.git' ${path}`;
const result = await spawnCommand(command);
// if(result.error) throw result.error
failOnError(`Failed to clone ${projectPathWithOrg}`, result);

const checkoutResult = await execCommand(`git -C ${path} checkout ${branchOrCommit}`, { workingDirectory: path });
failOnError(`Failed to checkout ${branchOrCommit} in ${path}`, checkoutResult);
}
const agent = agentContext();
if (agent) agentContext().memory[`GitHub_project_${org}_${project}_FileSystem_directory`] = path;

return path;
}

async checkIfBranch(ref: string): Promise<boolean> {
const result = await execCommand(`git show-ref refs/heads/${ref}`);
if (result.exitCode) return false;
return result.stdout.trim().length > 0;
}

@func()
async createMergeRequest(title: string, description: string, sourceBranch: string, targetBranch: string): Promise<string> {
// TODO git push
Expand Down
124 changes: 123 additions & 1 deletion src/llm/models/anthropic-vertex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { AnthropicVertex } from '@anthropic-ai/vertex-sdk';
import { AgentLLMs, addCost, agentContext } from '#agent/agentContext';
import { BaseLLM } from '../base-llm';
import { MaxTokensError } from '../errors';
import { GenerateTextOptions, LLM, combinePrompts, logTextGeneration } from '../llm';
import { GenerateTextOptions, LLM, LlmMessage, combinePrompts, logTextGeneration } from '../llm';
import Message = Anthropic.Message;
import { LlmCall } from '#llm/llmCallService/llmCall';
import { CallerId } from '#llm/llmCallService/llmCallService';
Expand Down Expand Up @@ -210,6 +210,128 @@ class AnthropicVertexLLM extends BaseLLM {
});
}

// Error when
// {"error":{"code":400,"message":"Project `1234567890` is not allowed to use Publisher Model `projects/project-id/locations/us-central1/publishers/anthropic/models/claude-3-haiku@20240307`","status":"FAILED_PRECONDITION"}}
@cacheRetry({ backOffMs: 5000 })
// @logTextGeneration
async generateText2(messages: LlmMessage[], opts?: GenerateTextOptions): Promise<string> {
return await withActiveSpan(`generateText ${opts?.id ?? ''}`, async (span) => {
const maxOutputTokens = 4096;

let systemPrompt: string | undefined;
if (messages[0].role === 'system') {
systemPrompt = messages[0].text;
span.setAttribute('systemPrompt', systemPrompt);
messages = messages.slice(1);
}

const userPrompt = messages.map((msg) => msg.text).join('\n');
const combinedPrompt = combinePrompts(userPrompt, systemPrompt);

span.setAttributes({
userPrompt,
inputChars: combinedPrompt.length,
model: this.model,
service: this.service,
caller: agentContext().callStack.at(-1) ?? '',
});
if (opts?.id) span.setAttribute('id', opts.id);

const llmCallSave: Promise<LlmCall> = appContext().llmCallService.saveRequest({
userPrompt,
systemPrompt,
llmId: this.getId(),
agentId: agentContext().agentId,
callStack: agentContext().callStack.join(' > '),
});
const requestTime = Date.now();

let message: Message;
try {
message = await this.api().messages.create({
system: systemPrompt ? [{ type: 'text', text: systemPrompt }] : undefined,
messages: [
{
role: 'user',
content: userPrompt,
},
],
model: this.model,
max_tokens: maxOutputTokens,
stop_sequences: opts?.stopSequences,
});
} catch (e) {
if (this.isRetryableError(e)) {
throw new RetryableError(e);
}
throw e;
}

// This started happening randomly!
if (typeof message === 'string') {
message = JSON.parse(message);
}

const errorMessage = message as any;
if (errorMessage.type === 'error') {
throw new Error(`${errorMessage.error.type} ${errorMessage.error.message}`);
}

if (!message.content.length) throw new Error(`Response Message did not have any content: ${JSON.stringify(message)}`);

if (message.content[0].type !== 'text') throw new Error(`Message content type was not text. Was ${message.content[0].type}`);

const responseText = (message.content[0] as TextBlock).text;

const finishTime = Date.now();
const timeToFirstToken = finishTime - requestTime;

const llmCall: LlmCall = await llmCallSave;

// TODO
const inputTokens = message.usage.input_tokens;
const outputTokens = message.usage.output_tokens;

const inputCost = this.calculateInputCost(combinedPrompt);
const outputCost = this.calculateOutputCost(responseText);
const cost = inputCost + outputCost;
addCost(cost);

llmCall.responseText = responseText;
llmCall.timeToFirstToken = timeToFirstToken;
llmCall.totalTime = finishTime - requestTime;
llmCall.cost = cost;
llmCall.inputTokens = inputTokens;
llmCall.outputTokens = outputTokens;

span.setAttributes({
inputTokens,
outputTokens,
response: responseText,
inputCost: inputCost.toFixed(4),
outputCost: outputCost.toFixed(4),
cost: cost.toFixed(4),
outputChars: responseText.length,
callStack: agentContext().callStack.join(' > '),
});

try {
await appContext().llmCallService.saveResponse(llmCall);
} catch (e) {
// queue to save
logger.error(e);
}

if (message.stop_reason === 'max_tokens') {
// TODO we can replay with request with the current response appended so the LLM can complete it
logger.error('= RESPONSE exceeded max tokens ===============================');
logger.debug(responseText);
throw new MaxTokensError(maxOutputTokens, responseText);
}
return responseText;
});
}

isRetryableError(e: any) {
if (e.status === 429 || e.status === 529) return true;
if (e.error?.code === 429 || e.error?.code === 529) return true;
Expand Down
Loading

0 comments on commit 34ad474

Please sign in to comment.