language_models: Add tool use support for Mistral models (#29994)
Closes https://github.com/zed-industries/zed/issues/29855 Implement tool use handling in Mistral provider, including mapping tool call events and updating request construction. Add support for tool_choice and parallel_tool_calls in Mistral API requests. This works fine with all the existing models. Didn't touched anything else but for future. Fetching models using their models api, deducting tool call support, parallel tool calls etc should be done from model data from api response. <img width="547" alt="Screenshot 2025-05-06 at 4 52 37 PM" src="https://github.com/user-attachments/assets/4c08b544-1174-40cc-a40d-522989953448" /> Tasks: - [x] Add tool call support - [x] Auto Fetch models using mistral api - [x] Add tests for mistral crates. - [x] Fix mistral configurations for llm providers. Release Notes: - agent: Add tool call support for existing mistral models --------- Co-authored-by: Peter Tripp <peter@zed.dev> Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
This commit is contained in:
parent
26a8cac0d8
commit
926f377c6c
6 changed files with 347 additions and 50 deletions
|
@ -2,6 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
|
|||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::stream::BoxStream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||
|
@ -11,13 +12,13 @@ use language_model::{
|
|||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, RateLimiter, Role,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
};
|
||||
|
||||
use futures::stream::BoxStream;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
|
@ -26,6 +27,9 @@ use util::ResultExt;
|
|||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
|
||||
const PROVIDER_ID: &str = "mistral";
|
||||
const PROVIDER_NAME: &str = "Mistral";
|
||||
|
||||
|
@ -43,6 +47,7 @@ pub struct AvailableModel {
|
|||
pub max_tokens: usize,
|
||||
pub max_output_tokens: Option<u32>,
|
||||
pub max_completion_tokens: Option<u32>,
|
||||
pub supports_tools: Option<bool>,
|
||||
}
|
||||
|
||||
pub struct MistralLanguageModelProvider {
|
||||
|
@ -209,6 +214,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
|
|||
max_tokens: model.max_tokens,
|
||||
max_output_tokens: model.max_output_tokens,
|
||||
max_completion_tokens: model.max_completion_tokens,
|
||||
supports_tools: model.supports_tools,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -300,14 +306,14 @@ impl LanguageModel for MistralLanguageModel {
|
|||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
self.model.supports_tools()
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
||||
self.model.supports_tools()
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
@ -368,26 +374,8 @@ impl LanguageModel for MistralLanguageModel {
|
|||
|
||||
async move {
|
||||
let stream = stream.await?;
|
||||
Ok(stream
|
||||
.map(|result| {
|
||||
result
|
||||
.and_then(|response| {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("Empty response"))
|
||||
.map(|choice| {
|
||||
choice
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
})
|
||||
})
|
||||
.map_err(LanguageModelCompletionError::Other)
|
||||
})
|
||||
.boxed())
|
||||
let mapper = MistralEventMapper::new();
|
||||
Ok(mapper.map_stream(stream).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
@ -398,33 +386,87 @@ pub fn into_mistral(
|
|||
model: String,
|
||||
max_output_tokens: Option<u32>,
|
||||
) -> mistral::Request {
|
||||
let len = request.messages.len();
|
||||
let merged_messages =
|
||||
request
|
||||
.messages
|
||||
.into_iter()
|
||||
.fold(Vec::with_capacity(len), |mut acc, msg| {
|
||||
let role = msg.role;
|
||||
let content = msg.string_contents();
|
||||
let stream = true;
|
||||
|
||||
acc.push(match role {
|
||||
Role::User => mistral::RequestMessage::User { content },
|
||||
Role::Assistant => mistral::RequestMessage::Assistant {
|
||||
content: Some(content),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => mistral::RequestMessage::System { content },
|
||||
});
|
||||
acc
|
||||
});
|
||||
let mut messages = Vec::new();
|
||||
for message in request.messages {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
|
||||
.push(match message.role {
|
||||
Role::User => mistral::RequestMessage::User { content: text },
|
||||
Role::Assistant => mistral::RequestMessage::Assistant {
|
||||
content: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => mistral::RequestMessage::System { content: text },
|
||||
}),
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_) => {}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
let tool_call = mistral::ToolCall {
|
||||
id: tool_use.id.to_string(),
|
||||
content: mistral::ToolCallContent::Function {
|
||||
function: mistral::FunctionContent {
|
||||
name: tool_use.name.to_string(),
|
||||
arguments: serde_json::to_string(&tool_use.input)
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
|
||||
messages.last_mut()
|
||||
{
|
||||
tool_calls.push(tool_call);
|
||||
} else {
|
||||
messages.push(mistral::RequestMessage::Assistant {
|
||||
content: None,
|
||||
tool_calls: vec![tool_call],
|
||||
});
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
let content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => text.to_string(),
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
// TODO: Mistral image support
|
||||
"[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
|
||||
}
|
||||
};
|
||||
|
||||
messages.push(mistral::RequestMessage::Tool {
|
||||
content,
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mistral::Request {
|
||||
model,
|
||||
messages: merged_messages,
|
||||
stream: true,
|
||||
messages,
|
||||
stream,
|
||||
max_tokens: max_output_tokens,
|
||||
temperature: request.temperature,
|
||||
response_format: None,
|
||||
tool_choice: match request.tool_choice {
|
||||
Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
|
||||
Some(mistral::ToolChoice::Auto)
|
||||
}
|
||||
Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
|
||||
Some(mistral::ToolChoice::Any)
|
||||
}
|
||||
Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
|
||||
_ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
|
||||
_ => None,
|
||||
},
|
||||
parallel_tool_calls: if !request.tools.is_empty() {
|
||||
Some(false)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
|
@ -439,6 +481,127 @@ pub fn into_mistral(
|
|||
}
|
||||
}
|
||||
|
||||
pub struct MistralEventMapper {
|
||||
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
||||
}
|
||||
|
||||
impl MistralEventMapper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tool_calls_by_index: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + futures::Stream<Item = Result<mistral::StreamResponse>>>>,
|
||||
) -> impl futures::Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Ok(event) => self.map_event(event),
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn map_event(
|
||||
&mut self,
|
||||
event: mistral::StreamResponse,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"Response contained no choices"
|
||||
)))];
|
||||
};
|
||||
|
||||
let mut events = Vec::new();
|
||||
if let Some(content) = choice.delta.content.clone() {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
|
||||
for tool_call in tool_calls {
|
||||
let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
|
||||
|
||||
if let Some(tool_id) = tool_call.id.clone() {
|
||||
entry.id = tool_id;
|
||||
}
|
||||
|
||||
if let Some(function) = tool_call.function.as_ref() {
|
||||
if let Some(name) = function.name.clone() {
|
||||
entry.name = name;
|
||||
}
|
||||
|
||||
if let Some(arguments) = function.arguments.clone() {
|
||||
entry.arguments.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(finish_reason) = choice.finish_reason.as_deref() {
|
||||
match finish_reason {
|
||||
"stop" => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
"tool_calls" => {
|
||||
events.extend(self.process_tool_calls());
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
|
||||
}
|
||||
unexpected => {
|
||||
log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
events
|
||||
}
|
||||
|
||||
fn process_tool_calls(
|
||||
&mut self,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
for (_, tool_call) in self.tool_calls_by_index.drain() {
|
||||
if tool_call.id.is_empty() || tool_call.name.is_empty() {
|
||||
results.push(Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"Received incomplete tool call: missing id or name"
|
||||
))));
|
||||
continue;
|
||||
}
|
||||
|
||||
match serde_json::Value::from_str(&tool_call.arguments) {
|
||||
Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.into(),
|
||||
name: tool_call.name.into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments,
|
||||
},
|
||||
))),
|
||||
Err(error) => results.push(Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct RawToolCall {
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<Editor>,
|
||||
state: gpui::Entity<State>,
|
||||
|
@ -623,3 +786,65 @@ impl Render for ConfigurationView {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use language_model;
|
||||
|
||||
#[test]
|
||||
fn test_into_mistral_conversion() {
|
||||
let request = language_model::LanguageModelRequest {
|
||||
messages: vec![
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: language_model::Role::System,
|
||||
content: vec![language_model::MessageContent::Text(
|
||||
"You are a helpful assistant.".to_string(),
|
||||
)],
|
||||
cache: false,
|
||||
},
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: language_model::Role::User,
|
||||
content: vec![language_model::MessageContent::Text(
|
||||
"Hello, how are you?".to_string(),
|
||||
)],
|
||||
cache: false,
|
||||
},
|
||||
],
|
||||
temperature: Some(0.7),
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
mode: None,
|
||||
stop: Vec::new(),
|
||||
};
|
||||
|
||||
let model_name = "mistral-medium-latest".to_string();
|
||||
let max_output_tokens = Some(1000);
|
||||
let mistral_request = into_mistral(request, model_name, max_output_tokens);
|
||||
|
||||
assert_eq!(mistral_request.model, "mistral-medium-latest");
|
||||
assert_eq!(mistral_request.temperature, Some(0.7));
|
||||
assert_eq!(mistral_request.max_tokens, Some(1000));
|
||||
assert!(mistral_request.stream);
|
||||
assert!(mistral_request.tools.is_empty());
|
||||
assert!(mistral_request.tool_choice.is_none());
|
||||
|
||||
assert_eq!(mistral_request.messages.len(), 2);
|
||||
|
||||
match &mistral_request.messages[0] {
|
||||
mistral::RequestMessage::System { content } => {
|
||||
assert_eq!(content, "You are a helpful assistant.");
|
||||
}
|
||||
_ => panic!("Expected System message"),
|
||||
}
|
||||
|
||||
match &mistral_request.messages[1] {
|
||||
mistral::RequestMessage::User { content } => {
|
||||
assert_eq!(content, "Hello, how are you?");
|
||||
}
|
||||
_ => panic!("Expected User message"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue