Skip to content

Commit

Permalink
feat: add reasoning effort to agent config model + support in core (#…
Browse files Browse the repository at this point in the history
…9603)

Co-authored-by: Henry Fontanier <henry@dust.tt>
  • Loading branch information
fontanierh and Henry Fontanier authored Dec 23, 2024
1 parent b81d98f commit a6aa46e
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 11 deletions.
3 changes: 3 additions & 0 deletions core/src/blocks/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ impl Block for Chat {
if let Some(Value::String(s)) = v.get("response_format") {
extras["response_format"] = json!(s.clone());
}
if let Some(Value::String(s)) = v.get("reasoning_effort") {
extras["reasoning_effort"] = json!(s.clone());
}

match extras.as_object().unwrap().keys().len() {
0 => None,
Expand Down
3 changes: 3 additions & 0 deletions core/src/blocks/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ impl Block for LLM {
if let Some(Value::String(s)) = v.get("response_format") {
extras["response_format"] = json!(s.clone());
}
if let Some(Value::String(s)) = v.get("reasoning_effort") {
extras["reasoning_effort"] = json!(s.clone());
}

match extras.as_object().unwrap().keys().len() {
0 => None,
Expand Down
10 changes: 8 additions & 2 deletions core/src/providers/azure_openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,8 @@ impl LLM for AzureOpenAILLM {
}
}

let (openai_user, response_format) = match &extras {
None => (None, None),
let (openai_user, response_format, reasoning_effort) = match &extras {
None => (None, None, None),
Some(v) => (
match v.get("openai_user") {
Some(Value::String(u)) => Some(u.to_string()),
Expand All @@ -458,6 +458,10 @@ impl LLM for AzureOpenAILLM {
Some(Value::String(f)) => Some(f.to_string()),
_ => None,
},
match v.get("reasoning_effort") {
Some(Value::String(r)) => Some(r.to_string()),
_ => None,
},
),
};

Expand Down Expand Up @@ -501,6 +505,7 @@ impl LLM for AzureOpenAILLM {
},
response_format,
openai_user,
reasoning_effort,
event_sender,
)
.await?
Expand Down Expand Up @@ -531,6 +536,7 @@ impl LLM for AzureOpenAILLM {
None => 0.0,
},
response_format,
reasoning_effort,
openai_user,
)
.await?
Expand Down
32 changes: 23 additions & 9 deletions core/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,7 @@ pub async fn streamed_chat_completion(
presence_penalty: f32,
frequency_penalty: f32,
response_format: Option<String>,
reasoning_effort: Option<String>,
user: Option<String>,
event_sender: Option<UnboundedSender<Value>>,
) -> Result<(OpenAIChatCompletion, Option<String>)> {
Expand Down Expand Up @@ -1071,14 +1072,17 @@ pub async fn streamed_chat_completion(
if tools.len() > 0 {
body["tools"] = json!(tools);
}
if tool_choice.is_some() {
if let Some(tool_choice) = tool_choice {
body["tool_choice"] = json!(tool_choice);
}
if response_format.is_some() {
if let Some(response_format) = response_format {
body["response_format"] = json!({
"type": response_format.unwrap(),
"type": response_format,
});
}
if let Some(reasoning_effort) = reasoning_effort {
body["reasoning_effort"] = json!(reasoning_effort);
}

let client = builder
.body(body.to_string())
Expand Down Expand Up @@ -1436,6 +1440,7 @@ pub async fn chat_completion(
presence_penalty: f32,
frequency_penalty: f32,
response_format: Option<String>,
reasoning_effort: Option<String>,
user: Option<String>,
) -> Result<(OpenAIChatCompletion, Option<String>)> {
let mut body = json!({
Expand All @@ -1449,7 +1454,7 @@ pub async fn chat_completion(
if user.is_some() {
body["user"] = json!(user);
}
if model_id.is_some() {
if let Some(model_id) = model_id {
body["model"] = json!(model_id);
}
if let Some(mt) = max_tokens {
Expand All @@ -1459,17 +1464,20 @@ pub async fn chat_completion(
body["stop"] = json!(stop);
}

if response_format.is_some() {
if let Some(response_format) = response_format {
body["response_format"] = json!({
"type": response_format.unwrap(),
"type": response_format,
});
}
if tools.len() > 0 {
body["tools"] = json!(tools);
}
if tool_choice.is_some() {
if let Some(tool_choice) = tool_choice {
body["tool_choice"] = json!(tool_choice);
}
if let Some(reasoning_effort) = reasoning_effort {
body["reasoning_effort"] = json!(reasoning_effort);
}

let mut req = reqwest::Client::new()
.post(uri.to_string())
Expand Down Expand Up @@ -2005,8 +2013,8 @@ impl LLM for OpenAILLM {
}
}

let (openai_org_id, openai_user, response_format) = match &extras {
None => (None, None, None),
let (openai_org_id, openai_user, response_format, reasoning_effort) = match &extras {
None => (None, None, None, None),
Some(v) => (
match v.get("openai_organization_id") {
Some(Value::String(o)) => Some(o.to_string()),
Expand All @@ -2020,6 +2028,10 @@ impl LLM for OpenAILLM {
Some(Value::String(f)) => Some(f.to_string()),
_ => None,
},
match v.get("reasoning_effort") {
Some(Value::String(r)) => Some(r.to_string()),
_ => None,
},
),
};

Expand Down Expand Up @@ -2065,6 +2077,7 @@ impl LLM for OpenAILLM {
None => 0.0,
},
response_format,
reasoning_effort,
openai_user,
event_sender.clone(),
)
Expand Down Expand Up @@ -2095,6 +2108,7 @@ impl LLM for OpenAILLM {
None => 0.0,
},
response_format,
reasoning_effort,
openai_user,
)
.await?
Expand Down
2 changes: 2 additions & 0 deletions core/src/providers/togetherai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ impl LLM for TogetherAILLM {
},
None,
None,
None,
event_sender.clone(),
)
.await?
Expand Down Expand Up @@ -225,6 +226,7 @@ impl LLM for TogetherAILLM {
},
None,
None,
None,
)
.await?
};
Expand Down
6 changes: 6 additions & 0 deletions front/lib/models/assistant/agent.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type {
AgentConfigurationScope,
AgentReasoningEffort,
AgentStatus,
GlobalAgentStatus,
ModelIdType,
Expand Down Expand Up @@ -34,6 +35,7 @@ export class AgentConfiguration extends BaseModel<AgentConfiguration> {
declare providerId: ModelProviderIdType;
declare modelId: ModelIdType;
declare temperature: number;
declare reasoningEffort: AgentReasoningEffort | null;

declare pictureUrl: string;

Expand Down Expand Up @@ -107,6 +109,10 @@ AgentConfiguration.init(
allowNull: false,
defaultValue: 0.7,
},
reasoningEffort: {
type: DataTypes.STRING,
allowNull: true,
},
maxStepsPerRun: {
type: DataTypes.INTEGER,
allowNull: true,
Expand Down
2 changes: 2 additions & 0 deletions front/migrations/db/migration_136.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- Migration created on Dec 23, 2024
ALTER TABLE "public"."agent_configurations" ADD COLUMN "reasoningEffort" VARCHAR(255);
5 changes: 5 additions & 0 deletions types/src/front/assistant/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,13 @@ export type AgentUsageType = {

export type AgentRecentAuthors = readonly string[];

export type AgentReasoningEffort = "low" | "medium" | "high";

export type AgentModelConfigurationType = {
providerId: ModelProviderIdType;
modelId: ModelIdType;
temperature: number;
reasoningEffort?: AgentReasoningEffort;
};

export type LightAgentConfigurationType = {
Expand Down Expand Up @@ -228,6 +231,8 @@ export type LightAgentConfigurationType = {
//
// Example: [[1,2], [3,4]] means (1 OR 2) AND (3 OR 4)
requestedGroupIds: string[][];

reasoningEffort?: AgentReasoningEffort;
};

export type AgentConfigurationType = LightAgentConfigurationType & {
Expand Down

0 comments on commit a6aa46e

Please sign in to comment.