language_models: Add support for images to Mistral models (#32154)
Tested with following models. Hallucinates with whites outline images like white lined zed logo but works fine with zed black outlined logo: Pixtral 12B (pixtral-12b-latest) Pixtral Large (pixtral-large-latest) Mistral Medium (mistral-medium-latest) Mistral Small (mistral-small-latest) After this PR, almost all of the zed's llm provider who support images are now supported. Only remaining one is LMStudio. Hopefully we will get that one as well soon. Release Notes: - Add support for images to mistral models --------- Signed-off-by: Umesh Yadav <git@umesh.dev> Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de> Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
This commit is contained in:
parent
4ac7935589
commit
0bc9478b46
3 changed files with 257 additions and 92 deletions
|
@ -18,6 +18,8 @@ use language_model::{
|
|||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
|
@ -27,9 +29,6 @@ use util::ResultExt;
|
|||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
|
||||
const PROVIDER_ID: &str = "mistral";
|
||||
const PROVIDER_NAME: &str = "Mistral";
|
||||
|
||||
|
@ -48,6 +47,7 @@ pub struct AvailableModel {
|
|||
pub max_output_tokens: Option<u32>,
|
||||
pub max_completion_tokens: Option<u32>,
|
||||
pub supports_tools: Option<bool>,
|
||||
pub supports_images: Option<bool>,
|
||||
}
|
||||
|
||||
pub struct MistralLanguageModelProvider {
|
||||
|
@ -215,6 +215,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
|
|||
max_output_tokens: model.max_output_tokens,
|
||||
max_completion_tokens: model.max_completion_tokens,
|
||||
supports_tools: model.supports_tools,
|
||||
supports_images: model.supports_images,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -314,7 +315,7 @@ impl LanguageModel for MistralLanguageModel {
|
|||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
self.model.supports_images()
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
|
@ -389,58 +390,113 @@ pub fn into_mistral(
|
|||
let stream = true;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
for message in request.messages {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
|
||||
.push(match message.role {
|
||||
Role::User => mistral::RequestMessage::User { content: text },
|
||||
Role::Assistant => mistral::RequestMessage::Assistant {
|
||||
content: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => mistral::RequestMessage::System { content: text },
|
||||
}),
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_) => {}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
let tool_call = mistral::ToolCall {
|
||||
id: tool_use.id.to_string(),
|
||||
content: mistral::ToolCallContent::Function {
|
||||
function: mistral::FunctionContent {
|
||||
name: tool_use.name.to_string(),
|
||||
arguments: serde_json::to_string(&tool_use.input)
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
|
||||
messages.last_mut()
|
||||
{
|
||||
tool_calls.push(tool_call);
|
||||
} else {
|
||||
messages.push(mistral::RequestMessage::Assistant {
|
||||
content: None,
|
||||
tool_calls: vec![tool_call],
|
||||
});
|
||||
for message in &request.messages {
|
||||
match message.role {
|
||||
Role::User => {
|
||||
let mut message_content = mistral::MessageContent::empty();
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
message_content
|
||||
.push_part(mistral::MessagePart::Text { text: text.clone() });
|
||||
}
|
||||
MessageContent::Image(image_content) => {
|
||||
message_content.push_part(mistral::MessagePart::ImageUrl {
|
||||
image_url: image_content.to_base64_url(),
|
||||
});
|
||||
}
|
||||
MessageContent::Thinking { text, .. } => {
|
||||
message_content
|
||||
.push_part(mistral::MessagePart::Text { text: text.clone() });
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::ToolUse(_) | MessageContent::ToolResult(_) => {
|
||||
// Tool content is not supported in User messages for Mistral
|
||||
}
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
let content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => text.to_string(),
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
// TODO: Mistral image support
|
||||
"[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
|
||||
}
|
||||
};
|
||||
|
||||
messages.push(mistral::RequestMessage::Tool {
|
||||
content,
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
|
||||
{
|
||||
messages.push(mistral::RequestMessage::User {
|
||||
content: message_content,
|
||||
});
|
||||
}
|
||||
}
|
||||
Role::Assistant => {
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
|
||||
messages.push(mistral::RequestMessage::Assistant {
|
||||
content: Some(text.clone()),
|
||||
tool_calls: Vec::new(),
|
||||
});
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_) => {}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
let tool_call = mistral::ToolCall {
|
||||
id: tool_use.id.to_string(),
|
||||
content: mistral::ToolCallContent::Function {
|
||||
function: mistral::FunctionContent {
|
||||
name: tool_use.name.to_string(),
|
||||
arguments: serde_json::to_string(&tool_use.input)
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
|
||||
messages.last_mut()
|
||||
{
|
||||
tool_calls.push(tool_call);
|
||||
} else {
|
||||
messages.push(mistral::RequestMessage::Assistant {
|
||||
content: None,
|
||||
tool_calls: vec![tool_call],
|
||||
});
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResult(_) => {
|
||||
// Tool results are not supported in Assistant messages
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Role::System => {
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
|
||||
messages.push(mistral::RequestMessage::System {
|
||||
content: text.clone(),
|
||||
});
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_)
|
||||
| MessageContent::ToolUse(_)
|
||||
| MessageContent::ToolResult(_) => {
|
||||
// Images and tools are not supported in System messages
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for message in &request.messages {
|
||||
for content in &message.content {
|
||||
if let MessageContent::ToolResult(tool_result) = content {
|
||||
let content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => text.to_string(),
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
"[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
|
||||
}
|
||||
};
|
||||
|
||||
messages.push(mistral::RequestMessage::Tool {
|
||||
content,
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -819,62 +875,88 @@ impl Render for ConfigurationView {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use language_model;
|
||||
use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
|
||||
|
||||
#[test]
|
||||
fn test_into_mistral_conversion() {
|
||||
let request = language_model::LanguageModelRequest {
|
||||
fn test_into_mistral_basic_conversion() {
|
||||
let request = LanguageModelRequest {
|
||||
messages: vec![
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: language_model::Role::System,
|
||||
content: vec![language_model::MessageContent::Text(
|
||||
"You are a helpful assistant.".to_string(),
|
||||
)],
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: vec![MessageContent::Text("System prompt".into())],
|
||||
cache: false,
|
||||
},
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: language_model::Role::User,
|
||||
content: vec![language_model::MessageContent::Text(
|
||||
"Hello, how are you?".to_string(),
|
||||
)],
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text("Hello".into())],
|
||||
cache: false,
|
||||
},
|
||||
],
|
||||
temperature: Some(0.7),
|
||||
tools: Vec::new(),
|
||||
temperature: Some(0.5),
|
||||
tools: vec![],
|
||||
tool_choice: None,
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: None,
|
||||
mode: None,
|
||||
stop: Vec::new(),
|
||||
stop: vec![],
|
||||
};
|
||||
|
||||
let model_name = "mistral-medium-latest".to_string();
|
||||
let max_output_tokens = Some(1000);
|
||||
let mistral_request = into_mistral(request, model_name, max_output_tokens);
|
||||
|
||||
assert_eq!(mistral_request.model, "mistral-medium-latest");
|
||||
assert_eq!(mistral_request.temperature, Some(0.7));
|
||||
assert_eq!(mistral_request.max_tokens, Some(1000));
|
||||
assert!(mistral_request.stream);
|
||||
assert!(mistral_request.tools.is_empty());
|
||||
assert!(mistral_request.tool_choice.is_none());
|
||||
let mistral_request = into_mistral(request, "mistral-small-latest".into(), None);
|
||||
|
||||
assert_eq!(mistral_request.model, "mistral-small-latest");
|
||||
assert_eq!(mistral_request.temperature, Some(0.5));
|
||||
assert_eq!(mistral_request.messages.len(), 2);
|
||||
assert!(mistral_request.stream);
|
||||
}
|
||||
|
||||
match &mistral_request.messages[0] {
|
||||
mistral::RequestMessage::System { content } => {
|
||||
assert_eq!(content, "You are a helpful assistant.");
|
||||
}
|
||||
_ => panic!("Expected System message"),
|
||||
}
|
||||
#[test]
|
||||
fn test_into_mistral_with_image() {
|
||||
let request = LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![
|
||||
MessageContent::Text("What's in this image?".into()),
|
||||
MessageContent::Image(LanguageModelImage {
|
||||
source: "base64data".into(),
|
||||
size: Default::default(),
|
||||
}),
|
||||
],
|
||||
cache: false,
|
||||
}],
|
||||
tools: vec![],
|
||||
tool_choice: None,
|
||||
temperature: None,
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: None,
|
||||
mode: None,
|
||||
stop: vec![],
|
||||
};
|
||||
|
||||
match &mistral_request.messages[1] {
|
||||
mistral::RequestMessage::User { content } => {
|
||||
assert_eq!(content, "Hello, how are you?");
|
||||
let mistral_request = into_mistral(request, "pixtral-12b-latest".into(), None);
|
||||
|
||||
assert_eq!(mistral_request.messages.len(), 1);
|
||||
assert!(matches!(
|
||||
&mistral_request.messages[0],
|
||||
mistral::RequestMessage::User {
|
||||
content: mistral::MessageContent::Multipart { .. }
|
||||
}
|
||||
_ => panic!("Expected User message"),
|
||||
));
|
||||
|
||||
if let mistral::RequestMessage::User {
|
||||
content: mistral::MessageContent::Multipart { content },
|
||||
} = &mistral_request.messages[0]
|
||||
{
|
||||
assert_eq!(content.len(), 2);
|
||||
assert!(matches!(
|
||||
&content[0],
|
||||
mistral::MessagePart::Text { text } if text == "What's in this image?"
|
||||
));
|
||||
assert!(matches!(
|
||||
&content[1],
|
||||
mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,6 +60,10 @@ pub enum Model {
|
|||
OpenCodestralMamba,
|
||||
#[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")]
|
||||
DevstralSmallLatest,
|
||||
#[serde(rename = "pixtral-12b-latest", alias = "pixtral-12b-latest")]
|
||||
Pixtral12BLatest,
|
||||
#[serde(rename = "pixtral-large-latest", alias = "pixtral-large-latest")]
|
||||
PixtralLargeLatest,
|
||||
|
||||
#[serde(rename = "custom")]
|
||||
Custom {
|
||||
|
@ -70,6 +74,7 @@ pub enum Model {
|
|||
max_output_tokens: Option<u32>,
|
||||
max_completion_tokens: Option<u32>,
|
||||
supports_tools: Option<bool>,
|
||||
supports_images: Option<bool>,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -86,6 +91,9 @@ impl Model {
|
|||
"mistral-small-latest" => Ok(Self::MistralSmallLatest),
|
||||
"open-mistral-nemo" => Ok(Self::OpenMistralNemo),
|
||||
"open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
|
||||
"devstral-small-latest" => Ok(Self::DevstralSmallLatest),
|
||||
"pixtral-12b-latest" => Ok(Self::Pixtral12BLatest),
|
||||
"pixtral-large-latest" => Ok(Self::PixtralLargeLatest),
|
||||
invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
|
||||
}
|
||||
}
|
||||
|
@ -99,6 +107,8 @@ impl Model {
|
|||
Self::OpenMistralNemo => "open-mistral-nemo",
|
||||
Self::OpenCodestralMamba => "open-codestral-mamba",
|
||||
Self::DevstralSmallLatest => "devstral-small-latest",
|
||||
Self::Pixtral12BLatest => "pixtral-12b-latest",
|
||||
Self::PixtralLargeLatest => "pixtral-large-latest",
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
@ -112,6 +122,8 @@ impl Model {
|
|||
Self::OpenMistralNemo => "open-mistral-nemo",
|
||||
Self::OpenCodestralMamba => "open-codestral-mamba",
|
||||
Self::DevstralSmallLatest => "devstral-small-latest",
|
||||
Self::Pixtral12BLatest => "pixtral-12b-latest",
|
||||
Self::PixtralLargeLatest => "pixtral-large-latest",
|
||||
Self::Custom {
|
||||
name, display_name, ..
|
||||
} => display_name.as_ref().unwrap_or(name),
|
||||
|
@ -127,6 +139,8 @@ impl Model {
|
|||
Self::OpenMistralNemo => 131000,
|
||||
Self::OpenCodestralMamba => 256000,
|
||||
Self::DevstralSmallLatest => 262144,
|
||||
Self::Pixtral12BLatest => 128000,
|
||||
Self::PixtralLargeLatest => 128000,
|
||||
Self::Custom { max_tokens, .. } => *max_tokens,
|
||||
}
|
||||
}
|
||||
|
@ -148,10 +162,29 @@ impl Model {
|
|||
| Self::MistralSmallLatest
|
||||
| Self::OpenMistralNemo
|
||||
| Self::OpenCodestralMamba
|
||||
| Self::DevstralSmallLatest => true,
|
||||
| Self::DevstralSmallLatest
|
||||
| Self::Pixtral12BLatest
|
||||
| Self::PixtralLargeLatest => true,
|
||||
Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supports_images(&self) -> bool {
|
||||
match self {
|
||||
Self::Pixtral12BLatest
|
||||
| Self::PixtralLargeLatest
|
||||
| Self::MistralMediumLatest
|
||||
| Self::MistralSmallLatest => true,
|
||||
Self::CodestralLatest
|
||||
| Self::MistralLargeLatest
|
||||
| Self::OpenMistralNemo
|
||||
| Self::OpenCodestralMamba
|
||||
| Self::DevstralSmallLatest => false,
|
||||
Self::Custom {
|
||||
supports_images, ..
|
||||
} => supports_images.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
@ -231,7 +264,8 @@ pub enum RequestMessage {
|
|||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
User {
|
||||
content: String,
|
||||
#[serde(flatten)]
|
||||
content: MessageContent,
|
||||
},
|
||||
System {
|
||||
content: String,
|
||||
|
@ -242,6 +276,54 @@ pub enum RequestMessage {
|
|||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
#[serde(rename = "content")]
|
||||
Plain { content: String },
|
||||
#[serde(rename = "content")]
|
||||
Multipart { content: Vec<MessagePart> },
|
||||
}
|
||||
|
||||
impl MessageContent {
|
||||
pub fn empty() -> Self {
|
||||
Self::Plain {
|
||||
content: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_part(&mut self, part: MessagePart) {
|
||||
match self {
|
||||
Self::Plain { content } => match part {
|
||||
MessagePart::Text { text } => {
|
||||
content.push_str(&text);
|
||||
}
|
||||
part => {
|
||||
let mut parts = if content.is_empty() {
|
||||
Vec::new()
|
||||
} else {
|
||||
vec![MessagePart::Text {
|
||||
text: content.clone(),
|
||||
}]
|
||||
};
|
||||
parts.push(part);
|
||||
*self = Self::Multipart { content: parts };
|
||||
}
|
||||
},
|
||||
Self::Multipart { content } => {
|
||||
content.push(part);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum MessagePart {
|
||||
Text { text: String },
|
||||
ImageUrl { image_url: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
|
|
|
@ -302,7 +302,8 @@ The Zed Assistant comes pre-configured with several Mistral models (codestral-la
|
|||
"max_tokens": 32000,
|
||||
"max_output_tokens": 4096,
|
||||
"max_completion_tokens": 1024,
|
||||
"supports_tools": true
|
||||
"supports_tools": true,
|
||||
"supports_images": false
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@ -374,10 +375,10 @@ The `supports_tools` option controls whether or not the model will use additiona
|
|||
If the model is tagged with `tools` in the Ollama catalog this option should be supplied, and built in profiles `Ask` and `Write` can be used.
|
||||
If the model is not tagged with `tools` in the Ollama catalog, this option can still be supplied with value `true`; however be aware that only the `Minimal` built in profile will work.
|
||||
|
||||
The `supports_thinking` option controls whether or not the model will perform an explicit “thinking” (reasoning) pass before producing its final answer.
|
||||
The `supports_thinking` option controls whether or not the model will perform an explicit “thinking” (reasoning) pass before producing its final answer.
|
||||
If the model is tagged with `thinking` in the Ollama catalog, set this option and you can use it in zed.
|
||||
|
||||
The `supports_images` option enables the model’s vision capabilities, allowing it to process images included in the conversation context.
|
||||
The `supports_images` option enables the model’s vision capabilities, allowing it to process images included in the conversation context.
|
||||
If the model is tagged with `vision` in the Ollama catalog, set this option and you can use it in zed.
|
||||
|
||||
### OpenAI {#openai}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue