From 6563330239a142b7e0c2cb02abf7aa5da1571c57 Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Fri, 3 May 2024 12:50:42 -0700 Subject: [PATCH] Supermaven (#10788) Adds a supermaven provider for completions. There are various other refactors amidst this branch, primarily to make copilot no longer a dependency of project as well as show LSP Logs for global LSPs like copilot properly. This feature is not enabled by default. We're going to seek to refine it in the coming weeks. Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra Co-authored-by: Nathan Sobo Co-authored-by: Max Co-authored-by: Max Brunsfeld --- Cargo.lock | 97 +++- Cargo.toml | 8 +- assets/icons/supermaven.svg | 8 + assets/icons/supermaven_disabled.svg | 15 + assets/icons/supermaven_error.svg | 11 + assets/icons/supermaven_init.svg | 11 + assets/settings/default.json | 4 +- crates/anthropic/src/anthropic.rs | 4 +- crates/collab/Cargo.toml | 1 + crates/collab/k8s/collab.template.yml | 5 + crates/collab/src/completion.rs | 2 + crates/collab/src/lib.rs | 1 + crates/collab/src/rpc.rs | 74 ++- crates/collab/src/tests/test_server.rs | 1 + crates/copilot/Cargo.toml | 10 + crates/copilot/src/copilot.rs | 23 +- .../src/copilot_completion_provider.rs | 77 +-- crates/{copilot_ui => copilot}/src/sign_in.rs | 4 +- crates/copilot_ui/src/copilot_button.rs | 403 -------------- crates/copilot_ui/src/copilot_ui.rs | 7 - crates/editor/src/editor.rs | 44 +- .../editor/src/inline_completion_provider.rs | 10 +- crates/google_ai/src/google_ai.rs | 6 +- .../Cargo.toml | 7 +- .../LICENSE-GPL | 0 .../src/inline_completion_button.rs | 510 ++++++++++++++++++ crates/language/src/language_settings.rs | 116 ++-- crates/language_tools/Cargo.toml | 2 +- crates/language_tools/src/lsp_log.rs | 323 ++++++----- crates/project/Cargo.toml | 1 - crates/project/src/project.rs | 79 +-- crates/rpc/proto/zed.proto | 13 +- crates/rpc/src/proto.rs | 3 + crates/supermaven/Cargo.toml | 41 ++ crates/supermaven/src/messages.rs | 152 ++++++ crates/supermaven/src/supermaven.rs | 345 ++++++++++++ .../src/supermaven_completion_provider.rs | 131 +++++ crates/supermaven_api/Cargo.toml | 21 + crates/supermaven_api/src/supermaven_api.rs | 291 ++++++++++ crates/ui/src/components/icon.rs | 8 + crates/util/src/paths.rs | 1 + crates/welcome/Cargo.toml | 2 +- crates/welcome/src/welcome.rs | 3 +- crates/zed/Cargo.toml | 3 +- crates/zed/src/main.rs | 57 +- crates/zed/src/zed.rs | 8 +- .../zed/src/zed/inline_completion_registry.rs | 126 +++++ 47 files changed, 2242 insertions(+), 827 deletions(-) create mode 100644 assets/icons/supermaven.svg create mode 100644 assets/icons/supermaven_disabled.svg create mode 100644 assets/icons/supermaven_error.svg create mode 100644 assets/icons/supermaven_init.svg create mode 100644 crates/collab/src/completion.rs rename crates/{copilot_ui => copilot}/src/copilot_completion_provider.rs (94%) rename crates/{copilot_ui => copilot}/src/sign_in.rs (98%) delete mode 100644 crates/copilot_ui/src/copilot_button.rs delete mode 100644 crates/copilot_ui/src/copilot_ui.rs rename crates/{copilot_ui => inline_completion_button}/Cargo.toml (88%) rename crates/{copilot_ui => inline_completion_button}/LICENSE-GPL (100%) create mode 100644 crates/inline_completion_button/src/inline_completion_button.rs create mode 100644 crates/supermaven/Cargo.toml create mode 100644 crates/supermaven/src/messages.rs create mode 100644 crates/supermaven/src/supermaven.rs create mode 100644 crates/supermaven/src/supermaven_completion_provider.rs create mode 100644 crates/supermaven_api/Cargo.toml create mode 100644 crates/supermaven_api/src/supermaven_api.rs create mode 100644 crates/zed/src/zed/inline_completion_registry.rs diff --git a/Cargo.lock b/Cargo.lock index 3417a62d82..894b827c4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2316,6 +2316,7 @@ dependencies = [ "sha2 0.10.7", "sqlx", "subtle", + "supermaven_api", "telemetry_events", "text", "theme", @@ -2512,30 +2513,10 @@ dependencies = [ "async-compression", "async-std", "async-tar", + "client", "clock", "collections", "command_palette_hooks", - "fs", - "futures 0.3.28", - "gpui", - "language", - "lsp", - "node_runtime", - "parking_lot", - "rpc", - "serde", - "settings", - "smol", - "util", -] - -[[package]] -name = "copilot_ui" -version = "0.1.0" -dependencies = [ - "anyhow", - "client", - "copilot", "editor", "fs", "futures 0.3.28", @@ -2544,14 +2525,18 @@ dependencies = [ "language", "lsp", "menu", + "node_runtime", + "parking_lot", "project", + "rpc", + "serde", "serde_json", "settings", + "smol", "theme", "ui", "util", "workspace", - "zed_actions", ] [[package]] @@ -5143,6 +5128,30 @@ dependencies = [ "syn 2.0.59", ] +[[package]] +name = "inline_completion_button" +version = "0.1.0" +dependencies = [ + "anyhow", + "copilot", + "editor", + "fs", + "futures 0.3.28", + "gpui", + "indoc", + "language", + "lsp", + "project", + "serde_json", + "settings", + "supermaven", + "theme", + "ui", + "util", + "workspace", + "zed_actions", +] + [[package]] name = "inotify" version = "0.9.6" @@ -5548,6 +5557,7 @@ dependencies = [ "anyhow", "client", "collections", + "copilot", "editor", "env_logger", "futures 0.3.28", @@ -7422,7 +7432,6 @@ dependencies = [ "client", "clock", "collections", - "copilot", "env_logger", "fs", "futures 0.3.28", @@ -9594,6 +9603,43 @@ dependencies = [ "rayon", ] +[[package]] +name = "supermaven" +version = "0.1.0" +dependencies = [ + "anyhow", + "client", + "collections", + "editor", + "env_logger", + "futures 0.3.28", + "gpui", + "language", + "log", + "postage", + "project", + "serde", + "serde_json", + "settings", + "smol", + "supermaven_api", + "theme", + "ui", + "util", +] + +[[package]] +name = "supermaven_api" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.28", + "serde", + "serde_json", + "smol", + "util", +] + [[package]] name = "sval" version = "2.8.0" @@ -11798,12 +11844,12 @@ version = "0.1.0" dependencies = [ "anyhow", "client", - "copilot_ui", "db", "editor", "extensions_ui", "fuzzy", "gpui", + "inline_completion_button", "install_cli", "picker", "project", @@ -12683,7 +12729,6 @@ dependencies = [ "collections", "command_palette", "copilot", - "copilot_ui", "db", "dev_server_projects", "diagnostics", @@ -12700,6 +12745,7 @@ dependencies = [ "gpui", "headless", "image_viewer", + "inline_completion_button", "install_cli", "isahc", "journal", @@ -12730,6 +12776,7 @@ dependencies = [ "settings", "simplelog", "smol", + "supermaven", "tab_switcher", "task", "tasks_ui", diff --git a/Cargo.toml b/Cargo.toml index ca0e5f35bd..67ce732b61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ members = [ "crates/command_palette", "crates/command_palette_hooks", "crates/copilot", - "crates/copilot_ui", "crates/db", "crates/diagnostics", "crates/editor", @@ -42,6 +41,7 @@ members = [ "crates/gpui_macros", "crates/headless", "crates/image_viewer", + "crates/inline_completion_button", "crates/install_cli", "crates/journal", "crates/language", @@ -86,6 +86,8 @@ members = [ "crates/storybook", "crates/sum_tree", "crates/tab_switcher", + "crates/supermaven", + "crates/supermaven_api", "crates/terminal", "crates/terminal_view", "crates/text", @@ -159,7 +161,6 @@ color = { path = "crates/color" } command_palette = { path = "crates/command_palette" } command_palette_hooks = { path = "crates/command_palette_hooks" } copilot = { path = "crates/copilot" } -copilot_ui = { path = "crates/copilot_ui" } db = { path = "crates/db" } diagnostics = { path = "crates/diagnostics" } editor = { path = "crates/editor" } @@ -180,6 +181,7 @@ gpui_macros = { path = "crates/gpui_macros" } headless = { path = "crates/headless" } install_cli = { path = "crates/install_cli" } image_viewer = { path = "crates/image_viewer" } +inline_completion_button = { path = "crates/inline_completion_button" } journal = { path = "crates/journal" } language = { path = "crates/language" } language_selector = { path = "crates/language_selector" } @@ -220,6 +222,8 @@ settings = { path = "crates/settings" } snippet = { path = "crates/snippet" } sqlez = { path = "crates/sqlez" } sqlez_macros = { path = "crates/sqlez_macros" } +supermaven = { path = "crates/supermaven" } +supermaven_api = { path = "crates/supermaven_api"} story = { path = "crates/story" } storybook = { path = "crates/storybook" } sum_tree = { path = "crates/sum_tree" } diff --git a/assets/icons/supermaven.svg b/assets/icons/supermaven.svg new file mode 100644 index 0000000000..19837fbf56 --- /dev/null +++ b/assets/icons/supermaven.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/assets/icons/supermaven_disabled.svg b/assets/icons/supermaven_disabled.svg new file mode 100644 index 0000000000..39ff8a6122 --- /dev/null +++ b/assets/icons/supermaven_disabled.svg @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/assets/icons/supermaven_error.svg b/assets/icons/supermaven_error.svg new file mode 100644 index 0000000000..669322b97d --- /dev/null +++ b/assets/icons/supermaven_error.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/assets/icons/supermaven_init.svg b/assets/icons/supermaven_init.svg new file mode 100644 index 0000000000..b919d5559b --- /dev/null +++ b/assets/icons/supermaven_init.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/assets/settings/default.json b/assets/settings/default.json index c8560d7f15..de6da01f87 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -12,8 +12,8 @@ "base_keymap": "VSCode", // Features that can be globally enabled or disabled "features": { - // Show Copilot icon in status bar - "copilot": true + // Which inline completion provider to use. + "inline_completion_provider": "copilot" }, // The name of a font to use for rendering text in the editor "buffer_font_family": "Zed Mono", diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index a96a23b166..aeaae1f34d 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; use serde::{Deserialize, Serialize}; -use std::convert::TryFrom; +use std::{convert::TryFrom, sync::Arc}; use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] @@ -141,7 +141,7 @@ pub enum TextDelta { } pub async fn stream_completion( - client: &dyn HttpClient, + client: Arc, api_url: &str, api_key: &str, request: Request, diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index e8dbcf851a..5e719739ae 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -39,6 +39,7 @@ live_kit_server.workspace = true log.workspace = true nanoid.workspace = true open_ai.workspace = true +supermaven_api.workspace = true parking_lot.workspace = true prometheus = "0.13" prost.workspace = true diff --git a/crates/collab/k8s/collab.template.yml b/crates/collab/k8s/collab.template.yml index 8bd6a71514..271b146b0b 100644 --- a/crates/collab/k8s/collab.template.yml +++ b/crates/collab/k8s/collab.template.yml @@ -172,6 +172,11 @@ spec: secretKeyRef: name: slack key: panics_webhook + - name: SUPERMAVEN_ADMIN_API_KEY + valueFrom: + secretKeyRef: + name: supermaven + key: api_key - name: INVITE_LINK_PREFIX value: ${INVITE_LINK_PREFIX} - name: RUST_BACKTRACE diff --git a/crates/collab/src/completion.rs b/crates/collab/src/completion.rs new file mode 100644 index 0000000000..dd1f4b3be6 --- /dev/null +++ b/crates/collab/src/completion.rs @@ -0,0 +1,2 @@ +use anyhow::{anyhow, Result}; +use rpc::proto; diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 925d192fc0..ae83fccb98 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -138,6 +138,7 @@ pub struct Config { pub zed_client_checksum_seed: Option, pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, + pub supermaven_admin_api_key: Option>, } impl Config { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index e4a83a4338..59f811f0b5 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -34,6 +34,7 @@ pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL}; use sha2::Digest; +use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; use futures::{ channel::oneshot, @@ -148,7 +149,8 @@ struct Session { peer: Arc, connection_pool: Arc>, live_kit_client: Option>, - http_client: IsahcHttpClient, + supermaven_client: Option>, + http_client: Arc, rate_limiter: Arc, _executor: Executor, } @@ -189,6 +191,14 @@ impl Session { } } + fn is_staff(&self) -> bool { + match &self.principal { + Principal::User(user) => user.admin, + Principal::Impersonated { .. } => true, + Principal::DevServer(_) => false, + } + } + fn dev_server_id(&self) -> Option { match &self.principal { Principal::User(_) | Principal::Impersonated { .. } => None, @@ -233,6 +243,14 @@ impl UserSession { pub fn user_id(&self) -> UserId { self.0.user_id().unwrap() } + + pub fn email(&self) -> Option { + match &self.0.principal { + Principal::User(user) => user.email_address.clone(), + Principal::Impersonated { user, .. } => user.email_address.clone(), + Principal::DevServer(..) => None, + } + } } impl Deref for UserSession { @@ -561,6 +579,7 @@ impl Server { .add_request_handler(user_handler(get_private_user_info)) .add_message_handler(user_message_handler(acknowledge_channel_message)) .add_message_handler(user_message_handler(acknowledge_buffer_version)) + .add_request_handler(user_handler(get_supermaven_api_key)) .add_streaming_request_handler({ let app_state = app_state.clone(); move |request, response, session| { @@ -938,13 +957,22 @@ impl Server { tracing::info!("connection opened"); let http_client = match IsahcHttpClient::new() { - Ok(http_client) => http_client, + Ok(http_client) => Arc::new(http_client), Err(error) => { tracing::error!(?error, "failed to create HTTP client"); return; } }; + let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() { + Some(Arc::new(SupermavenAdminApi::new( + supermaven_admin_api_key.to_string(), + http_client.clone(), + ))) + } else { + None + }; + let session = Session { principal: principal.clone(), connection_id, @@ -955,6 +983,7 @@ impl Server { http_client, rate_limiter: this.app_state.rate_limiter.clone(), _executor: executor.clone(), + supermaven_client, }; if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await { @@ -4210,7 +4239,7 @@ async fn complete_with_open_ai( api_key: Arc, ) -> Result<()> { let mut completion_stream = open_ai::stream_completion( - &session.http_client, + session.http_client.as_ref(), OPEN_AI_API_URL, &api_key, crate::ai::language_model_request_to_open_ai(request)?, @@ -4274,7 +4303,7 @@ async fn complete_with_google_ai( api_key: Arc, ) -> Result<()> { let mut stream = google_ai::stream_generate_content( - &session.http_client, + session.http_client.clone(), google_ai::API_URL, api_key.as_ref(), crate::ai::language_model_request_to_google_ai(request)?, @@ -4358,7 +4387,7 @@ async fn complete_with_anthropic( .collect(); let mut stream = anthropic::stream_completion( - &session.http_client, + session.http_client.clone(), "https://api.anthropic.com", &api_key, anthropic::Request { @@ -4482,7 +4511,7 @@ async fn count_tokens_with_language_model( let api_key = google_ai_api_key .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?; let tokens_response = google_ai::count_tokens( - &session.http_client, + session.http_client.as_ref(), google_ai::API_URL, &api_key, crate::ai::count_tokens_request_to_google_ai(request)?, @@ -4530,7 +4559,7 @@ async fn compute_embeddings( let embeddings = match request.model.as_str() { "openai/text-embedding-3-small" => { open_ai::embed( - &session.http_client, + session.http_client.as_ref(), OPEN_AI_API_URL, &api_key, OpenAiEmbeddingModel::TextEmbedding3Small, @@ -4602,6 +4631,37 @@ async fn authorize_access_to_language_models(session: &UserSession) -> Result<() } } +/// Get a Supermaven API key for the user +async fn get_supermaven_api_key( + _request: proto::GetSupermavenApiKey, + response: Response, + session: UserSession, +) -> Result<()> { + let user_id: String = session.user_id().to_string(); + if !session.is_staff() { + return Err(anyhow!("supermaven not enabled for this account"))?; + } + + let email = session + .email() + .ok_or_else(|| anyhow!("user must have an email"))?; + + let supermaven_admin_api = session + .supermaven_client + .as_ref() + .ok_or_else(|| anyhow!("supermaven not configured"))?; + + let result = supermaven_admin_api + .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email }) + .await?; + + response.send(proto::GetSupermavenApiKeyResponse { + api_key: result.api_key, + })?; + + Ok(()) +} + /// Start receiving chat updates for a channel async fn join_channel_chat( request: proto::JoinChannelChat, diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 2fec21f76e..3a456a328e 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -655,6 +655,7 @@ impl TestServer { auto_join_channel_id: None, migrations_path: None, seed_path: None, + supermaven_admin_api_key: None, }, }) } diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 609bd0e3a8..3f38a81f5b 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -27,28 +27,38 @@ anyhow.workspace = true async-compression.workspace = true async-tar.workspace = true collections.workspace = true +client.workspace = true command_palette_hooks.workspace = true +editor.workspace = true futures.workspace = true gpui.workspace = true language.workspace = true lsp.workspace = true +menu.workspace = true node_runtime.workspace = true parking_lot.workspace = true +project.workspace = true serde.workspace = true settings.workspace = true smol.workspace = true +ui.workspace = true util.workspace = true +workspace.workspace = true [target.'cfg(windows)'.dependencies] async-std = { version = "1.12.0", features = ["unstable"] } [dev-dependencies] clock.workspace = true +indoc.workspace = true +serde_json.workspace = true collections = { workspace = true, features = ["test-support"] } fs = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } language = { workspace = true, features = ["test-support"] } lsp = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } rpc = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } +theme = { workspace = true, features = ["test-support"] } util = { workspace = true, features = ["test-support"] } diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 99f94b5511..577f335d2a 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -1,4 +1,7 @@ +mod copilot_completion_provider; pub mod request; +mod sign_in; + use anyhow::{anyhow, Context as _, Result}; use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; @@ -10,9 +13,9 @@ use gpui::{ ModelContext, Task, WeakModel, }; use language::{ - language_settings::{all_language_settings, language_settings}, - point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, - LanguageServerName, PointUtf16, ToPointUtf16, + language_settings::{all_language_settings, language_settings, InlineCompletionProvider}, + point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16, + ToPointUtf16, }; use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId}; use node_runtime::NodeRuntime; @@ -32,6 +35,9 @@ use util::{ fs::remove_matching, github::latest_github_release, http::HttpClient, maybe, paths, ResultExt, }; +pub use copilot_completion_provider::CopilotCompletionProvider; +pub use sign_in::CopilotCodeVerification; + actions!( copilot, [ @@ -144,7 +150,6 @@ impl CopilotServer { } struct RunningCopilotServer { - name: LanguageServerName, lsp: Arc, sign_in_status: SignInStatus, registered_buffers: HashMap, @@ -354,7 +359,9 @@ impl Copilot { let server_id = self.server_id; let http = self.http.clone(); let node_runtime = self.node_runtime.clone(); - if all_language_settings(None, cx).copilot_enabled(None, None) { + if all_language_settings(None, cx).inline_completions.provider + == InlineCompletionProvider::Copilot + { if matches!(self.server, CopilotServer::Disabled) { let start_task = cx .spawn(move |this, cx| { @@ -393,7 +400,6 @@ impl Copilot { http: http.clone(), node_runtime, server: CopilotServer::Running(RunningCopilotServer { - name: LanguageServerName(Arc::from("copilot")), lsp: Arc::new(server), sign_in_status: SignInStatus::Authorized, registered_buffers: Default::default(), @@ -467,7 +473,6 @@ impl Copilot { match server { Ok((server, status)) => { this.server = CopilotServer::Running(RunningCopilotServer { - name: LanguageServerName(Arc::from("copilot")), lsp: server, sign_in_status: SignInStatus::SignedOut, registered_buffers: Default::default(), @@ -607,9 +612,9 @@ impl Copilot { cx.background_executor().spawn(start_task) } - pub fn language_server(&self) -> Option<(&LanguageServerName, &Arc)> { + pub fn language_server(&self) -> Option<&Arc> { if let CopilotServer::Running(server) = &self.server { - Some((&server.name, &server.lsp)) + Some(&server.lsp) } else { None } diff --git a/crates/copilot_ui/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_completion_provider.rs similarity index 94% rename from crates/copilot_ui/src/copilot_completion_provider.rs rename to crates/copilot/src/copilot_completion_provider.rs index c6226c7bb1..970145a10f 100644 --- a/crates/copilot_ui/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_completion_provider.rs @@ -1,10 +1,12 @@ +use crate::{Completion, Copilot}; use anyhow::Result; use client::telemetry::Telemetry; -use copilot::Copilot; use editor::{Direction, InlineCompletionProvider}; use gpui::{AppContext, EntityId, Model, ModelContext, Task}; -use language::language_settings::AllLanguageSettings; -use language::{language_settings::all_language_settings, Buffer, OffsetRangeExt, ToOffset}; +use language::{ + language_settings::{all_language_settings, AllLanguageSettings}, + Buffer, OffsetRangeExt, ToOffset, +}; use settings::Settings; use std::{path::Path, sync::Arc, time::Duration}; @@ -13,7 +15,7 @@ pub const COPILOT_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75); pub struct CopilotCompletionProvider { cycled: bool, buffer_id: Option, - completions: Vec, + completions: Vec, active_completion_index: usize, file_extension: Option, pending_refresh: Task>, @@ -42,11 +44,11 @@ impl CopilotCompletionProvider { self } - fn active_completion(&self) -> Option<&copilot::Completion> { + fn active_completion(&self) -> Option<&Completion> { self.completions.get(self.active_completion_index) } - fn push_completion(&mut self, new_completion: copilot::Completion) { + fn push_completion(&mut self, new_completion: Completion) { for completion in &self.completions { if completion.text == new_completion.text && completion.range == new_completion.range { return; @@ -71,7 +73,7 @@ impl InlineCompletionProvider for CopilotCompletionProvider { let file = buffer.file(); let language = buffer.language_at(cursor_position); let settings = all_language_settings(file, cx); - settings.copilot_enabled(language.as_ref(), file.map(|f| f.path().as_ref())) + settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref())) } fn refresh( @@ -196,7 +198,10 @@ impl InlineCompletionProvider for CopilotCompletionProvider { fn discard(&mut self, cx: &mut ModelContext) { let settings = AllLanguageSettings::get_global(cx); - if !settings.copilot.feature_enabled { + + let copilot_enabled = settings.inline_completions_enabled(None, None); + + if !copilot_enabled { return; } @@ -298,7 +303,9 @@ mod tests { ) .await; let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot)); - cx.update_editor(|editor, cx| editor.set_inline_completion_provider(copilot_provider, cx)); + cx.update_editor(|editor, cx| { + editor.set_inline_completion_provider(Some(copilot_provider), cx) + }); // When inserting, ensure autocompletion is favored over Copilot suggestions. cx.set_state(indoc! {" @@ -318,7 +325,7 @@ mod tests { ); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.copilot1".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)), ..Default::default() @@ -360,7 +367,7 @@ mod tests { ); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.copilot1".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)), ..Default::default() @@ -393,7 +400,7 @@ mod tests { ); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.copilot1".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)), ..Default::default() @@ -426,7 +433,7 @@ mod tests { // After debouncing, new Copilot completions should be requested. handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.copilot2".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 5)), ..Default::default() @@ -503,7 +510,7 @@ mod tests { }); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: " let x = 4;".into(), range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 2)), ..Default::default() @@ -553,7 +560,9 @@ mod tests { ) .await; let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot)); - cx.update_editor(|editor, cx| editor.set_inline_completion_provider(copilot_provider, cx)); + cx.update_editor(|editor, cx| { + editor.set_inline_completion_provider(Some(copilot_provider), cx) + }); // Setup the editor with a completion request. cx.set_state(indoc! {" @@ -573,7 +582,7 @@ mod tests { ); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.copilot1".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)), ..Default::default() @@ -615,7 +624,7 @@ mod tests { ); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.123. copilot\n 456".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)), ..Default::default() @@ -675,7 +684,9 @@ mod tests { ) .await; let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot)); - cx.update_editor(|editor, cx| editor.set_inline_completion_provider(copilot_provider, cx)); + cx.update_editor(|editor, cx| { + editor.set_inline_completion_provider(Some(copilot_provider), cx) + }); cx.set_state(indoc! {" one @@ -685,7 +696,7 @@ mod tests { handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "two.foo()".into(), range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 2)), ..Default::default() @@ -756,13 +767,13 @@ mod tests { let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot)); editor .update(cx, |editor, cx| { - editor.set_inline_completion_provider(copilot_provider, cx) + editor.set_inline_completion_provider(Some(copilot_provider), cx) }) .unwrap(); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "b = 2 + a".into(), range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 5)), ..Default::default() @@ -788,7 +799,7 @@ mod tests { handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "d = 4 + c".into(), range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 6)), ..Default::default() @@ -833,7 +844,7 @@ mod tests { async fn test_copilot_disabled_globs(executor: BackgroundExecutor, cx: &mut TestAppContext) { init_test(cx, |settings| { settings - .copilot + .inline_completions .get_or_insert(Default::default()) .disabled_globs = Some(vec![".env*".to_string()]); }); @@ -888,15 +899,15 @@ mod tests { let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot)); editor .update(cx, |editor, cx| { - editor.set_inline_completion_provider(copilot_provider, cx) + editor.set_inline_completion_provider(Some(copilot_provider), cx) }) .unwrap(); let mut copilot_requests = copilot_lsp - .handle_request::( + .handle_request::( move |_params, _cx| async move { - Ok(copilot::request::GetCompletionsResult { - completions: vec![copilot::request::Completion { + Ok(crate::request::GetCompletionsResult { + completions: vec![crate::request::Completion { text: "next line".into(), range: lsp::Range::new( lsp::Position::new(1, 0), @@ -931,21 +942,21 @@ mod tests { fn handle_copilot_completion_request( lsp: &lsp::FakeLanguageServer, - completions: Vec, - completions_cycling: Vec, + completions: Vec, + completions_cycling: Vec, ) { - lsp.handle_request::(move |_params, _cx| { + lsp.handle_request::(move |_params, _cx| { let completions = completions.clone(); async move { - Ok(copilot::request::GetCompletionsResult { + Ok(crate::request::GetCompletionsResult { completions: completions.clone(), }) } }); - lsp.handle_request::(move |_params, _cx| { + lsp.handle_request::(move |_params, _cx| { let completions_cycling = completions_cycling.clone(); async move { - Ok(copilot::request::GetCompletionsResult { + Ok(crate::request::GetCompletionsResult { completions: completions_cycling.clone(), }) } diff --git a/crates/copilot_ui/src/sign_in.rs b/crates/copilot/src/sign_in.rs similarity index 98% rename from crates/copilot_ui/src/sign_in.rs rename to crates/copilot/src/sign_in.rs index 396b2367f9..abf7252fef 100644 --- a/crates/copilot_ui/src/sign_in.rs +++ b/crates/copilot/src/sign_in.rs @@ -1,4 +1,4 @@ -use copilot::{request::PromptUserDeviceFlow, Copilot, Status}; +use crate::{request::PromptUserDeviceFlow, Copilot, Status}; use gpui::{ div, svg, AppContext, ClipboardItem, DismissEvent, Element, EventEmitter, FocusHandle, FocusableView, InteractiveElement, IntoElement, Model, MouseDownEvent, ParentElement, Render, @@ -26,7 +26,7 @@ impl EventEmitter for CopilotCodeVerification {} impl ModalView for CopilotCodeVerification {} impl CopilotCodeVerification { - pub(crate) fn new(copilot: &Model, cx: &mut ViewContext) -> Self { + pub fn new(copilot: &Model, cx: &mut ViewContext) -> Self { let status = copilot.read(cx).status(); Self { status, diff --git a/crates/copilot_ui/src/copilot_button.rs b/crates/copilot_ui/src/copilot_button.rs deleted file mode 100644 index b228a10839..0000000000 --- a/crates/copilot_ui/src/copilot_button.rs +++ /dev/null @@ -1,403 +0,0 @@ -use crate::sign_in::CopilotCodeVerification; -use anyhow::Result; -use copilot::{Copilot, SignOut, Status}; -use editor::{scroll::Autoscroll, Editor}; -use fs::Fs; -use gpui::{ - div, Action, AnchorCorner, AppContext, AsyncWindowContext, Entity, IntoElement, ParentElement, - Render, Subscription, View, ViewContext, WeakView, WindowContext, -}; -use language::{ - language_settings::{self, all_language_settings, AllLanguageSettings}, - File, Language, -}; -use settings::{update_settings_file, Settings, SettingsStore}; -use std::{path::Path, sync::Arc}; -use util::{paths, ResultExt}; -use workspace::notifications::NotificationId; -use workspace::{ - create_and_open_local_file, - item::ItemHandle, - ui::{ - popover_menu, ButtonCommon, Clickable, ContextMenu, IconButton, IconName, IconSize, Tooltip, - }, - StatusItemView, Toast, Workspace, -}; -use zed_actions::OpenBrowser; - -const COPILOT_SETTINGS_URL: &str = "https://github.com/settings/copilot"; - -struct CopilotStartingToast; - -struct CopilotErrorToast; - -pub struct CopilotButton { - editor_subscription: Option<(Subscription, usize)>, - editor_enabled: Option, - language: Option>, - file: Option>, - fs: Arc, -} - -impl Render for CopilotButton { - fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - let all_language_settings = all_language_settings(None, cx); - if !all_language_settings.copilot.feature_enabled { - return div(); - } - - let Some(copilot) = Copilot::global(cx) else { - return div(); - }; - let status = copilot.read(cx).status(); - - let enabled = self - .editor_enabled - .unwrap_or_else(|| all_language_settings.copilot_enabled(None, None)); - - let icon = match status { - Status::Error(_) => IconName::CopilotError, - Status::Authorized => { - if enabled { - IconName::Copilot - } else { - IconName::CopilotDisabled - } - } - _ => IconName::CopilotInit, - }; - - if let Status::Error(e) = status { - return div().child( - IconButton::new("copilot-error", icon) - .icon_size(IconSize::Small) - .on_click(cx.listener(move |_, _, cx| { - if let Some(workspace) = cx.window_handle().downcast::() { - workspace - .update(cx, |workspace, cx| { - workspace.show_toast( - Toast::new( - NotificationId::unique::(), - format!("Copilot can't be started: {}", e), - ) - .on_click( - "Reinstall Copilot", - |cx| { - if let Some(copilot) = Copilot::global(cx) { - copilot - .update(cx, |copilot, cx| { - copilot.reinstall(cx) - }) - .detach(); - } - }, - ), - cx, - ); - }) - .ok(); - } - })) - .tooltip(|cx| Tooltip::text("GitHub Copilot", cx)), - ); - } - let this = cx.view().clone(); - - div().child( - popover_menu("copilot") - .menu(move |cx| match status { - Status::Authorized => { - Some(this.update(cx, |this, cx| this.build_copilot_menu(cx))) - } - _ => Some(this.update(cx, |this, cx| this.build_copilot_start_menu(cx))), - }) - .anchor(AnchorCorner::BottomRight) - .trigger( - IconButton::new("copilot-icon", icon) - .tooltip(|cx| Tooltip::text("GitHub Copilot", cx)), - ), - ) - } -} - -impl CopilotButton { - pub fn new(fs: Arc, cx: &mut ViewContext) -> Self { - if let Some(copilot) = Copilot::global(cx) { - cx.observe(&copilot, |_, _, cx| cx.notify()).detach() - } - - cx.observe_global::(move |_, cx| cx.notify()) - .detach(); - - Self { - editor_subscription: None, - editor_enabled: None, - language: None, - file: None, - fs, - } - } - - pub fn build_copilot_start_menu(&mut self, cx: &mut ViewContext) -> View { - let fs = self.fs.clone(); - ContextMenu::build(cx, |menu, _| { - menu.entry("Sign In", None, initiate_sign_in).entry( - "Disable Copilot", - None, - move |cx| hide_copilot(fs.clone(), cx), - ) - }) - } - - pub fn build_copilot_menu(&mut self, cx: &mut ViewContext) -> View { - let fs = self.fs.clone(); - - ContextMenu::build(cx, move |mut menu, cx| { - if let Some(language) = self.language.clone() { - let fs = fs.clone(); - let language_enabled = - language_settings::language_settings(Some(&language), None, cx) - .show_copilot_suggestions; - - menu = menu.entry( - format!( - "{} Suggestions for {}", - if language_enabled { "Hide" } else { "Show" }, - language.name() - ), - None, - move |cx| toggle_copilot_for_language(language.clone(), fs.clone(), cx), - ); - } - - let settings = AllLanguageSettings::get_global(cx); - - if let Some(file) = &self.file { - let path = file.path().clone(); - let path_enabled = settings.copilot_enabled_for_path(&path); - - menu = menu.entry( - format!( - "{} Suggestions for This Path", - if path_enabled { "Hide" } else { "Show" } - ), - None, - move |cx| { - if let Some(workspace) = cx.window_handle().downcast::() { - if let Ok(workspace) = workspace.root_view(cx) { - let workspace = workspace.downgrade(); - cx.spawn(|cx| { - configure_disabled_globs( - workspace, - path_enabled.then_some(path.clone()), - cx, - ) - }) - .detach_and_log_err(cx); - } - } - }, - ); - } - - let globally_enabled = settings.copilot_enabled(None, None); - menu.entry( - if globally_enabled { - "Hide Suggestions for All Files" - } else { - "Show Suggestions for All Files" - }, - None, - move |cx| toggle_copilot_globally(fs.clone(), cx), - ) - .separator() - .link( - "Copilot Settings", - OpenBrowser { - url: COPILOT_SETTINGS_URL.to_string(), - } - .boxed_clone(), - ) - .action("Sign Out", SignOut.boxed_clone()) - }) - } - - pub fn update_enabled(&mut self, editor: View, cx: &mut ViewContext) { - let editor = editor.read(cx); - let snapshot = editor.buffer().read(cx).snapshot(cx); - let suggestion_anchor = editor.selections.newest_anchor().start; - let language = snapshot.language_at(suggestion_anchor); - let file = snapshot.file_at(suggestion_anchor).cloned(); - self.editor_enabled = { - let file = file.as_ref(); - Some( - file.map(|file| !file.is_private()).unwrap_or(true) - && all_language_settings(file, cx) - .copilot_enabled(language, file.map(|file| file.path().as_ref())), - ) - }; - self.language = language.cloned(); - self.file = file; - - cx.notify() - } -} - -impl StatusItemView for CopilotButton { - fn set_active_pane_item(&mut self, item: Option<&dyn ItemHandle>, cx: &mut ViewContext) { - if let Some(editor) = item.and_then(|item| item.act_as::(cx)) { - self.editor_subscription = Some(( - cx.observe(&editor, Self::update_enabled), - editor.entity_id().as_u64() as usize, - )); - self.update_enabled(editor, cx); - } else { - self.language = None; - self.editor_subscription = None; - self.editor_enabled = None; - } - cx.notify(); - } -} - -async fn configure_disabled_globs( - workspace: WeakView, - path_to_disable: Option>, - mut cx: AsyncWindowContext, -) -> Result<()> { - let settings_editor = workspace - .update(&mut cx, |_, cx| { - create_and_open_local_file(&paths::SETTINGS, cx, || { - settings::initial_user_settings_content().as_ref().into() - }) - })? - .await? - .downcast::() - .unwrap(); - - settings_editor.downgrade().update(&mut cx, |item, cx| { - let text = item.buffer().read(cx).snapshot(cx).text(); - - let settings = cx.global::(); - let edits = settings.edits_for_update::(&text, |file| { - let copilot = file.copilot.get_or_insert_with(Default::default); - let globs = copilot.disabled_globs.get_or_insert_with(|| { - settings - .get::(None) - .copilot - .disabled_globs - .iter() - .map(|glob| glob.glob().to_string()) - .collect() - }); - - if let Some(path_to_disable) = &path_to_disable { - globs.push(path_to_disable.to_string_lossy().into_owned()); - } else { - globs.clear(); - } - }); - - if !edits.is_empty() { - item.change_selections(Some(Autoscroll::newest()), cx, |selections| { - selections.select_ranges(edits.iter().map(|e| e.0.clone())); - }); - - // When *enabling* a path, don't actually perform an edit, just select the range. - if path_to_disable.is_some() { - item.edit(edits.iter().cloned(), cx); - } - } - })?; - - anyhow::Ok(()) -} - -fn toggle_copilot_globally(fs: Arc, cx: &mut AppContext) { - let show_copilot_suggestions = all_language_settings(None, cx).copilot_enabled(None, None); - update_settings_file::(fs, cx, move |file| { - file.defaults.show_copilot_suggestions = Some(!show_copilot_suggestions) - }); -} - -fn toggle_copilot_for_language(language: Arc, fs: Arc, cx: &mut AppContext) { - let show_copilot_suggestions = - all_language_settings(None, cx).copilot_enabled(Some(&language), None); - update_settings_file::(fs, cx, move |file| { - file.languages - .entry(language.name()) - .or_default() - .show_copilot_suggestions = Some(!show_copilot_suggestions); - }); -} - -fn hide_copilot(fs: Arc, cx: &mut AppContext) { - update_settings_file::(fs, cx, move |file| { - file.features.get_or_insert(Default::default()).copilot = Some(false); - }); -} - -pub fn initiate_sign_in(cx: &mut WindowContext) { - let Some(copilot) = Copilot::global(cx) else { - return; - }; - let status = copilot.read(cx).status(); - let Some(workspace) = cx.window_handle().downcast::() else { - return; - }; - match status { - Status::Starting { task } => { - let Some(workspace) = cx.window_handle().downcast::() else { - return; - }; - - let Ok(workspace) = workspace.update(cx, |workspace, cx| { - workspace.show_toast( - Toast::new( - NotificationId::unique::(), - "Copilot is starting...", - ), - cx, - ); - workspace.weak_handle() - }) else { - return; - }; - - cx.spawn(|mut cx| async move { - task.await; - if let Some(copilot) = cx.update(|cx| Copilot::global(cx)).ok().flatten() { - workspace - .update(&mut cx, |workspace, cx| match copilot.read(cx).status() { - Status::Authorized => workspace.show_toast( - Toast::new( - NotificationId::unique::(), - "Copilot has started!", - ), - cx, - ), - _ => { - workspace.dismiss_toast( - &NotificationId::unique::(), - cx, - ); - copilot - .update(cx, |copilot, cx| copilot.sign_in(cx)) - .detach_and_log_err(cx); - } - }) - .log_err(); - } - }) - .detach(); - } - _ => { - copilot.update(cx, |this, cx| this.sign_in(cx)).detach(); - workspace - .update(cx, |this, cx| { - this.toggle_modal(cx, |cx| CopilotCodeVerification::new(&copilot, cx)); - }) - .ok(); - } - } -} diff --git a/crates/copilot_ui/src/copilot_ui.rs b/crates/copilot_ui/src/copilot_ui.rs deleted file mode 100644 index 63bd03102f..0000000000 --- a/crates/copilot_ui/src/copilot_ui.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod copilot_button; -mod copilot_completion_provider; -mod sign_in; - -pub use copilot_button::*; -pub use copilot_completion_provider::*; -pub use sign_in::*; diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index ef5c97a592..bbd215b23b 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -1757,19 +1757,22 @@ impl Editor { self.completion_provider = Some(hub); } - pub fn set_inline_completion_provider( + pub fn set_inline_completion_provider( &mut self, - provider: Model, + provider: Option>, cx: &mut ViewContext, - ) { - self.inline_completion_provider = Some(RegisteredInlineCompletionProvider { - _subscription: cx.observe(&provider, |this, _, cx| { - if this.focus_handle.is_focused(cx) { - this.update_visible_inline_completion(cx); - } - }), - provider: Arc::new(provider), - }); + ) where + T: InlineCompletionProvider, + { + self.inline_completion_provider = + provider.map(|provider| RegisteredInlineCompletionProvider { + _subscription: cx.observe(&provider, |this, _, cx| { + if this.focus_handle.is_focused(cx) { + this.update_visible_inline_completion(cx); + } + }), + provider: Arc::new(provider), + }); self.refresh_inline_completion(false, cx); } @@ -2676,7 +2679,7 @@ impl Editor { } drop(snapshot); - let had_active_copilot_completion = this.has_active_inline_completion(cx); + let had_active_inline_completion = this.has_active_inline_completion(cx); this.change_selections(Some(Autoscroll::fit()), cx, |s| s.select(new_selections)); if brace_inserted { @@ -2692,7 +2695,7 @@ impl Editor { } } - if had_active_copilot_completion { + if had_active_inline_completion { this.refresh_inline_completion(true, cx); if !this.has_active_inline_completion(cx) { this.trigger_completion_on_input(&text, cx); @@ -4005,7 +4008,7 @@ impl Editor { if !self.show_inline_completions || !provider.is_enabled(&buffer, cursor_buffer_position, cx) { - self.clear_inline_completion(cx); + self.discard_inline_completion(cx); return None; } @@ -4207,13 +4210,6 @@ impl Editor { self.discard_inline_completion(cx); } - fn clear_inline_completion(&mut self, cx: &mut ViewContext) { - if let Some(old_completion) = self.active_inline_completion.take() { - self.splice_inlays(vec![old_completion.id], Vec::new(), cx); - } - self.discard_inline_completion(cx); - } - fn inline_completion_provider(&self) -> Option> { Some(self.inline_completion_provider.as_ref()?.provider.clone()) } @@ -9947,12 +9943,14 @@ impl Editor { .raw_user_settings() .get("vim_mode") == Some(&serde_json::Value::Bool(true)); - let copilot_enabled = all_language_settings(file, cx).copilot_enabled(None, None); + + let copilot_enabled = all_language_settings(file, cx).inline_completions.provider + == language::language_settings::InlineCompletionProvider::Copilot; let copilot_enabled_for_language = self .buffer .read(cx) .settings_at(0, cx) - .show_copilot_suggestions; + .show_inline_completions; let telemetry = project.read(cx).client().telemetry().clone(); telemetry.report_editor_event( diff --git a/crates/editor/src/inline_completion_provider.rs b/crates/editor/src/inline_completion_provider.rs index 31edf80623..2fb2cb608f 100644 --- a/crates/editor/src/inline_completion_provider.rs +++ b/crates/editor/src/inline_completion_provider.rs @@ -25,11 +25,11 @@ pub trait InlineCompletionProvider: 'static + Sized { ); fn accept(&mut self, cx: &mut ModelContext); fn discard(&mut self, cx: &mut ModelContext); - fn active_completion_text( - &self, + fn active_completion_text<'a>( + &'a self, buffer: &Model, cursor_position: language::Anchor, - cx: &AppContext, + cx: &'a AppContext, ) -> Option<&str>; } @@ -57,7 +57,7 @@ pub trait InlineCompletionProviderHandle { fn accept(&self, cx: &mut AppContext); fn discard(&self, cx: &mut AppContext); fn active_completion_text<'a>( - &self, + &'a self, buffer: &Model, cursor_position: language::Anchor, cx: &'a AppContext, @@ -110,7 +110,7 @@ where } fn active_completion_text<'a>( - &self, + &'a self, buffer: &Model, cursor_position: language::Anchor, cx: &'a AppContext, diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 4fe461981f..53ed5894de 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use anyhow::{anyhow, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; use serde::{Deserialize, Serialize}; @@ -5,8 +7,8 @@ use util::http::HttpClient; pub const API_URL: &str = "https://generativelanguage.googleapis.com"; -pub async fn stream_generate_content( - client: &T, +pub async fn stream_generate_content( + client: Arc, api_url: &str, api_key: &str, request: GenerateContentRequest, diff --git a/crates/copilot_ui/Cargo.toml b/crates/inline_completion_button/Cargo.toml similarity index 88% rename from crates/copilot_ui/Cargo.toml rename to crates/inline_completion_button/Cargo.toml index 4bf3240aab..48acdb3ae1 100644 --- a/crates/copilot_ui/Cargo.toml +++ b/crates/inline_completion_button/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "copilot_ui" +name = "inline_completion_button" version = "0.1.0" edition = "2021" publish = false @@ -9,19 +9,18 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/copilot_ui.rs" +path = "src/inline_completion_button.rs" doctest = false [dependencies] anyhow.workspace = true -client.workspace = true copilot.workspace = true editor.workspace = true fs.workspace = true gpui.workspace = true language.workspace = true -menu.workspace = true settings.workspace = true +supermaven.workspace = true ui.workspace = true util.workspace = true workspace.workspace = true diff --git a/crates/copilot_ui/LICENSE-GPL b/crates/inline_completion_button/LICENSE-GPL similarity index 100% rename from crates/copilot_ui/LICENSE-GPL rename to crates/inline_completion_button/LICENSE-GPL diff --git a/crates/inline_completion_button/src/inline_completion_button.rs b/crates/inline_completion_button/src/inline_completion_button.rs new file mode 100644 index 0000000000..86f6945ac1 --- /dev/null +++ b/crates/inline_completion_button/src/inline_completion_button.rs @@ -0,0 +1,510 @@ +use anyhow::Result; +use copilot::{Copilot, CopilotCodeVerification, Status}; +use editor::{scroll::Autoscroll, Editor}; +use fs::Fs; +use gpui::{ + div, Action, AnchorCorner, AppContext, AsyncWindowContext, Entity, IntoElement, ParentElement, + Render, Subscription, View, ViewContext, WeakView, WindowContext, +}; +use language::{ + language_settings::{ + self, all_language_settings, AllLanguageSettings, InlineCompletionProvider, + }, + File, Language, +}; +use settings::{update_settings_file, Settings, SettingsStore}; +use std::{path::Path, sync::Arc}; +use supermaven::{AccountStatus, Supermaven}; +use util::{paths, ResultExt}; +use workspace::{ + create_and_open_local_file, + item::ItemHandle, + notifications::NotificationId, + ui::{ + popover_menu, ButtonCommon, Clickable, ContextMenu, IconButton, IconName, IconSize, Tooltip, + }, + StatusItemView, Toast, Workspace, +}; +use zed_actions::OpenBrowser; + +const COPILOT_SETTINGS_URL: &str = "https://github.com/settings/copilot"; + +struct CopilotStartingToast; + +struct CopilotErrorToast; + +pub struct InlineCompletionButton { + editor_subscription: Option<(Subscription, usize)>, + editor_enabled: Option, + language: Option>, + file: Option>, + fs: Arc, +} + +enum SupermavenButtonStatus { + Ready, + Errored(String), + NeedsActivation(String), + Initializing, +} + +impl Render for InlineCompletionButton { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + let all_language_settings = all_language_settings(None, cx); + + match all_language_settings.inline_completions.provider { + InlineCompletionProvider::None => return div(), + + InlineCompletionProvider::Copilot => { + let Some(copilot) = Copilot::global(cx) else { + return div(); + }; + let status = copilot.read(cx).status(); + + let enabled = self.editor_enabled.unwrap_or_else(|| { + all_language_settings.inline_completions_enabled(None, None) + }); + + let icon = match status { + Status::Error(_) => IconName::CopilotError, + Status::Authorized => { + if enabled { + IconName::Copilot + } else { + IconName::CopilotDisabled + } + } + _ => IconName::CopilotInit, + }; + + if let Status::Error(e) = status { + return div().child( + IconButton::new("copilot-error", icon) + .icon_size(IconSize::Small) + .on_click(cx.listener(move |_, _, cx| { + if let Some(workspace) = cx.window_handle().downcast::() + { + workspace + .update(cx, |workspace, cx| { + workspace.show_toast( + Toast::new( + NotificationId::unique::(), + format!("Copilot can't be started: {}", e), + ) + .on_click("Reinstall Copilot", |cx| { + if let Some(copilot) = Copilot::global(cx) { + copilot + .update(cx, |copilot, cx| { + copilot.reinstall(cx) + }) + .detach(); + } + }), + cx, + ); + }) + .ok(); + } + })) + .tooltip(|cx| Tooltip::text("GitHub Copilot", cx)), + ); + } + let this = cx.view().clone(); + + div().child( + popover_menu("copilot") + .menu(move |cx| { + Some(match status { + Status::Authorized => { + this.update(cx, |this, cx| this.build_copilot_context_menu(cx)) + } + _ => this.update(cx, |this, cx| this.build_copilot_start_menu(cx)), + }) + }) + .anchor(AnchorCorner::BottomRight) + .trigger( + IconButton::new("copilot-icon", icon) + .tooltip(|cx| Tooltip::text("GitHub Copilot", cx)), + ), + ) + } + + InlineCompletionProvider::Supermaven => { + let Some(supermaven) = Supermaven::global(cx) else { + return div(); + }; + + let supermaven = supermaven.read(cx); + + let status = match supermaven { + Supermaven::Starting => SupermavenButtonStatus::Initializing, + Supermaven::FailedDownload { error } => { + SupermavenButtonStatus::Errored(error.to_string()) + } + Supermaven::Spawned(agent) => { + let account_status = agent.account_status.clone(); + match account_status { + AccountStatus::NeedsActivation { activate_url } => { + SupermavenButtonStatus::NeedsActivation(activate_url.clone()) + } + AccountStatus::Unknown => SupermavenButtonStatus::Initializing, + AccountStatus::Ready => SupermavenButtonStatus::Ready, + } + } + Supermaven::Error { error } => { + SupermavenButtonStatus::Errored(error.to_string()) + } + }; + + let icon = status.to_icon(); + let tooltip_text = status.to_tooltip(); + let this = cx.view().clone(); + + return div().child( + popover_menu("supermaven") + .menu(move |cx| match &status { + SupermavenButtonStatus::NeedsActivation(activate_url) => { + Some(ContextMenu::build(cx, |menu, _| { + let activate_url = activate_url.clone(); + menu.entry("Sign In", None, move |cx| { + cx.open_url(activate_url.as_str()) + }) + })) + } + SupermavenButtonStatus::Ready => Some( + this.update(cx, |this, cx| this.build_supermaven_context_menu(cx)), + ), + _ => None, + }) + .anchor(AnchorCorner::BottomRight) + .trigger( + IconButton::new("supermaven-icon", icon) + .tooltip(move |cx| Tooltip::text(tooltip_text.clone(), cx)), + ), + ); + } + } + } +} + +impl InlineCompletionButton { + pub fn new(fs: Arc, cx: &mut ViewContext) -> Self { + if let Some(copilot) = Copilot::global(cx) { + cx.observe(&copilot, |_, _, cx| cx.notify()).detach() + } + + cx.observe_global::(move |_, cx| cx.notify()) + .detach(); + + Self { + editor_subscription: None, + editor_enabled: None, + language: None, + file: None, + fs, + } + } + + pub fn build_copilot_start_menu(&mut self, cx: &mut ViewContext) -> View { + let fs = self.fs.clone(); + ContextMenu::build(cx, |menu, _| { + menu.entry("Sign In", None, initiate_sign_in).entry( + "Disable Copilot", + None, + move |cx| hide_copilot(fs.clone(), cx), + ) + }) + } + + pub fn build_language_settings_menu( + &self, + mut menu: ContextMenu, + cx: &mut WindowContext, + ) -> ContextMenu { + let fs = self.fs.clone(); + + if let Some(language) = self.language.clone() { + let fs = fs.clone(); + let language_enabled = language_settings::language_settings(Some(&language), None, cx) + .show_inline_completions; + + menu = menu.entry( + format!( + "{} Inline Completions for {}", + if language_enabled { "Hide" } else { "Show" }, + language.name() + ), + None, + move |cx| toggle_inline_completions_for_language(language.clone(), fs.clone(), cx), + ); + } + + let settings = AllLanguageSettings::get_global(cx); + + if let Some(file) = &self.file { + let path = file.path().clone(); + let path_enabled = settings.inline_completions_enabled_for_path(&path); + + menu = menu.entry( + format!( + "{} Inline Completions for This Path", + if path_enabled { "Hide" } else { "Show" } + ), + None, + move |cx| { + if let Some(workspace) = cx.window_handle().downcast::() { + if let Ok(workspace) = workspace.root_view(cx) { + let workspace = workspace.downgrade(); + cx.spawn(|cx| { + configure_disabled_globs( + workspace, + path_enabled.then_some(path.clone()), + cx, + ) + }) + .detach_and_log_err(cx); + } + } + }, + ); + } + + let globally_enabled = settings.inline_completions_enabled(None, None); + menu.entry( + if globally_enabled { + "Hide Inline Completions for All Files" + } else { + "Show Inline Completions for All Files" + }, + None, + move |cx| toggle_inline_completions_globally(fs.clone(), cx), + ) + } + + fn build_copilot_context_menu(&self, cx: &mut ViewContext) -> View { + ContextMenu::build(cx, |menu, cx| { + self.build_language_settings_menu(menu, cx) + .separator() + .link( + "Copilot Settings", + OpenBrowser { + url: COPILOT_SETTINGS_URL.to_string(), + } + .boxed_clone(), + ) + .action("Sign Out", copilot::SignOut.boxed_clone()) + }) + } + + fn build_supermaven_context_menu(&self, cx: &mut ViewContext) -> View { + ContextMenu::build(cx, |menu, cx| { + self.build_language_settings_menu(menu, cx).separator() + }) + } + + pub fn update_enabled(&mut self, editor: View, cx: &mut ViewContext) { + let editor = editor.read(cx); + let snapshot = editor.buffer().read(cx).snapshot(cx); + let suggestion_anchor = editor.selections.newest_anchor().start; + let language = snapshot.language_at(suggestion_anchor); + let file = snapshot.file_at(suggestion_anchor).cloned(); + self.editor_enabled = { + let file = file.as_ref(); + Some( + file.map(|file| !file.is_private()).unwrap_or(true) + && all_language_settings(file, cx).inline_completions_enabled( + language, + file.map(|file| file.path().as_ref()), + ), + ) + }; + self.language = language.cloned(); + self.file = file; + + cx.notify() + } +} + +impl StatusItemView for InlineCompletionButton { + fn set_active_pane_item(&mut self, item: Option<&dyn ItemHandle>, cx: &mut ViewContext) { + if let Some(editor) = item.and_then(|item| item.act_as::(cx)) { + self.editor_subscription = Some(( + cx.observe(&editor, Self::update_enabled), + editor.entity_id().as_u64() as usize, + )); + self.update_enabled(editor, cx); + } else { + self.language = None; + self.editor_subscription = None; + self.editor_enabled = None; + } + cx.notify(); + } +} + +impl SupermavenButtonStatus { + fn to_icon(&self) -> IconName { + match self { + SupermavenButtonStatus::Ready => IconName::Supermaven, + SupermavenButtonStatus::Errored(_) => IconName::SupermavenError, + SupermavenButtonStatus::NeedsActivation(_) => IconName::SupermavenInit, + SupermavenButtonStatus::Initializing => IconName::SupermavenInit, + } + } + + fn to_tooltip(&self) -> String { + match self { + SupermavenButtonStatus::Ready => "Supermaven is ready".to_string(), + SupermavenButtonStatus::Errored(error) => format!("Supermaven error: {}", error), + SupermavenButtonStatus::NeedsActivation(_) => "Supermaven needs activation".to_string(), + SupermavenButtonStatus::Initializing => "Supermaven initializing".to_string(), + } + } +} + +async fn configure_disabled_globs( + workspace: WeakView, + path_to_disable: Option>, + mut cx: AsyncWindowContext, +) -> Result<()> { + let settings_editor = workspace + .update(&mut cx, |_, cx| { + create_and_open_local_file(&paths::SETTINGS, cx, || { + settings::initial_user_settings_content().as_ref().into() + }) + })? + .await? + .downcast::() + .unwrap(); + + settings_editor.downgrade().update(&mut cx, |item, cx| { + let text = item.buffer().read(cx).snapshot(cx).text(); + + let settings = cx.global::(); + let edits = settings.edits_for_update::(&text, |file| { + let copilot = file.inline_completions.get_or_insert_with(Default::default); + let globs = copilot.disabled_globs.get_or_insert_with(|| { + settings + .get::(None) + .inline_completions + .disabled_globs + .iter() + .map(|glob| glob.glob().to_string()) + .collect() + }); + + if let Some(path_to_disable) = &path_to_disable { + globs.push(path_to_disable.to_string_lossy().into_owned()); + } else { + globs.clear(); + } + }); + + if !edits.is_empty() { + item.change_selections(Some(Autoscroll::newest()), cx, |selections| { + selections.select_ranges(edits.iter().map(|e| e.0.clone())); + }); + + // When *enabling* a path, don't actually perform an edit, just select the range. + if path_to_disable.is_some() { + item.edit(edits.iter().cloned(), cx); + } + } + })?; + + anyhow::Ok(()) +} + +fn toggle_inline_completions_globally(fs: Arc, cx: &mut AppContext) { + let show_inline_completions = + all_language_settings(None, cx).inline_completions_enabled(None, None); + update_settings_file::(fs, cx, move |file| { + file.defaults.show_inline_completions = Some(!show_inline_completions) + }); +} + +fn toggle_inline_completions_for_language( + language: Arc, + fs: Arc, + cx: &mut AppContext, +) { + let show_inline_completions = + all_language_settings(None, cx).inline_completions_enabled(Some(&language), None); + update_settings_file::(fs, cx, move |file| { + file.languages + .entry(language.name()) + .or_default() + .show_inline_completions = Some(!show_inline_completions); + }); +} + +fn hide_copilot(fs: Arc, cx: &mut AppContext) { + update_settings_file::(fs, cx, move |file| { + file.features.get_or_insert(Default::default()).copilot = Some(false); + }); +} + +pub fn initiate_sign_in(cx: &mut WindowContext) { + let Some(copilot) = Copilot::global(cx) else { + return; + }; + let status = copilot.read(cx).status(); + let Some(workspace) = cx.window_handle().downcast::() else { + return; + }; + match status { + Status::Starting { task } => { + let Some(workspace) = cx.window_handle().downcast::() else { + return; + }; + + let Ok(workspace) = workspace.update(cx, |workspace, cx| { + workspace.show_toast( + Toast::new( + NotificationId::unique::(), + "Copilot is starting...", + ), + cx, + ); + workspace.weak_handle() + }) else { + return; + }; + + cx.spawn(|mut cx| async move { + task.await; + if let Some(copilot) = cx.update(|cx| Copilot::global(cx)).ok().flatten() { + workspace + .update(&mut cx, |workspace, cx| match copilot.read(cx).status() { + Status::Authorized => workspace.show_toast( + Toast::new( + NotificationId::unique::(), + "Copilot has started!", + ), + cx, + ), + _ => { + workspace.dismiss_toast( + &NotificationId::unique::(), + cx, + ); + copilot + .update(cx, |copilot, cx| copilot.sign_in(cx)) + .detach_and_log_err(cx); + } + }) + .log_err(); + } + }) + .detach(); + } + _ => { + copilot.update(cx, |this, cx| this.sign_in(cx)).detach(); + workspace + .update(cx, |this, cx| { + this.toggle_modal(cx, |cx| CopilotCodeVerification::new(&copilot, cx)); + }) + .ok(); + } + } +} diff --git a/crates/language/src/language_settings.rs b/crates/language/src/language_settings.rs index bea5344be2..537816b983 100644 --- a/crates/language/src/language_settings.rs +++ b/crates/language/src/language_settings.rs @@ -51,8 +51,8 @@ pub fn all_language_settings<'a>( /// The settings for all languages. #[derive(Debug, Clone)] pub struct AllLanguageSettings { - /// The settings for GitHub Copilot. - pub copilot: CopilotSettings, + /// The inline completion settings. + pub inline_completions: InlineCompletionSettings, defaults: LanguageSettings, languages: HashMap, LanguageSettings>, pub(crate) file_types: HashMap, Vec>, @@ -101,9 +101,9 @@ pub struct LanguageSettings { /// - `"!"` - A language server ID prefixed with a `!` will be disabled. /// - `"..."` - A placeholder to refer to the **rest** of the registered language servers for this language. pub language_servers: Vec>, - /// Controls whether Copilot provides suggestion immediately (true) - /// or waits for a `copilot::Toggle` (false). - pub show_copilot_suggestions: bool, + /// Controls whether inline completions are shown immediately (true) + /// or manually by triggering `editor::ShowInlineCompletion` (false). + pub show_inline_completions: bool, /// Whether to show tabs and spaces in the editor. pub show_whitespaces: ShowWhitespaceSetting, /// Whether to start a new line with a comment when a previous line is a comment as well. @@ -165,12 +165,23 @@ impl LanguageSettings { } } -/// The settings for [GitHub Copilot](https://github.com/features/copilot). +/// The provider that supplies inline completions. +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum InlineCompletionProvider { + None, + #[default] + Copilot, + Supermaven, +} + +/// The settings for inline completions, such as [GitHub Copilot](https://github.com/features/copilot) +/// or [Supermaven](https://supermaven.com). #[derive(Clone, Debug, Default)] -pub struct CopilotSettings { - /// Whether Copilot is enabled. - pub feature_enabled: bool, - /// A list of globs representing files that Copilot should be disabled for. +pub struct InlineCompletionSettings { + /// The provider that supplies inline completions. + pub provider: InlineCompletionProvider, + /// A list of globs representing files that inline completions should be disabled for. pub disabled_globs: Vec, } @@ -180,9 +191,9 @@ pub struct AllLanguageSettingsContent { /// The settings for enabling/disabling features. #[serde(default)] pub features: Option, - /// The settings for GitHub Copilot. - #[serde(default)] - pub copilot: Option, + /// The inline completion settings. + #[serde(default, alias = "copilot")] + pub inline_completions: Option, /// The default language settings. #[serde(flatten)] pub defaults: LanguageSettingsContent, @@ -277,12 +288,12 @@ pub struct LanguageSettingsContent { /// Default: ["..."] #[serde(default)] pub language_servers: Option>>, - /// Controls whether Copilot provides suggestion immediately (true) - /// or waits for a `copilot::Toggle` (false). + /// Controls whether inline completions are shown immediately (true) + /// or manually by triggering `editor::ShowInlineCompletion` (false). /// /// Default: true - #[serde(default)] - pub show_copilot_suggestions: Option, + #[serde(default, alias = "show_copilot_suggestions")] + pub show_inline_completions: Option, /// Whether to show tabs and spaces in the editor. #[serde(default)] pub show_whitespaces: Option, @@ -314,10 +325,10 @@ pub struct LanguageSettingsContent { pub code_actions_on_format: Option>, } -/// The contents of the GitHub Copilot settings. -#[derive(Clone, Debug, PartialEq, Default, Serialize, Deserialize, JsonSchema)] -pub struct CopilotSettingsContent { - /// A list of globs representing files that Copilot should be disabled for. +/// The contents of the inline completion settings. +#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq)] +pub struct InlineCompletionSettingsContent { + /// A list of globs representing files that inline completions should be disabled for. #[serde(default)] pub disabled_globs: Option>, } @@ -328,6 +339,8 @@ pub struct CopilotSettingsContent { pub struct FeaturesContent { /// Whether the GitHub Copilot feature is enabled. pub copilot: Option, + /// Determines which inline completion provider to use. + pub inline_completion_provider: Option, } /// Controls the soft-wrapping behavior in the editor. @@ -475,29 +488,29 @@ impl AllLanguageSettings { &self.defaults } - /// Returns whether GitHub Copilot is enabled for the given path. - pub fn copilot_enabled_for_path(&self, path: &Path) -> bool { + /// Returns whether inline completions are enabled for the given path. + pub fn inline_completions_enabled_for_path(&self, path: &Path) -> bool { !self - .copilot + .inline_completions .disabled_globs .iter() .any(|glob| glob.is_match(path)) } - /// Returns whether GitHub Copilot is enabled for the given language and path. - pub fn copilot_enabled(&self, language: Option<&Arc>, path: Option<&Path>) -> bool { - if !self.copilot.feature_enabled { - return false; - } - + /// Returns whether inline completions are enabled for the given language and path. + pub fn inline_completions_enabled( + &self, + language: Option<&Arc>, + path: Option<&Path>, + ) -> bool { if let Some(path) = path { - if !self.copilot_enabled_for_path(path) { + if !self.inline_completions_enabled_for_path(path) { return false; } } self.language(language.map(|l| l.name()).as_deref()) - .show_copilot_suggestions + .show_inline_completions } } @@ -551,13 +564,13 @@ impl settings::Settings for AllLanguageSettings { languages.insert(language_name.clone(), language_settings); } - let mut copilot_enabled = default_value + let mut copilot_enabled = default_value.features.as_ref().and_then(|f| f.copilot); + let mut inline_completion_provider = default_value .features .as_ref() - .and_then(|f| f.copilot) - .ok_or_else(Self::missing_default)?; - let mut copilot_globs = default_value - .copilot + .and_then(|f| f.inline_completion_provider); + let mut completion_globs = default_value + .inline_completions .as_ref() .and_then(|c| c.disabled_globs.as_ref()) .ok_or_else(Self::missing_default)?; @@ -565,14 +578,21 @@ impl settings::Settings for AllLanguageSettings { let mut file_types: HashMap, Vec> = HashMap::default(); for user_settings in sources.customizations() { if let Some(copilot) = user_settings.features.as_ref().and_then(|f| f.copilot) { - copilot_enabled = copilot; + copilot_enabled = Some(copilot); + } + if let Some(provider) = user_settings + .features + .as_ref() + .and_then(|f| f.inline_completion_provider) + { + inline_completion_provider = Some(provider); } if let Some(globs) = user_settings - .copilot + .inline_completions .as_ref() .and_then(|f| f.disabled_globs.as_ref()) { - copilot_globs = globs; + completion_globs = globs; } // A user's global settings override the default global settings and @@ -601,9 +621,15 @@ impl settings::Settings for AllLanguageSettings { } Ok(Self { - copilot: CopilotSettings { - feature_enabled: copilot_enabled, - disabled_globs: copilot_globs + inline_completions: InlineCompletionSettings { + provider: if let Some(provider) = inline_completion_provider { + provider + } else if copilot_enabled.unwrap_or(true) { + InlineCompletionProvider::Copilot + } else { + InlineCompletionProvider::None + }, + disabled_globs: completion_globs .iter() .filter_map(|g| Some(globset::Glob::new(g).ok()?.compile_matcher())) .collect(), @@ -714,8 +740,8 @@ fn merge_settings(settings: &mut LanguageSettings, src: &LanguageSettingsContent ); merge(&mut settings.language_servers, src.language_servers.clone()); merge( - &mut settings.show_copilot_suggestions, - src.show_copilot_suggestions, + &mut settings.show_inline_completions, + src.show_inline_completions, ); merge(&mut settings.show_whitespaces, src.show_whitespaces); merge( diff --git a/crates/language_tools/Cargo.toml b/crates/language_tools/Cargo.toml index 6d0a1199b3..d85f5a6e52 100644 --- a/crates/language_tools/Cargo.toml +++ b/crates/language_tools/Cargo.toml @@ -15,6 +15,7 @@ doctest = false [dependencies] anyhow.workspace = true collections.workspace = true +copilot.workspace = true editor.workspace = true futures.workspace = true gpui.workspace = true @@ -26,7 +27,6 @@ settings.workspace = true theme.workspace = true tree-sitter.workspace = true ui.workspace = true -util.workspace = true workspace.workspace = true [dev-dependencies] diff --git a/crates/language_tools/src/lsp_log.rs b/crates/language_tools/src/lsp_log.rs index a35d8b33e5..28a27aac60 100644 --- a/crates/language_tools/src/lsp_log.rs +++ b/crates/language_tools/src/lsp_log.rs @@ -1,4 +1,5 @@ use collections::{HashMap, VecDeque}; +use copilot::Copilot; use editor::{actions::MoveToEnd, Editor, EditorEvent}; use futures::{channel::mpsc, StreamExt}; use gpui::{ @@ -7,11 +8,10 @@ use gpui::{ View, ViewContext, VisualContext, WeakModel, WindowContext, }; use language::{LanguageServerId, LanguageServerName}; -use lsp::IoKind; +use lsp::{IoKind, LanguageServer}; use project::{search::SearchQuery, Project}; use std::{borrow::Cow, sync::Arc}; use ui::{popover_menu, prelude::*, Button, Checkbox, ContextMenu, Label, Selection}; -use util::maybe; use workspace::{ item::{Item, ItemHandle, TabContentParams}, searchable::{SearchEvent, SearchableItem, SearchableItemHandle}, @@ -24,17 +24,21 @@ const MAX_STORED_LOG_ENTRIES: usize = 2000; pub struct LogStore { projects: HashMap, ProjectState>, - io_tx: mpsc::UnboundedSender<(WeakModel, LanguageServerId, IoKind, String)>, + language_servers: HashMap, + copilot_log_subscription: Option, + _copilot_subscription: Option, + io_tx: mpsc::UnboundedSender<(LanguageServerId, IoKind, String)>, } struct ProjectState { - servers: HashMap, _subscriptions: [gpui::Subscription; 2], } struct LanguageServerState { + name: LanguageServerName, log_messages: VecDeque, rpc_state: Option, + project: Option>, _io_logs_subscription: Option, _lsp_logs_subscription: Option, } @@ -109,15 +113,55 @@ pub fn init(cx: &mut AppContext) { impl LogStore { pub fn new(cx: &mut ModelContext) -> Self { let (io_tx, mut io_rx) = mpsc::unbounded(); + + let copilot_subscription = Copilot::global(cx).map(|copilot| { + let copilot = &copilot; + cx.subscribe( + copilot, + |this, copilot, copilot_event, cx| match copilot_event { + copilot::Event::CopilotLanguageServerStarted => { + if let Some(server) = copilot.read(cx).language_server() { + let server_id = server.server_id(); + let weak_this = cx.weak_model(); + this.copilot_log_subscription = + Some(server.on_notification::( + move |params, mut cx| { + weak_this + .update(&mut cx, |this, cx| { + this.add_language_server_log( + server_id, + ¶ms.message, + cx, + ); + }) + .ok(); + }, + )); + this.add_language_server( + None, + LanguageServerName(Arc::from("copilot")), + server.clone(), + cx, + ); + } + } + }, + ) + }); + let this = Self { + copilot_log_subscription: None, + _copilot_subscription: copilot_subscription, projects: HashMap::default(), + language_servers: HashMap::default(), io_tx, }; + cx.spawn(|this, mut cx| async move { - while let Some((project, server_id, io_kind, message)) = io_rx.next().await { + while let Some((server_id, io_kind, message)) = io_rx.next().await { if let Some(this) = this.upgrade() { this.update(&mut cx, |this, cx| { - this.on_io(project, server_id, io_kind, &message, cx); + this.on_io(server_id, io_kind, &message, cx); })?; } } @@ -132,20 +176,32 @@ impl LogStore { self.projects.insert( project.downgrade(), ProjectState { - servers: HashMap::default(), _subscriptions: [ cx.observe_release(project, move |this, _, _| { this.projects.remove(&weak_project); + this.language_servers + .retain(|_, state| state.project.as_ref() != Some(&weak_project)); }), cx.subscribe(project, |this, project, event, cx| match event { project::Event::LanguageServerAdded(id) => { - this.add_language_server(&project, *id, cx); + let read_project = project.read(cx); + if let Some((server, adapter)) = read_project + .language_server_for_id(*id) + .zip(read_project.language_server_adapter_for_id(*id)) + { + this.add_language_server( + Some(&project.downgrade()), + adapter.name.clone(), + server, + cx, + ); + } } project::Event::LanguageServerRemoved(id) => { - this.remove_language_server(&project, *id, cx); + this.remove_language_server(*id, cx); } project::Event::LanguageServerLog(id, message) => { - this.add_language_server_log(&project, *id, message, cx); + this.add_language_server_log(*id, message, cx); } _ => {} }), @@ -154,74 +210,69 @@ impl LogStore { ); } + fn get_language_server_state( + &mut self, + id: LanguageServerId, + ) -> Option<&mut LanguageServerState> { + self.language_servers.get_mut(&id) + } + fn add_language_server( &mut self, - project: &Model, - id: LanguageServerId, + project: Option<&WeakModel>, + name: LanguageServerName, + server: Arc, cx: &mut ModelContext, ) -> Option<&mut LanguageServerState> { - let project_state = self.projects.get_mut(&project.downgrade())?; - let server_state = project_state.servers.entry(id).or_insert_with(|| { - cx.notify(); - LanguageServerState { - rpc_state: None, - log_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES), - _io_logs_subscription: None, - _lsp_logs_subscription: None, - } - }); + let server_state = self + .language_servers + .entry(server.server_id()) + .or_insert_with(|| { + cx.notify(); + LanguageServerState { + name, + rpc_state: None, + project: project.cloned(), + log_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES), + _io_logs_subscription: None, + _lsp_logs_subscription: None, + } + }); - let server = project.read(cx).language_server_for_id(id); - if let Some(server) = server.as_deref() { - if server.has_notification_handler::() { - // Another event wants to re-add the server that was already added and subscribed to, avoid doing it again. - return Some(server_state); - } + if server.has_notification_handler::() { + // Another event wants to re-add the server that was already added and subscribed to, avoid doing it again. + return Some(server_state); } - let weak_project = project.downgrade(); let io_tx = self.io_tx.clone(); - server_state._io_logs_subscription = server.as_ref().map(|server| { - server.on_io(move |io_kind, message| { - io_tx - .unbounded_send((weak_project.clone(), id, io_kind, message.to_string())) - .ok(); - }) - }); + let server_id = server.server_id(); + server_state._io_logs_subscription = Some(server.on_io(move |io_kind, message| { + io_tx + .unbounded_send((server_id, io_kind, message.to_string())) + .ok(); + })); let this = cx.handle().downgrade(); - let weak_project = project.downgrade(); - server_state._lsp_logs_subscription = server.map(|server| { - let server_id = server.server_id(); - server.on_notification::({ + server_state._lsp_logs_subscription = + Some(server.on_notification::({ move |params, mut cx| { - if let Some((project, this)) = weak_project.upgrade().zip(this.upgrade()) { + if let Some(this) = this.upgrade() { this.update(&mut cx, |this, cx| { - this.add_language_server_log(&project, server_id, ¶ms.message, cx); + this.add_language_server_log(server_id, ¶ms.message, cx); }) .ok(); } } - }) - }); + })); Some(server_state) } fn add_language_server_log( &mut self, - project: &Model, id: LanguageServerId, message: &str, cx: &mut ModelContext, ) -> Option<()> { - let language_server_state = match self - .projects - .get_mut(&project.downgrade())? - .servers - .get_mut(&id) - { - Some(existing_state) => existing_state, - None => self.add_language_server(&project, id, cx)?, - }; + let language_server_state = self.get_language_server_state(id)?; let log_lines = &mut language_server_state.log_messages; while log_lines.len() >= MAX_STORED_LOG_ENTRIES { @@ -238,38 +289,43 @@ impl LogStore { Some(()) } - fn remove_language_server( - &mut self, - project: &Model, - id: LanguageServerId, - cx: &mut ModelContext, - ) -> Option<()> { - let project_state = self.projects.get_mut(&project.downgrade())?; - project_state.servers.remove(&id); + fn remove_language_server(&mut self, id: LanguageServerId, cx: &mut ModelContext) { + self.language_servers.remove(&id); cx.notify(); - Some(()) } - fn server_logs( - &self, - project: &Model, - server_id: LanguageServerId, - ) -> Option<&VecDeque> { - let weak_project = project.downgrade(); - let project_state = self.projects.get(&weak_project)?; - let server_state = project_state.servers.get(&server_id)?; - Some(&server_state.log_messages) + fn server_logs(&self, server_id: LanguageServerId) -> Option<&VecDeque> { + Some(&self.language_servers.get(&server_id)?.log_messages) + } + + fn server_ids_for_project<'a>( + &'a self, + project: &'a WeakModel, + ) -> impl Iterator + 'a { + [].into_iter() + .chain(self.language_servers.iter().filter_map(|(id, state)| { + if state.project.as_ref() == Some(project) { + return Some(*id); + } else { + None + } + })) + .chain(self.language_servers.iter().filter_map(|(id, state)| { + if state.project.is_none() { + return Some(*id); + } else { + None + } + })) } fn enable_rpc_trace_for_language_server( &mut self, - project: &Model, server_id: LanguageServerId, ) -> Option<&mut LanguageServerRpcState> { - let weak_project = project.downgrade(); - let project_state = self.projects.get_mut(&weak_project)?; - let server_state = project_state.servers.get_mut(&server_id)?; - let rpc_state = server_state + let rpc_state = self + .language_servers + .get_mut(&server_id)? .rpc_state .get_or_insert_with(|| LanguageServerRpcState { rpc_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES), @@ -280,20 +336,14 @@ impl LogStore { pub fn disable_rpc_trace_for_language_server( &mut self, - project: &Model, server_id: LanguageServerId, - _: &mut ModelContext, ) -> Option<()> { - let project = project.downgrade(); - let project_state = self.projects.get_mut(&project)?; - let server_state = project_state.servers.get_mut(&server_id)?; - server_state.rpc_state.take(); + self.language_servers.get_mut(&server_id)?.rpc_state.take(); Some(()) } fn on_io( &mut self, - project: WeakModel, language_server_id: LanguageServerId, io_kind: IoKind, message: &str, @@ -303,18 +353,14 @@ impl LogStore { IoKind::StdOut => true, IoKind::StdIn => false, IoKind::StdErr => { - let project = project.upgrade()?; let message = format!("stderr: {}", message.trim()); - self.add_language_server_log(&project, language_server_id, &message, cx); + self.add_language_server_log(language_server_id, &message, cx); return Some(()); } }; let state = self - .projects - .get_mut(&project)? - .servers - .get_mut(&language_server_id)? + .get_language_server_state(language_server_id)? .rpc_state .as_mut()?; let kind = if is_received { @@ -360,42 +406,40 @@ impl LspLogView { ) -> Self { let server_id = log_store .read(cx) - .projects - .get(&project.downgrade()) - .and_then(|project| project.servers.keys().copied().next()); - let model_changes_subscription = cx.observe(&log_store, |this, store, cx| { - maybe!({ - let project_state = store.read(cx).projects.get(&this.project.downgrade())?; - if let Some(current_lsp) = this.current_server_id { - if !project_state.servers.contains_key(¤t_lsp) { - if let Some(server) = project_state.servers.iter().next() { - if this.is_showing_rpc_trace { - this.show_rpc_trace_for_server(*server.0, cx) - } else { - this.show_logs_for_server(*server.0, cx) - } - } else { - this.current_server_id = None; - this.editor.update(cx, |editor, cx| { - editor.set_read_only(false); - editor.clear(cx); - editor.set_read_only(true); - }); - cx.notify(); - } - } - } else { - if let Some(server) = project_state.servers.iter().next() { + .language_servers + .iter() + .find(|(_, server)| server.project == Some(project.downgrade())) + .map(|(id, _)| *id); + + let weak_project = project.downgrade(); + let model_changes_subscription = cx.observe(&log_store, move |this, store, cx| { + let first_server_id_for_project = + store.read(cx).server_ids_for_project(&weak_project).next(); + if let Some(current_lsp) = this.current_server_id { + if !store.read(cx).language_servers.contains_key(¤t_lsp) { + if let Some(server_id) = first_server_id_for_project { if this.is_showing_rpc_trace { - this.show_rpc_trace_for_server(*server.0, cx) + this.show_rpc_trace_for_server(server_id, cx) } else { - this.show_logs_for_server(*server.0, cx) + this.show_logs_for_server(server_id, cx) } + } else { + this.current_server_id = None; + this.editor.update(cx, |editor, cx| { + editor.set_read_only(false); + editor.clear(cx); + editor.set_read_only(true); + }); + cx.notify(); } } - - Some(()) - }); + } else if let Some(server_id) = first_server_id_for_project { + if this.is_showing_rpc_trace { + this.show_rpc_trace_for_server(server_id, cx) + } else { + this.show_logs_for_server(server_id, cx) + } + } cx.notify(); }); @@ -477,14 +521,14 @@ impl LspLogView { pub(crate) fn menu_items<'a>(&'a self, cx: &'a AppContext) -> Option> { let log_store = self.log_store.read(cx); - let state = log_store.projects.get(&self.project.downgrade())?; + let mut rows = self .project .read(cx) .language_servers() .filter_map(|(server_id, language_server_name, worktree_id)| { let worktree = self.project.read(cx).worktree_for_id(worktree_id, cx)?; - let state = state.servers.get(&server_id)?; + let state = log_store.language_servers.get(&server_id)?; Some(LogMenuItem { server_id, server_name: language_server_name, @@ -501,7 +545,7 @@ impl LspLogView { .read(cx) .supplementary_language_servers() .filter_map(|(&server_id, (name, _))| { - let state = state.servers.get(&server_id)?; + let state = log_store.language_servers.get(&server_id)?; Some(LogMenuItem { server_id, server_name: name.clone(), @@ -514,6 +558,27 @@ impl LspLogView { }) }), ) + .chain( + log_store + .language_servers + .iter() + .filter_map(|(server_id, state)| { + if state.project.is_none() { + Some(LogMenuItem { + server_id: *server_id, + server_name: state.name.clone(), + worktree_root_name: "supplementary".to_string(), + rpc_trace_enabled: state.rpc_state.is_some(), + rpc_trace_selected: self.is_showing_rpc_trace + && self.current_server_id == Some(*server_id), + logs_selected: !self.is_showing_rpc_trace + && self.current_server_id == Some(*server_id), + }) + } else { + None + } + }), + ) .collect::>(); rows.sort_by_key(|row| row.server_id); rows.dedup_by_key(|row| row.server_id); @@ -524,7 +589,7 @@ impl LspLogView { let log_contents = self .log_store .read(cx) - .server_logs(&self.project, server_id) + .server_logs(server_id) .map(log_contents); if let Some(log_contents) = log_contents { self.current_server_id = Some(server_id); @@ -544,7 +609,7 @@ impl LspLogView { ) { let rpc_log = self.log_store.update(cx, |log_store, _| { log_store - .enable_rpc_trace_for_language_server(&self.project, server_id) + .enable_rpc_trace_for_language_server(server_id) .map(|state| log_contents(&state.rpc_messages)) }); if let Some(rpc_log) = rpc_log { @@ -585,11 +650,11 @@ impl LspLogView { enabled: bool, cx: &mut ViewContext, ) { - self.log_store.update(cx, |log_store, cx| { + self.log_store.update(cx, |log_store, _| { if enabled { - log_store.enable_rpc_trace_for_language_server(&self.project, server_id); + log_store.enable_rpc_trace_for_language_server(server_id); } else { - log_store.disable_rpc_trace_for_language_server(&self.project, server_id, cx); + log_store.disable_rpc_trace_for_language_server(server_id); } }); if !enabled && Some(server_id) == self.current_server_id { diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index 30766a7b6f..1d943bc080 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -30,7 +30,6 @@ async-trait.workspace = true client.workspace = true clock.workspace = true collections.workspace = true -copilot.workspace = true fs.workspace = true futures.workspace = true fuzzy.workspace = true diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 733a06172b..28c6182016 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -20,7 +20,6 @@ use client::{ }; use clock::ReplicaId; use collections::{hash_map, BTreeMap, HashMap, HashSet, VecDeque}; -use copilot::Copilot; use debounced_delay::DebouncedDelay; use futures::{ channel::{ @@ -200,8 +199,6 @@ pub struct Project { _maintain_buffer_languages: Task<()>, _maintain_workspace_config: Task>, terminals: Terminals, - copilot_lsp_subscription: Option, - copilot_log_subscription: Option, current_lsp_settings: HashMap, LspSettings>, node: Option>, default_prettier: DefaultPrettier, @@ -685,8 +682,6 @@ impl Project { let (tx, rx) = mpsc::unbounded(); cx.spawn(move |this, cx| Self::send_buffer_ordered_messages(this, rx, cx)) .detach(); - let copilot_lsp_subscription = - Copilot::global(cx).map(|copilot| subscribe_for_copilot_events(&copilot, cx)); let tasks = Inventory::new(cx); Self { @@ -735,8 +730,6 @@ impl Project { terminals: Terminals { local_handles: Vec::new(), }, - copilot_lsp_subscription, - copilot_log_subscription: None, current_lsp_settings: ProjectSettings::get_global(cx).lsp.clone(), node: Some(node), default_prettier: DefaultPrettier::default(), @@ -823,8 +816,6 @@ impl Project { let (tx, rx) = mpsc::unbounded(); cx.spawn(move |this, cx| Self::send_buffer_ordered_messages(this, rx, cx)) .detach(); - let copilot_lsp_subscription = - Copilot::global(cx).map(|copilot| subscribe_for_copilot_events(&copilot, cx)); let mut this = Self { worktrees: Vec::new(), buffer_ordered_messages_tx: tx, @@ -891,8 +882,6 @@ impl Project { terminals: Terminals { local_handles: Vec::new(), }, - copilot_lsp_subscription, - copilot_log_subscription: None, current_lsp_settings: ProjectSettings::get_global(cx).lsp.clone(), node: None, default_prettier: DefaultPrettier::default(), @@ -1184,17 +1173,6 @@ impl Project { self.restart_language_servers(worktree, language, cx); } - if self.copilot_lsp_subscription.is_none() { - if let Some(copilot) = Copilot::global(cx) { - for buffer in self.opened_buffers.values() { - if let Some(buffer) = buffer.upgrade() { - self.register_buffer_with_copilot(&buffer, cx); - } - } - self.copilot_lsp_subscription = Some(subscribe_for_copilot_events(&copilot, cx)); - } - } - cx.notify(); } @@ -2351,7 +2329,7 @@ impl Project { self.detect_language_for_buffer(buffer, cx); self.register_buffer_with_language_servers(buffer, cx); - self.register_buffer_with_copilot(buffer, cx); + // self.register_buffer_with_copilot(buffer, cx); cx.observe_release(buffer, |this, buffer, cx| { if let Some(file) = File::from_dyn(buffer.file()) { if file.is_local() { @@ -2500,15 +2478,15 @@ impl Project { }); } - fn register_buffer_with_copilot( - &self, - buffer_handle: &Model, - cx: &mut ModelContext, - ) { - if let Some(copilot) = Copilot::global(cx) { - copilot.update(cx, |copilot, cx| copilot.register_buffer(buffer_handle, cx)); - } - } + // fn register_buffer_with_copilot( + // &self, + // buffer_handle: &Model, + // cx: &mut ModelContext, + // ) { + // if let Some(copilot) = Copilot::global(cx) { + // copilot.update(cx, |copilot, cx| copilot.register_buffer(buffer_handle, cx)); + // } + // } async fn send_buffer_ordered_messages( this: WeakModel, @@ -10475,43 +10453,6 @@ async fn search_ignored_entry( } } -fn subscribe_for_copilot_events( - copilot: &Model, - cx: &mut ModelContext<'_, Project>, -) -> gpui::Subscription { - cx.subscribe( - copilot, - |project, copilot, copilot_event, cx| match copilot_event { - copilot::Event::CopilotLanguageServerStarted => { - match copilot.read(cx).language_server() { - Some((name, copilot_server)) => { - // Another event wants to re-add the server that was already added and subscribed to, avoid doing it again. - if !copilot_server.has_notification_handler::() { - let new_server_id = copilot_server.server_id(); - let weak_project = cx.weak_model(); - let copilot_log_subscription = copilot_server - .on_notification::( - move |params, mut cx| { - weak_project.update(&mut cx, |_, cx| { - cx.emit(Event::LanguageServerLog( - new_server_id, - params.message, - )); - }).ok(); - }, - ); - project.supplementary_language_servers.insert(new_server_id, (name.clone(), Arc::clone(copilot_server))); - project.copilot_log_subscription = Some(copilot_log_subscription); - cx.emit(Event::LanguageServerAdded(new_server_id)); - } - } - None => debug_panic!("Received Copilot language server started event, but no language server is running"), - } - } - }, - ) -} - fn glob_literal_prefix(glob: &str) -> &str { let mut literal_end = 0; for (i, part) in glob.split(path::MAIN_SEPARATOR).enumerate() { diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 3dfa9508dc..5f8af8e1f0 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -207,7 +207,7 @@ message Envelope { GetCachedEmbeddings get_cached_embeddings = 189; GetCachedEmbeddingsResponse get_cached_embeddings_response = 190; ComputeEmbeddings compute_embeddings = 191; - ComputeEmbeddingsResponse compute_embeddings_response = 192; // current max + ComputeEmbeddingsResponse compute_embeddings_response = 192; UpdateChannelMessage update_channel_message = 170; ChannelMessageUpdate channel_message_update = 171; @@ -238,7 +238,10 @@ message Envelope { ValidateDevServerProjectRequest validate_dev_server_project_request = 194; DeleteDevServer delete_dev_server = 195; OpenNewBuffer open_new_buffer = 196; - DeleteDevServerProject delete_dev_server_project = 197; // Current max + DeleteDevServerProject delete_dev_server_project = 197; + + GetSupermavenApiKey get_supermaven_api_key = 198; + GetSupermavenApiKeyResponse get_supermaven_api_key_response = 199; // current max } reserved 158 to 161; @@ -2084,3 +2087,9 @@ message LspResponse { GetCodeActionsResponse get_code_actions_response = 2; } } + +message GetSupermavenApiKey {} + +message GetSupermavenApiKeyResponse { + string api_key = 1; +} diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 966a24ead9..d011f1d1d2 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -201,6 +201,8 @@ messages!( (GetProjectSymbolsResponse, Background), (GetReferences, Background), (GetReferencesResponse, Background), + (GetSupermavenApiKey, Background), + (GetSupermavenApiKeyResponse, Background), (GetTypeDefinition, Background), (GetTypeDefinitionResponse, Background), (GetImplementation, Background), @@ -360,6 +362,7 @@ request_messages!( (GetPrivateUserInfo, GetPrivateUserInfoResponse), (GetProjectSymbols, GetProjectSymbolsResponse), (GetReferences, GetReferencesResponse), + (GetSupermavenApiKey, GetSupermavenApiKeyResponse), (GetTypeDefinition, GetTypeDefinitionResponse), (GetUsers, UsersResponse), (IncomingCall, Ack), diff --git a/crates/supermaven/Cargo.toml b/crates/supermaven/Cargo.toml new file mode 100644 index 0000000000..4abbcd4a43 --- /dev/null +++ b/crates/supermaven/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "supermaven" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/supermaven.rs" +doctest = false + +[dependencies] +anyhow.workspace = true +client.workspace = true +collections.workspace = true +editor.workspace = true +gpui.workspace = true +futures.workspace = true +language.workspace = true +log.workspace = true +postage.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +supermaven_api.workspace = true +smol.workspace = true +ui.workspace = true +util.workspace = true + +[dev-dependencies] +editor = { workspace = true, features = ["test-support"] } +env_logger.workspace = true +gpui = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } +settings = { workspace = true, features = ["test-support"] } +theme = { workspace = true, features = ["test-support"] } +util = { workspace = true, features = ["test-support"] } diff --git a/crates/supermaven/src/messages.rs b/crates/supermaven/src/messages.rs new file mode 100644 index 0000000000..9082e00d60 --- /dev/null +++ b/crates/supermaven/src/messages.rs @@ -0,0 +1,152 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct SetApiKey { + pub api_key: String, +} + +// Outbound messages +#[derive(Debug, Serialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum OutboundMessage { + SetApiKey(SetApiKey), + StateUpdate(StateUpdateMessage), + #[allow(dead_code)] + UseFreeVersion, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct StateUpdateMessage { + pub new_id: String, + pub updates: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum StateUpdate { + FileUpdate(FileUpdateMessage), + CursorUpdate(CursorPositionUpdateMessage), +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct FileUpdateMessage { + pub path: String, + pub content: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct CursorPositionUpdateMessage { + pub path: String, + pub offset: usize, +} + +// Inbound messages coming in on stdout + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum ResponseItem { + // A completion + Text { text: String }, + // Vestigial message type from old versions -- safe to ignore + Del { text: String }, + // Be able to delete whitespace prior to the cursor, likely for the rest of the completion + Dedent { text: String }, + // When the completion is over + End, + // Got the closing parentheses and shouldn't show any more after + Barrier, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SupermavenResponse { + pub state_id: String, + pub items: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SupermavenMetadataMessage { + pub dust_strings: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SupermavenTaskUpdateMessage { + pub task: String, + pub status: TaskStatus, + pub percent_complete: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum TaskStatus { + InProgress, + Complete, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SupermavenActiveRepoMessage { + pub repo_simple_name: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum SupermavenPopupAction { + OpenUrl { label: String, url: String }, + NoOp { label: String }, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct SupermavenPopupMessage { + pub message: String, + pub actions: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "camelCase")] +pub struct ActivationRequest { + pub activate_url: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SupermavenSetMessage { + pub key: String, + pub value: serde_json::Value, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum ServiceTier { + FreeNoLicense, + #[serde(other)] + Unknown, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum SupermavenMessage { + Response(SupermavenResponse), + Metadata(SupermavenMetadataMessage), + Apology { + message: Option, + }, + ActivationRequest(ActivationRequest), + ActivationSuccess, + Passthrough { + passthrough: Box, + }, + Popup(SupermavenPopupMessage), + TaskStatus(SupermavenTaskUpdateMessage), + ActiveRepo(SupermavenActiveRepoMessage), + ServiceTier { + service_tier: ServiceTier, + }, + + Set(SupermavenSetMessage), + #[serde(other)] + Unknown, +} diff --git a/crates/supermaven/src/supermaven.rs b/crates/supermaven/src/supermaven.rs new file mode 100644 index 0000000000..c432116357 --- /dev/null +++ b/crates/supermaven/src/supermaven.rs @@ -0,0 +1,345 @@ +mod messages; +mod supermaven_completion_provider; + +pub use supermaven_completion_provider::*; + +use anyhow::{Context as _, Result}; +#[allow(unused_imports)] +use client::{proto, Client}; +use collections::BTreeMap; + +use futures::{channel::mpsc, io::BufReader, AsyncBufReadExt, StreamExt}; +use gpui::{AppContext, AsyncAppContext, EntityId, Global, Model, ModelContext, Task, WeakModel}; +use language::{language_settings::all_language_settings, Anchor, Buffer, ToOffset}; +use messages::*; +use postage::watch; +use serde::{Deserialize, Serialize}; +use settings::SettingsStore; +use smol::{ + io::AsyncWriteExt, + process::{Child, ChildStdin, ChildStdout, Command}, +}; +use std::{ops::Range, path::PathBuf, process::Stdio, sync::Arc}; +use ui::prelude::*; +use util::ResultExt; + +pub fn init(client: Arc, cx: &mut AppContext) { + let supermaven = cx.new_model(|_| Supermaven::Starting); + Supermaven::set_global(supermaven.clone(), cx); + + let mut provider = all_language_settings(None, cx).inline_completions.provider; + if provider == language::language_settings::InlineCompletionProvider::Supermaven { + supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx)); + } + + cx.observe_global::(move |cx| { + let new_provider = all_language_settings(None, cx).inline_completions.provider; + if new_provider != provider { + provider = new_provider; + if provider == language::language_settings::InlineCompletionProvider::Supermaven { + supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx)); + } else { + supermaven.update(cx, |supermaven, _cx| supermaven.stop()); + } + } + }) + .detach(); +} + +pub enum Supermaven { + Starting, + FailedDownload { error: anyhow::Error }, + Spawned(SupermavenAgent), + Error { error: anyhow::Error }, +} + +#[derive(Clone)] +pub enum AccountStatus { + Unknown, + NeedsActivation { activate_url: String }, + Ready, +} + +#[derive(Clone)] +struct SupermavenGlobal(Model); + +impl Global for SupermavenGlobal {} + +impl Supermaven { + pub fn global(cx: &AppContext) -> Option> { + cx.try_global::() + .map(|model| model.0.clone()) + } + + pub fn set_global(supermaven: Model, cx: &mut AppContext) { + cx.set_global(SupermavenGlobal(supermaven)); + } + + pub fn start(&mut self, client: Arc, cx: &mut ModelContext) { + if let Self::Starting = self { + cx.spawn(|this, mut cx| async move { + let binary_path = + supermaven_api::get_supermaven_agent_path(client.http_client()).await?; + + this.update(&mut cx, |this, cx| { + if let Self::Starting = this { + *this = + Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?); + } + anyhow::Ok(()) + }) + }) + .detach_and_log_err(cx) + } + } + + pub fn stop(&mut self) { + *self = Self::Starting; + } + + pub fn is_enabled(&self) -> bool { + matches!(self, Self::Spawned { .. }) + } + + pub fn complete( + &mut self, + buffer: &Model, + cursor_position: Anchor, + cx: &AppContext, + ) -> Option { + if let Self::Spawned(agent) = self { + let buffer_id = buffer.entity_id(); + let buffer = buffer.read(cx); + let path = buffer + .file() + .and_then(|file| Some(file.as_local()?.abs_path(cx))) + .unwrap_or_else(|| PathBuf::from("untitled")) + .to_string_lossy() + .to_string(); + let content = buffer.text(); + let offset = cursor_position.to_offset(buffer); + let state_id = agent.next_state_id; + agent.next_state_id.0 += 1; + + let (updates_tx, mut updates_rx) = watch::channel(); + postage::stream::Stream::try_recv(&mut updates_rx).unwrap(); + + agent.states.insert( + state_id, + SupermavenCompletionState { + buffer_id, + range: cursor_position.bias_left(buffer)..cursor_position.bias_right(buffer), + completion: Vec::new(), + text: String::new(), + updates_tx, + }, + ); + let _ = agent + .outgoing_tx + .unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage { + new_id: state_id.0.to_string(), + updates: vec![ + StateUpdate::FileUpdate(FileUpdateMessage { + path: path.clone(), + content, + }), + StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }), + ], + })); + + Some(SupermavenCompletion { + id: state_id, + updates: updates_rx, + }) + } else { + None + } + } + + pub fn completion( + &self, + id: SupermavenCompletionStateId, + ) -> Option<&SupermavenCompletionState> { + if let Self::Spawned(agent) = self { + agent.states.get(&id) + } else { + None + } + } +} + +pub struct SupermavenAgent { + _process: Child, + next_state_id: SupermavenCompletionStateId, + states: BTreeMap, + outgoing_tx: mpsc::UnboundedSender, + _handle_outgoing_messages: Task>, + _handle_incoming_messages: Task>, + pub account_status: AccountStatus, + service_tier: Option, + #[allow(dead_code)] + client: Arc, +} + +impl SupermavenAgent { + fn new( + binary_path: PathBuf, + client: Arc, + cx: &mut ModelContext, + ) -> Result { + let mut process = Command::new(&binary_path) + .arg("stdio") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .kill_on_drop(true) + .spawn() + .context("failed to start the binary")?; + + let stdin = process + .stdin + .take() + .context("failed to get stdin for process")?; + let stdout = process + .stdout + .take() + .context("failed to get stdout for process")?; + + let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); + + cx.spawn({ + let client = client.clone(); + let outgoing_tx = outgoing_tx.clone(); + move |this, mut cx| async move { + let mut status = client.status(); + while let Some(status) = status.next().await { + if status.is_connected() { + let api_key = client.request(proto::GetSupermavenApiKey {}).await?.api_key; + outgoing_tx + .unbounded_send(OutboundMessage::SetApiKey(SetApiKey { api_key })) + .ok(); + this.update(&mut cx, |this, cx| { + if let Supermaven::Spawned(this) = this { + this.account_status = AccountStatus::Ready; + cx.notify(); + } + })?; + break; + } + } + return anyhow::Ok(()); + } + }) + .detach(); + + Ok(Self { + _process: process, + next_state_id: SupermavenCompletionStateId::default(), + states: BTreeMap::default(), + outgoing_tx, + _handle_outgoing_messages: cx + .spawn(|_, _cx| Self::handle_outgoing_messages(outgoing_rx, stdin)), + _handle_incoming_messages: cx + .spawn(|this, cx| Self::handle_incoming_messages(this, stdout, cx)), + account_status: AccountStatus::Unknown, + service_tier: None, + client, + }) + } + + async fn handle_outgoing_messages( + mut outgoing: mpsc::UnboundedReceiver, + mut stdin: ChildStdin, + ) -> Result<()> { + while let Some(message) = outgoing.next().await { + let bytes = serde_json::to_vec(&message)?; + stdin.write_all(&bytes).await?; + stdin.write_all(&[b'\n']).await?; + } + Ok(()) + } + + async fn handle_incoming_messages( + this: WeakModel, + stdout: ChildStdout, + mut cx: AsyncAppContext, + ) -> Result<()> { + const MESSAGE_PREFIX: &str = "SM-MESSAGE "; + + let stdout = BufReader::new(stdout); + let mut lines = stdout.lines(); + while let Some(line) = lines.next().await { + let Some(line) = line.context("failed to read line from stdout").log_err() else { + continue; + }; + let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else { + continue; + }; + let Some(message) = serde_json::from_str::(&line) + .with_context(|| format!("failed to deserialize line from stdout: {:?}", line)) + .log_err() + else { + continue; + }; + + this.update(&mut cx, |this, _cx| { + if let Supermaven::Spawned(this) = this { + this.handle_message(message); + } + Task::ready(anyhow::Ok(())) + })? + .await?; + } + + Ok(()) + } + + fn handle_message(&mut self, message: SupermavenMessage) { + match message { + SupermavenMessage::ActivationRequest(request) => { + self.account_status = match request.activate_url { + Some(activate_url) => AccountStatus::NeedsActivation { + activate_url: activate_url.clone(), + }, + None => AccountStatus::Ready, + }; + } + SupermavenMessage::ServiceTier { service_tier } => { + self.service_tier = Some(service_tier); + } + SupermavenMessage::Response(response) => { + let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap()); + if let Some(state) = self.states.get_mut(&state_id) { + for item in &response.items { + if let ResponseItem::Text { text } = item { + state.text.push_str(text); + } + } + state.completion.extend(response.items); + *state.updates_tx.borrow_mut() = (); + } + } + SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough), + _ => { + log::warn!("unhandled message: {:?}", message); + } + } + } +} + +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] +pub struct SupermavenCompletionStateId(usize); + +#[allow(dead_code)] +pub struct SupermavenCompletionState { + buffer_id: EntityId, + range: Range, + completion: Vec, + text: String, + updates_tx: watch::Sender<()>, +} + +pub struct SupermavenCompletion { + pub id: SupermavenCompletionStateId, + pub updates: watch::Receiver<()>, +} diff --git a/crates/supermaven/src/supermaven_completion_provider.rs b/crates/supermaven/src/supermaven_completion_provider.rs new file mode 100644 index 0000000000..8dc06bfac0 --- /dev/null +++ b/crates/supermaven/src/supermaven_completion_provider.rs @@ -0,0 +1,131 @@ +use crate::{Supermaven, SupermavenCompletionStateId}; +use anyhow::Result; +use editor::{Direction, InlineCompletionProvider}; +use futures::StreamExt as _; +use gpui::{AppContext, Model, ModelContext, Task}; +use language::{ + language_settings::all_language_settings, Anchor, Buffer, OffsetRangeExt as _, ToOffset, +}; +use std::time::Duration; + +pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75); + +pub struct SupermavenCompletionProvider { + supermaven: Model, + completion_id: Option, + pending_refresh: Task>, +} + +impl SupermavenCompletionProvider { + pub fn new(supermaven: Model) -> Self { + Self { + supermaven, + completion_id: None, + pending_refresh: Task::ready(Ok(())), + } + } +} + +impl InlineCompletionProvider for SupermavenCompletionProvider { + fn is_enabled(&self, buffer: &Model, cursor_position: Anchor, cx: &AppContext) -> bool { + if !self.supermaven.read(cx).is_enabled() { + return false; + } + + let buffer = buffer.read(cx); + let file = buffer.file(); + let language = buffer.language_at(cursor_position); + let settings = all_language_settings(file, cx); + settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref())) + } + + fn refresh( + &mut self, + buffer_handle: Model, + cursor_position: Anchor, + debounce: bool, + cx: &mut ModelContext, + ) { + let Some(mut completion) = self.supermaven.update(cx, |supermaven, cx| { + supermaven.complete(&buffer_handle, cursor_position, cx) + }) else { + return; + }; + + self.pending_refresh = cx.spawn(|this, mut cx| async move { + if debounce { + cx.background_executor().timer(DEBOUNCE_TIMEOUT).await; + } + + while let Some(()) = completion.updates.next().await { + this.update(&mut cx, |this, cx| { + this.completion_id = Some(completion.id); + cx.notify(); + })?; + } + Ok(()) + }); + } + + fn cycle( + &mut self, + _buffer: Model, + _cursor_position: Anchor, + _direction: Direction, + _cx: &mut ModelContext, + ) { + // todo!("cycling") + } + + fn accept(&mut self, _cx: &mut ModelContext) { + self.pending_refresh = Task::ready(Ok(())); + self.completion_id = None; + } + + fn discard(&mut self, _cx: &mut ModelContext) { + self.pending_refresh = Task::ready(Ok(())); + self.completion_id = None; + } + + fn active_completion_text<'a>( + &'a self, + buffer: &Model, + cursor_position: Anchor, + cx: &'a AppContext, + ) -> Option<&'a str> { + let completion_id = self.completion_id?; + let buffer = buffer.read(cx); + let cursor_offset = cursor_position.to_offset(buffer); + let completion = self.supermaven.read(cx).completion(completion_id)?; + + let mut completion_range = completion.range.to_offset(buffer); + + let prefix_len = common_prefix( + buffer.chars_for_range(completion_range.clone()), + completion.text.chars(), + ); + completion_range.start += prefix_len; + let suffix_len = common_prefix( + buffer.reversed_chars_for_range(completion_range.clone()), + completion.text[prefix_len..].chars().rev(), + ); + completion_range.end = completion_range.end.saturating_sub(suffix_len); + + let completion_text = &completion.text[prefix_len..completion.text.len() - suffix_len]; + if completion_range.is_empty() + && completion_range.start == cursor_offset + && !completion_text.trim().is_empty() + { + Some(completion_text) + } else { + None + } + } +} + +fn common_prefix, T2: Iterator>(a: T1, b: T2) -> usize { + a.zip(b) + .take_while(|(a, b)| a == b) + .map(|(a, _)| a.len_utf8()) + .sum() +} diff --git a/crates/supermaven_api/Cargo.toml b/crates/supermaven_api/Cargo.toml new file mode 100644 index 0000000000..69b6965283 --- /dev/null +++ b/crates/supermaven_api/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "supermaven_api" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/supermaven_api.rs" +doctest = false + +[dependencies] +anyhow.workspace = true +futures.workspace = true +serde.workspace = true +serde_json.workspace = true +smol.workspace = true +util.workspace = true diff --git a/crates/supermaven_api/src/supermaven_api.rs b/crates/supermaven_api/src/supermaven_api.rs new file mode 100644 index 0000000000..9d55bc5413 --- /dev/null +++ b/crates/supermaven_api/src/supermaven_api.rs @@ -0,0 +1,291 @@ +use anyhow::{anyhow, Context, Result}; +use futures::io::BufReader; +use futures::{AsyncReadExt, Future}; +use serde::{Deserialize, Serialize}; +use smol::fs::{self, File}; +use smol::stream::StreamExt; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use util::http::{AsyncBody, HttpClient, Request as HttpRequest}; +use util::paths::SUPERMAVEN_DIR; + +#[derive(Serialize)] +pub struct GetExternalUserRequest { + pub id: String, +} + +#[derive(Serialize)] +pub struct CreateExternalUserRequest { + pub id: String, + pub email: String, +} + +#[derive(Serialize)] +pub struct DeleteExternalUserRequest { + pub id: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CreateExternalUserResponse { + pub api_key: String, +} + +#[derive(Deserialize)] +pub struct SupermavenApiError { + pub message: String, +} + +pub struct SupermavenBinary {} + +pub struct SupermavenAdminApi { + admin_api_key: String, + api_url: String, + http_client: Arc, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SupermavenDownloadResponse { + pub download_url: String, + pub version: u64, + pub sha256_hash: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SupermavenUser { + id: String, + email: String, + api_key: String, +} + +impl SupermavenAdminApi { + pub fn new(admin_api_key: String, http_client: Arc) -> Self { + Self { + admin_api_key, + api_url: "https://supermaven.com/api/".to_string(), + http_client, + } + } + + pub async fn try_get_user( + &self, + request: GetExternalUserRequest, + ) -> Result> { + let uri = format!("{}external-user/{}", &self.api_url, &request.id); + + let request = HttpRequest::get(&uri).header("Authorization", self.admin_api_key.clone()); + + let mut response = self + .http_client + .send(request.body(AsyncBody::default())?) + .await + .with_context(|| "Unable to get Supermaven API Key".to_string())?; + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + if response.status().is_client_error() { + let error: SupermavenApiError = serde_json::from_slice(&body)?; + if error.message == "User not found" { + return Ok(None); + } else { + return Err(anyhow!("Supermaven API error: {}", error.message)); + } + } else if response.status().is_server_error() { + let error: SupermavenApiError = serde_json::from_slice(&body)?; + return Err(anyhow!("Supermaven API server error").context(error.message)); + } + + let body_str = std::str::from_utf8(&body)?; + + Ok(Some( + serde_json::from_str::(body_str) + .with_context(|| "Unable to parse Supermaven user response".to_string())?, + )) + } + + pub async fn try_create_user( + &self, + request: CreateExternalUserRequest, + ) -> Result { + let uri = format!("{}external-user", &self.api_url); + + let request = HttpRequest::post(&uri) + .header("Authorization", self.admin_api_key.clone()) + .body(AsyncBody::from(serde_json::to_vec(&request)?))?; + + let mut response = self + .http_client + .send(request) + .await + .with_context(|| "Unable to create Supermaven API Key".to_string())?; + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + let body_str = std::str::from_utf8(&body)?; + + if !response.status().is_success() { + let error: SupermavenApiError = serde_json::from_slice(&body)?; + return Err(anyhow!("Supermaven API server error").context(error.message)); + } + + serde_json::from_str::(body_str) + .with_context(|| "Unable to parse Supermaven API Key response".to_string()) + } + + pub async fn try_delete_user(&self, request: DeleteExternalUserRequest) -> Result<()> { + let uri = format!("{}external-user/{}", &self.api_url, &request.id); + + let request = HttpRequest::delete(&uri).header("Authorization", self.admin_api_key.clone()); + + let mut response = self + .http_client + .send(request.body(AsyncBody::default())?) + .await + .with_context(|| "Unable to delete Supermaven User".to_string())?; + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + if response.status().is_client_error() { + let error: SupermavenApiError = serde_json::from_slice(&body)?; + if error.message == "User not found" { + return Ok(()); + } else { + return Err(anyhow!("Supermaven API error: {}", error.message)); + } + } else if response.status().is_server_error() { + let error: SupermavenApiError = serde_json::from_slice(&body)?; + return Err(anyhow!("Supermaven API server error").context(error.message)); + } + + Ok(()) + } + + pub async fn try_get_or_create_user( + &self, + request: CreateExternalUserRequest, + ) -> Result { + let get_user_request = GetExternalUserRequest { + id: request.id.clone(), + }; + + match self.try_get_user(get_user_request).await? { + None => self.try_create_user(request).await, + Some(SupermavenUser { api_key, .. }) => Ok(CreateExternalUserResponse { api_key }), + } + } +} + +pub async fn latest_release( + client: Arc, + platform: &str, + arch: &str, +) -> Result { + let uri = format!( + "https://supermaven.com/api/download-path?platform={}&arch={}", + platform, arch + ); + + // Download is not authenticated + let request = HttpRequest::get(&uri); + + let mut response = client + .send(request.body(AsyncBody::default())?) + .await + .with_context(|| "Unable to acquire Supermaven Agent".to_string())?; + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + if response.status().is_client_error() || response.status().is_server_error() { + let body_str = std::str::from_utf8(&body)?; + let error: SupermavenApiError = serde_json::from_str(body_str)?; + return Err(anyhow!("Supermaven API error: {}", error.message)); + } + + serde_json::from_slice::(&body) + .with_context(|| "Unable to parse Supermaven Agent response".to_string()) +} + +pub fn version_path(version: u64) -> PathBuf { + SUPERMAVEN_DIR.join(format!("sm-agent-{}", version)) +} + +pub async fn has_version(version_path: &Path) -> bool { + fs::metadata(version_path) + .await + .map_or(false, |m| m.is_file()) +} + +pub fn get_supermaven_agent_path( + client: Arc, +) -> impl Future> { + async move { + fs::create_dir_all(&*SUPERMAVEN_DIR) + .await + .with_context(|| { + format!( + "Could not create Supermaven Agent Directory at {:?}", + &*SUPERMAVEN_DIR + ) + })?; + + let platform = match std::env::consts::OS { + "macos" => "darwin", + "windows" => "windows", + "linux" => "linux", + _ => return Err(anyhow!("unsupported platform")), + }; + + let arch = match std::env::consts::ARCH { + "x86_64" => "amd64", + "aarch64" => "arm64", + _ => return Err(anyhow!("unsupported architecture")), + }; + + let download_info = latest_release(client.clone(), platform, arch).await?; + + let binary_path = version_path(download_info.version); + + if has_version(&binary_path).await { + return Ok(binary_path); + } + + let request = HttpRequest::get(&download_info.download_url); + + let mut response = client + .send(request.body(AsyncBody::default())?) + .await + .with_context(|| "Unable to download Supermaven Agent".to_string())?; + + let mut file = File::create(&binary_path) + .await + .with_context(|| format!("Unable to create file at {:?}", binary_path))?; + + futures::io::copy(BufReader::new(response.body_mut()), &mut file) + .await + .with_context(|| format!("Unable to write binary to file at {:?}", binary_path))?; + + #[cfg(not(windows))] + { + file.set_permissions(::from_mode( + 0o755, + )) + .await?; + } + + let mut old_binary_paths = fs::read_dir(&*SUPERMAVEN_DIR).await?; + while let Some(old_binary_path) = old_binary_paths.next().await { + let old_binary_path = old_binary_path?; + if old_binary_path.path() != binary_path { + fs::remove_file(old_binary_path.path()).await?; + } + } + + Ok(binary_path) + } +} diff --git a/crates/ui/src/components/icon.rs b/crates/ui/src/components/icon.rs index bc05a8f3d3..9c9e05d6b6 100644 --- a/crates/ui/src/components/icon.rs +++ b/crates/ui/src/components/icon.rs @@ -155,6 +155,10 @@ pub enum IconName { Space, Split, Spinner, + Supermaven, + SupermavenDisabled, + SupermavenError, + SupermavenInit, Tab, Terminal, Trash, @@ -261,6 +265,10 @@ impl IconName { IconName::Space => "icons/space.svg", IconName::Split => "icons/split.svg", IconName::Spinner => "icons/spinner.svg", + IconName::Supermaven => "icons/supermaven.svg", + IconName::SupermavenDisabled => "icons/supermaven_disabled.svg", + IconName::SupermavenError => "icons/supermaven_error.svg", + IconName::SupermavenInit => "icons/supermaven_init.svg", IconName::Tab => "icons/tab.svg", IconName::Terminal => "icons/terminal.svg", IconName::Trash => "icons/trash.svg", diff --git a/crates/util/src/paths.rs b/crates/util/src/paths.rs index 205ea72f0a..feb7c19535 100644 --- a/crates/util/src/paths.rs +++ b/crates/util/src/paths.rs @@ -52,6 +52,7 @@ lazy_static::lazy_static! { pub static ref EXTENSIONS_DIR: PathBuf = SUPPORT_DIR.join("extensions"); pub static ref LANGUAGES_DIR: PathBuf = SUPPORT_DIR.join("languages"); pub static ref COPILOT_DIR: PathBuf = SUPPORT_DIR.join("copilot"); + pub static ref SUPERMAVEN_DIR: PathBuf = SUPPORT_DIR.join("supermaven"); pub static ref DEFAULT_PRETTIER_DIR: PathBuf = SUPPORT_DIR.join("prettier"); pub static ref DB_DIR: PathBuf = SUPPORT_DIR.join("db"); pub static ref CRASHES_DIR: Option = cfg!(target_os = "macos") diff --git a/crates/welcome/Cargo.toml b/crates/welcome/Cargo.toml index c18a09673f..e747072cde 100644 --- a/crates/welcome/Cargo.toml +++ b/crates/welcome/Cargo.toml @@ -17,7 +17,7 @@ test-support = [] [dependencies] anyhow.workspace = true client.workspace = true -copilot_ui.workspace = true +inline_completion_button.workspace = true db.workspace = true extensions_ui.workspace = true fuzzy.workspace = true diff --git a/crates/welcome/src/welcome.rs b/crates/welcome/src/welcome.rs index e6a2a53f2e..3ae07cda68 100644 --- a/crates/welcome/src/welcome.rs +++ b/crates/welcome/src/welcome.rs @@ -2,7 +2,6 @@ mod base_keymap_picker; mod base_keymap_setting; use client::{telemetry::Telemetry, TelemetrySettings}; -use copilot_ui; use db::kvp::KEY_VALUE_STORE; use gpui::{ svg, AnyElement, AppContext, EventEmitter, FocusHandle, FocusableView, InteractiveElement, @@ -143,7 +142,7 @@ impl Render for WelcomePage { this.telemetry.report_app_event( "welcome page: sign in to copilot".to_string(), ); - copilot_ui::initiate_sign_in(cx); + inline_completion_button::initiate_sign_in(cx); })), ) .child( diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 9a9f40020a..a8130fe5df 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -35,7 +35,6 @@ collab_ui.workspace = true collections.workspace = true command_palette.workspace = true copilot.workspace = true -copilot_ui.workspace = true db.workspace = true diagnostics.workspace = true editor.workspace = true @@ -51,6 +50,7 @@ go_to_line.workspace = true gpui.workspace = true headless.workspace = true image_viewer.workspace = true +inline_completion_button.workspace = true install_cli.workspace = true isahc.workspace = true journal.workspace = true @@ -83,6 +83,7 @@ settings.workspace = true simplelog = "0.9" smol.workspace = true tab_switcher.workspace = true +supermaven.workspace = true task.workspace = true tasks_ui.workspace = true telemetry_events.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 9850a2f603..3b2e96965e 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -9,16 +9,14 @@ mod zed; use anyhow::{anyhow, Context as _, Result}; use clap::{command, Parser}; use cli::FORCE_CLI_MODE_ENV_VAR_NAME; -use client::{parse_zed_link, telemetry::Telemetry, Client, DevServerToken, UserStore}; +use client::{parse_zed_link, Client, DevServerToken, UserStore}; use collab_ui::channel_view::ChannelView; -use copilot::Copilot; -use copilot_ui::CopilotCompletionProvider; use db::kvp::KEY_VALUE_STORE; -use editor::{Editor, EditorMode}; +use editor::Editor; use env_logger::Builder; use fs::RealFs; use futures::{future, StreamExt}; -use gpui::{App, AppContext, AsyncAppContext, Context, Task, ViewContext, VisualContext}; +use gpui::{App, AppContext, AsyncAppContext, Context, Task, VisualContext}; use image_viewer; use language::LanguageRegistry; use log::LevelFilter; @@ -55,6 +53,8 @@ use zed::{ OpenListener, OpenRequest, }; +use crate::zed::inline_completion_registry; + #[cfg(feature = "mimalloc")] #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; @@ -270,17 +270,20 @@ fn init_ui(args: Args) { editor::init(cx); image_viewer::init(cx); diagnostics::init(cx); + + // Initialize each completion provider. Settings are used for toggling between them. copilot::init( copilot_language_server_id, client.http_client(), node_runtime.clone(), cx, ); + supermaven::init(client.clone(), cx); assistant::init(client.clone(), cx); assistant2::init(client.clone(), cx); - init_inline_completion_provider(client.telemetry().clone(), cx); + inline_completion_registry::init(client.telemetry().clone(), cx); extension::init( fs.clone(), @@ -888,45 +891,3 @@ fn watch_file_types(fs: Arc, cx: &mut AppContext) { #[cfg(not(debug_assertions))] fn watch_file_types(_fs: Arc, _cx: &mut AppContext) {} - -fn init_inline_completion_provider(telemetry: Arc, cx: &mut AppContext) { - if let Some(copilot) = Copilot::global(cx) { - cx.observe_new_views(move |editor: &mut Editor, cx: &mut ViewContext| { - if editor.mode() == EditorMode::Full { - // We renamed some of these actions to not be copilot-specific, but that - // would have not been backwards-compatible. So here we are re-registering - // the actions with the old names to not break people's keymaps. - editor - .register_action(cx.listener( - |editor, _: &copilot::Suggest, cx: &mut ViewContext| { - editor.show_inline_completion(&Default::default(), cx); - }, - )) - .register_action(cx.listener( - |editor, _: &copilot::NextSuggestion, cx: &mut ViewContext| { - editor.next_inline_completion(&Default::default(), cx); - }, - )) - .register_action(cx.listener( - |editor, _: &copilot::PreviousSuggestion, cx: &mut ViewContext| { - editor.previous_inline_completion(&Default::default(), cx); - }, - )) - .register_action(cx.listener( - |editor, - _: &editor::actions::AcceptPartialCopilotSuggestion, - cx: &mut ViewContext| { - editor.accept_partial_inline_completion(&Default::default(), cx); - }, - )); - - let provider = cx.new_model(|_| { - CopilotCompletionProvider::new(copilot.clone()) - .with_telemetry(telemetry.clone()) - }); - editor.set_inline_completion_provider(provider, cx) - } - }) - .detach(); - } -} diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 6c0f155ce2..14cc9febd2 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -1,4 +1,5 @@ mod app_menus; +pub mod inline_completion_registry; mod only_instance; mod open_listener; @@ -127,7 +128,10 @@ pub fn initialize_workspace(app_state: Arc, cx: &mut AppContext) { }) .detach(); - let copilot = cx.new_view(|cx| copilot_ui::CopilotButton::new(app_state.fs.clone(), cx)); + let inline_completion_button = cx.new_view(|cx| { + inline_completion_button::InlineCompletionButton::new(app_state.fs.clone(), cx) + }); + let diagnostic_summary = cx.new_view(|cx| diagnostics::items::DiagnosticIndicator::new(workspace, cx)); let activity_indicator = @@ -140,7 +144,7 @@ pub fn initialize_workspace(app_state: Arc, cx: &mut AppContext) { workspace.status_bar().update(cx, |status_bar, cx| { status_bar.add_left_item(diagnostic_summary, cx); status_bar.add_left_item(activity_indicator, cx); - status_bar.add_right_item(copilot, cx); + status_bar.add_right_item(inline_completion_button, cx); status_bar.add_right_item(active_buffer_language, cx); status_bar.add_right_item(vim_mode_indicator, cx); status_bar.add_right_item(cursor_position, cx); diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/inline_completion_registry.rs new file mode 100644 index 0000000000..7ea50322a3 --- /dev/null +++ b/crates/zed/src/zed/inline_completion_registry.rs @@ -0,0 +1,126 @@ +use std::{cell::RefCell, rc::Rc, sync::Arc}; + +use client::telemetry::Telemetry; +use collections::HashMap; +use copilot::{Copilot, CopilotCompletionProvider}; +use editor::{Editor, EditorMode}; +use gpui::{AnyWindowHandle, AppContext, Context, ViewContext, WeakView}; +use language::language_settings::all_language_settings; +use settings::SettingsStore; +use supermaven::{Supermaven, SupermavenCompletionProvider}; + +pub fn init(telemetry: Arc, cx: &mut AppContext) { + let editors: Rc, AnyWindowHandle>>> = Rc::default(); + cx.observe_new_views({ + let editors = editors.clone(); + let telemetry = telemetry.clone(); + move |editor: &mut Editor, cx: &mut ViewContext| { + if editor.mode() != EditorMode::Full { + return; + } + + register_backward_compatible_actions(editor, cx); + + let editor_handle = cx.view().downgrade(); + cx.on_release({ + let editor_handle = editor_handle.clone(); + let editors = editors.clone(); + move |_, _, _| { + editors.borrow_mut().remove(&editor_handle); + } + }) + .detach(); + editors + .borrow_mut() + .insert(editor_handle, cx.window_handle()); + let provider = all_language_settings(None, cx).inline_completions.provider; + assign_inline_completion_provider(editor, provider, &telemetry, cx); + } + }) + .detach(); + + let mut provider = all_language_settings(None, cx).inline_completions.provider; + for (editor, window) in editors.borrow().iter() { + _ = window.update(cx, |_window, cx| { + _ = editor.update(cx, |editor, cx| { + assign_inline_completion_provider(editor, provider, &telemetry, cx); + }) + }); + } + + cx.observe_global::(move |cx| { + let new_provider = all_language_settings(None, cx).inline_completions.provider; + if new_provider != provider { + provider = new_provider; + for (editor, window) in editors.borrow().iter() { + _ = window.update(cx, |_window, cx| { + _ = editor.update(cx, |editor, cx| { + assign_inline_completion_provider(editor, provider, &telemetry, cx); + }) + }); + } + } + }) + .detach(); +} + +fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut ViewContext) { + // We renamed some of these actions to not be copilot-specific, but that + // would have not been backwards-compatible. So here we are re-registering + // the actions with the old names to not break people's keymaps. + editor + .register_action(cx.listener( + |editor, _: &copilot::Suggest, cx: &mut ViewContext| { + editor.show_inline_completion(&Default::default(), cx); + }, + )) + .register_action(cx.listener( + |editor, _: &copilot::NextSuggestion, cx: &mut ViewContext| { + editor.next_inline_completion(&Default::default(), cx); + }, + )) + .register_action(cx.listener( + |editor, _: &copilot::PreviousSuggestion, cx: &mut ViewContext| { + editor.previous_inline_completion(&Default::default(), cx); + }, + )) + .register_action(cx.listener( + |editor, + _: &editor::actions::AcceptPartialCopilotSuggestion, + cx: &mut ViewContext| { + editor.accept_partial_inline_completion(&Default::default(), cx); + }, + )); +} + +fn assign_inline_completion_provider( + editor: &mut Editor, + provider: language::language_settings::InlineCompletionProvider, + telemetry: &Arc, + cx: &mut ViewContext, +) { + match provider { + language::language_settings::InlineCompletionProvider::None => {} + language::language_settings::InlineCompletionProvider::Copilot => { + if let Some(copilot) = Copilot::global(cx) { + if let Some(buffer) = editor.buffer().read(cx).as_singleton() { + if buffer.read(cx).file().is_some() { + copilot.update(cx, |copilot, cx| { + copilot.register_buffer(&buffer, cx); + }); + } + } + let provider = cx.new_model(|_| { + CopilotCompletionProvider::new(copilot).with_telemetry(telemetry.clone()) + }); + editor.set_inline_completion_provider(Some(provider), cx); + } + } + language::language_settings::InlineCompletionProvider::Supermaven => { + if let Some(supermaven) = Supermaven::global(cx) { + let provider = cx.new_model(|_| SupermavenCompletionProvider::new(supermaven)); + editor.set_inline_completion_provider(Some(provider), cx); + } + } + } +}