Skip to content

Commit

Permalink
Merge pull request #3378 from quantified-uncertainty/db-workflows
Browse files Browse the repository at this point in the history
Store AI workflows in DB
  • Loading branch information
OAGr authored Sep 23, 2024
2 parents 790f109 + c062441 commit 4b51a84
Show file tree
Hide file tree
Showing 23 changed files with 576 additions and 269 deletions.
14 changes: 7 additions & 7 deletions packages/ai/src/LLMClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,21 @@ export interface LlmMetrics {
apiCalls: number;
inputTokens: number;
outputTokens: number;
LlmId: LlmId;
llmId: LlmId;
}

export function calculatePriceMultipleCalls(
metrics: Partial<Record<LlmId, LlmMetrics>>
): number {
let totalCost = 0;

for (const [LlmId, { inputTokens, outputTokens }] of Object.entries(
for (const [llmId, { inputTokens, outputTokens }] of Object.entries(
metrics
)) {
const modelConfig = MODEL_CONFIGS.find((model) => model.id === LlmId);
const modelConfig = MODEL_CONFIGS.find((model) => model.id === llmId);

if (!modelConfig) {
console.warn(`No pricing information found for LLM: ${LlmId}`);
console.warn(`No pricing information found for LLM: ${llmId}`);
continue;
}

Expand Down Expand Up @@ -132,7 +132,7 @@ export class LLMClient {
private anthropicClient?: Anthropic;

constructor(
public LlmId: LlmId,
public llmId: LlmId,
openaiApiKey?: string,
anthropicApiKey?: string
) {
Expand Down Expand Up @@ -168,11 +168,11 @@ export class LLMClient {
conversationHistory: Message[]
): Promise<StandardizedChatCompletion> {
const selectedModelConfig = MODEL_CONFIGS.find(
(model) => model.id === this.LlmId
(model) => model.id === this.llmId
);

if (!selectedModelConfig) {
throw new Error(`No model config found for LLM: ${this.LlmId}`);
throw new Error(`No model config found for LLM: ${this.llmId}`);
}

try {
Expand Down
4 changes: 2 additions & 2 deletions packages/ai/src/LLMStep.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ export class LLMStepInstance<const Shape extends StepShape = StepShape> {
const totalCost = calculatePriceMultipleCalls(
this.llmMetricsList.reduce(
(acc, metrics) => {
acc[metrics.LlmId] = metrics;
acc[metrics.llmId] = metrics;
return acc;
},
{} as Record<LlmId, LlmMetrics>
Expand Down Expand Up @@ -236,7 +236,7 @@ export class LLMStepInstance<const Shape extends StepShape = StepShape> {
apiCalls: 1,
inputTokens: completion?.usage?.prompt_tokens ?? 0,
outputTokens: completion?.usage?.completion_tokens ?? 0,
LlmId: workflow.llmConfig.llmId,
llmId: workflow.llmConfig.llmId,
});

if (!completion?.content) {
Expand Down
8 changes: 4 additions & 4 deletions packages/ai/src/generateSummary.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ function generateOverview(workflow: Workflow): string {
let overview = `- Total Steps: ${steps.length}\n`;
overview += `- Total Time: ${(totalTime / 1000).toFixed(2)} seconds\n`;

for (const [LlmId, metrics] of Object.entries(metricsByLLM)) {
overview += `- ${LlmId}:\n`;
for (const [llmId, metrics] of Object.entries(metricsByLLM)) {
overview += `- ${llmId}:\n`;
overview += ` - API Calls: ${metrics.apiCalls}\n`;
overview += ` - Input Tokens: ${metrics.inputTokens}\n`;
overview += ` - Output Tokens: ${metrics.outputTokens}\n`;
Expand Down Expand Up @@ -82,8 +82,8 @@ function generateDetailedStepLogs(workflow: Workflow): string {
detailedLogs += `- ⏱️ Duration: ${step.getDuration() / 1000} seconds\n`;

step.llmMetricsList.forEach((metrics) => {
const cost = calculatePriceMultipleCalls({ [metrics.LlmId]: metrics });
detailedLogs += `- ${metrics.LlmId}:\n`;
const cost = calculatePriceMultipleCalls({ [metrics.llmId]: metrics });
detailedLogs += `- ${metrics.llmId}:\n`;
detailedLogs += ` - API Calls: ${metrics.apiCalls}\n`;
detailedLogs += ` - Input Tokens: ${metrics.inputTokens}\n`;
detailedLogs += ` - Output Tokens: ${metrics.outputTokens}\n`;
Expand Down
1 change: 1 addition & 0 deletions packages/ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export {
type SerializedMessage,
type SerializedStep,
type SerializedWorkflow,
serializedWorkflowSchema,
workflowMessageSchema,
type WorkflowResult,
} from "./types.js";
Expand Down
2 changes: 1 addition & 1 deletion packages/ai/src/modelConfigs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ export const MODEL_CONFIGS = [
contextWindow: 131072,
name: "Llama 3.1",
},
] as const;
] as const satisfies ModelConfig[];

export type LlmId = (typeof MODEL_CONFIGS)[number]["id"];
export type LlmName = (typeof MODEL_CONFIGS)[number]["name"];
87 changes: 63 additions & 24 deletions packages/ai/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import { z } from "zod";

import { type SquiggleWorkflowInput } from "./workflows/SquiggleWorkflow.js";
// This could be defined in SquiggleWorkflow.ts, but it would cause a dependency on server-only modules.
export const squiggleWorkflowInputSchema = z.discriminatedUnion("type", [
z.object({
type: z.literal("Create"),
prompt: z.string(),
}),
z.object({
type: z.literal("Edit"),
source: z.string(),
prompt: z.string().optional(),
}),
]);

// Protocol for streaming workflow changes between server and client.

Expand Down Expand Up @@ -54,21 +65,16 @@ const stepSchema = z.object({

export type SerializedStep = z.infer<typeof stepSchema>;

// SquiggleWorkflowResult type
// Messages that incrementally update the SerializedWorkflow.
// They are using for streaming updates from the server to the client.
// They are similar to Workflow events, but not exactly the same. They must be JSON-serializable.
// See `addStreamingListeners` in workflows/streaming.ts for how they are used.

export const workflowResultSchema = z.object({
code: z.string().describe("Squiggle code snippet"),
isValid: z.boolean(),
totalPrice: z.number(),
runTimeMs: z.number(),
llmRunCount: z.number(),
logSummary: z.string(), // markdown
const workflowStartedSchema = z.object({
id: z.string(),
timestamp: z.number(),
});

export type WorkflowResult = z.infer<typeof workflowResultSchema>;

// Messages that incrementally update the SerializedWorkflow

const stepAddedSchema = stepSchema.omit({
state: true,
outputs: true,
Expand All @@ -81,7 +87,24 @@ const stepUpdatedSchema = stepSchema.partial().required({
outputs: true,
});

// WorkflowResult type

export const workflowResultSchema = z.object({
code: z.string().describe("Squiggle code snippet"),
isValid: z.boolean(),
totalPrice: z.number(),
runTimeMs: z.number(),
llmRunCount: z.number(),
logSummary: z.string(), // markdown
});

export type WorkflowResult = z.infer<typeof workflowResultSchema>;

export const workflowMessageSchema = z.discriminatedUnion("kind", [
z.object({
kind: z.literal("workflowStarted"),
content: workflowStartedSchema,
}),
z.object({
kind: z.literal("finalResult"),
content: workflowResultSchema,
Expand All @@ -100,14 +123,30 @@ export type WorkflowMessage = z.infer<typeof workflowMessageSchema>;

// Client-side representation of a workflow

export type SerializedWorkflow = {
id: string;
timestamp: Date;
input: SquiggleWorkflowInput; // FIXME - SquiggleWorkflow-specific
steps: SerializedStep[];
currentStep?: string;
} & (
| { status: "loading"; result?: undefined }
| { status: "finished"; result: WorkflowResult }
| { status: "error"; result: string }
);
const commonWorkflowFields = {
id: z.string(),
timestamp: z.number(), // milliseconds since epoch
input: squiggleWorkflowInputSchema, // FIXME - SquiggleWorkflow-specific
steps: z.array(stepSchema),
currentStep: z.string().optional(),
};

export const serializedWorkflowSchema = z.discriminatedUnion("status", [
z.object({
...commonWorkflowFields,
status: z.literal("loading"),
result: z.undefined(),
}),
z.object({
...commonWorkflowFields,
status: z.literal("finished"),
result: workflowResultSchema,
}),
z.object({
...commonWorkflowFields,
status: z.literal("error"),
result: z.string(),
}),
]);

export type SerializedWorkflow = z.infer<typeof serializedWorkflowSchema>;
2 changes: 2 additions & 0 deletions packages/ai/src/workflows/ControlledWorkflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ export abstract class ControlledWorkflow {
start: async (controller) => {
addStreamingListeners(this.workflow, controller);

this.workflow.prepareToStart();

// Important! `configure` should be called after all event listeners are set up.
// We want to capture `stepAdded` events.
this.configure();
Expand Down
7 changes: 4 additions & 3 deletions packages/ai/src/workflows/SquiggleWorkflow.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import { z } from "zod";

import { PromptArtifact, SourceArtifact } from "../Artifact.js";
import { adjustToFeedbackStep } from "../steps/adjustToFeedbackStep.js";
import { fixCodeUntilItRunsStep } from "../steps/fixCodeUntilItRunsStep.js";
import { generateCodeStep } from "../steps/generateCodeStep.js";
import { runAndFormatCodeStep } from "../steps/runAndFormatCodeStep.js";
import { squiggleWorkflowInputSchema } from "../types.js";
import { ControlledWorkflow } from "./ControlledWorkflow.js";
import { LlmConfig } from "./Workflow.js";

export type SquiggleWorkflowInput =
| { type: "Create"; prompt: string }
| { type: "Edit"; source: string; prompt?: string };
export type SquiggleWorkflowInput = z.infer<typeof squiggleWorkflowInputSchema>;

/**
* This is a basic workflow for generating Squiggle code.
Expand Down
32 changes: 23 additions & 9 deletions packages/ai/src/workflows/Workflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ export const llmConfigDefault: LlmConfig = {
};

export type WorkflowEventShape =
| {
type: "workflowStarted";
payload?: undefined;
}
| {
type: "stepAdded";
payload: {
Expand Down Expand Up @@ -84,9 +88,10 @@ export class Workflow {
private steps: LLMStepInstance[] = [];
private priceLimit: number;
private durationLimitMs: number;
private startTime: number;

public llmClient: LLMClient;
public id: string;
public startTime: number;

constructor(
public llmConfig: LlmConfig = llmConfigDefault,
Expand All @@ -95,7 +100,9 @@ export class Workflow {
) {
this.priceLimit = llmConfig.priceLimit;
this.durationLimitMs = llmConfig.durationLimitMinutes * 1000 * 60;

this.startTime = Date.now();
this.id = crypto.randomUUID();

this.llmClient = new LLMClient(
llmConfig.llmId,
Expand All @@ -104,6 +111,14 @@ export class Workflow {
);
}

// This is a hook that ControlledWorkflow can use to prepare the workflow.
// It's a bit of a hack; we need to dispatch this event after we configured the event handlers,
// but before we add any steps.
// So we can't do this neither in the constructor nor in `runUntilComplete`.
prepareToStart() {
this.dispatchEvent({ type: "workflowStarted" });
}

addStep<S extends StepShape>(
template: LLMStepTemplate<S>,
inputs: Inputs<S>
Expand Down Expand Up @@ -143,9 +158,8 @@ export class Workflow {
while (!this.isProcessComplete()) {
await this.runNextStep();
}
this.dispatchEvent({
type: "allStepsFinished",
});

this.dispatchEvent({ type: "allStepsFinished" });
}

checkResourceLimits(): string | undefined {
Expand Down Expand Up @@ -216,12 +230,12 @@ export class Workflow {
return this.getSteps().reduce(
(acc, step) => {
step.llmMetricsList.forEach((metrics) => {
if (!acc[metrics.LlmId]) {
acc[metrics.LlmId] = { ...metrics };
if (!acc[metrics.llmId]) {
acc[metrics.llmId] = { ...metrics };
} else {
acc[metrics.LlmId].apiCalls += metrics.apiCalls;
acc[metrics.LlmId].inputTokens += metrics.inputTokens;
acc[metrics.LlmId].outputTokens += metrics.outputTokens;
acc[metrics.llmId].apiCalls += metrics.apiCalls;
acc[metrics.llmId].inputTokens += metrics.inputTokens;
acc[metrics.llmId].outputTokens += metrics.outputTokens;
}
});
return acc;
Expand Down
Loading

0 comments on commit 4b51a84

Please sign in to comment.