Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simply retry strategy for minor errors #3380

Merged
merged 9 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/ai/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/dist
.env

1 change: 1 addition & 0 deletions packages/ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"axios": "^1.7.2",
"chalk": "^5.3.0",
"clsx": "^2.1.1",
"dotenv": "^16.4.5",
"next": "14.2.4",
"openai": "^4.56.1",
"ts-node": "^10.9.2",
Expand Down
54 changes: 35 additions & 19 deletions packages/ai/src/LLMStep.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import { LlmId } from "./modelConfigs.js";
import { PromptPair } from "./prompts.js";
import { Workflow } from "./workflows/Workflow.js";

export type ErrorType = "CRITICAL" | "MINOR";

export type StepState =
| {
kind: "PENDING";
Expand All @@ -24,8 +26,9 @@ export type StepState =
}
| {
kind: "FAILED";
errorType: ErrorType;
durationMs: number;
error: string;
message: string;
};

export type StepShape<
Expand All @@ -43,6 +46,7 @@ type ExecuteContext<Shape extends StepShape> = {
): void;
queryLLM(promptPair: PromptPair): Promise<string | null>;
log(log: LogEntry): void;
fail(errorType: ErrorType, message: string): void;
// workflow: Workflow; // intentionally not exposed, but if you need it, add it here
};

Expand All @@ -69,9 +73,10 @@ export class LLMStepTemplate<const Shape extends StepShape = StepShape> {

instantiate(
workflow: Workflow,
inputs: Inputs<Shape>
inputs: Inputs<Shape>,
retryingStep?: LLMStepInstance<Shape> | undefined
): LLMStepInstance<Shape> {
return new LLMStepInstance(this, workflow, inputs);
return new LLMStepInstance(this, workflow, inputs, retryingStep);
}
}

