diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index 4d48b6606a..0ae026189c 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -383,7 +383,9 @@ impl AssistantSettingsContent { _ => None, }; settings.provider = Some(AssistantProviderContentV1::LmStudio { - default_model: Some(lmstudio::Model::new(&model, None, None)), + default_model: Some(lmstudio::Model::new( + &model, None, None, false, + )), api_url, }); } diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 509816272c..c2147cd442 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -1,10 +1,13 @@ use anyhow::{Result, anyhow}; +use collections::HashMap; +use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelToolChoice, + LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, + StopReason, WrappedTextContent, }; use language_model::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, @@ -12,12 +15,14 @@ use language_model::{ LanguageModelRequest, RateLimiter, Role, }; use lmstudio::{ - ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model, + ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models, preload_model, stream_chat_completion, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; +use std::pin::Pin; +use std::str::FromStr; use std::{collections::BTreeMap, sync::Arc}; use ui::{ButtonLike, Indicator, List, prelude::*}; use util::ResultExt; @@ -40,12 +45,10 @@ pub struct LmStudioSettings { #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] pub struct AvailableModel { - /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc pub name: String, - /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel. pub display_name: Option, - /// The model's context window size. pub max_tokens: usize, + pub supports_tool_calls: bool, } pub struct LmStudioLanguageModelProvider { @@ -77,7 +80,14 @@ impl State { let mut models: Vec = models .into_iter() .filter(|model| model.r#type != ModelType::Embeddings) - .map(|model| lmstudio::Model::new(&model.id, None, None)) + .map(|model| { + lmstudio::Model::new( + &model.id, + None, + None, + model.capabilities.supports_tool_calls(), + ) + }) .collect(); models.sort_by(|a, b| a.name.cmp(&b.name)); @@ -156,12 +166,16 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { IconName::AiLmStudio } - fn default_model(&self, cx: &App) -> Option> { - self.provided_models(cx).into_iter().next() + fn default_model(&self, _: &App) -> Option> { + // We shouldn't try to select default model, because it might lead to a load call for an unloaded model. + // In a constrained environment where user might not have enough resources it'll be a bad UX to select something + // to load by default. + None } - fn default_fast_model(&self, cx: &App) -> Option> { - self.default_model(cx) + fn default_fast_model(&self, _: &App) -> Option> { + // See explanation for default_model. + None } fn provided_models(&self, cx: &App) -> Vec> { @@ -184,6 +198,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { name: model.name.clone(), display_name: model.display_name.clone(), max_tokens: model.max_tokens, + supports_tool_calls: model.supports_tool_calls, }, ); } @@ -237,31 +252,117 @@ pub struct LmStudioLanguageModel { impl LmStudioLanguageModel { fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest { + let mut messages = Vec::new(); + + for message in request.messages { + for content in message.content { + match content { + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages + .push(match message.role { + Role::User => ChatMessage::User { content: text }, + Role::Assistant => ChatMessage::Assistant { + content: Some(text), + tool_calls: Vec::new(), + }, + Role::System => ChatMessage::System { content: text }, + }), + MessageContent::RedactedThinking(_) => {} + MessageContent::Image(_) => {} + MessageContent::ToolUse(tool_use) => { + let tool_call = lmstudio::ToolCall { + id: tool_use.id.to_string(), + content: lmstudio::ToolCallContent::Function { + function: lmstudio::FunctionContent { + name: tool_use.name.to_string(), + arguments: serde_json::to_string(&tool_use.input) + .unwrap_or_default(), + }, + }, + }; + + if let Some(lmstudio::ChatMessage::Assistant { tool_calls, .. }) = + messages.last_mut() + { + tool_calls.push(tool_call); + } else { + messages.push(lmstudio::ChatMessage::Assistant { + content: None, + tool_calls: vec![tool_call], + }); + } + } + MessageContent::ToolResult(tool_result) => { + match &tool_result.content { + LanguageModelToolResultContent::Text(text) + | LanguageModelToolResultContent::WrappedText(WrappedTextContent { + text, + .. + }) => { + messages.push(lmstudio::ChatMessage::Tool { + content: text.to_string(), + tool_call_id: tool_result.tool_use_id.to_string(), + }); + } + LanguageModelToolResultContent::Image(_) => { + // no support for images for now + } + }; + } + } + } + } + ChatCompletionRequest { model: self.model.name.clone(), - messages: request - .messages - .into_iter() - .map(|msg| match msg.role { - Role::User => ChatMessage::User { - content: msg.string_contents(), - }, - Role::Assistant => ChatMessage::Assistant { - content: Some(msg.string_contents()), - tool_calls: None, - }, - Role::System => ChatMessage::System { - content: msg.string_contents(), - }, - }) - .collect(), + messages, stream: true, max_tokens: Some(-1), stop: Some(request.stop), - temperature: request.temperature.or(Some(0.0)), - tools: vec![], + // In LM Studio you can configure specific settings you'd like to use for your model. + // For example Qwen3 is recommended to be used with 0.7 temperature. + // It would be a bad UX to silently override these settings from Zed, so we pass no temperature as a default. + temperature: request.temperature.or(None), + tools: request + .tools + .into_iter() + .map(|tool| lmstudio::ToolDefinition::Function { + function: lmstudio::FunctionDefinition { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + }, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => lmstudio::ToolChoice::Auto, + LanguageModelToolChoice::Any => lmstudio::ToolChoice::Required, + LanguageModelToolChoice::None => lmstudio::ToolChoice::None, + }), } } + + fn stream_completion( + &self, + request: ChatCompletionRequest, + cx: &AsyncApp, + ) -> BoxFuture<'static, Result>>> + { + let http_client = self.http_client.clone(); + let Ok(api_url) = cx.update(|cx| { + let settings = &AllLanguageModelSettings::get_global(cx).lmstudio; + settings.api_url.clone() + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + let future = self.request_limiter.stream(async move { + let request = stream_chat_completion(http_client.as_ref(), &api_url, request); + let response = request.await?; + Ok(response) + }); + + async move { Ok(future.await?.boxed()) }.boxed() + } } impl LanguageModel for LmStudioLanguageModel { @@ -282,17 +383,22 @@ impl LanguageModel for LmStudioLanguageModel { } fn supports_tools(&self) -> bool { - false + self.model.supports_tool_calls() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + self.supports_tools() + && match choice { + LanguageModelToolChoice::Auto => true, + LanguageModelToolChoice::Any => true, + LanguageModelToolChoice::None => true, + } } fn supports_images(&self) -> bool { false } - fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool { - false - } - fn telemetry_id(&self) -> String { format!("lmstudio/{}", self.model.id()) } @@ -328,85 +434,126 @@ impl LanguageModel for LmStudioLanguageModel { >, > { let request = self.to_lmstudio_request(request); - - let http_client = self.http_client.clone(); - let Ok(api_url) = cx.update(|cx| { - let settings = &AllLanguageModelSettings::get_global(cx).lmstudio; - settings.api_url.clone() - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); - }; - - let future = self.request_limiter.stream(async move { - let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?; - - // Create a stream mapper to handle content across multiple deltas - let stream_mapper = LmStudioStreamMapper::new(); - - let stream = response - .map(move |response| { - response.and_then(|fragment| stream_mapper.process_fragment(fragment)) - }) - .filter_map(|result| async move { - match result { - Ok(Some(content)) => Some(Ok(content)), - Ok(None) => None, - Err(error) => Some(Err(error)), - } - }) - .boxed(); - - Ok(stream) - }); - + let completions = self.stream_completion(request, cx); async move { - Ok(future - .await? - .map(|result| { - result - .map(LanguageModelCompletionEvent::Text) - .map_err(LanguageModelCompletionError::Other) - }) - .boxed()) + let mapper = LmStudioEventMapper::new(); + Ok(mapper.map_stream(completions.await?).boxed()) } .boxed() } } -// This will be more useful when we implement tool calling. Currently keeping it empty. -struct LmStudioStreamMapper {} +struct LmStudioEventMapper { + tool_calls_by_index: HashMap, +} -impl LmStudioStreamMapper { +impl LmStudioEventMapper { fn new() -> Self { - Self {} + Self { + tool_calls_by_index: HashMap::default(), + } } - fn process_fragment(&self, fragment: lmstudio::ChatResponse) -> Result> { - // Most of the time, there will be only one choice - let Some(choice) = fragment.choices.first() else { - return Ok(None); + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + }) + }) + } + + pub fn map_event( + &mut self, + event: ResponseStreamEvent, + ) -> Vec> { + let Some(choice) = event.choices.into_iter().next() else { + return vec![Err(LanguageModelCompletionError::Other(anyhow!( + "Response contained no choices" + )))]; }; - // Extract the delta content - if let Ok(delta) = - serde_json::from_value::(choice.delta.clone()) - { - if let Some(content) = delta.content { - if !content.is_empty() { - return Ok(Some(content)); + let mut events = Vec::new(); + if let Some(content) = choice.delta.content { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } + + if let Some(tool_calls) = choice.delta.tool_calls { + for tool_call in tool_calls { + let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); + + if let Some(tool_id) = tool_call.id { + entry.id = tool_id; + } + + if let Some(function) = tool_call.function { + if let Some(name) = function.name { + // At the time of writing this code LM Studio (0.3.15) is incompatible with the OpenAI API: + // 1. It sends function name in the first chunk + // 2. It sends empty string in the function name field in all subsequent chunks for arguments + // According to https://platform.openai.com/docs/guides/function-calling?api-mode=responses#streaming + // function name field should be sent only inside the first chunk. + if !name.is_empty() { + entry.name = name; + } + } + + if let Some(arguments) = function.arguments { + entry.arguments.push_str(&arguments); + } } } } - // If there's a finish_reason, we're done - if choice.finish_reason.is_some() { - return Ok(None); + match choice.finish_reason.as_deref() { + Some("stop") => { + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + Some("tool_calls") => { + events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { + match serde_json::Value::from_str(&tool_call.arguments) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.into(), + name: tool_call.name.into(), + is_input_complete: true, + input, + raw_input: tool_call.arguments, + }, + )), + Err(error) => Err(LanguageModelCompletionError::BadInputJson { + id: tool_call.id.into(), + tool_name: tool_call.name.into(), + raw_input: tool_call.arguments.into(), + json_parse_error: error.to_string(), + }), + } + })); + + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + } + Some(stop_reason) => { + log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + None => {} } - Ok(None) + events } } +#[derive(Default)] +struct RawToolCall { + id: String, + name: String, + arguments: String, +} + struct ConfigurationView { state: gpui::Entity, loading_models_task: Option>, diff --git a/crates/lmstudio/src/lmstudio.rs b/crates/lmstudio/src/lmstudio.rs index 5fd192c7c7..e82eef5e4b 100644 --- a/crates/lmstudio/src/lmstudio.rs +++ b/crates/lmstudio/src/lmstudio.rs @@ -2,7 +2,7 @@ use anyhow::{Context as _, Result}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http}; use serde::{Deserialize, Serialize}; -use serde_json::{Value, value::RawValue}; +use serde_json::Value; use std::{convert::TryFrom, sync::Arc, time::Duration}; pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0"; @@ -47,14 +47,21 @@ pub struct Model { pub name: String, pub display_name: Option, pub max_tokens: usize, + pub supports_tool_calls: bool, } impl Model { - pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option) -> Self { + pub fn new( + name: &str, + display_name: Option<&str>, + max_tokens: Option, + supports_tool_calls: bool, + ) -> Self { Self { name: name.to_owned(), display_name: display_name.map(|s| s.to_owned()), max_tokens: max_tokens.unwrap_or(2048), + supports_tool_calls, } } @@ -69,15 +76,43 @@ impl Model { pub fn max_token_count(&self) -> usize { self.max_tokens } + + pub fn supports_tool_calls(&self) -> bool { + self.supports_tool_calls + } } + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolChoice { + Auto, + Required, + None, + Other(ToolDefinition), +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolDefinition { + #[allow(dead_code)] + Function { function: FunctionDefinition }, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FunctionDefinition { + pub name: String, + pub description: Option, + pub parameters: Option, +} + #[derive(Serialize, Deserialize, Debug)] #[serde(tag = "role", rename_all = "lowercase")] pub enum ChatMessage { Assistant { #[serde(default)] content: Option, - #[serde(default)] - tool_calls: Option>, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + tool_calls: Vec, }, User { content: String, @@ -85,31 +120,29 @@ pub enum ChatMessage { System { content: String, }, -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(rename_all = "lowercase")] -pub enum LmStudioToolCall { - Function(LmStudioFunctionCall), -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct LmStudioFunctionCall { - pub name: String, - pub arguments: Box, + Tool { + content: String, + tool_call_id: String, + }, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct LmStudioFunctionTool { - pub name: String, - pub description: Option, - pub parameters: Option, +pub struct ToolCall { + pub id: String, + #[serde(flatten)] + pub content: ToolCallContent, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(tag = "type", rename_all = "lowercase")] -pub enum LmStudioTool { - Function { function: LmStudioFunctionTool }, +pub enum ToolCallContent { + Function { function: FunctionContent }, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct FunctionContent { + pub name: String, + pub arguments: String, } #[derive(Serialize, Debug)] @@ -117,10 +150,16 @@ pub struct ChatCompletionRequest { pub model: String, pub messages: Vec, pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, - pub tools: Vec, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, } #[derive(Serialize, Deserialize, Debug)] @@ -135,8 +174,7 @@ pub struct ChatResponse { #[derive(Serialize, Deserialize, Debug)] pub struct ChoiceDelta { pub index: u32, - #[serde(default)] - pub delta: serde_json::Value, + pub delta: ResponseMessageDelta, pub finish_reason: Option, } @@ -164,6 +202,16 @@ pub struct Usage { pub total_tokens: u32, } +#[derive(Debug, Default, Clone, Deserialize, PartialEq)] +#[serde(transparent)] +pub struct Capabilities(Vec); + +impl Capabilities { + pub fn supports_tool_calls(&self) -> bool { + self.0.iter().any(|cap| cap == "tool_use") + } +} + #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum ResponseStreamResult { @@ -175,16 +223,17 @@ pub enum ResponseStreamResult { pub struct ResponseStreamEvent { pub created: u32, pub model: String, + pub object: String, pub choices: Vec, pub usage: Option, } -#[derive(Serialize, Deserialize)] +#[derive(Deserialize)] pub struct ListModelsResponse { pub data: Vec, } -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[derive(Clone, Debug, Deserialize, PartialEq)] pub struct ModelEntry { pub id: String, pub object: String, @@ -196,6 +245,8 @@ pub struct ModelEntry { pub state: ModelState, pub max_context_length: Option, pub loaded_context_length: Option, + #[serde(default)] + pub capabilities: Capabilities, } #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] @@ -265,7 +316,7 @@ pub async fn stream_chat_completion( client: &dyn HttpClient, api_url: &str, request: ChatCompletionRequest, -) -> Result>> { +) -> Result>> { let uri = format!("{api_url}/chat/completions"); let request_builder = http::Request::builder() .method(Method::POST)