Skip to content

Commit

Permalink
Add system message test for chat completions and refactor ResponseFor…
Browse files Browse the repository at this point in the history
…mat field
  • Loading branch information
Chloe199719 committed Jan 28, 2025
1 parent 6ac35b5 commit 4de3295
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
22 changes: 20 additions & 2 deletions src/client/chat_completions/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ impl DeepSeekClient {
#[cfg(test)]
mod tests {
use super::*;
use crate::client::chat_completions::request::{Message, Model, Temperature};
use crate::{
client::chat_completions::request::{Message, Model, Temperature},
request::ResponseFormat,
};

#[tokio::test]
#[ignore]
Expand Down Expand Up @@ -106,7 +109,22 @@ mod tests {
let response = client.chat_completions(request).await.unwrap();
assert!(!response.choices.is_empty());
}

#[tokio::test]
#[ignore]
async fn test_chat_completions_with_system_message_json() {
dotenvy::dotenv().ok();
let client = DeepSeekClient::default().unwrap();
let request = RequestBody::new_messages(vec![
Message::new_system_message("You are a helpful assistant.".to_string()),
Message::new_user_message("What is 2+2?".to_string()),
Message::new_user_message(r#"```json { "output" : <value>} "#.to_string()),
])
.with_response_format(ResponseFormat::new(
crate::request::ResponseFormatType::Json,
));
let response = client.chat_completions(request).await.unwrap();
assert!(!response.choices.is_empty());
}
#[tokio::test]
#[ignore]
async fn test_chat_completions_with_reasoner_model() {
Expand Down
15 changes: 9 additions & 6 deletions src/client/chat_completions/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,18 +273,21 @@ impl fmt::Display for PresencePenalty {
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]

pub struct ResponseFormat {
pub response_format: ResponseFormatType,
#[serde(rename = "type")]
pub type_: ResponseFormatType,
}
impl Default for ResponseFormat {
fn default() -> Self {
ResponseFormat {
response_format: ResponseFormatType::Json,
type_: ResponseFormatType::Json,
}
}
}
impl ResponseFormat {
pub fn new(response_format: ResponseFormatType) -> Self {
ResponseFormat { response_format }
ResponseFormat {
type_: response_format,
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -501,15 +504,15 @@ mod tests {
let text_format = ResponseFormat::new(ResponseFormatType::Text);
let default_format = ResponseFormat::default();

match default_format.response_format {
match default_format.type_ {
ResponseFormatType::Json => assert!(true),
_ => assert!(false, "Default should be Json"),
}
match json_format.response_format {
match json_format.type_ {
ResponseFormatType::Json => assert!(true),
_ => assert!(false, "Expected Json"),
}
match text_format.response_format {
match text_format.type_ {
ResponseFormatType::Text => assert!(true),
_ => assert!(false, "Expected Text"),
}
Expand Down

0 comments on commit 4de3295

Please sign in to comment.