diff --git a/Cargo.lock b/Cargo.lock index 504cb2a573..09f58daabd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -546,6 +546,7 @@ dependencies = [ "language_model", "lmstudio", "log", + "mistral", "ollama", "open_ai", "paths", diff --git a/crates/assistant_settings/Cargo.toml b/crates/assistant_settings/Cargo.toml index 8a8316fae0..c46ea64630 100644 --- a/crates/assistant_settings/Cargo.toml +++ b/crates/assistant_settings/Cargo.toml @@ -23,6 +23,7 @@ log.workspace = true ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } deepseek = { workspace = true, features = ["schemars"] } +mistral = { workspace = true, features = ["schemars"] } schemars.workspace = true serde.workspace = true settings.workspace = true diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index ad9c1e6d62..f7fd1a1ead 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -10,6 +10,7 @@ use deepseek::Model as DeepseekModel; use gpui::{App, Pixels, SharedString}; use language_model::{CloudModel, LanguageModel}; use lmstudio::Model as LmStudioModel; +use mistral::Model as MistralModel; use ollama::Model as OllamaModel; use schemars::{JsonSchema, schema::Schema}; use serde::{Deserialize, Serialize}; @@ -71,6 +72,11 @@ pub enum AssistantProviderContentV1 { default_model: Option, api_url: Option, }, + #[serde(rename = "mistral")] + Mistral { + default_model: Option, + api_url: Option, + }, } #[derive(Default, Clone, Debug)] @@ -249,6 +255,12 @@ impl AssistantSettingsContent { model: model.id().to_string(), }) } + AssistantProviderContentV1::Mistral { default_model, .. } => { + default_model.map(|model| LanguageModelSelection { + provider: "mistral".into(), + model: model.id().to_string(), + }) + } }), inline_assistant_model: None, commit_message_model: None, @@ -700,6 +712,7 @@ impl JsonSchema for LanguageModelProviderSetting { "zed.dev".into(), "copilot_chat".into(), "deepseek".into(), + "mistral".into(), ]), ..Default::default() } diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 5143767e9e..93317d1a51 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -2,6 +2,7 @@ use anyhow::{Context as _, Result, anyhow}; use collections::BTreeMap; use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; +use futures::stream::BoxStream; use futures::{FutureExt, StreamExt, future::BoxFuture}; use gpui::{ AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, @@ -11,13 +12,13 @@ use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, RateLimiter, Role, + LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, + RateLimiter, Role, StopReason, }; - -use futures::stream::BoxStream; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; +use std::str::FromStr; use std::sync::Arc; use strum::IntoEnumIterator; use theme::ThemeSettings; @@ -26,6 +27,9 @@ use util::ResultExt; use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use std::collections::HashMap; +use std::pin::Pin; + const PROVIDER_ID: &str = "mistral"; const PROVIDER_NAME: &str = "Mistral"; @@ -43,6 +47,7 @@ pub struct AvailableModel { pub max_tokens: usize, pub max_output_tokens: Option, pub max_completion_tokens: Option, + pub supports_tools: Option, } pub struct MistralLanguageModelProvider { @@ -209,6 +214,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider { max_tokens: model.max_tokens, max_output_tokens: model.max_output_tokens, max_completion_tokens: model.max_completion_tokens, + supports_tools: model.supports_tools, }, ); } @@ -300,14 +306,14 @@ impl LanguageModel for MistralLanguageModel { } fn supports_tools(&self) -> bool { - false - } - - fn supports_images(&self) -> bool { - false + self.model.supports_tools() } fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool { + self.model.supports_tools() + } + + fn supports_images(&self) -> bool { false } @@ -368,26 +374,8 @@ impl LanguageModel for MistralLanguageModel { async move { let stream = stream.await?; - Ok(stream - .map(|result| { - result - .and_then(|response| { - response - .choices - .first() - .ok_or_else(|| anyhow!("Empty response")) - .map(|choice| { - choice - .delta - .content - .clone() - .unwrap_or_default() - .map(LanguageModelCompletionEvent::Text) - }) - }) - .map_err(LanguageModelCompletionError::Other) - }) - .boxed()) + let mapper = MistralEventMapper::new(); + Ok(mapper.map_stream(stream).boxed()) } .boxed() } @@ -398,33 +386,87 @@ pub fn into_mistral( model: String, max_output_tokens: Option, ) -> mistral::Request { - let len = request.messages.len(); - let merged_messages = - request - .messages - .into_iter() - .fold(Vec::with_capacity(len), |mut acc, msg| { - let role = msg.role; - let content = msg.string_contents(); + let stream = true; - acc.push(match role { - Role::User => mistral::RequestMessage::User { content }, - Role::Assistant => mistral::RequestMessage::Assistant { - content: Some(content), - tool_calls: Vec::new(), - }, - Role::System => mistral::RequestMessage::System { content }, - }); - acc - }); + let mut messages = Vec::new(); + for message in request.messages { + for content in message.content { + match content { + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages + .push(match message.role { + Role::User => mistral::RequestMessage::User { content: text }, + Role::Assistant => mistral::RequestMessage::Assistant { + content: Some(text), + tool_calls: Vec::new(), + }, + Role::System => mistral::RequestMessage::System { content: text }, + }), + MessageContent::RedactedThinking(_) => {} + MessageContent::Image(_) => {} + MessageContent::ToolUse(tool_use) => { + let tool_call = mistral::ToolCall { + id: tool_use.id.to_string(), + content: mistral::ToolCallContent::Function { + function: mistral::FunctionContent { + name: tool_use.name.to_string(), + arguments: serde_json::to_string(&tool_use.input) + .unwrap_or_default(), + }, + }, + }; + + if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) = + messages.last_mut() + { + tool_calls.push(tool_call); + } else { + messages.push(mistral::RequestMessage::Assistant { + content: None, + tool_calls: vec![tool_call], + }); + } + } + MessageContent::ToolResult(tool_result) => { + let content = match &tool_result.content { + LanguageModelToolResultContent::Text(text) => text.to_string(), + LanguageModelToolResultContent::Image(_) => { + // TODO: Mistral image support + "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string() + } + }; + + messages.push(mistral::RequestMessage::Tool { + content, + tool_call_id: tool_result.tool_use_id.to_string(), + }); + } + } + } + } mistral::Request { model, - messages: merged_messages, - stream: true, + messages, + stream, max_tokens: max_output_tokens, temperature: request.temperature, response_format: None, + tool_choice: match request.tool_choice { + Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => { + Some(mistral::ToolChoice::Auto) + } + Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => { + Some(mistral::ToolChoice::Any) + } + Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None), + _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto), + _ => None, + }, + parallel_tool_calls: if !request.tools.is_empty() { + Some(false) + } else { + None + }, tools: request .tools .into_iter() @@ -439,6 +481,127 @@ pub fn into_mistral( } } +pub struct MistralEventMapper { + tool_calls_by_index: HashMap, +} + +impl MistralEventMapper { + pub fn new() -> Self { + Self { + tool_calls_by_index: HashMap::default(), + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl futures::Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + }) + }) + } + + pub fn map_event( + &mut self, + event: mistral::StreamResponse, + ) -> Vec> { + let Some(choice) = event.choices.first() else { + return vec![Err(LanguageModelCompletionError::Other(anyhow!( + "Response contained no choices" + )))]; + }; + + let mut events = Vec::new(); + if let Some(content) = choice.delta.content.clone() { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } + + if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { + for tool_call in tool_calls { + let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); + + if let Some(tool_id) = tool_call.id.clone() { + entry.id = tool_id; + } + + if let Some(function) = tool_call.function.as_ref() { + if let Some(name) = function.name.clone() { + entry.name = name; + } + + if let Some(arguments) = function.arguments.clone() { + entry.arguments.push_str(&arguments); + } + } + } + } + + if let Some(finish_reason) = choice.finish_reason.as_deref() { + match finish_reason { + "stop" => { + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + "tool_calls" => { + events.extend(self.process_tool_calls()); + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + } + unexpected => { + log::error!("Unexpected Mistral stop_reason: {unexpected:?}"); + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + } + } + + events + } + + fn process_tool_calls( + &mut self, + ) -> Vec> { + let mut results = Vec::new(); + + for (_, tool_call) in self.tool_calls_by_index.drain() { + if tool_call.id.is_empty() || tool_call.name.is_empty() { + results.push(Err(LanguageModelCompletionError::Other(anyhow!( + "Received incomplete tool call: missing id or name" + )))); + continue; + } + + match serde_json::Value::from_str(&tool_call.arguments) { + Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.into(), + name: tool_call.name.into(), + is_input_complete: true, + input, + raw_input: tool_call.arguments, + }, + ))), + Err(error) => results.push(Err(LanguageModelCompletionError::BadInputJson { + id: tool_call.id.into(), + tool_name: tool_call.name.into(), + raw_input: tool_call.arguments.into(), + json_parse_error: error.to_string(), + })), + } + } + + results + } +} + +#[derive(Default)] +struct RawToolCall { + id: String, + name: String, + arguments: String, +} + struct ConfigurationView { api_key_editor: Entity, state: gpui::Entity, @@ -623,3 +786,65 @@ impl Render for ConfigurationView { } } } + +#[cfg(test)] +mod tests { + use super::*; + use language_model; + + #[test] + fn test_into_mistral_conversion() { + let request = language_model::LanguageModelRequest { + messages: vec![ + language_model::LanguageModelRequestMessage { + role: language_model::Role::System, + content: vec![language_model::MessageContent::Text( + "You are a helpful assistant.".to_string(), + )], + cache: false, + }, + language_model::LanguageModelRequestMessage { + role: language_model::Role::User, + content: vec![language_model::MessageContent::Text( + "Hello, how are you?".to_string(), + )], + cache: false, + }, + ], + temperature: Some(0.7), + tools: Vec::new(), + tool_choice: None, + thread_id: None, + prompt_id: None, + mode: None, + stop: Vec::new(), + }; + + let model_name = "mistral-medium-latest".to_string(); + let max_output_tokens = Some(1000); + let mistral_request = into_mistral(request, model_name, max_output_tokens); + + assert_eq!(mistral_request.model, "mistral-medium-latest"); + assert_eq!(mistral_request.temperature, Some(0.7)); + assert_eq!(mistral_request.max_tokens, Some(1000)); + assert!(mistral_request.stream); + assert!(mistral_request.tools.is_empty()); + assert!(mistral_request.tool_choice.is_none()); + + assert_eq!(mistral_request.messages.len(), 2); + + match &mistral_request.messages[0] { + mistral::RequestMessage::System { content } => { + assert_eq!(content, "You are a helpful assistant."); + } + _ => panic!("Expected System message"), + } + + match &mistral_request.messages[1] { + mistral::RequestMessage::User { content } => { + assert_eq!(content, "Hello, how are you?"); + } + _ => panic!("Expected User message"), + } + } +} diff --git a/crates/mistral/src/mistral.rs b/crates/mistral/src/mistral.rs index 3dbe3a5d88..1e2667233c 100644 --- a/crates/mistral/src/mistral.rs +++ b/crates/mistral/src/mistral.rs @@ -67,6 +67,7 @@ pub enum Model { max_tokens: usize, max_output_tokens: Option, max_completion_tokens: Option, + supports_tools: Option, }, } @@ -133,6 +134,18 @@ impl Model { _ => None, } } + + pub fn supports_tools(&self) -> bool { + match self { + Self::CodestralLatest + | Self::MistralLargeLatest + | Self::MistralMediumLatest + | Self::MistralSmallLatest + | Self::OpenMistralNemo + | Self::OpenCodestralMamba => true, + Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false), + } + } } #[derive(Debug, Serialize, Deserialize)] @@ -146,6 +159,10 @@ pub struct Request { pub temperature: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub response_format: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub tools: Vec, } @@ -190,12 +207,13 @@ pub enum Prediction { } #[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] +#[serde(rename_all = "snake_case")] pub enum ToolChoice { Auto, Required, None, - Other(ToolDefinition), + Any, + Function(ToolDefinition), } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] diff --git a/docs/src/ai/configuration.md b/docs/src/ai/configuration.md index b6b23e2c6d..08eb55d410 100644 --- a/docs/src/ai/configuration.md +++ b/docs/src/ai/configuration.md @@ -14,6 +14,7 @@ Here's an overview of the supported providers and tool call support: | [Anthropic](#anthropic) | ✅ | | [GitHub Copilot Chat](#github-copilot-chat) | In Some Cases | | [Google AI](#google-ai) | ✅ | +| [Mistral](#mistral) | ✅ | | [Ollama](#ollama) | ✅ | | [OpenAI](#openai) | ✅ | | [DeepSeek](#deepseek) | 🚫 | @@ -128,6 +129,44 @@ By default Zed will use `stable` versions of models, but you can use specific ve Custom models will be listed in the model dropdown in the Agent Panel. +### Mistral {#mistral} + +> 🔨Supports tool use + +1. Visit the Mistral platform and [create an API key](https://console.mistral.ai/api-keys/) +2. Open the configuration view (`assistant: show configuration`) and navigate to the Mistral section +3. Enter your Mistral API key + +The Mistral API key will be saved in your keychain. + +Zed will also use the `MISTRAL_API_KEY` environment variable if it's defined. + +#### Mistral Custom Models {#mistral-custom-models} + +The Zed Assistant comes pre-configured with several Mistral models (codestral-latest, mistral-large-latest, mistral-medium-latest, mistral-small-latest, open-mistral-nemo, and open-codestral-mamba). All the default models support tool use. If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "mistral": { + "api_url": "https://api.mistral.ai/v1", + "available_models": [ + { + "name": "mistral-tiny-latest", + "display_name": "Mistral Tiny", + "max_tokens": 32000, + "max_output_tokens": 4096, + "max_completion_tokens": 1024, + "supports_tools": true + } + ] + } + } +} +``` + +Custom models will be listed in the model dropdown in the assistant panel. + ### Ollama {#ollama} > ✅ Supports tool use