Skip to content

Commit

Permalink
feat: added support for deepseek-reasoner
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Jan 20, 2025
1 parent 0cd60a0 commit 43ed919
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 10 deletions.
57 changes: 50 additions & 7 deletions gptme/llm/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_client(provider: Provider) -> "OpenAI":
return clients[provider]


def _prep_o1(msgs: list[Message]) -> Generator[Message, None, None]:
def _prep_o1(msgs: Iterable[Message]) -> Generator[Message, None, None]:
# prepare messages for OpenAI O1, which doesn't support the system role
# and requires the first message to be from the user
for msg in msgs:
Expand All @@ -101,6 +101,44 @@ def _prep_o1(msgs: list[Message]) -> Generator[Message, None, None]:
yield msg


def _merge_consecutive(msgs: Iterable[Message]) -> Generator[Message, None, None]:
# if consecutive messages from same role, merge them
last_message = None
for msg in msgs:
if last_message is None:
last_message = msg
continue

if last_message.role == msg.role:
last_message = last_message.replace(
content=f"{last_message.content}\n{msg.content}"
)
continue
else:
yield last_message
last_message = msg

if last_message:
yield last_message


assert (
len(
list(
_merge_consecutive(
[Message(role="user", content="a"), Message(role="user", content="b")]
)
)
)
== 1
)


def _prep_deepseek_reasoner(msgs: list[Message]) -> Generator[Message, None, None]:
yield msgs[0]
yield from _merge_consecutive(_prep_o1(msgs[1:]))


def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> str:
# This will generate code and such, so we need appropriate temperature and top_p params
# top_p controls diversity, temperature controls randomness
Expand All @@ -114,14 +152,16 @@ def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> s
from openai import NOT_GIVEN # fmt: skip

is_o1 = base_model.startswith("o1")
is_deepseek_reasoner = base_model == "deepseek-reasoner"
is_reasoner = is_o1 or is_deepseek_reasoner

messages_dicts, tools_dict = _prepare_messages_for_api(messages, model, tools)

response = client.chat.completions.create(
model=base_model,
messages=messages_dicts, # type: ignore
temperature=TEMPERATURE if not is_o1 else NOT_GIVEN,
top_p=TOP_P if not is_o1 else NOT_GIVEN,
temperature=TEMPERATURE if not is_reasoner else NOT_GIVEN,
top_p=TOP_P if not is_reasoner else NOT_GIVEN,
tools=tools_dict if tools_dict else NOT_GIVEN,
extra_headers=(openrouter_headers if provider == "openrouter" else {}),
)
Expand Down Expand Up @@ -153,14 +193,16 @@ def stream(
from openai import NOT_GIVEN # fmt: skip

is_o1 = base_model.startswith("o1")
is_deepseek_reasoner = base_model == "deepseek-reasoner"
is_reasoner = is_o1 or is_deepseek_reasoner

messages_dicts, tools_dict = _prepare_messages_for_api(messages, model, tools)

for chunk_raw in client.chat.completions.create(
model=base_model,
messages=messages_dicts, # type: ignore
temperature=TEMPERATURE if not is_o1 else NOT_GIVEN,
top_p=TOP_P if not is_o1 else NOT_GIVEN,
temperature=TEMPERATURE if not is_reasoner else NOT_GIVEN,
top_p=TOP_P if not is_reasoner else NOT_GIVEN,
stream=True,
tools=tools_dict if tools_dict else NOT_GIVEN,
# the llama-cpp-python server needs this explicitly set, otherwise unreliable results
Expand Down Expand Up @@ -449,15 +491,16 @@ def _spec2tool(spec: ToolSpec, model: ModelMeta) -> "ChatCompletionToolParam":
def _prepare_messages_for_api(
messages: list[Message], model: str, tools: list[ToolSpec] | None
) -> tuple[Iterable[dict], Iterable["ChatCompletionToolParam"] | None]:
from . import _get_base_model, get_provider_from_model # fmt: skip
from . import _get_base_model # fmt: skip
from .models import get_model # fmt: skip

get_provider_from_model(model)
model_meta = get_model(model)

is_o1 = _get_base_model(model).startswith("o1")
if is_o1:
messages = list(_prep_o1(messages))
if model_meta.model == "deepseek-reasoner":
messages = list(_prep_deepseek_reasoner(messages))

messages_dicts: Iterable[dict] = (
_process_file(msg, model_meta) for msg in msgs2dicts(messages)
Expand Down
12 changes: 9 additions & 3 deletions gptme/llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,17 @@ class _ModelDictMeta(TypedDict):
"deepseek": {
"deepseek-chat": {
"context": 64_000,
"max_output": 4000,
"max_output": 8192,
# 10x better price for cache hits
"price_input": 0.14,
"price_output": 0.28,
}
"price_output": 1.1,
},
"deepseek-reasoner": {
"context": 64_000,
"max_output": 8192,
"price_input": 0.55,
"price_output": 2.19,
},
},
# https://groq.com/pricing/
"groq": {
Expand Down

0 comments on commit 43ed919

Please sign in to comment.