Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions app/components/chat/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import { useReferralCode, useReferralStats } from '~/lib/hooks/useReferralCode';
import { useUsage } from '~/lib/stores/usage';
import { hasAnyApiKeySet, hasApiKeySet } from '~/lib/common/apiKey';
import { chatSyncState } from '~/lib/stores/startup/chatSyncState';
import { customSystemPromptStore } from '~/lib/stores/customSystemPrompt';

const logger = createScopedLogger('Chat');

Expand Down Expand Up @@ -348,6 +349,7 @@ export const Chat = memo(
);

const characterCounts = chatContextManager.current.calculatePromptCharacterCounts(preparedMessages);
const customSystemPrompt = customSystemPromptStore.get();

return {
messages: preparedMessages,
Expand All @@ -367,6 +369,7 @@ export const Chat = memo(
featureFlags: {
enableResend,
},
customSystemPrompt,
};
},
maxSteps: 64,
Expand Down
397 changes: 273 additions & 124 deletions app/components/chat/MessageInput.tsx

Large diffs are not rendered by default.

14 changes: 12 additions & 2 deletions app/lib/.server/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,21 @@ export async function chatAction({ request }: ActionFunctionArgs) {
recordRawPromptsForDebugging?: boolean;
collapsedMessages: boolean;
promptCharacterCounts?: PromptCharacterCounts;
customSystemPrompt?: string | null;
featureFlags: {
enableResend?: boolean;
};
};
const { messages, firstUserMessage, chatInitialId, deploymentName, token, teamSlug, recordRawPromptsForDebugging } =
body;
const {
messages,
firstUserMessage,
chatInitialId,
deploymentName,
token,
teamSlug,
recordRawPromptsForDebugging,
customSystemPrompt,
} = body;

if (getEnv('DISABLE_BEDROCK') === '1' && body.modelProvider === 'Bedrock') {
body.modelProvider = 'Anthropic';
Expand Down Expand Up @@ -186,6 +195,7 @@ export async function chatAction({ request }: ActionFunctionArgs) {
featureFlags: {
enableResend: body.featureFlags.enableResend ?? false,
},
customSystemPrompt: typeof customSystemPrompt === 'string' ? customSystemPrompt : undefined,
});

