diff --git a/gptme/llm/llm_openai.py b/gptme/llm/llm_openai.py index 5ff1e24c..b1cb0b24 100644 --- a/gptme/llm/llm_openai.py +++ b/gptme/llm/llm_openai.py @@ -90,7 +90,7 @@ def get_client(provider: Provider) -> "OpenAI": return clients[provider] -def _prep_o1(msgs: list[Message]) -> Generator[Message, None, None]: +def _prep_o1(msgs: Iterable[Message]) -> Generator[Message, None, None]: # prepare messages for OpenAI O1, which doesn't support the system role # and requires the first message to be from the user for msg in msgs: @@ -101,6 +101,44 @@ def _prep_o1(msgs: list[Message]) -> Generator[Message, None, None]: yield msg +def _merge_consecutive(msgs: Iterable[Message]) -> Generator[Message, None, None]: + # if consecutive messages from same role, merge them + last_message = None + for msg in msgs: + if last_message is None: + last_message = msg + continue + + if last_message.role == msg.role: + last_message = last_message.replace( + content=f"{last_message.content}\n{msg.content}" + ) + continue + else: + yield last_message + last_message = msg + + if last_message: + yield last_message + + +assert ( + len( + list( + _merge_consecutive( + [Message(role="user", content="a"), Message(role="user", content="b")] + ) + ) + ) + == 1 +) + + +def _prep_deepseek_reasoner(msgs: list[Message]) -> Generator[Message, None, None]: + yield msgs[0] + yield from _merge_consecutive(_prep_o1(msgs[1:])) + + def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> str: # This will generate code and such, so we need appropriate temperature and top_p params # top_p controls diversity, temperature controls randomness @@ -114,14 +152,16 @@ def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> s from openai import NOT_GIVEN # fmt: skip is_o1 = base_model.startswith("o1") + is_deepseek_reasoner = base_model == "deepseek-reasoner" + is_reasoner = is_o1 or is_deepseek_reasoner messages_dicts, tools_dict = _prepare_messages_for_api(messages, model, tools) response = client.chat.completions.create( model=base_model, messages=messages_dicts, # type: ignore - temperature=TEMPERATURE if not is_o1 else NOT_GIVEN, - top_p=TOP_P if not is_o1 else NOT_GIVEN, + temperature=TEMPERATURE if not is_reasoner else NOT_GIVEN, + top_p=TOP_P if not is_reasoner else NOT_GIVEN, tools=tools_dict if tools_dict else NOT_GIVEN, extra_headers=(openrouter_headers if provider == "openrouter" else {}), ) @@ -153,14 +193,16 @@ def stream( from openai import NOT_GIVEN # fmt: skip is_o1 = base_model.startswith("o1") + is_deepseek_reasoner = base_model == "deepseek-reasoner" + is_reasoner = is_o1 or is_deepseek_reasoner messages_dicts, tools_dict = _prepare_messages_for_api(messages, model, tools) for chunk_raw in client.chat.completions.create( model=base_model, messages=messages_dicts, # type: ignore - temperature=TEMPERATURE if not is_o1 else NOT_GIVEN, - top_p=TOP_P if not is_o1 else NOT_GIVEN, + temperature=TEMPERATURE if not is_reasoner else NOT_GIVEN, + top_p=TOP_P if not is_reasoner else NOT_GIVEN, stream=True, tools=tools_dict if tools_dict else NOT_GIVEN, # the llama-cpp-python server needs this explicitly set, otherwise unreliable results @@ -449,15 +491,16 @@ def _spec2tool(spec: ToolSpec, model: ModelMeta) -> "ChatCompletionToolParam": def _prepare_messages_for_api( messages: list[Message], model: str, tools: list[ToolSpec] | None ) -> tuple[Iterable[dict], Iterable["ChatCompletionToolParam"] | None]: - from . import _get_base_model, get_provider_from_model # fmt: skip + from . import _get_base_model # fmt: skip from .models import get_model # fmt: skip - get_provider_from_model(model) model_meta = get_model(model) is_o1 = _get_base_model(model).startswith("o1") if is_o1: messages = list(_prep_o1(messages)) + if model_meta.model == "deepseek-reasoner": + messages = list(_prep_deepseek_reasoner(messages)) messages_dicts: Iterable[dict] = ( _process_file(msg, model_meta) for msg in msgs2dicts(messages) diff --git a/gptme/llm/models.py b/gptme/llm/models.py index 490b39f9..bf1d6671 100644 --- a/gptme/llm/models.py +++ b/gptme/llm/models.py @@ -126,11 +126,17 @@ class _ModelDictMeta(TypedDict): "deepseek": { "deepseek-chat": { "context": 64_000, - "max_output": 4000, + "max_output": 8192, # 10x better price for cache hits "price_input": 0.14, - "price_output": 0.28, - } + "price_output": 1.1, + }, + "deepseek-reasoner": { + "context": 64_000, + "max_output": 8192, + "price_input": 0.55, + "price_output": 2.19, + }, }, # https://groq.com/pricing/ "groq": {