From 3b06087afe19cc1e6b04ec78cefee752201762c2 Mon Sep 17 00:00:00 2001 From: Rob Gordon Date: Wed, 10 Jan 2024 11:36:38 -0500 Subject: [PATCH] Basic ai editing working --- api/_lib/_llm.ts | 70 +++++++++ api/package.json | 8 +- api/prompt/edit.ts | 42 +++++ app/package.json | 2 +- app/src/components/EditWithAI.tsx | 200 ++++++++++++++++++++++++ app/src/components/EditWithAIButton.tsx | 45 ------ app/src/lib/graphOptions.ts | 2 +- app/src/pages/Sandbox.tsx | 38 +---- pnpm-lock.yaml | 49 +++--- 9 files changed, 350 insertions(+), 106 deletions(-) create mode 100644 api/_lib/_llm.ts create mode 100644 api/prompt/edit.ts create mode 100644 app/src/components/EditWithAI.tsx delete mode 100644 app/src/components/EditWithAIButton.tsx diff --git a/api/_lib/_llm.ts b/api/_lib/_llm.ts new file mode 100644 index 00000000..9db8e18b --- /dev/null +++ b/api/_lib/_llm.ts @@ -0,0 +1,70 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { z, ZodObject } from "zod"; +import { openai } from "./_openai"; +import zodToJsonSchema from "zod-to-json-schema"; +import OpenAI from "openai"; + +type Schemas>> = T; + +export async function llmMany>>( + content: string, + schemas: Schemas +) { + try { + // if the user passes a key "message" in schemas, throw an error + if (schemas.message) throw new Error("Cannot use key 'message' in schemas"); + + const completion = await openai.chat.completions.create({ + messages: [ + { + role: "user", + content, + }, + ], + tools: Object.entries(schemas).map(([key, schema]) => ({ + type: "function", + function: { + name: key, + parameters: zodToJsonSchema(schema), + }, + })), + model: "gpt-3.5-turbo-1106", + // model: "gpt-4-1106-preview", + }); + + const choice = completion.choices[0]; + + if (!choice) throw new Error("No choices returned"); + + // Must return the full thing, message and multiple tool calls + return simplifyChoice(choice) as SimplifiedChoice; + } catch (error) { + console.error(error); + const message = (error as Error)?.message || "Error with prompt"; + throw new Error(message); + } +} + +type SimplifiedChoice>> = { + message: string; + toolCalls: Array< + { + [K in keyof T]: { + name: K; + args: z.infer; + }; + }[keyof T] + >; +}; + +function simplifyChoice(choice: OpenAI.Chat.Completions.ChatCompletion.Choice) { + return { + message: choice.message.content || "", + toolCalls: + choice.message.tool_calls?.map((toolCall) => ({ + name: toolCall.function.name, + // Wish this were type-safe! + args: JSON.parse(toolCall.function.arguments ?? "{}"), + })) || [], + }; +} diff --git a/api/package.json b/api/package.json index 2c5a777a..0047e7b8 100644 --- a/api/package.json +++ b/api/package.json @@ -18,14 +18,16 @@ "ajv": "^8.12.0", "csv-parse": "^5.3.6", "date-fns": "^2.29.3", - "graph-selector": "^0.9.11", + "graph-selector": "^0.10.0", "highlight.js": "^11.8.0", "marked": "^4.1.1", "moniker": "^0.1.2", "notion-to-md": "^2.5.5", - "openai": "^4.10.0", + "openai": "^4.24.2", "shared": "workspace:*", - "stripe": "^11.11.0" + "stripe": "^11.11.0", + "zod": "^3.22.4", + "zod-to-json-schema": "^3.22.3" }, "devDependencies": { "@swc/jest": "^0.2.24", diff --git a/api/prompt/edit.ts b/api/prompt/edit.ts new file mode 100644 index 00000000..c6e3cfb0 --- /dev/null +++ b/api/prompt/edit.ts @@ -0,0 +1,42 @@ +import { VercelApiHandler } from "@vercel/node"; +import { llmMany } from "../_lib/_llm"; +import { z } from "zod"; + +const nodeSchema = z.object({ + // id: z.string(), + // classes: z.string(), + label: z.string(), +}); + +const edgeSchema = z.object({ + from: z.string(), + to: z.string(), + label: z.string().optional().default(""), +}); + +const graphSchema = z.object({ + nodes: z.array(nodeSchema), + edges: z.array(edgeSchema), +}); + +const handler: VercelApiHandler = async (req, res) => { + const { graph, prompt } = req.body; + if (!graph || !prompt) { + throw new Error("Missing graph or prompt"); + } + + const result = await llmMany( + `You are an AI flowchart assistant. Help the create a flowchart or diagram. Here is the current state of the flowchart: +${JSON.stringify(graph, null, 2)} + +Here is the user's message: +${prompt}`, + { + updateGraph: graphSchema, + } + ); + + res.json(result); +}; + +export default handler; diff --git a/app/package.json b/app/package.json index 16b6c2f2..a6a366a7 100644 --- a/app/package.json +++ b/app/package.json @@ -77,7 +77,7 @@ "file-saver": "^2.0.5", "formulaic": "workspace:*", "framer-motion": "^10.13.1", - "graph-selector": "^0.9.12", + "graph-selector": "^0.10.0", "gray-matter": "^4.0.2", "highlight.js": "^11.7.0", "immer": "^9.0.16", diff --git a/app/src/components/EditWithAI.tsx b/app/src/components/EditWithAI.tsx new file mode 100644 index 00000000..87ee3324 --- /dev/null +++ b/app/src/components/EditWithAI.tsx @@ -0,0 +1,200 @@ +import { MagicWand, Microphone } from "phosphor-react"; +import { Button2, IconButton2 } from "../ui/Shared"; +import * as Popover from "@radix-ui/react-popover"; +import { Trans, t } from "@lingui/macro"; +import { useCallback, useRef, useState } from "react"; +import { useDoc } from "../lib/useDoc"; +import { parse, stringify, Graph as GSGraph } from "graph-selector"; +import { useMutation } from "react-query"; + +// The Graph type we send to AI is slightly different from internal representation +type GraphForAI = { + nodes: { + label: string; + id?: string; + }[]; + edges: { + label: string; + from: string; + to: string; + }[]; +}; + +export function EditWithAI() { + const [isOpen, setIsOpen] = useState(false); + const { mutate: edit, isLoading } = useMutation({ + mutationFn: async (body: { prompt: string; graph: GraphForAI }) => { + // /api/prompt/edit + const response = await fetch("/api/prompt/edit", { + method: "POST", + body: JSON.stringify(body), + headers: { + "Content-Type": "application/json", + }, + }); + const data = await response.json(); + return data as { + message: string; + toolCalls: { + name: "updateGraph"; + args: GraphForAI; + }[]; + }; + }, + onMutate: () => setIsOpen(false), + onSuccess(data) { + if (data.message) { + window.alert(data.message); + } + + for (const { name, args } of data.toolCalls) { + switch (name) { + case "updateGraph": { + const newText = toGraphSelector(args); + useDoc.setState({ text: newText }, false, "EditWithAI"); + break; + } + } + } + }, + }); + const handleSubmit = useCallback( + (e: React.FormEvent) => { + e.preventDefault(); + + const formData = new FormData(e.currentTarget); + const prompt = formData.get("prompt") as string; + if (!prompt) return; + + const text = useDoc.getState().text; + const _graph = parse(text); + + const graph: GraphForAI = { + nodes: _graph.nodes.map((node) => { + if (isCustomID(node.data.id)) { + return { + label: node.data.label, + id: node.data.id, + }; + } + + return { + label: node.data.label, + }; + }), + edges: _graph.edges.map((edge) => { + // Because generated edges internally use a custom ID, + // we need to find the label, unless the user is using a custom ID + + let from = edge.source; + if (!isCustomID(from)) { + // find the from node + const fromNode = _graph.nodes.find((node) => node.data.id === from); + if (!fromNode) throw new Error("from node not found"); + from = fromNode.data.label; + } + + let to = edge.target; + if (!isCustomID(to)) { + // find the to node + const toNode = _graph.nodes.find((node) => node.data.id === to); + if (!toNode) throw new Error("to node not found"); + to = toNode.data.label; + } + + return { + label: edge.data.label, + from, + to, + }; + }), + }; + + edit({ prompt, graph }); + }, + [edit] + ); + + const formRef = useRef(null); + + return ( + + + + } + color="zinc" + size="sm" + rounded + className="aria-[expanded=true]:bg-zinc-700" + isLoading={isLoading} + > + Edit with AI + + + + +
+
+