language_models: Add support for tool use to LM Studio provider (#30589)

Closes #30004

**Quick demo:**


https://github.com/user-attachments/assets/0ac93851-81d7-4128-a34b-1f3ae4bcff6d

**Additional notes:**

I've tried to stick to existing code in OpenAI provider as much as
possible without changing much to keep the diff small.

This PR is done in collaboration with @yagil from LM Studio. We agreed
upon the format in which LM Studio will return information about tool
use support for the model in the upcoming version. As of current stable
version nothing is going to change for the users, but once they update
to a newer LM Studio tool use gets automatically enabled for them. I
think this is much better UX then defaulting to true right now.


Release Notes:

- Added support for tool calls to LM Studio provider

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Fedor Nezhivoi 2025-05-26 18:54:17 +07:00 committed by GitHub
parent 6363fdab88
commit 998542b048
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 320 additions and 120 deletions

View file

@ -383,7 +383,9 @@ impl AssistantSettingsContent {
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::LmStudio {
default_model: Some(lmstudio::Model::new(&model, None, None)),
default_model: Some(lmstudio::Model::new(
&model, None, None, false,
)),
api_url,
});
}

View file

@ -1,10 +1,13 @@
use anyhow::{Result, anyhow};
use collections::HashMap;
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelToolChoice,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
StopReason, WrappedTextContent,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@ -12,12 +15,14 @@ use language_model::{
LanguageModelRequest, RateLimiter, Role,
};
use lmstudio::{
ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models, preload_model,
stream_chat_completion,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr;
use std::{collections::BTreeMap, sync::Arc};
use ui::{ButtonLike, Indicator, List, prelude::*};
use util::ResultExt;
@ -40,12 +45,10 @@ pub struct LmStudioSettings {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
/// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
pub name: String,
/// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
pub display_name: Option<String>,
/// The model's context window size.
pub max_tokens: usize,
pub supports_tool_calls: bool,
}
pub struct LmStudioLanguageModelProvider {
@ -77,7 +80,14 @@ impl State {
let mut models: Vec<lmstudio::Model> = models
.into_iter()
.filter(|model| model.r#type != ModelType::Embeddings)
.map(|model| lmstudio::Model::new(&model.id, None, None))
.map(|model| {
lmstudio::Model::new(
&model.id,
None,
None,
model.capabilities.supports_tool_calls(),
)
})
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
@ -156,12 +166,16 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
IconName::AiLmStudio
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
self.provided_models(cx).into_iter().next()
fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
// We shouldn't try to select default model, because it might lead to a load call for an unloaded model.
// In a constrained environment where user might not have enough resources it'll be a bad UX to select something
// to load by default.
None
}
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
self.default_model(cx)
fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
// See explanation for default_model.
None
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@ -184,6 +198,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
supports_tool_calls: model.supports_tool_calls,
},
);
}
@ -237,31 +252,117 @@ pub struct LmStudioLanguageModel {
impl LmStudioLanguageModel {
fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
let mut messages = Vec::new();
for message in request.messages {
for content in message.content {
match content {
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
.push(match message.role {
Role::User => ChatMessage::User { content: text },
Role::Assistant => ChatMessage::Assistant {
content: Some(text),
tool_calls: Vec::new(),
},
Role::System => ChatMessage::System { content: text },
}),
MessageContent::RedactedThinking(_) => {}
MessageContent::Image(_) => {}
MessageContent::ToolUse(tool_use) => {
let tool_call = lmstudio::ToolCall {
id: tool_use.id.to_string(),
content: lmstudio::ToolCallContent::Function {
function: lmstudio::FunctionContent {
name: tool_use.name.to_string(),
arguments: serde_json::to_string(&tool_use.input)
.unwrap_or_default(),
},
},
};
if let Some(lmstudio::ChatMessage::Assistant { tool_calls, .. }) =
messages.last_mut()
{
tool_calls.push(tool_call);
} else {
messages.push(lmstudio::ChatMessage::Assistant {
content: None,
tool_calls: vec![tool_call],
});
}
}
MessageContent::ToolResult(tool_result) => {
match &tool_result.content {
LanguageModelToolResultContent::Text(text)
| LanguageModelToolResultContent::WrappedText(WrappedTextContent {
text,
..
}) => {
messages.push(lmstudio::ChatMessage::Tool {
content: text.to_string(),
tool_call_id: tool_result.tool_use_id.to_string(),
});
}
LanguageModelToolResultContent::Image(_) => {
// no support for images for now
}
};
}
}
}
}
ChatCompletionRequest {
model: self.model.name.clone(),
messages: request
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => ChatMessage::User {
content: msg.string_contents(),
},
Role::Assistant => ChatMessage::Assistant {
content: Some(msg.string_contents()),
tool_calls: None,
},
Role::System => ChatMessage::System {
content: msg.string_contents(),
},
})
.collect(),
messages,
stream: true,
max_tokens: Some(-1),
stop: Some(request.stop),
temperature: request.temperature.or(Some(0.0)),
tools: vec![],
// In LM Studio you can configure specific settings you'd like to use for your model.
// For example Qwen3 is recommended to be used with 0.7 temperature.
// It would be a bad UX to silently override these settings from Zed, so we pass no temperature as a default.
temperature: request.temperature.or(None),
tools: request
.tools
.into_iter()
.map(|tool| lmstudio::ToolDefinition::Function {
function: lmstudio::FunctionDefinition {
name: tool.name,
description: Some(tool.description),
parameters: Some(tool.input_schema),
},
})
.collect(),
tool_choice: request.tool_choice.map(|choice| match choice {
LanguageModelToolChoice::Auto => lmstudio::ToolChoice::Auto,
LanguageModelToolChoice::Any => lmstudio::ToolChoice::Required,
LanguageModelToolChoice::None => lmstudio::ToolChoice::None,
}),
}
}
fn stream_completion(
&self,
request: ChatCompletionRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
let Ok(api_url) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
settings.api_url.clone()
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let request = stream_chat_completion(http_client.as_ref(), &api_url, request);
let response = request.await?;
Ok(response)
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
impl LanguageModel for LmStudioLanguageModel {
@ -282,17 +383,22 @@ impl LanguageModel for LmStudioLanguageModel {
}
fn supports_tools(&self) -> bool {
false
self.model.supports_tool_calls()
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
self.supports_tools()
&& match choice {
LanguageModelToolChoice::Auto => true,
LanguageModelToolChoice::Any => true,
LanguageModelToolChoice::None => true,
}
}
fn supports_images(&self) -> bool {
false
}
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
false
}
fn telemetry_id(&self) -> String {
format!("lmstudio/{}", self.model.id())
}
@ -328,85 +434,126 @@ impl LanguageModel for LmStudioLanguageModel {
>,
> {
let request = self.to_lmstudio_request(request);
let http_client = self.http_client.clone();
let Ok(api_url) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
settings.api_url.clone()
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
// Create a stream mapper to handle content across multiple deltas
let stream_mapper = LmStudioStreamMapper::new();
let stream = response
.map(move |response| {
response.and_then(|fragment| stream_mapper.process_fragment(fragment))
})
.filter_map(|result| async move {
match result {
Ok(Some(content)) => Some(Ok(content)),
Ok(None) => None,
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
});
let completions = self.stream_completion(request, cx);
async move {
Ok(future
.await?
.map(|result| {
result
.map(LanguageModelCompletionEvent::Text)
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
let mapper = LmStudioEventMapper::new();
Ok(mapper.map_stream(completions.await?).boxed())
}
.boxed()
}
}
// This will be more useful when we implement tool calling. Currently keeping it empty.
struct LmStudioStreamMapper {}
struct LmStudioEventMapper {
tool_calls_by_index: HashMap<usize, RawToolCall>,
}
impl LmStudioStreamMapper {
impl LmStudioEventMapper {
fn new() -> Self {
Self {}
Self {
tool_calls_by_index: HashMap::default(),
}
}
fn process_fragment(&self, fragment: lmstudio::ChatResponse) -> Result<Option<String>> {
// Most of the time, there will be only one choice
let Some(choice) = fragment.choices.first() else {
return Ok(None);
pub fn map_stream(
mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
})
})
}
pub fn map_event(
&mut self,
event: ResponseStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.into_iter().next() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
"Response contained no choices"
)))];
};
// Extract the delta content
if let Ok(delta) =
serde_json::from_value::<lmstudio::ResponseMessageDelta>(choice.delta.clone())
{
if let Some(content) = delta.content {
if !content.is_empty() {
return Ok(Some(content));
let mut events = Vec::new();
if let Some(content) = choice.delta.content {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
if let Some(tool_calls) = choice.delta.tool_calls {
for tool_call in tool_calls {
let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
if let Some(tool_id) = tool_call.id {
entry.id = tool_id;
}
if let Some(function) = tool_call.function {
if let Some(name) = function.name {
// At the time of writing this code LM Studio (0.3.15) is incompatible with the OpenAI API:
// 1. It sends function name in the first chunk
// 2. It sends empty string in the function name field in all subsequent chunks for arguments
// According to https://platform.openai.com/docs/guides/function-calling?api-mode=responses#streaming
// function name field should be sent only inside the first chunk.
if !name.is_empty() {
entry.name = name;
}
}
if let Some(arguments) = function.arguments {
entry.arguments.push_str(&arguments);
}
}
}
}
// If there's a finish_reason, we're done
if choice.finish_reason.is_some() {
return Ok(None);
match choice.finish_reason.as_deref() {
Some("stop") => {
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
}
Some("tool_calls") => {
events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
match serde_json::Value::from_str(&tool_call.arguments) {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.into(),
name: tool_call.name.into(),
is_input_complete: true,
input,
raw_input: tool_call.arguments,
},
)),
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
tool_name: tool_call.name.into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
}),
}
}));
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
}
Some(stop_reason) => {
log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
}
None => {}
}
Ok(None)
events
}
}
#[derive(Default)]
struct RawToolCall {
id: String,
name: String,
arguments: String,
}
struct ConfigurationView {
state: gpui::Entity<State>,
loading_models_task: Option<Task<()>>,

View file

@ -2,7 +2,7 @@ use anyhow::{Context as _, Result};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
use serde::{Deserialize, Serialize};
use serde_json::{Value, value::RawValue};
use serde_json::Value;
use std::{convert::TryFrom, sync::Arc, time::Duration};
pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
@ -47,14 +47,21 @@ pub struct Model {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: usize,
pub supports_tool_calls: bool,
}
impl Model {
pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
pub fn new(
name: &str,
display_name: Option<&str>,
max_tokens: Option<usize>,
supports_tool_calls: bool,
) -> Self {
Self {
name: name.to_owned(),
display_name: display_name.map(|s| s.to_owned()),
max_tokens: max_tokens.unwrap_or(2048),
supports_tool_calls,
}
}
@ -69,15 +76,43 @@ impl Model {
pub fn max_token_count(&self) -> usize {
self.max_tokens
}
pub fn supports_tool_calls(&self) -> bool {
self.supports_tool_calls
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Auto,
Required,
None,
Other(ToolDefinition),
}
#[derive(Clone, Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolDefinition {
#[allow(dead_code)]
Function { function: FunctionDefinition },
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Value>,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum ChatMessage {
Assistant {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<LmStudioToolCall>>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
tool_calls: Vec<ToolCall>,
},
User {
content: String,
@ -85,31 +120,29 @@ pub enum ChatMessage {
System {
content: String,
},
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "lowercase")]
pub enum LmStudioToolCall {
Function(LmStudioFunctionCall),
}
#[derive(Serialize, Deserialize, Debug)]
pub struct LmStudioFunctionCall {
pub name: String,
pub arguments: Box<RawValue>,
Tool {
content: String,
tool_call_id: String,
},
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LmStudioFunctionTool {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Value>,
pub struct ToolCall {
pub id: String,
#[serde(flatten)]
pub content: ToolCallContent,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum LmStudioTool {
Function { function: LmStudioFunctionTool },
pub enum ToolCallContent {
Function { function: FunctionContent },
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct FunctionContent {
pub name: String,
pub arguments: String,
}
#[derive(Serialize, Debug)]
@ -117,10 +150,16 @@ pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
pub tools: Vec<LmStudioTool>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<ToolDefinition>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
}
#[derive(Serialize, Deserialize, Debug)]
@ -135,8 +174,7 @@ pub struct ChatResponse {
#[derive(Serialize, Deserialize, Debug)]
pub struct ChoiceDelta {
pub index: u32,
#[serde(default)]
pub delta: serde_json::Value,
pub delta: ResponseMessageDelta,
pub finish_reason: Option<String>,
}
@ -164,6 +202,16 @@ pub struct Usage {
pub total_tokens: u32,
}
#[derive(Debug, Default, Clone, Deserialize, PartialEq)]
#[serde(transparent)]
pub struct Capabilities(Vec<String>);
impl Capabilities {
pub fn supports_tool_calls(&self) -> bool {
self.0.iter().any(|cap| cap == "tool_use")
}
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum ResponseStreamResult {
@ -175,16 +223,17 @@ pub enum ResponseStreamResult {
pub struct ResponseStreamEvent {
pub created: u32,
pub model: String,
pub object: String,
pub choices: Vec<ChoiceDelta>,
pub usage: Option<Usage>,
}
#[derive(Serialize, Deserialize)]
#[derive(Deserialize)]
pub struct ListModelsResponse {
pub data: Vec<ModelEntry>,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
#[derive(Clone, Debug, Deserialize, PartialEq)]
pub struct ModelEntry {
pub id: String,
pub object: String,
@ -196,6 +245,8 @@ pub struct ModelEntry {
pub state: ModelState,
pub max_context_length: Option<u32>,
pub loaded_context_length: Option<u32>,
#[serde(default)]
pub capabilities: Capabilities,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
@ -265,7 +316,7 @@ pub async fn stream_chat_completion(
client: &dyn HttpClient,
api_url: &str,
request: ChatCompletionRequest,
) -> Result<BoxStream<'static, Result<ChatResponse>>> {
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
let uri = format!("{api_url}/chat/completions");
let request_builder = http::Request::builder()
.method(Method::POST)