From 926f377c6cea604446d843c5e7f385219feb062a Mon Sep 17 00:00:00 2001
From: Umesh Yadav <23421535+imumesh18@users.noreply.github.com>
Date: Mon, 19 May 2025 22:06:59 +0530
Subject: [PATCH] language_models: Add tool use support for Mistral models
(#29994)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Closes https://github.com/zed-industries/zed/issues/29855
Implement tool use handling in Mistral provider, including mapping tool
call events and updating request construction. Add support for
tool_choice and parallel_tool_calls in Mistral API requests.
This works fine with all the existing models. Didn't touched anything
else but for future. Fetching models using their models api, deducting
tool call support, parallel tool calls etc should be done from model
data from api response.
Tasks:
- [x] Add tool call support
- [x] Auto Fetch models using mistral api
- [x] Add tests for mistral crates.
- [x] Fix mistral configurations for llm providers.
Release Notes:
- agent: Add tool call support for existing mistral models
---------
Co-authored-by: Peter Tripp
Co-authored-by: Bennet Bo Fenner
---
Cargo.lock | 1 +
crates/assistant_settings/Cargo.toml | 1 +
.../src/assistant_settings.rs | 13 +
.../language_models/src/provider/mistral.rs | 321 +++++++++++++++---
crates/mistral/src/mistral.rs | 22 +-
docs/src/ai/configuration.md | 39 +++
6 files changed, 347 insertions(+), 50 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index 504cb2a573..09f58daabd 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -546,6 +546,7 @@ dependencies = [
"language_model",
"lmstudio",
"log",
+ "mistral",
"ollama",
"open_ai",
"paths",
diff --git a/crates/assistant_settings/Cargo.toml b/crates/assistant_settings/Cargo.toml
index 8a8316fae0..c46ea64630 100644
--- a/crates/assistant_settings/Cargo.toml
+++ b/crates/assistant_settings/Cargo.toml
@@ -23,6 +23,7 @@ log.workspace = true
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
deepseek = { workspace = true, features = ["schemars"] }
+mistral = { workspace = true, features = ["schemars"] }
schemars.workspace = true
serde.workspace = true
settings.workspace = true
diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs
index ad9c1e6d62..f7fd1a1ead 100644
--- a/crates/assistant_settings/src/assistant_settings.rs
+++ b/crates/assistant_settings/src/assistant_settings.rs
@@ -10,6 +10,7 @@ use deepseek::Model as DeepseekModel;
use gpui::{App, Pixels, SharedString};
use language_model::{CloudModel, LanguageModel};
use lmstudio::Model as LmStudioModel;
+use mistral::Model as MistralModel;
use ollama::Model as OllamaModel;
use schemars::{JsonSchema, schema::Schema};
use serde::{Deserialize, Serialize};
@@ -71,6 +72,11 @@ pub enum AssistantProviderContentV1 {
default_model: Option,
api_url: Option,
},
+ #[serde(rename = "mistral")]
+ Mistral {
+ default_model: Option,
+ api_url: Option,
+ },
}
#[derive(Default, Clone, Debug)]
@@ -249,6 +255,12 @@ impl AssistantSettingsContent {
model: model.id().to_string(),
})
}
+ AssistantProviderContentV1::Mistral { default_model, .. } => {
+ default_model.map(|model| LanguageModelSelection {
+ provider: "mistral".into(),
+ model: model.id().to_string(),
+ })
+ }
}),
inline_assistant_model: None,
commit_message_model: None,
@@ -700,6 +712,7 @@ impl JsonSchema for LanguageModelProviderSetting {
"zed.dev".into(),
"copilot_chat".into(),
"deepseek".into(),
+ "mistral".into(),
]),
..Default::default()
}
diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs
index 5143767e9e..93317d1a51 100644
--- a/crates/language_models/src/provider/mistral.rs
+++ b/crates/language_models/src/provider/mistral.rs
@@ -2,6 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
+use futures::stream::BoxStream;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
@@ -11,13 +12,13 @@ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
- LanguageModelToolChoice, RateLimiter, Role,
+ LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
+ RateLimiter, Role, StopReason,
};
-
-use futures::stream::BoxStream;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
+use std::str::FromStr;
use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
@@ -26,6 +27,9 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
+use std::collections::HashMap;
+use std::pin::Pin;
+
const PROVIDER_ID: &str = "mistral";
const PROVIDER_NAME: &str = "Mistral";
@@ -43,6 +47,7 @@ pub struct AvailableModel {
pub max_tokens: usize,
pub max_output_tokens: Option,
pub max_completion_tokens: Option,
+ pub supports_tools: Option,
}
pub struct MistralLanguageModelProvider {
@@ -209,6 +214,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
max_completion_tokens: model.max_completion_tokens,
+ supports_tools: model.supports_tools,
},
);
}
@@ -300,14 +306,14 @@ impl LanguageModel for MistralLanguageModel {
}
fn supports_tools(&self) -> bool {
- false
- }
-
- fn supports_images(&self) -> bool {
- false
+ self.model.supports_tools()
}
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
+ self.model.supports_tools()
+ }
+
+ fn supports_images(&self) -> bool {
false
}
@@ -368,26 +374,8 @@ impl LanguageModel for MistralLanguageModel {
async move {
let stream = stream.await?;
- Ok(stream
- .map(|result| {
- result
- .and_then(|response| {
- response
- .choices
- .first()
- .ok_or_else(|| anyhow!("Empty response"))
- .map(|choice| {
- choice
- .delta
- .content
- .clone()
- .unwrap_or_default()
- .map(LanguageModelCompletionEvent::Text)
- })
- })
- .map_err(LanguageModelCompletionError::Other)
- })
- .boxed())
+ let mapper = MistralEventMapper::new();
+ Ok(mapper.map_stream(stream).boxed())
}
.boxed()
}
@@ -398,33 +386,87 @@ pub fn into_mistral(
model: String,
max_output_tokens: Option,
) -> mistral::Request {
- let len = request.messages.len();
- let merged_messages =
- request
- .messages
- .into_iter()
- .fold(Vec::with_capacity(len), |mut acc, msg| {
- let role = msg.role;
- let content = msg.string_contents();
+ let stream = true;
- acc.push(match role {
- Role::User => mistral::RequestMessage::User { content },
- Role::Assistant => mistral::RequestMessage::Assistant {
- content: Some(content),
- tool_calls: Vec::new(),
- },
- Role::System => mistral::RequestMessage::System { content },
- });
- acc
- });
+ let mut messages = Vec::new();
+ for message in request.messages {
+ for content in message.content {
+ match content {
+ MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
+ .push(match message.role {
+ Role::User => mistral::RequestMessage::User { content: text },
+ Role::Assistant => mistral::RequestMessage::Assistant {
+ content: Some(text),
+ tool_calls: Vec::new(),
+ },
+ Role::System => mistral::RequestMessage::System { content: text },
+ }),
+ MessageContent::RedactedThinking(_) => {}
+ MessageContent::Image(_) => {}
+ MessageContent::ToolUse(tool_use) => {
+ let tool_call = mistral::ToolCall {
+ id: tool_use.id.to_string(),
+ content: mistral::ToolCallContent::Function {
+ function: mistral::FunctionContent {
+ name: tool_use.name.to_string(),
+ arguments: serde_json::to_string(&tool_use.input)
+ .unwrap_or_default(),
+ },
+ },
+ };
+
+ if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
+ messages.last_mut()
+ {
+ tool_calls.push(tool_call);
+ } else {
+ messages.push(mistral::RequestMessage::Assistant {
+ content: None,
+ tool_calls: vec![tool_call],
+ });
+ }
+ }
+ MessageContent::ToolResult(tool_result) => {
+ let content = match &tool_result.content {
+ LanguageModelToolResultContent::Text(text) => text.to_string(),
+ LanguageModelToolResultContent::Image(_) => {
+ // TODO: Mistral image support
+ "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
+ }
+ };
+
+ messages.push(mistral::RequestMessage::Tool {
+ content,
+ tool_call_id: tool_result.tool_use_id.to_string(),
+ });
+ }
+ }
+ }
+ }
mistral::Request {
model,
- messages: merged_messages,
- stream: true,
+ messages,
+ stream,
max_tokens: max_output_tokens,
temperature: request.temperature,
response_format: None,
+ tool_choice: match request.tool_choice {
+ Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
+ Some(mistral::ToolChoice::Auto)
+ }
+ Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
+ Some(mistral::ToolChoice::Any)
+ }
+ Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
+ _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
+ _ => None,
+ },
+ parallel_tool_calls: if !request.tools.is_empty() {
+ Some(false)
+ } else {
+ None
+ },
tools: request
.tools
.into_iter()
@@ -439,6 +481,127 @@ pub fn into_mistral(
}
}
+pub struct MistralEventMapper {
+ tool_calls_by_index: HashMap,
+}
+
+impl MistralEventMapper {
+ pub fn new() -> Self {
+ Self {
+ tool_calls_by_index: HashMap::default(),
+ }
+ }
+
+ pub fn map_stream(
+ mut self,
+ events: Pin>>>,
+ ) -> impl futures::Stream- >
+ {
+ events.flat_map(move |event| {
+ futures::stream::iter(match event {
+ Ok(event) => self.map_event(event),
+ Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+ })
+ })
+ }
+
+ pub fn map_event(
+ &mut self,
+ event: mistral::StreamResponse,
+ ) -> Vec> {
+ let Some(choice) = event.choices.first() else {
+ return vec![Err(LanguageModelCompletionError::Other(anyhow!(
+ "Response contained no choices"
+ )))];
+ };
+
+ let mut events = Vec::new();
+ if let Some(content) = choice.delta.content.clone() {
+ events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+ }
+
+ if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
+ for tool_call in tool_calls {
+ let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
+
+ if let Some(tool_id) = tool_call.id.clone() {
+ entry.id = tool_id;
+ }
+
+ if let Some(function) = tool_call.function.as_ref() {
+ if let Some(name) = function.name.clone() {
+ entry.name = name;
+ }
+
+ if let Some(arguments) = function.arguments.clone() {
+ entry.arguments.push_str(&arguments);
+ }
+ }
+ }
+ }
+
+ if let Some(finish_reason) = choice.finish_reason.as_deref() {
+ match finish_reason {
+ "stop" => {
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ "tool_calls" => {
+ events.extend(self.process_tool_calls());
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
+ }
+ unexpected => {
+ log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ }
+ }
+
+ events
+ }
+
+ fn process_tool_calls(
+ &mut self,
+ ) -> Vec> {
+ let mut results = Vec::new();
+
+ for (_, tool_call) in self.tool_calls_by_index.drain() {
+ if tool_call.id.is_empty() || tool_call.name.is_empty() {
+ results.push(Err(LanguageModelCompletionError::Other(anyhow!(
+ "Received incomplete tool call: missing id or name"
+ ))));
+ continue;
+ }
+
+ match serde_json::Value::from_str(&tool_call.arguments) {
+ Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_call.id.into(),
+ name: tool_call.name.into(),
+ is_input_complete: true,
+ input,
+ raw_input: tool_call.arguments,
+ },
+ ))),
+ Err(error) => results.push(Err(LanguageModelCompletionError::BadInputJson {
+ id: tool_call.id.into(),
+ tool_name: tool_call.name.into(),
+ raw_input: tool_call.arguments.into(),
+ json_parse_error: error.to_string(),
+ })),
+ }
+ }
+
+ results
+ }
+}
+
+#[derive(Default)]
+struct RawToolCall {
+ id: String,
+ name: String,
+ arguments: String,
+}
+
struct ConfigurationView {
api_key_editor: Entity,
state: gpui::Entity,
@@ -623,3 +786,65 @@ impl Render for ConfigurationView {
}
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use language_model;
+
+ #[test]
+ fn test_into_mistral_conversion() {
+ let request = language_model::LanguageModelRequest {
+ messages: vec![
+ language_model::LanguageModelRequestMessage {
+ role: language_model::Role::System,
+ content: vec![language_model::MessageContent::Text(
+ "You are a helpful assistant.".to_string(),
+ )],
+ cache: false,
+ },
+ language_model::LanguageModelRequestMessage {
+ role: language_model::Role::User,
+ content: vec![language_model::MessageContent::Text(
+ "Hello, how are you?".to_string(),
+ )],
+ cache: false,
+ },
+ ],
+ temperature: Some(0.7),
+ tools: Vec::new(),
+ tool_choice: None,
+ thread_id: None,
+ prompt_id: None,
+ mode: None,
+ stop: Vec::new(),
+ };
+
+ let model_name = "mistral-medium-latest".to_string();
+ let max_output_tokens = Some(1000);
+ let mistral_request = into_mistral(request, model_name, max_output_tokens);
+
+ assert_eq!(mistral_request.model, "mistral-medium-latest");
+ assert_eq!(mistral_request.temperature, Some(0.7));
+ assert_eq!(mistral_request.max_tokens, Some(1000));
+ assert!(mistral_request.stream);
+ assert!(mistral_request.tools.is_empty());
+ assert!(mistral_request.tool_choice.is_none());
+
+ assert_eq!(mistral_request.messages.len(), 2);
+
+ match &mistral_request.messages[0] {
+ mistral::RequestMessage::System { content } => {
+ assert_eq!(content, "You are a helpful assistant.");
+ }
+ _ => panic!("Expected System message"),
+ }
+
+ match &mistral_request.messages[1] {
+ mistral::RequestMessage::User { content } => {
+ assert_eq!(content, "Hello, how are you?");
+ }
+ _ => panic!("Expected User message"),
+ }
+ }
+}
diff --git a/crates/mistral/src/mistral.rs b/crates/mistral/src/mistral.rs
index 3dbe3a5d88..1e2667233c 100644
--- a/crates/mistral/src/mistral.rs
+++ b/crates/mistral/src/mistral.rs
@@ -67,6 +67,7 @@ pub enum Model {
max_tokens: usize,
max_output_tokens: Option,
max_completion_tokens: Option,
+ supports_tools: Option,
},
}
@@ -133,6 +134,18 @@ impl Model {
_ => None,
}
}
+
+ pub fn supports_tools(&self) -> bool {
+ match self {
+ Self::CodestralLatest
+ | Self::MistralLargeLatest
+ | Self::MistralMediumLatest
+ | Self::MistralSmallLatest
+ | Self::OpenMistralNemo
+ | Self::OpenCodestralMamba => true,
+ Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
+ }
+ }
}
#[derive(Debug, Serialize, Deserialize)]
@@ -146,6 +159,10 @@ pub struct Request {
pub temperature: Option,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub response_format: Option,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub tool_choice: Option,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub parallel_tool_calls: Option,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec,
}
@@ -190,12 +207,13 @@ pub enum Prediction {
}
#[derive(Debug, Serialize, Deserialize)]
-#[serde(untagged)]
+#[serde(rename_all = "snake_case")]
pub enum ToolChoice {
Auto,
Required,
None,
- Other(ToolDefinition),
+ Any,
+ Function(ToolDefinition),
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
diff --git a/docs/src/ai/configuration.md b/docs/src/ai/configuration.md
index b6b23e2c6d..08eb55d410 100644
--- a/docs/src/ai/configuration.md
+++ b/docs/src/ai/configuration.md
@@ -14,6 +14,7 @@ Here's an overview of the supported providers and tool call support:
| [Anthropic](#anthropic) | ✅ |
| [GitHub Copilot Chat](#github-copilot-chat) | In Some Cases |
| [Google AI](#google-ai) | ✅ |
+| [Mistral](#mistral) | ✅ |
| [Ollama](#ollama) | ✅ |
| [OpenAI](#openai) | ✅ |
| [DeepSeek](#deepseek) | 🚫 |
@@ -128,6 +129,44 @@ By default Zed will use `stable` versions of models, but you can use specific ve
Custom models will be listed in the model dropdown in the Agent Panel.
+### Mistral {#mistral}
+
+> 🔨Supports tool use
+
+1. Visit the Mistral platform and [create an API key](https://console.mistral.ai/api-keys/)
+2. Open the configuration view (`assistant: show configuration`) and navigate to the Mistral section
+3. Enter your Mistral API key
+
+The Mistral API key will be saved in your keychain.
+
+Zed will also use the `MISTRAL_API_KEY` environment variable if it's defined.
+
+#### Mistral Custom Models {#mistral-custom-models}
+
+The Zed Assistant comes pre-configured with several Mistral models (codestral-latest, mistral-large-latest, mistral-medium-latest, mistral-small-latest, open-mistral-nemo, and open-codestral-mamba). All the default models support tool use. If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`:
+
+```json
+{
+ "language_models": {
+ "mistral": {
+ "api_url": "https://api.mistral.ai/v1",
+ "available_models": [
+ {
+ "name": "mistral-tiny-latest",
+ "display_name": "Mistral Tiny",
+ "max_tokens": 32000,
+ "max_output_tokens": 4096,
+ "max_completion_tokens": 1024,
+ "supports_tools": true
+ }
+ ]
+ }
+ }
+}
+```
+
+Custom models will be listed in the model dropdown in the assistant panel.
+
### Ollama {#ollama}
> ✅ Supports tool use