Skip to content

Commit

Permalink
Merge pull request #3422 from continuedev/v0.8.64-vscode-release
Browse files Browse the repository at this point in the history
V0.8.64-vscode-release
  • Loading branch information
sestinj authored Dec 17, 2024
2 parents d2d3c30 + d291dbd commit b92ac82
Show file tree
Hide file tree
Showing 12 changed files with 96 additions and 32 deletions.
4 changes: 4 additions & 0 deletions .changes/unreleased/Fixed-20241216-220802.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
project: extensions/vscode
kind: Fixed
body: Fix tool use bug for models that don't support tools
time: 2024-12-16T22:08:02.399772-08:00
4 changes: 4 additions & 0 deletions .changes/unreleased/Fixed-20241216-220818.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
project: extensions/vscode
kind: Fixed
body: Autodetect mistral API key type
time: 2024-12-16T22:08:18.369444-08:00
3 changes: 2 additions & 1 deletion core/llm/autodetect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ const MODEL_SUPPORTS_IMAGES: string[] = [
"llama3.2",
];

function modelSupportsTools(modelName: string) {
function modelSupportsTools(modelName: string, provider: string) {
return (
provider === "anthropic" &&
modelName.includes("claude") &&
(modelName.includes("3-5") || modelName.includes("3.5"))
);
Expand Down
20 changes: 13 additions & 7 deletions core/llm/constructMessages.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { ChatHistoryItem, ChatMessage, MessagePart } from "../index.js";
import { normalizeToMessageParts } from "../util/messageContent.js";
import { ChatHistoryItem, ChatMessage, MessagePart } from "../";
import { normalizeToMessageParts } from "../util/messageContent";

import { modelSupportsTools } from "./autodetect.js";
import { modelSupportsTools } from "./autodetect";

const CUSTOM_SYS_MSG_MODEL_FAMILIES = ["sonnet", "gpt-4o", "mistral-large"];
const CUSTOM_SYS_MSG_MODEL_FAMILIES = ["sonnet"];

const SYSTEM_MESSAGE = `When generating new code:
Expand Down Expand Up @@ -66,11 +66,15 @@ Always follow these guidelines when generating code responses.`;
const TOOL_USE_RULES = `When using tools, follow the following guidelines:
- Avoid calling tools unless they are absolutely necessary. For example, if you are asked a simple programming question you do not need web search. As another example, if the user asks you to explain something about code, do not create a new file.`;

function constructSystemPrompt(model: string): string | null {
function constructSystemPrompt(
model: string,
provider: string,
useTools: boolean,
): string | null {
if (CUSTOM_SYS_MSG_MODEL_FAMILIES.some((family) => model.includes(family))) {
return SYSTEM_MESSAGE + "\n\n" + TOOL_USE_RULES;
}
if (modelSupportsTools(model)) {
if (useTools && modelSupportsTools(model, provider)) {
return TOOL_USE_RULES;
}

Expand All @@ -83,10 +87,12 @@ const CANCELED_TOOL_CALL_MESSAGE =
export function constructMessages(
history: ChatHistoryItem[],
model: string,
provider: string,
useTools: boolean,
): ChatMessage[] {
const msgs: ChatMessage[] = [];

const systemMessage = constructSystemPrompt(model);
const systemMessage = constructSystemPrompt(model, provider, useTools);
if (systemMessage) {
msgs.push({
role: "system" as const,
Expand Down
18 changes: 11 additions & 7 deletions core/llm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ export abstract class BaseLLM implements ILLM {

private _llmOptions: LLMOptions;

private openaiAdapter?: BaseLlmApi;
protected openaiAdapter?: BaseLlmApi;

constructor(_options: LLMOptions) {
this._llmOptions = _options;
Expand Down Expand Up @@ -212,12 +212,7 @@ export abstract class BaseLLM implements ILLM {
this.projectId = options.projectId;
this.profile = options.profile;

this.openaiAdapter = constructLlmApi({
provider: this.providerName as any,
apiKey: this.apiKey ?? "",
apiBase: this.apiBase,
requestOptions: this.requestOptions,
});
this.openaiAdapter = this.createOpenAiAdapter();

this.maxEmbeddingBatchSize =
options.maxEmbeddingBatchSize ?? DEFAULT_MAX_BATCH_SIZE;
Expand All @@ -226,6 +221,15 @@ export abstract class BaseLLM implements ILLM {
this.embeddingId = `${this.constructor.name}::${this.model}::${this.maxEmbeddingChunkSize}`;
}

protected createOpenAiAdapter() {
return constructLlmApi({
provider: this.providerName as any,
apiKey: this.apiKey ?? "",
apiBase: this.apiBase,
requestOptions: this.requestOptions,
});
}

listModels(): Promise<string[]> {
return Promise.resolve([]);
}
Expand Down
31 changes: 31 additions & 0 deletions core/llm/llms/Mistral.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import { codestralEditPrompt } from "../templates/edit/codestral.js";

import OpenAI from "./OpenAI.js";

type MistralApiKeyType = "mistral" | "codestral";

class Mistral extends OpenAI {
static providerName = "mistral";
static defaultOptions: Partial<LLMOptions> = {
Expand All @@ -14,6 +16,17 @@ class Mistral extends OpenAI {
maxEmbeddingBatchSize: 128,
};

private async autodetectApiKeyType(): Promise<MistralApiKeyType> {
const mistralResp = await fetch("https://api.mistral.ai/v1/models", {
method: "GET",
headers: this._getHeaders(),
});
if (mistralResp.status === 401) {
return "codestral";
}
return "mistral";
}

constructor(options: LLMOptions) {
super(options);
if (
Expand All @@ -26,6 +39,24 @@ class Mistral extends OpenAI {
if (!this.apiBase?.endsWith("/")) {
this.apiBase += "/";
}

// Unless the user explicitly specifies, we will autodetect the API key type and adjust the API base accordingly
if (!options.apiBase) {
this.autodetectApiKeyType()
.then((keyType) => {
switch (keyType) {
case "codestral":
this.apiBase = "https://codestral.mistral.ai/v1/";
break;
case "mistral":
this.apiBase = "https://api.mistral.ai/v1/";
break;
}

this.openaiAdapter = this.createOpenAiAdapter();
})
.catch((err: any) => {});
}
}

private static modelConversion: { [key: string]: string } = {
Expand Down
7 changes: 5 additions & 2 deletions core/llm/openaiTypeConverters.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ export function toChatBody(
},
}));

return {
const params: ChatCompletionCreateParams = {
messages: messages.map(toChatMessage),
model: options.model,
max_tokens: options.maxTokens,
Expand All @@ -77,9 +77,12 @@ export function toChatBody(
presence_penalty: options.presencePenalty,
stream: options.stream ?? true,
stop: options.stop,
tools,
prediction: options.prediction,
};
if (tools?.length) {
params.tools = tools;
}
return params;
}

export function toCompleteBody(
Expand Down
2 changes: 2 additions & 0 deletions extensions/vscode/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Pre-release Changes
### Fixed
* Display more mid-line completions
* Restored syntax highlighting
* Fix tool use bug for models that don't support tools
* Autodetect mistral API key type

## 0.8.62 - 2024-12-10
### Added
Expand Down
2 changes: 1 addition & 1 deletion extensions/vscode/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "continue",
"icon": "media/icon.png",
"author": "Continue Dev, Inc",
"version": "0.9.245",
"version": "0.8.64",
"repository": {
"type": "git",
"url": "https://github.com/continuedev/continue"
Expand Down
18 changes: 9 additions & 9 deletions gui/src/components/mainInput/InputToolbar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@ import {
vscForeground,
vscInputBackground,
} from "..";
import { useAppDispatch, useAppSelector } from "../../redux/hooks";
import { selectUseActiveFile } from "../../redux/selectors";
import { selectDefaultModel } from "../../redux/slices/configSlice";
import {
selectHasCodeToEdit,
selectIsInEditMode,
} from "../../redux/slices/sessionSlice";
import { exitEditMode } from "../../redux/thunks";
import { loadLastSession } from "../../redux/thunks/session";
import {
getAltKeyLabel,
getFontSize,
Expand All @@ -20,14 +28,6 @@ import { ToolTip } from "../gui/Tooltip";
import ModelSelect from "../modelSelection/ModelSelect";
import HoverItem from "./InputToolbar/HoverItem";
import ToggleToolsButton from "./InputToolbar/ToggleToolsButton";
import { useAppDispatch, useAppSelector } from "../../redux/hooks";
import { selectDefaultModel } from "../../redux/slices/configSlice";
import {
selectHasCodeToEdit,
selectIsInEditMode,
} from "../../redux/slices/sessionSlice";
import { exitEditMode } from "../../redux/thunks";
import { loadLastSession } from "../../redux/thunks/session";

const StyledDiv = styled.div<{ isHidden?: boolean }>`
padding-top: 4px;
Expand Down Expand Up @@ -93,7 +93,7 @@ function InputToolbar(props: InputToolbarProps) {
const isEnterDisabled = props.disabled || isEditModeAndNoCodeToEdit;
const shouldRenderToolsButton =
defaultModel &&
modelSupportsTools(defaultModel.model) &&
modelSupportsTools(defaultModel.model, defaultModel.provider) &&
!props.toolbarOptions?.hideTools;

const supportsImages =
Expand Down
14 changes: 10 additions & 4 deletions gui/src/redux/thunks/streamResponse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ import { InputModifiers, MessageContent, SlashCommandDescription } from "core";
import { constructMessages } from "core/llm/constructMessages";
import { renderChatMessage } from "core/util/messageContent";
import posthog from "posthog-js";
import { v4 as uuidv4 } from "uuid";
import { selectDefaultModel } from "../slices/configSlice";
import {
submitEditorAndInitAtIndex,
updateHistoryItemAtIndex,
} from "../slices/sessionSlice";
import { ThunkApiType } from "../store";
import { gatherContext } from "./gatherContext";
import { streamThunkWrapper } from "./streamThunkWrapper";
import { resetStateForNewMessage } from "./resetStateForNewMessage";
import { streamNormalInput } from "./streamNormalInput";
import { streamSlashCommand } from "./streamSlashCommand";
import { selectDefaultModel } from "../slices/configSlice";
import { v4 as uuidv4 } from "uuid";
import { streamThunkWrapper } from "./streamThunkWrapper";

const getSlashCommandForInput = (
input: MessageContent,
Expand Down Expand Up @@ -61,6 +61,7 @@ export const streamResponseThunk = createAsyncThunk<
await dispatch(
streamThunkWrapper(async () => {
const state = getState();
const useTools = state.ui.useTools;
const defaultModel = selectDefaultModel(state);
const slashCommands = state.config.config.slashCommands || [];
const inputIndex = index ?? state.session.history.length;
Expand Down Expand Up @@ -94,7 +95,12 @@ export const streamResponseThunk = createAsyncThunk<

// Construct messages from updated history
const updatedHistory = getState().session.history;
const messages = constructMessages(updatedHistory, defaultModel.model);
const messages = constructMessages(
updatedHistory,
defaultModel.model,
defaultModel.provider,
useTools,
);

posthog.capture("step run", {
step_name: "User Input",
Expand Down
5 changes: 4 additions & 1 deletion gui/src/redux/thunks/streamResponseAfterToolCall.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import {
streamUpdate,
} from "../slices/sessionSlice";
import { ThunkApiType } from "../store";
import { streamThunkWrapper } from "./streamThunkWrapper";
import { resetStateForNewMessage } from "./resetStateForNewMessage";
import { streamNormalInput } from "./streamNormalInput";
import { streamThunkWrapper } from "./streamThunkWrapper";

export const streamResponseAfterToolCall = createAsyncThunk<
void,
Expand All @@ -26,6 +26,7 @@ export const streamResponseAfterToolCall = createAsyncThunk<
await dispatch(
streamThunkWrapper(async () => {
const state = getState();
const useTools = state.ui.useTools;
const initialHistory = state.session.history;
const defaultModel = selectDefaultModel(state);

Expand Down Expand Up @@ -59,6 +60,8 @@ export const streamResponseAfterToolCall = createAsyncThunk<
const messages = constructMessages(
[...updatedHistory],
defaultModel.model,
defaultModel.provider,
useTools,
);
unwrapResult(await dispatch(streamNormalInput(messages)));
}),
Expand Down

0 comments on commit b92ac82

Please sign in to comment.