diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index b323b595ba..3b324cd11b 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -543,7 +543,7 @@ pub enum RequestContent { #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum ToolResultContent { - JustText(String), + Plain(String), Multipart(Vec), } diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index 2ac6bfe5a7..fe46ddebce 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -217,7 +217,7 @@ pub enum ChatMessage { #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum ChatMessageContent { - OnlyText(String), + Plain(String), Multipart(Vec), } @@ -230,7 +230,7 @@ impl ChatMessageContent { impl From> for ChatMessageContent { fn from(mut parts: Vec) -> Self { if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() { - ChatMessageContent::OnlyText(std::mem::take(text)) + ChatMessageContent::Plain(std::mem::take(text)) } else { ChatMessageContent::Multipart(parts) } @@ -239,7 +239,7 @@ impl From> for ChatMessageContent { impl From for ChatMessageContent { fn from(text: String) -> Self { - ChatMessageContent::OnlyText(text) + ChatMessageContent::Plain(text) } } diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index eccde976d3..a87d730093 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -589,7 +589,7 @@ pub fn into_anthropic( is_error: tool_result.is_error, content: match tool_result.content { LanguageModelToolResultContent::Text(text) => { - ToolResultContent::JustText(text.to_string()) + ToolResultContent::Plain(text.to_string()) } LanguageModelToolResultContent::Image(image) => { ToolResultContent::Multipart(vec![ToolResultPart::Image { diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index b19b4653b1..369c81e650 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -15,7 +15,7 @@ use language_model::{ LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, }; -use open_ai::{Model, ResponseStreamEvent, stream_completion}; +use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; @@ -362,17 +362,26 @@ pub fn into_open_ai( for message in request.messages { for content in message.content { match content { - MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages - .push(match message.role { - Role::User => open_ai::RequestMessage::User { content: text }, - Role::Assistant => open_ai::RequestMessage::Assistant { - content: Some(text), - tool_calls: Vec::new(), - }, - Role::System => open_ai::RequestMessage::System { content: text }, - }), + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { + add_message_content_part( + open_ai::MessagePart::Text { text: text }, + message.role, + &mut messages, + ) + } MessageContent::RedactedThinking(_) => {} - MessageContent::Image(_) => {} + MessageContent::Image(image) => { + add_message_content_part( + open_ai::MessagePart::Image { + image_url: ImageUrl { + url: image.to_base64_url(), + detail: None, + }, + }, + message.role, + &mut messages, + ); + } MessageContent::ToolUse(tool_use) => { let tool_call = open_ai::ToolCall { id: tool_use.id.to_string(), @@ -391,22 +400,30 @@ pub fn into_open_ai( tool_calls.push(tool_call); } else { messages.push(open_ai::RequestMessage::Assistant { - content: None, + content: open_ai::MessageContent::empty(), tool_calls: vec![tool_call], }); } } MessageContent::ToolResult(tool_result) => { let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => text.to_string(), - LanguageModelToolResultContent::Image(_) => { - // TODO: Open AI image support - "[Tool responded with an image, but Zed doesn't support these in Open AI models yet]".to_string() + LanguageModelToolResultContent::Text(text) => { + vec![open_ai::MessagePart::Text { + text: text.to_string(), + }] + } + LanguageModelToolResultContent::Image(image) => { + vec![open_ai::MessagePart::Image { + image_url: ImageUrl { + url: image.to_base64_url(), + detail: None, + }, + }] } }; messages.push(open_ai::RequestMessage::Tool { - content, + content: content.into(), tool_call_id: tool_result.tool_use_id.to_string(), }); } @@ -446,6 +463,34 @@ pub fn into_open_ai( } } +fn add_message_content_part( + new_part: open_ai::MessagePart, + role: Role, + messages: &mut Vec, +) { + match (role, messages.last_mut()) { + (Role::User, Some(open_ai::RequestMessage::User { content })) + | (Role::Assistant, Some(open_ai::RequestMessage::Assistant { content, .. })) + | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => { + content.push_part(new_part); + } + _ => { + messages.push(match role { + Role::User => open_ai::RequestMessage::User { + content: open_ai::MessageContent::empty(), + }, + Role::Assistant => open_ai::RequestMessage::Assistant { + content: open_ai::MessageContent::empty(), + tool_calls: Vec::new(), + }, + Role::System => open_ai::RequestMessage::System { + content: open_ai::MessageContent::empty(), + }, + }); + } + } +} + pub struct OpenAiEventMapper { tool_calls_by_index: HashMap, } diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 9faac29cac..59e26ee347 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -278,22 +278,75 @@ pub struct FunctionDefinition { #[serde(tag = "role", rename_all = "lowercase")] pub enum RequestMessage { Assistant { - content: Option, + content: MessageContent, #[serde(default, skip_serializing_if = "Vec::is_empty")] tool_calls: Vec, }, User { - content: String, + content: MessageContent, }, System { - content: String, + content: MessageContent, }, Tool { - content: String, + content: MessageContent, tool_call_id: String, }, } +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] +#[serde(untagged)] +pub enum MessageContent { + Plain(String), + Multipart(Vec), +} + +impl MessageContent { + pub fn empty() -> Self { + MessageContent::Multipart(vec![]) + } + + pub fn push_part(&mut self, part: MessagePart) { + match self { + MessageContent::Plain(text) => { + *self = + MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]); + } + MessageContent::Multipart(parts) if parts.is_empty() => match part { + MessagePart::Text { text } => *self = MessageContent::Plain(text), + MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]), + }, + MessageContent::Multipart(parts) => parts.push(part), + } + } +} + +impl From> for MessageContent { + fn from(mut parts: Vec) -> Self { + if let [MessagePart::Text { text }] = parts.as_mut_slice() { + MessageContent::Plain(std::mem::take(text)) + } else { + MessageContent::Multipart(parts) + } + } +} + +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] +#[serde(tag = "type")] +pub enum MessagePart { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image_url")] + Image { image_url: ImageUrl }, +} + +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] +pub struct ImageUrl { + pub url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, +} + #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ToolCall { pub id: String, @@ -509,24 +562,45 @@ fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent { choices: response .choices .into_iter() - .map(|choice| ChoiceDelta { - index: choice.index, - delta: ResponseMessageDelta { - role: Some(match choice.message { - RequestMessage::Assistant { .. } => Role::Assistant, - RequestMessage::User { .. } => Role::User, - RequestMessage::System { .. } => Role::System, - RequestMessage::Tool { .. } => Role::Tool, - }), - content: match choice.message { - RequestMessage::Assistant { content, .. } => content, - RequestMessage::User { content } => Some(content), - RequestMessage::System { content } => Some(content), - RequestMessage::Tool { content, .. } => Some(content), + .map(|choice| { + let content = match &choice.message { + RequestMessage::Assistant { content, .. } => content, + RequestMessage::User { content } => content, + RequestMessage::System { content } => content, + RequestMessage::Tool { content, .. } => content, + }; + + let mut text_content = String::new(); + match content { + MessageContent::Plain(text) => text_content.push_str(&text), + MessageContent::Multipart(parts) => { + for part in parts { + match part { + MessagePart::Text { text } => text_content.push_str(&text), + MessagePart::Image { .. } => {} + } + } + } + }; + + ChoiceDelta { + index: choice.index, + delta: ResponseMessageDelta { + role: Some(match choice.message { + RequestMessage::Assistant { .. } => Role::Assistant, + RequestMessage::User { .. } => Role::User, + RequestMessage::System { .. } => Role::System, + RequestMessage::Tool { .. } => Role::Tool, + }), + content: if text_content.is_empty() { + None + } else { + Some(text_content) + }, + tool_calls: None, }, - tool_calls: None, - }, - finish_reason: choice.finish_reason, + finish_reason: choice.finish_reason, + } }) .collect(), usage: Some(response.usage),