Expand All @@ -87,17 +92,23 @@ export class LLMStepInstance<const Shape extends StepShape = StepShape> {
constructor(
public readonly template: LLMStepTemplate<Shape>,
public readonly workflow: Workflow,
private readonly inputs: Inputs<Shape>
public readonly inputs: Inputs<Shape>,
public retryingStep?: LLMStepInstance<Shape> | undefined
) {
this.startTime = Date.now();
this.id = crypto.randomUUID();
this.logger = new Logger();
this.inputs = inputs;
}

getLogs(): TimestampedLogEntry[] {
return this.logger.logs;
}

isRetrying(): boolean {
return !!this.retryingStep;
}

getConversationMessages(): Message[] {
return this.conversationMessages;
}
Expand All @@ -109,25 +120,32 @@ export class LLMStepInstance<const Shape extends StepShape = StepShape> {

const limits = this.workflow.checkResourceLimits();
if (limits) {
this.criticalError(limits);
this.fail("CRITICAL", limits);
return;
}

const executeContext: ExecuteContext<Shape> = {
setOutput: (key, value) => this.setOutput(key, value),
log: (log) => this.log(log),
queryLLM: (promptPair) => this.queryLLM(promptPair),
fail: (errorType, message) => this.fail(errorType, message),
};

try {
await this.template.execute(executeContext, this.inputs);
} catch (error) {
this.criticalError(
this.fail(
"MINOR",
error instanceof Error ? error.message : String(error)
);
return;
}
this.complete();

const hasFailed = (this.state as StepState).kind === "FAILED";

if (!hasFailed) {
this.state = { kind: "DONE", durationMs: this.calculateDuration() };
}
}

getState() {
Expand Down Expand Up @@ -167,7 +185,10 @@ export class LLMStepInstance<const Shape extends StepShape = StepShape> {
value: Outputs<Shape>[K] | Outputs<Shape>[K]["value"]
): void {
if (key in this.outputs) {
this.criticalError(`Output ${key} is already set`);
this.fail(
"CRITICAL",
`Output ${key} is already set. This is a bug with the workflow code.`
);
return;
}

Expand All @@ -186,26 +207,20 @@ export class LLMStepInstance<const Shape extends StepShape = StepShape> {
this.logger.log(log);
}

private criticalError(error: string) {
this.log({ type: "error", message: error });
private fail(errorType: ErrorType, message: string) {
this.log({ type: "error", message });
this.state = {
kind: "FAILED",
durationMs: this.calculateDuration(),
error,
errorType,
message,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's some awkwardness here that makes me want to suggest exceptions as a failure mechanism.

I remember you don't like exceptions, so I'll explain my observations first.

  1. You log the error here (reasonable), but then sometimes you log it again, e.g. in FixCodeUntilItRuns:
        context.fail("MINOR", newCode.value);
        context.log({
          type: "codeRunError",
          error: newCode.value,
        });

This will lead to duplicate error log entries.

Because you needed to type it as codeRunError... ok, one way around it is to pass the error to fail(), like, fail(errorType: ErrorType, message: Extract<LogEntry, { type: "error" | "codeRunError" }>), and then pass that. Fixable.

  1. But then there's also the awkwardness that the step code can call fail() and then continue to do things. Which would be mostly ignored: the step is already failed, there should be nothing to do.

This also interferes with LLMStep.run(): you had to explicitly cast this.state as CodeState there, etc.

So, how about we just use throw MinorStepError(message) as the main way to fail? Then run() becomes more natural, if we catch an error then we minor-fail and either log the message or log String(error) if it's not a minor error. And we complete() otherwise.

One reason why I like this is because we already support exceptions; steps can throw, and we don't ask step implementations to carefully try/catch their code. So they might as well throw more meaningful errors.

In other words, fail("MINOR", "message") function is unnecessary, it's identical to throw new Error("message").

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, will experiment here.

};
}

private calculateDuration() {
return Date.now() - this.startTime;
}

private complete() {
if (this.state.kind === "FAILED") {
return;
}
this.state = { kind: "DONE", durationMs: this.calculateDuration() };
}

private addConversationMessage(message: Message): void {
this.conversationMessages.push(message);
}
Expand Down Expand Up @@ -259,7 +274,8 @@ export class LLMStepInstance<const Shape extends StepShape = StepShape> {

return completion.content;
} catch (error) {
this.criticalError(
this.fail(
"MINOR",
`Error in queryLLM: ${error instanceof Error ? error.message : error}`
);
return null;
Expand Down
4 changes: 4 additions & 0 deletions packages/ai/src/scripts/createSquiggle.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import { config } from "dotenv";

import { SquiggleWorkflow } from "../workflows/SquiggleWorkflow.js";

config();

async function main() {
const prompt =
"Generate a function that takes a list of numbers and returns the sum of the numbers";
Expand Down
6 changes: 6 additions & 0 deletions packages/ai/src/scripts/editSquiggle.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import { config } from "dotenv";

import { SquiggleWorkflow } from "../workflows/SquiggleWorkflow.js";

config();

async function main() {
const initialCode = `
foo = 0 to 100
bar = 30
foo
bar
`;

const { totalPrice, runTimeMs, llmRunCount, code, isValid, logSummary } =
Expand Down
9 changes: 5 additions & 4 deletions packages/ai/src/steps/fixCodeUntilItRunsStep.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,17 @@ export const fixCodeUntilItRunsStep = new LLMStepTemplate(

const completion = await context.queryLLM(promptPair);
if (completion) {
const nextState = await diffCompletionContentToCode(
const newCodeResult = await diffCompletionContentToCode(
completion,
code.value
);
if (nextState.ok) {
context.setOutput("code", nextState.value);
if (newCodeResult.ok) {
context.setOutput("code", newCodeResult.value);
} else {
context.fail("MINOR", newCodeResult.value);
context.log({
type: "codeRunError",
error: nextState.value,
error: newCodeResult.value,
});
}
}
Expand Down
16 changes: 9 additions & 7 deletions packages/ai/src/workflows/SquiggleWorkflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,18 @@ export class SquiggleWorkflow extends ControlledWorkflow {

protected configureControllerLoop(): void {
this.workflow.addEventListener("stepFinished", ({ data: { step } }) => {
if (step.getState().kind !== "DONE") {
return;
}

// output name is hardcoded, should we scan all outputs?
const code = step.getOutputs()["code"];
if (code?.kind !== "code") {
return;
const state = step.getState();

if (state.kind === "FAILED") {
if (state.errorType === "MINOR") {
this.workflow.addRetryOfPreviousStep();
OAGr marked this conversation as resolved.
Show resolved Hide resolved
}
return true;
}

if (code === undefined || code.kind !== "code") return;

if (code.value.type === "success") {
this.workflow.addStep(adjustToFeedbackStep, {
prompt: this.prompt,
Expand Down
29 changes: 27 additions & 2 deletions packages/ai/src/workflows/Workflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ export type WorkflowEventListener<T extends WorkflowEventType> = (
* See `ControlledWorkflow` for a common base class that controls the workflow
* by injecting new steps based on events.
*/

const MAX_RETRIES = 5;
export class Workflow {
private steps: LLMStepInstance[] = [];
private priceLimit: number;
Expand Down Expand Up @@ -121,10 +123,15 @@ export class Workflow {

addStep<S extends StepShape>(
template: LLMStepTemplate<S>,
inputs: Inputs<S>
inputs: Inputs<S>,
options?: { retryingStep?: LLMStepInstance<S> }
): LLMStepInstance<S> {
// sorry for "any"; countervariance issues
const step: LLMStepInstance<any> = template.instantiate(this, inputs);
const step: LLMStepInstance<any> = template.instantiate(
this,
inputs,
options?.retryingStep
);
this.steps.push(step);
this.dispatchEvent({
type: "stepAdded",
Expand All @@ -133,6 +140,24 @@ export class Workflow {
return step;
}

addRetryOfPreviousStep() {
const lastStep = this.steps.at(-1);
if (!lastStep) return;

const retryingStep = lastStep.retryingStep || lastStep;
const retryAttempts = this.getCurrentRetryAttempts(retryingStep.id);

if (retryAttempts >= MAX_RETRIES) {
OAGr marked this conversation as resolved.
Show resolved Hide resolved
return;
}

this.addStep(retryingStep.template, retryingStep.inputs, { retryingStep });
}

public getCurrentRetryAttempts(stepId: string): number {
return this.steps.filter((step) => step.retryingStep?.id === stepId).length;
OAGr marked this conversation as resolved.
Show resolved Hide resolved
}

private async runNextStep(): Promise<void> {
const step = this.getCurrentStep();

Expand Down
7 changes: 5 additions & 2 deletions packages/hub/src/app/ai/WorkflowSummaryList.tsx
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { orderBy } from "lodash";
import { FC } from "react";

import { SerializedWorkflow } from "@quri/squiggle-ai";
Expand All @@ -9,9 +10,11 @@ export const WorkflowSummaryList: FC<{
selectedWorkflow: SerializedWorkflow | undefined;
selectWorkflow: (id: string) => void;
}> = ({ workflows, selectedWorkflow, selectWorkflow }) => {
const sortedWorkflows = orderBy(workflows, ["timestamp"], ["desc"]);
OAGr marked this conversation as resolved.
Show resolved Hide resolved

return (
<div className="flex flex-col space-y-2">
{workflows.map((workflow) => (
<div className="flex max-h-[400px] flex-col space-y-2 overflow-y-auto pr-2">
{sortedWorkflows.map((workflow) => (
<WorkflowSummaryItem
key={workflow.id}
workflow={workflow}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ export const WorkflowActions: FC<{
height: number;
onNodeClick?: (node: SerializedStep) => void;
}> = ({ workflow, height, onNodeClick }) => {
const [selectedNodeIndex, setSelectedNodeIndex] = useState<number | null>(0);
const [selectedNodeIndex, setSelectedNodeIndex] = useState<number | null>(
workflow.steps.length - 1
);
const prevStepsLengthRef = useRef(workflow.steps.length);

useEffect(() => {
Expand Down
12 changes: 11 additions & 1 deletion packages/hub/src/app/ai/WorkflowViewer/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,17 @@ export const WorkflowViewer: FC<WorkflowViewerProps> = ({
case "loading":
return <LoadingWorkflowViewer {...props} workflow={workflow} />;
case "error":
return <div className="text-red-700">{workflow.result}</div>;
return (
<div className="mt-2 rounded-md border border-red-300 bg-red-50 p-4">
<h3 className="mb-2 text-lg font-semibold text-red-800">
Server Error
</h3>
<p className="mb-4 text-red-700">{workflow.result}</p>
<p className="text-sm text-red-600">
Please try refreshing the page or attempt your action again.
</p>
</div>
);
default:
throw workflow satisfies never;
}
Expand Down
2 changes: 1 addition & 1 deletion packages/hub/src/app/ai/useSquiggleWorkflows.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ export function useSquiggleWorkflows(initialWorkflows: SerializedWorkflow[]) {
updateWorkflow(id, (workflow) => ({
...workflow,
status: "error",
result: `Error: ${error instanceof Error ? error.toString() : "Unknown error"}`,
result: `Server error: ${error instanceof Error ? error.toString() : "Unknown error"}.`,
}));
}
},
Expand Down
4 changes: 3 additions & 1 deletion pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.