Skip to content

Commit 015c602

Browse files
committed
Preserve the routing from 32k to 4
1 parent 9c530bc commit 015c602

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

front/lib/api/assistant/generation.ts

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ import {
99
renderRetrievalActionForModel,
1010
retrievalMetaPrompt,
1111
} from "@app/lib/api/assistant/actions/retrieval";
12-
import { getSupportedModelConfig } from "@app/lib/assistant";
12+
import {
13+
getSupportedModelConfig,
14+
GPT_4_32K_MODEL_ID,
15+
GPT_4_MODEL_CONFIG,
16+
} from "@app/lib/assistant";
1317
import { Authenticator } from "@app/lib/auth";
1418
import { CoreAPI } from "@app/lib/core_api";
1519
import { redisClient } from "@app/lib/redis";
@@ -328,13 +332,15 @@ export async function* runGeneration(
328332
return;
329333
}
330334

331-
const contextSize = getSupportedModelConfig(c.model).contextSize;
335+
let model = c.model;
336+
337+
const contextSize = getSupportedModelConfig(model).contextSize;
332338

333339
const MIN_GENERATION_TOKENS = 2048;
334340

335341
if (contextSize < MIN_GENERATION_TOKENS) {
336342
throw new Error(
337-
`Model contextSize unexpectedly small for model: ${c.model.providerId} ${c.model.modelId}`
343+
`Model contextSize unexpectedly small for model: ${model.providerId} ${model.modelId}`
338344
);
339345
}
340346

@@ -343,7 +349,7 @@ export async function* runGeneration(
343349
// Turn the conversation into a digest that can be presented to the model.
344350
const modelConversationRes = await renderConversationForModel({
345351
conversation,
346-
model: c.model,
352+
model,
347353
prompt,
348354
allowedTokenCount: contextSize - MIN_GENERATION_TOKENS,
349355
});
@@ -356,17 +362,30 @@ export async function* runGeneration(
356362
messageId: agentMessage.sId,
357363
error: {
358364
code: "internal_server_error",
359-
message: `Failed tokenization for ${c.model.providerId} ${c.model.modelId}: ${modelConversationRes.error.message}`,
365+
message: `Failed tokenization for ${model.providerId} ${model.modelId}: ${modelConversationRes.error.message}`,
360366
},
361367
};
362368
return;
363369
}
364370

371+
// If model is gpt4-32k but tokens used is less than GPT_4_CONTEXT_SIZE-MIN_GENERATION_TOKENS,
372+
// then we override the model to gpt4 standard (8k context, cheaper).
373+
if (
374+
model.modelId === GPT_4_32K_MODEL_ID &&
375+
modelConversationRes.value.tokensUsed <
376+
GPT_4_MODEL_CONFIG.contextSize - MIN_GENERATION_TOKENS
377+
) {
378+
model = {
379+
modelId: GPT_4_MODEL_CONFIG.modelId,
380+
providerId: GPT_4_MODEL_CONFIG.providerId,
381+
};
382+
}
383+
365384
const config = cloneBaseConfig(
366385
DustProdActionRegistry["assistant-v2-generator"].config
367386
);
368-
config.MODEL.provider_id = c.model.providerId;
369-
config.MODEL.model_id = c.model.modelId;
387+
config.MODEL.provider_id = model.providerId;
388+
config.MODEL.model_id = model.modelId;
370389
config.MODEL.temperature = c.temperature;
371390

372391
// This is the console.log you want to uncomment to generate inputs for the generator app.
@@ -381,7 +400,7 @@ export async function* runGeneration(
381400
{
382401
workspaceId: conversation.owner.sId,
383402
conversationId: conversation.sId,
384-
model: c.model,
403+
model: model,
385404
temperature: c.temperature,
386405
},
387406
"[ASSISTANT_TRACE] Generation exection"

0 commit comments

Comments
 (0)