diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 47e4ca319e..c843b736a0 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -400,7 +400,7 @@ pub fn into_open_ai( tool_calls.push(tool_call); } else { messages.push(open_ai::RequestMessage::Assistant { - content: open_ai::MessageContent::empty(), + content: None, tool_calls: vec![tool_call], }); } @@ -474,7 +474,13 @@ fn add_message_content_part( ) { match (role, messages.last_mut()) { (Role::User, Some(open_ai::RequestMessage::User { content })) - | (Role::Assistant, Some(open_ai::RequestMessage::Assistant { content, .. })) + | ( + Role::Assistant, + Some(open_ai::RequestMessage::Assistant { + content: Some(content), + .. + }), + ) | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => { content.push_part(new_part); } @@ -484,7 +490,7 @@ fn add_message_content_part( content: open_ai::MessageContent::from(vec![new_part]), }, Role::Assistant => open_ai::RequestMessage::Assistant { - content: open_ai::MessageContent::from(vec![new_part]), + content: Some(open_ai::MessageContent::from(vec![new_part])), tool_calls: Vec::new(), }, Role::System => open_ai::RequestMessage::System { diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 486e7ea40b..9a56ab538d 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -278,7 +278,7 @@ pub struct FunctionDefinition { #[serde(tag = "role", rename_all = "lowercase")] pub enum RequestMessage { Assistant { - content: MessageContent, + content: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] tool_calls: Vec, }, @@ -562,16 +562,16 @@ fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent { .into_iter() .map(|choice| { let content = match &choice.message { - RequestMessage::Assistant { content, .. } => content, - RequestMessage::User { content } => content, - RequestMessage::System { content } => content, - RequestMessage::Tool { content, .. } => content, + RequestMessage::Assistant { content, .. } => content.as_ref(), + RequestMessage::User { content } => Some(content), + RequestMessage::System { content } => Some(content), + RequestMessage::Tool { content, .. } => Some(content), }; let mut text_content = String::new(); match content { - MessageContent::Plain(text) => text_content.push_str(&text), - MessageContent::Multipart(parts) => { + Some(MessageContent::Plain(text)) => text_content.push_str(&text), + Some(MessageContent::Multipart(parts)) => { for part in parts { match part { MessagePart::Text { text } => text_content.push_str(&text), @@ -579,6 +579,7 @@ fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent { } } } + None => {} }; ChoiceDelta {