language_models: Add tool use support for Mistral models (#29994)
Closes https://github.com/zed-industries/zed/issues/29855 Implement tool use handling in Mistral provider, including mapping tool call events and updating request construction. Add support for tool_choice and parallel_tool_calls in Mistral API requests. This works fine with all the existing models. Didn't touched anything else but for future. Fetching models using their models api, deducting tool call support, parallel tool calls etc should be done from model data from api response. <img width="547" alt="Screenshot 2025-05-06 at 4 52 37 PM" src="https://github.com/user-attachments/assets/4c08b544-1174-40cc-a40d-522989953448" /> Tasks: - [x] Add tool call support - [x] Auto Fetch models using mistral api - [x] Add tests for mistral crates. - [x] Fix mistral configurations for llm providers. Release Notes: - agent: Add tool call support for existing mistral models --------- Co-authored-by: Peter Tripp <peter@zed.dev> Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
This commit is contained in:
parent
26a8cac0d8
commit
926f377c6c
6 changed files with 347 additions and 50 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -546,6 +546,7 @@ dependencies = [
|
|||
"language_model",
|
||||
"lmstudio",
|
||||
"log",
|
||||
"mistral",
|
||||
"ollama",
|
||||
"open_ai",
|
||||
"paths",
|
||||
|
|
|
@ -23,6 +23,7 @@ log.workspace = true
|
|||
ollama = { workspace = true, features = ["schemars"] }
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
deepseek = { workspace = true, features = ["schemars"] }
|
||||
mistral = { workspace = true, features = ["schemars"] }
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
settings.workspace = true
|
||||
|
|
|
@ -10,6 +10,7 @@ use deepseek::Model as DeepseekModel;
|
|||
use gpui::{App, Pixels, SharedString};
|
||||
use language_model::{CloudModel, LanguageModel};
|
||||
use lmstudio::Model as LmStudioModel;
|
||||
use mistral::Model as MistralModel;
|
||||
use ollama::Model as OllamaModel;
|
||||
use schemars::{JsonSchema, schema::Schema};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -71,6 +72,11 @@ pub enum AssistantProviderContentV1 {
|
|||
default_model: Option<DeepseekModel>,
|
||||
api_url: Option<String>,
|
||||
},
|
||||
#[serde(rename = "mistral")]
|
||||
Mistral {
|
||||
default_model: Option<MistralModel>,
|
||||
api_url: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug)]
|
||||
|
@ -249,6 +255,12 @@ impl AssistantSettingsContent {
|
|||
model: model.id().to_string(),
|
||||
})
|
||||
}
|
||||
AssistantProviderContentV1::Mistral { default_model, .. } => {
|
||||
default_model.map(|model| LanguageModelSelection {
|
||||
provider: "mistral".into(),
|
||||
model: model.id().to_string(),
|
||||
})
|
||||
}
|
||||
}),
|
||||
inline_assistant_model: None,
|
||||
commit_message_model: None,
|
||||
|
@ -700,6 +712,7 @@ impl JsonSchema for LanguageModelProviderSetting {
|
|||
"zed.dev".into(),
|
||||
"copilot_chat".into(),
|
||||
"deepseek".into(),
|
||||
"mistral".into(),
|
||||
]),
|
||||
..Default::default()
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
|
|||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::stream::BoxStream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||
|
@ -11,13 +12,13 @@ use language_model::{
|
|||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, RateLimiter, Role,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
};
|
||||
|
||||
use futures::stream::BoxStream;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
|
@ -26,6 +27,9 @@ use util::ResultExt;
|
|||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
|
||||
const PROVIDER_ID: &str = "mistral";
|
||||
const PROVIDER_NAME: &str = "Mistral";
|
||||
|
||||
|
@ -43,6 +47,7 @@ pub struct AvailableModel {
|
|||
pub max_tokens: usize,
|
||||
pub max_output_tokens: Option<u32>,
|
||||
pub max_completion_tokens: Option<u32>,
|
||||
pub supports_tools: Option<bool>,
|
||||
}
|
||||
|
||||
pub struct MistralLanguageModelProvider {
|
||||
|
@ -209,6 +214,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
|
|||
max_tokens: model.max_tokens,
|
||||
max_output_tokens: model.max_output_tokens,
|
||||
max_completion_tokens: model.max_completion_tokens,
|
||||
supports_tools: model.supports_tools,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -300,14 +306,14 @@ impl LanguageModel for MistralLanguageModel {
|
|||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
self.model.supports_tools()
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
||||
self.model.supports_tools()
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
@ -368,26 +374,8 @@ impl LanguageModel for MistralLanguageModel {
|
|||
|
||||
async move {
|
||||
let stream = stream.await?;
|
||||
Ok(stream
|
||||
.map(|result| {
|
||||
result
|
||||
.and_then(|response| {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("Empty response"))
|
||||
.map(|choice| {
|
||||
choice
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
})
|
||||
})
|
||||
.map_err(LanguageModelCompletionError::Other)
|
||||
})
|
||||
.boxed())
|
||||
let mapper = MistralEventMapper::new();
|
||||
Ok(mapper.map_stream(stream).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
@ -398,33 +386,87 @@ pub fn into_mistral(
|
|||
model: String,
|
||||
max_output_tokens: Option<u32>,
|
||||
) -> mistral::Request {
|
||||
let len = request.messages.len();
|
||||
let merged_messages =
|
||||
request
|
||||
.messages
|
||||
.into_iter()
|
||||
.fold(Vec::with_capacity(len), |mut acc, msg| {
|
||||
let role = msg.role;
|
||||
let content = msg.string_contents();
|
||||
let stream = true;
|
||||
|
||||
acc.push(match role {
|
||||
Role::User => mistral::RequestMessage::User { content },
|
||||
let mut messages = Vec::new();
|
||||
for message in request.messages {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
|
||||
.push(match message.role {
|
||||
Role::User => mistral::RequestMessage::User { content: text },
|
||||
Role::Assistant => mistral::RequestMessage::Assistant {
|
||||
content: Some(content),
|
||||
content: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => mistral::RequestMessage::System { content },
|
||||
Role::System => mistral::RequestMessage::System { content: text },
|
||||
}),
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_) => {}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
let tool_call = mistral::ToolCall {
|
||||
id: tool_use.id.to_string(),
|
||||
content: mistral::ToolCallContent::Function {
|
||||
function: mistral::FunctionContent {
|
||||
name: tool_use.name.to_string(),
|
||||
arguments: serde_json::to_string(&tool_use.input)
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
|
||||
messages.last_mut()
|
||||
{
|
||||
tool_calls.push(tool_call);
|
||||
} else {
|
||||
messages.push(mistral::RequestMessage::Assistant {
|
||||
content: None,
|
||||
tool_calls: vec![tool_call],
|
||||
});
|
||||
acc
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
let content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => text.to_string(),
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
// TODO: Mistral image support
|
||||
"[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
|
||||
}
|
||||
};
|
||||
|
||||
messages.push(mistral::RequestMessage::Tool {
|
||||
content,
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mistral::Request {
|
||||
model,
|
||||
messages: merged_messages,
|
||||
stream: true,
|
||||
messages,
|
||||
stream,
|
||||
max_tokens: max_output_tokens,
|
||||
temperature: request.temperature,
|
||||
response_format: None,
|
||||
tool_choice: match request.tool_choice {
|
||||
Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
|
||||
Some(mistral::ToolChoice::Auto)
|
||||
}
|
||||
Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
|
||||
Some(mistral::ToolChoice::Any)
|
||||
}
|
||||
Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
|
||||
_ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
|
||||
_ => None,
|
||||
},
|
||||
parallel_tool_calls: if !request.tools.is_empty() {
|
||||
Some(false)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
|
@ -439,6 +481,127 @@ pub fn into_mistral(
|
|||
}
|
||||
}
|
||||
|
||||
pub struct MistralEventMapper {
|
||||
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
||||
}
|
||||
|
||||
impl MistralEventMapper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tool_calls_by_index: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + futures::Stream<Item = Result<mistral::StreamResponse>>>>,
|
||||
) -> impl futures::Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Ok(event) => self.map_event(event),
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn map_event(
|
||||
&mut self,
|
||||
event: mistral::StreamResponse,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"Response contained no choices"
|
||||
)))];
|
||||
};
|
||||
|
||||
let mut events = Vec::new();
|
||||
if let Some(content) = choice.delta.content.clone() {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
|
||||
for tool_call in tool_calls {
|
||||
let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
|
||||
|
||||
if let Some(tool_id) = tool_call.id.clone() {
|
||||
entry.id = tool_id;
|
||||
}
|
||||
|
||||
if let Some(function) = tool_call.function.as_ref() {
|
||||
if let Some(name) = function.name.clone() {
|
||||
entry.name = name;
|
||||
}
|
||||
|
||||
if let Some(arguments) = function.arguments.clone() {
|
||||
entry.arguments.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(finish_reason) = choice.finish_reason.as_deref() {
|
||||
match finish_reason {
|
||||
"stop" => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
"tool_calls" => {
|
||||
events.extend(self.process_tool_calls());
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
|
||||
}
|
||||
unexpected => {
|
||||
log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
events
|
||||
}
|
||||
|
||||
fn process_tool_calls(
|
||||
&mut self,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
for (_, tool_call) in self.tool_calls_by_index.drain() {
|
||||
if tool_call.id.is_empty() || tool_call.name.is_empty() {
|
||||
results.push(Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"Received incomplete tool call: missing id or name"
|
||||
))));
|
||||
continue;
|
||||
}
|
||||
|
||||
match serde_json::Value::from_str(&tool_call.arguments) {
|
||||
Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.into(),
|
||||
name: tool_call.name.into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments,
|
||||
},
|
||||
))),
|
||||
Err(error) => results.push(Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct RawToolCall {
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<Editor>,
|
||||
state: gpui::Entity<State>,
|
||||
|
@ -623,3 +786,65 @@ impl Render for ConfigurationView {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use language_model;
|
||||
|
||||
#[test]
|
||||
fn test_into_mistral_conversion() {
|
||||
let request = language_model::LanguageModelRequest {
|
||||
messages: vec![
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: language_model::Role::System,
|
||||
content: vec![language_model::MessageContent::Text(
|
||||
"You are a helpful assistant.".to_string(),
|
||||
)],
|
||||
cache: false,
|
||||
},
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: language_model::Role::User,
|
||||
content: vec![language_model::MessageContent::Text(
|
||||
"Hello, how are you?".to_string(),
|
||||
)],
|
||||
cache: false,
|
||||
},
|
||||
],
|
||||
temperature: Some(0.7),
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
mode: None,
|
||||
stop: Vec::new(),
|
||||
};
|
||||
|
||||
let model_name = "mistral-medium-latest".to_string();
|
||||
let max_output_tokens = Some(1000);
|
||||
let mistral_request = into_mistral(request, model_name, max_output_tokens);
|
||||
|
||||
assert_eq!(mistral_request.model, "mistral-medium-latest");
|
||||
assert_eq!(mistral_request.temperature, Some(0.7));
|
||||
assert_eq!(mistral_request.max_tokens, Some(1000));
|
||||
assert!(mistral_request.stream);
|
||||
assert!(mistral_request.tools.is_empty());
|
||||
assert!(mistral_request.tool_choice.is_none());
|
||||
|
||||
assert_eq!(mistral_request.messages.len(), 2);
|
||||
|
||||
match &mistral_request.messages[0] {
|
||||
mistral::RequestMessage::System { content } => {
|
||||
assert_eq!(content, "You are a helpful assistant.");
|
||||
}
|
||||
_ => panic!("Expected System message"),
|
||||
}
|
||||
|
||||
match &mistral_request.messages[1] {
|
||||
mistral::RequestMessage::User { content } => {
|
||||
assert_eq!(content, "Hello, how are you?");
|
||||
}
|
||||
_ => panic!("Expected User message"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -67,6 +67,7 @@ pub enum Model {
|
|||
max_tokens: usize,
|
||||
max_output_tokens: Option<u32>,
|
||||
max_completion_tokens: Option<u32>,
|
||||
supports_tools: Option<bool>,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -133,6 +134,18 @@ impl Model {
|
|||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supports_tools(&self) -> bool {
|
||||
match self {
|
||||
Self::CodestralLatest
|
||||
| Self::MistralLargeLatest
|
||||
| Self::MistralMediumLatest
|
||||
| Self::MistralSmallLatest
|
||||
| Self::OpenMistralNemo
|
||||
| Self::OpenCodestralMamba => true,
|
||||
Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
@ -146,6 +159,10 @@ pub struct Request {
|
|||
pub temperature: Option<f32>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub response_format: Option<ResponseFormat>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub tools: Vec<ToolDefinition>,
|
||||
}
|
||||
|
@ -190,12 +207,13 @@ pub enum Prediction {
|
|||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ToolChoice {
|
||||
Auto,
|
||||
Required,
|
||||
None,
|
||||
Other(ToolDefinition),
|
||||
Any,
|
||||
Function(ToolDefinition),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
|
|
|
@ -14,6 +14,7 @@ Here's an overview of the supported providers and tool call support:
|
|||
| [Anthropic](#anthropic) | ✅ |
|
||||
| [GitHub Copilot Chat](#github-copilot-chat) | In Some Cases |
|
||||
| [Google AI](#google-ai) | ✅ |
|
||||
| [Mistral](#mistral) | ✅ |
|
||||
| [Ollama](#ollama) | ✅ |
|
||||
| [OpenAI](#openai) | ✅ |
|
||||
| [DeepSeek](#deepseek) | 🚫 |
|
||||
|
@ -128,6 +129,44 @@ By default Zed will use `stable` versions of models, but you can use specific ve
|
|||
|
||||
Custom models will be listed in the model dropdown in the Agent Panel.
|
||||
|
||||
### Mistral {#mistral}
|
||||
|
||||
> 🔨Supports tool use
|
||||
|
||||
1. Visit the Mistral platform and [create an API key](https://console.mistral.ai/api-keys/)
|
||||
2. Open the configuration view (`assistant: show configuration`) and navigate to the Mistral section
|
||||
3. Enter your Mistral API key
|
||||
|
||||
The Mistral API key will be saved in your keychain.
|
||||
|
||||
Zed will also use the `MISTRAL_API_KEY` environment variable if it's defined.
|
||||
|
||||
#### Mistral Custom Models {#mistral-custom-models}
|
||||
|
||||
The Zed Assistant comes pre-configured with several Mistral models (codestral-latest, mistral-large-latest, mistral-medium-latest, mistral-small-latest, open-mistral-nemo, and open-codestral-mamba). All the default models support tool use. If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"language_models": {
|
||||
"mistral": {
|
||||
"api_url": "https://api.mistral.ai/v1",
|
||||
"available_models": [
|
||||
{
|
||||
"name": "mistral-tiny-latest",
|
||||
"display_name": "Mistral Tiny",
|
||||
"max_tokens": 32000,
|
||||
"max_output_tokens": 4096,
|
||||
"max_completion_tokens": 1024,
|
||||
"supports_tools": true
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Custom models will be listed in the model dropdown in the assistant panel.
|
||||
|
||||
### Ollama {#ollama}
|
||||
|
||||
> ✅ Supports tool use
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue