Add image input support for OpenAI models (#30639)

Release Notes:

- Added input image support for OpenAI models
This commit is contained in:
Agus Zubiaga 2025-05-13 17:32:42 +02:00 committed by GitHub
parent 68afe4fdda
commit dd6594621f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 162 additions and 43 deletions

View file

@ -543,7 +543,7 @@ pub enum RequestContent {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum ToolResultContent { pub enum ToolResultContent {
JustText(String), Plain(String),
Multipart(Vec<ToolResultPart>), Multipart(Vec<ToolResultPart>),
} }

View file

@ -217,7 +217,7 @@ pub enum ChatMessage {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum ChatMessageContent { pub enum ChatMessageContent {
OnlyText(String), Plain(String),
Multipart(Vec<ChatMessagePart>), Multipart(Vec<ChatMessagePart>),
} }
@ -230,7 +230,7 @@ impl ChatMessageContent {
impl From<Vec<ChatMessagePart>> for ChatMessageContent { impl From<Vec<ChatMessagePart>> for ChatMessageContent {
fn from(mut parts: Vec<ChatMessagePart>) -> Self { fn from(mut parts: Vec<ChatMessagePart>) -> Self {
if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() { if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
ChatMessageContent::OnlyText(std::mem::take(text)) ChatMessageContent::Plain(std::mem::take(text))
} else { } else {
ChatMessageContent::Multipart(parts) ChatMessageContent::Multipart(parts)
} }
@ -239,7 +239,7 @@ impl From<Vec<ChatMessagePart>> for ChatMessageContent {
impl From<String> for ChatMessageContent { impl From<String> for ChatMessageContent {
fn from(text: String) -> Self { fn from(text: String) -> Self {
ChatMessageContent::OnlyText(text) ChatMessageContent::Plain(text)
} }
} }

View file

@ -589,7 +589,7 @@ pub fn into_anthropic(
is_error: tool_result.is_error, is_error: tool_result.is_error,
content: match tool_result.content { content: match tool_result.content {
LanguageModelToolResultContent::Text(text) => { LanguageModelToolResultContent::Text(text) => {
ToolResultContent::JustText(text.to_string()) ToolResultContent::Plain(text.to_string())
} }
LanguageModelToolResultContent::Image(image) => { LanguageModelToolResultContent::Image(image) => {
ToolResultContent::Multipart(vec![ToolResultPart::Image { ToolResultContent::Multipart(vec![ToolResultPart::Image {

View file

@ -15,7 +15,7 @@ use language_model::{
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason, RateLimiter, Role, StopReason,
}; };
use open_ai::{Model, ResponseStreamEvent, stream_completion}; use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
@ -362,17 +362,26 @@ pub fn into_open_ai(
for message in request.messages { for message in request.messages {
for content in message.content { for content in message.content {
match content { match content {
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
.push(match message.role { add_message_content_part(
Role::User => open_ai::RequestMessage::User { content: text }, open_ai::MessagePart::Text { text: text },
Role::Assistant => open_ai::RequestMessage::Assistant { message.role,
content: Some(text), &mut messages,
tool_calls: Vec::new(), )
}, }
Role::System => open_ai::RequestMessage::System { content: text },
}),
MessageContent::RedactedThinking(_) => {} MessageContent::RedactedThinking(_) => {}
MessageContent::Image(_) => {} MessageContent::Image(image) => {
add_message_content_part(
open_ai::MessagePart::Image {
image_url: ImageUrl {
url: image.to_base64_url(),
detail: None,
},
},
message.role,
&mut messages,
);
}
MessageContent::ToolUse(tool_use) => { MessageContent::ToolUse(tool_use) => {
let tool_call = open_ai::ToolCall { let tool_call = open_ai::ToolCall {
id: tool_use.id.to_string(), id: tool_use.id.to_string(),
@ -391,22 +400,30 @@ pub fn into_open_ai(
tool_calls.push(tool_call); tool_calls.push(tool_call);
} else { } else {
messages.push(open_ai::RequestMessage::Assistant { messages.push(open_ai::RequestMessage::Assistant {
content: None, content: open_ai::MessageContent::empty(),
tool_calls: vec![tool_call], tool_calls: vec![tool_call],
}); });
} }
} }
MessageContent::ToolResult(tool_result) => { MessageContent::ToolResult(tool_result) => {
let content = match &tool_result.content { let content = match &tool_result.content {
LanguageModelToolResultContent::Text(text) => text.to_string(), LanguageModelToolResultContent::Text(text) => {
LanguageModelToolResultContent::Image(_) => { vec![open_ai::MessagePart::Text {
// TODO: Open AI image support text: text.to_string(),
"[Tool responded with an image, but Zed doesn't support these in Open AI models yet]".to_string() }]
}
LanguageModelToolResultContent::Image(image) => {
vec![open_ai::MessagePart::Image {
image_url: ImageUrl {
url: image.to_base64_url(),
detail: None,
},
}]
} }
}; };
messages.push(open_ai::RequestMessage::Tool { messages.push(open_ai::RequestMessage::Tool {
content, content: content.into(),
tool_call_id: tool_result.tool_use_id.to_string(), tool_call_id: tool_result.tool_use_id.to_string(),
}); });
} }
@ -446,6 +463,34 @@ pub fn into_open_ai(
} }
} }
fn add_message_content_part(
new_part: open_ai::MessagePart,
role: Role,
messages: &mut Vec<open_ai::RequestMessage>,
) {
match (role, messages.last_mut()) {
(Role::User, Some(open_ai::RequestMessage::User { content }))
| (Role::Assistant, Some(open_ai::RequestMessage::Assistant { content, .. }))
| (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
content.push_part(new_part);
}
_ => {
messages.push(match role {
Role::User => open_ai::RequestMessage::User {
content: open_ai::MessageContent::empty(),
},
Role::Assistant => open_ai::RequestMessage::Assistant {
content: open_ai::MessageContent::empty(),
tool_calls: Vec::new(),
},
Role::System => open_ai::RequestMessage::System {
content: open_ai::MessageContent::empty(),
},
});
}
}
}
pub struct OpenAiEventMapper { pub struct OpenAiEventMapper {
tool_calls_by_index: HashMap<usize, RawToolCall>, tool_calls_by_index: HashMap<usize, RawToolCall>,
} }

View file

@ -278,22 +278,75 @@ pub struct FunctionDefinition {
#[serde(tag = "role", rename_all = "lowercase")] #[serde(tag = "role", rename_all = "lowercase")]
pub enum RequestMessage { pub enum RequestMessage {
Assistant { Assistant {
content: Option<String>, content: MessageContent,
#[serde(default, skip_serializing_if = "Vec::is_empty")] #[serde(default, skip_serializing_if = "Vec::is_empty")]
tool_calls: Vec<ToolCall>, tool_calls: Vec<ToolCall>,
}, },
User { User {
content: String, content: MessageContent,
}, },
System { System {
content: String, content: MessageContent,
}, },
Tool { Tool {
content: String, content: MessageContent,
tool_call_id: String, tool_call_id: String,
}, },
} }
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
#[serde(untagged)]
pub enum MessageContent {
Plain(String),
Multipart(Vec<MessagePart>),
}
impl MessageContent {
pub fn empty() -> Self {
MessageContent::Multipart(vec![])
}
pub fn push_part(&mut self, part: MessagePart) {
match self {
MessageContent::Plain(text) => {
*self =
MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
}
MessageContent::Multipart(parts) if parts.is_empty() => match part {
MessagePart::Text { text } => *self = MessageContent::Plain(text),
MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
},
MessageContent::Multipart(parts) => parts.push(part),
}
}
}
impl From<Vec<MessagePart>> for MessageContent {
fn from(mut parts: Vec<MessagePart>) -> Self {
if let [MessagePart::Text { text }] = parts.as_mut_slice() {
MessageContent::Plain(std::mem::take(text))
} else {
MessageContent::Multipart(parts)
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
#[serde(tag = "type")]
pub enum MessagePart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
Image { image_url: ImageUrl },
}
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ToolCall { pub struct ToolCall {
pub id: String, pub id: String,
@ -509,24 +562,45 @@ fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
choices: response choices: response
.choices .choices
.into_iter() .into_iter()
.map(|choice| ChoiceDelta { .map(|choice| {
index: choice.index, let content = match &choice.message {
delta: ResponseMessageDelta { RequestMessage::Assistant { content, .. } => content,
role: Some(match choice.message { RequestMessage::User { content } => content,
RequestMessage::Assistant { .. } => Role::Assistant, RequestMessage::System { content } => content,
RequestMessage::User { .. } => Role::User, RequestMessage::Tool { content, .. } => content,
RequestMessage::System { .. } => Role::System, };
RequestMessage::Tool { .. } => Role::Tool,
}), let mut text_content = String::new();
content: match choice.message { match content {
RequestMessage::Assistant { content, .. } => content, MessageContent::Plain(text) => text_content.push_str(&text),
RequestMessage::User { content } => Some(content), MessageContent::Multipart(parts) => {
RequestMessage::System { content } => Some(content), for part in parts {
RequestMessage::Tool { content, .. } => Some(content), match part {
MessagePart::Text { text } => text_content.push_str(&text),
MessagePart::Image { .. } => {}
}
}
}
};
ChoiceDelta {
index: choice.index,
delta: ResponseMessageDelta {
role: Some(match choice.message {
RequestMessage::Assistant { .. } => Role::Assistant,
RequestMessage::User { .. } => Role::User,
RequestMessage::System { .. } => Role::System,
RequestMessage::Tool { .. } => Role::Tool,
}),
content: if text_content.is_empty() {
None
} else {
Some(text_content)
},
tool_calls: None,
}, },
tool_calls: None, finish_reason: choice.finish_reason,
}, }
finish_reason: choice.finish_reason,
}) })
.collect(), .collect(),
usage: Some(response.usage), usage: Some(response.usage),