Skip to content

Commit

Permalink
enh: support response_format extra config for OAI chat models (#2753)
Browse files Browse the repository at this point in the history
* enh: support response_format extra config for OAI chat models

* fixes

* use GPT4 32k for inputs generation

---------

Co-authored-by: Henry Fontanier <henry@dust.tt>
  • Loading branch information
fontanierh and Henry Fontanier authored Dec 4, 2023
1 parent c8cbdd4 commit d1a84ca
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 80 deletions.
27 changes: 18 additions & 9 deletions core/src/blocks/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,15 +277,24 @@ impl Block for Chat {
} && event_sender.is_some();

let extras = match config {
Some(v) => match v.get("openai_user") {
Some(v) => match v {
Value::String(s) => Some(json!({
"openai_user": s.clone(),
})),
_ => None,
},
None => None,
},
Some(v) => {
let mut extras = json!({});

if let Some(Value::String(s)) = v.get("openai_user") {
extras["openai_user"] = json!(s.clone());
}
if let Some(Value::String(s)) = v.get("openai_organization_id") {
extras["openai_organization_id"] = json!(s.clone());
}
if let Some(Value::String(s)) = v.get("response_format") {
extras["response_format"] = json!(s.clone());
}

match extras.as_object().unwrap().keys().len() {
0 => None,
_ => Some(extras),
}
}
None => None,
};

Expand Down
43 changes: 17 additions & 26 deletions core/src/blocks/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,35 +393,26 @@ impl Block for LLM {
None => false,
} && event_sender.is_some();

let mut extras_map = HashMap::new();
if let Some(v) = config {
if let Some(openai_user) = v.get("openai_user") {
match openai_user {
Value::String(s) => {
extras_map.insert("openai_user", s);
}
_ => Err(anyhow!(
"Invalid `openai_user` in configuration for llm block `{}`",
name
))?,
let extras = match config {
Some(v) => {
let mut extras = json!({});

if let Some(Value::String(s)) = v.get("openai_user") {
extras["openai_user"] = json!(s.clone());
}
}
if let Some(openai_organization_id) = v.get("openai_organization_id") {
match openai_organization_id {
Value::String(s) => {
extras_map.insert("openai_organization_id", s);
}
_ => Err(anyhow!(
"Invalid `openai_organization_id` in configuration for llm block `{}`",
name
))?,
if let Some(Value::String(s)) = v.get("openai_organization_id") {
extras["openai_organization_id"] = json!(s.clone());
}
if let Some(Value::String(s)) = v.get("response_format") {
extras["response_format"] = json!(s.clone());
}
}
}

let extras = match extras_map.len() {
0 => None,
_ => Some(json!(extras_map)),
match extras.as_object().unwrap().keys().len() {
0 => None,
_ => Some(extras),
}
}
None => None,
};

// if model_id starts with gpt-3.5-turbo or gpt-4 use the chat interface (except for
Expand Down
32 changes: 18 additions & 14 deletions core/src/providers/azure_openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,20 @@ impl LLM for AzureOpenAILLM {
}
}

let (openai_user, response_format) = match &extras {
None => (None, None),
Some(v) => (
match v.get("openai_user") {
Some(Value::String(u)) => Some(u.to_string()),
_ => None,
},
match v.get("response_format") {
Some(Value::String(f)) => Some(f.to_string()),
_ => None,
},
),
};

let c = match event_sender {
Some(_) => {
streamed_chat_completion(
Expand All @@ -509,13 +523,8 @@ impl LLM for AzureOpenAILLM {
Some(f) => f,
None => 0.0,
},
match &extras {
Some(e) => match e.get("openai_user") {
Some(Value::String(u)) => Some(u.to_string()),
_ => None,
},
None => None,
},
response_format,
openai_user,
event_sender,
)
.await?
Expand Down Expand Up @@ -545,13 +554,8 @@ impl LLM for AzureOpenAILLM {
Some(f) => f,
None => 0.0,
},
match &extras {
Some(e) => match e.get("openai_user") {
Some(Value::String(u)) => Some(u.to_string()),
_ => None,
},
None => None,
},
response_format,
openai_user,
)
.await?
}
Expand Down
65 changes: 36 additions & 29 deletions core/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ pub async fn streamed_chat_completion(
max_tokens: Option<i32>,
presence_penalty: f32,
frequency_penalty: f32,
response_format: Option<String>,
user: Option<String>,
event_sender: Option<UnboundedSender<Value>>,
) -> Result<ChatCompletion> {
Expand Down Expand Up @@ -609,6 +610,11 @@ pub async fn streamed_chat_completion(
if function_call.is_some() {
body["function_call"] = function_call.unwrap();
}
if response_format.is_some() {
body["response_format"] = json!({
"type": response_format.unwrap(),
});
}

// println!("BODY: {}", body.to_string());

Expand Down Expand Up @@ -899,6 +905,7 @@ pub async fn chat_completion(
max_tokens: Option<i32>,
presence_penalty: f32,
frequency_penalty: f32,
response_format: Option<String>,
user: Option<String>,
) -> Result<ChatCompletion> {
let https = HttpsConnector::new();
Expand Down Expand Up @@ -935,14 +942,18 @@ pub async fn chat_completion(
if model_id.is_some() {
body["model"] = json!(model_id);
}
if response_format.is_some() {
body["response_format"] = json!({
"type": response_format.unwrap(),
});
}
if functions.len() > 0 {
body["functions"] = json!(functions);
}
if function_call.is_some() {
body["function_call"] = function_call.unwrap();
}

// println!("BODY: {}", body.to_string());
let mut req_builder = Request::builder()
.method(Method::POST)
.uri(uri)
Expand Down Expand Up @@ -1398,18 +1409,30 @@ impl LLM for OpenAILLM {
}
}

let (openai_org_id, openai_user, response_format) = match &extras {
None => (None, None, None),
Some(v) => (
match v.get("openai_organization_id") {
Some(Value::String(o)) => Some(o.to_string()),
_ => None,
},
match v.get("openai_user") {
Some(Value::String(u)) => Some(u.to_string()),
_ => None,
},
match v.get("response_format") {
Some(Value::String(f)) => Some(f.to_string()),
_ => None,
},
),
};

let c = match event_sender {
Some(_) => {
streamed_chat_completion(
self.chat_uri()?,
self.api_key.clone().unwrap(),
match &extras {
Some(e) => match e.get("openai_organization_id") {
Some(Value::String(o)) => Some(o.to_string()),
_ => None,
},
None => None,
},
openai_org_id,
Some(self.id.clone()),
messages,
functions,
Expand All @@ -1430,13 +1453,8 @@ impl LLM for OpenAILLM {
Some(f) => f,
None => 0.0,
},
match &extras {
Some(e) => match e.get("openai_user") {
Some(Value::String(u)) => Some(u.to_string()),
_ => None,
},
None => None,
},
response_format,
openai_user,
event_sender,
)
.await?
Expand All @@ -1445,13 +1463,7 @@ impl LLM for OpenAILLM {
chat_completion(
self.chat_uri()?,
self.api_key.clone().unwrap(),
match &extras {
Some(e) => match e.get("openai_organization_id") {
Some(Value::String(o)) => Some(o.to_string()),
_ => None,
},
None => None,
},
openai_org_id,
Some(self.id.clone()),
messages,
functions,
Expand All @@ -1472,13 +1484,8 @@ impl LLM for OpenAILLM {
Some(f) => f,
None => 0.0,
},
match &extras {
Some(e) => match e.get("openai_user") {
Some(Value::String(u)) => Some(u.to_string()),
_ => None,
},
None => None,
},
response_format,
openai_user,
)
.await?
}
Expand Down
4 changes: 2 additions & 2 deletions front/lib/api/assistant/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ export async function generateActionInputs(
modelId: GPT_3_5_TURBO_MODEL_CONFIG.modelId,
}
: {
providerId: GPT_4_TURBO_MODEL_CONFIG.providerId,
modelId: GPT_4_TURBO_MODEL_CONFIG.modelId,
providerId: GPT_4_32K_MODEL_CONFIG.providerId,
modelId: GPT_4_32K_MODEL_CONFIG.modelId,
};

const contextSize = isFree
Expand Down

0 comments on commit d1a84ca

Please sign in to comment.