diff --git a/Cargo.lock b/Cargo.lock index 0a544dcf65..1d236bfb9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -406,6 +406,7 @@ dependencies = [ "language_model_selector", "language_models", "languages", + "lmstudio", "log", "lsp", "markdown", @@ -483,6 +484,7 @@ dependencies = [ "language_model", "language_model_selector", "language_models", + "lmstudio", "log", "lsp", "markdown", @@ -6682,6 +6684,7 @@ dependencies = [ "gpui", "http_client", "image", + "lmstudio", "log", "ollama", "open_ai", @@ -6727,6 +6730,7 @@ dependencies = [ "gpui", "http_client", "language_model", + "lmstudio", "menu", "ollama", "open_ai", @@ -7195,6 +7199,18 @@ dependencies = [ "libc", ] +[[package]] +name = "lmstudio" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.31", + "http_client", + "schemars", + "serde", + "serde_json", +] + [[package]] name = "lock_api" version = "0.4.12" diff --git a/Cargo.toml b/Cargo.toml index b5787abd42..365f892031 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,6 +69,7 @@ members = [ "crates/livekit_client", "crates/livekit_client_macos", "crates/livekit_server", + "crates/lmstudio", "crates/lsp", "crates/markdown", "crates/markdown_preview", @@ -255,6 +256,7 @@ languages = { path = "crates/languages" } livekit_client = { path = "crates/livekit_client" } livekit_client_macos = { path = "crates/livekit_client_macos" } livekit_server = { path = "crates/livekit_server" } +lmstudio = { path = "crates/lmstudio" } lsp = { path = "crates/lsp" } markdown = { path = "crates/markdown" } markdown_preview = { path = "crates/markdown_preview" } @@ -614,6 +616,7 @@ image_viewer = { codegen-units = 1 } inline_completion_button = { codegen-units = 1 } install_cli = { codegen-units = 1 } journal = { codegen-units = 1 } +lmstudio = { codegen-units = 1 } menu = { codegen-units = 1 } notifications = { codegen-units = 1 } ollama = { codegen-units = 1 } diff --git a/assets/icons/ai_lm_studio.svg b/assets/icons/ai_lm_studio.svg new file mode 100644 index 0000000000..0b455f48a7 --- /dev/null +++ b/assets/icons/ai_lm_studio.svg @@ -0,0 +1,33 @@ + + + Artboard + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/settings/default.json b/assets/settings/default.json index 62f4fd19b5..9967d4004b 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -1146,6 +1146,9 @@ "openai": { "version": "1", "api_url": "https://api.openai.com/v1" + }, + "lmstudio": { + "api_url": "http://localhost:1234/api/v0" } }, // Zed's Prettier integration settings. diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index 456cf9dcd6..63a2afa71a 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -52,6 +52,7 @@ language.workspace = true language_model.workspace = true language_model_selector.workspace = true language_models.workspace = true +lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true lsp.workspace = true markdown.workspace = true diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 89dbbd0fff..879838552b 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -5,6 +5,7 @@ use anthropic::Model as AnthropicModel; use feature_flags::FeatureFlagAppExt; use gpui::{AppContext, Pixels}; use language_model::{CloudModel, LanguageModel}; +use lmstudio::Model as LmStudioModel; use ollama::Model as OllamaModel; use schemars::{schema::Schema, JsonSchema}; use serde::{Deserialize, Serialize}; @@ -40,6 +41,10 @@ pub enum AssistantProviderContentV1 { default_model: Option, api_url: Option, }, + LmStudio { + default_model: Option, + api_url: Option, + }, } #[derive(Debug, Default)] @@ -137,6 +142,12 @@ impl AssistantSettingsContent { model: model.id().to_string(), }) } + AssistantProviderContentV1::LmStudio { default_model, .. } => { + default_model.map(|model| LanguageModelSelection { + provider: "lmstudio".to_string(), + model: model.id().to_string(), + }) + } }), inline_alternatives: None, enable_experimental_live_diffs: None, @@ -214,6 +225,18 @@ impl AssistantSettingsContent { api_url, }); } + "lmstudio" => { + let api_url = match &settings.provider { + Some(AssistantProviderContentV1::LmStudio { api_url, .. }) => { + api_url.clone() + } + _ => None, + }; + settings.provider = Some(AssistantProviderContentV1::LmStudio { + default_model: Some(lmstudio::Model::new(&model, None, None)), + api_url, + }); + } "openai" => { let (api_url, available_models) = match &settings.provider { Some(AssistantProviderContentV1::OpenAi { @@ -313,6 +336,7 @@ fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema: "anthropic".into(), "google".into(), "ollama".into(), + "lmstudio".into(), "openai".into(), "zed.dev".into(), "copilot_chat".into(), @@ -355,7 +379,7 @@ pub struct AssistantSettingsContentV1 { default_height: Option, /// The provider of the assistant service. /// - /// This can be "openai", "anthropic", "ollama", "zed.dev" + /// This can be "openai", "anthropic", "ollama", "lmstudio", "zed.dev" /// each with their respective default models and configurations. provider: Option, } diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index 85220bc56f..d8d20c0c61 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -46,6 +46,7 @@ markdown.workspace = true menu.workspace = true multi_buffer.workspace = true ollama = { workspace = true, features = ["schemars"] } +lmstudio = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } ordered-float.workspace = true parking_lot.workspace = true diff --git a/crates/assistant2/src/assistant_settings.rs b/crates/assistant2/src/assistant_settings.rs index 1d6d1cd850..65d317f1e6 100644 --- a/crates/assistant2/src/assistant_settings.rs +++ b/crates/assistant2/src/assistant_settings.rs @@ -4,6 +4,7 @@ use ::open_ai::Model as OpenAiModel; use anthropic::Model as AnthropicModel; use gpui::Pixels; use language_model::{CloudModel, LanguageModel}; +use lmstudio::Model as LmStudioModel; use ollama::Model as OllamaModel; use schemars::{schema::Schema, JsonSchema}; use serde::{Deserialize, Serialize}; @@ -39,6 +40,11 @@ pub enum AssistantProviderContentV1 { default_model: Option, api_url: Option, }, + #[serde(rename = "lmstudio")] + LmStudio { + default_model: Option, + api_url: Option, + }, } #[derive(Debug, Default)] @@ -130,6 +136,12 @@ impl AssistantSettingsContent { model: model.id().to_string(), }) } + AssistantProviderContentV1::LmStudio { default_model, .. } => { + default_model.map(|model| LanguageModelSelection { + provider: "lmstudio".to_string(), + model: model.id().to_string(), + }) + } }), inline_alternatives: None, enable_experimental_live_diffs: None, @@ -207,6 +219,18 @@ impl AssistantSettingsContent { api_url, }); } + "lmstudio" => { + let api_url = match &settings.provider { + Some(AssistantProviderContentV1::LmStudio { api_url, .. }) => { + api_url.clone() + } + _ => None, + }; + settings.provider = Some(AssistantProviderContentV1::LmStudio { + default_model: Some(lmstudio::Model::new(&model, None, None)), + api_url, + }); + } "openai" => { let (api_url, available_models) = match &settings.provider { Some(AssistantProviderContentV1::OpenAi { @@ -305,6 +329,7 @@ fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema: enum_values: Some(vec![ "anthropic".into(), "google".into(), + "lmstudio".into(), "ollama".into(), "openai".into(), "zed.dev".into(), diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 0fc54d509d..02b0bee939 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -28,6 +28,7 @@ image.workspace = true log.workspace = true ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } +lmstudio = { workspace = true, features = ["schemars"] } parking_lot.workspace = true proto.workspace = true schemars.workspace = true diff --git a/crates/language_model/src/model/mod.rs b/crates/language_model/src/model/mod.rs index 7b5ac88dea..12aaed3ab2 100644 --- a/crates/language_model/src/model/mod.rs +++ b/crates/language_model/src/model/mod.rs @@ -2,5 +2,6 @@ pub mod cloud_model; pub use anthropic::Model as AnthropicModel; pub use cloud_model::*; +pub use lmstudio::Model as LmStudioModel; pub use ollama::Model as OllamaModel; pub use open_ai::Model as OpenAiModel; diff --git a/crates/language_model/src/role.rs b/crates/language_model/src/role.rs index 2205ac52dc..17366fa3ec 100644 --- a/crates/language_model/src/role.rs +++ b/crates/language_model/src/role.rs @@ -65,3 +65,13 @@ impl From for open_ai::Role { } } } + +impl From for lmstudio::Role { + fn from(val: Role) -> Self { + match val { + Role::User => lmstudio::Role::User, + Role::Assistant => lmstudio::Role::Assistant, + Role::System => lmstudio::Role::System, + } + } +} diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 00d948bd2d..2918041712 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -27,6 +27,7 @@ http_client.workspace = true language_model.workspace = true menu.workspace = true ollama = { workspace = true, features = ["schemars"] } +lmstudio = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } project.workspace = true proto.workspace = true diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 6d618d1ec5..604c1a7ce4 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -15,6 +15,7 @@ pub use crate::provider::cloud::LlmApiToken; pub use crate::provider::cloud::RefreshLlmTokenListener; use crate::provider::copilot_chat::CopilotChatLanguageModelProvider; use crate::provider::google::GoogleLanguageModelProvider; +use crate::provider::lmstudio::LmStudioLanguageModelProvider; use crate::provider::ollama::OllamaLanguageModelProvider; use crate::provider::open_ai::OpenAiLanguageModelProvider; pub use crate::settings::*; @@ -55,6 +56,10 @@ fn register_language_model_providers( OllamaLanguageModelProvider::new(client.http_client(), cx), cx, ); + registry.register_provider( + LmStudioLanguageModelProvider::new(client.http_client(), cx), + cx, + ); registry.register_provider( GoogleLanguageModelProvider::new(client.http_client(), cx), cx, diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index fb79b12e4d..09fb975fc6 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -2,5 +2,6 @@ pub mod anthropic; pub mod cloud; pub mod copilot_chat; pub mod google; +pub mod lmstudio; pub mod ollama; pub mod open_ai; diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs new file mode 100644 index 0000000000..61f4ce3270 --- /dev/null +++ b/crates/language_models/src/provider/lmstudio.rs @@ -0,0 +1,518 @@ +use anyhow::{anyhow, Result}; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task}; +use http_client::HttpClient; +use language_model::LanguageModelCompletionEvent; +use language_model::{ + LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelRequest, RateLimiter, Role, +}; +use lmstudio::{ + get_models, preload_model, stream_chat_completion, ChatCompletionRequest, ChatMessage, + ModelType, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsStore}; +use std::{collections::BTreeMap, sync::Arc}; +use ui::{prelude::*, ButtonLike, Indicator}; +use util::ResultExt; + +use crate::AllLanguageModelSettings; + +const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download"; +const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models"; +const LMSTUDIO_SITE: &str = "https://lmstudio.ai/"; + +const PROVIDER_ID: &str = "lmstudio"; +const PROVIDER_NAME: &str = "LM Studio"; + +#[derive(Default, Debug, Clone, PartialEq)] +pub struct LmStudioSettings { + pub api_url: String, + pub available_models: Vec, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc + pub name: String, + /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel. + pub display_name: Option, + /// The model's context window size. + pub max_tokens: usize, +} + +pub struct LmStudioLanguageModelProvider { + http_client: Arc, + state: gpui::Model, +} + +pub struct State { + http_client: Arc, + available_models: Vec, + fetch_model_task: Option>>, + _subscription: Subscription, +} + +impl State { + fn is_authenticated(&self) -> bool { + !self.available_models.is_empty() + } + + fn fetch_models(&mut self, cx: &mut ModelContext) -> Task> { + let settings = &AllLanguageModelSettings::get_global(cx).lmstudio; + let http_client = self.http_client.clone(); + let api_url = settings.api_url.clone(); + + // As a proxy for the server being "authenticated", we'll check if its up by fetching the models + cx.spawn(|this, mut cx| async move { + let models = get_models(http_client.as_ref(), &api_url, None).await?; + + let mut models: Vec = models + .into_iter() + .filter(|model| model.r#type != ModelType::Embeddings) + .map(|model| lmstudio::Model::new(&model.id, None, None)) + .collect(); + + models.sort_by(|a, b| a.name.cmp(&b.name)); + + this.update(&mut cx, |this, cx| { + this.available_models = models; + cx.notify(); + }) + }) + } + + fn restart_fetch_models_task(&mut self, cx: &mut ModelContext) { + let task = self.fetch_models(cx); + self.fetch_model_task.replace(task); + } + + fn authenticate(&mut self, cx: &mut ModelContext) -> Task> { + if self.is_authenticated() { + Task::ready(Ok(())) + } else { + self.fetch_models(cx) + } + } +} + +impl LmStudioLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut AppContext) -> Self { + let this = Self { + http_client: http_client.clone(), + state: cx.new_model(|cx| { + let subscription = cx.observe_global::({ + let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone(); + move |this: &mut State, cx| { + let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio; + if &settings != new_settings { + settings = new_settings.clone(); + this.restart_fetch_models_task(cx); + cx.notify(); + } + } + }); + + State { + http_client, + available_models: Default::default(), + fetch_model_task: None, + _subscription: subscription, + } + }), + }; + this.state + .update(cx, |state, cx| state.restart_fetch_models_task(cx)); + this + } +} + +impl LanguageModelProviderState for LmStudioLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for LmStudioLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn icon(&self) -> IconName { + IconName::AiLmStudio + } + + fn provided_models(&self, cx: &AppContext) -> Vec> { + let mut models: BTreeMap = BTreeMap::default(); + + // Add models from the LM Studio API + for model in self.state.read(cx).available_models.iter() { + models.insert(model.name.clone(), model.clone()); + } + + // Override with available models from settings + for model in AllLanguageModelSettings::get_global(cx) + .lmstudio + .available_models + .iter() + { + models.insert( + model.name.clone(), + lmstudio::Model { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + }, + ); + } + + models + .into_values() + .map(|model| { + Arc::new(LmStudioLanguageModel { + id: LanguageModelId::from(model.name.clone()), + model: model.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) as Arc + }) + .collect() + } + + fn load_model(&self, model: Arc, cx: &AppContext) { + let settings = &AllLanguageModelSettings::get_global(cx).lmstudio; + let http_client = self.http_client.clone(); + let api_url = settings.api_url.clone(); + let id = model.id().0.to_string(); + cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await }) + .detach_and_log_err(cx); + } + + fn is_authenticated(&self, cx: &AppContext) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut AppContext) -> Task> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view(&self, cx: &mut WindowContext) -> AnyView { + let state = self.state.clone(); + cx.new_view(|cx| ConfigurationView::new(state, cx)).into() + } + + fn reset_credentials(&self, cx: &mut AppContext) -> Task> { + self.state.update(cx, |state, cx| state.fetch_models(cx)) + } +} + +pub struct LmStudioLanguageModel { + id: LanguageModelId, + model: lmstudio::Model, + http_client: Arc, + request_limiter: RateLimiter, +} + +impl LmStudioLanguageModel { + fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest { + ChatCompletionRequest { + model: self.model.name.clone(), + messages: request + .messages + .into_iter() + .map(|msg| match msg.role { + Role::User => ChatMessage::User { + content: msg.string_contents(), + }, + Role::Assistant => ChatMessage::Assistant { + content: Some(msg.string_contents()), + tool_calls: None, + }, + Role::System => ChatMessage::System { + content: msg.string_contents(), + }, + }) + .collect(), + stream: true, + max_tokens: Some(-1), + stop: Some(request.stop), + temperature: request.temperature.or(Some(0.0)), + tools: vec![], + } + } +} + +impl LanguageModel for LmStudioLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn telemetry_id(&self) -> String { + format!("lmstudio/{}", self.model.id()) + } + + fn max_token_count(&self) -> usize { + self.model.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + _cx: &AppContext, + ) -> BoxFuture<'static, Result> { + // Endpoint for this is coming soon. In the meantime, hacky estimation + let token_count = request + .messages + .iter() + .map(|msg| msg.string_contents().split_whitespace().count()) + .sum::(); + + let estimated_tokens = (token_count as f64 * 0.75) as usize; + async move { Ok(estimated_tokens) }.boxed() + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + let request = self.to_lmstudio_request(request); + + let http_client = self.http_client.clone(); + let Ok(api_url) = cx.update(|cx| { + let settings = &AllLanguageModelSettings::get_global(cx).lmstudio; + settings.api_url.clone() + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + let future = self.request_limiter.stream(async move { + let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(fragment) => { + // Skip empty deltas + if fragment.choices[0].delta.is_object() + && fragment.choices[0].delta.as_object().unwrap().is_empty() + { + return None; + } + + // Try to parse the delta as ChatMessage + if let Ok(chat_message) = serde_json::from_value::( + fragment.choices[0].delta.clone(), + ) { + let content = match chat_message { + ChatMessage::User { content } => content, + ChatMessage::Assistant { content, .. } => { + content.unwrap_or_default() + } + ChatMessage::System { content } => content, + }; + if !content.is_empty() { + Some(Ok(content)) + } else { + None + } + } else { + None + } + } + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + }); + + async move { + Ok(future + .await? + .map(|result| result.map(LanguageModelCompletionEvent::Text)) + .boxed()) + } + .boxed() + } + + fn use_any_tool( + &self, + _request: LanguageModelRequest, + _tool_name: String, + _tool_description: String, + _schema: serde_json::Value, + _cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + async move { Ok(futures::stream::empty().boxed()) }.boxed() + } +} + +struct ConfigurationView { + state: gpui::Model, + loading_models_task: Option>, +} + +impl ConfigurationView { + pub fn new(state: gpui::Model, cx: &mut ViewContext) -> Self { + let loading_models_task = Some(cx.spawn({ + let state = state.clone(); + |this, mut cx| async move { + if let Some(task) = state + .update(&mut cx, |state, cx| state.authenticate(cx)) + .log_err() + { + task.await.log_err(); + } + this.update(&mut cx, |this, cx| { + this.loading_models_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + state, + loading_models_task, + } + } + + fn retry_connection(&self, cx: &mut WindowContext) { + self.state + .update(cx, |state, cx| state.fetch_models(cx)) + .detach_and_log_err(cx); + } +} + +impl Render for ConfigurationView { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + let is_authenticated = self.state.read(cx).is_authenticated(); + + let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen."; + let lmstudio_reqs = + "To use LM Studio as a provider for Zed assistant, it needs to be running with at least one model downloaded."; + + let mut inline_code_bg = cx.theme().colors().editor_background; + inline_code_bg.fade_out(0.5); + + if self.loading_models_task.is_some() { + div().child(Label::new("Loading models...")).into_any() + } else { + v_flex() + .size_full() + .gap_3() + .child( + v_flex() + .size_full() + .gap_2() + .p_1() + .child(Label::new(lmstudio_intro)) + .child(Label::new(lmstudio_reqs)) + .child( + h_flex() + .gap_0p5() + .child(Label::new("To get your first model, try running ")) + .child( + div() + .bg(inline_code_bg) + .px_1p5() + .rounded_md() + .child(Label::new("lms get qwen2.5-coder-7b")), + ), + ), + ) + .child( + h_flex() + .w_full() + .pt_2() + .justify_between() + .gap_2() + .child( + h_flex() + .w_full() + .gap_2() + .map(|this| { + if is_authenticated { + this.child( + Button::new("lmstudio-site", "LM Studio") + .style(ButtonStyle::Subtle) + .icon(IconName::ExternalLink) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, cx| cx.open_url(LMSTUDIO_SITE)) + .into_any_element(), + ) + } else { + this.child( + Button::new( + "download_lmstudio_button", + "Download LM Studio", + ) + .style(ButtonStyle::Subtle) + .icon(IconName::ExternalLink) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, cx| { + cx.open_url(LMSTUDIO_DOWNLOAD_URL) + }) + .into_any_element(), + ) + } + }) + .child( + Button::new("view-models", "Model Catalog") + .style(ButtonStyle::Subtle) + .icon(IconName::ExternalLink) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, cx| cx.open_url(LMSTUDIO_CATALOG_URL)), + ), + ) + .child(if is_authenticated { + // This is only a button to ensure the spacing is correct + // it should stay disabled + ButtonLike::new("connected") + .disabled(true) + // Since this won't ever be clickable, we can use the arrow cursor + .cursor_style(gpui::CursorStyle::Arrow) + .child( + h_flex() + .gap_2() + .child(Indicator::dot().color(Color::Success)) + .child(Label::new("Connected")) + .into_any_element(), + ) + .into_any_element() + } else { + Button::new("retry_lmstudio_models", "Connect") + .icon_position(IconPosition::Start) + .icon(IconName::ArrowCircle) + .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx))) + .into_any_element() + }), + ) + .into_any() + } + } +} diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index c8ec9f7369..2c44a3983b 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -14,6 +14,7 @@ use crate::provider::{ cloud::{self, ZedDotDevSettings}, copilot_chat::CopilotChatSettings, google::GoogleSettings, + lmstudio::LmStudioSettings, ollama::OllamaSettings, open_ai::OpenAiSettings, }; @@ -59,12 +60,14 @@ pub struct AllLanguageModelSettings { pub zed_dot_dev: ZedDotDevSettings, pub google: GoogleSettings, pub copilot_chat: CopilotChatSettings, + pub lmstudio: LmStudioSettings, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct AllLanguageModelSettingsContent { pub anthropic: Option, pub ollama: Option, + pub lmstudio: Option, pub openai: Option, #[serde(rename = "zed.dev")] pub zed_dot_dev: Option, @@ -153,6 +156,12 @@ pub struct OllamaSettingsContent { pub available_models: Option>, } +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct LmStudioSettingsContent { + pub api_url: Option, + pub available_models: Option>, +} + #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[serde(untagged)] pub enum OpenAiSettingsContent { @@ -278,6 +287,18 @@ impl settings::Settings for AllLanguageModelSettings { ollama.as_ref().and_then(|s| s.available_models.clone()), ); + // LM Studio + let lmstudio = value.lmstudio.clone(); + + merge( + &mut settings.lmstudio.api_url, + value.lmstudio.as_ref().and_then(|s| s.api_url.clone()), + ); + merge( + &mut settings.lmstudio.available_models, + lmstudio.as_ref().and_then(|s| s.available_models.clone()), + ); + // OpenAI let (openai, upgraded) = match value.openai.clone().map(|s| s.upgrade()) { Some((content, upgraded)) => (Some(content), upgraded), diff --git a/crates/lmstudio/Cargo.toml b/crates/lmstudio/Cargo.toml new file mode 100644 index 0000000000..8027830ddc --- /dev/null +++ b/crates/lmstudio/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "lmstudio" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/lmstudio.rs" + +[features] +default = [] +schemars = ["dep:schemars"] + +[dependencies] +anyhow.workspace = true +futures.workspace = true +http_client.workspace = true +schemars = { workspace = true, optional = true } +serde.workspace = true +serde_json.workspace = true diff --git a/crates/lmstudio/LICENSE-GPL b/crates/lmstudio/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/lmstudio/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/lmstudio/src/lmstudio.rs b/crates/lmstudio/src/lmstudio.rs new file mode 100644 index 0000000000..2b4771cdee --- /dev/null +++ b/crates/lmstudio/src/lmstudio.rs @@ -0,0 +1,369 @@ +use anyhow::{anyhow, Context, Result}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use http_client::{http, AsyncBody, HttpClient, Method, Request as HttpRequest}; +use serde::{Deserialize, Serialize}; +use serde_json::{value::RawValue, Value}; +use std::{convert::TryFrom, sync::Arc, time::Duration}; + +pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0"; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, + Tool, +} + +impl TryFrom for Role { + type Error = anyhow::Error; + + fn try_from(value: String) -> Result { + match value.as_str() { + "user" => Ok(Self::User), + "assistant" => Ok(Self::Assistant), + "system" => Ok(Self::System), + "tool" => Ok(Self::Tool), + _ => Err(anyhow!("invalid role '{value}'")), + } + } +} + +impl From for String { + fn from(val: Role) -> Self { + match val { + Role::User => "user".to_owned(), + Role::Assistant => "assistant".to_owned(), + Role::System => "system".to_owned(), + Role::Tool => "tool".to_owned(), + } + } +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct Model { + pub name: String, + pub display_name: Option, + pub max_tokens: usize, +} + +impl Model { + pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option) -> Self { + Self { + name: name.to_owned(), + display_name: display_name.map(|s| s.to_owned()), + max_tokens: max_tokens.unwrap_or(2048), + } + } + + pub fn id(&self) -> &str { + &self.name + } + + pub fn display_name(&self) -> &str { + self.display_name.as_ref().unwrap_or(&self.name) + } + + pub fn max_token_count(&self) -> usize { + self.max_tokens + } +} +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "role", rename_all = "lowercase")] +pub enum ChatMessage { + Assistant { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, + }, + User { + content: String, + }, + System { + content: String, + }, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "lowercase")] +pub enum LmStudioToolCall { + Function(LmStudioFunctionCall), +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct LmStudioFunctionCall { + pub name: String, + pub arguments: Box, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct LmStudioFunctionTool { + pub name: String, + pub description: Option, + pub parameters: Option, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum LmStudioTool { + Function { function: LmStudioFunctionTool }, +} + +#[derive(Serialize, Debug)] +pub struct ChatCompletionRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + pub max_tokens: Option, + pub stop: Option>, + pub temperature: Option, + pub tools: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ChatResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ChoiceDelta { + pub index: u32, + #[serde(default)] + pub delta: serde_json::Value, + pub finish_reason: Option, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ToolCallChunk { + pub index: usize, + pub id: Option, + + // There is also an optional `type` field that would determine if a + // function is there. Sometimes this streams in with the `function` before + // it streams in the `type` + pub function: Option, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct FunctionChunk { + pub name: Option, + pub arguments: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(untagged)] +pub enum ResponseStreamResult { + Ok(ResponseStreamEvent), + Err { error: String }, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ResponseStreamEvent { + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +#[derive(Serialize, Deserialize)] +pub struct ListModelsResponse { + pub data: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct ModelEntry { + pub id: String, + pub object: String, + pub r#type: ModelType, + pub publisher: String, + pub arch: Option, + pub compatibility_type: CompatibilityType, + pub quantization: String, + pub state: ModelState, + pub max_context_length: Option, + pub loaded_context_length: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ModelType { + Llm, + Embeddings, + Vlm, +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "kebab-case")] +pub enum ModelState { + Loaded, + Loading, + NotLoaded, +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum CompatibilityType { + Gguf, + Mlx, +} + +pub async fn complete( + client: &dyn HttpClient, + api_url: &str, + request: ChatCompletionRequest, +) -> Result { + let uri = format!("{api_url}/chat/completions"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json"); + + let serialized_request = serde_json::to_string(&request)?; + let request = request_builder.body(AsyncBody::from(serialized_request))?; + + let mut response = client.send(request).await?; + if response.status().is_success() { + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + let response_message: ChatResponse = serde_json::from_slice(&body)?; + Ok(response_message) + } else { + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + let body_str = std::str::from_utf8(&body)?; + Err(anyhow!( + "Failed to connect to API: {} {}", + response.status(), + body_str + )) + } +} + +pub async fn stream_chat_completion( + client: &dyn HttpClient, + api_url: &str, + request: ChatCompletionRequest, +) -> Result>> { + let uri = format!("{api_url}/chat/completions"); + let request_builder = http::Request::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json"); + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + if response.status().is_success() { + let reader = BufReader::new(response.into_body()); + + Ok(reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + let line = line.strip_prefix("data: ")?; + if line == "[DONE]" { + None + } else { + let result = serde_json::from_str(&line) + .context("Unable to parse chat completions response"); + if let Err(ref e) = result { + eprintln!("Error parsing line: {e}\nLine content: '{line}'"); + } + Some(result) + } + } + Err(e) => { + eprintln!("Error reading line: {e}"); + Some(Err(e.into())) + } + } + }) + .boxed()) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Err(anyhow!( + "Failed to connect to LM Studio API: {} {}", + response.status(), + body, + )) + } +} + +pub async fn get_models( + client: &dyn HttpClient, + api_url: &str, + _: Option, +) -> Result> { + let uri = format!("{api_url}/models"); + let request_builder = HttpRequest::builder() + .method(Method::GET) + .uri(uri) + .header("Accept", "application/json"); + + let request = request_builder.body(AsyncBody::default())?; + + let mut response = client.send(request).await?; + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + if response.status().is_success() { + let response: ListModelsResponse = + serde_json::from_str(&body).context("Unable to parse LM Studio models response")?; + Ok(response.data) + } else { + Err(anyhow!( + "Failed to connect to LM Studio API: {} {}", + response.status(), + body, + )) + } +} + +/// Sends an empty request to LM Studio to trigger loading the model +pub async fn preload_model(client: Arc, api_url: &str, model: &str) -> Result<()> { + let uri = format!("{api_url}/completions"); + let request = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .body(AsyncBody::from(serde_json::to_string( + &serde_json::json!({ + "model": model, + "messages": [], + "stream": false, + "max_tokens": 0, + }), + )?))?; + + let mut response = client.send(request).await?; + + if response.status().is_success() { + Ok(()) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Err(anyhow!( + "Failed to connect to LM Studio API: {} {}", + response.status(), + body, + )) + } +} diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index 1e1e0f0be7..6246347b9e 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -1,8 +1,10 @@ mod cloud; +mod lmstudio; mod ollama; mod open_ai; pub use cloud::*; +pub use lmstudio::*; pub use ollama::*; pub use open_ai::*; use sha2::{Digest, Sha256}; diff --git a/crates/semantic_index/src/embedding/lmstudio.rs b/crates/semantic_index/src/embedding/lmstudio.rs new file mode 100644 index 0000000000..56f70ed4b1 --- /dev/null +++ b/crates/semantic_index/src/embedding/lmstudio.rs @@ -0,0 +1,70 @@ +use anyhow::{Context as _, Result}; +use futures::{future::BoxFuture, AsyncReadExt as _, FutureExt}; +use http_client::HttpClient; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +use crate::{Embedding, EmbeddingProvider, TextToEmbed}; + +pub enum LmStudioEmbeddingModel { + NomicEmbedText, +} + +pub struct LmStudioEmbeddingProvider { + client: Arc, + model: LmStudioEmbeddingModel, +} + +#[derive(Serialize)] +struct LmStudioEmbeddingRequest { + model: String, + prompt: String, +} + +#[derive(Deserialize)] +struct LmStudioEmbeddingResponse { + embedding: Vec, +} + +impl LmStudioEmbeddingProvider { + pub fn new(client: Arc, model: LmStudioEmbeddingModel) -> Self { + Self { client, model } + } +} + +impl EmbeddingProvider for LmStudioEmbeddingProvider { + fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>> { + let model = match self.model { + LmStudioEmbeddingModel::NomicEmbedText => "nomic-embed-text", + }; + + futures::future::try_join_all(texts.iter().map(|to_embed| { + let request = LmStudioEmbeddingRequest { + model: model.to_string(), + prompt: to_embed.text.to_string(), + }; + + let request = serde_json::to_string(&request).unwrap(); + + async { + let response = self + .client + .post_json("http://localhost:1234/api/v0/embeddings", request.into()) + .await?; + + let mut body = String::new(); + response.into_body().read_to_string(&mut body).await?; + + let response: LmStudioEmbeddingResponse = + serde_json::from_str(&body).context("Unable to parse response")?; + + Ok(Embedding::new(response.embedding)) + } + })) + .boxed() + } + + fn batch_size(&self) -> usize { + 256 + } +} diff --git a/crates/ui/src/components/icon.rs b/crates/ui/src/components/icon.rs index b27edf4d82..cec2e0df66 100644 --- a/crates/ui/src/components/icon.rs +++ b/crates/ui/src/components/icon.rs @@ -116,6 +116,7 @@ pub enum IconName { AiAnthropic, AiAnthropicHosted, AiGoogle, + AiLmStudio, AiOllama, AiOpenAi, AiZed, diff --git a/docs/src/assistant/assistant.md b/docs/src/assistant/assistant.md index 94144882f0..e92efe79fe 100644 --- a/docs/src/assistant/assistant.md +++ b/docs/src/assistant/assistant.md @@ -8,7 +8,7 @@ This section covers various aspects of the Assistant: - [Inline Assistant](./inline-assistant.md): Discover how to use the Assistant to power inline transformations directly within your code editor and terminal. -- [Providers & Configuration](./configuration.md): Configure the Assistant, and set up different language model providers like Anthropic, OpenAI, Ollama, Google Gemini, and GitHub Copilot Chat. +- [Providers & Configuration](./configuration.md): Configure the Assistant, and set up different language model providers like Anthropic, OpenAI, Ollama, LM Studio, Google Gemini, and GitHub Copilot Chat. - [Introducing Contexts](./contexts.md): Learn about contexts (similar to conversations), and learn how they power your interactions between you, your project, and the assistant/model. diff --git a/docs/src/assistant/configuration.md b/docs/src/assistant/configuration.md index 8e558007bf..038145c325 100644 --- a/docs/src/assistant/configuration.md +++ b/docs/src/assistant/configuration.md @@ -10,6 +10,7 @@ The following providers are supported: - [Google AI](#google-ai) [^1] - [Ollama](#ollama) - [OpenAI](#openai) +- [LM Studio](#lmstudio) To configure different providers, run `assistant: show configuration` in the command palette, or click on the hamburger menu at the top-right of the assistant panel and select "Configure". @@ -236,6 +237,25 @@ Example configuration for using X.ai Grok with Zed: } ``` +### LM Studio {#lmstudio} + +1. Download and install the latest version of LM Studio from https://lmstudio.ai/download +2. In the app press ⌘/Ctrl + Shift + M and download at least one model, e.g. qwen2.5-coder-7b + + You can also get models via the LM Studio CLI: + + ```sh + lms get qwen2.5-coder-7b + ``` + +3. Make sure the LM Studio API server by running: + + ```sh + lms server start + ``` + +Tip: Set [LM Studio as a login item](https://lmstudio.ai/docs/advanced/headless#run-the-llm-service-on-machine-login) to automate running the LM Studio server. + #### Custom endpoints {#custom-endpoint} You can use a custom API endpoint for different providers, as long as it's compatible with the providers API structure.