Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added support for deepseek-reasoner #410

Merged
merged 1 commit into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions gptme/llm/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_client(provider: Provider) -> "OpenAI":
return clients[provider]


def _prep_o1(msgs: list[Message]) -> Generator[Message, None, None]:
def _prep_o1(msgs: Iterable[Message]) -> Generator[Message, None, None]:
# prepare messages for OpenAI O1, which doesn't support the system role
# and requires the first message to be from the user
for msg in msgs:
Expand All @@ -101,6 +101,44 @@ def _prep_o1(msgs: list[Message]) -> Generator[Message, None, None]:
yield msg


def _merge_consecutive(msgs: Iterable[Message]) -> Generator[Message, None, None]:
# if consecutive messages from same role, merge them
last_message = None
for msg in msgs:
if last_message is None:
last_message = msg
continue

if last_message.role == msg.role:
last_message = last_message.replace(
content=f"{last_message.content}\n{msg.content}"
)
continue
else:
yield last_message
last_message = msg

if last_message:
yield last_message


assert (
len(
list(
_merge_consecutive(
[Message(role="user", content="a"), Message(role="user", content="b")]
)
)
)
== 1
)


def _prep_deepseek_reasoner(msgs: list[Message]) -> Generator[Message, None, None]:
yield msgs[0]
yield from _merge_consecutive(_prep_o1(msgs[1:]))


def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> str:
# This will generate code and such, so we need appropriate temperature and top_p params
# top_p controls diversity, temperature controls randomness
Expand All @@ -114,14 +152,16 @@ def chat(messages: list[Message], model: str, tools: list[ToolSpec] | None) -> s
from openai import NOT_GIVEN # fmt: skip

is_o1 = base_model.startswith("o1")
is_deepseek_reasoner = base_model == "deepseek-reasoner"
is_reasoner = is_o1 or is_deepseek_reasoner

messages_dicts, tools_dict = _prepare_messages_for_api(messages, model, tools)

response = client.chat.completions.create(
model=base_model,
messages=messages_dicts, # type: ignore
temperature=TEMPERATURE if not is_o1 else NOT_GIVEN,
top_p=TOP_P if not is_o1 else NOT_GIVEN,
temperature=TEMPERATURE if not is_reasoner else NOT_GIVEN,
top_p=TOP_P if not is_reasoner else NOT_GIVEN,
tools=tools_dict if tools_dict else NOT_GIVEN,
extra_headers=(openrouter_headers if provider == "openrouter" else {}),
)
Expand Down Expand Up @@ -153,14 +193,16 @@ def stream(
from openai import NOT_GIVEN # fmt: skip

is_o1 = base_model.startswith("o1")
is_deepseek_reasoner = base_model == "deepseek-reasoner"
is_reasoner = is_o1 or is_deepseek_reasoner

messages_dicts, tools_dict = _prepare_messages_for_api(messages, model, tools)

for chunk_raw in client.chat.completions.create(
model=base_model,
messages=messages_dicts, # type: ignore
temperature=TEMPERATURE if not is_o1 else NOT_GIVEN,
top_p=TOP_P if not is_o1 else NOT_GIVEN,
temperature=TEMPERATURE if not is_reasoner else NOT_GIVEN,
top_p=TOP_P if not is_reasoner else NOT_GIVEN,
stream=True,
tools=tools_dict if tools_dict else NOT_GIVEN,
# the llama-cpp-python server needs this explicitly set, otherwise unreliable results
Expand Down Expand Up @@ -449,15 +491,16 @@ def _spec2tool(spec: ToolSpec, model: ModelMeta) -> "ChatCompletionToolParam":
def _prepare_messages_for_api(
messages: list[Message], model: str, tools: list[ToolSpec] | None
) -> tuple[Iterable[dict], Iterable["ChatCompletionToolParam"] | None]:
from . import _get_base_model, get_provider_from_model # fmt: skip
from . import _get_base_model # fmt: skip
from .models import get_model # fmt: skip

get_provider_from_model(model)
model_meta = get_model(model)

is_o1 = _get_base_model(model).startswith("o1")
if is_o1:
messages = list(_prep_o1(messages))
if model_meta.model == "deepseek-reasoner":
messages = list(_prep_deepseek_reasoner(messages))

messages_dicts: Iterable[dict] = (
_process_file(msg, model_meta) for msg in msgs2dicts(messages)
Expand Down
12 changes: 9 additions & 3 deletions gptme/llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,17 @@ class _ModelDictMeta(TypedDict):
"deepseek": {
"deepseek-chat": {
"context": 64_000,
"max_output": 4000,
"max_output": 8192,
# 10x better price for cache hits
"price_input": 0.14,
"price_output": 0.28,
}
"price_output": 1.1,
},
"deepseek-reasoner": {
"context": 64_000,
"max_output": 8192,
"price_input": 0.55,
"price_output": 2.19,
},
},
# https://groq.com/pricing/
"groq": {
Expand Down
Loading