Add image input support for OpenAI models (#30639)
Release Notes: - Added input image support for OpenAI models
This commit is contained in:
parent
68afe4fdda
commit
dd6594621f
5 changed files with 162 additions and 43 deletions
|
@ -543,7 +543,7 @@ pub enum RequestContent {
|
|||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolResultContent {
|
||||
JustText(String),
|
||||
Plain(String),
|
||||
Multipart(Vec<ToolResultPart>),
|
||||
}
|
||||
|
||||
|
|
|
@ -217,7 +217,7 @@ pub enum ChatMessage {
|
|||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ChatMessageContent {
|
||||
OnlyText(String),
|
||||
Plain(String),
|
||||
Multipart(Vec<ChatMessagePart>),
|
||||
}
|
||||
|
||||
|
@ -230,7 +230,7 @@ impl ChatMessageContent {
|
|||
impl From<Vec<ChatMessagePart>> for ChatMessageContent {
|
||||
fn from(mut parts: Vec<ChatMessagePart>) -> Self {
|
||||
if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
|
||||
ChatMessageContent::OnlyText(std::mem::take(text))
|
||||
ChatMessageContent::Plain(std::mem::take(text))
|
||||
} else {
|
||||
ChatMessageContent::Multipart(parts)
|
||||
}
|
||||
|
@ -239,7 +239,7 @@ impl From<Vec<ChatMessagePart>> for ChatMessageContent {
|
|||
|
||||
impl From<String> for ChatMessageContent {
|
||||
fn from(text: String) -> Self {
|
||||
ChatMessageContent::OnlyText(text)
|
||||
ChatMessageContent::Plain(text)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -589,7 +589,7 @@ pub fn into_anthropic(
|
|||
is_error: tool_result.is_error,
|
||||
content: match tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
ToolResultContent::JustText(text.to_string())
|
||||
ToolResultContent::Plain(text.to_string())
|
||||
}
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
ToolResultContent::Multipart(vec![ToolResultPart::Image {
|
||||
|
|
|
@ -15,7 +15,7 @@ use language_model::{
|
|||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
};
|
||||
use open_ai::{Model, ResponseStreamEvent, stream_completion};
|
||||
use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
|
@ -362,17 +362,26 @@ pub fn into_open_ai(
|
|||
for message in request.messages {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
|
||||
.push(match message.role {
|
||||
Role::User => open_ai::RequestMessage::User { content: text },
|
||||
Role::Assistant => open_ai::RequestMessage::Assistant {
|
||||
content: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => open_ai::RequestMessage::System { content: text },
|
||||
}),
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
|
||||
add_message_content_part(
|
||||
open_ai::MessagePart::Text { text: text },
|
||||
message.role,
|
||||
&mut messages,
|
||||
)
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_) => {}
|
||||
MessageContent::Image(image) => {
|
||||
add_message_content_part(
|
||||
open_ai::MessagePart::Image {
|
||||
image_url: ImageUrl {
|
||||
url: image.to_base64_url(),
|
||||
detail: None,
|
||||
},
|
||||
},
|
||||
message.role,
|
||||
&mut messages,
|
||||
);
|
||||
}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
let tool_call = open_ai::ToolCall {
|
||||
id: tool_use.id.to_string(),
|
||||
|
@ -391,22 +400,30 @@ pub fn into_open_ai(
|
|||
tool_calls.push(tool_call);
|
||||
} else {
|
||||
messages.push(open_ai::RequestMessage::Assistant {
|
||||
content: None,
|
||||
content: open_ai::MessageContent::empty(),
|
||||
tool_calls: vec![tool_call],
|
||||
});
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
let content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => text.to_string(),
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
// TODO: Open AI image support
|
||||
"[Tool responded with an image, but Zed doesn't support these in Open AI models yet]".to_string()
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
vec![open_ai::MessagePart::Text {
|
||||
text: text.to_string(),
|
||||
}]
|
||||
}
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
vec![open_ai::MessagePart::Image {
|
||||
image_url: ImageUrl {
|
||||
url: image.to_base64_url(),
|
||||
detail: None,
|
||||
},
|
||||
}]
|
||||
}
|
||||
};
|
||||
|
||||
messages.push(open_ai::RequestMessage::Tool {
|
||||
content,
|
||||
content: content.into(),
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
}
|
||||
|
@ -446,6 +463,34 @@ pub fn into_open_ai(
|
|||
}
|
||||
}
|
||||
|
||||
fn add_message_content_part(
|
||||
new_part: open_ai::MessagePart,
|
||||
role: Role,
|
||||
messages: &mut Vec<open_ai::RequestMessage>,
|
||||
) {
|
||||
match (role, messages.last_mut()) {
|
||||
(Role::User, Some(open_ai::RequestMessage::User { content }))
|
||||
| (Role::Assistant, Some(open_ai::RequestMessage::Assistant { content, .. }))
|
||||
| (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
|
||||
content.push_part(new_part);
|
||||
}
|
||||
_ => {
|
||||
messages.push(match role {
|
||||
Role::User => open_ai::RequestMessage::User {
|
||||
content: open_ai::MessageContent::empty(),
|
||||
},
|
||||
Role::Assistant => open_ai::RequestMessage::Assistant {
|
||||
content: open_ai::MessageContent::empty(),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => open_ai::RequestMessage::System {
|
||||
content: open_ai::MessageContent::empty(),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpenAiEventMapper {
|
||||
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
||||
}
|
||||
|
|
|
@ -278,22 +278,75 @@ pub struct FunctionDefinition {
|
|||
#[serde(tag = "role", rename_all = "lowercase")]
|
||||
pub enum RequestMessage {
|
||||
Assistant {
|
||||
content: Option<String>,
|
||||
content: MessageContent,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
User {
|
||||
content: String,
|
||||
content: MessageContent,
|
||||
},
|
||||
System {
|
||||
content: String,
|
||||
content: MessageContent,
|
||||
},
|
||||
Tool {
|
||||
content: String,
|
||||
content: MessageContent,
|
||||
tool_call_id: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
Plain(String),
|
||||
Multipart(Vec<MessagePart>),
|
||||
}
|
||||
|
||||
impl MessageContent {
|
||||
pub fn empty() -> Self {
|
||||
MessageContent::Multipart(vec![])
|
||||
}
|
||||
|
||||
pub fn push_part(&mut self, part: MessagePart) {
|
||||
match self {
|
||||
MessageContent::Plain(text) => {
|
||||
*self =
|
||||
MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
|
||||
}
|
||||
MessageContent::Multipart(parts) if parts.is_empty() => match part {
|
||||
MessagePart::Text { text } => *self = MessageContent::Plain(text),
|
||||
MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
|
||||
},
|
||||
MessageContent::Multipart(parts) => parts.push(part),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<MessagePart>> for MessageContent {
|
||||
fn from(mut parts: Vec<MessagePart>) -> Self {
|
||||
if let [MessagePart::Text { text }] = parts.as_mut_slice() {
|
||||
MessageContent::Plain(std::mem::take(text))
|
||||
} else {
|
||||
MessageContent::Multipart(parts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum MessagePart {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image_url")]
|
||||
Image { image_url: ImageUrl },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ImageUrl {
|
||||
pub url: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
|
@ -509,24 +562,45 @@ fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
|
|||
choices: response
|
||||
.choices
|
||||
.into_iter()
|
||||
.map(|choice| ChoiceDelta {
|
||||
index: choice.index,
|
||||
delta: ResponseMessageDelta {
|
||||
role: Some(match choice.message {
|
||||
RequestMessage::Assistant { .. } => Role::Assistant,
|
||||
RequestMessage::User { .. } => Role::User,
|
||||
RequestMessage::System { .. } => Role::System,
|
||||
RequestMessage::Tool { .. } => Role::Tool,
|
||||
}),
|
||||
content: match choice.message {
|
||||
RequestMessage::Assistant { content, .. } => content,
|
||||
RequestMessage::User { content } => Some(content),
|
||||
RequestMessage::System { content } => Some(content),
|
||||
RequestMessage::Tool { content, .. } => Some(content),
|
||||
.map(|choice| {
|
||||
let content = match &choice.message {
|
||||
RequestMessage::Assistant { content, .. } => content,
|
||||
RequestMessage::User { content } => content,
|
||||
RequestMessage::System { content } => content,
|
||||
RequestMessage::Tool { content, .. } => content,
|
||||
};
|
||||
|
||||
let mut text_content = String::new();
|
||||
match content {
|
||||
MessageContent::Plain(text) => text_content.push_str(&text),
|
||||
MessageContent::Multipart(parts) => {
|
||||
for part in parts {
|
||||
match part {
|
||||
MessagePart::Text { text } => text_content.push_str(&text),
|
||||
MessagePart::Image { .. } => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
ChoiceDelta {
|
||||
index: choice.index,
|
||||
delta: ResponseMessageDelta {
|
||||
role: Some(match choice.message {
|
||||
RequestMessage::Assistant { .. } => Role::Assistant,
|
||||
RequestMessage::User { .. } => Role::User,
|
||||
RequestMessage::System { .. } => Role::System,
|
||||
RequestMessage::Tool { .. } => Role::Tool,
|
||||
}),
|
||||
content: if text_content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(text_content)
|
||||
},
|
||||
tool_calls: None,
|
||||
},
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: choice.finish_reason,
|
||||
finish_reason: choice.finish_reason,
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
usage: Some(response.usage),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue