language_models: Add support for tool use to LM Studio provider (#30589)
Closes #30004 **Quick demo:** https://github.com/user-attachments/assets/0ac93851-81d7-4128-a34b-1f3ae4bcff6d **Additional notes:** I've tried to stick to existing code in OpenAI provider as much as possible without changing much to keep the diff small. This PR is done in collaboration with @yagil from LM Studio. We agreed upon the format in which LM Studio will return information about tool use support for the model in the upcoming version. As of current stable version nothing is going to change for the users, but once they update to a newer LM Studio tool use gets automatically enabled for them. I think this is much better UX then defaulting to true right now. Release Notes: - Added support for tool calls to LM Studio provider --------- Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
parent
6363fdab88
commit
998542b048
3 changed files with 320 additions and 120 deletions
|
@ -1,10 +1,13 @@
|
|||
use anyhow::{Result, anyhow};
|
||||
use collections::HashMap;
|
||||
use futures::Stream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelToolChoice,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
StopReason, WrappedTextContent,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
|
@ -12,12 +15,14 @@ use language_model::{
|
|||
LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
use lmstudio::{
|
||||
ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
|
||||
ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models, preload_model,
|
||||
stream_chat_completion,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr;
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
use ui::{ButtonLike, Indicator, List, prelude::*};
|
||||
use util::ResultExt;
|
||||
|
@ -40,12 +45,10 @@ pub struct LmStudioSettings {
|
|||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AvailableModel {
|
||||
/// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
|
||||
pub name: String,
|
||||
/// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
|
||||
pub display_name: Option<String>,
|
||||
/// The model's context window size.
|
||||
pub max_tokens: usize,
|
||||
pub supports_tool_calls: bool,
|
||||
}
|
||||
|
||||
pub struct LmStudioLanguageModelProvider {
|
||||
|
@ -77,7 +80,14 @@ impl State {
|
|||
let mut models: Vec<lmstudio::Model> = models
|
||||
.into_iter()
|
||||
.filter(|model| model.r#type != ModelType::Embeddings)
|
||||
.map(|model| lmstudio::Model::new(&model.id, None, None))
|
||||
.map(|model| {
|
||||
lmstudio::Model::new(
|
||||
&model.id,
|
||||
None,
|
||||
None,
|
||||
model.capabilities.supports_tool_calls(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
@ -156,12 +166,16 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
|
|||
IconName::AiLmStudio
|
||||
}
|
||||
|
||||
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.provided_models(cx).into_iter().next()
|
||||
fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
// We shouldn't try to select default model, because it might lead to a load call for an unloaded model.
|
||||
// In a constrained environment where user might not have enough resources it'll be a bad UX to select something
|
||||
// to load by default.
|
||||
None
|
||||
}
|
||||
|
||||
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.default_model(cx)
|
||||
fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
// See explanation for default_model.
|
||||
None
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||
|
@ -184,6 +198,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
|
|||
name: model.name.clone(),
|
||||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
supports_tool_calls: model.supports_tool_calls,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -237,31 +252,117 @@ pub struct LmStudioLanguageModel {
|
|||
|
||||
impl LmStudioLanguageModel {
|
||||
fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
for message in request.messages {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
|
||||
.push(match message.role {
|
||||
Role::User => ChatMessage::User { content: text },
|
||||
Role::Assistant => ChatMessage::Assistant {
|
||||
content: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => ChatMessage::System { content: text },
|
||||
}),
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_) => {}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
let tool_call = lmstudio::ToolCall {
|
||||
id: tool_use.id.to_string(),
|
||||
content: lmstudio::ToolCallContent::Function {
|
||||
function: lmstudio::FunctionContent {
|
||||
name: tool_use.name.to_string(),
|
||||
arguments: serde_json::to_string(&tool_use.input)
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
if let Some(lmstudio::ChatMessage::Assistant { tool_calls, .. }) =
|
||||
messages.last_mut()
|
||||
{
|
||||
tool_calls.push(tool_call);
|
||||
} else {
|
||||
messages.push(lmstudio::ChatMessage::Assistant {
|
||||
content: None,
|
||||
tool_calls: vec![tool_call],
|
||||
});
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text)
|
||||
| LanguageModelToolResultContent::WrappedText(WrappedTextContent {
|
||||
text,
|
||||
..
|
||||
}) => {
|
||||
messages.push(lmstudio::ChatMessage::Tool {
|
||||
content: text.to_string(),
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
}
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
// no support for images for now
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ChatCompletionRequest {
|
||||
model: self.model.name.clone(),
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| match msg.role {
|
||||
Role::User => ChatMessage::User {
|
||||
content: msg.string_contents(),
|
||||
},
|
||||
Role::Assistant => ChatMessage::Assistant {
|
||||
content: Some(msg.string_contents()),
|
||||
tool_calls: None,
|
||||
},
|
||||
Role::System => ChatMessage::System {
|
||||
content: msg.string_contents(),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
messages,
|
||||
stream: true,
|
||||
max_tokens: Some(-1),
|
||||
stop: Some(request.stop),
|
||||
temperature: request.temperature.or(Some(0.0)),
|
||||
tools: vec![],
|
||||
// In LM Studio you can configure specific settings you'd like to use for your model.
|
||||
// For example Qwen3 is recommended to be used with 0.7 temperature.
|
||||
// It would be a bad UX to silently override these settings from Zed, so we pass no temperature as a default.
|
||||
temperature: request.temperature.or(None),
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| lmstudio::ToolDefinition::Function {
|
||||
function: lmstudio::FunctionDefinition {
|
||||
name: tool.name,
|
||||
description: Some(tool.description),
|
||||
parameters: Some(tool.input_schema),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
tool_choice: request.tool_choice.map(|choice| match choice {
|
||||
LanguageModelToolChoice::Auto => lmstudio::ToolChoice::Auto,
|
||||
LanguageModelToolChoice::Any => lmstudio::ToolChoice::Required,
|
||||
LanguageModelToolChoice::None => lmstudio::ToolChoice::None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
||||
{
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok(api_url) = cx.update(|cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
|
||||
settings.api_url.clone()
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = stream_chat_completion(http_client.as_ref(), &api_url, request);
|
||||
let response = request.await?;
|
||||
Ok(response)
|
||||
});
|
||||
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for LmStudioLanguageModel {
|
||||
|
@ -282,17 +383,22 @@ impl LanguageModel for LmStudioLanguageModel {
|
|||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
false
|
||||
self.model.supports_tool_calls()
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||
self.supports_tools()
|
||||
&& match choice {
|
||||
LanguageModelToolChoice::Auto => true,
|
||||
LanguageModelToolChoice::Any => true,
|
||||
LanguageModelToolChoice::None => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("lmstudio/{}", self.model.id())
|
||||
}
|
||||
|
@ -328,85 +434,126 @@ impl LanguageModel for LmStudioLanguageModel {
|
|||
>,
|
||||
> {
|
||||
let request = self.to_lmstudio_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok(api_url) = cx.update(|cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
|
||||
settings.api_url.clone()
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
|
||||
|
||||
// Create a stream mapper to handle content across multiple deltas
|
||||
let stream_mapper = LmStudioStreamMapper::new();
|
||||
|
||||
let stream = response
|
||||
.map(move |response| {
|
||||
response.and_then(|fragment| stream_mapper.process_fragment(fragment))
|
||||
})
|
||||
.filter_map(|result| async move {
|
||||
match result {
|
||||
Ok(Some(content)) => Some(Ok(content)),
|
||||
Ok(None) => None,
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
|
||||
Ok(stream)
|
||||
});
|
||||
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| {
|
||||
result
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
.map_err(LanguageModelCompletionError::Other)
|
||||
})
|
||||
.boxed())
|
||||
let mapper = LmStudioEventMapper::new();
|
||||
Ok(mapper.map_stream(completions.await?).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
// This will be more useful when we implement tool calling. Currently keeping it empty.
|
||||
struct LmStudioStreamMapper {}
|
||||
struct LmStudioEventMapper {
|
||||
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
||||
}
|
||||
|
||||
impl LmStudioStreamMapper {
|
||||
impl LmStudioEventMapper {
|
||||
fn new() -> Self {
|
||||
Self {}
|
||||
Self {
|
||||
tool_calls_by_index: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_fragment(&self, fragment: lmstudio::ChatResponse) -> Result<Option<String>> {
|
||||
// Most of the time, there will be only one choice
|
||||
let Some(choice) = fragment.choices.first() else {
|
||||
return Ok(None);
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Ok(event) => self.map_event(event),
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn map_event(
|
||||
&mut self,
|
||||
event: ResponseStreamEvent,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let Some(choice) = event.choices.into_iter().next() else {
|
||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"Response contained no choices"
|
||||
)))];
|
||||
};
|
||||
|
||||
// Extract the delta content
|
||||
if let Ok(delta) =
|
||||
serde_json::from_value::<lmstudio::ResponseMessageDelta>(choice.delta.clone())
|
||||
{
|
||||
if let Some(content) = delta.content {
|
||||
if !content.is_empty() {
|
||||
return Ok(Some(content));
|
||||
let mut events = Vec::new();
|
||||
if let Some(content) = choice.delta.content {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = choice.delta.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
|
||||
|
||||
if let Some(tool_id) = tool_call.id {
|
||||
entry.id = tool_id;
|
||||
}
|
||||
|
||||
if let Some(function) = tool_call.function {
|
||||
if let Some(name) = function.name {
|
||||
// At the time of writing this code LM Studio (0.3.15) is incompatible with the OpenAI API:
|
||||
// 1. It sends function name in the first chunk
|
||||
// 2. It sends empty string in the function name field in all subsequent chunks for arguments
|
||||
// According to https://platform.openai.com/docs/guides/function-calling?api-mode=responses#streaming
|
||||
// function name field should be sent only inside the first chunk.
|
||||
if !name.is_empty() {
|
||||
entry.name = name;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(arguments) = function.arguments {
|
||||
entry.arguments.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there's a finish_reason, we're done
|
||||
if choice.finish_reason.is_some() {
|
||||
return Ok(None);
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("stop") => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
Some("tool_calls") => {
|
||||
events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
|
||||
match serde_json::Value::from_str(&tool_call.arguments) {
|
||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.into(),
|
||||
name: tool_call.name.into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments,
|
||||
},
|
||||
)),
|
||||
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
}),
|
||||
}
|
||||
}));
|
||||
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
|
||||
}
|
||||
Some(stop_reason) => {
|
||||
log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
events
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct RawToolCall {
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
state: gpui::Entity<State>,
|
||||
loading_models_task: Option<Task<()>>,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue