diff --git a/core/src/blocks/chat.rs b/core/src/blocks/chat.rs index 6bce6912ad3f..082dc088f59e 100644 --- a/core/src/blocks/chat.rs +++ b/core/src/blocks/chat.rs @@ -277,15 +277,24 @@ impl Block for Chat { } && event_sender.is_some(); let extras = match config { - Some(v) => match v.get("openai_user") { - Some(v) => match v { - Value::String(s) => Some(json!({ - "openai_user": s.clone(), - })), - _ => None, - }, - None => None, - }, + Some(v) => { + let mut extras = json!({}); + + if let Some(Value::String(s)) = v.get("openai_user") { + extras["openai_user"] = json!(s.clone()); + } + if let Some(Value::String(s)) = v.get("openai_organization_id") { + extras["openai_organization_id"] = json!(s.clone()); + } + if let Some(Value::String(s)) = v.get("response_format") { + extras["response_format"] = json!(s.clone()); + } + + match extras.as_object().unwrap().keys().len() { + 0 => None, + _ => Some(extras), + } + } None => None, }; diff --git a/core/src/blocks/llm.rs b/core/src/blocks/llm.rs index 120a7257e372..6b76b5d5e584 100644 --- a/core/src/blocks/llm.rs +++ b/core/src/blocks/llm.rs @@ -393,35 +393,26 @@ impl Block for LLM { None => false, } && event_sender.is_some(); - let mut extras_map = HashMap::new(); - if let Some(v) = config { - if let Some(openai_user) = v.get("openai_user") { - match openai_user { - Value::String(s) => { - extras_map.insert("openai_user", s); - } - _ => Err(anyhow!( - "Invalid `openai_user` in configuration for llm block `{}`", - name - ))?, + let extras = match config { + Some(v) => { + let mut extras = json!({}); + + if let Some(Value::String(s)) = v.get("openai_user") { + extras["openai_user"] = json!(s.clone()); } - } - if let Some(openai_organization_id) = v.get("openai_organization_id") { - match openai_organization_id { - Value::String(s) => { - extras_map.insert("openai_organization_id", s); - } - _ => Err(anyhow!( - "Invalid `openai_organization_id` in configuration for llm block `{}`", - name - ))?, + if let Some(Value::String(s)) = v.get("openai_organization_id") { + extras["openai_organization_id"] = json!(s.clone()); + } + if let Some(Value::String(s)) = v.get("response_format") { + extras["response_format"] = json!(s.clone()); } - } - } - let extras = match extras_map.len() { - 0 => None, - _ => Some(json!(extras_map)), + match extras.as_object().unwrap().keys().len() { + 0 => None, + _ => Some(extras), + } + } + None => None, }; // if model_id starts with gpt-3.5-turbo or gpt-4 use the chat interface (except for diff --git a/core/src/providers/azure_openai.rs b/core/src/providers/azure_openai.rs index 0c85114e01fb..c4b84547419c 100644 --- a/core/src/providers/azure_openai.rs +++ b/core/src/providers/azure_openai.rs @@ -483,6 +483,20 @@ impl LLM for AzureOpenAILLM { } } + let (openai_user, response_format) = match &extras { + None => (None, None), + Some(v) => ( + match v.get("openai_user") { + Some(Value::String(u)) => Some(u.to_string()), + _ => None, + }, + match v.get("response_format") { + Some(Value::String(f)) => Some(f.to_string()), + _ => None, + }, + ), + }; + let c = match event_sender { Some(_) => { streamed_chat_completion( @@ -509,13 +523,8 @@ impl LLM for AzureOpenAILLM { Some(f) => f, None => 0.0, }, - match &extras { - Some(e) => match e.get("openai_user") { - Some(Value::String(u)) => Some(u.to_string()), - _ => None, - }, - None => None, - }, + response_format, + openai_user, event_sender, ) .await? @@ -545,13 +554,8 @@ impl LLM for AzureOpenAILLM { Some(f) => f, None => 0.0, }, - match &extras { - Some(e) => match e.get("openai_user") { - Some(Value::String(u)) => Some(u.to_string()), - _ => None, - }, - None => None, - }, + response_format, + openai_user, ) .await? } diff --git a/core/src/providers/openai.rs b/core/src/providers/openai.rs index fd485f117a4e..5eb8fda70391 100644 --- a/core/src/providers/openai.rs +++ b/core/src/providers/openai.rs @@ -540,6 +540,7 @@ pub async fn streamed_chat_completion( max_tokens: Option, presence_penalty: f32, frequency_penalty: f32, + response_format: Option, user: Option, event_sender: Option>, ) -> Result { @@ -609,6 +610,11 @@ pub async fn streamed_chat_completion( if function_call.is_some() { body["function_call"] = function_call.unwrap(); } + if response_format.is_some() { + body["response_format"] = json!({ + "type": response_format.unwrap(), + }); + } // println!("BODY: {}", body.to_string()); @@ -899,6 +905,7 @@ pub async fn chat_completion( max_tokens: Option, presence_penalty: f32, frequency_penalty: f32, + response_format: Option, user: Option, ) -> Result { let https = HttpsConnector::new(); @@ -935,6 +942,11 @@ pub async fn chat_completion( if model_id.is_some() { body["model"] = json!(model_id); } + if response_format.is_some() { + body["response_format"] = json!({ + "type": response_format.unwrap(), + }); + } if functions.len() > 0 { body["functions"] = json!(functions); } @@ -942,7 +954,6 @@ pub async fn chat_completion( body["function_call"] = function_call.unwrap(); } - // println!("BODY: {}", body.to_string()); let mut req_builder = Request::builder() .method(Method::POST) .uri(uri) @@ -1398,18 +1409,30 @@ impl LLM for OpenAILLM { } } + let (openai_org_id, openai_user, response_format) = match &extras { + None => (None, None, None), + Some(v) => ( + match v.get("openai_organization_id") { + Some(Value::String(o)) => Some(o.to_string()), + _ => None, + }, + match v.get("openai_user") { + Some(Value::String(u)) => Some(u.to_string()), + _ => None, + }, + match v.get("response_format") { + Some(Value::String(f)) => Some(f.to_string()), + _ => None, + }, + ), + }; + let c = match event_sender { Some(_) => { streamed_chat_completion( self.chat_uri()?, self.api_key.clone().unwrap(), - match &extras { - Some(e) => match e.get("openai_organization_id") { - Some(Value::String(o)) => Some(o.to_string()), - _ => None, - }, - None => None, - }, + openai_org_id, Some(self.id.clone()), messages, functions, @@ -1430,13 +1453,8 @@ impl LLM for OpenAILLM { Some(f) => f, None => 0.0, }, - match &extras { - Some(e) => match e.get("openai_user") { - Some(Value::String(u)) => Some(u.to_string()), - _ => None, - }, - None => None, - }, + response_format, + openai_user, event_sender, ) .await? @@ -1445,13 +1463,7 @@ impl LLM for OpenAILLM { chat_completion( self.chat_uri()?, self.api_key.clone().unwrap(), - match &extras { - Some(e) => match e.get("openai_organization_id") { - Some(Value::String(o)) => Some(o.to_string()), - _ => None, - }, - None => None, - }, + openai_org_id, Some(self.id.clone()), messages, functions, @@ -1472,13 +1484,8 @@ impl LLM for OpenAILLM { Some(f) => f, None => 0.0, }, - match &extras { - Some(e) => match e.get("openai_user") { - Some(Value::String(u)) => Some(u.to_string()), - _ => None, - }, - None => None, - }, + response_format, + openai_user, ) .await? } diff --git a/front/lib/api/assistant/agent.ts b/front/lib/api/assistant/agent.ts index 5bb906d93f5b..8a9713df688c 100644 --- a/front/lib/api/assistant/agent.ts +++ b/front/lib/api/assistant/agent.ts @@ -71,8 +71,8 @@ export async function generateActionInputs( modelId: GPT_3_5_TURBO_MODEL_CONFIG.modelId, } : { - providerId: GPT_4_TURBO_MODEL_CONFIG.providerId, - modelId: GPT_4_TURBO_MODEL_CONFIG.modelId, + providerId: GPT_4_32K_MODEL_CONFIG.providerId, + modelId: GPT_4_32K_MODEL_CONFIG.modelId, }; const contextSize = isFree