language_models: Add support for images to Mistral models (#32154)
Tested with following models. Hallucinates with whites outline images like white lined zed logo but works fine with zed black outlined logo: Pixtral 12B (pixtral-12b-latest) Pixtral Large (pixtral-large-latest) Mistral Medium (mistral-medium-latest) Mistral Small (mistral-small-latest) After this PR, almost all of the zed's llm provider who support images are now supported. Only remaining one is LMStudio. Hopefully we will get that one as well soon. Release Notes: - Add support for images to mistral models --------- Signed-off-by: Umesh Yadav <git@umesh.dev> Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de> Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
This commit is contained in:
parent
4ac7935589
commit
0bc9478b46
3 changed files with 257 additions and 92 deletions
|
@ -18,6 +18,8 @@ use language_model::{
|
|||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
|
@ -27,9 +29,6 @@ use util::ResultExt;
|
|||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
|
||||
const PROVIDER_ID: &str = "mistral";
|
||||
const PROVIDER_NAME: &str = "Mistral";
|
||||
|
||||
|
@ -48,6 +47,7 @@ pub struct AvailableModel {
|
|||
pub max_output_tokens: Option<u32>,
|
||||
pub max_completion_tokens: Option<u32>,
|
||||
pub supports_tools: Option<bool>,
|
||||
pub supports_images: Option<bool>,
|
||||
}
|
||||
|
||||
pub struct MistralLanguageModelProvider {
|
||||
|
@ -215,6 +215,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
|
|||
max_output_tokens: model.max_output_tokens,
|
||||
max_completion_tokens: model.max_completion_tokens,
|
||||
supports_tools: model.supports_tools,
|
||||
supports_images: model.supports_images,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -314,7 +315,7 @@ impl LanguageModel for MistralLanguageModel {
|
|||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
self.model.supports_images()
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
|
@ -389,18 +390,47 @@ pub fn into_mistral(
|
|||
let stream = true;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
for message in request.messages {
|
||||
for content in message.content {
|
||||
for message in &request.messages {
|
||||
match message.role {
|
||||
Role::User => {
|
||||
let mut message_content = mistral::MessageContent::empty();
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
|
||||
.push(match message.role {
|
||||
Role::User => mistral::RequestMessage::User { content: text },
|
||||
Role::Assistant => mistral::RequestMessage::Assistant {
|
||||
content: Some(text),
|
||||
MessageContent::Text(text) => {
|
||||
message_content
|
||||
.push_part(mistral::MessagePart::Text { text: text.clone() });
|
||||
}
|
||||
MessageContent::Image(image_content) => {
|
||||
message_content.push_part(mistral::MessagePart::ImageUrl {
|
||||
image_url: image_content.to_base64_url(),
|
||||
});
|
||||
}
|
||||
MessageContent::Thinking { text, .. } => {
|
||||
message_content
|
||||
.push_part(mistral::MessagePart::Text { text: text.clone() });
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::ToolUse(_) | MessageContent::ToolResult(_) => {
|
||||
// Tool content is not supported in User messages for Mistral
|
||||
}
|
||||
}
|
||||
}
|
||||
if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
|
||||
{
|
||||
messages.push(mistral::RequestMessage::User {
|
||||
content: message_content,
|
||||
});
|
||||
}
|
||||
}
|
||||
Role::Assistant => {
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
|
||||
messages.push(mistral::RequestMessage::Assistant {
|
||||
content: Some(text.clone()),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => mistral::RequestMessage::System { content: text },
|
||||
}),
|
||||
});
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_) => {}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
|
@ -426,11 +456,38 @@ pub fn into_mistral(
|
|||
});
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
MessageContent::ToolResult(_) => {
|
||||
// Tool results are not supported in Assistant messages
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Role::System => {
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
|
||||
messages.push(mistral::RequestMessage::System {
|
||||
content: text.clone(),
|
||||
});
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_)
|
||||
| MessageContent::ToolUse(_)
|
||||
| MessageContent::ToolResult(_) => {
|
||||
// Images and tools are not supported in System messages
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for message in &request.messages {
|
||||
for content in &message.content {
|
||||
if let MessageContent::ToolResult(tool_result) = content {
|
||||
let content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => text.to_string(),
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
// TODO: Mistral image support
|
||||
"[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
|
||||
}
|
||||
};
|
||||
|
@ -442,7 +499,6 @@ pub fn into_mistral(
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The Mistral API requires that tool messages be followed by assistant messages,
|
||||
// not user messages. When we have a tool->user sequence in the conversation,
|
||||
|
@ -819,62 +875,88 @@ impl Render for ConfigurationView {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use language_model;
|
||||
use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
|
||||
|
||||
#[test]
|
||||
fn test_into_mistral_conversion() {
|
||||
let request = language_model::LanguageModelRequest {
|
||||
fn test_into_mistral_basic_conversion() {
|
||||
let request = LanguageModelRequest {
|
||||
messages: vec![
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: language_model::Role::System,
|
||||
content: vec![language_model::MessageContent::Text(
|
||||
"You are a helpful assistant.".to_string(),
|
||||
)],
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: vec![MessageContent::Text("System prompt".into())],
|
||||
cache: false,
|
||||
},
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: language_model::Role::User,
|
||||
content: vec![language_model::MessageContent::Text(
|
||||
"Hello, how are you?".to_string(),
|
||||
)],
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text("Hello".into())],
|
||||
cache: false,
|
||||
},
|
||||
],
|
||||
temperature: Some(0.7),
|
||||
tools: Vec::new(),
|
||||
temperature: Some(0.5),
|
||||
tools: vec![],
|
||||
tool_choice: None,
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: None,
|
||||
mode: None,
|
||||
stop: Vec::new(),
|
||||
stop: vec![],
|
||||
};
|
||||
|
||||
let model_name = "mistral-medium-latest".to_string();
|
||||
let max_output_tokens = Some(1000);
|
||||
let mistral_request = into_mistral(request, model_name, max_output_tokens);
|
||||
|
||||
assert_eq!(mistral_request.model, "mistral-medium-latest");
|
||||
assert_eq!(mistral_request.temperature, Some(0.7));
|
||||
assert_eq!(mistral_request.max_tokens, Some(1000));
|
||||
assert!(mistral_request.stream);
|
||||
assert!(mistral_request.tools.is_empty());
|
||||
assert!(mistral_request.tool_choice.is_none());
|
||||
let mistral_request = into_mistral(request, "mistral-small-latest".into(), None);
|
||||
|
||||
assert_eq!(mistral_request.model, "mistral-small-latest");
|
||||
assert_eq!(mistral_request.temperature, Some(0.5));
|
||||
assert_eq!(mistral_request.messages.len(), 2);
|
||||
|
||||
match &mistral_request.messages[0] {
|
||||
mistral::RequestMessage::System { content } => {
|
||||
assert_eq!(content, "You are a helpful assistant.");
|
||||
}
|
||||
_ => panic!("Expected System message"),
|
||||
assert!(mistral_request.stream);
|
||||
}
|
||||
|
||||
match &mistral_request.messages[1] {
|
||||
mistral::RequestMessage::User { content } => {
|
||||
assert_eq!(content, "Hello, how are you?");
|
||||
#[test]
|
||||
fn test_into_mistral_with_image() {
|
||||
let request = LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![
|
||||
MessageContent::Text("What's in this image?".into()),
|
||||
MessageContent::Image(LanguageModelImage {
|
||||
source: "base64data".into(),
|
||||
size: Default::default(),
|
||||
}),
|
||||
],
|
||||
cache: false,
|
||||
}],
|
||||
tools: vec![],
|
||||
tool_choice: None,
|
||||
temperature: None,
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: None,
|
||||
mode: None,
|
||||
stop: vec![],
|
||||
};
|
||||
|
||||
let mistral_request = into_mistral(request, "pixtral-12b-latest".into(), None);
|
||||
|
||||
assert_eq!(mistral_request.messages.len(), 1);
|
||||
assert!(matches!(
|
||||
&mistral_request.messages[0],
|
||||
mistral::RequestMessage::User {
|
||||
content: mistral::MessageContent::Multipart { .. }
|
||||
}
|
||||
_ => panic!("Expected User message"),
|
||||
));
|
||||
|
||||
if let mistral::RequestMessage::User {
|
||||
content: mistral::MessageContent::Multipart { content },
|
||||
} = &mistral_request.messages[0]
|
||||
{
|
||||
assert_eq!(content.len(), 2);
|
||||
assert!(matches!(
|
||||
&content[0],
|
||||
mistral::MessagePart::Text { text } if text == "What's in this image?"
|
||||
));
|
||||
assert!(matches!(
|
||||
&content[1],
|
||||
mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,6 +60,10 @@ pub enum Model {
|
|||
OpenCodestralMamba,
|
||||
#[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")]
|
||||
DevstralSmallLatest,
|
||||
#[serde(rename = "pixtral-12b-latest", alias = "pixtral-12b-latest")]
|
||||
Pixtral12BLatest,
|
||||
#[serde(rename = "pixtral-large-latest", alias = "pixtral-large-latest")]
|
||||
PixtralLargeLatest,
|
||||
|
||||
#[serde(rename = "custom")]
|
||||
Custom {
|
||||
|
@ -70,6 +74,7 @@ pub enum Model {
|
|||
max_output_tokens: Option<u32>,
|
||||
max_completion_tokens: Option<u32>,
|
||||
supports_tools: Option<bool>,
|
||||
supports_images: Option<bool>,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -86,6 +91,9 @@ impl Model {
|
|||
"mistral-small-latest" => Ok(Self::MistralSmallLatest),
|
||||
"open-mistral-nemo" => Ok(Self::OpenMistralNemo),
|
||||
"open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
|
||||
"devstral-small-latest" => Ok(Self::DevstralSmallLatest),
|
||||
"pixtral-12b-latest" => Ok(Self::Pixtral12BLatest),
|
||||
"pixtral-large-latest" => Ok(Self::PixtralLargeLatest),
|
||||
invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
|
||||
}
|
||||
}
|
||||
|
@ -99,6 +107,8 @@ impl Model {
|
|||
Self::OpenMistralNemo => "open-mistral-nemo",
|
||||
Self::OpenCodestralMamba => "open-codestral-mamba",
|
||||
Self::DevstralSmallLatest => "devstral-small-latest",
|
||||
Self::Pixtral12BLatest => "pixtral-12b-latest",
|
||||
Self::PixtralLargeLatest => "pixtral-large-latest",
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
@ -112,6 +122,8 @@ impl Model {
|
|||
Self::OpenMistralNemo => "open-mistral-nemo",
|
||||
Self::OpenCodestralMamba => "open-codestral-mamba",
|
||||
Self::DevstralSmallLatest => "devstral-small-latest",
|
||||
Self::Pixtral12BLatest => "pixtral-12b-latest",
|
||||
Self::PixtralLargeLatest => "pixtral-large-latest",
|
||||
Self::Custom {
|
||||
name, display_name, ..
|
||||
} => display_name.as_ref().unwrap_or(name),
|
||||
|
@ -127,6 +139,8 @@ impl Model {
|
|||
Self::OpenMistralNemo => 131000,
|
||||
Self::OpenCodestralMamba => 256000,
|
||||
Self::DevstralSmallLatest => 262144,
|
||||
Self::Pixtral12BLatest => 128000,
|
||||
Self::PixtralLargeLatest => 128000,
|
||||
Self::Custom { max_tokens, .. } => *max_tokens,
|
||||
}
|
||||
}
|
||||
|
@ -148,10 +162,29 @@ impl Model {
|
|||
| Self::MistralSmallLatest
|
||||
| Self::OpenMistralNemo
|
||||
| Self::OpenCodestralMamba
|
||||
| Self::DevstralSmallLatest => true,
|
||||
| Self::DevstralSmallLatest
|
||||
| Self::Pixtral12BLatest
|
||||
| Self::PixtralLargeLatest => true,
|
||||
Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supports_images(&self) -> bool {
|
||||
match self {
|
||||
Self::Pixtral12BLatest
|
||||
| Self::PixtralLargeLatest
|
||||
| Self::MistralMediumLatest
|
||||
| Self::MistralSmallLatest => true,
|
||||
Self::CodestralLatest
|
||||
| Self::MistralLargeLatest
|
||||
| Self::OpenMistralNemo
|
||||
| Self::OpenCodestralMamba
|
||||
| Self::DevstralSmallLatest => false,
|
||||
Self::Custom {
|
||||
supports_images, ..
|
||||
} => supports_images.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
@ -231,7 +264,8 @@ pub enum RequestMessage {
|
|||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
User {
|
||||
content: String,
|
||||
#[serde(flatten)]
|
||||
content: MessageContent,
|
||||
},
|
||||
System {
|
||||
content: String,
|
||||
|
@ -242,6 +276,54 @@ pub enum RequestMessage {
|
|||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
#[serde(rename = "content")]
|
||||
Plain { content: String },
|
||||
#[serde(rename = "content")]
|
||||
Multipart { content: Vec<MessagePart> },
|
||||
}
|
||||
|
||||
impl MessageContent {
|
||||
pub fn empty() -> Self {
|
||||
Self::Plain {
|
||||
content: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_part(&mut self, part: MessagePart) {
|
||||
match self {
|
||||
Self::Plain { content } => match part {
|
||||
MessagePart::Text { text } => {
|
||||
content.push_str(&text);
|
||||
}
|
||||
part => {
|
||||
let mut parts = if content.is_empty() {
|
||||
Vec::new()
|
||||
} else {
|
||||
vec![MessagePart::Text {
|
||||
text: content.clone(),
|
||||
}]
|
||||
};
|
||||
parts.push(part);
|
||||
*self = Self::Multipart { content: parts };
|
||||
}
|
||||
},
|
||||
Self::Multipart { content } => {
|
||||
content.push(part);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum MessagePart {
|
||||
Text { text: String },
|
||||
ImageUrl { image_url: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
|
|
|
@ -302,7 +302,8 @@ The Zed Assistant comes pre-configured with several Mistral models (codestral-la
|
|||
"max_tokens": 32000,
|
||||
"max_output_tokens": 4096,
|
||||
"max_completion_tokens": 1024,
|
||||
"supports_tools": true
|
||||
"supports_tools": true,
|
||||
"supports_images": false
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue