language_models: Add tool use support for Mistral models (#29994)

Closes https://github.com/zed-industries/zed/issues/29855

Implement tool use handling in Mistral provider, including mapping tool
call events and updating request construction. Add support for
tool_choice and parallel_tool_calls in Mistral API requests.

This works fine with all the existing models. Didn't touched anything
else but for future. Fetching models using their models api, deducting
tool call support, parallel tool calls etc should be done from model
data from api response.

<img width="547" alt="Screenshot 2025-05-06 at 4 52 37 PM"
src="https://github.com/user-attachments/assets/4c08b544-1174-40cc-a40d-522989953448"
/>

Tasks:

- [x] Add tool call support
- [x] Auto Fetch models using mistral api
- [x] Add tests for mistral crates.
- [x] Fix mistral configurations for llm providers.

Release Notes:

- agent: Add tool call support for existing mistral models

---------

Co-authored-by: Peter Tripp <peter@zed.dev>
Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
This commit is contained in:
Umesh Yadav 2025-05-19 22:06:59 +05:30 committed by GitHub
parent 26a8cac0d8
commit 926f377c6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 347 additions and 50 deletions

1
Cargo.lock generated
View file

@ -546,6 +546,7 @@ dependencies = [
"language_model", "language_model",
"lmstudio", "lmstudio",
"log", "log",
"mistral",
"ollama", "ollama",
"open_ai", "open_ai",
"paths", "paths",

View file

@ -23,6 +23,7 @@ log.workspace = true
ollama = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] }
deepseek = { workspace = true, features = ["schemars"] } deepseek = { workspace = true, features = ["schemars"] }
mistral = { workspace = true, features = ["schemars"] }
schemars.workspace = true schemars.workspace = true
serde.workspace = true serde.workspace = true
settings.workspace = true settings.workspace = true

View file

@ -10,6 +10,7 @@ use deepseek::Model as DeepseekModel;
use gpui::{App, Pixels, SharedString}; use gpui::{App, Pixels, SharedString};
use language_model::{CloudModel, LanguageModel}; use language_model::{CloudModel, LanguageModel};
use lmstudio::Model as LmStudioModel; use lmstudio::Model as LmStudioModel;
use mistral::Model as MistralModel;
use ollama::Model as OllamaModel; use ollama::Model as OllamaModel;
use schemars::{JsonSchema, schema::Schema}; use schemars::{JsonSchema, schema::Schema};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -71,6 +72,11 @@ pub enum AssistantProviderContentV1 {
default_model: Option<DeepseekModel>, default_model: Option<DeepseekModel>,
api_url: Option<String>, api_url: Option<String>,
}, },
#[serde(rename = "mistral")]
Mistral {
default_model: Option<MistralModel>,
api_url: Option<String>,
},
} }
#[derive(Default, Clone, Debug)] #[derive(Default, Clone, Debug)]
@ -249,6 +255,12 @@ impl AssistantSettingsContent {
model: model.id().to_string(), model: model.id().to_string(),
}) })
} }
AssistantProviderContentV1::Mistral { default_model, .. } => {
default_model.map(|model| LanguageModelSelection {
provider: "mistral".into(),
model: model.id().to_string(),
})
}
}), }),
inline_assistant_model: None, inline_assistant_model: None,
commit_message_model: None, commit_message_model: None,
@ -700,6 +712,7 @@ impl JsonSchema for LanguageModelProviderSetting {
"zed.dev".into(), "zed.dev".into(),
"copilot_chat".into(), "copilot_chat".into(),
"deepseek".into(), "deepseek".into(),
"mistral".into(),
]), ]),
..Default::default() ..Default::default()
} }

View file

@ -2,6 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
use collections::BTreeMap; use collections::BTreeMap;
use credentials_provider::CredentialsProvider; use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle}; use editor::{Editor, EditorElement, EditorStyle};
use futures::stream::BoxStream;
use futures::{FutureExt, StreamExt, future::BoxFuture}; use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{ use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
@ -11,13 +12,13 @@ use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, RateLimiter, Role, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason,
}; };
use futures::stream::BoxStream;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use theme::ThemeSettings; use theme::ThemeSettings;
@ -26,6 +27,9 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem}; use crate::{AllLanguageModelSettings, ui::InstructionListItem};
use std::collections::HashMap;
use std::pin::Pin;
const PROVIDER_ID: &str = "mistral"; const PROVIDER_ID: &str = "mistral";
const PROVIDER_NAME: &str = "Mistral"; const PROVIDER_NAME: &str = "Mistral";
@ -43,6 +47,7 @@ pub struct AvailableModel {
pub max_tokens: usize, pub max_tokens: usize,
pub max_output_tokens: Option<u32>, pub max_output_tokens: Option<u32>,
pub max_completion_tokens: Option<u32>, pub max_completion_tokens: Option<u32>,
pub supports_tools: Option<bool>,
} }
pub struct MistralLanguageModelProvider { pub struct MistralLanguageModelProvider {
@ -209,6 +214,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
max_tokens: model.max_tokens, max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens, max_output_tokens: model.max_output_tokens,
max_completion_tokens: model.max_completion_tokens, max_completion_tokens: model.max_completion_tokens,
supports_tools: model.supports_tools,
}, },
); );
} }
@ -300,14 +306,14 @@ impl LanguageModel for MistralLanguageModel {
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
false self.model.supports_tools()
}
fn supports_images(&self) -> bool {
false
} }
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool { fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
self.model.supports_tools()
}
fn supports_images(&self) -> bool {
false false
} }
@ -368,26 +374,8 @@ impl LanguageModel for MistralLanguageModel {
async move { async move {
let stream = stream.await?; let stream = stream.await?;
Ok(stream let mapper = MistralEventMapper::new();
.map(|result| { Ok(mapper.map_stream(stream).boxed())
result
.and_then(|response| {
response
.choices
.first()
.ok_or_else(|| anyhow!("Empty response"))
.map(|choice| {
choice
.delta
.content
.clone()
.unwrap_or_default()
.map(LanguageModelCompletionEvent::Text)
})
})
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
} }
.boxed() .boxed()
} }
@ -398,33 +386,87 @@ pub fn into_mistral(
model: String, model: String,
max_output_tokens: Option<u32>, max_output_tokens: Option<u32>,
) -> mistral::Request { ) -> mistral::Request {
let len = request.messages.len(); let stream = true;
let merged_messages =
request
.messages
.into_iter()
.fold(Vec::with_capacity(len), |mut acc, msg| {
let role = msg.role;
let content = msg.string_contents();
acc.push(match role { let mut messages = Vec::new();
Role::User => mistral::RequestMessage::User { content }, for message in request.messages {
for content in message.content {
match content {
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
.push(match message.role {
Role::User => mistral::RequestMessage::User { content: text },
Role::Assistant => mistral::RequestMessage::Assistant { Role::Assistant => mistral::RequestMessage::Assistant {
content: Some(content), content: Some(text),
tool_calls: Vec::new(), tool_calls: Vec::new(),
}, },
Role::System => mistral::RequestMessage::System { content }, Role::System => mistral::RequestMessage::System { content: text },
}),
MessageContent::RedactedThinking(_) => {}
MessageContent::Image(_) => {}
MessageContent::ToolUse(tool_use) => {
let tool_call = mistral::ToolCall {
id: tool_use.id.to_string(),
content: mistral::ToolCallContent::Function {
function: mistral::FunctionContent {
name: tool_use.name.to_string(),
arguments: serde_json::to_string(&tool_use.input)
.unwrap_or_default(),
},
},
};
if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
messages.last_mut()
{
tool_calls.push(tool_call);
} else {
messages.push(mistral::RequestMessage::Assistant {
content: None,
tool_calls: vec![tool_call],
}); });
acc }
}
MessageContent::ToolResult(tool_result) => {
let content = match &tool_result.content {
LanguageModelToolResultContent::Text(text) => text.to_string(),
LanguageModelToolResultContent::Image(_) => {
// TODO: Mistral image support
"[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
}
};
messages.push(mistral::RequestMessage::Tool {
content,
tool_call_id: tool_result.tool_use_id.to_string(),
}); });
}
}
}
}
mistral::Request { mistral::Request {
model, model,
messages: merged_messages, messages,
stream: true, stream,
max_tokens: max_output_tokens, max_tokens: max_output_tokens,
temperature: request.temperature, temperature: request.temperature,
response_format: None, response_format: None,
tool_choice: match request.tool_choice {
Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
Some(mistral::ToolChoice::Auto)
}
Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
Some(mistral::ToolChoice::Any)
}
Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
_ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
_ => None,
},
parallel_tool_calls: if !request.tools.is_empty() {
Some(false)
} else {
None
},
tools: request tools: request
.tools .tools
.into_iter() .into_iter()
@ -439,6 +481,127 @@ pub fn into_mistral(
} }
} }
pub struct MistralEventMapper {
tool_calls_by_index: HashMap<usize, RawToolCall>,
}
impl MistralEventMapper {
pub fn new() -> Self {
Self {
tool_calls_by_index: HashMap::default(),
}
}
pub fn map_stream(
mut self,
events: Pin<Box<dyn Send + futures::Stream<Item = Result<mistral::StreamResponse>>>>,
) -> impl futures::Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
})
})
}
pub fn map_event(
&mut self,
event: mistral::StreamResponse,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.first() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
"Response contained no choices"
)))];
};
let mut events = Vec::new();
if let Some(content) = choice.delta.content.clone() {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
for tool_call in tool_calls {
let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
if let Some(tool_id) = tool_call.id.clone() {
entry.id = tool_id;
}
if let Some(function) = tool_call.function.as_ref() {
if let Some(name) = function.name.clone() {
entry.name = name;
}
if let Some(arguments) = function.arguments.clone() {
entry.arguments.push_str(&arguments);
}
}
}
}
if let Some(finish_reason) = choice.finish_reason.as_deref() {
match finish_reason {
"stop" => {
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
}
"tool_calls" => {
events.extend(self.process_tool_calls());
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
}
unexpected => {
log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
}
}
}
events
}
fn process_tool_calls(
&mut self,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let mut results = Vec::new();
for (_, tool_call) in self.tool_calls_by_index.drain() {
if tool_call.id.is_empty() || tool_call.name.is_empty() {
results.push(Err(LanguageModelCompletionError::Other(anyhow!(
"Received incomplete tool call: missing id or name"
))));
continue;
}
match serde_json::Value::from_str(&tool_call.arguments) {
Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.into(),
name: tool_call.name.into(),
is_input_complete: true,
input,
raw_input: tool_call.arguments,
},
))),
Err(error) => results.push(Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
tool_name: tool_call.name.into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
})),
}
}
results
}
}
#[derive(Default)]
struct RawToolCall {
id: String,
name: String,
arguments: String,
}
struct ConfigurationView { struct ConfigurationView {
api_key_editor: Entity<Editor>, api_key_editor: Entity<Editor>,
state: gpui::Entity<State>, state: gpui::Entity<State>,
@ -623,3 +786,65 @@ impl Render for ConfigurationView {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use language_model;
#[test]
fn test_into_mistral_conversion() {
let request = language_model::LanguageModelRequest {
messages: vec![
language_model::LanguageModelRequestMessage {
role: language_model::Role::System,
content: vec![language_model::MessageContent::Text(
"You are a helpful assistant.".to_string(),
)],
cache: false,
},
language_model::LanguageModelRequestMessage {
role: language_model::Role::User,
content: vec![language_model::MessageContent::Text(
"Hello, how are you?".to_string(),
)],
cache: false,
},
],
temperature: Some(0.7),
tools: Vec::new(),
tool_choice: None,
thread_id: None,
prompt_id: None,
mode: None,
stop: Vec::new(),
};
let model_name = "mistral-medium-latest".to_string();
let max_output_tokens = Some(1000);
let mistral_request = into_mistral(request, model_name, max_output_tokens);
assert_eq!(mistral_request.model, "mistral-medium-latest");
assert_eq!(mistral_request.temperature, Some(0.7));
assert_eq!(mistral_request.max_tokens, Some(1000));
assert!(mistral_request.stream);
assert!(mistral_request.tools.is_empty());
assert!(mistral_request.tool_choice.is_none());
assert_eq!(mistral_request.messages.len(), 2);
match &mistral_request.messages[0] {
mistral::RequestMessage::System { content } => {
assert_eq!(content, "You are a helpful assistant.");
}
_ => panic!("Expected System message"),
}
match &mistral_request.messages[1] {
mistral::RequestMessage::User { content } => {
assert_eq!(content, "Hello, how are you?");
}
_ => panic!("Expected User message"),
}
}
}

View file

@ -67,6 +67,7 @@ pub enum Model {
max_tokens: usize, max_tokens: usize,
max_output_tokens: Option<u32>, max_output_tokens: Option<u32>,
max_completion_tokens: Option<u32>, max_completion_tokens: Option<u32>,
supports_tools: Option<bool>,
}, },
} }
@ -133,6 +134,18 @@ impl Model {
_ => None, _ => None,
} }
} }
pub fn supports_tools(&self) -> bool {
match self {
Self::CodestralLatest
| Self::MistralLargeLatest
| Self::MistralMediumLatest
| Self::MistralSmallLatest
| Self::OpenMistralNemo
| Self::OpenCodestralMamba => true,
Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
}
}
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -146,6 +159,10 @@ pub struct Request {
pub temperature: Option<f32>, pub temperature: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>, pub response_format: Option<ResponseFormat>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(default, skip_serializing_if = "Vec::is_empty")] #[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<ToolDefinition>, pub tools: Vec<ToolDefinition>,
} }
@ -190,12 +207,13 @@ pub enum Prediction {
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)] #[serde(rename_all = "snake_case")]
pub enum ToolChoice { pub enum ToolChoice {
Auto, Auto,
Required, Required,
None, None,
Other(ToolDefinition), Any,
Function(ToolDefinition),
} }
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]