return new Response(dataStream, {
Expand Down
13 changes: 12 additions & 1 deletion app/lib/.server/llm/convex-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ export async function convexAgent(args: {
featureFlags: {
enableResend: boolean;
};
customSystemPrompt?: string;
}) {
const {
chatInitialId,
Expand All @@ -73,6 +74,7 @@ export async function convexAgent(args: {
collapsedMessages,
promptCharacterCounts,
featureFlags,
customSystemPrompt,
} = args;
console.debug('Starting agent with model provider', modelProvider);
if (userApiKey) {
Expand Down Expand Up @@ -111,9 +113,18 @@ export async function convexAgent(args: {
role: 'system' as const,
content: generalSystemPrompt(opts),
},
...cleanupAssistantMessages(messages),
];

const trimmedCustomPrompt = customSystemPrompt?.trim();
if (trimmedCustomPrompt) {
messagesForDataStream.push({
role: 'system' as const,
content: trimmedCustomPrompt,
});
}

messagesForDataStream.push(...cleanupAssistantMessages(messages));

if (modelProvider === 'Bedrock') {
messagesForDataStream[messagesForDataStream.length - 1].providerOptions = {
bedrock: {
Expand Down
3 changes: 3 additions & 0 deletions app/lib/stores/customSystemPrompt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import { atom } from 'nanostores';

export const customSystemPromptStore = atom<string | undefined>(undefined);
4 changes: 4 additions & 0 deletions app/lib/stores/startup/useInitialMessages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import * as lz4 from 'lz4-wasm';
import { getConvexSiteUrl } from '~/lib/convexSiteUrl';
import { subchatIndexStore } from '~/lib/stores/subchats';
import { useStore } from '@nanostores/react';
import { customSystemPromptStore } from '~/lib/stores/customSystemPrompt';

export interface InitialMessages {
loadedChatId: string;
Expand Down Expand Up @@ -43,6 +44,7 @@ export function useInitialMessages(chatId: string | undefined):
});
if (chatInfo === null) {
setInitialMessages(null);
customSystemPromptStore.set(undefined);
return;
}
if (subchatIndex === undefined) {
Expand Down Expand Up @@ -75,6 +77,7 @@ export function useInitialMessages(chatId: string | undefined):
deserialized: [],
loadedSubchatIndex: subchatIndexToFetch,
});
customSystemPromptStore.set(chatInfo.customSystemPrompt ?? undefined);
return;
}
const content = await initialMessagesResponse.arrayBuffer();
Expand Down Expand Up @@ -118,6 +121,7 @@ export function useInitialMessages(chatId: string | undefined):
loadedSubchatIndex: subchatIndexToFetch,
});
description.set(chatInfo.description);
customSystemPromptStore.set(chatInfo.customSystemPrompt ?? undefined);
} catch (error) {
toast.error('Failed to load chat messages from Convex. Try reloading the page.');
console.error('Error fetching initial messages:', error);
Expand Down
22 changes: 22 additions & 0 deletions convex/messages.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,28 @@ describe("messages", () => {
await expect(
t.mutation(api.messages.setDescription, { sessionId, id: chatId, description: "test" }),
).rejects.toThrow();
await expect(
t.mutation(api.messages.setCustomSystemPrompt, { sessionId, id: chatId, customSystemPrompt: "Prompt" }),
).rejects.toThrow();
});

test("set custom system prompt", async () => {
const { sessionId, chatId } = await createChat(t);

const prompt = "Use concise language.";
await t.mutation(api.messages.setCustomSystemPrompt, { sessionId, id: chatId, customSystemPrompt: prompt });

const chatWithPrompt = await t.query(api.messages.get, { sessionId, id: chatId });
expect(chatWithPrompt?.customSystemPrompt).toBe(prompt);

await t.mutation(api.messages.setCustomSystemPrompt, { sessionId, id: chatId, customSystemPrompt: null });
const chatWithoutPrompt = await t.query(api.messages.get, { sessionId, id: chatId });
expect(chatWithoutPrompt?.customSystemPrompt).toBeUndefined();

const tooLongPrompt = "a".repeat(2001);
await expect(
t.mutation(api.messages.setCustomSystemPrompt, { sessionId, id: chatId, customSystemPrompt: tooLongPrompt }),
).rejects.toThrow("Custom system prompt must be at most 2000 characters long");
});

test("store chat without snapshot", async () => {
Expand Down
34 changes: 34 additions & 0 deletions convex/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ export type SerializedMessage = Omit<AIMessage, "createdAt" | "content"> & {

export const CHAT_NOT_FOUND_ERROR = new ConvexError({ code: "NotFound", message: "Chat not found" });

const MAX_CUSTOM_SYSTEM_PROMPT_LENGTH = 2000;

export const initializeChat = mutation({
args: {
sessionId: v.id("sessions"),
Expand Down Expand Up @@ -102,6 +104,36 @@ export const setDescription = mutation({
},
});

export const setCustomSystemPrompt = mutation({
args: {
sessionId: v.id("sessions"),
id: v.string(),
customSystemPrompt: v.union(v.string(), v.null()),
},
returns: v.null(),
handler: async (ctx, args) => {
const { id, customSystemPrompt } = args;
const existing = await getChatByIdOrUrlIdEnsuringAccess(ctx, { id, sessionId: args.sessionId });

if (!existing) {
throw CHAT_NOT_FOUND_ERROR;
}

const normalizedPrompt = customSystemPrompt === null ? "" : customSystemPrompt.trim();

if (normalizedPrompt.length > MAX_CUSTOM_SYSTEM_PROMPT_LENGTH) {
throw new ConvexError({
code: "InvalidArgument",
message: `Custom system prompt must be at most ${MAX_CUSTOM_SYSTEM_PROMPT_LENGTH} characters long`,
});
}

await ctx.db.patch(existing._id, {
customSystemPrompt: normalizedPrompt.length > 0 ? normalizedPrompt : undefined,
});
},
});

export async function getChat(ctx: QueryCtx, id: string, sessionId: Id<"sessions">) {
const chat = await getChatByIdOrUrlIdEnsuringAccess(ctx, { id, sessionId });

Expand All @@ -114,6 +146,7 @@ export async function getChat(ctx: QueryCtx, id: string, sessionId: Id<"sessions
initialId: chat.initialId,
urlId: chat.urlId,
description: chat.description,
customSystemPrompt: chat.customSystemPrompt,
timestamp: chat.timestamp,
snapshotId: chat.snapshotId,
subchatIndex: chat.lastSubchatIndex,
Expand All @@ -131,6 +164,7 @@ export const get = query({
initialId: v.string(),
urlId: v.optional(v.string()),
description: v.optional(v.string()),
customSystemPrompt: v.optional(v.string()),
timestamp: v.string(),
snapshotId: v.optional(v.id("_storage")),
subchatIndex: v.optional(v.number()),
Expand Down
1 change: 1 addition & 0 deletions convex/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ export default defineSchema({
initialId: v.string(),
urlId: v.optional(v.string()),
description: v.optional(v.string()),
customSystemPrompt: v.optional(v.string()),
timestamp: v.string(),
metadata: v.optional(v.any()), // TODO migration to remove this column
snapshotId: v.optional(v.id("_storage")),
Expand Down