Skip to content

Commit 85beb4e

Browse files
committed
feat(backend): if no temperature is sent with chat request, use agent temperature
1 parent 3195de9 commit 85beb4e

File tree

4 files changed

+17
-17
lines changed

4 files changed

+17
-17
lines changed

src/backend/routers/chat.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Generator
22

3-
from fastapi import APIRouter, Depends, Request
3+
from fastapi import APIRouter, Depends
44
from sse_starlette.sse import EventSourceResponse
55

66
from backend.chat.custom.custom import CustomChat
@@ -31,7 +31,6 @@
3131
async def chat_stream(
3232
session: DBSessionDep,
3333
chat_request: CohereChatRequest,
34-
request: Request,
3534
ctx: Context = Depends(get_context),
3635
) -> Generator[ChatResponseEvent, Any, None]:
3736
"""
@@ -58,7 +57,7 @@ async def chat_stream(
5857
managed_tools,
5958
next_message_position,
6059
ctx,
61-
) = process_chat(session, chat_request, request, ctx)
60+
) = process_chat(session, chat_request, ctx)
6261

6362
return EventSourceResponse(
6463
generate_chat_stream(
@@ -86,7 +85,6 @@ async def chat_stream(
8685
async def regenerate_chat_stream(
8786
session: DBSessionDep,
8887
chat_request: CohereChatRequest,
89-
request: Request,
9088
ctx: Context = Depends(get_context),
9189
) -> EventSourceResponse:
9290
"""
@@ -127,7 +125,7 @@ async def regenerate_chat_stream(
127125
previous_response_message_ids,
128126
managed_tools,
129127
ctx,
130-
) = process_message_regeneration(session, chat_request, request, ctx)
128+
) = process_message_regeneration(session, chat_request, ctx)
131129

132130
return EventSourceResponse(
133131
generate_chat_stream(
@@ -155,7 +153,6 @@ async def regenerate_chat_stream(
155153
async def chat(
156154
session: DBSessionDep,
157155
chat_request: CohereChatRequest,
158-
request: Request,
159156
ctx: Context = Depends(get_context),
160157
) -> NonStreamedChatResponse:
161158
"""
@@ -197,7 +194,7 @@ async def chat(
197194
managed_tools,
198195
next_message_position,
199196
ctx,
200-
) = process_chat(session, chat_request, request, ctx)
197+
) = process_chat(session, chat_request, ctx)
201198

202199
response = await generate_chat_response(
203200
session,

src/backend/services/chat.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import nltk
66
from cohere.types import StreamedChatResponse
7-
from fastapi import HTTPException, Request
7+
from fastapi import HTTPException
88
from fastapi.encoders import jsonable_encoder
99

1010
from backend.chat.collate import to_dict
@@ -74,19 +74,17 @@ def generate_tools_preamble(chat_request: CohereChatRequest) -> str:
7474

7575
def process_chat(
7676
session: DBSessionDep,
77-
chat_request: BaseChatRequest,
78-
request: Request,
77+
chat_request: CohereChatRequest,
7978
ctx: Context,
8079
) -> tuple[
81-
DBSessionDep, BaseChatRequest, Union[list[str], None], Message, str, str, dict
80+
DBSessionDep, CohereChatRequest, Union[list[str], None], Message, str, str, Context
8281
]:
8382
"""
8483
Process a chat request.
8584
8685
Args:
87-
chat_request (BaseChatRequest): Chat request data.
86+
chat_request (CohereChatRequest): Chat request data.
8887
session (DBSessionDep): Database session.
89-
request (Request): Request object.
9088
ctx (Context): Context object.
9189
9290
Returns:
@@ -124,6 +122,10 @@ def process_chat(
124122
chat_request.model = agent.model
125123
chat_request.preamble = agent.preamble
126124

125+
# If temperature is not defined in the chat request, use the temperature from the agent
126+
if not chat_request.temperature:
127+
chat_request.temperature = agent.temperature
128+
127129
should_store = chat_request.chat_history is None and not is_custom_tool_call(
128130
chat_request
129131
)
@@ -193,7 +195,6 @@ def process_chat(
193195
def process_message_regeneration(
194196
session: DBSessionDep,
195197
chat_request: CohereChatRequest,
196-
request: Request,
197198
ctx: Context,
198199
) -> tuple[Any, CohereChatRequest, Message, list[str], bool, Context]:
199200
"""
@@ -202,7 +203,6 @@ def process_message_regeneration(
202203
Args:
203204
session (DBSessionDep): Database session.
204205
chat_request (CohereChatRequest): Chat request data.
205-
request (Request): Request object.
206206
ctx (Context): Context object.
207207
208208
Returns:
@@ -224,6 +224,10 @@ def process_message_regeneration(
224224
# Set the agent settings in the chat request
225225
chat_request.preamble = agent.preamble
226226

227+
# If temperature is not defined in the chat request, use the temperature from the agent
228+
if not chat_request.temperature:
229+
chat_request.temperature = agent.temperature
230+
227231
conversation_id = chat_request.conversation_id
228232
ctx.with_conversation_id(conversation_id)
229233

src/backend/tests/unit/factories/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class Meta:
1919
description = factory.Faker("sentence")
2020
preamble = factory.Faker("sentence")
2121
version = factory.Faker("random_int")
22-
temperature = factory.Faker("pyfloat")
22+
temperature = factory.Faker("pyfloat", min_value=0.0, max_value=1.0)
2323
created_at = factory.Faker("date_time")
2424
updated_at = factory.Faker("date_time")
2525
tools = factory.List(

src/backend/tests/unit/routers/test_chat.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,6 @@ def test_streaming_fail_chat_missing_message(
311311
"loc": ["body", "message"],
312312
"msg": "Field required",
313313
"input": {},
314-
"url": "https://errors.pydantic.dev/2.10/v/missing",
315314
}
316315
]
317316
}

0 commit comments

Comments
 (0)