View file

@ -14,6 +14,7 @@ Here's an overview of the supported providers and tool call support:
| [Anthropic](#anthropic) | ✅ | | [Anthropic](#anthropic) | ✅ |
| [GitHub Copilot Chat](#github-copilot-chat) | In Some Cases | | [GitHub Copilot Chat](#github-copilot-chat) | In Some Cases |
| [Google AI](#google-ai) | ✅ | | [Google AI](#google-ai) | ✅ |
| [Mistral](#mistral) | ✅ |
| [Ollama](#ollama) | ✅ | | [Ollama](#ollama) | ✅ |
| [OpenAI](#openai) | ✅ | | [OpenAI](#openai) | ✅ |
| [DeepSeek](#deepseek) | 🚫 | | [DeepSeek](#deepseek) | 🚫 |
@ -128,6 +129,44 @@ By default Zed will use `stable` versions of models, but you can use specific ve
Custom models will be listed in the model dropdown in the Agent Panel. Custom models will be listed in the model dropdown in the Agent Panel.
### Mistral {#mistral}
> 🔨Supports tool use
1. Visit the Mistral platform and [create an API key](https://console.mistral.ai/api-keys/)
2. Open the configuration view (`assistant: show configuration`) and navigate to the Mistral section
3. Enter your Mistral API key
The Mistral API key will be saved in your keychain.
Zed will also use the `MISTRAL_API_KEY` environment variable if it's defined.
#### Mistral Custom Models {#mistral-custom-models}
The Zed Assistant comes pre-configured with several Mistral models (codestral-latest, mistral-large-latest, mistral-medium-latest, mistral-small-latest, open-mistral-nemo, and open-codestral-mamba). All the default models support tool use. If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`:
```json
{
"language_models": {
"mistral": {
"api_url": "https://api.mistral.ai/v1",
"available_models": [
{
"name": "mistral-tiny-latest",
"display_name": "Mistral Tiny",
"max_tokens": 32000,
"max_output_tokens": 4096,
"max_completion_tokens": 1024,
"supports_tools": true
}
]
}
}
}
```
Custom models will be listed in the model dropdown in the assistant panel.
### Ollama {#ollama} ### Ollama {#ollama}
> ✅ Supports tool use > ✅ Supports tool use