From 4de3295a07ab57603f2b7284bacccaace7214af4 Mon Sep 17 00:00:00 2001 From: Chloe Pratas Date: Tue, 28 Jan 2025 10:27:14 +0100 Subject: [PATCH] Add system message test for chat completions and refactor ResponseFormat field --- .../chat_completions/chat_completions.rs | 22 +++++++++++++++++-- src/client/chat_completions/request.rs | 15 ++++++++----- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/client/chat_completions/chat_completions.rs b/src/client/chat_completions/chat_completions.rs index 4bafc8d..dfe007e 100644 --- a/src/client/chat_completions/chat_completions.rs +++ b/src/client/chat_completions/chat_completions.rs @@ -69,7 +69,10 @@ impl DeepSeekClient { #[cfg(test)] mod tests { use super::*; - use crate::client::chat_completions::request::{Message, Model, Temperature}; + use crate::{ + client::chat_completions::request::{Message, Model, Temperature}, + request::ResponseFormat, + }; #[tokio::test] #[ignore] @@ -106,7 +109,22 @@ mod tests { let response = client.chat_completions(request).await.unwrap(); assert!(!response.choices.is_empty()); } - + #[tokio::test] + #[ignore] + async fn test_chat_completions_with_system_message_json() { + dotenvy::dotenv().ok(); + let client = DeepSeekClient::default().unwrap(); + let request = RequestBody::new_messages(vec![ + Message::new_system_message("You are a helpful assistant.".to_string()), + Message::new_user_message("What is 2+2?".to_string()), + Message::new_user_message(r#"```json { "output" : } "#.to_string()), + ]) + .with_response_format(ResponseFormat::new( + crate::request::ResponseFormatType::Json, + )); + let response = client.chat_completions(request).await.unwrap(); + assert!(!response.choices.is_empty()); + } #[tokio::test] #[ignore] async fn test_chat_completions_with_reasoner_model() { diff --git a/src/client/chat_completions/request.rs b/src/client/chat_completions/request.rs index 5f7f645..7f2ce81 100644 --- a/src/client/chat_completions/request.rs +++ b/src/client/chat_completions/request.rs @@ -273,18 +273,21 @@ impl fmt::Display for PresencePenalty { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub struct ResponseFormat { - pub response_format: ResponseFormatType, + #[serde(rename = "type")] + pub type_: ResponseFormatType, } impl Default for ResponseFormat { fn default() -> Self { ResponseFormat { - response_format: ResponseFormatType::Json, + type_: ResponseFormatType::Json, } } } impl ResponseFormat { pub fn new(response_format: ResponseFormatType) -> Self { - ResponseFormat { response_format } + ResponseFormat { + type_: response_format, + } } } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] @@ -501,15 +504,15 @@ mod tests { let text_format = ResponseFormat::new(ResponseFormatType::Text); let default_format = ResponseFormat::default(); - match default_format.response_format { + match default_format.type_ { ResponseFormatType::Json => assert!(true), _ => assert!(false, "Default should be Json"), } - match json_format.response_format { + match json_format.type_ { ResponseFormatType::Json => assert!(true), _ => assert!(false, "Expected Json"), } - match text_format.response_format { + match text_format.type_ { ResponseFormatType::Text => assert!(true), _ => assert!(false, "Expected Text"), }