Skip to content

Commit ece8725

Browse files
committed
2 parents f49e855 + 506c17a commit ece8725

File tree

10 files changed

+91
-36
lines changed

10 files changed

+91
-36
lines changed

app/api/config/route.ts

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ const DANGER_CONFIG = {
1313
hideBalanceQuery: serverConfig.hideBalanceQuery,
1414
disableFastLink: serverConfig.disableFastLink,
1515
customModels: serverConfig.customModels,
16+
defaultModel: serverConfig.defaultModel,
1617
};
1718

1819
declare global {

app/client/platforms/google.ts

+4-9
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,10 @@ export class GeminiProApi implements LLMApi {
2121
}
2222
async chat(options: ChatOptions): Promise<void> {
2323
// const apiClient = this;
24-
const visionModel = isVisionModel(options.config.model);
2524
let multimodal = false;
2625
const messages = options.messages.map((v) => {
2726
let parts: any[] = [{ text: getMessageTextContent(v) }];
28-
if (visionModel) {
27+
if (isVisionModel(options.config.model)) {
2928
const images = getMessageImages(v);
3029
if (images.length > 0) {
3130
multimodal = true;
@@ -117,17 +116,12 @@ export class GeminiProApi implements LLMApi {
117116
const controller = new AbortController();
118117
options.onController?.(controller);
119118
try {
120-
let googleChatPath = visionModel
121-
? Google.VisionChatPath(modelConfig.model)
122-
: Google.ChatPath(modelConfig.model);
123-
let chatPath = this.path(googleChatPath);
124-
125119
// let baseUrl = accessStore.googleUrl;
126120

127121
if (!baseUrl) {
128122
baseUrl = isApp
129-
? DEFAULT_API_HOST + "/api/proxy/google/" + googleChatPath
130-
: chatPath;
123+
? DEFAULT_API_HOST + "/api/proxy/google/" + Google.ChatPath(modelConfig.model)
124+
: this.path(Google.ChatPath(modelConfig.model));
131125
}
132126

133127
if (isApp) {
@@ -145,6 +139,7 @@ export class GeminiProApi implements LLMApi {
145139
() => controller.abort(),
146140
REQUEST_TIMEOUT_MS,
147141
);
142+
148143
if (shouldStream) {
149144
let responseText = "";
150145
let remainText = "";

app/client/platforms/openai.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ export class ChatGPTApi implements LLMApi {
129129
};
130130

131131
// add max_tokens to vision model
132-
if (visionModel) {
132+
if (visionModel && modelConfig.model.includes("preview")) {
133133
requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
134134
}
135135

app/components/chat.tsx

+22-7
Original file line numberDiff line numberDiff line change
@@ -448,10 +448,20 @@ export function ChatActions(props: {
448448
// switch model
449449
const currentModel = chatStore.currentSession().mask.modelConfig.model;
450450
const allModels = useAllModels();
451-
const models = useMemo(
452-
() => allModels.filter((m) => m.available),
453-
[allModels],
454-
);
451+
const models = useMemo(() => {
452+
const filteredModels = allModels.filter((m) => m.available);
453+
const defaultModel = filteredModels.find((m) => m.isDefault);
454+
455+
if (defaultModel) {
456+
const arr = [
457+
defaultModel,
458+
...filteredModels.filter((m) => m !== defaultModel),
459+
];
460+
return arr;
461+
} else {
462+
return filteredModels;
463+
}
464+
}, [allModels]);
455465
const [showModelSelector, setShowModelSelector] = useState(false);
456466
const [showUploadImage, setShowUploadImage] = useState(false);
457467

@@ -467,7 +477,10 @@ export function ChatActions(props: {
467477
// switch to first available model
468478
const isUnavaliableModel = !models.some((m) => m.name === currentModel);
469479
if (isUnavaliableModel && models.length > 0) {
470-
const nextModel = models[0].name as ModelType;
480+
// show next model to default model if exist
481+
let nextModel: ModelType = (
482+
models.find((model) => model.isDefault) || models[0]
483+
).name;
471484
chatStore.updateCurrentSession(
472485
(session) => (session.mask.modelConfig.model = nextModel),
473486
);
@@ -1102,11 +1115,13 @@ function _Chat() {
11021115
};
11031116
// eslint-disable-next-line react-hooks/exhaustive-deps
11041117
}, []);
1105-
1118+
11061119
const handlePaste = useCallback(
11071120
async (event: React.ClipboardEvent<HTMLTextAreaElement>) => {
11081121
const currentModel = chatStore.currentSession().mask.modelConfig.model;
1109-
if(!isVisionModel(currentModel)){return;}
1122+
if (!isVisionModel(currentModel)) {
1123+
return;
1124+
}
11101125
const items = (event.clipboardData || window.clipboardData).items;
11111126
for (const item of items) {
11121127
if (item.kind === "file" && item.type.startsWith("image/")) {

app/config/server.ts

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ declare global {
2121
ENABLE_BALANCE_QUERY?: string; // allow user to query balance or not
2222
DISABLE_FAST_LINK?: string; // disallow parse settings from url or not
2323
CUSTOM_MODELS?: string; // to control custom models
24+
DEFAULT_MODEL?: string; // to cnntrol default model in every new chat window
2425

2526
// azure only
2627
AZURE_URL?: string; // https://{azure-url}/openai/deployments/{deploy-name}
@@ -59,12 +60,14 @@ export const getServerSideConfig = () => {
5960

6061
const disableGPT4 = !!process.env.DISABLE_GPT4;
6162
let customModels = process.env.CUSTOM_MODELS ?? "";
63+
let defaultModel = process.env.DEFAULT_MODEL ?? "";
6264

6365
if (disableGPT4) {
6466
if (customModels) customModels += ",";
6567
customModels += DEFAULT_MODELS.filter((m) => m.name.startsWith("gpt-4"))
6668
.map((m) => "-" + m.name)
6769
.join(",");
70+
if (defaultModel.startsWith("gpt-4")) defaultModel = "";
6871
}
6972

7073
const isAzure = !!process.env.AZURE_URL;
@@ -116,6 +119,7 @@ export const getServerSideConfig = () => {
116119
hideBalanceQuery: !process.env.ENABLE_BALANCE_QUERY,
117120
disableFastLink: !!process.env.DISABLE_FAST_LINK,
118121
customModels,
122+
defaultModel,
119123
whiteWebDevEndpoints,
120124
};
121125
};

app/constant.ts

-11
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ export const Azure = {
9999
export const Google = {
100100
ExampleEndpoint: "https://generativelanguage.googleapis.com/",
101101
ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`,
102-
VisionChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`,
103102
};
104103

105104
export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
@@ -128,8 +127,6 @@ export const KnowledgeCutOffDate: Record<string, string> = {
128127
"gpt-4-turbo": "2023-12",
129128
"gpt-4-turbo-2024-04-09": "2023-12",
130129
"gpt-4-turbo-preview": "2023-12",
131-
"gpt-4-1106-preview": "2023-04",
132-
"gpt-4-0125-preview": "2023-12",
133130
"gpt-4-vision-preview": "2023-04",
134131
// After improvements,
135132
// it's now easier to add "KnowledgeCutOffDate" instead of stupid hardcoding it, as was done previously.
@@ -139,19 +136,11 @@ export const KnowledgeCutOffDate: Record<string, string> = {
139136

140137
const openaiModels = [
141138
"gpt-3.5-turbo",
142-
"gpt-3.5-turbo-0301",
143-
"gpt-3.5-turbo-0613",
144139
"gpt-3.5-turbo-1106",
145140
"gpt-3.5-turbo-0125",
146-
"gpt-3.5-turbo-16k",
147-
"gpt-3.5-turbo-16k-0613",
148141
"gpt-4",
149-
"gpt-4-0314",
150142
"gpt-4-0613",
151-
"gpt-4-1106-preview",
152-
"gpt-4-0125-preview",
153143
"gpt-4-32k",
154-
"gpt-4-32k-0314",
155144
"gpt-4-32k-0613",
156145
"gpt-4-turbo",
157146
"gpt-4-turbo-preview",

app/store/access.ts

+9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import { getHeaders } from "../client/api";
88
import { getClientConfig } from "../config/client";
99
import { createPersistStore } from "../utils/store";
1010
import { ensure } from "../utils/clone";
11+
import { DEFAULT_CONFIG } from "./config";
1112

1213
let fetchState = 0; // 0 not fetch, 1 fetching, 2 done
1314

@@ -48,6 +49,7 @@ const DEFAULT_ACCESS_STATE = {
4849
disableGPT4: false,
4950
disableFastLink: false,
5051
customModels: "",
52+
defaultModel: "",
5153
};
5254

5355
export const useAccessStore = createPersistStore(
@@ -100,6 +102,13 @@ export const useAccessStore = createPersistStore(
100102
},
101103
})
102104
.then((res) => res.json())
105+
.then((res) => {
106+
// Set default model from env request
107+
let defaultModel = res.defaultModel ?? "";
108+
DEFAULT_CONFIG.modelConfig.model =
109+
defaultModel !== "" ? defaultModel : "gpt-3.5-turbo";
110+
return res;
111+
})
103112
.then((res: DangerConfig) => {
104113
console.log("[Config] got config from server", res);
105114
set(() => ({ ...res }));

app/styles/globals.scss

+5
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
@include dark;
8787
}
8888
}
89+
8990
html {
9091
height: var(--full-height);
9192

@@ -110,6 +111,10 @@ body {
110111
@media only screen and (max-width: 600px) {
111112
background-color: var(--second);
112113
}
114+
115+
*:focus-visible {
116+
outline: none;
117+
}
113118
}
114119

115120
::-webkit-scrollbar {

app/utils/hooks.ts

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import { useMemo } from "react";
22
import { useAccessStore, useAppConfig } from "../store";
3-
import { collectModels } from "./model";
3+
import { collectModels, collectModelsWithDefaultModel } from "./model";
44

55
export function useAllModels() {
66
const accessStore = useAccessStore();
77
const configStore = useAppConfig();
88
const models = useMemo(() => {
9-
return collectModels(
9+
return collectModelsWithDefaultModel(
1010
configStore.models,
1111
[configStore.customModels, accessStore.customModels].join(","),
12+
accessStore.defaultModel,
1213
);
1314
}, [accessStore.customModels, configStore.customModels, configStore.models]);
1415

app/utils/model.ts

+42-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
import { LLMModel } from "../client/api";
22

3+
const customProvider = (modelName: string) => ({
4+
id: modelName,
5+
providerName: "",
6+
providerType: "custom",
7+
});
8+
39
export function collectModelTable(
410
models: readonly LLMModel[],
511
customModels: string,
@@ -11,6 +17,7 @@ export function collectModelTable(
1117
name: string;
1218
displayName: string;
1319
provider?: LLMModel["provider"]; // Marked as optional
20+
isDefault?: boolean;
1421
}
1522
> = {};
1623

@@ -22,12 +29,6 @@ export function collectModelTable(
2229
};
2330
});
2431

25-
const customProvider = (modelName: string) => ({
26-
id: modelName,
27-
providerName: "",
28-
providerType: "custom",
29-
});
30-
3132
// server custom models
3233
customModels
3334
.split(",")
@@ -52,6 +53,27 @@ export function collectModelTable(
5253
};
5354
}
5455
});
56+
57+
return modelTable;
58+
}
59+
60+
export function collectModelTableWithDefaultModel(
61+
models: readonly LLMModel[],
62+
customModels: string,
63+
defaultModel: string,
64+
) {
65+
let modelTable = collectModelTable(models, customModels);
66+
if (defaultModel && defaultModel !== "") {
67+
delete modelTable[defaultModel];
68+
modelTable[defaultModel] = {
69+
name: defaultModel,
70+
displayName: defaultModel,
71+
available: true,
72+
provider:
73+
modelTable[defaultModel]?.provider ?? customProvider(defaultModel),
74+
isDefault: true,
75+
};
76+
}
5577
return modelTable;
5678
}
5779

@@ -67,3 +89,17 @@ export function collectModels(
6789

6890
return allModels;
6991
}
92+
93+
export function collectModelsWithDefaultModel(
94+
models: readonly LLMModel[],
95+
customModels: string,
96+
defaultModel: string,
97+
) {
98+
const modelTable = collectModelTableWithDefaultModel(
99+
models,
100+
customModels,
101+
defaultModel,
102+
);
103+
const allModels = Object.values(modelTable);
104+
return allModels;
105+
}

0 commit comments

Comments
 (0)