From 95dbc9d9c2ef75239d87dd2d849bb5439b380203 Mon Sep 17 00:00:00 2001 From: Henry Fontanier Date: Mon, 23 Dec 2024 11:21:25 +0100 Subject: [PATCH] feat: add reasoning effort to agent config model + support in core --- core/src/blocks/chat.rs | 3 +++ core/src/blocks/llm.rs | 3 +++ core/src/providers/azure_openai.rs | 10 +++++++-- core/src/providers/openai.rs | 32 +++++++++++++++++++-------- core/src/providers/togetherai.rs | 2 ++ front/lib/models/assistant/agent.ts | 6 +++++ front/migrations/db/migration_136.sql | 26 ++++++++++++++++++++++ types/src/front/assistant/agent.ts | 5 +++++ 8 files changed, 76 insertions(+), 11 deletions(-) create mode 100644 front/migrations/db/migration_136.sql diff --git a/core/src/blocks/chat.rs b/core/src/blocks/chat.rs index 816dce2d3589a..c1cbd0a2c034e 100644 --- a/core/src/blocks/chat.rs +++ b/core/src/blocks/chat.rs @@ -290,6 +290,9 @@ impl Block for Chat { if let Some(Value::String(s)) = v.get("response_format") { extras["response_format"] = json!(s.clone()); } + if let Some(Value::String(s)) = v.get("reasoning_effort") { + extras["reasoning_effort"] = json!(s.clone()); + } match extras.as_object().unwrap().keys().len() { 0 => None, diff --git a/core/src/blocks/llm.rs b/core/src/blocks/llm.rs index b75fc828f8bec..95de154eece38 100644 --- a/core/src/blocks/llm.rs +++ b/core/src/blocks/llm.rs @@ -406,6 +406,9 @@ impl Block for LLM { if let Some(Value::String(s)) = v.get("response_format") { extras["response_format"] = json!(s.clone()); } + if let Some(Value::String(s)) = v.get("reasoning_effort") { + extras["reasoning_effort"] = json!(s.clone()); + } match extras.as_object().unwrap().keys().len() { 0 => None, diff --git a/core/src/providers/azure_openai.rs b/core/src/providers/azure_openai.rs index 3170328809470..404a9bdaff520 100644 --- a/core/src/providers/azure_openai.rs +++ b/core/src/providers/azure_openai.rs @@ -447,8 +447,8 @@ impl LLM for AzureOpenAILLM { } } - let (openai_user, response_format) = match &extras { - None => (None, None), + let (openai_user, response_format, reasoning_effort) = match &extras { + None => (None, None, None), Some(v) => ( match v.get("openai_user") { Some(Value::String(u)) => Some(u.to_string()), @@ -458,6 +458,10 @@ impl LLM for AzureOpenAILLM { Some(Value::String(f)) => Some(f.to_string()), _ => None, }, + match v.get("reasoning_effort") { + Some(Value::String(r)) => Some(r.to_string()), + _ => None, + }, ), }; @@ -501,6 +505,7 @@ impl LLM for AzureOpenAILLM { }, response_format, openai_user, + reasoning_effort, event_sender, ) .await? @@ -531,6 +536,7 @@ impl LLM for AzureOpenAILLM { None => 0.0, }, response_format, + reasoning_effort, openai_user, ) .await? diff --git a/core/src/providers/openai.rs b/core/src/providers/openai.rs index cf29bffc1e00f..175fb2945aade 100644 --- a/core/src/providers/openai.rs +++ b/core/src/providers/openai.rs @@ -1014,6 +1014,7 @@ pub async fn streamed_chat_completion( presence_penalty: f32, frequency_penalty: f32, response_format: Option, + reasoning_effort: Option, user: Option, event_sender: Option>, ) -> Result<(OpenAIChatCompletion, Option)> { @@ -1071,14 +1072,17 @@ pub async fn streamed_chat_completion( if tools.len() > 0 { body["tools"] = json!(tools); } - if tool_choice.is_some() { + if let Some(tool_choice) = tool_choice { body["tool_choice"] = json!(tool_choice); } - if response_format.is_some() { + if let Some(response_format) = response_format { body["response_format"] = json!({ - "type": response_format.unwrap(), + "type": response_format, }); } + if let Some(reasoning_effort) = reasoning_effort { + body["reasoning_effort"] = json!(reasoning_effort); + } let client = builder .body(body.to_string()) @@ -1436,6 +1440,7 @@ pub async fn chat_completion( presence_penalty: f32, frequency_penalty: f32, response_format: Option, + reasoning_effort: Option, user: Option, ) -> Result<(OpenAIChatCompletion, Option)> { let mut body = json!({ @@ -1449,7 +1454,7 @@ pub async fn chat_completion( if user.is_some() { body["user"] = json!(user); } - if model_id.is_some() { + if let Some(model_id) = model_id { body["model"] = json!(model_id); } if let Some(mt) = max_tokens { @@ -1459,17 +1464,20 @@ pub async fn chat_completion( body["stop"] = json!(stop); } - if response_format.is_some() { + if let Some(response_format) = response_format { body["response_format"] = json!({ - "type": response_format.unwrap(), + "type": response_format, }); } if tools.len() > 0 { body["tools"] = json!(tools); } - if tool_choice.is_some() { + if let Some(tool_choice) = tool_choice { body["tool_choice"] = json!(tool_choice); } + if let Some(reasoning_effort) = reasoning_effort { + body["reasoning_effort"] = json!(reasoning_effort); + } let mut req = reqwest::Client::new() .post(uri.to_string()) @@ -2005,8 +2013,8 @@ impl LLM for OpenAILLM { } } - let (openai_org_id, openai_user, response_format) = match &extras { - None => (None, None, None), + let (openai_org_id, openai_user, response_format, reasoning_effort) = match &extras { + None => (None, None, None, None), Some(v) => ( match v.get("openai_organization_id") { Some(Value::String(o)) => Some(o.to_string()), @@ -2020,6 +2028,10 @@ impl LLM for OpenAILLM { Some(Value::String(f)) => Some(f.to_string()), _ => None, }, + match v.get("reasoning_effort") { + Some(Value::String(r)) => Some(r.to_string()), + _ => None, + }, ), }; @@ -2065,6 +2077,7 @@ impl LLM for OpenAILLM { None => 0.0, }, response_format, + reasoning_effort, openai_user, event_sender.clone(), ) @@ -2095,6 +2108,7 @@ impl LLM for OpenAILLM { None => 0.0, }, response_format, + reasoning_effort, openai_user, ) .await? diff --git a/core/src/providers/togetherai.rs b/core/src/providers/togetherai.rs index 806f5b1bed34b..5bb0336d84d99 100644 --- a/core/src/providers/togetherai.rs +++ b/core/src/providers/togetherai.rs @@ -195,6 +195,7 @@ impl LLM for TogetherAILLM { }, None, None, + None, event_sender.clone(), ) .await? @@ -225,6 +226,7 @@ impl LLM for TogetherAILLM { }, None, None, + None, ) .await? }; diff --git a/front/lib/models/assistant/agent.ts b/front/lib/models/assistant/agent.ts index 64b95c7c985fb..207b0e47f933b 100644 --- a/front/lib/models/assistant/agent.ts +++ b/front/lib/models/assistant/agent.ts @@ -1,5 +1,6 @@ import type { AgentConfigurationScope, + AgentReasoningEffort, AgentStatus, GlobalAgentStatus, ModelIdType, @@ -34,6 +35,7 @@ export class AgentConfiguration extends BaseModel { declare providerId: ModelProviderIdType; declare modelId: ModelIdType; declare temperature: number; + declare reasoningEffort: AgentReasoningEffort | null; declare pictureUrl: string; @@ -107,6 +109,10 @@ AgentConfiguration.init( allowNull: false, defaultValue: 0.7, }, + reasoningEffort: { + type: DataTypes.STRING, + allowNull: true, + }, maxStepsPerRun: { type: DataTypes.INTEGER, allowNull: true, diff --git a/front/migrations/db/migration_136.sql b/front/migrations/db/migration_136.sql new file mode 100644 index 0000000000000..8a2eee91220e1 --- /dev/null +++ b/front/migrations/db/migration_136.sql @@ -0,0 +1,26 @@ +-- -- This migration is dependant on a backfill script +-- -- The backfill script is: +-- -- run psql with --set=backfilled=1 argument if you have rune the script. + +CREATE OR REPLACE FUNCTION perform_migration(backfilled boolean DEFAULT false) +RETURNS VARCHAR AS $$ +BEGIN + IF NOT backfilled THEN + RAISE NOTICE 'The backfill script: is required before applying this migation. If you already did it, run psql with --set=backfilled=1 argument.'; + END IF; + +-- Migration created on Dec 23, 2024 +ALTER TABLE "public"."agent_configurations" ADD COLUMN "reasoningEffort" VARCHAR(255); + + RETURN 'success'; +END; +$$ LANGUAGE plpgsql; + +\if :{?backfilled} + SELECT perform_migration(:'backfilled'::boolean); +\else + \echo '!! Migration was NOT applied !!' + \echo 'The backfill script: is required before applying this migation. If you already did it, run psql with --set=backfilled=1 argument.' +\endif + +DROP FUNCTION perform_migration(boolean); diff --git a/types/src/front/assistant/agent.ts b/types/src/front/assistant/agent.ts index 62630dbfa245f..33ae3f2143884 100644 --- a/types/src/front/assistant/agent.ts +++ b/types/src/front/assistant/agent.ts @@ -176,10 +176,13 @@ export type AgentUsageType = { export type AgentRecentAuthors = readonly string[]; +export type AgentReasoningEffort = "low" | "medium" | "high"; + export type AgentModelConfigurationType = { providerId: ModelProviderIdType; modelId: ModelIdType; temperature: number; + reasoningEffort?: AgentReasoningEffort; }; export type LightAgentConfigurationType = { @@ -228,6 +231,8 @@ export type LightAgentConfigurationType = { // // Example: [[1,2], [3,4]] means (1 OR 2) AND (3 OR 4) requestedGroupIds: string[][]; + + reasoningEffort?: AgentReasoningEffort; }; export type AgentConfigurationType = LightAgentConfigurationType & {