diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 87c23e49e7..6debead977 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -18,6 +18,8 @@ use language_model::{ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; +use std::collections::HashMap; +use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; use strum::IntoEnumIterator; @@ -27,9 +29,6 @@ use util::ResultExt; use crate::{AllLanguageModelSettings, ui::InstructionListItem}; -use std::collections::HashMap; -use std::pin::Pin; - const PROVIDER_ID: &str = "mistral"; const PROVIDER_NAME: &str = "Mistral"; @@ -48,6 +47,7 @@ pub struct AvailableModel { pub max_output_tokens: Option, pub max_completion_tokens: Option, pub supports_tools: Option, + pub supports_images: Option, } pub struct MistralLanguageModelProvider { @@ -215,6 +215,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider { max_output_tokens: model.max_output_tokens, max_completion_tokens: model.max_completion_tokens, supports_tools: model.supports_tools, + supports_images: model.supports_images, }, ); } @@ -314,7 +315,7 @@ impl LanguageModel for MistralLanguageModel { } fn supports_images(&self) -> bool { - false + self.model.supports_images() } fn telemetry_id(&self) -> String { @@ -389,58 +390,113 @@ pub fn into_mistral( let stream = true; let mut messages = Vec::new(); - for message in request.messages { - for content in message.content { - match content { - MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages - .push(match message.role { - Role::User => mistral::RequestMessage::User { content: text }, - Role::Assistant => mistral::RequestMessage::Assistant { - content: Some(text), - tool_calls: Vec::new(), - }, - Role::System => mistral::RequestMessage::System { content: text }, - }), - MessageContent::RedactedThinking(_) => {} - MessageContent::Image(_) => {} - MessageContent::ToolUse(tool_use) => { - let tool_call = mistral::ToolCall { - id: tool_use.id.to_string(), - content: mistral::ToolCallContent::Function { - function: mistral::FunctionContent { - name: tool_use.name.to_string(), - arguments: serde_json::to_string(&tool_use.input) - .unwrap_or_default(), - }, - }, - }; - - if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) = - messages.last_mut() - { - tool_calls.push(tool_call); - } else { - messages.push(mistral::RequestMessage::Assistant { - content: None, - tool_calls: vec![tool_call], - }); + for message in &request.messages { + match message.role { + Role::User => { + let mut message_content = mistral::MessageContent::empty(); + for content in &message.content { + match content { + MessageContent::Text(text) => { + message_content + .push_part(mistral::MessagePart::Text { text: text.clone() }); + } + MessageContent::Image(image_content) => { + message_content.push_part(mistral::MessagePart::ImageUrl { + image_url: image_content.to_base64_url(), + }); + } + MessageContent::Thinking { text, .. } => { + message_content + .push_part(mistral::MessagePart::Text { text: text.clone() }); + } + MessageContent::RedactedThinking(_) => {} + MessageContent::ToolUse(_) | MessageContent::ToolResult(_) => { + // Tool content is not supported in User messages for Mistral + } } } - MessageContent::ToolResult(tool_result) => { - let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => text.to_string(), - LanguageModelToolResultContent::Image(_) => { - // TODO: Mistral image support - "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string() - } - }; - - messages.push(mistral::RequestMessage::Tool { - content, - tool_call_id: tool_result.tool_use_id.to_string(), + if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty()) + { + messages.push(mistral::RequestMessage::User { + content: message_content, }); } } + Role::Assistant => { + for content in &message.content { + match content { + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { + messages.push(mistral::RequestMessage::Assistant { + content: Some(text.clone()), + tool_calls: Vec::new(), + }); + } + MessageContent::RedactedThinking(_) => {} + MessageContent::Image(_) => {} + MessageContent::ToolUse(tool_use) => { + let tool_call = mistral::ToolCall { + id: tool_use.id.to_string(), + content: mistral::ToolCallContent::Function { + function: mistral::FunctionContent { + name: tool_use.name.to_string(), + arguments: serde_json::to_string(&tool_use.input) + .unwrap_or_default(), + }, + }, + }; + + if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) = + messages.last_mut() + { + tool_calls.push(tool_call); + } else { + messages.push(mistral::RequestMessage::Assistant { + content: None, + tool_calls: vec![tool_call], + }); + } + } + MessageContent::ToolResult(_) => { + // Tool results are not supported in Assistant messages + } + } + } + } + Role::System => { + for content in &message.content { + match content { + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { + messages.push(mistral::RequestMessage::System { + content: text.clone(), + }); + } + MessageContent::RedactedThinking(_) => {} + MessageContent::Image(_) + | MessageContent::ToolUse(_) + | MessageContent::ToolResult(_) => { + // Images and tools are not supported in System messages + } + } + } + } + } + } + + for message in &request.messages { + for content in &message.content { + if let MessageContent::ToolResult(tool_result) = content { + let content = match &tool_result.content { + LanguageModelToolResultContent::Text(text) => text.to_string(), + LanguageModelToolResultContent::Image(_) => { + "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string() + } + }; + + messages.push(mistral::RequestMessage::Tool { + content, + tool_call_id: tool_result.tool_use_id.to_string(), + }); + } } } @@ -819,62 +875,88 @@ impl Render for ConfigurationView { #[cfg(test)] mod tests { use super::*; - use language_model; + use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent}; #[test] - fn test_into_mistral_conversion() { - let request = language_model::LanguageModelRequest { + fn test_into_mistral_basic_conversion() { + let request = LanguageModelRequest { messages: vec![ - language_model::LanguageModelRequestMessage { - role: language_model::Role::System, - content: vec![language_model::MessageContent::Text( - "You are a helpful assistant.".to_string(), - )], + LanguageModelRequestMessage { + role: Role::System, + content: vec![MessageContent::Text("System prompt".into())], cache: false, }, - language_model::LanguageModelRequestMessage { - role: language_model::Role::User, - content: vec![language_model::MessageContent::Text( - "Hello, how are you?".to_string(), - )], + LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text("Hello".into())], cache: false, }, ], - temperature: Some(0.7), - tools: Vec::new(), + temperature: Some(0.5), + tools: vec![], tool_choice: None, thread_id: None, prompt_id: None, intent: None, mode: None, - stop: Vec::new(), + stop: vec![], }; - let model_name = "mistral-medium-latest".to_string(); - let max_output_tokens = Some(1000); - let mistral_request = into_mistral(request, model_name, max_output_tokens); - - assert_eq!(mistral_request.model, "mistral-medium-latest"); - assert_eq!(mistral_request.temperature, Some(0.7)); - assert_eq!(mistral_request.max_tokens, Some(1000)); - assert!(mistral_request.stream); - assert!(mistral_request.tools.is_empty()); - assert!(mistral_request.tool_choice.is_none()); + let mistral_request = into_mistral(request, "mistral-small-latest".into(), None); + assert_eq!(mistral_request.model, "mistral-small-latest"); + assert_eq!(mistral_request.temperature, Some(0.5)); assert_eq!(mistral_request.messages.len(), 2); + assert!(mistral_request.stream); + } - match &mistral_request.messages[0] { - mistral::RequestMessage::System { content } => { - assert_eq!(content, "You are a helpful assistant."); - } - _ => panic!("Expected System message"), - } + #[test] + fn test_into_mistral_with_image() { + let request = LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![ + MessageContent::Text("What's in this image?".into()), + MessageContent::Image(LanguageModelImage { + source: "base64data".into(), + size: Default::default(), + }), + ], + cache: false, + }], + tools: vec![], + tool_choice: None, + temperature: None, + thread_id: None, + prompt_id: None, + intent: None, + mode: None, + stop: vec![], + }; - match &mistral_request.messages[1] { - mistral::RequestMessage::User { content } => { - assert_eq!(content, "Hello, how are you?"); + let mistral_request = into_mistral(request, "pixtral-12b-latest".into(), None); + + assert_eq!(mistral_request.messages.len(), 1); + assert!(matches!( + &mistral_request.messages[0], + mistral::RequestMessage::User { + content: mistral::MessageContent::Multipart { .. } } - _ => panic!("Expected User message"), + )); + + if let mistral::RequestMessage::User { + content: mistral::MessageContent::Multipart { content }, + } = &mistral_request.messages[0] + { + assert_eq!(content.len(), 2); + assert!(matches!( + &content[0], + mistral::MessagePart::Text { text } if text == "What's in this image?" + )); + assert!(matches!( + &content[1], + mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,") + )); } } } diff --git a/crates/mistral/src/mistral.rs b/crates/mistral/src/mistral.rs index e2103dcae8..7ad3b1c294 100644 --- a/crates/mistral/src/mistral.rs +++ b/crates/mistral/src/mistral.rs @@ -60,6 +60,10 @@ pub enum Model { OpenCodestralMamba, #[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")] DevstralSmallLatest, + #[serde(rename = "pixtral-12b-latest", alias = "pixtral-12b-latest")] + Pixtral12BLatest, + #[serde(rename = "pixtral-large-latest", alias = "pixtral-large-latest")] + PixtralLargeLatest, #[serde(rename = "custom")] Custom { @@ -70,6 +74,7 @@ pub enum Model { max_output_tokens: Option, max_completion_tokens: Option, supports_tools: Option, + supports_images: Option, }, } @@ -86,6 +91,9 @@ impl Model { "mistral-small-latest" => Ok(Self::MistralSmallLatest), "open-mistral-nemo" => Ok(Self::OpenMistralNemo), "open-codestral-mamba" => Ok(Self::OpenCodestralMamba), + "devstral-small-latest" => Ok(Self::DevstralSmallLatest), + "pixtral-12b-latest" => Ok(Self::Pixtral12BLatest), + "pixtral-large-latest" => Ok(Self::PixtralLargeLatest), invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"), } } @@ -99,6 +107,8 @@ impl Model { Self::OpenMistralNemo => "open-mistral-nemo", Self::OpenCodestralMamba => "open-codestral-mamba", Self::DevstralSmallLatest => "devstral-small-latest", + Self::Pixtral12BLatest => "pixtral-12b-latest", + Self::PixtralLargeLatest => "pixtral-large-latest", Self::Custom { name, .. } => name, } } @@ -112,6 +122,8 @@ impl Model { Self::OpenMistralNemo => "open-mistral-nemo", Self::OpenCodestralMamba => "open-codestral-mamba", Self::DevstralSmallLatest => "devstral-small-latest", + Self::Pixtral12BLatest => "pixtral-12b-latest", + Self::PixtralLargeLatest => "pixtral-large-latest", Self::Custom { name, display_name, .. } => display_name.as_ref().unwrap_or(name), @@ -127,6 +139,8 @@ impl Model { Self::OpenMistralNemo => 131000, Self::OpenCodestralMamba => 256000, Self::DevstralSmallLatest => 262144, + Self::Pixtral12BLatest => 128000, + Self::PixtralLargeLatest => 128000, Self::Custom { max_tokens, .. } => *max_tokens, } } @@ -148,10 +162,29 @@ impl Model { | Self::MistralSmallLatest | Self::OpenMistralNemo | Self::OpenCodestralMamba - | Self::DevstralSmallLatest => true, + | Self::DevstralSmallLatest + | Self::Pixtral12BLatest + | Self::PixtralLargeLatest => true, Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false), } } + + pub fn supports_images(&self) -> bool { + match self { + Self::Pixtral12BLatest + | Self::PixtralLargeLatest + | Self::MistralMediumLatest + | Self::MistralSmallLatest => true, + Self::CodestralLatest + | Self::MistralLargeLatest + | Self::OpenMistralNemo + | Self::OpenCodestralMamba + | Self::DevstralSmallLatest => false, + Self::Custom { + supports_images, .. + } => supports_images.unwrap_or(false), + } + } } #[derive(Debug, Serialize, Deserialize)] @@ -231,7 +264,8 @@ pub enum RequestMessage { tool_calls: Vec, }, User { - content: String, + #[serde(flatten)] + content: MessageContent, }, System { content: String, @@ -242,6 +276,54 @@ pub enum RequestMessage { }, } +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(untagged)] +pub enum MessageContent { + #[serde(rename = "content")] + Plain { content: String }, + #[serde(rename = "content")] + Multipart { content: Vec }, +} + +impl MessageContent { + pub fn empty() -> Self { + Self::Plain { + content: String::new(), + } + } + + pub fn push_part(&mut self, part: MessagePart) { + match self { + Self::Plain { content } => match part { + MessagePart::Text { text } => { + content.push_str(&text); + } + part => { + let mut parts = if content.is_empty() { + Vec::new() + } else { + vec![MessagePart::Text { + text: content.clone(), + }] + }; + parts.push(part); + *self = Self::Multipart { content: parts }; + } + }, + Self::Multipart { content } => { + content.push(part); + } + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum MessagePart { + Text { text: String }, + ImageUrl { image_url: String }, +} + #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ToolCall { pub id: String, diff --git a/docs/src/ai/configuration.md b/docs/src/ai/configuration.md index 5180818cf0..50a792d05a 100644 --- a/docs/src/ai/configuration.md +++ b/docs/src/ai/configuration.md @@ -302,7 +302,8 @@ The Zed Assistant comes pre-configured with several Mistral models (codestral-la "max_tokens": 32000, "max_output_tokens": 4096, "max_completion_tokens": 1024, - "supports_tools": true + "supports_tools": true, + "supports_images": false } ] } @@ -374,10 +375,10 @@ The `supports_tools` option controls whether or not the model will use additiona If the model is tagged with `tools` in the Ollama catalog this option should be supplied, and built in profiles `Ask` and `Write` can be used. If the model is not tagged with `tools` in the Ollama catalog, this option can still be supplied with value `true`; however be aware that only the `Minimal` built in profile will work. -The `supports_thinking` option controls whether or not the model will perform an explicit “thinking” (reasoning) pass before producing its final answer. +The `supports_thinking` option controls whether or not the model will perform an explicit “thinking” (reasoning) pass before producing its final answer. If the model is tagged with `thinking` in the Ollama catalog, set this option and you can use it in zed. -The `supports_images` option enables the model’s vision capabilities, allowing it to process images included in the conversation context. +The `supports_images` option enables the model’s vision capabilities, allowing it to process images included in the conversation context. If the model is tagged with `vision` in the Ollama catalog, set this option and you can use it in zed. ### OpenAI {#openai}