diff --git a/Cargo.lock b/Cargo.lock index 2b3d7b2691..aa3a910390 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,6 +39,26 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "acp_tools" +version = "0.1.0" +dependencies = [ + "agent-client-protocol", + "collections", + "gpui", + "language", + "markdown", + "project", + "serde", + "serde_json", + "settings", + "theme", + "ui", + "util", + "workspace", + "workspace-hack", +] + [[package]] name = "action_log" version = "0.1.0" @@ -171,11 +191,12 @@ dependencies = [ [[package]] name = "agent-client-protocol" -version = "0.0.30" +version = "0.0.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f792e009ba59b137ee1db560bc37e567887ad4b5af6f32181d381fff690e2d4" +checksum = "289eb34ee17213dadcca47eedadd386a5e7678094095414e475965d1bcca2860" dependencies = [ "anyhow", + "async-broadcast", "futures 0.3.31", "log", "parking_lot", @@ -264,10 +285,10 @@ name = "agent_servers" version = "0.1.0" dependencies = [ "acp_thread", + "acp_tools", "action_log", "agent-client-protocol", "agent_settings", - "agentic-coding-protocol", "anyhow", "client", "collections", @@ -421,24 +442,6 @@ dependencies = [ "zed_actions", ] -[[package]] -name = "agentic-coding-protocol" -version = "0.0.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e6ae951b36fa2f8d9dd6e1af6da2fcaba13d7c866cf6a9e65deda9dc6c5fe4" -dependencies = [ - "anyhow", - "chrono", - "derive_more 2.0.1", - "futures 0.3.31", - "log", - "parking_lot", - "schemars", - "semver", - "serde", - "serde_json", -] - [[package]] name = "ahash" version = "0.7.8" @@ -854,7 +857,7 @@ dependencies = [ "anyhow", "async-trait", "collections", - "derive_more 0.99.19", + "derive_more", "extension", "futures 0.3.31", "gpui", @@ -917,7 +920,7 @@ dependencies = [ "clock", "collections", "ctor", - "derive_more 0.99.19", + "derive_more", "gpui", "icons", "indoc", @@ -954,7 +957,7 @@ dependencies = [ "cloud_llm_client", "collections", "component", - "derive_more 0.99.19", + "derive_more", "diffy", "editor", "feature_flags", @@ -3067,7 +3070,7 @@ dependencies = [ "cocoa 0.26.0", "collections", "credentials_provider", - "derive_more 0.99.19", + "derive_more", "feature_flags", "fs", "futures 0.3.31", @@ -3499,7 +3502,7 @@ name = "command_palette_hooks" version = "0.1.0" dependencies = [ "collections", - "derive_more 0.99.19", + "derive_more", "gpui", "workspace-hack", ] @@ -4662,27 +4665,6 @@ dependencies = [ "syn 2.0.101", ] -[[package]] -name = "derive_more" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" -dependencies = [ - "derive_more-impl", -] - -[[package]] -name = "derive_more-impl" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.101", - "unicode-xid", -] - [[package]] name = "derive_refineable" version = "0.1.0" @@ -6419,7 +6401,7 @@ dependencies = [ "askpass", "async-trait", "collections", - "derive_more 0.99.19", + "derive_more", "futures 0.3.31", "git2", "gpui", @@ -7449,7 +7431,7 @@ dependencies = [ "core-video", "cosmic-text", "ctor", - "derive_more 0.99.19", + "derive_more", "embed-resource", "env_logger 0.11.8", "etagere", @@ -7974,7 +7956,7 @@ version = "0.1.0" dependencies = [ "anyhow", "bytes 1.10.1", - "derive_more 0.99.19", + "derive_more", "futures 0.3.31", "http 1.3.1", "http-body 1.0.1", @@ -14377,12 +14359,10 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe8c9d1c68d67dd9f97ecbc6f932b60eb289c5dbddd8aa1405484a8fd2fcd984" dependencies = [ - "chrono", "dyn-clone", "indexmap", "ref-cast", "schemars_derive", - "semver", "serde", "serde_json", ] @@ -16466,7 +16446,7 @@ version = "0.1.0" dependencies = [ "anyhow", "collections", - "derive_more 0.99.19", + "derive_more", "fs", "futures 0.3.31", "gpui", @@ -19807,7 +19787,6 @@ dependencies = [ "any_vec", "anyhow", "async-recursion", - "bincode", "call", "client", "clock", @@ -19826,6 +19805,7 @@ dependencies = [ "node_runtime", "parking_lot", "postage", + "pretty_assertions", "project", "remote", "schemars", @@ -19981,7 +19961,6 @@ dependencies = [ "rustix 1.0.7", "rustls 0.23.26", "rustls-webpki 0.103.1", - "schemars", "scopeguard", "sea-orm", "sea-query-binder", @@ -20417,6 +20396,7 @@ dependencies = [ name = "zed" version = "0.202.0" dependencies = [ + "acp_tools", "activity_indicator", "agent", "agent_servers", diff --git a/Cargo.toml b/Cargo.toml index 84de9b30ad..6ec243a9b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] resolver = "2" members = [ + "crates/acp_tools", "crates/acp_thread", "crates/action_log", "crates/activity_indicator", @@ -227,6 +228,7 @@ edition = "2024" # Workspace member crates # +acp_tools = { path = "crates/acp_tools" } acp_thread = { path = "crates/acp_thread" } action_log = { path = "crates/action_log" } agent = { path = "crates/agent" } @@ -424,8 +426,7 @@ zlog_settings = { path = "crates/zlog_settings" } # External crates # -agentic-coding-protocol = "0.0.10" -agent-client-protocol = "0.0.30" +agent-client-protocol = "0.0.31" aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" diff --git a/Procfile.web b/Procfile.web new file mode 100644 index 0000000000..8140555144 --- /dev/null +++ b/Procfile.web @@ -0,0 +1,2 @@ +postgrest_llm: postgrest crates/collab/postgrest_llm.conf +website: cd ../zed.dev; npm run dev -- --port=3000 diff --git a/assets/icons/attach.svg b/assets/icons/attach.svg new file mode 100644 index 0000000000..f923a3c7c8 --- /dev/null +++ b/assets/icons/attach.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/icons/pencil_unavailable.svg b/assets/icons/pencil_unavailable.svg new file mode 100644 index 0000000000..4241d766ac --- /dev/null +++ b/assets/icons/pencil_unavailable.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/assets/icons/tool_think.svg b/assets/icons/tool_think.svg index efd5908a90..773f5e7fa7 100644 --- a/assets/icons/tool_think.svg +++ b/assets/icons/tool_think.svg @@ -1,3 +1,3 @@ - + diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 955e68f5a9..e84f4834af 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -16,7 +16,6 @@ "up": "menu::SelectPrevious", "enter": "menu::Confirm", "ctrl-enter": "menu::SecondaryConfirm", - "ctrl-escape": "menu::Cancel", "ctrl-c": "menu::Cancel", "escape": "menu::Cancel", "alt-shift-enter": "menu::Restart", @@ -856,7 +855,7 @@ "ctrl-backspace": ["project_panel::Delete", { "skip_prompt": false }], "ctrl-delete": ["project_panel::Delete", { "skip_prompt": false }], "alt-ctrl-r": "project_panel::RevealInFileManager", - "ctrl-shift-enter": "project_panel::OpenWithSystem", + "ctrl-shift-enter": "workspace::OpenWithSystem", "alt-d": "project_panel::CompareMarkedFiles", "shift-find": "project_panel::NewSearchInDirectory", "ctrl-alt-shift-f": "project_panel::NewSearchInDirectory", @@ -1195,9 +1194,16 @@ "ctrl-1": "onboarding::ActivateBasicsPage", "ctrl-2": "onboarding::ActivateEditingPage", "ctrl-3": "onboarding::ActivateAISetupPage", - "ctrl-escape": "onboarding::Finish", - "alt-tab": "onboarding::SignIn", + "ctrl-enter": "onboarding::Finish", + "alt-shift-l": "onboarding::SignIn", "alt-shift-a": "onboarding::OpenAccount" } + }, + { + "context": "InvalidBuffer", + "use_key_equivalents": true, + "bindings": { + "ctrl-shift-enter": "workspace::OpenWithSystem" + } } ] diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 8b18299a91..e72f4174ff 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -915,7 +915,7 @@ "cmd-backspace": ["project_panel::Trash", { "skip_prompt": true }], "cmd-delete": ["project_panel::Delete", { "skip_prompt": false }], "alt-cmd-r": "project_panel::RevealInFileManager", - "ctrl-shift-enter": "project_panel::OpenWithSystem", + "ctrl-shift-enter": "workspace::OpenWithSystem", "alt-d": "project_panel::CompareMarkedFiles", "cmd-alt-backspace": ["project_panel::Delete", { "skip_prompt": false }], "cmd-alt-shift-f": "project_panel::NewSearchInDirectory", @@ -1301,5 +1301,12 @@ "alt-tab": "onboarding::SignIn", "alt-shift-a": "onboarding::OpenAccount" } + }, + { + "context": "InvalidBuffer", + "use_key_equivalents": true, + "bindings": { + "ctrl-shift-enter": "workspace::OpenWithSystem" + } } ] diff --git a/assets/keymaps/vim.json b/assets/keymaps/vim.json index be6d34a134..62e50b3c8c 100644 --- a/assets/keymaps/vim.json +++ b/assets/keymaps/vim.json @@ -819,7 +819,7 @@ "v": "project_panel::OpenPermanent", "p": "project_panel::Open", "x": "project_panel::RevealInFileManager", - "s": "project_panel::OpenWithSystem", + "s": "workspace::OpenWithSystem", "z d": "project_panel::CompareMarkedFiles", "] c": "project_panel::SelectNextGitEntry", "[ c": "project_panel::SelectPrevGitEntry", diff --git a/assets/settings/default.json b/assets/settings/default.json index c290baf003..014b483250 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -1503,6 +1503,11 @@ // // Default: fallback "words": "fallback", + // Minimum number of characters required to automatically trigger word-based completions. + // Before that value, it's still possible to trigger the words-based completion manually with the corresponding editor command. + // + // Default: 3 + "words_min_length": 3, // Whether to fetch LSP completions or not. // // Default: true @@ -1642,9 +1647,6 @@ "use_on_type_format": false, "allow_rewrap": "anywhere", "soft_wrap": "editor_width", - "completions": { - "words": "disabled" - }, "prettier": { "allowed": true } @@ -1658,9 +1660,6 @@ } }, "Plain Text": { - "completions": { - "words": "disabled" - }, "allow_rewrap": "anywhere" }, "Python": { diff --git a/assets/themes/ayu/ayu.json b/assets/themes/ayu/ayu.json index f9f8720729..0ffbb9f61e 100644 --- a/assets/themes/ayu/ayu.json +++ b/assets/themes/ayu/ayu.json @@ -93,7 +93,7 @@ "terminal.ansi.bright_cyan": "#4c806fff", "terminal.ansi.dim_cyan": "#cbf2e4ff", "terminal.ansi.white": "#bfbdb6ff", - "terminal.ansi.bright_white": "#bfbdb6ff", + "terminal.ansi.bright_white": "#fafafaff", "terminal.ansi.dim_white": "#787876ff", "link_text.hover": "#5ac1feff", "conflict": "#feb454ff", @@ -479,7 +479,7 @@ "terminal.ansi.bright_cyan": "#ace0cbff", "terminal.ansi.dim_cyan": "#2a5f4aff", "terminal.ansi.white": "#fcfcfcff", - "terminal.ansi.bright_white": "#fcfcfcff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#bcbec0ff", "link_text.hover": "#3b9ee5ff", "conflict": "#f1ad49ff", @@ -865,7 +865,7 @@ "terminal.ansi.bright_cyan": "#4c806fff", "terminal.ansi.dim_cyan": "#cbf2e4ff", "terminal.ansi.white": "#cccac2ff", - "terminal.ansi.bright_white": "#cccac2ff", + "terminal.ansi.bright_white": "#fafafaff", "terminal.ansi.dim_white": "#898a8aff", "link_text.hover": "#72cffeff", "conflict": "#fecf72ff", diff --git a/assets/themes/gruvbox/gruvbox.json b/assets/themes/gruvbox/gruvbox.json index 459825c733..f0f0358b76 100644 --- a/assets/themes/gruvbox/gruvbox.json +++ b/assets/themes/gruvbox/gruvbox.json @@ -94,7 +94,7 @@ "terminal.ansi.bright_cyan": "#45603eff", "terminal.ansi.dim_cyan": "#c7dfbdff", "terminal.ansi.white": "#fbf1c7ff", - "terminal.ansi.bright_white": "#fbf1c7ff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#b0a189ff", "link_text.hover": "#83a598ff", "version_control.added": "#b7bb26ff", @@ -494,7 +494,7 @@ "terminal.ansi.bright_cyan": "#45603eff", "terminal.ansi.dim_cyan": "#c7dfbdff", "terminal.ansi.white": "#fbf1c7ff", - "terminal.ansi.bright_white": "#fbf1c7ff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#b0a189ff", "link_text.hover": "#83a598ff", "version_control.added": "#b7bb26ff", @@ -894,7 +894,7 @@ "terminal.ansi.bright_cyan": "#45603eff", "terminal.ansi.dim_cyan": "#c7dfbdff", "terminal.ansi.white": "#fbf1c7ff", - "terminal.ansi.bright_white": "#fbf1c7ff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#b0a189ff", "link_text.hover": "#83a598ff", "version_control.added": "#b7bb26ff", @@ -1294,7 +1294,7 @@ "terminal.ansi.bright_cyan": "#9fbca8ff", "terminal.ansi.dim_cyan": "#253e2eff", "terminal.ansi.white": "#fbf1c7ff", - "terminal.ansi.bright_white": "#fbf1c7ff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#b0a189ff", "link_text.hover": "#0b6678ff", "version_control.added": "#797410ff", @@ -1694,7 +1694,7 @@ "terminal.ansi.bright_cyan": "#9fbca8ff", "terminal.ansi.dim_cyan": "#253e2eff", "terminal.ansi.white": "#f9f5d7ff", - "terminal.ansi.bright_white": "#f9f5d7ff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#b0a189ff", "link_text.hover": "#0b6678ff", "version_control.added": "#797410ff", @@ -2094,7 +2094,7 @@ "terminal.ansi.bright_cyan": "#9fbca8ff", "terminal.ansi.dim_cyan": "#253e2eff", "terminal.ansi.white": "#f2e5bcff", - "terminal.ansi.bright_white": "#f2e5bcff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#b0a189ff", "link_text.hover": "#0b6678ff", "version_control.added": "#797410ff", diff --git a/assets/themes/one/one.json b/assets/themes/one/one.json index 23ebbcc67e..33f6d3c622 100644 --- a/assets/themes/one/one.json +++ b/assets/themes/one/one.json @@ -93,7 +93,7 @@ "terminal.ansi.bright_cyan": "#3a565bff", "terminal.ansi.dim_cyan": "#b9d9dfff", "terminal.ansi.white": "#dce0e5ff", - "terminal.ansi.bright_white": "#dce0e5ff", + "terminal.ansi.bright_white": "#fafafaff", "terminal.ansi.dim_white": "#575d65ff", "link_text.hover": "#74ade8ff", "version_control.added": "#27a657ff", @@ -468,7 +468,7 @@ "terminal.bright_foreground": "#242529ff", "terminal.dim_foreground": "#fafafaff", "terminal.ansi.black": "#242529ff", - "terminal.ansi.bright_black": "#242529ff", + "terminal.ansi.bright_black": "#747579ff", "terminal.ansi.dim_black": "#97979aff", "terminal.ansi.red": "#d36151ff", "terminal.ansi.bright_red": "#f0b0a4ff", @@ -489,7 +489,7 @@ "terminal.ansi.bright_cyan": "#a3bedaff", "terminal.ansi.dim_cyan": "#254058ff", "terminal.ansi.white": "#fafafaff", - "terminal.ansi.bright_white": "#fafafaff", + "terminal.ansi.bright_white": "#ffffffff", "terminal.ansi.dim_white": "#aaaaaaff", "link_text.hover": "#5c78e2ff", "version_control.added": "#27a657ff", diff --git a/crates/acp_thread/src/mention.rs b/crates/acp_thread/src/mention.rs index cde5806aaf..601b630a6e 100644 --- a/crates/acp_thread/src/mention.rs +++ b/crates/acp_thread/src/mention.rs @@ -5,7 +5,7 @@ use prompt_store::{PromptId, UserPromptId}; use serde::{Deserialize, Serialize}; use std::{ fmt, - ops::Range, + ops::RangeInclusive, path::{Path, PathBuf}, str::FromStr, }; @@ -17,13 +17,14 @@ pub enum MentionUri { File { abs_path: PathBuf, }, + PastedImage, Directory { abs_path: PathBuf, }, Symbol { - path: PathBuf, + abs_path: PathBuf, name: String, - line_range: Range, + line_range: RangeInclusive, }, Thread { id: acp::SessionId, @@ -38,8 +39,9 @@ pub enum MentionUri { name: String, }, Selection { - path: PathBuf, - line_range: Range, + #[serde(default, skip_serializing_if = "Option::is_none")] + abs_path: Option, + line_range: RangeInclusive, }, Fetch { url: Url, @@ -48,36 +50,44 @@ pub enum MentionUri { impl MentionUri { pub fn parse(input: &str) -> Result { + fn parse_line_range(fragment: &str) -> Result> { + let range = fragment + .strip_prefix("L") + .context("Line range must start with \"L\"")?; + let (start, end) = range + .split_once(":") + .context("Line range must use colon as separator")?; + let range = start + .parse::() + .context("Parsing line range start")? + .checked_sub(1) + .context("Line numbers should be 1-based")? + ..=end + .parse::() + .context("Parsing line range end")? + .checked_sub(1) + .context("Line numbers should be 1-based")?; + Ok(range) + } + let url = url::Url::parse(input)?; let path = url.path(); match url.scheme() { "file" => { let path = url.to_file_path().ok().context("Extracting file path")?; if let Some(fragment) = url.fragment() { - let range = fragment - .strip_prefix("L") - .context("Line range must start with \"L\"")?; - let (start, end) = range - .split_once(":") - .context("Line range must use colon as separator")?; - let line_range = start - .parse::() - .context("Parsing line range start")? - .checked_sub(1) - .context("Line numbers should be 1-based")? - ..end - .parse::() - .context("Parsing line range end")? - .checked_sub(1) - .context("Line numbers should be 1-based")?; + let line_range = parse_line_range(fragment)?; if let Some(name) = single_query_param(&url, "symbol")? { Ok(Self::Symbol { name, - path, + abs_path: path, line_range, }) } else { - Ok(Self::Selection { path, line_range }) + Ok(Self::Selection { + abs_path: Some(path), + line_range, + }) } } else if input.ends_with("/") { Ok(Self::Directory { abs_path: path }) @@ -105,6 +115,17 @@ impl MentionUri { id: rule_id.into(), name, }) + } else if path.starts_with("/agent/pasted-image") { + Ok(Self::PastedImage) + } else if path.starts_with("/agent/untitled-buffer") { + let fragment = url + .fragment() + .context("Missing fragment for untitled buffer selection")?; + let line_range = parse_line_range(fragment)?; + Ok(Self::Selection { + abs_path: None, + line_range, + }) } else { bail!("invalid zed url: {:?}", input); } @@ -121,13 +142,16 @@ impl MentionUri { .unwrap_or_default() .to_string_lossy() .into_owned(), + MentionUri::PastedImage => "Image".to_string(), MentionUri::Symbol { name, .. } => name.clone(), MentionUri::Thread { name, .. } => name.clone(), MentionUri::TextThread { name, .. } => name.clone(), MentionUri::Rule { name, .. } => name.clone(), MentionUri::Selection { - path, line_range, .. - } => selection_name(path, line_range), + abs_path: path, + line_range, + .. + } => selection_name(path.as_deref(), line_range), MentionUri::Fetch { url } => url.to_string(), } } @@ -137,6 +161,7 @@ impl MentionUri { MentionUri::File { abs_path } => { FileIcons::get_icon(abs_path, cx).unwrap_or_else(|| IconName::File.path().into()) } + MentionUri::PastedImage => IconName::Image.path().into(), MentionUri::Directory { abs_path } => { let name = abs_path .file_name() @@ -164,29 +189,40 @@ impl MentionUri { MentionUri::File { abs_path } => { Url::from_file_path(abs_path).expect("mention path should be absolute") } + MentionUri::PastedImage => Url::parse("zed:///agent/pasted-image").unwrap(), MentionUri::Directory { abs_path } => { Url::from_directory_path(abs_path).expect("mention path should be absolute") } MentionUri::Symbol { - path, + abs_path, name, line_range, } => { - let mut url = Url::from_file_path(path).expect("mention path should be absolute"); + let mut url = + Url::from_file_path(abs_path).expect("mention path should be absolute"); url.query_pairs_mut().append_pair("symbol", name); url.set_fragment(Some(&format!( "L{}:{}", - line_range.start + 1, - line_range.end + 1 + line_range.start() + 1, + line_range.end() + 1 ))); url } - MentionUri::Selection { path, line_range } => { - let mut url = Url::from_file_path(path).expect("mention path should be absolute"); + MentionUri::Selection { + abs_path: path, + line_range, + } => { + let mut url = if let Some(path) = path { + Url::from_file_path(path).expect("mention path should be absolute") + } else { + let mut url = Url::parse("zed:///").unwrap(); + url.set_path("/agent/untitled-buffer"); + url + }; url.set_fragment(Some(&format!( "L{}:{}", - line_range.start + 1, - line_range.end + 1 + line_range.start() + 1, + line_range.end() + 1 ))); url } @@ -198,7 +234,10 @@ impl MentionUri { } MentionUri::TextThread { path, name } => { let mut url = Url::parse("zed:///").unwrap(); - url.set_path(&format!("/agent/text-thread/{}", path.to_string_lossy())); + url.set_path(&format!( + "/agent/text-thread/{}", + path.to_string_lossy().trim_start_matches('/') + )); url.query_pairs_mut().append_pair("name", name); url } @@ -244,12 +283,14 @@ fn single_query_param(url: &Url, name: &'static str) -> Result> { } } -pub fn selection_name(path: &Path, line_range: &Range) -> String { +pub fn selection_name(path: Option<&Path>, line_range: &RangeInclusive) -> String { format!( "{} ({}:{})", - path.file_name().unwrap_or_default().display(), - line_range.start + 1, - line_range.end + 1 + path.and_then(|path| path.file_name()) + .unwrap_or("Untitled".as_ref()) + .display(), + *line_range.start() + 1, + *line_range.end() + 1 ) } @@ -309,14 +350,14 @@ mod tests { let parsed = MentionUri::parse(symbol_uri).unwrap(); match &parsed { MentionUri::Symbol { - path, + abs_path: path, name, line_range, } => { assert_eq!(path.to_str().unwrap(), path!("/path/to/file.rs")); assert_eq!(name, "MySymbol"); - assert_eq!(line_range.start, 9); - assert_eq!(line_range.end, 19); + assert_eq!(line_range.start(), &9); + assert_eq!(line_range.end(), &19); } _ => panic!("Expected Symbol variant"), } @@ -328,16 +369,39 @@ mod tests { let selection_uri = uri!("file:///path/to/file.rs#L5:15"); let parsed = MentionUri::parse(selection_uri).unwrap(); match &parsed { - MentionUri::Selection { path, line_range } => { - assert_eq!(path.to_str().unwrap(), path!("/path/to/file.rs")); - assert_eq!(line_range.start, 4); - assert_eq!(line_range.end, 14); + MentionUri::Selection { + abs_path: path, + line_range, + } => { + assert_eq!( + path.as_ref().unwrap().to_str().unwrap(), + path!("/path/to/file.rs") + ); + assert_eq!(line_range.start(), &4); + assert_eq!(line_range.end(), &14); } _ => panic!("Expected Selection variant"), } assert_eq!(parsed.to_uri().to_string(), selection_uri); } + #[test] + fn test_parse_untitled_selection_uri() { + let selection_uri = uri!("zed:///agent/untitled-buffer#L1:10"); + let parsed = MentionUri::parse(selection_uri).unwrap(); + match &parsed { + MentionUri::Selection { + abs_path: None, + line_range, + } => { + assert_eq!(line_range.start(), &0); + assert_eq!(line_range.end(), &9); + } + _ => panic!("Expected Selection variant without path"), + } + assert_eq!(parsed.to_uri().to_string(), selection_uri); + } + #[test] fn test_parse_thread_uri() { let thread_uri = "zed:///agent/thread/session123?name=Thread+name"; diff --git a/crates/acp_tools/Cargo.toml b/crates/acp_tools/Cargo.toml new file mode 100644 index 0000000000..7a6d8c21a0 --- /dev/null +++ b/crates/acp_tools/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "acp_tools" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + + +[lints] +workspace = true + +[lib] +path = "src/acp_tools.rs" +doctest = false + +[dependencies] +agent-client-protocol.workspace = true +collections.workspace = true +gpui.workspace = true +language.workspace= true +markdown.workspace = true +project.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +theme.workspace = true +ui.workspace = true +util.workspace = true +workspace-hack.workspace = true +workspace.workspace = true diff --git a/crates/acp_tools/LICENSE-GPL b/crates/acp_tools/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/acp_tools/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/acp_tools/src/acp_tools.rs b/crates/acp_tools/src/acp_tools.rs new file mode 100644 index 0000000000..ca5e57e85a --- /dev/null +++ b/crates/acp_tools/src/acp_tools.rs @@ -0,0 +1,494 @@ +use std::{ + cell::RefCell, + collections::HashSet, + fmt::Display, + rc::{Rc, Weak}, + sync::Arc, +}; + +use agent_client_protocol as acp; +use collections::HashMap; +use gpui::{ + App, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ListAlignment, ListState, + StyleRefinement, Subscription, Task, TextStyleRefinement, Window, actions, list, prelude::*, +}; +use language::LanguageRegistry; +use markdown::{CodeBlockRenderer, Markdown, MarkdownElement, MarkdownStyle}; +use project::Project; +use settings::Settings; +use theme::ThemeSettings; +use ui::prelude::*; +use util::ResultExt as _; +use workspace::{Item, Workspace}; + +actions!(acp, [OpenDebugTools]); + +pub fn init(cx: &mut App) { + cx.observe_new( + |workspace: &mut Workspace, _window, _cx: &mut Context| { + workspace.register_action(|workspace, _: &OpenDebugTools, window, cx| { + let acp_tools = + Box::new(cx.new(|cx| AcpTools::new(workspace.project().clone(), cx))); + workspace.add_item_to_active_pane(acp_tools, None, true, window, cx); + }); + }, + ) + .detach(); +} + +struct GlobalAcpConnectionRegistry(Entity); + +impl Global for GlobalAcpConnectionRegistry {} + +#[derive(Default)] +pub struct AcpConnectionRegistry { + active_connection: RefCell>, +} + +struct ActiveConnection { + server_name: &'static str, + connection: Weak, +} + +impl AcpConnectionRegistry { + pub fn default_global(cx: &mut App) -> Entity { + if cx.has_global::() { + cx.global::().0.clone() + } else { + let registry = cx.new(|_cx| AcpConnectionRegistry::default()); + cx.set_global(GlobalAcpConnectionRegistry(registry.clone())); + registry + } + } + + pub fn set_active_connection( + &self, + server_name: &'static str, + connection: &Rc, + cx: &mut Context, + ) { + self.active_connection.replace(Some(ActiveConnection { + server_name, + connection: Rc::downgrade(connection), + })); + cx.notify(); + } +} + +struct AcpTools { + project: Entity, + focus_handle: FocusHandle, + expanded: HashSet, + watched_connection: Option, + connection_registry: Entity, + _subscription: Subscription, +} + +struct WatchedConnection { + server_name: &'static str, + messages: Vec, + list_state: ListState, + connection: Weak, + incoming_request_methods: HashMap>, + outgoing_request_methods: HashMap>, + _task: Task<()>, +} + +impl AcpTools { + fn new(project: Entity, cx: &mut Context) -> Self { + let connection_registry = AcpConnectionRegistry::default_global(cx); + + let subscription = cx.observe(&connection_registry, |this, _, cx| { + this.update_connection(cx); + cx.notify(); + }); + + let mut this = Self { + project, + focus_handle: cx.focus_handle(), + expanded: HashSet::default(), + watched_connection: None, + connection_registry, + _subscription: subscription, + }; + this.update_connection(cx); + this + } + + fn update_connection(&mut self, cx: &mut Context) { + let active_connection = self.connection_registry.read(cx).active_connection.borrow(); + let Some(active_connection) = active_connection.as_ref() else { + return; + }; + + if let Some(watched_connection) = self.watched_connection.as_ref() { + if Weak::ptr_eq( + &watched_connection.connection, + &active_connection.connection, + ) { + return; + } + } + + if let Some(connection) = active_connection.connection.upgrade() { + let mut receiver = connection.subscribe(); + let task = cx.spawn(async move |this, cx| { + while let Ok(message) = receiver.recv().await { + this.update(cx, |this, cx| { + this.push_stream_message(message, cx); + }) + .ok(); + } + }); + + self.watched_connection = Some(WatchedConnection { + server_name: active_connection.server_name, + messages: vec![], + list_state: ListState::new(0, ListAlignment::Bottom, px(2048.)), + connection: active_connection.connection.clone(), + incoming_request_methods: HashMap::default(), + outgoing_request_methods: HashMap::default(), + _task: task, + }); + } + } + + fn push_stream_message(&mut self, stream_message: acp::StreamMessage, cx: &mut Context) { + let Some(connection) = self.watched_connection.as_mut() else { + return; + }; + let language_registry = self.project.read(cx).languages().clone(); + let index = connection.messages.len(); + + let (request_id, method, message_type, params) = match stream_message.message { + acp::StreamMessageContent::Request { id, method, params } => { + let method_map = match stream_message.direction { + acp::StreamMessageDirection::Incoming => { + &mut connection.incoming_request_methods + } + acp::StreamMessageDirection::Outgoing => { + &mut connection.outgoing_request_methods + } + }; + + method_map.insert(id, method.clone()); + (Some(id), method.into(), MessageType::Request, Ok(params)) + } + acp::StreamMessageContent::Response { id, result } => { + let method_map = match stream_message.direction { + acp::StreamMessageDirection::Incoming => { + &mut connection.outgoing_request_methods + } + acp::StreamMessageDirection::Outgoing => { + &mut connection.incoming_request_methods + } + }; + + if let Some(method) = method_map.remove(&id) { + (Some(id), method.into(), MessageType::Response, result) + } else { + ( + Some(id), + "[unrecognized response]".into(), + MessageType::Response, + result, + ) + } + } + acp::StreamMessageContent::Notification { method, params } => { + (None, method.into(), MessageType::Notification, Ok(params)) + } + }; + + let message = WatchedConnectionMessage { + name: method, + message_type, + request_id, + direction: stream_message.direction, + collapsed_params_md: match params.as_ref() { + Ok(params) => params + .as_ref() + .map(|params| collapsed_params_md(params, &language_registry, cx)), + Err(err) => { + if let Ok(err) = &serde_json::to_value(err) { + Some(collapsed_params_md(&err, &language_registry, cx)) + } else { + None + } + } + }, + + expanded_params_md: None, + params, + }; + + connection.messages.push(message); + connection.list_state.splice(index..index, 1); + cx.notify(); + } + + fn render_message( + &mut self, + index: usize, + window: &mut Window, + cx: &mut Context, + ) -> AnyElement { + let Some(connection) = self.watched_connection.as_ref() else { + return Empty.into_any(); + }; + + let Some(message) = connection.messages.get(index) else { + return Empty.into_any(); + }; + + let base_size = TextSize::Editor.rems(cx); + + let theme_settings = ThemeSettings::get_global(cx); + let text_style = window.text_style(); + + let colors = cx.theme().colors(); + let expanded = self.expanded.contains(&index); + + v_flex() + .w_full() + .px_4() + .py_3() + .border_color(colors.border) + .border_b_1() + .gap_2() + .items_start() + .font_buffer(cx) + .text_size(base_size) + .id(index) + .group("message") + .hover(|this| this.bg(colors.element_background.opacity(0.5))) + .on_click(cx.listener(move |this, _, _, cx| { + if this.expanded.contains(&index) { + this.expanded.remove(&index); + } else { + this.expanded.insert(index); + let Some(connection) = &mut this.watched_connection else { + return; + }; + let Some(message) = connection.messages.get_mut(index) else { + return; + }; + message.expanded(this.project.read(cx).languages().clone(), cx); + connection.list_state.scroll_to_reveal_item(index); + } + cx.notify() + })) + .child( + h_flex() + .w_full() + .gap_2() + .items_center() + .flex_shrink_0() + .child(match message.direction { + acp::StreamMessageDirection::Incoming => { + ui::Icon::new(ui::IconName::ArrowDown).color(Color::Error) + } + acp::StreamMessageDirection::Outgoing => { + ui::Icon::new(ui::IconName::ArrowUp).color(Color::Success) + } + }) + .child( + Label::new(message.name.clone()) + .buffer_font(cx) + .color(Color::Muted), + ) + .child(div().flex_1()) + .child( + div() + .child(ui::Chip::new(message.message_type.to_string())) + .visible_on_hover("message"), + ) + .children( + message + .request_id + .map(|req_id| div().child(ui::Chip::new(req_id.to_string()))), + ), + ) + // I'm aware using markdown is a hack. Trying to get something working for the demo. + // Will clean up soon! + .when_some( + if expanded { + message.expanded_params_md.clone() + } else { + message.collapsed_params_md.clone() + }, + |this, params| { + this.child( + div().pl_6().w_full().child( + MarkdownElement::new( + params, + MarkdownStyle { + base_text_style: text_style, + selection_background_color: colors.element_selection_background, + syntax: cx.theme().syntax().clone(), + code_block_overflow_x_scroll: true, + code_block: StyleRefinement { + text: Some(TextStyleRefinement { + font_family: Some( + theme_settings.buffer_font.family.clone(), + ), + font_size: Some((base_size * 0.8).into()), + ..Default::default() + }), + ..Default::default() + }, + ..Default::default() + }, + ) + .code_block_renderer( + CodeBlockRenderer::Default { + copy_button: false, + copy_button_on_hover: expanded, + border: false, + }, + ), + ), + ) + }, + ) + .into_any() + } +} + +struct WatchedConnectionMessage { + name: SharedString, + request_id: Option, + direction: acp::StreamMessageDirection, + message_type: MessageType, + params: Result, acp::Error>, + collapsed_params_md: Option>, + expanded_params_md: Option>, +} + +impl WatchedConnectionMessage { + fn expanded(&mut self, language_registry: Arc, cx: &mut App) { + let params_md = match &self.params { + Ok(Some(params)) => Some(expanded_params_md(params, &language_registry, cx)), + Err(err) => { + if let Some(err) = &serde_json::to_value(err).log_err() { + Some(expanded_params_md(&err, &language_registry, cx)) + } else { + None + } + } + _ => None, + }; + self.expanded_params_md = params_md; + } +} + +fn collapsed_params_md( + params: &serde_json::Value, + language_registry: &Arc, + cx: &mut App, +) -> Entity { + let params_json = serde_json::to_string(params).unwrap_or_default(); + let mut spaced_out_json = String::with_capacity(params_json.len() + params_json.len() / 4); + + for ch in params_json.chars() { + match ch { + '{' => spaced_out_json.push_str("{ "), + '}' => spaced_out_json.push_str(" }"), + ':' => spaced_out_json.push_str(": "), + ',' => spaced_out_json.push_str(", "), + c => spaced_out_json.push(c), + } + } + + let params_md = format!("```json\n{}\n```", spaced_out_json); + cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx)) +} + +fn expanded_params_md( + params: &serde_json::Value, + language_registry: &Arc, + cx: &mut App, +) -> Entity { + let params_json = serde_json::to_string_pretty(params).unwrap_or_default(); + let params_md = format!("```json\n{}\n```", params_json); + cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx)) +} + +enum MessageType { + Request, + Response, + Notification, +} + +impl Display for MessageType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MessageType::Request => write!(f, "Request"), + MessageType::Response => write!(f, "Response"), + MessageType::Notification => write!(f, "Notification"), + } + } +} + +enum AcpToolsEvent {} + +impl EventEmitter for AcpTools {} + +impl Item for AcpTools { + type Event = AcpToolsEvent; + + fn tab_content_text(&self, _detail: usize, _cx: &App) -> ui::SharedString { + format!( + "ACP: {}", + self.watched_connection + .as_ref() + .map_or("Disconnected", |connection| connection.server_name) + ) + .into() + } + + fn tab_icon(&self, _window: &Window, _cx: &App) -> Option { + Some(ui::Icon::new(IconName::Thread)) + } +} + +impl Focusable for AcpTools { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for AcpTools { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + v_flex() + .track_focus(&self.focus_handle) + .size_full() + .bg(cx.theme().colors().editor_background) + .child(match self.watched_connection.as_ref() { + Some(connection) => { + if connection.messages.is_empty() { + h_flex() + .size_full() + .justify_center() + .items_center() + .child("No messages recorded yet") + .into_any() + } else { + list( + connection.list_state.clone(), + cx.processor(Self::render_message), + ) + .with_sizing_behavior(gpui::ListSizingBehavior::Auto) + .flex_grow() + .into_any() + } + } + None => h_flex() + .size_full() + .justify_center() + .items_center() + .child("No active connection") + .into_any(), + }) + } +} diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 45e551dbdf..cba2457566 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -893,8 +893,19 @@ impl ThreadsDatabase { let needs_migration_from_heed = mdb_path.exists(); - let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) { + let connection = if *ZED_STATELESS { Connection::open_memory(Some("THREAD_FALLBACK_DB")) + } else if cfg!(any(feature = "test-support", test)) { + // rust stores the name of the test on the current thread. + // We use this to automatically create a database that will + // be shared within the test (for the test_retrieve_old_thread) + // but not with concurrent tests. + let thread = std::thread::current(); + let test_name = thread.name(); + Connection::open_memory(Some(&format!( + "THREAD_FALLBACK_{}", + test_name.unwrap_or_default() + ))) } else { Connection::open_file(&sqlite_path.to_string_lossy()) }; diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index 1b88955a24..e7d31c0c7a 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -266,8 +266,19 @@ impl ThreadsDatabase { } pub fn new(executor: BackgroundExecutor) -> Result { - let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) { + let connection = if *ZED_STATELESS { Connection::open_memory(Some("THREAD_FALLBACK_DB")) + } else if cfg!(any(feature = "test-support", test)) { + // rust stores the name of the test on the current thread. + // We use this to automatically create a database that will + // be shared within the test (for the test_retrieve_old_thread) + // but not with concurrent tests. + let thread = std::thread::current(); + let test_name = thread.name(); + Connection::open_memory(Some(&format!( + "THREAD_FALLBACK_{}", + test_name.unwrap_or_default() + ))) } else { let threads_dir = paths::data_dir().join("threads"); std::fs::create_dir_all(&threads_dir)?; diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent2/src/native_agent_server.rs index ac5aa95c04..4ce467d6fd 100644 --- a/crates/agent2/src/native_agent_server.rs +++ b/crates/agent2/src/native_agent_server.rs @@ -23,11 +23,11 @@ impl NativeAgentServer { impl AgentServer for NativeAgentServer { fn name(&self) -> &'static str { - "Native Agent" + "Zed Agent" } fn empty_state_headline(&self) -> &'static str { - "Welcome to the Agent Panel" + self.name() } fn empty_state_message(&self) -> &'static str { diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 09048488c8..60b3198081 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -4,26 +4,35 @@ use agent_client_protocol::{self as acp}; use agent_settings::AgentProfileId; use anyhow::Result; use client::{Client, UserStore}; +use context_server::{ContextServer, ContextServerCommand, ContextServerId}; use fs::{FakeFs, Fs}; -use futures::{StreamExt, channel::mpsc::UnboundedReceiver}; +use futures::{ + StreamExt, + channel::{ + mpsc::{self, UnboundedReceiver}, + oneshot, + }, +}; use gpui::{ App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient, }; use indoc::indoc; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, - LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage, - LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason, - fake_provider::FakeLanguageModel, + LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat, + LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel, }; use pretty_assertions::assert_eq; -use project::Project; +use project::{ + Project, context_server_store::ContextServerStore, project_settings::ProjectSettings, +}; use prompt_store::ProjectContext; use reqwest_client::ReqwestClient; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::json; -use settings::SettingsStore; +use settings::{Settings, SettingsStore}; use std::{path::Path, rc::Rc, sync::Arc, time::Duration}; use util::path; @@ -931,6 +940,334 @@ async fn test_profiles(cx: &mut TestAppContext) { assert_eq!(tool_names, vec![InfiniteTool::name()]); } +#[gpui::test] +async fn test_mcp_tools(cx: &mut TestAppContext) { + let ThreadTest { + model, + thread, + context_server_store, + fs, + .. + } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + // Override profiles and wait for settings to be loaded. + fs.insert_file( + paths::settings_file(), + json!({ + "agent": { + "profiles": { + "test": { + "name": "Test Profile", + "enable_all_context_servers": true, + "tools": { + EchoTool::name(): true, + } + }, + } + } + }) + .to_string() + .into_bytes(), + ) + .await; + cx.run_until_parked(); + thread.update(cx, |thread, _| { + thread.set_profile(AgentProfileId("test".into())) + }); + + let mut mcp_tool_calls = setup_context_server( + "test_server", + vec![context_server::types::Tool { + name: "echo".into(), + description: None, + input_schema: serde_json::to_value( + EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema), + ) + .unwrap(), + output_schema: None, + annotations: None, + }], + &context_server_store, + cx, + ); + + let events = thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hey"], cx).unwrap() + }); + cx.run_until_parked(); + + // Simulate the model calling the MCP tool. + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!(tool_names_for_completion(&completion), vec!["echo"]); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_1".into(), + name: "echo".into(), + raw_input: json!({"text": "test"}).to_string(), + input: json!({"text": "test"}), + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap(); + assert_eq!(tool_call_params.name, "echo"); + assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"}))); + tool_call_response + .send(context_server::types::CallToolResponse { + content: vec![context_server::types::ToolResponseContent::Text { + text: "test".into(), + }], + is_error: None, + meta: None, + structured_content: None, + }) + .unwrap(); + cx.run_until_parked(); + + assert_eq!(tool_names_for_completion(&completion), vec!["echo"]); + fake_model.send_last_completion_stream_text_chunk("Done!"); + fake_model.end_last_completion_stream(); + events.collect::>().await; + + // Send again after adding the echo tool, ensuring the name collision is resolved. + let events = thread.update(cx, |thread, cx| { + thread.add_tool(EchoTool); + thread.send(UserMessageId::new(), ["Go"], cx).unwrap() + }); + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + tool_names_for_completion(&completion), + vec!["echo", "test_server_echo"] + ); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_2".into(), + name: "test_server_echo".into(), + raw_input: json!({"text": "mcp"}).to_string(), + input: json!({"text": "mcp"}), + is_input_complete: true, + }, + )); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_3".into(), + name: "echo".into(), + raw_input: json!({"text": "native"}).to_string(), + input: json!({"text": "native"}), + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap(); + assert_eq!(tool_call_params.name, "echo"); + assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"}))); + tool_call_response + .send(context_server::types::CallToolResponse { + content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }], + is_error: None, + meta: None, + structured_content: None, + }) + .unwrap(); + cx.run_until_parked(); + + // Ensure the tool results were inserted with the correct names. + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages.last().unwrap().content, + vec![ + MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: "tool_3".into(), + tool_name: "echo".into(), + is_error: false, + content: "native".into(), + output: Some("native".into()), + },), + MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: "tool_2".into(), + tool_name: "test_server_echo".into(), + is_error: false, + content: "mcp".into(), + output: Some("mcp".into()), + },), + ] + ); + fake_model.end_last_completion_stream(); + events.collect::>().await; +} + +#[gpui::test] +async fn test_mcp_tool_truncation(cx: &mut TestAppContext) { + let ThreadTest { + model, + thread, + context_server_store, + fs, + .. + } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + // Set up a profile with all tools enabled + fs.insert_file( + paths::settings_file(), + json!({ + "agent": { + "profiles": { + "test": { + "name": "Test Profile", + "enable_all_context_servers": true, + "tools": { + EchoTool::name(): true, + DelayTool::name(): true, + WordListTool::name(): true, + ToolRequiringPermission::name(): true, + InfiniteTool::name(): true, + } + }, + } + } + }) + .to_string() + .into_bytes(), + ) + .await; + cx.run_until_parked(); + + thread.update(cx, |thread, _| { + thread.set_profile(AgentProfileId("test".into())); + thread.add_tool(EchoTool); + thread.add_tool(DelayTool); + thread.add_tool(WordListTool); + thread.add_tool(ToolRequiringPermission); + thread.add_tool(InfiniteTool); + }); + + // Set up multiple context servers with some overlapping tool names + let _server1_calls = setup_context_server( + "xxx", + vec![ + context_server::types::Tool { + name: "echo".into(), // Conflicts with native EchoTool + description: None, + input_schema: serde_json::to_value( + EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema), + ) + .unwrap(), + output_schema: None, + annotations: None, + }, + context_server::types::Tool { + name: "unique_tool_1".into(), + description: None, + input_schema: json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + }, + ], + &context_server_store, + cx, + ); + + let _server2_calls = setup_context_server( + "yyy", + vec![ + context_server::types::Tool { + name: "echo".into(), // Also conflicts with native EchoTool + description: None, + input_schema: serde_json::to_value( + EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema), + ) + .unwrap(), + output_schema: None, + annotations: None, + }, + context_server::types::Tool { + name: "unique_tool_2".into(), + description: None, + input_schema: json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + }, + context_server::types::Tool { + name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2), + description: None, + input_schema: json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + }, + context_server::types::Tool { + name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1), + description: None, + input_schema: json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + }, + ], + &context_server_store, + cx, + ); + let _server3_calls = setup_context_server( + "zzz", + vec![ + context_server::types::Tool { + name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2), + description: None, + input_schema: json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + }, + context_server::types::Tool { + name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1), + description: None, + input_schema: json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + }, + context_server::types::Tool { + name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1), + description: None, + input_schema: json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + }, + ], + &context_server_store, + cx, + ); + + thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Go"], cx) + }) + .unwrap(); + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + tool_names_for_completion(&completion), + vec![ + "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", + "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc", + "delay", + "echo", + "infinite", + "tool_requiring_permission", + "unique_tool_1", + "unique_tool_2", + "word_list", + "xxx_echo", + "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + "yyy_echo", + "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + ] + ); +} + #[gpui::test] #[cfg_attr(not(feature = "e2e"), ignore)] async fn test_cancellation(cx: &mut TestAppContext) { @@ -1806,6 +2143,7 @@ struct ThreadTest { model: Arc, thread: Entity, project_context: Entity, + context_server_store: Entity, fs: Arc, } @@ -1844,6 +2182,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { WordListTool::name(): true, ToolRequiringPermission::name(): true, InfiniteTool::name(): true, + ThinkingTool::name(): true, } } } @@ -1900,8 +2239,9 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { .await; let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); let thread = cx.new(|cx| { Thread::new( project, @@ -1916,6 +2256,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { model, thread, project_context, + context_server_store, fs, } } @@ -1950,3 +2291,89 @@ fn watch_settings(fs: Arc, cx: &mut App) { }) .detach(); } + +fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec { + completion + .tools + .iter() + .map(|tool| tool.name.clone()) + .collect() +} + +fn setup_context_server( + name: &'static str, + tools: Vec, + context_server_store: &Entity, + cx: &mut TestAppContext, +) -> mpsc::UnboundedReceiver<( + context_server::types::CallToolParams, + oneshot::Sender, +)> { + cx.update(|cx| { + let mut settings = ProjectSettings::get_global(cx).clone(); + settings.context_servers.insert( + name.into(), + project::project_settings::ContextServerSettings::Custom { + enabled: true, + command: ContextServerCommand { + path: "somebinary".into(), + args: Vec::new(), + env: None, + }, + }, + ); + ProjectSettings::override_global(settings, cx); + }); + + let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded(); + let fake_transport = context_server::test::create_fake_transport(name, cx.executor()) + .on_request::(move |_params| async move { + context_server::types::InitializeResponse { + protocol_version: context_server::types::ProtocolVersion( + context_server::types::LATEST_PROTOCOL_VERSION.to_string(), + ), + server_info: context_server::types::Implementation { + name: name.into(), + version: "1.0.0".to_string(), + }, + capabilities: context_server::types::ServerCapabilities { + tools: Some(context_server::types::ToolsCapabilities { + list_changed: Some(true), + }), + ..Default::default() + }, + meta: None, + } + }) + .on_request::(move |_params| { + let tools = tools.clone(); + async move { + context_server::types::ListToolsResponse { + tools, + next_cursor: None, + meta: None, + } + } + }) + .on_request::(move |params| { + let mcp_tool_calls_tx = mcp_tool_calls_tx.clone(); + async move { + let (response_tx, response_rx) = oneshot::channel(); + mcp_tool_calls_tx + .unbounded_send((params, response_tx)) + .unwrap(); + response_rx.await.unwrap() + } + }); + context_server_store.update(cx, |store, cx| { + store.start_server( + Arc::new(ContextServer::new( + ContextServerId(name.into()), + Arc::new(fake_transport), + )), + cx, + ); + }); + cx.run_until_parked(); + mcp_tool_calls_rx +} diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index af18afa055..6d616f73fc 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -9,15 +9,15 @@ use action_log::ActionLog; use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot}; use agent_client_protocol as acp; use agent_settings::{ - AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT, - SUMMARIZE_THREAD_PROMPT, + AgentProfileId, AgentProfileSettings, AgentSettings, CompletionMode, + SUMMARIZE_THREAD_DETAILED_PROMPT, SUMMARIZE_THREAD_PROMPT, }; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; use chrono::{DateTime, Utc}; use client::{ModelRequestUsage, RequestUsage}; use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; -use collections::{HashMap, IndexMap}; +use collections::{HashMap, HashSet, IndexMap}; use fs::Fs; use futures::{ FutureExt, @@ -45,17 +45,19 @@ use schemars::{JsonSchema, Schema}; use serde::{Deserialize, Serialize}; use settings::{Settings, update_settings_file}; use smol::stream::StreamExt; +use std::fmt::Write; use std::{ collections::BTreeMap, + ops::RangeInclusive, path::Path, sync::Arc, time::{Duration, Instant}, }; -use std::{fmt::Write, ops::Range}; -use util::{ResultExt, markdown::MarkdownCodeBlock}; +use util::{ResultExt, debug_panic, markdown::MarkdownCodeBlock}; use uuid::Uuid; const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user"; +pub const MAX_TOOL_NAME_LENGTH: usize = 64; /// The ID of the user prompt that initiated a request. /// @@ -186,6 +188,7 @@ impl UserMessage { const OPEN_FILES_TAG: &str = ""; const OPEN_DIRECTORIES_TAG: &str = ""; const OPEN_SYMBOLS_TAG: &str = ""; + const OPEN_SELECTIONS_TAG: &str = ""; const OPEN_THREADS_TAG: &str = ""; const OPEN_FETCH_TAG: &str = ""; const OPEN_RULES_TAG: &str = @@ -194,6 +197,7 @@ impl UserMessage { let mut file_context = OPEN_FILES_TAG.to_string(); let mut directory_context = OPEN_DIRECTORIES_TAG.to_string(); let mut symbol_context = OPEN_SYMBOLS_TAG.to_string(); + let mut selection_context = OPEN_SELECTIONS_TAG.to_string(); let mut thread_context = OPEN_THREADS_TAG.to_string(); let mut fetch_context = OPEN_FETCH_TAG.to_string(); let mut rules_context = OPEN_RULES_TAG.to_string(); @@ -210,7 +214,7 @@ impl UserMessage { match uri { MentionUri::File { abs_path } => { write!( - &mut symbol_context, + &mut file_context, "\n{}", MarkdownCodeBlock { tag: &codeblock_tag(abs_path, None), @@ -219,17 +223,19 @@ impl UserMessage { ) .ok(); } + MentionUri::PastedImage => { + debug_panic!("pasted image URI should not be used in mention content") + } MentionUri::Directory { .. } => { write!(&mut directory_context, "\n{}\n", content).ok(); } MentionUri::Symbol { - path, line_range, .. - } - | MentionUri::Selection { - path, line_range, .. + abs_path: path, + line_range, + .. } => { write!( - &mut rules_context, + &mut symbol_context, "\n{}", MarkdownCodeBlock { tag: &codeblock_tag(path, Some(line_range)), @@ -238,6 +244,24 @@ impl UserMessage { ) .ok(); } + MentionUri::Selection { + abs_path: path, + line_range, + .. + } => { + write!( + &mut selection_context, + "\n{}", + MarkdownCodeBlock { + tag: &codeblock_tag( + path.as_deref().unwrap_or("Untitled".as_ref()), + Some(line_range) + ), + text: content + } + ) + .ok(); + } MentionUri::Thread { .. } => { write!(&mut thread_context, "\n{}\n", content).ok(); } @@ -290,6 +314,13 @@ impl UserMessage { .push(language_model::MessageContent::Text(symbol_context)); } + if selection_context.len() > OPEN_SELECTIONS_TAG.len() { + selection_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(selection_context)); + } + if thread_context.len() > OPEN_THREADS_TAG.len() { thread_context.push_str("\n"); message @@ -325,7 +356,7 @@ impl UserMessage { } } -fn codeblock_tag(full_path: &Path, line_range: Option<&Range>) -> String { +fn codeblock_tag(full_path: &Path, line_range: Option<&RangeInclusive>) -> String { let mut result = String::new(); if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) { @@ -335,10 +366,10 @@ fn codeblock_tag(full_path: &Path, line_range: Option<&Range>) -> String { let _ = write!(result, "{}", full_path.display()); if let Some(range) = line_range { - if range.start == range.end { - let _ = write!(result, ":{}", range.start + 1); + if range.start() == range.end() { + let _ = write!(result, ":{}", range.start() + 1); } else { - let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1); + let _ = write!(result, ":{}-{}", range.start() + 1, range.end() + 1); } } @@ -627,7 +658,20 @@ impl Thread { stream: &ThreadEventStream, cx: &mut Context, ) { - let Some(tool) = self.tools.get(tool_use.name.as_ref()) else { + let tool = self.tools.get(tool_use.name.as_ref()).cloned().or_else(|| { + self.context_server_registry + .read(cx) + .servers() + .find_map(|(_, tools)| { + if let Some(tool) = tools.get(tool_use.name.as_ref()) { + Some(tool.clone()) + } else { + None + } + }) + }); + + let Some(tool) = tool else { stream .0 .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall { @@ -1079,6 +1123,10 @@ impl Thread { self.cancel(cx); let model = self.model.clone().context("No language model configured")?; + let profile = AgentSettings::get_global(cx) + .profiles + .get(&self.profile_id) + .context("Profile not found")?; let (events_tx, events_rx) = mpsc::unbounded::>(); let event_stream = ThreadEventStream(events_tx); let message_ix = self.messages.len().saturating_sub(1); @@ -1086,6 +1134,7 @@ impl Thread { self.summary = None; self.running_turn = Some(RunningTurn { event_stream: event_stream.clone(), + tools: self.enabled_tools(profile, &model, cx), _task: cx.spawn(async move |this, cx| { log::info!("Starting agent turn execution"); @@ -1417,7 +1466,7 @@ impl Thread { ) -> Option> { cx.notify(); - let tool = self.tools.get(tool_use.name.as_ref()).cloned(); + let tool = self.tool(tool_use.name.as_ref()); let mut title = SharedString::from(&tool_use.name); let mut kind = acp::ToolKind::Other; if let Some(tool) = tool.as_ref() { @@ -1727,6 +1776,21 @@ impl Thread { cx: &mut App, ) -> Result { let model = self.model().context("No language model configured")?; + let tools = if let Some(turn) = self.running_turn.as_ref() { + turn.tools + .iter() + .filter_map(|(tool_name, tool)| { + log::trace!("Including tool: {}", tool_name); + Some(LanguageModelRequestTool { + name: tool_name.to_string(), + description: tool.description().to_string(), + input_schema: tool.input_schema(model.tool_input_format()).log_err()?, + }) + }) + .collect::>() + } else { + Vec::new() + }; log::debug!("Building completion request"); log::debug!("Completion intent: {:?}", completion_intent); @@ -1734,23 +1798,6 @@ impl Thread { let messages = self.build_request_messages(cx); log::info!("Request will include {} messages", messages.len()); - - let tools = if let Some(tools) = self.tools(cx).log_err() { - tools - .filter_map(|tool| { - let tool_name = tool.name().to_string(); - log::trace!("Including tool: {}", tool_name); - Some(LanguageModelRequestTool { - name: tool_name, - description: tool.description().to_string(), - input_schema: tool.input_schema(model.tool_input_format()).log_err()?, - }) - }) - .collect() - } else { - Vec::new() - }; - log::info!("Request includes {} tools", tools.len()); let request = LanguageModelRequest { @@ -1770,37 +1817,76 @@ impl Thread { Ok(request) } - fn tools<'a>(&'a self, cx: &'a App) -> Result>> { - let model = self.model().context("No language model configured")?; + fn enabled_tools( + &self, + profile: &AgentProfileSettings, + model: &Arc, + cx: &App, + ) -> BTreeMap> { + fn truncate(tool_name: &SharedString) -> SharedString { + if tool_name.len() > MAX_TOOL_NAME_LENGTH { + let mut truncated = tool_name.to_string(); + truncated.truncate(MAX_TOOL_NAME_LENGTH); + truncated.into() + } else { + tool_name.clone() + } + } - let profile = AgentSettings::get_global(cx) - .profiles - .get(&self.profile_id) - .context("profile not found")?; - let provider_id = model.provider_id(); - - Ok(self + let mut tools = self .tools .iter() - .filter(move |(_, tool)| tool.supported_provider(&provider_id)) .filter_map(|(tool_name, tool)| { - if profile.is_tool_enabled(tool_name) { - Some(tool) + if tool.supported_provider(&model.provider_id()) + && profile.is_tool_enabled(tool_name) + { + Some((truncate(tool_name), tool.clone())) } else { None } }) - .chain(self.context_server_registry.read(cx).servers().flat_map( - |(server_id, tools)| { - tools.iter().filter_map(|(tool_name, tool)| { - if profile.is_context_server_tool_enabled(&server_id.0, tool_name) { - Some(tool) - } else { - None - } - }) - }, - ))) + .collect::>(); + + let mut context_server_tools = Vec::new(); + let mut seen_tools = tools.keys().cloned().collect::>(); + let mut duplicate_tool_names = HashSet::default(); + for (server_id, server_tools) in self.context_server_registry.read(cx).servers() { + for (tool_name, tool) in server_tools { + if profile.is_context_server_tool_enabled(&server_id.0, &tool_name) { + let tool_name = truncate(tool_name); + if !seen_tools.insert(tool_name.clone()) { + duplicate_tool_names.insert(tool_name.clone()); + } + context_server_tools.push((server_id.clone(), tool_name, tool.clone())); + } + } + } + + // When there are duplicate tool names, disambiguate by prefixing them + // with the server ID. In the rare case there isn't enough space for the + // disambiguated tool name, keep only the last tool with this name. + for (server_id, tool_name, tool) in context_server_tools { + if duplicate_tool_names.contains(&tool_name) { + let available = MAX_TOOL_NAME_LENGTH.saturating_sub(tool_name.len()); + if available >= 2 { + let mut disambiguated = server_id.0.to_string(); + disambiguated.truncate(available - 1); + disambiguated.push('_'); + disambiguated.push_str(&tool_name); + tools.insert(disambiguated.into(), tool.clone()); + } else { + tools.insert(tool_name, tool.clone()); + } + } else { + tools.insert(tool_name, tool.clone()); + } + } + + tools + } + + fn tool(&self, name: &str) -> Option> { + self.running_turn.as_ref()?.tools.get(name).cloned() } fn build_request_messages(&self, cx: &App) -> Vec { @@ -1965,6 +2051,8 @@ struct RunningTurn { /// The current event stream for the running turn. Used to report a final /// cancellation event if we cancel the turn. event_stream: ThreadEventStream, + /// The tools that were enabled for this turn. + tools: BTreeMap>, } impl RunningTurn { diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 60dd796463..9f90f3a78a 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -17,11 +17,11 @@ path = "src/agent_servers.rs" doctest = false [dependencies] +acp_tools.workspace = true acp_thread.workspace = true action_log.workspace = true agent-client-protocol.workspace = true agent_settings.workspace = true -agentic-coding-protocol.workspace = true anyhow.workspace = true client = { workspace = true, optional = true } collections.workspace = true diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index 1cfb1fcabf..a99a401431 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -1,34 +1,391 @@ -use std::{path::Path, rc::Rc}; - use crate::AgentServerCommand; use acp_thread::AgentConnection; -use anyhow::Result; -use gpui::AsyncApp; +use acp_tools::AcpConnectionRegistry; +use action_log::ActionLog; +use agent_client_protocol::{self as acp, Agent as _, ErrorCode}; +use anyhow::anyhow; +use collections::HashMap; +use futures::AsyncBufReadExt as _; +use futures::channel::oneshot; +use futures::io::BufReader; +use project::Project; +use serde::Deserialize; +use std::{any::Any, cell::RefCell}; +use std::{path::Path, rc::Rc}; use thiserror::Error; -mod v0; -mod v1; +use anyhow::{Context as _, Result}; +use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; + +use acp_thread::{AcpThread, AuthRequired, LoadError}; #[derive(Debug, Error)] #[error("Unsupported version")] pub struct UnsupportedVersion; +pub struct AcpConnection { + server_name: &'static str, + connection: Rc, + sessions: Rc>>, + auth_methods: Vec, + prompt_capabilities: acp::PromptCapabilities, + _io_task: Task>, +} + +pub struct AcpSession { + thread: WeakEntity, + suppress_abort_err: bool, +} + pub async fn connect( server_name: &'static str, command: AgentServerCommand, root_dir: &Path, cx: &mut AsyncApp, ) -> Result> { - let conn = v1::AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await; + let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await?; + Ok(Rc::new(conn) as _) +} - match conn { - Ok(conn) => Ok(Rc::new(conn) as _), - Err(err) if err.is::() => { - // Consider re-using initialize response and subprocess when adding another version here - let conn: Rc = - Rc::new(v0::AcpConnection::stdio(server_name, command, root_dir, cx).await?); - Ok(conn) +const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1; + +impl AcpConnection { + pub async fn stdio( + server_name: &'static str, + command: AgentServerCommand, + root_dir: &Path, + cx: &mut AsyncApp, + ) -> Result { + let mut child = util::command::new_smol_command(&command.path) + .args(command.args.iter().map(|arg| arg.as_str())) + .envs(command.env.iter().flatten()) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .kill_on_drop(true) + .spawn()?; + + let stdout = child.stdout.take().context("Failed to take stdout")?; + let stdin = child.stdin.take().context("Failed to take stdin")?; + let stderr = child.stderr.take().context("Failed to take stderr")?; + log::trace!("Spawned (pid: {})", child.id()); + + let sessions = Rc::new(RefCell::new(HashMap::default())); + + let client = ClientDelegate { + sessions: sessions.clone(), + cx: cx.clone(), + }; + let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, { + let foreground_executor = cx.foreground_executor().clone(); + move |fut| { + foreground_executor.spawn(fut).detach(); + } + }); + + let io_task = cx.background_spawn(io_task); + + cx.background_spawn(async move { + let mut stderr = BufReader::new(stderr); + let mut line = String::new(); + while let Ok(n) = stderr.read_line(&mut line).await + && n > 0 + { + log::warn!("agent stderr: {}", &line); + line.clear(); + } + }) + .detach(); + + cx.spawn({ + let sessions = sessions.clone(); + async move |cx| { + let status = child.status().await?; + + for session in sessions.borrow().values() { + session + .thread + .update(cx, |thread, cx| { + thread.emit_load_error(LoadError::Exited { status }, cx) + }) + .ok(); + } + + anyhow::Ok(()) + } + }) + .detach(); + + let connection = Rc::new(connection); + + cx.update(|cx| { + AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| { + registry.set_active_connection(server_name, &connection, cx) + }); + })?; + + let response = connection + .initialize(acp::InitializeRequest { + protocol_version: acp::VERSION, + client_capabilities: acp::ClientCapabilities { + fs: acp::FileSystemCapability { + read_text_file: true, + write_text_file: true, + }, + }, + }) + .await?; + + if response.protocol_version < MINIMUM_SUPPORTED_VERSION { + return Err(UnsupportedVersion.into()); } - Err(err) => Err(err), + + Ok(Self { + auth_methods: response.auth_methods, + connection, + server_name, + sessions, + prompt_capabilities: response.agent_capabilities.prompt_capabilities, + _io_task: io_task, + }) + } +} + +impl AgentConnection for AcpConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut App, + ) -> Task>> { + let conn = self.connection.clone(); + let sessions = self.sessions.clone(); + let cwd = cwd.to_path_buf(); + cx.spawn(async move |cx| { + let response = conn + .new_session(acp::NewSessionRequest { + mcp_servers: vec![], + cwd, + }) + .await + .map_err(|err| { + if err.code == acp::ErrorCode::AUTH_REQUIRED.code { + let mut error = AuthRequired::new(); + + if err.message != acp::ErrorCode::AUTH_REQUIRED.message { + error = error.with_description(err.message); + } + + anyhow!(error) + } else { + anyhow!(err) + } + })?; + + let session_id = response.session_id; + let action_log = cx.new(|_| ActionLog::new(project.clone()))?; + let thread = cx.new(|_cx| { + AcpThread::new( + self.server_name, + self.clone(), + project, + action_log, + session_id.clone(), + ) + })?; + + let session = AcpSession { + thread: thread.downgrade(), + suppress_abort_err: false, + }; + sessions.borrow_mut().insert(session_id, session); + + Ok(thread) + }) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &self.auth_methods + } + + fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { + let conn = self.connection.clone(); + cx.foreground_executor().spawn(async move { + let result = conn + .authenticate(acp::AuthenticateRequest { + method_id: method_id.clone(), + }) + .await?; + + Ok(result) + }) + } + + fn prompt( + &self, + _id: Option, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let conn = self.connection.clone(); + let sessions = self.sessions.clone(); + let session_id = params.session_id.clone(); + cx.foreground_executor().spawn(async move { + let result = conn.prompt(params).await; + + let mut suppress_abort_err = false; + + if let Some(session) = sessions.borrow_mut().get_mut(&session_id) { + suppress_abort_err = session.suppress_abort_err; + session.suppress_abort_err = false; + } + + match result { + Ok(response) => Ok(response), + Err(err) => { + if err.code != ErrorCode::INTERNAL_ERROR.code { + anyhow::bail!(err) + } + + let Some(data) = &err.data else { + anyhow::bail!(err) + }; + + // Temporary workaround until the following PR is generally available: + // https://github.com/google-gemini/gemini-cli/pull/6656 + + #[derive(Deserialize)] + #[serde(deny_unknown_fields)] + struct ErrorDetails { + details: Box, + } + + match serde_json::from_value(data.clone()) { + Ok(ErrorDetails { details }) => { + if suppress_abort_err && details.contains("This operation was aborted") + { + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::Cancelled, + }) + } else { + Err(anyhow!(details)) + } + } + Err(_) => Err(anyhow!(err)), + } + } + } + }) + } + + fn prompt_capabilities(&self) -> acp::PromptCapabilities { + self.prompt_capabilities + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) { + session.suppress_abort_err = true; + } + let conn = self.connection.clone(); + let params = acp::CancelNotification { + session_id: session_id.clone(), + }; + cx.foreground_executor() + .spawn(async move { conn.cancel(params).await }) + .detach(); + } + + fn into_any(self: Rc) -> Rc { + self + } +} + +struct ClientDelegate { + sessions: Rc>>, + cx: AsyncApp, +} + +impl acp::Client for ClientDelegate { + async fn request_permission( + &self, + arguments: acp::RequestPermissionRequest, + ) -> Result { + let cx = &mut self.cx.clone(); + let rx = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx) + })?; + + let result = rx?.await; + + let outcome = match result { + Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option }, + Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled, + }; + + Ok(acp::RequestPermissionResponse { outcome }) + } + + async fn write_text_file( + &self, + arguments: acp::WriteTextFileRequest, + ) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + let task = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.write_text_file(arguments.path, arguments.content, cx) + })?; + + task.await?; + + Ok(()) + } + + async fn read_text_file( + &self, + arguments: acp::ReadTextFileRequest, + ) -> Result { + let cx = &mut self.cx.clone(); + let task = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx) + })?; + + let content = task.await?; + + Ok(acp::ReadTextFileResponse { content }) + } + + async fn session_notification( + &self, + notification: acp::SessionNotification, + ) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + let sessions = self.sessions.borrow(); + let session = sessions + .get(¬ification.session_id) + .context("Failed to get session")?; + + session.thread.update(cx, |thread, cx| { + thread.handle_session_update(notification.update, cx) + })??; + + Ok(()) } } diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs index 29f389547d..1945ad2483 100644 --- a/crates/agent_servers/src/acp/v1.rs +++ b/crates/agent_servers/src/acp/v1.rs @@ -1,3 +1,4 @@ +use acp_tools::AcpConnectionRegistry; use action_log::ActionLog; use agent_client_protocol::{self as acp, Agent as _, ErrorCode}; use anyhow::anyhow; @@ -101,6 +102,14 @@ impl AcpConnection { }) .detach(); + let connection = Rc::new(connection); + + cx.update(|cx| { + AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| { + registry.set_active_connection(server_name, &connection, cx) + }); + })?; + let response = connection .initialize(acp::InitializeRequest { protocol_version: acp::VERSION, @@ -119,7 +128,7 @@ impl AcpConnection { Ok(Self { auth_methods: response.auth_methods, - connection: connection.into(), + connection, server_name, sessions, prompt_capabilities: response.agent_capabilities.prompt_capabilities, diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index d6ccabb130..ef666974f1 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -44,7 +44,7 @@ pub struct ClaudeCode; impl AgentServer for ClaudeCode { fn name(&self) -> &'static str { - "Welcome to Claude Code" + "Claude Code" } fn empty_state_headline(&self) -> &'static str { diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index 3b892e7931..29120fff6e 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -23,7 +23,7 @@ impl AgentServer for Gemini { } fn empty_state_headline(&self) -> &'static str { - "Welcome to Gemini CLI" + self.name() } fn empty_state_message(&self) -> &'static str { diff --git a/crates/agent_ui/src/acp/completion_provider.rs b/crates/agent_ui/src/acp/completion_provider.rs index 22a9ea6773..5b40967069 100644 --- a/crates/agent_ui/src/acp/completion_provider.rs +++ b/crates/agent_ui/src/acp/completion_provider.rs @@ -247,9 +247,9 @@ impl ContextPickerCompletionProvider { let abs_path = project.read(cx).absolute_path(&symbol.path, cx)?; let uri = MentionUri::Symbol { - path: abs_path, + abs_path, name: symbol.name.clone(), - line_range: symbol.range.start.0.row..symbol.range.end.0.row, + line_range: symbol.range.start.0.row..=symbol.range.end.0.row, }; let new_text = format!("{} ", uri.as_link()); let new_text_len = new_text.len(); diff --git a/crates/agent_ui/src/acp/message_editor.rs b/crates/agent_ui/src/acp/message_editor.rs index 7d73ebeb19..115008cf52 100644 --- a/crates/agent_ui/src/acp/message_editor.rs +++ b/crates/agent_ui/src/acp/message_editor.rs @@ -6,7 +6,7 @@ use acp_thread::{MentionUri, selection_name}; use agent_client_protocol as acp; use agent_servers::AgentServer; use agent2::HistoryStore; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use assistant_slash_commands::codeblock_fence_for_path; use collections::{HashMap, HashSet}; use editor::{ @@ -17,8 +17,8 @@ use editor::{ display_map::{Crease, CreaseId, FoldId}, }; use futures::{ - FutureExt as _, TryFutureExt as _, - future::{Shared, join_all, try_join_all}, + FutureExt as _, + future::{Shared, join_all}, }; use gpui::{ AppContext, ClipboardEntry, Context, Entity, EventEmitter, FocusHandle, Focusable, @@ -28,14 +28,14 @@ use gpui::{ use language::{Buffer, Language}; use language_model::LanguageModelImage; use project::{CompletionIntent, Project, ProjectItem, ProjectPath, Worktree}; -use prompt_store::PromptStore; +use prompt_store::{PromptId, PromptStore}; use rope::Point; use settings::Settings; use std::{ cell::Cell, ffi::OsStr, fmt::Write, - ops::Range, + ops::{Range, RangeInclusive}, path::{Path, PathBuf}, rc::Rc, sync::Arc, @@ -49,12 +49,8 @@ use ui::{ Render, SelectableButton, SharedString, Styled, TextSize, TintColor, Toggleable, Window, div, h_flex, px, }; -use url::Url; -use util::ResultExt; -use workspace::{ - Toast, Workspace, - notifications::{NotificationId, NotifyResultExt as _}, -}; +use util::{ResultExt, debug_panic}; +use workspace::{Workspace, notifications::NotifyResultExt as _}; use zed_actions::agent::Chat; const PARSE_SLASH_COMMAND_DEBOUNCE: Duration = Duration::from_millis(50); @@ -219,9 +215,9 @@ impl MessageEditor { pub fn mentions(&self) -> HashSet { self.mention_set - .uri_by_crease_id + .mentions .values() - .cloned() + .map(|(uri, _)| uri.clone()) .collect() } @@ -246,132 +242,168 @@ impl MessageEditor { else { return Task::ready(()); }; + let end_anchor = snapshot + .buffer_snapshot + .anchor_before(start_anchor.to_offset(&snapshot.buffer_snapshot) + content_len + 1); - if let MentionUri::File { abs_path, .. } = &mention_uri { - let extension = abs_path - .extension() - .and_then(OsStr::to_str) - .unwrap_or_default(); - - if Img::extensions().contains(&extension) && !extension.contains("svg") { - if !self.prompt_capabilities.get().image { - struct ImagesNotAllowed; - - let end_anchor = snapshot.buffer_snapshot.anchor_before( - start_anchor.to_offset(&snapshot.buffer_snapshot) + content_len + 1, - ); - - self.editor.update(cx, |editor, cx| { - // Remove mention - editor.edit([((start_anchor..end_anchor), "")], cx); - }); - - self.workspace - .update(cx, |workspace, cx| { - workspace.show_toast( - Toast::new( - NotificationId::unique::(), - "This agent does not support images yet", - ) - .autohide(), - cx, - ); - }) - .ok(); - return Task::ready(()); - } - - let project = self.project.clone(); - let Some(project_path) = project - .read(cx) - .project_path_for_absolute_path(abs_path, cx) - else { - return Task::ready(()); - }; - let image = cx - .spawn(async move |_, cx| { - let image = project - .update(cx, |project, cx| project.open_image(project_path, cx)) - .map_err(|e| e.to_string())? - .await - .map_err(|e| e.to_string())?; - image - .read_with(cx, |image, _cx| image.image.clone()) - .map_err(|e| e.to_string()) - }) - .shared(); - let Some(crease_id) = insert_crease_for_image( - *excerpt_id, - start, - content_len, - Some(abs_path.as_path().into()), - image.clone(), - self.editor.clone(), - window, - cx, - ) else { - return Task::ready(()); - }; - return self.confirm_mention_for_image( - crease_id, - start_anchor, - Some(abs_path.clone()), - image, - window, - cx, - ); - } - } - - let Some(crease_id) = crate::context_picker::insert_crease_for_mention( - *excerpt_id, - start, - content_len, - crease_text, - mention_uri.icon_path(cx), - self.editor.clone(), - window, - cx, - ) else { + let crease_id = if let MentionUri::File { abs_path } = &mention_uri + && let Some(extension) = abs_path.extension() + && let Some(extension) = extension.to_str() + && Img::extensions().contains(&extension) + && !extension.contains("svg") + { + let Some(project_path) = self + .project + .read(cx) + .project_path_for_absolute_path(&abs_path, cx) + else { + log::error!("project path not found"); + return Task::ready(()); + }; + let image = self + .project + .update(cx, |project, cx| project.open_image(project_path, cx)); + let image = cx + .spawn(async move |_, cx| { + let image = image.await.map_err(|e| e.to_string())?; + let image = image + .update(cx, |image, _| image.image.clone()) + .map_err(|e| e.to_string())?; + Ok(image) + }) + .shared(); + insert_crease_for_image( + *excerpt_id, + start, + content_len, + Some(abs_path.as_path().into()), + image, + self.editor.clone(), + window, + cx, + ) + } else { + crate::context_picker::insert_crease_for_mention( + *excerpt_id, + start, + content_len, + crease_text, + mention_uri.icon_path(cx), + self.editor.clone(), + window, + cx, + ) + }; + let Some(crease_id) = crease_id else { return Task::ready(()); }; - match mention_uri { - MentionUri::Fetch { url } => { - self.confirm_mention_for_fetch(crease_id, start_anchor, url, window, cx) + let task = match mention_uri.clone() { + MentionUri::Fetch { url } => self.confirm_mention_for_fetch(url, cx), + MentionUri::Directory { abs_path } => self.confirm_mention_for_directory(abs_path, cx), + MentionUri::Thread { id, .. } => self.confirm_mention_for_thread(id, cx), + MentionUri::TextThread { path, .. } => self.confirm_mention_for_text_thread(path, cx), + MentionUri::File { abs_path } => self.confirm_mention_for_file(abs_path, cx), + MentionUri::Symbol { + abs_path, + line_range, + .. + } => self.confirm_mention_for_symbol(abs_path, line_range, cx), + MentionUri::Rule { id, .. } => self.confirm_mention_for_rule(id, cx), + MentionUri::PastedImage => { + debug_panic!("pasted image URI should not be included in completions"); + Task::ready(Err(anyhow!( + "pasted imaged URI should not be included in completions" + ))) } - MentionUri::Directory { abs_path } => { - self.confirm_mention_for_directory(crease_id, start_anchor, abs_path, window, cx) + MentionUri::Selection { .. } => { + // Handled elsewhere + debug_panic!("unexpected selection URI"); + Task::ready(Err(anyhow!("unexpected selection URI"))) } - MentionUri::Thread { id, name } => { - self.confirm_mention_for_thread(crease_id, start_anchor, id, name, window, cx) + }; + let task = cx + .spawn(async move |_, _| task.await.map_err(|e| e.to_string())) + .shared(); + self.mention_set + .mentions + .insert(crease_id, (mention_uri, task.clone())); + + // Notify the user if we failed to load the mentioned context + cx.spawn_in(window, async move |this, cx| { + if task.await.notify_async_err(cx).is_none() { + this.update(cx, |this, cx| { + this.editor.update(cx, |editor, cx| { + // Remove mention + editor.edit([(start_anchor..end_anchor, "")], cx); + }); + this.mention_set.mentions.remove(&crease_id); + }) + .ok(); } - MentionUri::TextThread { path, name } => self.confirm_mention_for_text_thread( - crease_id, - start_anchor, - path, - name, - window, - cx, - ), - MentionUri::File { .. } - | MentionUri::Symbol { .. } - | MentionUri::Rule { .. } - | MentionUri::Selection { .. } => { - self.mention_set.insert_uri(crease_id, mention_uri.clone()); - Task::ready(()) + }) + } + + fn confirm_mention_for_file( + &mut self, + abs_path: PathBuf, + cx: &mut Context, + ) -> Task> { + let Some(project_path) = self + .project + .read(cx) + .project_path_for_absolute_path(&abs_path, cx) + else { + return Task::ready(Err(anyhow!("project path not found"))); + }; + let extension = abs_path + .extension() + .and_then(OsStr::to_str) + .unwrap_or_default(); + + if Img::extensions().contains(&extension) && !extension.contains("svg") { + if !self.prompt_capabilities.get().image { + return Task::ready(Err(anyhow!("This agent does not support images yet"))); } + let task = self + .project + .update(cx, |project, cx| project.open_image(project_path, cx)); + return cx.spawn(async move |_, cx| { + let image = task.await?; + let image = image.update(cx, |image, _| image.image.clone())?; + let format = image.format; + let image = cx + .update(|cx| LanguageModelImage::from_image(image, cx))? + .await; + if let Some(image) = image { + Ok(Mention::Image(MentionImage { + data: image.source, + format, + })) + } else { + Err(anyhow!("Failed to convert image")) + } + }); } + + let buffer = self + .project + .update(cx, |project, cx| project.open_buffer(project_path, cx)); + cx.spawn(async move |_, cx| { + let buffer = buffer.await?; + let mention = buffer.update(cx, |buffer, cx| Mention::Text { + content: buffer.text(), + tracked_buffers: vec![cx.entity()], + })?; + anyhow::Ok(mention) + }) } fn confirm_mention_for_directory( &mut self, - crease_id: CreaseId, - anchor: Anchor, abs_path: PathBuf, - window: &mut Window, cx: &mut Context, - ) -> Task<()> { + ) -> Task> { fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec<(Arc, PathBuf)> { let mut files = Vec::new(); @@ -386,24 +418,21 @@ impl MessageEditor { files } - let uri = MentionUri::Directory { - abs_path: abs_path.clone(), - }; let Some(project_path) = self .project .read(cx) .project_path_for_absolute_path(&abs_path, cx) else { - return Task::ready(()); + return Task::ready(Err(anyhow!("project path not found"))); }; let Some(entry) = self.project.read(cx).entry_for_path(&project_path, cx) else { - return Task::ready(()); + return Task::ready(Err(anyhow!("project entry not found"))); }; let Some(worktree) = self.project.read(cx).worktree_for_entry(entry.id, cx) else { - return Task::ready(()); + return Task::ready(Err(anyhow!("worktree not found"))); }; let project = self.project.clone(); - let task = cx.spawn(async move |_, cx| { + cx.spawn(async move |_, cx| { let directory_path = entry.path.clone(); let worktree_id = worktree.read_with(cx, |worktree, _| worktree.id())?; @@ -453,89 +482,83 @@ impl MessageEditor { ((rel_path, full_path, rope), buffer) }) .unzip(); - (render_directory_contents(contents), tracked_buffers) + Mention::Text { + content: render_directory_contents(contents), + tracked_buffers, + } }) .await; anyhow::Ok(contents) - }); - let task = cx - .spawn(async move |_, _| task.await.map_err(|e| e.to_string())) - .shared(); - - self.mention_set - .directories - .insert(abs_path.clone(), task.clone()); - - let editor = self.editor.clone(); - cx.spawn_in(window, async move |this, cx| { - if task.await.notify_async_err(cx).is_some() { - this.update(cx, |this, _| { - this.mention_set.insert_uri(crease_id, uri); - }) - .ok(); - } else { - editor - .update(cx, |editor, cx| { - editor.display_map.update(cx, |display_map, cx| { - display_map.unfold_intersecting(vec![anchor..anchor], true, cx); - }); - editor.remove_creases([crease_id], cx); - }) - .ok(); - this.update(cx, |this, _cx| { - this.mention_set.directories.remove(&abs_path); - }) - .ok(); - } }) } fn confirm_mention_for_fetch( &mut self, - crease_id: CreaseId, - anchor: Anchor, url: url::Url, - window: &mut Window, cx: &mut Context, - ) -> Task<()> { - let Some(http_client) = self + ) -> Task> { + let http_client = match self .workspace - .update(cx, |workspace, _cx| workspace.client().http_client()) - .ok() - else { - return Task::ready(()); + .update(cx, |workspace, _| workspace.client().http_client()) + { + Ok(http_client) => http_client, + Err(e) => return Task::ready(Err(e)), }; - - let url_string = url.to_string(); - let fetch = cx - .background_executor() - .spawn(async move { - fetch_url_content(http_client, url_string) - .map_err(|e| e.to_string()) - .await + cx.background_executor().spawn(async move { + let content = fetch_url_content(http_client, url.to_string()).await?; + Ok(Mention::Text { + content, + tracked_buffers: Vec::new(), }) - .shared(); - self.mention_set - .add_fetch_result(url.clone(), fetch.clone()); + }) + } - cx.spawn_in(window, async move |this, cx| { - let fetch = fetch.await.notify_async_err(cx); - this.update(cx, |this, cx| { - if fetch.is_some() { - this.mention_set - .insert_uri(crease_id, MentionUri::Fetch { url }); - } else { - // Remove crease if we failed to fetch - this.editor.update(cx, |editor, cx| { - editor.display_map.update(cx, |display_map, cx| { - display_map.unfold_intersecting(vec![anchor..anchor], true, cx); - }); - editor.remove_creases([crease_id], cx); - }); - this.mention_set.fetch_results.remove(&url); + fn confirm_mention_for_symbol( + &mut self, + abs_path: PathBuf, + line_range: RangeInclusive, + cx: &mut Context, + ) -> Task> { + let Some(project_path) = self + .project + .read(cx) + .project_path_for_absolute_path(&abs_path, cx) + else { + return Task::ready(Err(anyhow!("project path not found"))); + }; + let buffer = self + .project + .update(cx, |project, cx| project.open_buffer(project_path, cx)); + cx.spawn(async move |_, cx| { + let buffer = buffer.await?; + let mention = buffer.update(cx, |buffer, cx| { + let start = Point::new(*line_range.start(), 0).min(buffer.max_point()); + let end = Point::new(*line_range.end() + 1, 0).min(buffer.max_point()); + let content = buffer.text_for_range(start..end).collect(); + Mention::Text { + content, + tracked_buffers: vec![cx.entity()], } + })?; + anyhow::Ok(mention) + }) + } + + fn confirm_mention_for_rule( + &mut self, + id: PromptId, + cx: &mut Context, + ) -> Task> { + let Some(prompt_store) = self.prompt_store.clone() else { + return Task::ready(Err(anyhow!("missing prompt store"))); + }; + let prompt = prompt_store.read(cx).load(id, cx); + cx.spawn(async move |_, _| { + let prompt = prompt.await?; + Ok(Mention::Text { + content: prompt, + tracked_buffers: Vec::new(), }) - .ok(); }) } @@ -560,24 +583,24 @@ impl MessageEditor { let range = snapshot.anchor_after(offset + range_to_fold.start) ..snapshot.anchor_after(offset + range_to_fold.end); - // TODO support selections from buffers with no path - let Some(project_path) = buffer.read(cx).project_path(cx) else { - continue; - }; - let Some(abs_path) = self.project.read(cx).absolute_path(&project_path, cx) else { - continue; - }; + let abs_path = buffer + .read(cx) + .project_path(cx) + .and_then(|project_path| self.project.read(cx).absolute_path(&project_path, cx)); let snapshot = buffer.read(cx).snapshot(); + let text = snapshot + .text_for_range(selection_range.clone()) + .collect::(); let point_range = selection_range.to_point(&snapshot); - let line_range = point_range.start.row..point_range.end.row; + let line_range = point_range.start.row..=point_range.end.row; let uri = MentionUri::Selection { - path: abs_path.clone(), + abs_path: abs_path.clone(), line_range: line_range.clone(), }; let crease = crate::context_picker::crease_for_mention( - selection_name(&abs_path, &line_range).into(), + selection_name(abs_path.as_deref(), &line_range).into(), uri.icon_path(cx), range, self.editor.downgrade(), @@ -589,132 +612,69 @@ impl MessageEditor { crease_ids.first().copied().unwrap() }); - self.mention_set.insert_uri(crease_id, uri); + self.mention_set.mentions.insert( + crease_id, + ( + uri, + Task::ready(Ok(Mention::Text { + content: text, + tracked_buffers: vec![buffer], + })) + .shared(), + ), + ); } } fn confirm_mention_for_thread( &mut self, - crease_id: CreaseId, - anchor: Anchor, id: acp::SessionId, - name: String, - window: &mut Window, cx: &mut Context, - ) -> Task<()> { - let uri = MentionUri::Thread { - id: id.clone(), - name, - }; + ) -> Task> { let server = Rc::new(agent2::NativeAgentServer::new( self.project.read(cx).fs().clone(), self.history_store.clone(), )); let connection = server.connect(Path::new(""), &self.project, cx); - let load_summary = cx.spawn({ - let id = id.clone(); - async move |_, cx| { - let agent = connection.await?; - let agent = agent.downcast::().unwrap(); - let summary = agent - .0 - .update(cx, |agent, cx| agent.thread_summary(id, cx))? - .await?; - anyhow::Ok(summary) - } - }); - let task = cx - .spawn(async move |_, _| load_summary.await.map_err(|e| format!("{e}"))) - .shared(); - - self.mention_set.insert_thread(id.clone(), task.clone()); - self.mention_set.insert_uri(crease_id, uri); - - let editor = self.editor.clone(); - cx.spawn_in(window, async move |this, cx| { - if task.await.notify_async_err(cx).is_none() { - editor - .update(cx, |editor, cx| { - editor.display_map.update(cx, |display_map, cx| { - display_map.unfold_intersecting(vec![anchor..anchor], true, cx); - }); - editor.remove_creases([crease_id], cx); - }) - .ok(); - this.update(cx, |this, _| { - this.mention_set.thread_summaries.remove(&id); - this.mention_set.uri_by_crease_id.remove(&crease_id); - }) - .ok(); - } + cx.spawn(async move |_, cx| { + let agent = connection.await?; + let agent = agent.downcast::().unwrap(); + let summary = agent + .0 + .update(cx, |agent, cx| agent.thread_summary(id, cx))? + .await?; + anyhow::Ok(Mention::Text { + content: summary.to_string(), + tracked_buffers: Vec::new(), + }) }) } fn confirm_mention_for_text_thread( &mut self, - crease_id: CreaseId, - anchor: Anchor, path: PathBuf, - name: String, - window: &mut Window, cx: &mut Context, - ) -> Task<()> { - let uri = MentionUri::TextThread { - path: path.clone(), - name, - }; + ) -> Task> { let context = self.history_store.update(cx, |text_thread_store, cx| { text_thread_store.load_text_thread(path.as_path().into(), cx) }); - let task = cx - .spawn(async move |_, cx| { - let context = context.await.map_err(|e| e.to_string())?; - let xml = context - .update(cx, |context, cx| context.to_xml(cx)) - .map_err(|e| e.to_string())?; - Ok(xml) + cx.spawn(async move |_, cx| { + let context = context.await?; + let xml = context.update(cx, |context, cx| context.to_xml(cx))?; + Ok(Mention::Text { + content: xml, + tracked_buffers: Vec::new(), }) - .shared(); - - self.mention_set - .insert_text_thread(path.clone(), task.clone()); - - let editor = self.editor.clone(); - cx.spawn_in(window, async move |this, cx| { - if task.await.notify_async_err(cx).is_some() { - this.update(cx, |this, _| { - this.mention_set.insert_uri(crease_id, uri); - }) - .ok(); - } else { - editor - .update(cx, |editor, cx| { - editor.display_map.update(cx, |display_map, cx| { - display_map.unfold_intersecting(vec![anchor..anchor], true, cx); - }); - editor.remove_creases([crease_id], cx); - }) - .ok(); - this.update(cx, |this, _| { - this.mention_set.text_thread_summaries.remove(&path); - }) - .ok(); - } }) } pub fn contents( &self, - window: &mut Window, cx: &mut Context, ) -> Task, Vec>)>> { - let contents = self.mention_set.contents( - &self.project, - self.prompt_store.as_ref(), - &self.prompt_capabilities.get(), - window, - cx, - ); + let contents = self + .mention_set + .contents(&self.prompt_capabilities.get(), cx); let editor = self.editor.clone(); let prevent_slash_commands = self.prevent_slash_commands; @@ -729,7 +689,7 @@ impl MessageEditor { editor.display_map.update(cx, |map, cx| { let snapshot = map.snapshot(cx); for (crease_id, crease) in snapshot.crease_snapshot.creases() { - let Some(mention) = contents.get(&crease_id) else { + let Some((uri, mention)) = contents.get(&crease_id) else { continue; }; @@ -747,7 +707,6 @@ impl MessageEditor { } let chunk = match mention { Mention::Text { - uri, content, tracked_buffers, } => { @@ -764,17 +723,25 @@ impl MessageEditor { }) } Mention::Image(mention_image) => { + let uri = match uri { + MentionUri::File { .. } => Some(uri.to_uri().to_string()), + MentionUri::PastedImage => None, + other => { + debug_panic!( + "unexpected mention uri for image: {:?}", + other + ); + None + } + }; acp::ContentBlock::Image(acp::ImageContent { annotations: None, data: mention_image.data.to_string(), mime_type: mention_image.format.mime_type().into(), - uri: mention_image - .abs_path - .as_ref() - .map(|path| format!("file://{}", path.display())), + uri, }) } - Mention::UriOnly(uri) => { + Mention::UriOnly => { acp::ContentBlock::ResourceLink(acp::ResourceLink { name: uri.name(), uri: uri.to_uri().to_string(), @@ -813,7 +780,13 @@ impl MessageEditor { pub fn clear(&mut self, window: &mut Window, cx: &mut Context) { self.editor.update(cx, |editor, cx| { editor.clear(window, cx); - editor.remove_creases(self.mention_set.drain(), cx) + editor.remove_creases( + self.mention_set + .mentions + .drain() + .map(|(crease_id, _)| crease_id), + cx, + ) }); } @@ -853,7 +826,7 @@ impl MessageEditor { } cx.stop_propagation(); - let replacement_text = "image"; + let replacement_text = MentionUri::PastedImage.as_link().to_string(); for image in images { let (excerpt_id, text_anchor, multibuffer_anchor) = self.editor.update(cx, |message_editor, cx| { @@ -876,24 +849,62 @@ impl MessageEditor { }); let content_len = replacement_text.len(); - let Some(anchor) = multibuffer_anchor else { - return; + let Some(start_anchor) = multibuffer_anchor else { + continue; }; - let task = Task::ready(Ok(Arc::new(image))).shared(); + let end_anchor = self.editor.update(cx, |editor, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + snapshot.anchor_before(start_anchor.to_offset(&snapshot) + content_len) + }); + let image = Arc::new(image); let Some(crease_id) = insert_crease_for_image( excerpt_id, text_anchor, content_len, None.clone(), - task.clone(), + Task::ready(Ok(image.clone())).shared(), self.editor.clone(), window, cx, ) else { - return; + continue; }; - self.confirm_mention_for_image(crease_id, anchor, None, task, window, cx) - .detach(); + let task = cx + .spawn_in(window, { + async move |_, cx| { + let format = image.format; + let image = cx + .update(|_, cx| LanguageModelImage::from_image(image, cx)) + .map_err(|e| e.to_string())? + .await; + if let Some(image) = image { + Ok(Mention::Image(MentionImage { + data: image.source, + format, + })) + } else { + Err("Failed to convert image".into()) + } + } + }) + .shared(); + + self.mention_set + .mentions + .insert(crease_id, (MentionUri::PastedImage, task.clone())); + + cx.spawn_in(window, async move |this, cx| { + if task.await.notify_async_err(cx).is_none() { + this.update(cx, |this, cx| { + this.editor.update(cx, |editor, cx| { + editor.edit([(start_anchor..end_anchor, "")], cx); + }); + this.mention_set.mentions.remove(&crease_id); + }) + .ok(); + } + }) + .detach(); } } @@ -995,67 +1006,6 @@ impl MessageEditor { }) } - fn confirm_mention_for_image( - &mut self, - crease_id: CreaseId, - anchor: Anchor, - abs_path: Option, - image: Shared, String>>>, - window: &mut Window, - cx: &mut Context, - ) -> Task<()> { - let editor = self.editor.clone(); - let task = cx - .spawn_in(window, { - let abs_path = abs_path.clone(); - async move |_, cx| { - let image = image.await?; - let format = image.format; - let image = cx - .update(|_, cx| LanguageModelImage::from_image(image, cx)) - .map_err(|e| e.to_string())? - .await; - if let Some(image) = image { - Ok(MentionImage { - abs_path, - data: image.source, - format, - }) - } else { - Err("Failed to convert image".into()) - } - } - }) - .shared(); - - self.mention_set.insert_image(crease_id, task.clone()); - - cx.spawn_in(window, async move |this, cx| { - if task.await.notify_async_err(cx).is_some() { - if let Some(abs_path) = abs_path.clone() { - this.update(cx, |this, _cx| { - this.mention_set - .insert_uri(crease_id, MentionUri::File { abs_path }); - }) - .ok(); - } - } else { - editor - .update(cx, |editor, cx| { - editor.display_map.update(cx, |display_map, cx| { - display_map.unfold_intersecting(vec![anchor..anchor], true, cx); - }); - editor.remove_creases([crease_id], cx); - }) - .ok(); - this.update(cx, |this, _cx| { - this.mention_set.images.remove(&crease_id); - }) - .ok(); - } - }) - } - pub fn set_mode(&mut self, mode: EditorMode, cx: &mut Context) { self.editor.update(cx, |editor, cx| { editor.set_mode(mode); @@ -1073,7 +1023,6 @@ impl MessageEditor { let mut text = String::new(); let mut mentions = Vec::new(); - let mut images = Vec::new(); for chunk in message { match chunk { @@ -1084,26 +1033,58 @@ impl MessageEditor { resource: acp::EmbeddedResourceResource::TextResourceContents(resource), .. }) => { - if let Some(mention_uri) = MentionUri::parse(&resource.uri).log_err() { - let start = text.len(); - write!(&mut text, "{}", mention_uri.as_link()).ok(); - let end = text.len(); - mentions.push((start..end, mention_uri, resource.text)); - } + let Some(mention_uri) = MentionUri::parse(&resource.uri).log_err() else { + continue; + }; + let start = text.len(); + write!(&mut text, "{}", mention_uri.as_link()).ok(); + let end = text.len(); + mentions.push(( + start..end, + mention_uri, + Mention::Text { + content: resource.text, + tracked_buffers: Vec::new(), + }, + )); } acp::ContentBlock::ResourceLink(resource) => { if let Some(mention_uri) = MentionUri::parse(&resource.uri).log_err() { let start = text.len(); write!(&mut text, "{}", mention_uri.as_link()).ok(); let end = text.len(); - mentions.push((start..end, mention_uri, resource.uri)); + mentions.push((start..end, mention_uri, Mention::UriOnly)); } } - acp::ContentBlock::Image(content) => { + acp::ContentBlock::Image(acp::ImageContent { + uri, + data, + mime_type, + annotations: _, + }) => { + let mention_uri = if let Some(uri) = uri { + MentionUri::parse(&uri) + } else { + Ok(MentionUri::PastedImage) + }; + let Some(mention_uri) = mention_uri.log_err() else { + continue; + }; + let Some(format) = ImageFormat::from_mime_type(&mime_type) else { + log::error!("failed to parse MIME type for image: {mime_type:?}"); + continue; + }; let start = text.len(); - text.push_str("image"); + write!(&mut text, "{}", mention_uri.as_link()).ok(); let end = text.len(); - images.push((start..end, content)); + mentions.push(( + start..end, + mention_uri, + Mention::Image(MentionImage { + data: data.into(), + format, + }), + )); } acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => {} } @@ -1114,9 +1095,9 @@ impl MessageEditor { editor.buffer().read(cx).snapshot(cx) }); - for (range, mention_uri, text) in mentions { + for (range, mention_uri, mention) in mentions { let anchor = snapshot.anchor_before(range.start); - let crease_id = crate::context_picker::insert_crease_for_mention( + let Some(crease_id) = crate::context_picker::insert_crease_for_mention( anchor.excerpt_id, anchor.text_anchor, range.end - range.start, @@ -1125,77 +1106,14 @@ impl MessageEditor { self.editor.clone(), window, cx, - ); - - if let Some(crease_id) = crease_id { - self.mention_set.insert_uri(crease_id, mention_uri.clone()); - } - - match mention_uri { - MentionUri::Thread { id, .. } => { - self.mention_set - .insert_thread(id, Task::ready(Ok(text.into())).shared()); - } - MentionUri::TextThread { path, .. } => { - self.mention_set - .insert_text_thread(path, Task::ready(Ok(text)).shared()); - } - MentionUri::Fetch { url } => { - self.mention_set - .add_fetch_result(url, Task::ready(Ok(text)).shared()); - } - MentionUri::Directory { abs_path } => { - let task = Task::ready(Ok((text, Vec::new()))).shared(); - self.mention_set.directories.insert(abs_path, task); - } - MentionUri::File { .. } - | MentionUri::Symbol { .. } - | MentionUri::Rule { .. } - | MentionUri::Selection { .. } => {} - } - } - for (range, content) in images { - let Some(format) = ImageFormat::from_mime_type(&content.mime_type) else { + ) else { continue; }; - let anchor = snapshot.anchor_before(range.start); - let abs_path = content - .uri - .as_ref() - .and_then(|uri| uri.strip_prefix("file://").map(|s| Path::new(s).into())); - let name = content - .uri - .as_ref() - .and_then(|uri| { - uri.strip_prefix("file://") - .and_then(|path| Path::new(path).file_name()) - }) - .map(|name| name.to_string_lossy().to_string()) - .unwrap_or("Image".to_owned()); - let crease_id = crate::context_picker::insert_crease_for_mention( - anchor.excerpt_id, - anchor.text_anchor, - range.end - range.start, - name.into(), - IconName::Image.path().into(), - self.editor.clone(), - window, - cx, + self.mention_set.mentions.insert( + crease_id, + (mention_uri.clone(), Task::ready(Ok(mention)).shared()), ); - let data: SharedString = content.data.to_string().into(); - - if let Some(crease_id) = crease_id { - self.mention_set.insert_image( - crease_id, - Task::ready(Ok(MentionImage { - abs_path, - data, - format, - })) - .shared(), - ); - } } cx.notify(); } @@ -1425,289 +1343,60 @@ impl Render for ImageHover { } } -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq)] pub enum Mention { Text { - uri: MentionUri, content: String, tracked_buffers: Vec>, }, Image(MentionImage), - UriOnly(MentionUri), + UriOnly, } #[derive(Clone, Debug, Eq, PartialEq)] pub struct MentionImage { - pub abs_path: Option, pub data: SharedString, pub format: ImageFormat, } #[derive(Default)] pub struct MentionSet { - uri_by_crease_id: HashMap, - fetch_results: HashMap>>>, - images: HashMap>>>, - thread_summaries: HashMap>>>, - text_thread_summaries: HashMap>>>, - directories: HashMap>), String>>>>, + mentions: HashMap>>)>, } impl MentionSet { - pub fn insert_uri(&mut self, crease_id: CreaseId, uri: MentionUri) { - self.uri_by_crease_id.insert(crease_id, uri); - } - - pub fn add_fetch_result(&mut self, url: Url, content: Shared>>) { - self.fetch_results.insert(url, content); - } - - pub fn insert_image( - &mut self, - crease_id: CreaseId, - task: Shared>>, - ) { - self.images.insert(crease_id, task); - } - - fn insert_thread( - &mut self, - id: acp::SessionId, - task: Shared>>, - ) { - self.thread_summaries.insert(id, task); - } - - fn insert_text_thread(&mut self, path: PathBuf, task: Shared>>) { - self.text_thread_summaries.insert(path, task); - } - - pub fn contents( + fn contents( &self, - project: &Entity, - prompt_store: Option<&Entity>, prompt_capabilities: &acp::PromptCapabilities, - _window: &mut Window, cx: &mut App, - ) -> Task>> { + ) -> Task>> { if !prompt_capabilities.embedded_context { let mentions = self - .uri_by_crease_id + .mentions .iter() - .map(|(crease_id, uri)| (*crease_id, Mention::UriOnly(uri.clone()))) + .map(|(crease_id, (uri, _))| (*crease_id, (uri.clone(), Mention::UriOnly))) .collect(); return Task::ready(Ok(mentions)); } - let mut processed_image_creases = HashSet::default(); - - let mut contents = self - .uri_by_crease_id - .iter() - .map(|(&crease_id, uri)| { - match uri { - MentionUri::File { abs_path, .. } => { - let uri = uri.clone(); - let abs_path = abs_path.to_path_buf(); - - if let Some(task) = self.images.get(&crease_id).cloned() { - processed_image_creases.insert(crease_id); - return cx.spawn(async move |_| { - let image = task.await.map_err(|e| anyhow!("{e}"))?; - anyhow::Ok((crease_id, Mention::Image(image))) - }); - } - - let buffer_task = project.update(cx, |project, cx| { - let path = project - .find_project_path(abs_path, cx) - .context("Failed to find project path")?; - anyhow::Ok(project.open_buffer(path, cx)) - }); - cx.spawn(async move |cx| { - let buffer = buffer_task?.await?; - let content = buffer.read_with(cx, |buffer, _cx| buffer.text())?; - - anyhow::Ok(( - crease_id, - Mention::Text { - uri, - content, - tracked_buffers: vec![buffer], - }, - )) - }) - } - MentionUri::Directory { abs_path } => { - let Some(content) = self.directories.get(abs_path).cloned() else { - return Task::ready(Err(anyhow!("missing directory load task"))); - }; - let uri = uri.clone(); - cx.spawn(async move |_| { - let (content, tracked_buffers) = - content.await.map_err(|e| anyhow::anyhow!("{e}"))?; - Ok(( - crease_id, - Mention::Text { - uri, - content, - tracked_buffers, - }, - )) - }) - } - MentionUri::Symbol { - path, line_range, .. - } - | MentionUri::Selection { - path, line_range, .. - } => { - let uri = uri.clone(); - let path_buf = path.clone(); - let line_range = line_range.clone(); - - let buffer_task = project.update(cx, |project, cx| { - let path = project - .find_project_path(&path_buf, cx) - .context("Failed to find project path")?; - anyhow::Ok(project.open_buffer(path, cx)) - }); - - cx.spawn(async move |cx| { - let buffer = buffer_task?.await?; - let content = buffer.read_with(cx, |buffer, _cx| { - buffer - .text_for_range( - Point::new(line_range.start, 0) - ..Point::new( - line_range.end, - buffer.line_len(line_range.end), - ), - ) - .collect() - })?; - - anyhow::Ok(( - crease_id, - Mention::Text { - uri, - content, - tracked_buffers: vec![buffer], - }, - )) - }) - } - MentionUri::Thread { id, .. } => { - let Some(content) = self.thread_summaries.get(id).cloned() else { - return Task::ready(Err(anyhow!("missing thread summary"))); - }; - let uri = uri.clone(); - cx.spawn(async move |_| { - Ok(( - crease_id, - Mention::Text { - uri, - content: content - .await - .map_err(|e| anyhow::anyhow!("{e}"))? - .to_string(), - tracked_buffers: Vec::new(), - }, - )) - }) - } - MentionUri::TextThread { path, .. } => { - let Some(content) = self.text_thread_summaries.get(path).cloned() else { - return Task::ready(Err(anyhow!("missing text thread summary"))); - }; - let uri = uri.clone(); - cx.spawn(async move |_| { - Ok(( - crease_id, - Mention::Text { - uri, - content: content.await.map_err(|e| anyhow::anyhow!("{e}"))?, - tracked_buffers: Vec::new(), - }, - )) - }) - } - MentionUri::Rule { id: prompt_id, .. } => { - let Some(prompt_store) = prompt_store else { - return Task::ready(Err(anyhow!("missing prompt store"))); - }; - let text_task = prompt_store.read(cx).load(*prompt_id, cx); - let uri = uri.clone(); - cx.spawn(async move |_| { - // TODO: report load errors instead of just logging - let text = text_task.await?; - anyhow::Ok(( - crease_id, - Mention::Text { - uri, - content: text, - tracked_buffers: Vec::new(), - }, - )) - }) - } - MentionUri::Fetch { url } => { - let Some(content) = self.fetch_results.get(url).cloned() else { - return Task::ready(Err(anyhow!("missing fetch result"))); - }; - let uri = uri.clone(); - cx.spawn(async move |_| { - Ok(( - crease_id, - Mention::Text { - uri, - content: content.await.map_err(|e| anyhow::anyhow!("{e}"))?, - tracked_buffers: Vec::new(), - }, - )) - }) - } - } - }) - .collect::>(); - - // Handle images that didn't have a mention URI (because they were added by the paste handler). - contents.extend(self.images.iter().filter_map(|(crease_id, image)| { - if processed_image_creases.contains(crease_id) { - return None; - } - let crease_id = *crease_id; - let image = image.clone(); - Some(cx.spawn(async move |_| { - Ok(( - crease_id, - Mention::Image(image.await.map_err(|e| anyhow::anyhow!("{e}"))?), - )) - })) - })); - + let mentions = self.mentions.clone(); cx.spawn(async move |_cx| { - let contents = try_join_all(contents).await?.into_iter().collect(); - anyhow::Ok(contents) + let mut contents = HashMap::default(); + for (crease_id, (mention_uri, task)) in mentions { + contents.insert( + crease_id, + (mention_uri, task.await.map_err(|e| anyhow!("{e}"))?), + ); + } + Ok(contents) }) } - pub fn drain(&mut self) -> impl Iterator { - self.fetch_results.clear(); - self.thread_summaries.clear(); - self.text_thread_summaries.clear(); - self.directories.clear(); - self.uri_by_crease_id - .drain() - .map(|(id, _)| id) - .chain(self.images.drain().map(|(id, _)| id)) - } - - pub fn remove_invalid(&mut self, snapshot: EditorSnapshot) { + fn remove_invalid(&mut self, snapshot: EditorSnapshot) { for (crease_id, crease) in snapshot.crease_snapshot.creases() { if !crease.range().start.is_valid(&snapshot.buffer_snapshot) { - self.uri_by_crease_id.remove(&crease_id); + self.mentions.remove(&crease_id); } } } @@ -1969,9 +1658,7 @@ mod tests { }); let (content, _) = message_editor - .update_in(cx, |message_editor, window, cx| { - message_editor.contents(window, cx) - }) + .update(cx, |message_editor, cx| message_editor.contents(cx)) .await .unwrap(); @@ -2038,7 +1725,8 @@ mod tests { "six.txt": "6", "seven.txt": "7", "eight.txt": "8", - } + }, + "x.png": "", }), ) .await; @@ -2222,14 +1910,10 @@ mod tests { }; let contents = message_editor - .update_in(&mut cx, |message_editor, window, cx| { - message_editor.mention_set().contents( - &project, - None, - &all_prompt_capabilities, - window, - cx, - ) + .update(&mut cx, |message_editor, cx| { + message_editor + .mention_set() + .contents(&all_prompt_capabilities, cx) }) .await .unwrap() @@ -2237,7 +1921,7 @@ mod tests { .collect::>(); { - let [Mention::Text { content, uri, .. }] = contents.as_slice() else { + let [(uri, Mention::Text { content, .. })] = contents.as_slice() else { panic!("Unexpected mentions"); }; pretty_assertions::assert_eq!(content, "1"); @@ -2245,14 +1929,10 @@ mod tests { } let contents = message_editor - .update_in(&mut cx, |message_editor, window, cx| { - message_editor.mention_set().contents( - &project, - None, - &acp::PromptCapabilities::default(), - window, - cx, - ) + .update(&mut cx, |message_editor, cx| { + message_editor + .mention_set() + .contents(&acp::PromptCapabilities::default(), cx) }) .await .unwrap() @@ -2260,7 +1940,7 @@ mod tests { .collect::>(); { - let [Mention::UriOnly(uri)] = contents.as_slice() else { + let [(uri, Mention::UriOnly)] = contents.as_slice() else { panic!("Unexpected mentions"); }; pretty_assertions::assert_eq!(uri, &url_one.parse::().unwrap()); @@ -2300,14 +1980,10 @@ mod tests { cx.run_until_parked(); let contents = message_editor - .update_in(&mut cx, |message_editor, window, cx| { - message_editor.mention_set().contents( - &project, - None, - &all_prompt_capabilities, - window, - cx, - ) + .update(&mut cx, |message_editor, cx| { + message_editor + .mention_set() + .contents(&all_prompt_capabilities, cx) }) .await .unwrap() @@ -2317,7 +1993,7 @@ mod tests { let url_eight = uri!("file:///dir/b/eight.txt"); { - let [_, Mention::Text { content, uri, .. }] = contents.as_slice() else { + let [_, (uri, Mention::Text { content, .. })] = contents.as_slice() else { panic!("Unexpected mentions"); }; pretty_assertions::assert_eq!(content, "8"); @@ -2414,14 +2090,10 @@ mod tests { }); let contents = message_editor - .update_in(&mut cx, |message_editor, window, cx| { - message_editor.mention_set().contents( - &project, - None, - &all_prompt_capabilities, - window, - cx, - ) + .update(&mut cx, |message_editor, cx| { + message_editor + .mention_set() + .contents(&all_prompt_capabilities, cx) }) .await .unwrap() @@ -2429,7 +2101,7 @@ mod tests { .collect::>(); { - let [_, _, Mention::Text { content, uri, .. }] = contents.as_slice() else { + let [_, _, (uri, Mention::Text { content, .. })] = contents.as_slice() else { panic!("Unexpected mentions"); }; pretty_assertions::assert_eq!(content, "1"); @@ -2444,11 +2116,85 @@ mod tests { cx.run_until_parked(); editor.read_with(&cx, |editor, cx| { - assert_eq!( - editor.text(cx), - format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) [@MySymbol]({url_one}?symbol=MySymbol#L1:1) ") - ); - }); + assert_eq!( + editor.text(cx), + format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) [@MySymbol]({url_one}?symbol=MySymbol#L1:1) ") + ); + }); + + // Try to mention an "image" file that will fail to load + cx.simulate_input("@file x.png"); + + editor.update(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) [@MySymbol]({url_one}?symbol=MySymbol#L1:1) @file x.png") + ); + assert!(editor.has_visible_completions_menu()); + assert_eq!(current_completion_labels(editor), &["x.png dir/"]); + }); + + editor.update_in(&mut cx, |editor, window, cx| { + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + // Getting the message contents fails + message_editor + .update(&mut cx, |message_editor, cx| { + message_editor + .mention_set() + .contents(&all_prompt_capabilities, cx) + }) + .await + .expect_err("Should fail to load x.png"); + + cx.run_until_parked(); + + // Mention was removed + editor.read_with(&cx, |editor, cx| { + assert_eq!( + editor.text(cx), + format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) [@MySymbol]({url_one}?symbol=MySymbol#L1:1) ") + ); + }); + + // Once more + cx.simulate_input("@file x.png"); + + editor.update(&mut cx, |editor, cx| { + assert_eq!( + editor.text(cx), + format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) [@MySymbol]({url_one}?symbol=MySymbol#L1:1) @file x.png") + ); + assert!(editor.has_visible_completions_menu()); + assert_eq!(current_completion_labels(editor), &["x.png dir/"]); + }); + + editor.update_in(&mut cx, |editor, window, cx| { + editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); + }); + + // This time don't immediately get the contents, just let the confirmed completion settle + cx.run_until_parked(); + + // Mention was removed + editor.read_with(&cx, |editor, cx| { + assert_eq!( + editor.text(cx), + format!("Lorem [@one.txt]({url_one}) Ipsum [@eight.txt]({url_eight}) [@MySymbol]({url_one}?symbol=MySymbol#L1:1) ") + ); + }); + + // Now getting the contents succeeds, because the invalid mention was removed + let contents = message_editor + .update(&mut cx, |message_editor, cx| { + message_editor + .mention_set() + .contents(&all_prompt_capabilities, cx) + }) + .await + .unwrap(); + assert_eq!(contents.len(), 3); } fn fold_ranges(editor: &Editor, cx: &mut App) -> Vec> { diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index dae89b3283..3ad1234e22 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -57,7 +57,9 @@ use crate::agent_diff::AgentDiff; use crate::profile_selector::{ProfileProvider, ProfileSelector}; use crate::ui::preview::UsageCallout; -use crate::ui::{AgentNotification, AgentNotificationEvent, BurnModeTooltip}; +use crate::ui::{ + AgentNotification, AgentNotificationEvent, BurnModeTooltip, UnavailableEditingTooltip, +}; use crate::{ AgentDiffPane, AgentPanel, ContinueThread, ContinueWithBurnMode, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, OpenHistory, RejectAll, ToggleBurnMode, ToggleProfileSelector, @@ -256,6 +258,7 @@ pub struct AcpThreadView { hovered_recent_history_item: Option, entry_view_state: Entity, message_editor: Entity, + focus_handle: FocusHandle, model_selector: Option>, profile_selector: Option>, notifications: Vec>, @@ -271,6 +274,7 @@ pub struct AcpThreadView { edits_expanded: bool, plan_expanded: bool, editor_expanded: bool, + terminal_expanded: bool, editing_message: Option, prompt_capabilities: Rc>, _cancel_task: Option>, @@ -310,6 +314,13 @@ impl AcpThreadView { ) -> Self { let prompt_capabilities = Rc::new(Cell::new(acp::PromptCapabilities::default())); let prevent_slash_commands = agent.clone().downcast::().is_some(); + + let placeholder = if agent.name() == "Zed Agent" { + format!("Message the {} — @ to include context", agent.name()) + } else { + format!("Message {} — @ to include context", agent.name()) + }; + let message_editor = cx.new(|cx| { let mut editor = MessageEditor::new( workspace.clone(), @@ -317,7 +328,7 @@ impl AcpThreadView { history_store.clone(), prompt_store.clone(), prompt_capabilities.clone(), - "Message the agent — @ to include context", + placeholder, prevent_slash_commands, editor::EditorMode::AutoHeight { min_lines: MIN_EDITOR_LINES, @@ -374,11 +385,13 @@ impl AcpThreadView { edits_expanded: false, plan_expanded: false, editor_expanded: false, + terminal_expanded: true, history_store, hovered_recent_history_item: None, prompt_capabilities, _subscriptions: subscriptions, _cancel_task: None, + focus_handle: cx.focus_handle(), } } @@ -402,8 +415,12 @@ impl AcpThreadView { let connection = match connect_task.await { Ok(connection) => connection, Err(err) => { - this.update(cx, |this, cx| { - this.handle_load_error(err, cx); + this.update_in(cx, |this, window, cx| { + if err.downcast_ref::().is_some() { + this.handle_load_error(err, window, cx); + } else { + this.handle_thread_error(err, cx); + } cx.notify(); }) .log_err(); @@ -520,6 +537,7 @@ impl AcpThreadView { title_editor, _subscriptions: subscriptions, }; + this.message_editor.focus_handle(cx).focus(window); this.profile_selector = this.as_native_thread(cx).map(|thread| { cx.new(|cx| { @@ -535,7 +553,7 @@ impl AcpThreadView { cx.notify(); } Err(err) => { - this.handle_load_error(err, cx); + this.handle_load_error(err, window, cx); } }; }) @@ -604,17 +622,28 @@ impl AcpThreadView { .map(|desc| cx.new(|cx| Markdown::new(desc.into(), None, None, cx))), _subscription: subscription, }; + if this.message_editor.focus_handle(cx).is_focused(window) { + this.focus_handle.focus(window) + } cx.notify(); }) .ok(); } - fn handle_load_error(&mut self, err: anyhow::Error, cx: &mut Context) { + fn handle_load_error( + &mut self, + err: anyhow::Error, + window: &mut Window, + cx: &mut Context, + ) { if let Some(load_err) = err.downcast_ref::() { self.thread_state = ThreadState::LoadError(load_err.clone()); } else { self.thread_state = ThreadState::LoadError(LoadError::Other(err.to_string().into())) } + if self.message_editor.focus_handle(cx).is_focused(window) { + self.focus_handle.focus(window) + } cx.notify(); } @@ -631,12 +660,11 @@ impl AcpThreadView { } } - pub fn title(&self, cx: &App) -> SharedString { + pub fn title(&self) -> SharedString { match &self.thread_state { - ThreadState::Ready { thread, .. } => thread.read(cx).title(), + ThreadState::Ready { .. } | ThreadState::Unauthenticated { .. } => "New Thread".into(), ThreadState::Loading { .. } => "Loading…".into(), ThreadState::LoadError(_) => "Failed to load".into(), - ThreadState::Unauthenticated { .. } => "Authentication Required".into(), } } @@ -809,7 +837,7 @@ impl AcpThreadView { let contents = self .message_editor - .update(cx, |message_editor, cx| message_editor.contents(window, cx)); + .update(cx, |message_editor, cx| message_editor.contents(cx)); self.send_impl(contents, window, cx) } @@ -822,7 +850,7 @@ impl AcpThreadView { let contents = self .message_editor - .update(cx, |message_editor, cx| message_editor.contents(window, cx)); + .update(cx, |message_editor, cx| message_editor.contents(cx)); cx.spawn_in(window, async move |this, cx| { cancelled.await; @@ -930,8 +958,7 @@ impl AcpThreadView { return; }; - let contents = - message_editor.update(cx, |message_editor, cx| message_editor.contents(window, cx)); + let contents = message_editor.update(cx, |message_editor, cx| message_editor.contents(cx)); let task = cx.foreground_executor().spawn(async move { rewind.await?; @@ -1067,6 +1094,9 @@ impl AcpThreadView { AcpThreadEvent::LoadError(error) => { self.thread_retry_status.take(); self.thread_state = ThreadState::LoadError(error.clone()); + if self.message_editor.focus_handle(cx).is_focused(window) { + self.focus_handle.focus(window) + } } AcpThreadEvent::TitleUpdated => { let title = thread.read(cx).title(); @@ -1239,6 +1269,8 @@ impl AcpThreadView { None }; + let agent_name = self.agent.name(); + v_flex() .id(("user_message", entry_ix)) .pt_2() @@ -1292,42 +1324,61 @@ impl AcpThreadView { .text_xs() .child(editor.clone().into_any_element()), ) - .when(editing && editor_focus, |this| - this.child( - h_flex() - .absolute() - .top_neg_3p5() - .right_3() - .gap_1() - .rounded_sm() - .border_1() - .border_color(cx.theme().colors().border) - .bg(cx.theme().colors().editor_background) - .overflow_hidden() - .child( - IconButton::new("cancel", IconName::Close) - .icon_color(Color::Error) - .icon_size(IconSize::XSmall) - .on_click(cx.listener(Self::cancel_editing)) - ) - .child( - IconButton::new("regenerate", IconName::Return) - .icon_color(Color::Muted) - .icon_size(IconSize::XSmall) - .tooltip(Tooltip::text( - "Editing will restart the thread from this point." - )) - .on_click(cx.listener({ - let editor = editor.clone(); - move |this, _, window, cx| { - this.regenerate( - entry_ix, &editor, window, cx, - ); - } - })), - ) - ) - ), + .when(editor_focus, |this| { + let base_container = h_flex() + .absolute() + .top_neg_3p5() + .right_3() + .gap_1() + .rounded_sm() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().editor_background) + .overflow_hidden(); + + if message.id.is_some() { + this.child( + base_container + .child( + IconButton::new("cancel", IconName::Close) + .icon_color(Color::Error) + .icon_size(IconSize::XSmall) + .on_click(cx.listener(Self::cancel_editing)) + ) + .child( + IconButton::new("regenerate", IconName::Return) + .icon_color(Color::Muted) + .icon_size(IconSize::XSmall) + .tooltip(Tooltip::text( + "Editing will restart the thread from this point." + )) + .on_click(cx.listener({ + let editor = editor.clone(); + move |this, _, window, cx| { + this.regenerate( + entry_ix, &editor, window, cx, + ); + } + })), + ) + ) + } else { + this.child( + base_container + .border_dashed() + .child( + IconButton::new("editing_unavailable", IconName::PencilUnavailable) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .style(ButtonStyle::Transparent) + .tooltip(move |_window, cx| { + cx.new(|_| UnavailableEditingTooltip::new(agent_name.into())) + .into() + }) + ) + ) + } + }), ) .into_any() } @@ -1618,39 +1669,14 @@ impl AcpThreadView { let header_id = SharedString::from(format!("outer-tool-call-header-{}", entry_ix)); let card_header_id = SharedString::from("inner-tool-call-header"); - let status_icon = match &tool_call.status { - ToolCallStatus::Pending - | ToolCallStatus::WaitingForConfirmation { .. } - | ToolCallStatus::Completed => None, - ToolCallStatus::InProgress => Some( - div() - .absolute() - .right_2() - .child( - Icon::new(IconName::ArrowCircle) - .color(Color::Muted) - .size(IconSize::Small) - .with_animation( - "running", - Animation::new(Duration::from_secs(3)).repeat(), - |icon, delta| { - icon.transform(Transformation::rotate(percentage(delta))) - }, - ), - ) - .into_any(), - ), - ToolCallStatus::Rejected | ToolCallStatus::Canceled | ToolCallStatus::Failed => Some( - div() - .absolute() - .right_2() - .child( - Icon::new(IconName::Close) - .color(Color::Error) - .size(IconSize::Small), - ) - .into_any_element(), - ), + let in_progress = match &tool_call.status { + ToolCallStatus::InProgress => true, + _ => false, + }; + + let failed_or_canceled = match &tool_call.status { + ToolCallStatus::Rejected | ToolCallStatus::Canceled | ToolCallStatus::Failed => true, + _ => false, }; let failed_tool_call = matches!( @@ -1665,16 +1691,17 @@ impl AcpThreadView { matches!(tool_call.kind, acp::ToolKind::Edit) || tool_call.diffs().next().is_some(); let use_card_layout = needs_confirmation || is_edit; - let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation; + let is_collapsible = !tool_call.content.is_empty() && !use_card_layout; - let is_open = needs_confirmation || self.expanded_tool_calls.contains(&tool_call.id); + let is_open = + needs_confirmation || is_edit || self.expanded_tool_calls.contains(&tool_call.id); let gradient_overlay = |color: Hsla| { div() .absolute() .top_0() .right_0() - .w_12() + .w_16() .h_full() .bg(linear_gradient( 90., @@ -1814,6 +1841,7 @@ impl AcpThreadView { .w_full() .max_w_full() .ml_1p5() + .overflow_hidden() .child(h_flex().pr_8().child(self.render_markdown( tool_call.label.clone(), default_markdown_style(false, true, window, cx), @@ -1833,7 +1861,33 @@ impl AcpThreadView { .into_any() }), ) - .children(status_icon), + .when(in_progress && use_card_layout, |this| { + this.child( + div().absolute().right_2().child( + Icon::new(IconName::ArrowCircle) + .color(Color::Muted) + .size(IconSize::Small) + .with_animation( + "running", + Animation::new(Duration::from_secs(3)).repeat(), + |icon, delta| { + icon.transform(Transformation::rotate(percentage( + delta, + ))) + }, + ), + ), + ) + }) + .when(failed_or_canceled, |this| { + this.child( + div().absolute().right_2().child( + Icon::new(IconName::Close) + .color(Color::Error) + .size(IconSize::Small), + ), + ) + }), ) .children(tool_output_display) } @@ -1883,13 +1937,10 @@ impl AcpThreadView { .text_color(cx.theme().colors().text_muted) .child(self.render_markdown(markdown, default_markdown_style(false, false, window, cx))) .child( - Button::new(button_id, "Collapse") + IconButton::new(button_id, IconName::ChevronUp) .full_width() .style(ButtonStyle::Outlined) - .label_size(LabelSize::Small) - .icon(IconName::ChevronUp) .icon_color(Color::Muted) - .icon_position(IconPosition::Start) .on_click(cx.listener({ move |this: &mut Self, _, _, cx: &mut Context| { this.expanded_tool_calls.remove(&tool_call_id); @@ -2113,8 +2164,6 @@ impl AcpThreadView { .map(|path| format!("{}", path.display())) .unwrap_or_else(|| "current directory".to_string()); - let is_expanded = self.expanded_tool_calls.contains(&tool_call.id); - let header = h_flex() .id(SharedString::from(format!( "terminal-tool-header-{}", @@ -2248,19 +2297,12 @@ impl AcpThreadView { "terminal-tool-disclosure-{}", terminal.entity_id() )), - is_expanded, + self.terminal_expanded, ) .opened_icon(IconName::ChevronUp) .closed_icon(IconName::ChevronDown) - .on_click(cx.listener({ - let id = tool_call.id.clone(); - move |this, _event, _window, _cx| { - if is_expanded { - this.expanded_tool_calls.remove(&id); - } else { - this.expanded_tool_calls.insert(id.clone()); - } - } + .on_click(cx.listener(move |this, _event, _window, _cx| { + this.terminal_expanded = !this.terminal_expanded; })), ); @@ -2269,7 +2311,7 @@ impl AcpThreadView { .read(cx) .entry(entry_ix) .and_then(|entry| entry.terminal(terminal)); - let show_output = is_expanded && terminal_view.is_some(); + let show_output = self.terminal_expanded && terminal_view.is_some(); v_flex() .mb_2() @@ -2317,33 +2359,6 @@ impl AcpThreadView { .into_any() } - fn render_agent_logo(&self) -> AnyElement { - Icon::new(self.agent.logo()) - .color(Color::Muted) - .size(IconSize::XLarge) - .into_any_element() - } - - fn render_error_agent_logo(&self) -> AnyElement { - let logo = Icon::new(self.agent.logo()) - .color(Color::Muted) - .size(IconSize::XLarge) - .into_any_element(); - - h_flex() - .relative() - .justify_center() - .child(div().opacity(0.3).child(logo)) - .child( - h_flex() - .absolute() - .right_1() - .bottom_0() - .child(Icon::new(IconName::XCircleFilled).color(Color::Error)), - ) - .into_any_element() - } - fn render_rules_item(&self, cx: &Context) -> Option { let project_context = self .as_native_thread(cx)? @@ -2391,39 +2406,32 @@ impl AcpThreadView { return None; } + let has_both = user_rules_text.is_some() && rules_file_text.is_some(); + Some( - v_flex() + h_flex() .px_2p5() - .gap_1() + .pb_1() + .child( + Icon::new(IconName::Attach) + .size(IconSize::XSmall) + .color(Color::Disabled), + ) .when_some(user_rules_text, |parent, user_rules_text| { parent.child( h_flex() - .group("user-rules") .id("user-rules") - .w_full() - .child( - Icon::new(IconName::Reader) - .size(IconSize::XSmall) - .color(Color::Disabled), - ) + .ml_1() + .mr_1p5() .child( Label::new(user_rules_text) .size(LabelSize::XSmall) .color(Color::Muted) .truncate() - .buffer_font(cx) - .ml_1p5() - .mr_0p5(), - ) - .child( - IconButton::new("open-prompt-library", IconName::ArrowUpRight) - .shape(ui::IconButtonShape::Square) - .icon_size(IconSize::XSmall) - .icon_color(Color::Ignored) - .visible_on_hover("user-rules") - // TODO: Figure out a way to pass focus handle here so we can display the `OpenRulesLibrary` keybinding - .tooltip(Tooltip::text("View User Rules")), + .buffer_font(cx), ) + .hover(|s| s.bg(cx.theme().colors().element_hover)) + .tooltip(Tooltip::text("View User Rules")) .on_click(move |_event, window, cx| { window.dispatch_action( Box::new(OpenRulesLibrary { @@ -2434,33 +2442,20 @@ impl AcpThreadView { }), ) }) + .when(has_both, |this| this.child(Divider::vertical())) .when_some(rules_file_text, |parent, rules_file_text| { parent.child( h_flex() - .group("project-rules") .id("project-rules") - .w_full() - .child( - Icon::new(IconName::Reader) - .size(IconSize::XSmall) - .color(Color::Disabled), - ) + .ml_1p5() .child( Label::new(rules_file_text) .size(LabelSize::XSmall) .color(Color::Muted) - .buffer_font(cx) - .ml_1p5() - .mr_0p5(), - ) - .child( - IconButton::new("open-rule", IconName::ArrowUpRight) - .shape(ui::IconButtonShape::Square) - .icon_size(IconSize::XSmall) - .icon_color(Color::Ignored) - .visible_on_hover("project-rules") - .tooltip(Tooltip::text("View Project Rules")), + .buffer_font(cx), ) + .hover(|s| s.bg(cx.theme().colors().element_hover)) + .tooltip(Tooltip::text("View Project Rules")) .on_click(cx.listener(Self::handle_open_rules)), ) }) @@ -2492,8 +2487,7 @@ impl AcpThreadView { ) } - fn render_empty_state(&self, window: &mut Window, cx: &mut Context) -> AnyElement { - let loading = matches!(&self.thread_state, ThreadState::Loading { .. }); + fn render_recent_history(&self, window: &mut Window, cx: &mut Context) -> AnyElement { let render_history = self .agent .clone() @@ -2505,38 +2499,6 @@ impl AcpThreadView { v_flex() .size_full() - .when(!render_history, |this| { - this.child( - v_flex() - .size_full() - .items_center() - .justify_center() - .child(if loading { - h_flex() - .justify_center() - .child(self.render_agent_logo()) - .with_animation( - "pulsating_icon", - Animation::new(Duration::from_secs(2)) - .repeat() - .with_easing(pulsating_between(0.4, 1.0)), - |icon, delta| icon.opacity(delta), - ) - .into_any() - } else { - self.render_agent_logo().into_any_element() - }) - .child(h_flex().mt_4().mb_2().justify_center().child(if loading { - div() - .child(LoadingLabel::new("").size(LabelSize::Large)) - .into_any_element() - } else { - Headline::new(self.agent.empty_state_headline()) - .size(HeadlineSize::Medium) - .into_any_element() - })), - ) - }) .when(render_history, |this| { let recent_history: Vec<_> = self.history_store.update(cx, |history_store, _| { history_store.entries().take(3).collect() @@ -2611,196 +2573,136 @@ impl AcpThreadView { window: &mut Window, cx: &Context, ) -> Div { - v_flex() - .p_2() - .gap_2() - .flex_1() - .items_center() - .justify_center() - .child( - v_flex() - .items_center() - .justify_center() - .child(self.render_error_agent_logo()) - .child( - h_flex().mt_4().mb_1().justify_center().child( - Headline::new("Authentication Required").size(HeadlineSize::Medium), - ), - ) - .into_any(), - ) - .children(description.map(|desc| { - div().text_ui(cx).text_center().child(self.render_markdown( - desc.clone(), - default_markdown_style(false, false, window, cx), - )) - })) - .children( - configuration_view - .cloned() - .map(|view| div().px_4().w_full().max_w_128().child(view)), - ) - .when( - configuration_view.is_none() - && description.is_none() - && pending_auth_method.is_none(), - |el| { - el.child( - div() - .text_ui(cx) - .text_center() - .px_4() - .w_full() - .max_w_128() - .child(Label::new("Authentication required")), - ) - }, - ) - .when_some(pending_auth_method, |el, _| { - let spinner_icon = div() - .px_0p5() - .id("generating") - .tooltip(Tooltip::text("Generating Changes…")) - .child( - Icon::new(IconName::ArrowCircle) - .size(IconSize::Small) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| { - icon.transform(Transformation::rotate(percentage(delta))) - }, - ) - .into_any_element(), - ) - .into_any(); - el.child( + let show_description = + configuration_view.is_none() && description.is_none() && pending_auth_method.is_none(); + + v_flex().flex_1().size_full().justify_end().child( + v_flex() + .p_2() + .pr_3() + .w_full() + .gap_1() + .border_t_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().status().warning.opacity(0.04)) + .child( h_flex() - .text_ui(cx) - .text_center() - .justify_center() - .gap_2() - .px_4() - .w_full() - .max_w_128() - .child(Label::new("Authenticating...")) - .child(spinner_icon), + .gap_1p5() + .child( + Icon::new(IconName::Warning) + .color(Color::Warning) + .size(IconSize::Small), + ) + .child(Label::new("Authentication Required").size(LabelSize::Small)), ) - }) - .child( - h_flex() - .mt_1p5() - .gap_1() - .flex_wrap() - .justify_center() - .children(connection.auth_methods().iter().enumerate().rev().map( - |(ix, method)| { - Button::new( - SharedString::from(method.id.0.clone()), - method.name.clone(), + .children(description.map(|desc| { + div().text_ui(cx).child(self.render_markdown( + desc.clone(), + default_markdown_style(false, false, window, cx), + )) + })) + .children( + configuration_view + .cloned() + .map(|view| div().w_full().child(view)), + ) + .when( + show_description, + |el| { + el.child( + Label::new(format!( + "You are not currently authenticated with {}. Please choose one of the following options:", + self.agent.name() + )) + .size(LabelSize::Small) + .color(Color::Muted) + .mb_1() + .ml_5(), + ) + }, + ) + .when_some(pending_auth_method, |el, _| { + el.child( + h_flex() + .py_4() + .w_full() + .justify_center() + .gap_1() + .child( + Icon::new(IconName::ArrowCircle) + .size(IconSize::Small) + .color(Color::Muted) + .with_animation( + "arrow-circle", + Animation::new(Duration::from_secs(2)).repeat(), + |icon, delta| { + icon.transform(Transformation::rotate(percentage( + delta, + ))) + }, + ) + .into_any_element(), ) - .style(ButtonStyle::Outlined) - .when(ix == 0, |el| { - el.style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .child(Label::new("Authenticating…").size(LabelSize::Small)), + ) + }) + .when(!connection.auth_methods().is_empty(), |this| { + this.child( + h_flex() + .justify_end() + .flex_wrap() + .gap_1() + .when(!show_description, |this| { + this.border_t_1() + .mt_1() + .pt_2() + .border_color(cx.theme().colors().border.opacity(0.8)) }) - .size(ButtonSize::Medium) - .label_size(LabelSize::Small) - .on_click({ - let method_id = method.id.clone(); - cx.listener(move |this, _, window, cx| { - this.authenticate(method_id.clone(), window, cx) - }) - }) - }, - )), - ) + .children( + connection + .auth_methods() + .iter() + .enumerate() + .rev() + .map(|(ix, method)| { + Button::new( + SharedString::from(method.id.0.clone()), + method.name.clone(), + ) + .when(ix == 0, |el| { + el.style(ButtonStyle::Tinted(ui::TintColor::Warning)) + }) + .label_size(LabelSize::Small) + .on_click({ + let method_id = method.id.clone(); + cx.listener(move |this, _, window, cx| { + this.authenticate(method_id.clone(), window, cx) + }) + }) + }), + ), + ) + }) + + ) } fn render_load_error(&self, e: &LoadError, cx: &Context) -> AnyElement { - let mut container = v_flex() - .items_center() - .justify_center() - .child(self.render_error_agent_logo()) - .child( - v_flex() - .mt_4() - .mb_2() - .gap_0p5() - .text_center() - .items_center() - .child(Headline::new("Failed to launch").size(HeadlineSize::Medium)) - .child( - Label::new(e.to_string()) - .size(LabelSize::Small) - .color(Color::Muted), - ), - ); - - if let LoadError::Unsupported { - upgrade_message, - upgrade_command, - .. - } = &e - { - let upgrade_message = upgrade_message.clone(); - let upgrade_command = upgrade_command.clone(); - container = container.child( - Button::new("upgrade", upgrade_message) - .tooltip(Tooltip::text(upgrade_command.clone())) - .on_click(cx.listener(move |this, _, window, cx| { - let task = this - .workspace - .update(cx, |workspace, cx| { - let project = workspace.project().read(cx); - let cwd = project.first_project_directory(cx); - let shell = project.terminal_settings(&cwd, cx).shell.clone(); - let spawn_in_terminal = task::SpawnInTerminal { - id: task::TaskId("upgrade".to_string()), - full_label: upgrade_command.clone(), - label: upgrade_command.clone(), - command: Some(upgrade_command.clone()), - args: Vec::new(), - command_label: upgrade_command.clone(), - cwd, - env: Default::default(), - use_new_terminal: true, - allow_concurrent_runs: true, - reveal: Default::default(), - reveal_target: Default::default(), - hide: Default::default(), - shell, - show_summary: true, - show_command: true, - show_rerun: false, - }; - workspace.spawn_in_terminal(spawn_in_terminal, window, cx) - }) - .ok(); - let Some(task) = task else { return }; - cx.spawn_in(window, async move |this, cx| { - if let Some(Ok(_)) = task.await { - this.update_in(cx, |this, window, cx| { - this.reset(window, cx); - }) - .ok(); - } - }) - .detach() - })), - ); - } else if let LoadError::NotInstalled { - install_message, - install_command, - .. - } = e - { - let install_message = install_message.clone(); - let install_command = install_command.clone(); - container = container.child( - Button::new("install", install_message) - .style(ButtonStyle::Tinted(ui::TintColor::Accent)) - .size(ButtonSize::Medium) + let (message, action_slot) = match e { + LoadError::NotInstalled { + error_message, + install_message, + install_command, + } => { + let install_command = install_command.clone(); + let button = Button::new("install", install_message) .tooltip(Tooltip::text(install_command.clone())) + .style(ButtonStyle::Outlined) + .label_size(LabelSize::Small) + .icon(IconName::Download) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .icon_position(IconPosition::Start) .on_click(cx.listener(move |this, _, window, cx| { let task = this .workspace @@ -2840,11 +2742,81 @@ impl AcpThreadView { } }) .detach() - })), - ); - } + })); - container.into_any() + (error_message.clone(), Some(button.into_any_element())) + } + LoadError::Unsupported { + error_message, + upgrade_message, + upgrade_command, + } => { + let upgrade_command = upgrade_command.clone(); + let button = Button::new("upgrade", upgrade_message) + .tooltip(Tooltip::text(upgrade_command.clone())) + .style(ButtonStyle::Outlined) + .label_size(LabelSize::Small) + .icon(IconName::Download) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .icon_position(IconPosition::Start) + .on_click(cx.listener(move |this, _, window, cx| { + let task = this + .workspace + .update(cx, |workspace, cx| { + let project = workspace.project().read(cx); + let cwd = project.first_project_directory(cx); + let shell = project.terminal_settings(&cwd, cx).shell.clone(); + let spawn_in_terminal = task::SpawnInTerminal { + id: task::TaskId("upgrade".to_string()), + full_label: upgrade_command.clone(), + label: upgrade_command.clone(), + command: Some(upgrade_command.clone()), + args: Vec::new(), + command_label: upgrade_command.clone(), + cwd, + env: Default::default(), + use_new_terminal: true, + allow_concurrent_runs: true, + reveal: Default::default(), + reveal_target: Default::default(), + hide: Default::default(), + shell, + show_summary: true, + show_command: true, + show_rerun: false, + }; + workspace.spawn_in_terminal(spawn_in_terminal, window, cx) + }) + .ok(); + let Some(task) = task else { return }; + cx.spawn_in(window, async move |this, cx| { + if let Some(Ok(_)) = task.await { + this.update_in(cx, |this, window, cx| { + this.reset(window, cx); + }) + .ok(); + } + }) + .detach() + })); + + (error_message.clone(), Some(button.into_any_element())) + } + LoadError::Exited { .. } => ("Server exited with status {status}".into(), None), + LoadError::Other(msg) => ( + msg.into(), + Some(self.create_copy_button(msg.to_string()).into_any_element()), + ), + }; + + Callout::new() + .severity(Severity::Error) + .icon(IconName::XCircleFilled) + .title("Failed to Launch") + .description(message) + .actions_slot(div().children(action_slot)) + .into_any_element() } fn render_activity_bar( @@ -3335,6 +3307,19 @@ impl AcpThreadView { (IconName::Maximize, "Expand Message Editor") }; + let backdrop = div() + .size_full() + .absolute() + .inset_0() + .bg(cx.theme().colors().panel_background) + .opacity(0.8) + .block_mouse_except_scroll(); + + let enable_editor = match self.thread_state { + ThreadState::Loading { .. } | ThreadState::Ready { .. } => true, + ThreadState::Unauthenticated { .. } | ThreadState::LoadError(..) => false, + }; + v_flex() .on_action(cx.listener(Self::expand_message_editor)) .on_action(cx.listener(|this, _: &ToggleProfileSelector, window, cx| { @@ -3410,6 +3395,7 @@ impl AcpThreadView { .child(self.render_send_button(cx)), ), ) + .when(!enable_editor, |this| this.child(backdrop)) .into_any() } @@ -3662,6 +3648,7 @@ impl AcpThreadView { .open_path(path, None, true, window, cx) .detach_and_log_err(cx); } + MentionUri::PastedImage => {} MentionUri::Directory { abs_path } => { let project = workspace.project(); let Some(entry) = project.update(cx, |project, cx| { @@ -3676,9 +3663,14 @@ impl AcpThreadView { }); } MentionUri::Symbol { - path, line_range, .. + abs_path: path, + line_range, + .. } - | MentionUri::Selection { path, line_range } => { + | MentionUri::Selection { + abs_path: Some(path), + line_range, + } => { let project = workspace.project(); let Some((path, _)) = project.update(cx, |project, cx| { let path = project.find_project_path(path, cx)?; @@ -3694,8 +3686,8 @@ impl AcpThreadView { let Some(editor) = item.await?.downcast::() else { return Ok(()); }; - let range = - Point::new(line_range.start, 0)..Point::new(line_range.start, 0); + let range = Point::new(*line_range.start(), 0) + ..Point::new(*line_range.start(), 0); editor .update_in(cx, |editor, window, cx| { editor.change_selections( @@ -3710,6 +3702,7 @@ impl AcpThreadView { }) .detach_and_log_err(cx); } + MentionUri::Selection { abs_path: None, .. } => {} MentionUri::Thread { id, name } => { if let Some(panel) = workspace.panel::(cx) { panel.update(cx, |panel, cx| { @@ -3912,18 +3905,19 @@ impl AcpThreadView { return; } - let title = self.title(cx); + // TODO: Change this once we have title summarization for external agents. + let title = self.agent.name(); match AgentSettings::get_global(cx).notify_when_agent_waiting { NotifyWhenAgentWaiting::PrimaryScreen => { if let Some(primary) = cx.primary_display() { - self.pop_up(icon, caption.into(), title, window, primary, cx); + self.pop_up(icon, caption.into(), title.into(), window, primary, cx); } } NotifyWhenAgentWaiting::AllScreens => { let caption = caption.into(); for screen in cx.displays() { - self.pop_up(icon, caption.clone(), title.clone(), window, screen, cx); + self.pop_up(icon, caption.clone(), title.into(), window, screen, cx); } } NotifyWhenAgentWaiting::Never => { @@ -4057,8 +4051,10 @@ impl AcpThreadView { .group("thread-controls-container") .w_full() .mr_1() + .pt_1() .pb_2() .px(RESPONSE_PADDING_X) + .gap_px() .opacity(0.4) .hover(|style| style.opacity(1.)) .flex_wrap() @@ -4420,6 +4416,7 @@ impl AcpThreadView { Callout::new() .severity(Severity::Error) .title("Error") + .icon(IconName::XCircle) .description(error.clone()) .actions_slot(self.create_copy_button(error.to_string())) .dismiss_action(self.dismiss_error_button(cx)) @@ -4431,6 +4428,7 @@ impl AcpThreadView { Callout::new() .severity(Severity::Error) + .icon(IconName::XCircle) .title("Free Usage Exceeded") .description(ERROR_MESSAGE) .actions_slot( @@ -4450,6 +4448,7 @@ impl AcpThreadView { Callout::new() .severity(Severity::Error) .title("Authentication Required") + .icon(IconName::XCircle) .description(error.clone()) .actions_slot( h_flex() @@ -4475,6 +4474,7 @@ impl AcpThreadView { Callout::new() .severity(Severity::Error) .title("Model Prompt Limit Reached") + .icon(IconName::XCircle) .description(error_message) .actions_slot( h_flex() @@ -4645,7 +4645,14 @@ impl AcpThreadView { impl Focusable for AcpThreadView { fn focus_handle(&self, cx: &App) -> FocusHandle { - self.message_editor.focus_handle(cx) + match self.thread_state { + ThreadState::Loading { .. } | ThreadState::Ready { .. } => { + self.message_editor.focus_handle(cx) + } + ThreadState::LoadError(_) | ThreadState::Unauthenticated { .. } => { + self.focus_handle.clone() + } + } } } @@ -4661,6 +4668,7 @@ impl Render for AcpThreadView { .on_action(cx.listener(Self::toggle_burn_mode)) .on_action(cx.listener(Self::keep_all)) .on_action(cx.listener(Self::reject_all)) + .track_focus(&self.focus_handle) .bg(cx.theme().colors().panel_background) .child(match &self.thread_state { ThreadState::Unauthenticated { @@ -4677,14 +4685,14 @@ impl Render for AcpThreadView { window, cx, ), - ThreadState::Loading { .. } => { - v_flex().flex_1().child(self.render_empty_state(window, cx)) - } - ThreadState::LoadError(e) => v_flex() - .p_2() + ThreadState::Loading { .. } => v_flex() .flex_1() + .child(self.render_recent_history(window, cx)), + ThreadState::LoadError(e) => v_flex() + .flex_1() + .size_full() .items_center() - .justify_center() + .justify_end() .child(self.render_load_error(e, cx)), ThreadState::Ready { thread, .. } => { let thread_clone = thread.clone(); @@ -4721,7 +4729,7 @@ impl Render for AcpThreadView { }, ) } else { - this.child(self.render_empty_state(window, cx)) + this.child(self.render_recent_history(window, cx)) } }) } diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index 2cad913295..e0cecad6e2 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -1595,11 +1595,6 @@ impl ActiveThread { return; }; - if model.provider.must_accept_terms(cx) { - cx.notify(); - return; - } - let edited_text = state.editor.read(cx).text(cx); let creases = state.editor.update(cx, extract_message_creases); diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 00e48efdac..f33f0ba032 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -93,14 +93,6 @@ impl AgentConfiguration { let scroll_handle = ScrollHandle::new(); let scrollbar_state = ScrollbarState::new(scroll_handle.clone()); - let mut expanded_provider_configurations = HashMap::default(); - if LanguageModelRegistry::read_global(cx) - .provider(&ZED_CLOUD_PROVIDER_ID) - .is_some_and(|cloud_provider| cloud_provider.must_accept_terms(cx)) - { - expanded_provider_configurations.insert(ZED_CLOUD_PROVIDER_ID, true); - } - let mut this = Self { fs, language_registry, @@ -109,7 +101,7 @@ impl AgentConfiguration { configuration_views_by_provider: HashMap::default(), context_server_store, expanded_context_server_tools: HashMap::default(), - expanded_provider_configurations, + expanded_provider_configurations: HashMap::default(), tools, _registry_subscription: registry_subscription, scroll_handle, diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index d2ff6aa4f3..0e611d0db9 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -54,9 +54,7 @@ use gpui::{ Pixels, Subscription, Task, UpdateGlobal, WeakEntity, prelude::*, pulsating_between, }; use language::LanguageRegistry; -use language_model::{ - ConfigurationError, ConfiguredModel, LanguageModelProviderTosView, LanguageModelRegistry, -}; +use language_model::{ConfigurationError, ConfiguredModel, LanguageModelRegistry}; use project::{DisableAiSettings, Project, ProjectPath, Worktree}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; use rules_library::{RulesLibrary, open_rules_library}; @@ -2041,9 +2039,11 @@ impl AgentPanel { match state { ThreadSummary::Pending => Label::new(ThreadSummary::DEFAULT) .truncate() + .color(Color::Muted) .into_any_element(), ThreadSummary::Generating => Label::new(LOADING_SUMMARY_PLACEHOLDER) .truncate() + .color(Color::Muted) .into_any_element(), ThreadSummary::Ready(_) => div() .w_full() @@ -2097,7 +2097,8 @@ impl AgentPanel { .child(title_editor) .into_any_element() } else { - Label::new(thread_view.read(cx).title(cx)) + Label::new(thread_view.read(cx).title()) + .color(Color::Muted) .truncate() .into_any_element() } @@ -2111,6 +2112,7 @@ impl AgentPanel { match summary { ContextSummary::Pending => Label::new(ContextSummary::DEFAULT) + .color(Color::Muted) .truncate() .into_any_element(), ContextSummary::Content(summary) => { @@ -2122,6 +2124,7 @@ impl AgentPanel { } else { Label::new(LOADING_SUMMARY_PLACEHOLDER) .truncate() + .color(Color::Muted) .into_any_element() } } @@ -3198,17 +3201,6 @@ impl AgentPanel { ConfigurationError::ModelNotFound | ConfigurationError::ProviderNotAuthenticated(_) | ConfigurationError::NoProvider => callout.into_any_element(), - ConfigurationError::ProviderPendingTermsAcceptance(provider) => { - Banner::new() - .severity(Severity::Warning) - .child(h_flex().w_full().children( - provider.render_accept_terms( - LanguageModelProviderTosView::ThreadEmptyState, - cx, - ), - )) - .into_any_element() - } } } diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index bed10e90a7..45e7529ec2 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -378,18 +378,13 @@ impl MessageEditor { } fn send_to_model(&mut self, window: &mut Window, cx: &mut Context) { - let Some(ConfiguredModel { model, provider }) = self + let Some(ConfiguredModel { model, .. }) = self .thread .update(cx, |thread, cx| thread.get_or_init_configured_model(cx)) else { return; }; - if provider.must_accept_terms(cx) { - cx.notify(); - return; - } - let (user_message, user_message_creases) = self.editor.update(cx, |editor, cx| { let creases = extract_message_creases(editor, cx); let text = editor.text(cx); diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index 9fbd90c4a6..edb672a872 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -190,7 +190,6 @@ pub struct TextThreadEditor { invoked_slash_command_creases: HashMap, _subscriptions: Vec, last_error: Option, - show_accept_terms: bool, pub(crate) slash_menu_handle: PopoverMenuHandle>, // dragged_file_worktrees is used to keep references to worktrees that were added @@ -289,7 +288,6 @@ impl TextThreadEditor { invoked_slash_command_creases: HashMap::default(), _subscriptions, last_error: None, - show_accept_terms: false, slash_menu_handle: Default::default(), dragged_file_worktrees: Vec::new(), language_model_selector: cx.new(|cx| { @@ -367,20 +365,7 @@ impl TextThreadEditor { } fn send_to_model(&mut self, window: &mut Window, cx: &mut Context) { - let provider = LanguageModelRegistry::read_global(cx) - .default_model() - .map(|default| default.provider); - if provider - .as_ref() - .is_some_and(|provider| provider.must_accept_terms(cx)) - { - self.show_accept_terms = true; - cx.notify(); - return; - } - self.last_error = None; - if let Some(user_message) = self.context.update(cx, |context, cx| context.assist(cx)) { let new_selection = { let cursor = user_message @@ -1930,7 +1915,6 @@ impl TextThreadEditor { ConfigurationError::NoProvider | ConfigurationError::ModelNotFound | ConfigurationError::ProviderNotAuthenticated(_) => true, - ConfigurationError::ProviderPendingTermsAcceptance(_) => self.show_accept_terms, } } diff --git a/crates/agent_ui/src/ui.rs b/crates/agent_ui/src/ui.rs index e27a224240..ada973cddf 100644 --- a/crates/agent_ui/src/ui.rs +++ b/crates/agent_ui/src/ui.rs @@ -4,9 +4,11 @@ mod context_pill; mod end_trial_upsell; mod onboarding_modal; pub mod preview; +mod unavailable_editing_tooltip; pub use agent_notification::*; pub use burn_mode_tooltip::*; pub use context_pill::*; pub use end_trial_upsell::*; pub use onboarding_modal::*; +pub use unavailable_editing_tooltip::*; diff --git a/crates/agent_ui/src/ui/preview/usage_callouts.rs b/crates/agent_ui/src/ui/preview/usage_callouts.rs index 29b12ea627..d4d037b976 100644 --- a/crates/agent_ui/src/ui/preview/usage_callouts.rs +++ b/crates/agent_ui/src/ui/preview/usage_callouts.rs @@ -86,23 +86,18 @@ impl RenderOnce for UsageCallout { (IconName::Warning, Severity::Warning) }; - div() - .border_t_1() - .border_color(cx.theme().colors().border) - .child( - Callout::new() - .icon(icon) - .severity(severity) - .icon(icon) - .title(title) - .description(message) - .actions_slot( - Button::new("upgrade", button_text) - .label_size(LabelSize::Small) - .on_click(move |_, _, cx| { - cx.open_url(&url); - }), - ), + Callout::new() + .icon(icon) + .severity(severity) + .icon(icon) + .title(title) + .description(message) + .actions_slot( + Button::new("upgrade", button_text) + .label_size(LabelSize::Small) + .on_click(move |_, _, cx| { + cx.open_url(&url); + }), ) .into_any_element() } diff --git a/crates/agent_ui/src/ui/unavailable_editing_tooltip.rs b/crates/agent_ui/src/ui/unavailable_editing_tooltip.rs new file mode 100644 index 0000000000..78d4c64e0a --- /dev/null +++ b/crates/agent_ui/src/ui/unavailable_editing_tooltip.rs @@ -0,0 +1,29 @@ +use gpui::{Context, IntoElement, Render, Window}; +use ui::{prelude::*, tooltip_container}; + +pub struct UnavailableEditingTooltip { + agent_name: SharedString, +} + +impl UnavailableEditingTooltip { + pub fn new(agent_name: SharedString) -> Self { + Self { agent_name } + } +} + +impl Render for UnavailableEditingTooltip { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + tooltip_container(window, cx, |this, _, _| { + this.child(Label::new("Unavailable Editing")).child( + div().max_w_64().child( + Label::new(format!( + "Editing previous messages is not available for {} yet.", + self.agent_name + )) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + }) + } +} diff --git a/crates/ai_onboarding/src/ai_onboarding.rs b/crates/ai_onboarding/src/ai_onboarding.rs index 717abebfd1..6d8ac64725 100644 --- a/crates/ai_onboarding/src/ai_onboarding.rs +++ b/crates/ai_onboarding/src/ai_onboarding.rs @@ -19,7 +19,7 @@ use std::sync::Arc; use client::{Client, UserStore, zed_urls}; use gpui::{AnyElement, Entity, IntoElement, ParentElement}; -use ui::{Divider, RegisterComponent, TintColor, Tooltip, prelude::*}; +use ui::{Divider, RegisterComponent, Tooltip, prelude::*}; #[derive(PartialEq)] pub enum SignInStatus { @@ -43,12 +43,10 @@ impl From for SignInStatus { #[derive(RegisterComponent, IntoElement)] pub struct ZedAiOnboarding { pub sign_in_status: SignInStatus, - pub has_accepted_terms_of_service: bool, pub plan: Option, pub account_too_young: bool, pub continue_with_zed_ai: Arc, pub sign_in: Arc, - pub accept_terms_of_service: Arc, pub dismiss_onboarding: Option>, } @@ -64,17 +62,9 @@ impl ZedAiOnboarding { Self { sign_in_status: status.into(), - has_accepted_terms_of_service: store.has_accepted_terms_of_service(), plan: store.plan(), account_too_young: store.account_too_young(), continue_with_zed_ai, - accept_terms_of_service: Arc::new({ - let store = user_store.clone(); - move |_window, cx| { - let task = store.update(cx, |store, cx| store.accept_terms_of_service(cx)); - task.detach_and_log_err(cx); - } - }), sign_in: Arc::new(move |_window, cx| { cx.spawn({ let client = client.clone(); @@ -94,42 +84,6 @@ impl ZedAiOnboarding { self } - fn render_accept_terms_of_service(&self) -> AnyElement { - v_flex() - .gap_1() - .w_full() - .child(Headline::new("Accept Terms of Service")) - .child( - Label::new("We don’t sell your data, track you across the web, or compromise your privacy.") - .color(Color::Muted) - .mb_2(), - ) - .child( - Button::new("terms_of_service", "Review Terms of Service") - .full_width() - .style(ButtonStyle::Outlined) - .icon(IconName::ArrowUpRight) - .icon_color(Color::Muted) - .icon_size(IconSize::Small) - .on_click(move |_, _window, cx| { - telemetry::event!("Review Terms of Service Clicked"); - cx.open_url(&zed_urls::terms_of_service(cx)) - }), - ) - .child( - Button::new("accept_terms", "Accept") - .full_width() - .style(ButtonStyle::Tinted(TintColor::Accent)) - .on_click({ - let callback = self.accept_terms_of_service.clone(); - move |_, window, cx| { - telemetry::event!("Terms of Service Accepted"); - (callback)(window, cx)} - }), - ) - .into_any_element() - } - fn render_sign_in_disclaimer(&self, _cx: &mut App) -> AnyElement { let signing_in = matches!(self.sign_in_status, SignInStatus::SigningIn); let plan_definitions = PlanDefinitions; @@ -359,14 +313,10 @@ impl ZedAiOnboarding { impl RenderOnce for ZedAiOnboarding { fn render(self, _window: &mut ui::Window, cx: &mut App) -> impl IntoElement { if matches!(self.sign_in_status, SignInStatus::SignedIn) { - if self.has_accepted_terms_of_service { - match self.plan { - None | Some(Plan::ZedFree) => self.render_free_plan_state(cx), - Some(Plan::ZedProTrial) => self.render_trial_state(cx), - Some(Plan::ZedPro) => self.render_pro_plan_state(cx), - } - } else { - self.render_accept_terms_of_service() + match self.plan { + None | Some(Plan::ZedFree) => self.render_free_plan_state(cx), + Some(Plan::ZedProTrial) => self.render_trial_state(cx), + Some(Plan::ZedPro) => self.render_pro_plan_state(cx), } } else { self.render_sign_in_disclaimer(cx) @@ -390,18 +340,15 @@ impl Component for ZedAiOnboarding { fn preview(_window: &mut Window, _cx: &mut App) -> Option { fn onboarding( sign_in_status: SignInStatus, - has_accepted_terms_of_service: bool, plan: Option, account_too_young: bool, ) -> AnyElement { ZedAiOnboarding { sign_in_status, - has_accepted_terms_of_service, plan, account_too_young, continue_with_zed_ai: Arc::new(|_, _| {}), sign_in: Arc::new(|_, _| {}), - accept_terms_of_service: Arc::new(|_, _| {}), dismiss_onboarding: None, } .into_any_element() @@ -415,27 +362,23 @@ impl Component for ZedAiOnboarding { .children(vec![ single_example( "Not Signed-in", - onboarding(SignInStatus::SignedOut, false, None, false), - ), - single_example( - "Not Accepted ToS", - onboarding(SignInStatus::SignedIn, false, None, false), + onboarding(SignInStatus::SignedOut, None, false), ), single_example( "Young Account", - onboarding(SignInStatus::SignedIn, true, None, true), + onboarding(SignInStatus::SignedIn, None, true), ), single_example( "Free Plan", - onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedFree), false), + onboarding(SignInStatus::SignedIn, Some(Plan::ZedFree), false), ), single_example( "Pro Trial", - onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedProTrial), false), + onboarding(SignInStatus::SignedIn, Some(Plan::ZedProTrial), false), ), single_example( "Pro Plan", - onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedPro), false), + onboarding(SignInStatus::SignedIn, Some(Plan::ZedPro), false), ), ]) .into_any_element(), diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index f9b8a10610..2bbe7dd1b5 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -66,6 +66,8 @@ pub static IMPERSONATE_LOGIN: LazyLock> = LazyLock::new(|| { .and_then(|s| if s.is_empty() { None } else { Some(s) }) }); +pub static USE_WEB_LOGIN: LazyLock = LazyLock::new(|| std::env::var("ZED_WEB_LOGIN").is_ok()); + pub static ADMIN_API_TOKEN: LazyLock> = LazyLock::new(|| { std::env::var("ZED_ADMIN_API_TOKEN") .ok() @@ -1392,11 +1394,13 @@ impl Client { if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) { - eprintln!("authenticate as admin {login}, {token}"); + if !*USE_WEB_LOGIN { + eprintln!("authenticate as admin {login}, {token}"); - return this - .authenticate_as_admin(http, login.clone(), token.clone()) - .await; + return this + .authenticate_as_admin(http, login.clone(), token.clone()) + .await; + } } // Start an HTTP server to receive the redirect from Zed's sign-in page. diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 20f99e3944..d23eb37519 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -1,5 +1,5 @@ use super::{Client, Status, TypedEnvelope, proto}; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Context as _, Result}; use chrono::{DateTime, Utc}; use cloud_api_client::websocket_protocol::MessageToClient; use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo}; @@ -46,11 +46,6 @@ impl ProjectId { } } -#[derive( - Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize, -)] -pub struct DevServerProjectId(pub u64); - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ParticipantIndex(pub u32); @@ -116,7 +111,6 @@ pub struct UserStore { edit_prediction_usage: Option, plan_info: Option, current_user: watch::Receiver>>, - accepted_tos_at: Option>, contacts: Vec>, incoming_contact_requests: Vec>, outgoing_contact_requests: Vec>, @@ -194,7 +188,6 @@ impl UserStore { plan_info: None, model_request_usage: None, edit_prediction_usage: None, - accepted_tos_at: None, contacts: Default::default(), incoming_contact_requests: Default::default(), participant_indices: Default::default(), @@ -271,7 +264,6 @@ impl UserStore { Status::SignedOut => { current_user_tx.send(None).await.ok(); this.update(cx, |this, cx| { - this.accepted_tos_at = None; cx.emit(Event::PrivateUserInfoUpdated); cx.notify(); this.clear_contacts() @@ -791,19 +783,6 @@ impl UserStore { .set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff); } - let accepted_tos_at = { - #[cfg(debug_assertions)] - if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() { - None - } else { - response.user.accepted_tos_at - } - - #[cfg(not(debug_assertions))] - response.user.accepted_tos_at - }; - - self.accepted_tos_at = Some(accepted_tos_at); self.model_request_usage = Some(ModelRequestUsage(RequestUsage { limit: response.plan.usage.model_requests.limit, amount: response.plan.usage.model_requests.used as i32, @@ -846,32 +825,6 @@ impl UserStore { self.current_user.clone() } - pub fn has_accepted_terms_of_service(&self) -> bool { - self.accepted_tos_at - .is_some_and(|accepted_tos_at| accepted_tos_at.is_some()) - } - - pub fn accept_terms_of_service(&self, cx: &Context) -> Task> { - if self.current_user().is_none() { - return Task::ready(Err(anyhow!("no current user"))); - }; - - let client = self.client.clone(); - cx.spawn(async move |this, cx| -> anyhow::Result<()> { - let client = client.upgrade().context("client not found")?; - let response = client - .cloud_client() - .accept_terms_of_service() - .await - .context("error accepting tos")?; - this.update(cx, |this, cx| { - this.accepted_tos_at = Some(response.user.accepted_tos_at); - cx.emit(Event::PrivateUserInfoUpdated); - })?; - Ok(()) - }) - } - fn load_users( &self, request: impl RequestMessage, diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs index 205f3e2432..7fd96fcef0 100644 --- a/crates/cloud_api_client/src/cloud_api_client.rs +++ b/crates/cloud_api_client/src/cloud_api_client.rs @@ -115,34 +115,6 @@ impl CloudApiClient { })) } - pub async fn accept_terms_of_service(&self) -> Result { - let request = self.build_request( - Request::builder().method(Method::POST).uri( - self.http_client - .build_zed_cloud_url("/client/terms_of_service/accept", &[])? - .as_ref(), - ), - AsyncBody::default(), - )?; - - let mut response = self.http_client.send(request).await?; - - if !response.status().is_success() { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - anyhow::bail!( - "Failed to accept terms of service.\nStatus: {:?}\nBody: {body}", - response.status() - ) - } - - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - Ok(serde_json::from_str(&body)?) - } - pub async fn create_llm_token( &self, system_id: Option, diff --git a/crates/collab/src/tests/following_tests.rs b/crates/collab/src/tests/following_tests.rs index d9fd8ffeb2..1e0c915bcb 100644 --- a/crates/collab/src/tests/following_tests.rs +++ b/crates/collab/src/tests/following_tests.rs @@ -970,7 +970,7 @@ async fn test_peers_following_each_other(cx_a: &mut TestAppContext, cx_b: &mut T // the follow. workspace_b.update_in(cx_b, |workspace, window, cx| { workspace.active_pane().update(cx, |pane, cx| { - pane.activate_prev_item(true, window, cx); + pane.activate_previous_item(&Default::default(), window, cx); }); }); executor.run_until_parked(); @@ -1073,7 +1073,7 @@ async fn test_peers_following_each_other(cx_a: &mut TestAppContext, cx_b: &mut T // Client A cycles through some tabs. workspace_a.update_in(cx_a, |workspace, window, cx| { workspace.active_pane().update(cx, |pane, cx| { - pane.activate_prev_item(true, window, cx); + pane.activate_previous_item(&Default::default(), window, cx); }); }); executor.run_until_parked(); @@ -1117,7 +1117,7 @@ async fn test_peers_following_each_other(cx_a: &mut TestAppContext, cx_b: &mut T workspace_a.update_in(cx_a, |workspace, window, cx| { workspace.active_pane().update(cx, |pane, cx| { - pane.activate_prev_item(true, window, cx); + pane.activate_previous_item(&Default::default(), window, cx); }); }); executor.run_until_parked(); @@ -1164,7 +1164,7 @@ async fn test_peers_following_each_other(cx_a: &mut TestAppContext, cx_b: &mut T workspace_a.update_in(cx_a, |workspace, window, cx| { workspace.active_pane().update(cx, |pane, cx| { - pane.activate_prev_item(true, window, cx); + pane.activate_previous_item(&Default::default(), window, cx); }); }); executor.run_until_parked(); diff --git a/crates/context_server/src/test.rs b/crates/context_server/src/test.rs index dedf589664..008542ab24 100644 --- a/crates/context_server/src/test.rs +++ b/crates/context_server/src/test.rs @@ -1,6 +1,6 @@ use anyhow::Context as _; use collections::HashMap; -use futures::{Stream, StreamExt as _, lock::Mutex}; +use futures::{FutureExt, Stream, StreamExt as _, future::BoxFuture, lock::Mutex}; use gpui::BackgroundExecutor; use std::{pin::Pin, sync::Arc}; @@ -14,9 +14,12 @@ pub fn create_fake_transport( executor: BackgroundExecutor, ) -> FakeTransport { let name = name.into(); - FakeTransport::new(executor).on_request::(move |_params| { - create_initialize_response(name.clone()) - }) + FakeTransport::new(executor).on_request::( + move |_params| { + let name = name.clone(); + async move { create_initialize_response(name.clone()) } + }, + ) } fn create_initialize_response(server_name: String) -> InitializeResponse { @@ -32,8 +35,10 @@ fn create_initialize_response(server_name: String) -> InitializeResponse { } pub struct FakeTransport { - request_handlers: - HashMap<&'static str, Arc serde_json::Value + Send + Sync>>, + request_handlers: HashMap< + &'static str, + Arc BoxFuture<'static, serde_json::Value>>, + >, tx: futures::channel::mpsc::UnboundedSender, rx: Arc>>, executor: BackgroundExecutor, @@ -50,18 +55,25 @@ impl FakeTransport { } } - pub fn on_request( + pub fn on_request( mut self, - handler: impl Fn(T::Params) -> T::Response + Send + Sync + 'static, - ) -> Self { + handler: impl 'static + Send + Sync + Fn(T::Params) -> Fut, + ) -> Self + where + T: crate::types::Request, + Fut: 'static + Send + Future, + { self.request_handlers.insert( T::METHOD, Arc::new(move |value| { - let params = value.get("params").expect("Missing parameters").clone(); + let params = value + .get("params") + .cloned() + .unwrap_or(serde_json::Value::Null); let params: T::Params = serde_json::from_value(params).expect("Invalid parameters received"); let response = handler(params); - serde_json::to_value(response).unwrap() + async move { serde_json::to_value(response.await).unwrap() }.boxed() }), ); self @@ -77,7 +89,7 @@ impl Transport for FakeTransport { if let Some(method) = msg.get("method") { let method = method.as_str().expect("Invalid method received"); if let Some(handler) = self.request_handlers.get(method) { - let payload = handler(msg); + let payload = handler(msg).await; let response = serde_json::json!({ "jsonrpc": "2.0", "id": id, diff --git a/crates/copilot/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_completion_provider.rs index 9308500ed4..52d75175e5 100644 --- a/crates/copilot/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_completion_provider.rs @@ -301,6 +301,7 @@ mod tests { init_test(cx, |settings| { settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Disabled, + words_min_length: 0, lsp: true, lsp_fetch_timeout_ms: 0, lsp_insert_mode: LspInsertMode::Insert, @@ -533,6 +534,7 @@ mod tests { init_test(cx, |settings| { settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Disabled, + words_min_length: 0, lsp: true, lsp_fetch_timeout_ms: 0, lsp_insert_mode: LspInsertMode::Insert, diff --git a/crates/diagnostics/src/diagnostics.rs b/crates/diagnostics/src/diagnostics.rs index 2e20118381..037e4fc0fd 100644 --- a/crates/diagnostics/src/diagnostics.rs +++ b/crates/diagnostics/src/diagnostics.rs @@ -438,7 +438,7 @@ impl ProjectDiagnosticsEditor { for buffer_path in diagnostics_sources.iter().cloned() { if cx .update(|cx| { - fetch_tasks.push(run_flycheck(project.clone(), buffer_path, cx)); + fetch_tasks.push(run_flycheck(project.clone(), Some(buffer_path), cx)); }) .is_err() { @@ -462,7 +462,7 @@ impl ProjectDiagnosticsEditor { .iter() .cloned() { - cancel_gasks.push(cancel_flycheck(self.project.clone(), buffer_path, cx)); + cancel_gasks.push(cancel_flycheck(self.project.clone(), Some(buffer_path), cx)); } self.cargo_diagnostics_fetch.cancel_task = Some(cx.background_spawn(async move { diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 964f202934..6b695af1ae 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -89,9 +89,6 @@ pub trait EditPredictionProvider: 'static + Sized { debounce: bool, cx: &mut Context, ); - fn needs_terms_acceptance(&self, _cx: &App) -> bool { - false - } fn cycle( &mut self, buffer: Entity, @@ -124,7 +121,6 @@ pub trait EditPredictionProviderHandle { fn data_collection_state(&self, cx: &App) -> DataCollectionState; fn usage(&self, cx: &App) -> Option; fn toggle_data_collection(&self, cx: &mut App); - fn needs_terms_acceptance(&self, cx: &App) -> bool; fn is_refreshing(&self, cx: &App) -> bool; fn refresh( &self, @@ -196,10 +192,6 @@ where self.read(cx).is_enabled(buffer, cursor_position, cx) } - fn needs_terms_acceptance(&self, cx: &App) -> bool { - self.read(cx).needs_terms_acceptance(cx) - } - fn is_refreshing(&self, cx: &App) -> bool { self.read(cx).is_refreshing() } diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs index 4f69af7ee4..0e3fe8cb1a 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -242,13 +242,9 @@ impl Render for EditPredictionButton { IconName::ZedPredictDisabled }; - if zeta::should_show_upsell_modal(&self.user_store, cx) { + if zeta::should_show_upsell_modal() { let tooltip_meta = if self.user_store.read(cx).current_user().is_some() { - if self.user_store.read(cx).has_accepted_terms_of_service() { - "Choose a Plan" - } else { - "Accept the Terms of Service" - } + "Choose a Plan" } else { "Sign In" }; diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 2af8e6c0e4..29e009fdf8 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -253,7 +253,6 @@ pub type RenderDiffHunkControlsFn = Arc< enum ReportEditorEvent { Saved { auto_saved: bool }, EditorOpened, - ZetaTosClicked, Closed, } @@ -262,7 +261,6 @@ impl ReportEditorEvent { match self { Self::Saved { .. } => "Editor Saved", Self::EditorOpened => "Editor Opened", - Self::ZetaTosClicked => "Edit Prediction Provider ToS Clicked", Self::Closed => "Editor Closed", } } @@ -5576,6 +5574,11 @@ impl Editor { .as_ref() .is_none_or(|query| !query.chars().any(|c| c.is_digit(10))); + let omit_word_completions = match &query { + Some(query) => query.chars().count() < completion_settings.words_min_length, + None => completion_settings.words_min_length != 0, + }; + let (mut words, provider_responses) = match &provider { Some(provider) => { let provider_responses = provider.completions( @@ -5587,9 +5590,11 @@ impl Editor { cx, ); - let words = match completion_settings.words { - WordsCompletionMode::Disabled => Task::ready(BTreeMap::default()), - WordsCompletionMode::Enabled | WordsCompletionMode::Fallback => cx + let words = match (omit_word_completions, completion_settings.words) { + (true, _) | (_, WordsCompletionMode::Disabled) => { + Task::ready(BTreeMap::default()) + } + (false, WordsCompletionMode::Enabled | WordsCompletionMode::Fallback) => cx .background_spawn(async move { buffer_snapshot.words_in_range(WordsQuery { fuzzy_contents: None, @@ -5601,16 +5606,20 @@ impl Editor { (words, provider_responses) } - None => ( - cx.background_spawn(async move { - buffer_snapshot.words_in_range(WordsQuery { - fuzzy_contents: None, - range: word_search_range, - skip_digits, + None => { + let words = if omit_word_completions { + Task::ready(BTreeMap::default()) + } else { + cx.background_spawn(async move { + buffer_snapshot.words_in_range(WordsQuery { + fuzzy_contents: None, + range: word_search_range, + skip_digits, + }) }) - }), - Task::ready(Ok(Vec::new())), - ), + }; + (words, Task::ready(Ok(Vec::new()))) + } }; let snippet_sort_order = EditorSettings::get_global(cx).snippet_sort_order; @@ -9169,45 +9178,6 @@ impl Editor { let provider = self.edit_prediction_provider.as_ref()?; let provider_icon = Self::get_prediction_provider_icon_name(&self.edit_prediction_provider); - if provider.provider.needs_terms_acceptance(cx) { - return Some( - h_flex() - .min_w(min_width) - .flex_1() - .px_2() - .py_1() - .gap_3() - .elevation_2(cx) - .hover(|style| style.bg(cx.theme().colors().element_hover)) - .id("accept-terms") - .cursor_pointer() - .on_mouse_down(MouseButton::Left, |_, window, _| window.prevent_default()) - .on_click(cx.listener(|this, _event, window, cx| { - cx.stop_propagation(); - this.report_editor_event(ReportEditorEvent::ZetaTosClicked, None, cx); - window.dispatch_action( - zed_actions::OpenZedPredictOnboarding.boxed_clone(), - cx, - ); - })) - .child( - h_flex() - .flex_1() - .gap_2() - .child(Icon::new(provider_icon)) - .child(Label::new("Accept Terms of Service")) - .child(div().w_full()) - .child( - Icon::new(IconName::ArrowUpRight) - .color(Color::Muted) - .size(IconSize::Small), - ) - .into_any_element(), - ) - .into_any(), - ); - } - let is_refreshing = provider.provider.is_refreshing(cx); fn pending_completion_container(icon: IconName) -> Div { @@ -9809,6 +9779,9 @@ impl Editor { } pub fn backspace(&mut self, _: &Backspace, window: &mut Window, cx: &mut Context) { + if self.read_only(cx) { + return; + } self.hide_mouse_cursor(HideMouseCursorOrigin::TypingAction, cx); self.transact(window, cx, |this, window, cx| { this.select_autoclose_pair(window, cx); @@ -9902,6 +9875,9 @@ impl Editor { } pub fn delete(&mut self, _: &Delete, window: &mut Window, cx: &mut Context) { + if self.read_only(cx) { + return; + } self.hide_mouse_cursor(HideMouseCursorOrigin::TypingAction, cx); self.transact(window, cx, |this, window, cx| { this.change_selections(Default::default(), window, cx, |s| { diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index 96261fdb2c..2cfdb92593 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -57,7 +57,9 @@ use util::{ use workspace::{ CloseActiveItem, CloseAllItems, CloseOtherItems, MoveItemToPaneInDirection, NavigationEntry, OpenOptions, ViewId, + invalid_buffer_view::InvalidBufferView, item::{FollowEvent, FollowableItem, Item, ItemHandle, SaveOptions}, + register_project_item, }; #[gpui::test] @@ -12237,6 +12239,7 @@ async fn test_completion_mode(cx: &mut TestAppContext) { settings.defaults.completions = Some(CompletionSettings { lsp_insert_mode, words: WordsCompletionMode::Disabled, + words_min_length: 0, lsp: true, lsp_fetch_timeout_ms: 0, }); @@ -12295,6 +12298,7 @@ async fn test_completion_with_mode_specified_by_action(cx: &mut TestAppContext) update_test_language_settings(&mut cx, |settings| { settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Disabled, + words_min_length: 0, // set the opposite here to ensure that the action is overriding the default behavior lsp_insert_mode: LspInsertMode::Insert, lsp: true, @@ -12331,6 +12335,7 @@ async fn test_completion_with_mode_specified_by_action(cx: &mut TestAppContext) update_test_language_settings(&mut cx, |settings| { settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Disabled, + words_min_length: 0, // set the opposite here to ensure that the action is overriding the default behavior lsp_insert_mode: LspInsertMode::Replace, lsp: true, @@ -13072,6 +13077,7 @@ async fn test_word_completion(cx: &mut TestAppContext) { init_test(cx, |language_settings| { language_settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Fallback, + words_min_length: 0, lsp: true, lsp_fetch_timeout_ms: 10, lsp_insert_mode: LspInsertMode::Insert, @@ -13168,6 +13174,7 @@ async fn test_word_completions_do_not_duplicate_lsp_ones(cx: &mut TestAppContext init_test(cx, |language_settings| { language_settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Enabled, + words_min_length: 0, lsp: true, lsp_fetch_timeout_ms: 0, lsp_insert_mode: LspInsertMode::Insert, @@ -13231,6 +13238,7 @@ async fn test_word_completions_continue_on_typing(cx: &mut TestAppContext) { init_test(cx, |language_settings| { language_settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Disabled, + words_min_length: 0, lsp: true, lsp_fetch_timeout_ms: 0, lsp_insert_mode: LspInsertMode::Insert, @@ -13304,6 +13312,7 @@ async fn test_word_completions_usually_skip_digits(cx: &mut TestAppContext) { init_test(cx, |language_settings| { language_settings.defaults.completions = Some(CompletionSettings { words: WordsCompletionMode::Fallback, + words_min_length: 0, lsp: false, lsp_fetch_timeout_ms: 0, lsp_insert_mode: LspInsertMode::Insert, @@ -13361,6 +13370,56 @@ async fn test_word_completions_usually_skip_digits(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_word_completions_do_not_show_before_threshold(cx: &mut TestAppContext) { + init_test(cx, |language_settings| { + language_settings.defaults.completions = Some(CompletionSettings { + words: WordsCompletionMode::Enabled, + words_min_length: 3, + lsp: true, + lsp_fetch_timeout_ms: 0, + lsp_insert_mode: LspInsertMode::Insert, + }); + }); + + let mut cx = EditorLspTestContext::new_rust(lsp::ServerCapabilities::default(), cx).await; + cx.set_state(indoc! {"ˇ + wow + wowen + wowser + "}); + cx.simulate_keystroke("w"); + cx.executor().run_until_parked(); + cx.update_editor(|editor, _, _| { + if editor.context_menu.borrow_mut().is_some() { + panic!( + "expected completion menu to be hidden, as words completion threshold is not met" + ); + } + }); + + cx.simulate_keystroke("o"); + cx.executor().run_until_parked(); + cx.update_editor(|editor, _, _| { + if editor.context_menu.borrow_mut().is_some() { + panic!( + "expected completion menu to be hidden, as words completion threshold is not met still" + ); + } + }); + + cx.simulate_keystroke("w"); + cx.executor().run_until_parked(); + cx.update_editor(|editor, _, _| { + if let Some(CodeContextMenu::Completions(menu)) = editor.context_menu.borrow_mut().as_ref() + { + assert_eq!(completion_menu_entries(menu), &["wowen", "wowser"], "After word completion threshold is met, matching words should be shown, excluding the already typed word"); + } else { + panic!("expected completion menu to be open after the word completions threshold is met"); + } + }); +} + fn gen_text_edit(params: &CompletionParams, text: &str) -> Option { let position = || lsp::Position { line: params.text_document_position.position.line, @@ -22656,7 +22715,7 @@ async fn test_invisible_worktree_servers(cx: &mut TestAppContext) { .await .unwrap(); pane.update_in(cx, |pane, window, cx| { - pane.navigate_backward(window, cx); + pane.navigate_backward(&Default::default(), window, cx); }); cx.run_until_parked(); pane.update(cx, |pane, cx| { @@ -24243,7 +24302,7 @@ async fn test_document_colors(cx: &mut TestAppContext) { workspace .update(cx, |workspace, window, cx| { workspace.active_pane().update(cx, |pane, cx| { - pane.navigate_backward(window, cx); + pane.navigate_backward(&Default::default(), window, cx); }) }) .unwrap(); @@ -24291,6 +24350,41 @@ async fn test_newline_replacement_in_single_line(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_non_utf_8_opens(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + cx.update(|cx| { + register_project_item::(cx); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/root1", json!({})).await; + fs.insert_file("/root1/one.pdf", vec![0xff, 0xfe, 0xfd]) + .await; + + let project = Project::test(fs, ["/root1".as_ref()], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let worktree_id = project.update(cx, |project, cx| { + project.worktrees(cx).next().unwrap().read(cx).id() + }); + + let handle = workspace + .update_in(cx, |workspace, window, cx| { + let project_path = (worktree_id, "one.pdf"); + workspace.open_path(project_path, None, true, window, cx) + }) + .await + .unwrap(); + + assert_eq!( + handle.to_any().entity_type(), + TypeId::of::() + ); +} + #[track_caller] fn extract_color_inlays(editor: &Editor, cx: &App) -> Vec { editor diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 797b0d6634..32582ba941 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -74,6 +74,7 @@ use std::{ fmt::{self, Write}, iter, mem, ops::{Deref, Range}, + path::Path, rc::Rc, sync::Arc, time::{Duration, Instant}, @@ -3693,7 +3694,12 @@ impl EditorElement { }) .take(1), ) - .children(indicator) + .child( + h_flex() + .size(Pixels(12.0)) + .justify_center() + .children(indicator), + ) .child( h_flex() .cursor_pointer() @@ -3782,25 +3788,31 @@ impl EditorElement { && let Some(worktree) = project.read(cx).worktree_for_id(file.worktree_id(cx), cx) { + let worktree = worktree.read(cx); let relative_path = file.path(); - let entry_for_path = worktree.read(cx).entry_for_path(relative_path); - let abs_path = entry_for_path.and_then(|e| e.canonical_path.as_deref()); - let has_relative_path = - worktree.read(cx).root_entry().is_some_and(Entry::is_dir); + let entry_for_path = worktree.entry_for_path(relative_path); + let abs_path = entry_for_path.map(|e| { + e.canonical_path.as_deref().map_or_else( + || worktree.abs_path().join(relative_path), + Path::to_path_buf, + ) + }); + let has_relative_path = worktree.root_entry().is_some_and(Entry::is_dir); - let parent_abs_path = - abs_path.and_then(|abs_path| Some(abs_path.parent()?.to_path_buf())); + let parent_abs_path = abs_path + .as_ref() + .and_then(|abs_path| Some(abs_path.parent()?.to_path_buf())); let relative_path = has_relative_path .then_some(relative_path) .map(ToOwned::to_owned); let visible_in_project_panel = - relative_path.is_some() && worktree.read(cx).is_visible(); + relative_path.is_some() && worktree.is_visible(); let reveal_in_project_panel = entry_for_path .filter(|_| visible_in_project_panel) .map(|entry| entry.id); menu = menu - .when_some(abs_path.map(ToOwned::to_owned), |menu, abs_path| { + .when_some(abs_path, |menu, abs_path| { menu.entry( "Copy Path", Some(Box::new(zed_actions::workspace::CopyPath)), diff --git a/crates/editor/src/items.rs b/crates/editor/src/items.rs index afc5767de0..641e8a97ed 100644 --- a/crates/editor/src/items.rs +++ b/crates/editor/src/items.rs @@ -42,6 +42,7 @@ use ui::{IconDecorationKind, prelude::*}; use util::{ResultExt, TryFutureExt, paths::PathExt}; use workspace::{ CollaboratorId, ItemId, ItemNavHistory, ToolbarItemLocation, ViewId, Workspace, WorkspaceId, + invalid_buffer_view::InvalidBufferView, item::{FollowableItem, Item, ItemEvent, ProjectItem, SaveOptions}, searchable::{Direction, SearchEvent, SearchableItem, SearchableItemHandle}, }; @@ -1401,6 +1402,16 @@ impl ProjectItem for Editor { editor } + + fn for_broken_project_item( + abs_path: PathBuf, + is_local: bool, + e: &anyhow::Error, + window: &mut Window, + cx: &mut App, + ) -> Option { + Some(InvalidBufferView::new(abs_path, is_local, e, window, cx)) + } } fn clip_ranges<'a>( diff --git a/crates/editor/src/rust_analyzer_ext.rs b/crates/editor/src/rust_analyzer_ext.rs index e3d83ab160..cf74ee0a9e 100644 --- a/crates/editor/src/rust_analyzer_ext.rs +++ b/crates/editor/src/rust_analyzer_ext.rs @@ -26,6 +26,17 @@ fn is_rust_language(language: &Language) -> bool { } pub fn apply_related_actions(editor: &Entity, window: &mut Window, cx: &mut App) { + if editor.read(cx).project().is_some_and(|project| { + project + .read(cx) + .language_server_statuses(cx) + .any(|(_, status)| status.name == RUST_ANALYZER_NAME) + }) { + register_action(editor, window, cancel_flycheck_action); + register_action(editor, window, run_flycheck_action); + register_action(editor, window, clear_flycheck_action); + } + if editor .read(cx) .buffer() @@ -38,9 +49,6 @@ pub fn apply_related_actions(editor: &Entity, window: &mut Window, cx: & register_action(editor, window, go_to_parent_module); register_action(editor, window, expand_macro_recursively); register_action(editor, window, open_docs); - register_action(editor, window, cancel_flycheck_action); - register_action(editor, window, run_flycheck_action); - register_action(editor, window, clear_flycheck_action); } } @@ -309,7 +317,7 @@ fn cancel_flycheck_action( let Some(project) = &editor.project else { return; }; - let Some(buffer_id) = editor + let buffer_id = editor .selections .disjoint_anchors() .iter() @@ -321,10 +329,7 @@ fn cancel_flycheck_action( .read(cx) .entry_id(cx)?; project.path_for_entry(entry_id, cx) - }) - else { - return; - }; + }); cancel_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx); } @@ -337,7 +342,7 @@ fn run_flycheck_action( let Some(project) = &editor.project else { return; }; - let Some(buffer_id) = editor + let buffer_id = editor .selections .disjoint_anchors() .iter() @@ -349,10 +354,7 @@ fn run_flycheck_action( .read(cx) .entry_id(cx)?; project.path_for_entry(entry_id, cx) - }) - else { - return; - }; + }); run_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx); } @@ -365,7 +367,7 @@ fn clear_flycheck_action( let Some(project) = &editor.project else { return; }; - let Some(buffer_id) = editor + let buffer_id = editor .selections .disjoint_anchors() .iter() @@ -377,9 +379,6 @@ fn clear_flycheck_action( .read(cx) .entry_id(cx)?; project.path_for_entry(entry_id, cx) - }) - else { - return; - }; + }); clear_flycheck(project.clone(), buffer_id, cx).detach_and_log_err(cx); } diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index 38f02c2206..4fc6039fd7 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -34,6 +34,7 @@ pub enum IconName { ArrowRightLeft, ArrowUp, ArrowUpRight, + Attach, AudioOff, AudioOn, Backspace, @@ -164,6 +165,7 @@ pub enum IconName { PageDown, PageUp, Pencil, + PencilUnavailable, Person, Pin, PlayOutlined, diff --git a/crates/language/src/language_settings.rs b/crates/language/src/language_settings.rs index 386ad19747..0f82d3997f 100644 --- a/crates/language/src/language_settings.rs +++ b/crates/language/src/language_settings.rs @@ -350,6 +350,12 @@ pub struct CompletionSettings { /// Default: `fallback` #[serde(default = "default_words_completion_mode")] pub words: WordsCompletionMode, + /// How many characters has to be in the completions query to automatically show the words-based completions. + /// Before that value, it's still possible to trigger the words-based completion manually with the corresponding editor command. + /// + /// Default: 3 + #[serde(default = "default_3")] + pub words_min_length: usize, /// Whether to fetch LSP completions or not. /// /// Default: true @@ -359,7 +365,7 @@ pub struct CompletionSettings { /// When set to 0, waits indefinitely. /// /// Default: 0 - #[serde(default = "default_lsp_fetch_timeout_ms")] + #[serde(default)] pub lsp_fetch_timeout_ms: u64, /// Controls how LSP completions are inserted. /// @@ -405,8 +411,8 @@ fn default_lsp_insert_mode() -> LspInsertMode { LspInsertMode::ReplaceSuffix } -fn default_lsp_fetch_timeout_ms() -> u64 { - 0 +fn default_3() -> usize { + 3 } /// The settings for a particular language. @@ -1468,6 +1474,7 @@ impl settings::Settings for AllLanguageSettings { } else { d.completions = Some(CompletionSettings { words: mode, + words_min_length: 3, lsp: true, lsp_fetch_timeout_ms: 0, lsp_insert_mode: LspInsertMode::ReplaceSuffix, diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 158bebcbbf..e0a3866443 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -14,7 +14,7 @@ use client::Client; use cloud_llm_client::{CompletionMode, CompletionRequestStatus}; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; -use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; +use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window}; use http_client::{StatusCode, http}; use icons::IconName; use parking_lot::Mutex; @@ -640,16 +640,6 @@ pub trait LanguageModelProvider: 'static { window: &mut Window, cx: &mut App, ) -> AnyView; - fn must_accept_terms(&self, _cx: &App) -> bool { - false - } - fn render_accept_terms( - &self, - _view: LanguageModelProviderTosView, - _cx: &mut App, - ) -> Option { - None - } fn reset_credentials(&self, cx: &mut App) -> Task>; } diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index bcbb3404a8..c7693a64c7 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -24,9 +24,6 @@ pub enum ConfigurationError { ModelNotFound, #[error("{} LLM provider is not configured.", .0.name().0)] ProviderNotAuthenticated(Arc), - #[error("Using the {} LLM provider requires accepting the Terms of Service.", - .0.name().0)] - ProviderPendingTermsAcceptance(Arc), } impl std::fmt::Debug for ConfigurationError { @@ -37,9 +34,6 @@ impl std::fmt::Debug for ConfigurationError { Self::ProviderNotAuthenticated(provider) => { write!(f, "ProviderNotAuthenticated({})", provider.id()) } - Self::ProviderPendingTermsAcceptance(provider) => { - write!(f, "ProviderPendingTermsAcceptance({})", provider.id()) - } } } } @@ -198,12 +192,6 @@ impl LanguageModelRegistry { return Some(ConfigurationError::ProviderNotAuthenticated(model.provider)); } - if model.provider.must_accept_terms(cx) { - return Some(ConfigurationError::ProviderPendingTermsAcceptance( - model.provider, - )); - } - None } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 8e4b786935..fb6e2fb1e4 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -23,9 +23,9 @@ use language_model::{ AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, - ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, + LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, + LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError, + PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, }; use release_channel::AppVersion; use schemars::JsonSchema; @@ -118,7 +118,6 @@ pub struct State { llm_api_token: LlmApiToken, user_store: Entity, status: client::Status, - accept_terms_of_service_task: Option>>, models: Vec>, default_model: Option>, default_fast_model: Option>, @@ -142,7 +141,6 @@ impl State { llm_api_token: LlmApiToken::default(), user_store, status, - accept_terms_of_service_task: None, models: Vec::new(), default_model: None, default_fast_model: None, @@ -197,24 +195,6 @@ impl State { state.update(cx, |_, cx| cx.notify()) }) } - - fn has_accepted_terms_of_service(&self, cx: &App) -> bool { - self.user_store.read(cx).has_accepted_terms_of_service() - } - - fn accept_terms_of_service(&mut self, cx: &mut Context) { - let user_store = self.user_store.clone(); - self.accept_terms_of_service_task = Some(cx.spawn(async move |this, cx| { - let _ = user_store - .update(cx, |store, cx| store.accept_terms_of_service(cx))? - .await; - this.update(cx, |this, cx| { - this.accept_terms_of_service_task = None; - cx.notify() - }) - })); - } - fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context) { let mut models = Vec::new(); @@ -384,7 +364,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn is_authenticated(&self, cx: &App) -> bool { let state = self.state.read(cx); - !state.is_signed_out(cx) && state.has_accepted_terms_of_service(cx) + !state.is_signed_out(cx) } fn authenticate(&self, _cx: &mut App) -> Task> { @@ -401,112 +381,11 @@ impl LanguageModelProvider for CloudLanguageModelProvider { .into() } - fn must_accept_terms(&self, cx: &App) -> bool { - !self.state.read(cx).has_accepted_terms_of_service(cx) - } - - fn render_accept_terms( - &self, - view: LanguageModelProviderTosView, - cx: &mut App, - ) -> Option { - let state = self.state.read(cx); - if state.has_accepted_terms_of_service(cx) { - return None; - } - Some( - render_accept_terms(view, state.accept_terms_of_service_task.is_some(), { - let state = self.state.clone(); - move |_window, cx| { - state.update(cx, |state, cx| state.accept_terms_of_service(cx)); - } - }) - .into_any_element(), - ) - } - fn reset_credentials(&self, _cx: &mut App) -> Task> { Task::ready(Ok(())) } } -fn render_accept_terms( - view_kind: LanguageModelProviderTosView, - accept_terms_of_service_in_progress: bool, - accept_terms_callback: impl Fn(&mut Window, &mut App) + 'static, -) -> impl IntoElement { - let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart); - let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadEmptyState); - - let terms_button = Button::new("terms_of_service", "Terms of Service") - .style(ButtonStyle::Subtle) - .icon(IconName::ArrowUpRight) - .icon_color(Color::Muted) - .icon_size(IconSize::Small) - .when(thread_empty_state, |this| this.label_size(LabelSize::Small)) - .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service")); - - let button_container = h_flex().child( - Button::new("accept_terms", "I accept the Terms of Service") - .when(!thread_empty_state, |this| { - this.full_width() - .style(ButtonStyle::Tinted(TintColor::Accent)) - .icon(IconName::Check) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - }) - .when(thread_empty_state, |this| { - this.style(ButtonStyle::Tinted(TintColor::Warning)) - .label_size(LabelSize::Small) - }) - .disabled(accept_terms_of_service_in_progress) - .on_click(move |_, window, cx| (accept_terms_callback)(window, cx)), - ); - - if thread_empty_state { - h_flex() - .w_full() - .flex_wrap() - .justify_between() - .child( - h_flex() - .child( - Label::new("To start using Zed AI, please read and accept the") - .size(LabelSize::Small), - ) - .child(terms_button), - ) - .child(button_container) - } else { - v_flex() - .w_full() - .gap_2() - .child( - h_flex() - .flex_wrap() - .when(thread_fresh_start, |this| this.justify_center()) - .child(Label::new( - "To start using Zed AI, please read and accept the", - )) - .child(terms_button), - ) - .child({ - match view_kind { - LanguageModelProviderTosView::TextThreadPopup => { - button_container.w_full().justify_end() - } - LanguageModelProviderTosView::Configuration => { - button_container.w_full().justify_start() - } - LanguageModelProviderTosView::ThreadFreshStart => { - button_container.w_full().justify_center() - } - LanguageModelProviderTosView::ThreadEmptyState => div().w_0(), - } - }) - } -} - pub struct CloudLanguageModel { id: LanguageModelId, model: Arc, @@ -1107,10 +986,7 @@ struct ZedAiConfiguration { plan: Option, subscription_period: Option<(DateTime, DateTime)>, eligible_for_trial: bool, - has_accepted_terms_of_service: bool, account_too_young: bool, - accept_terms_of_service_in_progress: bool, - accept_terms_of_service_callback: Arc, sign_in_callback: Arc, } @@ -1176,58 +1052,30 @@ impl RenderOnce for ZedAiConfiguration { ); } - v_flex() - .gap_2() - .w_full() - .when(!self.has_accepted_terms_of_service, |this| { - this.child(render_accept_terms( - LanguageModelProviderTosView::Configuration, - self.accept_terms_of_service_in_progress, - { - let callback = self.accept_terms_of_service_callback.clone(); - move |window, cx| (callback)(window, cx) - }, - )) - }) - .map(|this| { - if self.has_accepted_terms_of_service && self.account_too_young { - this.child(young_account_banner).child( - Button::new("upgrade", "Upgrade to Pro") - .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent)) - .full_width() - .on_click(|_, _, cx| { - cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)) - }), - ) - } else if self.has_accepted_terms_of_service { - this.text_sm() - .child(subscription_text) - .child(manage_subscription_buttons) - } else { - this - } - }) - .when(self.has_accepted_terms_of_service, |this| this) + v_flex().gap_2().w_full().map(|this| { + if self.account_too_young { + this.child(young_account_banner).child( + Button::new("upgrade", "Upgrade to Pro") + .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent)) + .full_width() + .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))), + ) + } else { + this.text_sm() + .child(subscription_text) + .child(manage_subscription_buttons) + } + }) } } struct ConfigurationView { state: Entity, - accept_terms_of_service_callback: Arc, sign_in_callback: Arc, } impl ConfigurationView { fn new(state: Entity) -> Self { - let accept_terms_of_service_callback = Arc::new({ - let state = state.clone(); - move |_window: &mut Window, cx: &mut App| { - state.update(cx, |state, cx| { - state.accept_terms_of_service(cx); - }); - } - }); - let sign_in_callback = Arc::new({ let state = state.clone(); move |_window: &mut Window, cx: &mut App| { @@ -1239,7 +1087,6 @@ impl ConfigurationView { Self { state, - accept_terms_of_service_callback, sign_in_callback, } } @@ -1255,10 +1102,7 @@ impl Render for ConfigurationView { plan: user_store.plan(), subscription_period: user_store.subscription_period(), eligible_for_trial: user_store.trial_started_at().is_none(), - has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx), account_too_young: user_store.account_too_young(), - accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(), - accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(), sign_in_callback: self.sign_in_callback.clone(), } } @@ -1283,7 +1127,6 @@ impl Component for ZedAiConfiguration { plan: Option, eligible_for_trial: bool, account_too_young: bool, - has_accepted_terms_of_service: bool, ) -> AnyElement { ZedAiConfiguration { is_connected, @@ -1292,10 +1135,7 @@ impl Component for ZedAiConfiguration { .is_some() .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))), eligible_for_trial, - has_accepted_terms_of_service, account_too_young, - accept_terms_of_service_in_progress: false, - accept_terms_of_service_callback: Arc::new(|_, _| {}), sign_in_callback: Arc::new(|_, _| {}), } .into_any_element() @@ -1306,33 +1146,30 @@ impl Component for ZedAiConfiguration { .p_4() .gap_4() .children(vec![ - single_example( - "Not connected", - configuration(false, None, false, false, true), - ), + single_example("Not connected", configuration(false, None, false, false)), single_example( "Accept Terms of Service", - configuration(true, None, true, false, false), + configuration(true, None, true, false), ), single_example( "No Plan - Not eligible for trial", - configuration(true, None, false, false, true), + configuration(true, None, false, false), ), single_example( "No Plan - Eligible for trial", - configuration(true, None, true, false, true), + configuration(true, None, true, false), ), single_example( "Free Plan", - configuration(true, Some(Plan::ZedFree), true, false, true), + configuration(true, Some(Plan::ZedFree), true, false), ), single_example( "Zed Pro Trial Plan", - configuration(true, Some(Plan::ZedProTrial), true, false, true), + configuration(true, Some(Plan::ZedProTrial), true, false), ), single_example( "Zed Pro Plan", - configuration(true, Some(Plan::ZedPro), true, false, true), + configuration(true, Some(Plan::ZedPro), true, false), ), ]) .into_any_element(), diff --git a/crates/languages/src/rust/highlights.scm b/crates/languages/src/rust/highlights.scm index 1c46061827..9c02fbedaa 100644 --- a/crates/languages/src/rust/highlights.scm +++ b/crates/languages/src/rust/highlights.scm @@ -6,6 +6,9 @@ (self) @variable.special (field_identifier) @property +(shorthand_field_initializer + (identifier) @property) + (trait_item name: (type_identifier) @type.interface) (impl_item trait: (type_identifier) @type.interface) (abstract_type trait: (type_identifier) @type.interface) @@ -38,11 +41,20 @@ (identifier) @function.special (scoped_identifier name: (identifier) @function.special) - ]) + ] + "!" @function.special) (macro_definition name: (identifier) @function.special.definition) +(mod_item + name: (identifier) @module) + +(visibility_modifier [ + (crate) @keyword + (super) @keyword +]) + ; Identifier conventions ; Assume uppercase names are types/enum-constructors @@ -115,9 +127,7 @@ "where" "while" "yield" - (crate) (mutable_specifier) - (super) ] @keyword [ @@ -189,6 +199,7 @@ operator: "/" @operator (lifetime) @lifetime +(lifetime (identifier) @lifetime) (parameter (identifier) @variable.parameter) diff --git a/crates/onboarding/src/editing_page.rs b/crates/onboarding/src/editing_page.rs index 8fae695854..47dfd84894 100644 --- a/crates/onboarding/src/editing_page.rs +++ b/crates/onboarding/src/editing_page.rs @@ -606,7 +606,7 @@ fn render_popular_settings_section( cx: &mut App, ) -> impl IntoElement { const LIGATURE_TOOLTIP: &str = - "Font ligatures combine two characters into one. For example, turning =/= into ≠."; + "Font ligatures combine two characters into one. For example, turning != into ≠."; v_flex() .pt_6() diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index cc3a0a05bb..fb1fae3736 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -9029,13 +9029,22 @@ impl LspStore { lsp_store.update(&mut cx, |lsp_store, cx| { if let Some(server) = lsp_store.language_server_for_id(server_id) { let text_document = if envelope.payload.current_file_only { - let buffer_id = BufferId::new(envelope.payload.buffer_id)?; - lsp_store - .buffer_store() - .read(cx) - .get(buffer_id) - .and_then(|buffer| Some(buffer.read(cx).file()?.as_local()?.abs_path(cx))) - .map(|path| make_text_document_identifier(&path)) + let buffer_id = envelope + .payload + .buffer_id + .map(|id| BufferId::new(id)) + .transpose()?; + buffer_id + .and_then(|buffer_id| { + lsp_store + .buffer_store() + .read(cx) + .get(buffer_id) + .and_then(|buffer| { + Some(buffer.read(cx).file()?.as_local()?.abs_path(cx)) + }) + .map(|path| make_text_document_identifier(&path)) + }) .transpose()? } else { None diff --git a/crates/project/src/lsp_store/rust_analyzer_ext.rs b/crates/project/src/lsp_store/rust_analyzer_ext.rs index e5e6338d3c..54f63220b1 100644 --- a/crates/project/src/lsp_store/rust_analyzer_ext.rs +++ b/crates/project/src/lsp_store/rust_analyzer_ext.rs @@ -1,8 +1,8 @@ use ::serde::{Deserialize, Serialize}; use anyhow::Context as _; -use gpui::{App, Entity, Task, WeakEntity}; -use language::ServerHealth; -use lsp::{LanguageServer, LanguageServerName}; +use gpui::{App, AsyncApp, Entity, Task, WeakEntity}; +use language::{Buffer, ServerHealth}; +use lsp::{LanguageServer, LanguageServerId, LanguageServerName}; use rpc::proto; use crate::{LspStore, LspStoreEvent, Project, ProjectPath, lsp_store}; @@ -83,31 +83,32 @@ pub fn register_notifications(lsp_store: WeakEntity, language_server: pub fn cancel_flycheck( project: Entity, - buffer_path: ProjectPath, + buffer_path: Option, cx: &mut App, ) -> Task> { let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client(); let lsp_store = project.read(cx).lsp_store(); - let buffer = project.update(cx, |project, cx| { - project.buffer_store().update(cx, |buffer_store, cx| { - buffer_store.open_buffer(buffer_path, cx) + let buffer = buffer_path.map(|buffer_path| { + project.update(cx, |project, cx| { + project.buffer_store().update(cx, |buffer_store, cx| { + buffer_store.open_buffer(buffer_path, cx) + }) }) }); cx.spawn(async move |cx| { - let buffer = buffer.await?; - let Some(rust_analyzer_server) = project.read_with(cx, |project, cx| { - project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) - })? + let buffer = match buffer { + Some(buffer) => Some(buffer.await?), + None => None, + }; + let Some(rust_analyzer_server) = find_rust_analyzer_server(&project, buffer.as_ref(), cx) else { return Ok(()); }; - let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id().to_proto())?; if let Some((client, project_id)) = upstream_client { let request = proto::LspExtCancelFlycheck { project_id, - buffer_id, language_server_id: rust_analyzer_server.to_proto(), }; client @@ -130,28 +131,33 @@ pub fn cancel_flycheck( pub fn run_flycheck( project: Entity, - buffer_path: ProjectPath, + buffer_path: Option, cx: &mut App, ) -> Task> { let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client(); let lsp_store = project.read(cx).lsp_store(); - let buffer = project.update(cx, |project, cx| { - project.buffer_store().update(cx, |buffer_store, cx| { - buffer_store.open_buffer(buffer_path, cx) + let buffer = buffer_path.map(|buffer_path| { + project.update(cx, |project, cx| { + project.buffer_store().update(cx, |buffer_store, cx| { + buffer_store.open_buffer(buffer_path, cx) + }) }) }); cx.spawn(async move |cx| { - let buffer = buffer.await?; - let Some(rust_analyzer_server) = project.read_with(cx, |project, cx| { - project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) - })? + let buffer = match buffer { + Some(buffer) => Some(buffer.await?), + None => None, + }; + let Some(rust_analyzer_server) = find_rust_analyzer_server(&project, buffer.as_ref(), cx) else { return Ok(()); }; - let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id().to_proto())?; if let Some((client, project_id)) = upstream_client { + let buffer_id = buffer + .map(|buffer| buffer.read_with(cx, |buffer, _| buffer.remote_id().to_proto())) + .transpose()?; let request = proto::LspExtRunFlycheck { project_id, buffer_id, @@ -182,31 +188,32 @@ pub fn run_flycheck( pub fn clear_flycheck( project: Entity, - buffer_path: ProjectPath, + buffer_path: Option, cx: &mut App, ) -> Task> { let upstream_client = project.read(cx).lsp_store().read(cx).upstream_client(); let lsp_store = project.read(cx).lsp_store(); - let buffer = project.update(cx, |project, cx| { - project.buffer_store().update(cx, |buffer_store, cx| { - buffer_store.open_buffer(buffer_path, cx) + let buffer = buffer_path.map(|buffer_path| { + project.update(cx, |project, cx| { + project.buffer_store().update(cx, |buffer_store, cx| { + buffer_store.open_buffer(buffer_path, cx) + }) }) }); cx.spawn(async move |cx| { - let buffer = buffer.await?; - let Some(rust_analyzer_server) = project.read_with(cx, |project, cx| { - project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) - })? + let buffer = match buffer { + Some(buffer) => Some(buffer.await?), + None => None, + }; + let Some(rust_analyzer_server) = find_rust_analyzer_server(&project, buffer.as_ref(), cx) else { return Ok(()); }; - let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id().to_proto())?; if let Some((client, project_id)) = upstream_client { let request = proto::LspExtClearFlycheck { project_id, - buffer_id, language_server_id: rust_analyzer_server.to_proto(), }; client @@ -226,3 +233,40 @@ pub fn clear_flycheck( anyhow::Ok(()) }) } + +fn find_rust_analyzer_server( + project: &Entity, + buffer: Option<&Entity>, + cx: &mut AsyncApp, +) -> Option { + project + .read_with(cx, |project, cx| { + buffer + .and_then(|buffer| { + project.language_server_id_for_name(buffer.read(cx), &RUST_ANALYZER_NAME, cx) + }) + // If no rust-analyzer found for the current buffer (e.g. `settings.json`), fall back to the project lookup + // and use project's rust-analyzer if it's the only one. + .or_else(|| { + let rust_analyzer_servers = project + .lsp_store() + .read(cx) + .language_server_statuses + .iter() + .filter_map(|(server_id, server_status)| { + if server_status.name == RUST_ANALYZER_NAME { + Some(*server_id) + } else { + None + } + }) + .collect::>(); + if rust_analyzer_servers.len() == 1 { + rust_analyzer_servers.first().copied() + } else { + None + } + }) + }) + .ok()? +} diff --git a/crates/project_panel/src/project_panel.rs b/crates/project_panel/src/project_panel.rs index 40e974dfba..aff734b770 100644 --- a/crates/project_panel/src/project_panel.rs +++ b/crates/project_panel/src/project_panel.rs @@ -69,6 +69,7 @@ use workspace::{ notifications::{DetachAndPromptErr, NotifyTaskExt}, }; use worktree::CreatedEntry; +use zed_actions::workspace::OpenWithSystem; const PROJECT_PANEL_KEY: &str = "ProjectPanel"; const NEW_ENTRY_ID: ProjectEntryId = ProjectEntryId::MAX; @@ -255,8 +256,6 @@ actions!( RevealInFileManager, /// Removes the selected folder from the project. RemoveFromProject, - /// Opens the selected file with the system's default application. - OpenWithSystem, /// Cuts the selected file or directory. Cut, /// Pastes the previously cut or copied item. diff --git a/crates/proto/proto/lsp.proto b/crates/proto/proto/lsp.proto index ac9c275aa2..473ef5c38c 100644 --- a/crates/proto/proto/lsp.proto +++ b/crates/proto/proto/lsp.proto @@ -834,21 +834,19 @@ message LspRunnable { message LspExtCancelFlycheck { uint64 project_id = 1; - uint64 buffer_id = 2; - uint64 language_server_id = 3; + uint64 language_server_id = 2; } message LspExtRunFlycheck { uint64 project_id = 1; - uint64 buffer_id = 2; + optional uint64 buffer_id = 2; uint64 language_server_id = 3; bool current_file_only = 4; } message LspExtClearFlycheck { uint64 project_id = 1; - uint64 buffer_id = 2; - uint64 language_server_id = 3; + uint64 language_server_id = 2; } message LspDiagnosticRelatedInformation { diff --git a/crates/recent_projects/src/disconnected_overlay.rs b/crates/recent_projects/src/disconnected_overlay.rs index dd4d788cfd..8ffe0ef07c 100644 --- a/crates/recent_projects/src/disconnected_overlay.rs +++ b/crates/recent_projects/src/disconnected_overlay.rs @@ -1,5 +1,3 @@ -use std::path::PathBuf; - use gpui::{ClickEvent, DismissEvent, EventEmitter, FocusHandle, Focusable, Render, WeakEntity}; use project::project_settings::ProjectSettings; use remote::SshConnectionOptions; @@ -103,17 +101,17 @@ impl DisconnectedOverlay { return; }; - let Some(ssh_project) = workspace.read(cx).serialized_ssh_project() else { - return; - }; - let Some(window_handle) = window.window_handle().downcast::() else { return; }; let app_state = workspace.read(cx).app_state().clone(); - - let paths = ssh_project.paths.iter().map(PathBuf::from).collect(); + let paths = workspace + .read(cx) + .root_paths(cx) + .iter() + .map(|path| path.to_path_buf()) + .collect(); cx.spawn_in(window, async move |_, cx| { open_ssh_project( diff --git a/crates/recent_projects/src/recent_projects.rs b/crates/recent_projects/src/recent_projects.rs index 2093e96cae..fa57b588cd 100644 --- a/crates/recent_projects/src/recent_projects.rs +++ b/crates/recent_projects/src/recent_projects.rs @@ -19,15 +19,12 @@ use picker::{ pub use remote_servers::RemoteServerProjects; use settings::Settings; pub use ssh_connections::SshSettings; -use std::{ - path::{Path, PathBuf}, - sync::Arc, -}; +use std::{path::Path, sync::Arc}; use ui::{KeyBinding, ListItem, ListItemSpacing, Tooltip, prelude::*, tooltip_container}; use util::{ResultExt, paths::PathExt}; use workspace::{ - CloseIntent, HistoryManager, ModalView, OpenOptions, SerializedWorkspaceLocation, WORKSPACE_DB, - Workspace, WorkspaceId, with_active_or_new_workspace, + CloseIntent, HistoryManager, ModalView, OpenOptions, PathList, SerializedWorkspaceLocation, + WORKSPACE_DB, Workspace, WorkspaceId, with_active_or_new_workspace, }; use zed_actions::{OpenRecent, OpenRemote}; @@ -154,7 +151,7 @@ impl Render for RecentProjects { pub struct RecentProjectsDelegate { workspace: WeakEntity, - workspaces: Vec<(WorkspaceId, SerializedWorkspaceLocation)>, + workspaces: Vec<(WorkspaceId, SerializedWorkspaceLocation, PathList)>, selected_match_index: usize, matches: Vec, render_paths: bool, @@ -178,12 +175,15 @@ impl RecentProjectsDelegate { } } - pub fn set_workspaces(&mut self, workspaces: Vec<(WorkspaceId, SerializedWorkspaceLocation)>) { + pub fn set_workspaces( + &mut self, + workspaces: Vec<(WorkspaceId, SerializedWorkspaceLocation, PathList)>, + ) { self.workspaces = workspaces; self.has_any_non_local_projects = !self .workspaces .iter() - .all(|(_, location)| matches!(location, SerializedWorkspaceLocation::Local(_, _))); + .all(|(_, location, _)| matches!(location, SerializedWorkspaceLocation::Local)); } } impl EventEmitter for RecentProjectsDelegate {} @@ -236,15 +236,14 @@ impl PickerDelegate for RecentProjectsDelegate { .workspaces .iter() .enumerate() - .filter(|(_, (id, _))| !self.is_current_workspace(*id, cx)) - .map(|(id, (_, location))| { - let combined_string = location - .sorted_paths() + .filter(|(_, (id, _, _))| !self.is_current_workspace(*id, cx)) + .map(|(id, (_, _, paths))| { + let combined_string = paths + .paths() .iter() .map(|path| path.compact().to_string_lossy().into_owned()) .collect::>() .join(""); - StringMatchCandidate::new(id, &combined_string) }) .collect::>(); @@ -279,7 +278,7 @@ impl PickerDelegate for RecentProjectsDelegate { .get(self.selected_index()) .zip(self.workspace.upgrade()) { - let (candidate_workspace_id, candidate_workspace_location) = + let (candidate_workspace_id, candidate_workspace_location, candidate_workspace_paths) = &self.workspaces[selected_match.candidate_id]; let replace_current_window = if self.create_new_window { secondary @@ -292,8 +291,8 @@ impl PickerDelegate for RecentProjectsDelegate { Task::ready(Ok(())) } else { match candidate_workspace_location { - SerializedWorkspaceLocation::Local(paths, _) => { - let paths = paths.paths().to_vec(); + SerializedWorkspaceLocation::Local => { + let paths = candidate_workspace_paths.paths().to_vec(); if replace_current_window { cx.spawn_in(window, async move |workspace, cx| { let continue_replacing = workspace @@ -321,7 +320,7 @@ impl PickerDelegate for RecentProjectsDelegate { workspace.open_workspace_for_paths(false, paths, window, cx) } } - SerializedWorkspaceLocation::Ssh(ssh_project) => { + SerializedWorkspaceLocation::Ssh(connection) => { let app_state = workspace.app_state().clone(); let replace_window = if replace_current_window { @@ -337,12 +336,12 @@ impl PickerDelegate for RecentProjectsDelegate { let connection_options = SshSettings::get_global(cx) .connection_options_for( - ssh_project.host.clone(), - ssh_project.port, - ssh_project.user.clone(), + connection.host.clone(), + connection.port, + connection.user.clone(), ); - let paths = ssh_project.paths.iter().map(PathBuf::from).collect(); + let paths = candidate_workspace_paths.paths().to_vec(); cx.spawn_in(window, async move |_, cx| { open_ssh_project( @@ -383,12 +382,12 @@ impl PickerDelegate for RecentProjectsDelegate { ) -> Option { let hit = self.matches.get(ix)?; - let (_, location) = self.workspaces.get(hit.candidate_id)?; + let (_, location, paths) = self.workspaces.get(hit.candidate_id)?; let mut path_start_offset = 0; - let (match_labels, paths): (Vec<_>, Vec<_>) = location - .sorted_paths() + let (match_labels, paths): (Vec<_>, Vec<_>) = paths + .paths() .iter() .map(|p| p.compact()) .map(|path| { @@ -416,11 +415,9 @@ impl PickerDelegate for RecentProjectsDelegate { .gap_3() .when(self.has_any_non_local_projects, |this| { this.child(match location { - SerializedWorkspaceLocation::Local(_, _) => { - Icon::new(IconName::Screen) - .color(Color::Muted) - .into_any_element() - } + SerializedWorkspaceLocation::Local => Icon::new(IconName::Screen) + .color(Color::Muted) + .into_any_element(), SerializedWorkspaceLocation::Ssh(_) => Icon::new(IconName::Server) .color(Color::Muted) .into_any_element(), @@ -568,7 +565,7 @@ impl RecentProjectsDelegate { cx: &mut Context>, ) { if let Some(selected_match) = self.matches.get(ix) { - let (workspace_id, _) = self.workspaces[selected_match.candidate_id]; + let (workspace_id, _, _) = self.workspaces[selected_match.candidate_id]; cx.spawn_in(window, async move |this, cx| { let _ = WORKSPACE_DB.delete_workspace_by_id(workspace_id).await; let workspaces = WORKSPACE_DB @@ -707,7 +704,8 @@ mod tests { }]; delegate.set_workspaces(vec![( WorkspaceId::default(), - SerializedWorkspaceLocation::from_local_paths(vec![path!("/test/path/")]), + SerializedWorkspaceLocation::Local, + PathList::new(&[path!("/test/path")]), )]); }); }) diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index c02d0ad7e7..b9af528643 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -52,11 +52,6 @@ use util::{ paths::{PathStyle, RemotePathBuf}, }; -#[derive( - Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize, -)] -pub struct SshProjectId(pub u64); - #[derive(Clone)] pub struct SshSocket { connection_options: SshConnectionOptions, diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index c4ba9b5154..8ac12588af 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -3905,7 +3905,7 @@ pub mod tests { assert_eq!(workspace.active_pane(), &second_pane); second_pane.update(cx, |this, cx| { assert_eq!(this.active_item_index(), 1); - this.activate_prev_item(false, window, cx); + this.activate_previous_item(&Default::default(), window, cx); assert_eq!(this.active_item_index(), 0); }); workspace.activate_pane_in_direction(workspace::SplitDirection::Left, window, cx); @@ -3940,7 +3940,9 @@ pub mod tests { // Focus the second pane's non-search item window .update(cx, |_workspace, window, cx| { - second_pane.update(cx, |pane, cx| pane.activate_next_item(true, window, cx)); + second_pane.update(cx, |pane, cx| { + pane.activate_next_item(&Default::default(), window, cx) + }); }) .unwrap(); diff --git a/crates/ui/src/components/callout.rs b/crates/ui/src/components/callout.rs index 7ffeda881c..b1ead18ee7 100644 --- a/crates/ui/src/components/callout.rs +++ b/crates/ui/src/components/callout.rs @@ -132,6 +132,7 @@ impl RenderOnce for Callout { h_flex() .min_w_0() + .w_full() .p_2() .gap_2() .items_start() diff --git a/crates/workspace/Cargo.toml b/crates/workspace/Cargo.toml index 570657ba8f..869aa5322e 100644 --- a/crates/workspace/Cargo.toml +++ b/crates/workspace/Cargo.toml @@ -29,7 +29,6 @@ test-support = [ any_vec.workspace = true anyhow.workspace = true async-recursion.workspace = true -bincode.workspace = true call.workspace = true client.workspace = true clock.workspace = true @@ -80,5 +79,6 @@ project = { workspace = true, features = ["test-support"] } session = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } http_client = { workspace = true, features = ["test-support"] } +pretty_assertions.workspace = true tempfile.workspace = true zlog.workspace = true diff --git a/crates/workspace/src/history_manager.rs b/crates/workspace/src/history_manager.rs index a8387369f4..f68b58ff82 100644 --- a/crates/workspace/src/history_manager.rs +++ b/crates/workspace/src/history_manager.rs @@ -5,7 +5,9 @@ use smallvec::SmallVec; use ui::App; use util::{ResultExt, paths::PathExt}; -use crate::{NewWindow, SerializedWorkspaceLocation, WORKSPACE_DB, WorkspaceId}; +use crate::{ + NewWindow, SerializedWorkspaceLocation, WORKSPACE_DB, WorkspaceId, path_list::PathList, +}; pub fn init(cx: &mut App) { let manager = cx.new(|_| HistoryManager::new()); @@ -44,7 +46,13 @@ impl HistoryManager { .unwrap_or_default() .into_iter() .rev() - .map(|(id, location)| HistoryManagerEntry::new(id, &location)) + .filter_map(|(id, location, paths)| { + if matches!(location, SerializedWorkspaceLocation::Local) { + Some(HistoryManagerEntry::new(id, &paths)) + } else { + None + } + }) .collect::>(); this.update(cx, |this, cx| { this.history = recent_folders; @@ -118,9 +126,9 @@ impl HistoryManager { } impl HistoryManagerEntry { - pub fn new(id: WorkspaceId, location: &SerializedWorkspaceLocation) -> Self { - let path = location - .sorted_paths() + pub fn new(id: WorkspaceId, paths: &PathList) -> Self { + let path = paths + .paths() .iter() .map(|path| path.compact()) .collect::>(); diff --git a/crates/workspace/src/invalid_buffer_view.rs b/crates/workspace/src/invalid_buffer_view.rs new file mode 100644 index 0000000000..b017373474 --- /dev/null +++ b/crates/workspace/src/invalid_buffer_view.rs @@ -0,0 +1,111 @@ +use std::{path::PathBuf, sync::Arc}; + +use gpui::{EventEmitter, FocusHandle, Focusable}; +use ui::{ + App, Button, ButtonCommon, ButtonStyle, Clickable, Context, FluentBuilder, InteractiveElement, + KeyBinding, ParentElement, Render, SharedString, Styled as _, Window, h_flex, v_flex, +}; +use zed_actions::workspace::OpenWithSystem; + +use crate::Item; + +/// A view to display when a certain buffer fails to open. +pub struct InvalidBufferView { + /// Which path was attempted to open. + pub abs_path: Arc, + /// An error message, happened when opening the buffer. + pub error: SharedString, + is_local: bool, + focus_handle: FocusHandle, +} + +impl InvalidBufferView { + pub fn new( + abs_path: PathBuf, + is_local: bool, + e: &anyhow::Error, + _: &mut Window, + cx: &mut App, + ) -> Self { + Self { + is_local, + abs_path: Arc::new(abs_path), + error: format!("{e}").into(), + focus_handle: cx.focus_handle(), + } + } +} + +impl Item for InvalidBufferView { + type Event = (); + + fn tab_content_text(&self, mut detail: usize, _: &App) -> SharedString { + // Ensure we always render at least the filename. + detail += 1; + + let path = self.abs_path.as_path(); + + let mut prefix = path; + while detail > 0 { + if let Some(parent) = prefix.parent() { + prefix = parent; + detail -= 1; + } else { + break; + } + } + + let path = if detail > 0 { + path + } else { + path.strip_prefix(prefix).unwrap_or(path) + }; + + SharedString::new(path.to_string_lossy()) + } +} + +impl EventEmitter<()> for InvalidBufferView {} + +impl Focusable for InvalidBufferView { + fn focus_handle(&self, _: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for InvalidBufferView { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl gpui::IntoElement { + let abs_path = self.abs_path.clone(); + v_flex() + .size_full() + .track_focus(&self.focus_handle(cx)) + .flex_none() + .justify_center() + .overflow_hidden() + .key_context("InvalidBuffer") + .child( + h_flex().size_full().justify_center().child( + v_flex() + .justify_center() + .gap_2() + .child(h_flex().justify_center().child("Unsupported file type")) + .when(self.is_local, |contents| { + contents.child( + h_flex().justify_center().child( + Button::new("open-with-system", "Open in Default App") + .on_click(move |_, _, cx| { + cx.open_with_system(&abs_path); + }) + .style(ButtonStyle::Outlined) + .key_binding(KeyBinding::for_action( + &OpenWithSystem, + window, + cx, + )), + ), + ) + }), + ), + ) + } +} diff --git a/crates/workspace/src/item.rs b/crates/workspace/src/item.rs index 5a497398f9..3485fcca43 100644 --- a/crates/workspace/src/item.rs +++ b/crates/workspace/src/item.rs @@ -1,6 +1,7 @@ use crate::{ CollaboratorId, DelayedDebouncedEditAction, FollowableViewRegistry, ItemNavHistory, SerializableItemRegistry, ToolbarItemLocation, ViewId, Workspace, WorkspaceId, + invalid_buffer_view::InvalidBufferView, pane::{self, Pane}, persistence::model::ItemId, searchable::SearchableItemHandle, @@ -22,6 +23,7 @@ use std::{ any::{Any, TypeId}, cell::RefCell, ops::Range, + path::PathBuf, rc::Rc, sync::Arc, time::Duration, @@ -1161,6 +1163,22 @@ pub trait ProjectItem: Item { ) -> Self where Self: Sized; + + /// A fallback handler, which will be called after [`project::ProjectItem::try_open`] fails, + /// with the error from that failure as an argument. + /// Allows to open an item that can gracefully display and handle errors. + fn for_broken_project_item( + _abs_path: PathBuf, + _is_local: bool, + _e: &anyhow::Error, + _window: &mut Window, + _cx: &mut App, + ) -> Option + where + Self: Sized, + { + None + } } #[derive(Debug)] diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index 23c8c0b185..fe8014d9f7 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -2,6 +2,7 @@ use crate::{ CloseWindow, NewFile, NewTerminal, OpenInTerminal, OpenOptions, OpenTerminal, OpenVisible, SplitDirection, ToggleFileFinder, ToggleProjectSymbols, ToggleZoom, Workspace, WorkspaceItemBuilder, + invalid_buffer_view::InvalidBufferView, item::{ ActivateOnClose, ClosePosition, Item, ItemHandle, ItemSettings, PreviewTabsSettings, ProjectItemKind, SaveOptions, ShowCloseButton, ShowDiagnostics, TabContentParams, @@ -513,7 +514,7 @@ impl Pane { } } - fn alternate_file(&mut self, window: &mut Window, cx: &mut Context) { + fn alternate_file(&mut self, _: &AlternateFile, window: &mut Window, cx: &mut Context) { let (_, alternative) = &self.alternate_file_items; if let Some(alternative) = alternative { let existing = self @@ -787,7 +788,7 @@ impl Pane { !self.nav_history.0.lock().forward_stack.is_empty() } - pub fn navigate_backward(&mut self, window: &mut Window, cx: &mut Context) { + pub fn navigate_backward(&mut self, _: &GoBack, window: &mut Window, cx: &mut Context) { if let Some(workspace) = self.workspace.upgrade() { let pane = cx.entity().downgrade(); window.defer(cx, move |window, cx| { @@ -798,7 +799,7 @@ impl Pane { } } - fn navigate_forward(&mut self, window: &mut Window, cx: &mut Context) { + fn navigate_forward(&mut self, _: &GoForward, window: &mut Window, cx: &mut Context) { if let Some(workspace) = self.workspace.upgrade() { let pane = cx.entity().downgrade(); window.defer(cx, move |window, cx| { @@ -897,19 +898,43 @@ impl Pane { } } } + + let set_up_existing_item = + |index: usize, pane: &mut Self, window: &mut Window, cx: &mut Context| { + // If the item is already open, and the item is a preview item + // and we are not allowing items to open as preview, mark the item as persistent. + if let Some(preview_item_id) = pane.preview_item_id + && let Some(tab) = pane.items.get(index) + && tab.item_id() == preview_item_id + && !allow_preview + { + pane.set_preview_item_id(None, cx); + } + if activate { + pane.activate_item(index, focus_item, focus_item, window, cx); + } + }; + let set_up_new_item = |new_item: Box, + destination_index: Option, + pane: &mut Self, + window: &mut Window, + cx: &mut Context| { + if allow_preview { + pane.set_preview_item_id(Some(new_item.item_id()), cx); + } + pane.add_item_inner( + new_item, + true, + focus_item, + activate, + destination_index, + window, + cx, + ); + }; + if let Some((index, existing_item)) = existing_item { - // If the item is already open, and the item is a preview item - // and we are not allowing items to open as preview, mark the item as persistent. - if let Some(preview_item_id) = self.preview_item_id - && let Some(tab) = self.items.get(index) - && tab.item_id() == preview_item_id - && !allow_preview - { - self.set_preview_item_id(None, cx); - } - if activate { - self.activate_item(index, focus_item, focus_item, window, cx); - } + set_up_existing_item(index, self, window, cx); existing_item } else { // If the item is being opened as preview and we have an existing preview tab, @@ -921,21 +946,46 @@ impl Pane { }; let new_item = build_item(self, window, cx); + // A special case that won't ever get a `project_entry_id` but has to be deduplicated nonetheless. + if let Some(invalid_buffer_view) = new_item.downcast::() { + let mut already_open_view = None; + let mut views_to_close = HashSet::default(); + for existing_error_view in self + .items_of_type::() + .filter(|item| item.read(cx).abs_path == invalid_buffer_view.read(cx).abs_path) + { + if already_open_view.is_none() + && existing_error_view.read(cx).error == invalid_buffer_view.read(cx).error + { + already_open_view = Some(existing_error_view); + } else { + views_to_close.insert(existing_error_view.item_id()); + } + } - if allow_preview { - self.set_preview_item_id(Some(new_item.item_id()), cx); + let resulting_item = match already_open_view { + Some(already_open_view) => { + if let Some(index) = self.index_for_item_id(already_open_view.item_id()) { + set_up_existing_item(index, self, window, cx); + } + Box::new(already_open_view) as Box<_> + } + None => { + set_up_new_item(new_item.clone(), destination_index, self, window, cx); + new_item + } + }; + + self.close_items(window, cx, SaveIntent::Skip, |existing_item| { + views_to_close.contains(&existing_item) + }) + .detach(); + + resulting_item + } else { + set_up_new_item(new_item.clone(), destination_index, self, window, cx); + new_item } - self.add_item_inner( - new_item.clone(), - true, - focus_item, - activate, - destination_index, - window, - cx, - ); - - new_item } } @@ -1233,9 +1283,9 @@ impl Pane { } } - pub fn activate_prev_item( + pub fn activate_previous_item( &mut self, - activate_pane: bool, + _: &ActivatePreviousItem, window: &mut Window, cx: &mut Context, ) { @@ -1245,12 +1295,12 @@ impl Pane { } else if !self.items.is_empty() { index = self.items.len() - 1; } - self.activate_item(index, activate_pane, activate_pane, window, cx); + self.activate_item(index, true, true, window, cx); } pub fn activate_next_item( &mut self, - activate_pane: bool, + _: &ActivateNextItem, window: &mut Window, cx: &mut Context, ) { @@ -1260,10 +1310,15 @@ impl Pane { } else { index = 0; } - self.activate_item(index, activate_pane, activate_pane, window, cx); + self.activate_item(index, true, true, window, cx); } - pub fn swap_item_left(&mut self, window: &mut Window, cx: &mut Context) { + pub fn swap_item_left( + &mut self, + _: &SwapItemLeft, + window: &mut Window, + cx: &mut Context, + ) { let index = self.active_item_index; if index == 0 { return; @@ -1273,9 +1328,14 @@ impl Pane { self.activate_item(index - 1, true, true, window, cx); } - pub fn swap_item_right(&mut self, window: &mut Window, cx: &mut Context) { + pub fn swap_item_right( + &mut self, + _: &SwapItemRight, + window: &mut Window, + cx: &mut Context, + ) { let index = self.active_item_index; - if index + 1 == self.items.len() { + if index + 1 >= self.items.len() { return; } @@ -1283,6 +1343,16 @@ impl Pane { self.activate_item(index + 1, true, true, window, cx); } + pub fn activate_last_item( + &mut self, + _: &ActivateLastItem, + window: &mut Window, + cx: &mut Context, + ) { + let index = self.items.len().saturating_sub(1); + self.activate_item(index, true, true, window, cx); + } + pub fn close_active_item( &mut self, action: &CloseActiveItem, @@ -2831,7 +2901,9 @@ impl Pane { .on_click({ let entity = cx.entity(); move |_, window, cx| { - entity.update(cx, |pane, cx| pane.navigate_backward(window, cx)) + entity.update(cx, |pane, cx| { + pane.navigate_backward(&Default::default(), window, cx) + }) } }) .disabled(!self.can_navigate_backward()) @@ -2846,7 +2918,11 @@ impl Pane { .icon_size(IconSize::Small) .on_click({ let entity = cx.entity(); - move |_, window, cx| entity.update(cx, |pane, cx| pane.navigate_forward(window, cx)) + move |_, window, cx| { + entity.update(cx, |pane, cx| { + pane.navigate_forward(&Default::default(), window, cx) + }) + } }) .disabled(!self.can_navigate_forward()) .tooltip({ @@ -3478,9 +3554,6 @@ impl Render for Pane { .size_full() .flex_none() .overflow_hidden() - .on_action(cx.listener(|pane, _: &AlternateFile, window, cx| { - pane.alternate_file(window, cx); - })) .on_action( cx.listener(|pane, _: &SplitLeft, _, cx| pane.split(SplitDirection::Left, cx)), ) @@ -3497,12 +3570,6 @@ impl Render for Pane { .on_action( cx.listener(|pane, _: &SplitDown, _, cx| pane.split(SplitDirection::Down, cx)), ) - .on_action( - cx.listener(|pane, _: &GoBack, window, cx| pane.navigate_backward(window, cx)), - ) - .on_action( - cx.listener(|pane, _: &GoForward, window, cx| pane.navigate_forward(window, cx)), - ) .on_action(cx.listener(|_, _: &JoinIntoNext, _, cx| { cx.emit(Event::JoinIntoNext); })) @@ -3510,6 +3577,8 @@ impl Render for Pane { cx.emit(Event::JoinAll); })) .on_action(cx.listener(Pane::toggle_zoom)) + .on_action(cx.listener(Self::navigate_backward)) + .on_action(cx.listener(Self::navigate_forward)) .on_action( cx.listener(|pane: &mut Pane, action: &ActivateItem, window, cx| { pane.activate_item( @@ -3521,33 +3590,14 @@ impl Render for Pane { ); }), ) - .on_action( - cx.listener(|pane: &mut Pane, _: &ActivateLastItem, window, cx| { - pane.activate_item(pane.items.len().saturating_sub(1), true, true, window, cx); - }), - ) - .on_action( - cx.listener(|pane: &mut Pane, _: &ActivatePreviousItem, window, cx| { - pane.activate_prev_item(true, window, cx); - }), - ) - .on_action( - cx.listener(|pane: &mut Pane, _: &ActivateNextItem, window, cx| { - pane.activate_next_item(true, window, cx); - }), - ) - .on_action( - cx.listener(|pane, _: &SwapItemLeft, window, cx| pane.swap_item_left(window, cx)), - ) - .on_action( - cx.listener(|pane, _: &SwapItemRight, window, cx| pane.swap_item_right(window, cx)), - ) - .on_action(cx.listener(|pane, action, window, cx| { - pane.toggle_pin_tab(action, window, cx); - })) - .on_action(cx.listener(|pane, action, window, cx| { - pane.unpin_all_tabs(action, window, cx); - })) + .on_action(cx.listener(Self::alternate_file)) + .on_action(cx.listener(Self::activate_last_item)) + .on_action(cx.listener(Self::activate_previous_item)) + .on_action(cx.listener(Self::activate_next_item)) + .on_action(cx.listener(Self::swap_item_left)) + .on_action(cx.listener(Self::swap_item_right)) + .on_action(cx.listener(Self::toggle_pin_tab)) + .on_action(cx.listener(Self::unpin_all_tabs)) .when(PreviewTabsSettings::get_global(cx).enabled, |this| { this.on_action(cx.listener(|pane: &mut Pane, _: &TogglePreviewTab, _, cx| { if let Some(active_item_id) = pane.active_item().map(|i| i.item_id()) { @@ -6402,6 +6452,57 @@ mod tests { .unwrap(); } + #[gpui::test] + async fn test_item_swapping_actions(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, None, cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project, window, cx)); + + let pane = workspace.read_with(cx, |workspace, _| workspace.active_pane().clone()); + assert_item_labels(&pane, [], cx); + + // Test that these actions do not panic + pane.update_in(cx, |pane, window, cx| { + pane.swap_item_right(&Default::default(), window, cx); + }); + + pane.update_in(cx, |pane, window, cx| { + pane.swap_item_left(&Default::default(), window, cx); + }); + + add_labeled_item(&pane, "A", false, cx); + add_labeled_item(&pane, "B", false, cx); + add_labeled_item(&pane, "C", false, cx); + assert_item_labels(&pane, ["A", "B", "C*"], cx); + + pane.update_in(cx, |pane, window, cx| { + pane.swap_item_right(&Default::default(), window, cx); + }); + assert_item_labels(&pane, ["A", "B", "C*"], cx); + + pane.update_in(cx, |pane, window, cx| { + pane.swap_item_left(&Default::default(), window, cx); + }); + assert_item_labels(&pane, ["A", "C*", "B"], cx); + + pane.update_in(cx, |pane, window, cx| { + pane.swap_item_left(&Default::default(), window, cx); + }); + assert_item_labels(&pane, ["C*", "A", "B"], cx); + + pane.update_in(cx, |pane, window, cx| { + pane.swap_item_left(&Default::default(), window, cx); + }); + assert_item_labels(&pane, ["C*", "A", "B"], cx); + + pane.update_in(cx, |pane, window, cx| { + pane.swap_item_right(&Default::default(), window, cx); + }); + assert_item_labels(&pane, ["A", "C*", "B"], cx); + } + fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); diff --git a/crates/workspace/src/path_list.rs b/crates/workspace/src/path_list.rs new file mode 100644 index 0000000000..4f9ed42312 --- /dev/null +++ b/crates/workspace/src/path_list.rs @@ -0,0 +1,121 @@ +use std::{ + path::{Path, PathBuf}, + sync::Arc, +}; + +use util::paths::SanitizedPath; + +/// A list of absolute paths, in a specific order. +/// +/// The paths are stored in lexicographic order, so that they can be compared to +/// other path lists without regard to the order of the paths. +#[derive(Default, PartialEq, Eq, Debug, Clone)] +pub struct PathList { + paths: Arc<[PathBuf]>, + order: Arc<[usize]>, +} + +#[derive(Debug)] +pub struct SerializedPathList { + pub paths: String, + pub order: String, +} + +impl PathList { + pub fn new>(paths: &[P]) -> Self { + let mut indexed_paths: Vec<(usize, PathBuf)> = paths + .iter() + .enumerate() + .map(|(ix, path)| (ix, SanitizedPath::from(path).into())) + .collect(); + indexed_paths.sort_by(|(_, a), (_, b)| a.cmp(b)); + let order = indexed_paths.iter().map(|e| e.0).collect::>().into(); + let paths = indexed_paths + .into_iter() + .map(|e| e.1) + .collect::>() + .into(); + Self { order, paths } + } + + pub fn is_empty(&self) -> bool { + self.paths.is_empty() + } + + pub fn paths(&self) -> &[PathBuf] { + self.paths.as_ref() + } + + pub fn order(&self) -> &[usize] { + self.order.as_ref() + } + + pub fn is_lexicographically_ordered(&self) -> bool { + self.order.iter().enumerate().all(|(i, &j)| i == j) + } + + pub fn deserialize(serialized: &SerializedPathList) -> Self { + let mut paths: Vec = if serialized.paths.is_empty() { + Vec::new() + } else { + serde_json::from_str::>(&serialized.paths) + .unwrap_or(Vec::new()) + .into_iter() + .map(|s| SanitizedPath::from(s).into()) + .collect() + }; + + let mut order: Vec = serialized + .order + .split(',') + .filter_map(|s| s.parse().ok()) + .collect(); + + if !paths.is_sorted() || order.len() != paths.len() { + order = (0..paths.len()).collect(); + paths.sort(); + } + + Self { + paths: paths.into(), + order: order.into(), + } + } + + pub fn serialize(&self) -> SerializedPathList { + use std::fmt::Write as _; + + let paths = serde_json::to_string(&self.paths).unwrap_or_default(); + + let mut order = String::new(); + for ix in self.order.iter() { + if !order.is_empty() { + order.push(','); + } + write!(&mut order, "{}", *ix).unwrap(); + } + SerializedPathList { paths, order } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_path_list() { + let list1 = PathList::new(&["a/d", "a/c"]); + let list2 = PathList::new(&["a/c", "a/d"]); + + assert_eq!(list1.paths(), list2.paths()); + assert_ne!(list1, list2); + assert_eq!(list1.order(), &[1, 0]); + assert_eq!(list2.order(), &[0, 1]); + + let list1_deserialized = PathList::deserialize(&list1.serialize()); + assert_eq!(list1_deserialized, list1); + + let list2_deserialized = PathList::deserialize(&list2.serialize()); + assert_eq!(list2_deserialized, list2); + } +} diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index b2d1340a7b..39a1e08c93 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -9,15 +9,13 @@ use std::{ }; use anyhow::{Context as _, Result, bail}; -use client::DevServerProjectId; +use collections::HashMap; use db::{define_connection, query, sqlez::connection::Connection, sqlez_macros::sql}; use gpui::{Axis, Bounds, Task, WindowBounds, WindowId, point, size}; -use itertools::Itertools; use project::debugger::breakpoint_store::{BreakpointState, SourceBreakpoint}; use language::{LanguageName, Toolchain}; use project::WorktreeId; -use remote::ssh_session::SshProjectId; use sqlez::{ bindable::{Bind, Column, StaticColumnCount}, statement::{SqlType, Statement}, @@ -28,14 +26,17 @@ use ui::{App, px}; use util::{ResultExt, maybe}; use uuid::Uuid; -use crate::WorkspaceId; - -use model::{ - GroupId, ItemId, LocalPaths, PaneId, SerializedItem, SerializedPane, SerializedPaneGroup, - SerializedSshProject, SerializedWorkspace, +use crate::{ + WorkspaceId, + path_list::{PathList, SerializedPathList}, }; -use self::model::{DockStructure, LocalPathsOrder, SerializedWorkspaceLocation}; +use model::{ + GroupId, ItemId, PaneId, SerializedItem, SerializedPane, SerializedPaneGroup, + SerializedSshConnection, SerializedWorkspace, SshConnectionId, +}; + +use self::model::{DockStructure, SerializedWorkspaceLocation}; #[derive(Copy, Clone, Debug, PartialEq)] pub(crate) struct SerializedAxis(pub(crate) gpui::Axis); @@ -275,70 +276,9 @@ impl sqlez::bindable::Bind for SerializedPixels { } define_connection! { - // Current schema shape using pseudo-rust syntax: - // - // workspaces( - // workspace_id: usize, // Primary key for workspaces - // local_paths: Bincode>, - // local_paths_order: Bincode>, - // dock_visible: bool, // Deprecated - // dock_anchor: DockAnchor, // Deprecated - // dock_pane: Option, // Deprecated - // left_sidebar_open: boolean, - // timestamp: String, // UTC YYYY-MM-DD HH:MM:SS - // window_state: String, // WindowBounds Discriminant - // window_x: Option, // WindowBounds::Fixed RectF x - // window_y: Option, // WindowBounds::Fixed RectF y - // window_width: Option, // WindowBounds::Fixed RectF width - // window_height: Option, // WindowBounds::Fixed RectF height - // display: Option, // Display id - // fullscreen: Option, // Is the window fullscreen? - // centered_layout: Option, // Is the Centered Layout mode activated? - // session_id: Option, // Session id - // window_id: Option, // Window Id - // ) - // - // pane_groups( - // group_id: usize, // Primary key for pane_groups - // workspace_id: usize, // References workspaces table - // parent_group_id: Option, // None indicates that this is the root node - // position: Option, // None indicates that this is the root node - // axis: Option, // 'Vertical', 'Horizontal' - // flexes: Option>, // A JSON array of floats - // ) - // - // panes( - // pane_id: usize, // Primary key for panes - // workspace_id: usize, // References workspaces table - // active: bool, - // ) - // - // center_panes( - // pane_id: usize, // Primary key for center_panes - // parent_group_id: Option, // References pane_groups. If none, this is the root - // position: Option, // None indicates this is the root - // ) - // - // CREATE TABLE items( - // item_id: usize, // This is the item's view id, so this is not unique - // workspace_id: usize, // References workspaces table - // pane_id: usize, // References panes table - // kind: String, // Indicates which view this connects to. This is the key in the item_deserializers global - // position: usize, // Position of the item in the parent pane. This is equivalent to panes' position column - // active: bool, // Indicates if this item is the active one in the pane - // preview: bool // Indicates if this item is a preview item - // ) - // - // CREATE TABLE breakpoints( - // workspace_id: usize Foreign Key, // References workspace table - // path: PathBuf, // The absolute path of the file that this breakpoint belongs to - // breakpoint_location: Vec, // A list of the locations of breakpoints - // kind: int, // The kind of breakpoint (standard, log) - // log_message: String, // log message for log breakpoints, otherwise it's Null - // ) pub static ref DB: WorkspaceDb<()> = &[ - sql!( + sql!( CREATE TABLE workspaces( workspace_id INTEGER PRIMARY KEY, workspace_location BLOB UNIQUE, @@ -555,7 +495,109 @@ define_connection! { SELECT * FROM toolchains; DROP TABLE toolchains; ALTER TABLE toolchains2 RENAME TO toolchains; - ) + ), + sql!( + CREATE TABLE ssh_connections ( + id INTEGER PRIMARY KEY, + host TEXT NOT NULL, + port INTEGER, + user TEXT + ); + + INSERT INTO ssh_connections (host, port, user) + SELECT DISTINCT host, port, user + FROM ssh_projects; + + CREATE TABLE workspaces_2( + workspace_id INTEGER PRIMARY KEY, + paths TEXT, + paths_order TEXT, + ssh_connection_id INTEGER REFERENCES ssh_connections(id), + timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL, + window_state TEXT, + window_x REAL, + window_y REAL, + window_width REAL, + window_height REAL, + display BLOB, + left_dock_visible INTEGER, + left_dock_active_panel TEXT, + right_dock_visible INTEGER, + right_dock_active_panel TEXT, + bottom_dock_visible INTEGER, + bottom_dock_active_panel TEXT, + left_dock_zoom INTEGER, + right_dock_zoom INTEGER, + bottom_dock_zoom INTEGER, + fullscreen INTEGER, + centered_layout INTEGER, + session_id TEXT, + window_id INTEGER + ) STRICT; + + INSERT + INTO workspaces_2 + SELECT + workspaces.workspace_id, + CASE + WHEN ssh_projects.id IS NOT NULL THEN ssh_projects.paths + ELSE + CASE + WHEN workspaces.local_paths_array IS NULL OR workspaces.local_paths_array = "" THEN + NULL + ELSE + json('[' || '"' || replace(workspaces.local_paths_array, ',', '"' || "," || '"') || '"' || ']') + END + END as paths, + + CASE + WHEN ssh_projects.id IS NOT NULL THEN "" + ELSE workspaces.local_paths_order_array + END as paths_order, + + CASE + WHEN ssh_projects.id IS NOT NULL THEN ( + SELECT ssh_connections.id + FROM ssh_connections + WHERE + ssh_connections.host IS ssh_projects.host AND + ssh_connections.port IS ssh_projects.port AND + ssh_connections.user IS ssh_projects.user + ) + ELSE NULL + END as ssh_connection_id, + + workspaces.timestamp, + workspaces.window_state, + workspaces.window_x, + workspaces.window_y, + workspaces.window_width, + workspaces.window_height, + workspaces.display, + workspaces.left_dock_visible, + workspaces.left_dock_active_panel, + workspaces.right_dock_visible, + workspaces.right_dock_active_panel, + workspaces.bottom_dock_visible, + workspaces.bottom_dock_active_panel, + workspaces.left_dock_zoom, + workspaces.right_dock_zoom, + workspaces.bottom_dock_zoom, + workspaces.fullscreen, + workspaces.centered_layout, + workspaces.session_id, + workspaces.window_id + FROM + workspaces LEFT JOIN + ssh_projects ON + workspaces.ssh_project_id = ssh_projects.id; + + DROP TABLE ssh_projects; + DROP TABLE workspaces; + ALTER TABLE workspaces_2 RENAME TO workspaces; + + CREATE UNIQUE INDEX ix_workspaces_location ON workspaces(ssh_connection_id, paths); + ), ]; } @@ -566,17 +608,33 @@ impl WorkspaceDb { pub(crate) fn workspace_for_roots>( &self, worktree_roots: &[P], + ) -> Option { + self.workspace_for_roots_internal(worktree_roots, None) + } + + pub(crate) fn ssh_workspace_for_roots>( + &self, + worktree_roots: &[P], + ssh_project_id: SshConnectionId, + ) -> Option { + self.workspace_for_roots_internal(worktree_roots, Some(ssh_project_id)) + } + + pub(crate) fn workspace_for_roots_internal>( + &self, + worktree_roots: &[P], + ssh_connection_id: Option, ) -> Option { // paths are sorted before db interactions to ensure that the order of the paths // doesn't affect the workspace selection for existing workspaces - let local_paths = LocalPaths::new(worktree_roots); + let root_paths = PathList::new(worktree_roots); // Note that we re-assign the workspace_id here in case it's empty // and we've grabbed the most recent workspace let ( workspace_id, - local_paths, - local_paths_order, + paths, + paths_order, window_bounds, display, centered_layout, @@ -584,8 +642,8 @@ impl WorkspaceDb { window_id, ): ( WorkspaceId, - Option, - Option, + String, + String, Option, Option, Option, @@ -595,8 +653,8 @@ impl WorkspaceDb { .select_row_bound(sql! { SELECT workspace_id, - local_paths, - local_paths_order, + paths, + paths_order, window_state, window_x, window_y, @@ -615,25 +673,31 @@ impl WorkspaceDb { bottom_dock_zoom, window_id FROM workspaces - WHERE local_paths = ? + WHERE + paths IS ? AND + ssh_connection_id IS ? + LIMIT 1 + }) + .map(|mut prepared_statement| { + (prepared_statement)(( + root_paths.serialize().paths, + ssh_connection_id.map(|id| id.0 as i32), + )) + .unwrap() }) - .and_then(|mut prepared_statement| (prepared_statement)(&local_paths)) .context("No workspaces found") .warn_on_err() .flatten()?; - let local_paths = local_paths?; - let location = match local_paths_order { - Some(order) => SerializedWorkspaceLocation::Local(local_paths, order), - None => { - let order = LocalPathsOrder::default_for_paths(&local_paths); - SerializedWorkspaceLocation::Local(local_paths, order) - } - }; + let paths = PathList::deserialize(&SerializedPathList { + paths, + order: paths_order, + }); Some(SerializedWorkspace { id: workspace_id, - location, + location: SerializedWorkspaceLocation::Local, + paths, center_group: self .get_center_pane_group(workspace_id) .context("Getting center group") @@ -648,63 +712,6 @@ impl WorkspaceDb { }) } - pub(crate) fn workspace_for_ssh_project( - &self, - ssh_project: &SerializedSshProject, - ) -> Option { - let (workspace_id, window_bounds, display, centered_layout, docks, window_id): ( - WorkspaceId, - Option, - Option, - Option, - DockStructure, - Option, - ) = self - .select_row_bound(sql! { - SELECT - workspace_id, - window_state, - window_x, - window_y, - window_width, - window_height, - display, - centered_layout, - left_dock_visible, - left_dock_active_panel, - left_dock_zoom, - right_dock_visible, - right_dock_active_panel, - right_dock_zoom, - bottom_dock_visible, - bottom_dock_active_panel, - bottom_dock_zoom, - window_id - FROM workspaces - WHERE ssh_project_id = ? - }) - .and_then(|mut prepared_statement| (prepared_statement)(ssh_project.id.0)) - .context("No workspaces found") - .warn_on_err() - .flatten()?; - - Some(SerializedWorkspace { - id: workspace_id, - location: SerializedWorkspaceLocation::Ssh(ssh_project.clone()), - center_group: self - .get_center_pane_group(workspace_id) - .context("Getting center group") - .log_err()?, - window_bounds, - centered_layout: centered_layout.unwrap_or(false), - breakpoints: self.breakpoints(workspace_id), - display, - docks, - session_id: None, - window_id, - }) - } - fn breakpoints(&self, workspace_id: WorkspaceId) -> BTreeMap, Vec> { let breakpoints: Result> = self .select_bound(sql! { @@ -754,16 +761,34 @@ impl WorkspaceDb { /// Saves a workspace using the worktree roots. Will garbage collect any workspaces /// that used this workspace previously pub(crate) async fn save_workspace(&self, workspace: SerializedWorkspace) { + let paths = workspace.paths.serialize(); log::debug!("Saving workspace at location: {:?}", workspace.location); self.write(move |conn| { conn.with_savepoint("update_worktrees", || { + let ssh_connection_id = match &workspace.location { + SerializedWorkspaceLocation::Local => None, + SerializedWorkspaceLocation::Ssh(connection) => { + Some(Self::get_or_create_ssh_connection_query( + conn, + connection.host.clone(), + connection.port, + connection.user.clone(), + )?.0) + } + }; + // Clear out panes and pane_groups conn.exec_bound(sql!( DELETE FROM pane_groups WHERE workspace_id = ?1; DELETE FROM panes WHERE workspace_id = ?1;))?(workspace.id) .context("Clearing old panes")?; - conn.exec_bound(sql!(DELETE FROM breakpoints WHERE workspace_id = ?1))?(workspace.id).context("Clearing old breakpoints")?; + conn.exec_bound( + sql!( + DELETE FROM breakpoints WHERE workspace_id = ?1; + DELETE FROM toolchains WHERE workspace_id = ?1; + ) + )?(workspace.id).context("Clearing old breakpoints")?; for (path, breakpoints) in workspace.breakpoints { for bp in breakpoints { @@ -790,115 +815,73 @@ impl WorkspaceDb { } } } - } + conn.exec_bound(sql!( + DELETE + FROM workspaces + WHERE + workspace_id != ?1 AND + paths IS ?2 AND + ssh_connection_id IS ?3 + ))?(( + workspace.id, + paths.paths.clone(), + ssh_connection_id, + )) + .context("clearing out old locations")?; - match workspace.location { - SerializedWorkspaceLocation::Local(local_paths, local_paths_order) => { - conn.exec_bound(sql!( - DELETE FROM toolchains WHERE workspace_id = ?1; - DELETE FROM workspaces WHERE local_paths = ? AND workspace_id != ? - ))?((&local_paths, workspace.id)) - .context("clearing out old locations")?; + // Upsert + let query = sql!( + INSERT INTO workspaces( + workspace_id, + paths, + paths_order, + ssh_connection_id, + left_dock_visible, + left_dock_active_panel, + left_dock_zoom, + right_dock_visible, + right_dock_active_panel, + right_dock_zoom, + bottom_dock_visible, + bottom_dock_active_panel, + bottom_dock_zoom, + session_id, + window_id, + timestamp + ) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, CURRENT_TIMESTAMP) + ON CONFLICT DO + UPDATE SET + paths = ?2, + paths_order = ?3, + ssh_connection_id = ?4, + left_dock_visible = ?5, + left_dock_active_panel = ?6, + left_dock_zoom = ?7, + right_dock_visible = ?8, + right_dock_active_panel = ?9, + right_dock_zoom = ?10, + bottom_dock_visible = ?11, + bottom_dock_active_panel = ?12, + bottom_dock_zoom = ?13, + session_id = ?14, + window_id = ?15, + timestamp = CURRENT_TIMESTAMP + ); + let mut prepared_query = conn.exec_bound(query)?; + let args = ( + workspace.id, + paths.paths.clone(), + paths.order.clone(), + ssh_connection_id, + workspace.docks, + workspace.session_id, + workspace.window_id, + ); - // Upsert - let query = sql!( - INSERT INTO workspaces( - workspace_id, - local_paths, - local_paths_order, - left_dock_visible, - left_dock_active_panel, - left_dock_zoom, - right_dock_visible, - right_dock_active_panel, - right_dock_zoom, - bottom_dock_visible, - bottom_dock_active_panel, - bottom_dock_zoom, - session_id, - window_id, - timestamp, - local_paths_array, - local_paths_order_array - ) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, CURRENT_TIMESTAMP, ?15, ?16) - ON CONFLICT DO - UPDATE SET - local_paths = ?2, - local_paths_order = ?3, - left_dock_visible = ?4, - left_dock_active_panel = ?5, - left_dock_zoom = ?6, - right_dock_visible = ?7, - right_dock_active_panel = ?8, - right_dock_zoom = ?9, - bottom_dock_visible = ?10, - bottom_dock_active_panel = ?11, - bottom_dock_zoom = ?12, - session_id = ?13, - window_id = ?14, - timestamp = CURRENT_TIMESTAMP, - local_paths_array = ?15, - local_paths_order_array = ?16 - ); - let mut prepared_query = conn.exec_bound(query)?; - let args = (workspace.id, &local_paths, &local_paths_order, workspace.docks, workspace.session_id, workspace.window_id, local_paths.paths().iter().map(|path| path.to_string_lossy().to_string()).join(","), local_paths_order.order().iter().map(|order| order.to_string()).join(",")); - - prepared_query(args).context("Updating workspace")?; - } - SerializedWorkspaceLocation::Ssh(ssh_project) => { - conn.exec_bound(sql!( - DELETE FROM toolchains WHERE workspace_id = ?1; - DELETE FROM workspaces WHERE ssh_project_id = ? AND workspace_id != ? - ))?((ssh_project.id.0, workspace.id)) - .context("clearing out old locations")?; - - // Upsert - conn.exec_bound(sql!( - INSERT INTO workspaces( - workspace_id, - ssh_project_id, - left_dock_visible, - left_dock_active_panel, - left_dock_zoom, - right_dock_visible, - right_dock_active_panel, - right_dock_zoom, - bottom_dock_visible, - bottom_dock_active_panel, - bottom_dock_zoom, - session_id, - window_id, - timestamp - ) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, CURRENT_TIMESTAMP) - ON CONFLICT DO - UPDATE SET - ssh_project_id = ?2, - left_dock_visible = ?3, - left_dock_active_panel = ?4, - left_dock_zoom = ?5, - right_dock_visible = ?6, - right_dock_active_panel = ?7, - right_dock_zoom = ?8, - bottom_dock_visible = ?9, - bottom_dock_active_panel = ?10, - bottom_dock_zoom = ?11, - session_id = ?12, - window_id = ?13, - timestamp = CURRENT_TIMESTAMP - ))?(( - workspace.id, - ssh_project.id.0, - workspace.docks, - workspace.session_id, - workspace.window_id - )) - .context("Updating workspace")?; - } - } + prepared_query(args).context("Updating workspace")?; // Save center pane group Self::save_pane_group(conn, workspace.id, &workspace.center_group, None) @@ -911,89 +894,95 @@ impl WorkspaceDb { .await; } - pub(crate) async fn get_or_create_ssh_project( + pub(crate) async fn get_or_create_ssh_connection( &self, host: String, port: Option, - paths: Vec, user: Option, - ) -> Result { - let paths = serde_json::to_string(&paths)?; - if let Some(project) = self - .get_ssh_project(host.clone(), port, paths.clone(), user.clone()) - .await? + ) -> Result { + self.write(move |conn| Self::get_or_create_ssh_connection_query(conn, host, port, user)) + .await + } + + fn get_or_create_ssh_connection_query( + this: &Connection, + host: String, + port: Option, + user: Option, + ) -> Result { + if let Some(id) = this.select_row_bound(sql!( + SELECT id FROM ssh_connections WHERE host IS ? AND port IS ? AND user IS ? LIMIT 1 + ))?((host.clone(), port, user.clone()))? { - Ok(project) + Ok(SshConnectionId(id)) } else { log::debug!("Inserting SSH project at host {host}"); - self.insert_ssh_project(host, port, paths, user) - .await? - .context("failed to insert ssh project") + let id = this.select_row_bound(sql!( + INSERT INTO ssh_connections ( + host, + port, + user + ) VALUES (?1, ?2, ?3) + RETURNING id + ))?((host, port, user))? + .context("failed to insert ssh project")?; + Ok(SshConnectionId(id)) } } - query! { - async fn get_ssh_project(host: String, port: Option, paths: String, user: Option) -> Result> { - SELECT id, host, port, paths, user - FROM ssh_projects - WHERE host IS ? AND port IS ? AND paths IS ? AND user IS ? - LIMIT 1 - } - } - - query! { - async fn insert_ssh_project(host: String, port: Option, paths: String, user: Option) -> Result> { - INSERT INTO ssh_projects( - host, - port, - paths, - user - ) VALUES (?1, ?2, ?3, ?4) - RETURNING id, host, port, paths, user - } - } - - query! { - pub async fn update_ssh_project_paths_query(ssh_project_id: u64, paths: String) -> Result> { - UPDATE ssh_projects - SET paths = ?2 - WHERE id = ?1 - RETURNING id, host, port, paths, user - } - } - - pub(crate) async fn update_ssh_project_paths( - &self, - ssh_project_id: SshProjectId, - new_paths: Vec, - ) -> Result { - let paths = serde_json::to_string(&new_paths)?; - self.update_ssh_project_paths_query(ssh_project_id.0, paths) - .await? - .context("failed to update ssh project paths") - } - query! { pub async fn next_id() -> Result { INSERT INTO workspaces DEFAULT VALUES RETURNING workspace_id } } + fn recent_workspaces(&self) -> Result)>> { + Ok(self + .recent_workspaces_query()? + .into_iter() + .map(|(id, paths, order, ssh_connection_id)| { + ( + id, + PathList::deserialize(&SerializedPathList { paths, order }), + ssh_connection_id, + ) + }) + .collect()) + } + query! { - fn recent_workspaces() -> Result)>> { - SELECT workspace_id, local_paths, local_paths_order, ssh_project_id + fn recent_workspaces_query() -> Result)>> { + SELECT workspace_id, paths, paths_order, ssh_connection_id FROM workspaces - WHERE local_paths IS NOT NULL - OR ssh_project_id IS NOT NULL + WHERE + paths IS NOT NULL OR + ssh_connection_id IS NOT NULL ORDER BY timestamp DESC } } + fn session_workspaces( + &self, + session_id: String, + ) -> Result, Option)>> { + Ok(self + .session_workspaces_query(session_id)? + .into_iter() + .map(|(paths, order, window_id, ssh_connection_id)| { + ( + PathList::deserialize(&SerializedPathList { paths, order }), + window_id, + ssh_connection_id.map(SshConnectionId), + ) + }) + .collect()) + } + query! { - fn session_workspaces(session_id: String) -> Result, Option)>> { - SELECT local_paths, local_paths_order, window_id, ssh_project_id + fn session_workspaces_query(session_id: String) -> Result, Option)>> { + SELECT paths, paths_order, window_id, ssh_connection_id FROM workspaces - WHERE session_id = ?1 AND dev_server_project_id IS NULL + WHERE session_id = ?1 ORDER BY timestamp DESC } } @@ -1013,17 +1002,39 @@ impl WorkspaceDb { } } - query! { - fn ssh_projects() -> Result> { - SELECT id, host, port, paths, user - FROM ssh_projects - } + fn ssh_connections(&self) -> Result> { + Ok(self + .ssh_connections_query()? + .into_iter() + .map(|(id, host, port, user)| { + ( + SshConnectionId(id), + SerializedSshConnection { host, port, user }, + ) + }) + .collect()) } query! { - fn ssh_project(id: u64) -> Result { - SELECT id, host, port, paths, user - FROM ssh_projects + pub fn ssh_connections_query() -> Result, Option)>> { + SELECT id, host, port, user + FROM ssh_connections + } + } + + pub(crate) fn ssh_connection(&self, id: SshConnectionId) -> Result { + let row = self.ssh_connection_query(id.0)?; + Ok(SerializedSshConnection { + host: row.0, + port: row.1, + user: row.2, + }) + } + + query! { + fn ssh_connection_query(id: u64) -> Result<(String, Option, Option)> { + SELECT host, port, user + FROM ssh_connections WHERE id = ? } } @@ -1037,7 +1048,7 @@ impl WorkspaceDb { display, window_state, window_x, window_y, window_width, window_height FROM workspaces - WHERE local_paths + WHERE paths IS NOT NULL ORDER BY timestamp DESC LIMIT 1 @@ -1054,46 +1065,33 @@ impl WorkspaceDb { } } - pub async fn delete_workspace_by_dev_server_project_id( - &self, - id: DevServerProjectId, - ) -> Result<()> { - self.write(move |conn| { - conn.exec_bound(sql!( - DELETE FROM dev_server_projects WHERE id = ? - ))?(id.0)?; - conn.exec_bound(sql!( - DELETE FROM toolchains WHERE workspace_id = ?1; - DELETE FROM workspaces - WHERE dev_server_project_id IS ? - ))?(id.0) - }) - .await - } - // Returns the recent locations which are still valid on disk and deletes ones which no longer // exist. pub async fn recent_workspaces_on_disk( &self, - ) -> Result> { + ) -> Result> { let mut result = Vec::new(); let mut delete_tasks = Vec::new(); - let ssh_projects = self.ssh_projects()?; + let ssh_connections = self.ssh_connections()?; - for (id, location, order, ssh_project_id) in self.recent_workspaces()? { - if let Some(ssh_project_id) = ssh_project_id.map(SshProjectId) { - if let Some(ssh_project) = ssh_projects.iter().find(|rp| rp.id == ssh_project_id) { - result.push((id, SerializedWorkspaceLocation::Ssh(ssh_project.clone()))); + for (id, paths, ssh_connection_id) in self.recent_workspaces()? { + if let Some(ssh_connection_id) = ssh_connection_id.map(SshConnectionId) { + if let Some(ssh_connection) = ssh_connections.get(&ssh_connection_id) { + result.push(( + id, + SerializedWorkspaceLocation::Ssh(ssh_connection.clone()), + paths, + )); } else { delete_tasks.push(self.delete_workspace_by_id(id)); } continue; } - if location.paths().iter().all(|path| path.exists()) - && location.paths().iter().any(|path| path.is_dir()) + if paths.paths().iter().all(|path| path.exists()) + && paths.paths().iter().any(|path| path.is_dir()) { - result.push((id, SerializedWorkspaceLocation::Local(location, order))); + result.push((id, SerializedWorkspaceLocation::Local, paths)); } else { delete_tasks.push(self.delete_workspace_by_id(id)); } @@ -1103,13 +1101,13 @@ impl WorkspaceDb { Ok(result) } - pub async fn last_workspace(&self) -> Result> { + pub async fn last_workspace(&self) -> Result> { Ok(self .recent_workspaces_on_disk() .await? .into_iter() .next() - .map(|(_, location)| location)) + .map(|(_, location, paths)| (location, paths))) } // Returns the locations of the workspaces that were still opened when the last @@ -1120,25 +1118,31 @@ impl WorkspaceDb { &self, last_session_id: &str, last_session_window_stack: Option>, - ) -> Result> { + ) -> Result> { let mut workspaces = Vec::new(); - for (location, order, window_id, ssh_project_id) in + for (paths, window_id, ssh_connection_id) in self.session_workspaces(last_session_id.to_owned())? { - if let Some(ssh_project_id) = ssh_project_id { - let location = SerializedWorkspaceLocation::Ssh(self.ssh_project(ssh_project_id)?); - workspaces.push((location, window_id.map(WindowId::from))); - } else if location.paths().iter().all(|path| path.exists()) - && location.paths().iter().any(|path| path.is_dir()) + if let Some(ssh_connection_id) = ssh_connection_id { + workspaces.push(( + SerializedWorkspaceLocation::Ssh(self.ssh_connection(ssh_connection_id)?), + paths, + window_id.map(WindowId::from), + )); + } else if paths.paths().iter().all(|path| path.exists()) + && paths.paths().iter().any(|path| path.is_dir()) { - let location = SerializedWorkspaceLocation::Local(location, order); - workspaces.push((location, window_id.map(WindowId::from))); + workspaces.push(( + SerializedWorkspaceLocation::Local, + paths, + window_id.map(WindowId::from), + )); } } if let Some(stack) = last_session_window_stack { - workspaces.sort_by_key(|(_, window_id)| { + workspaces.sort_by_key(|(_, _, window_id)| { window_id .and_then(|id| stack.iter().position(|&order_id| order_id == id)) .unwrap_or(usize::MAX) @@ -1147,7 +1151,7 @@ impl WorkspaceDb { Ok(workspaces .into_iter() - .map(|(paths, _)| paths) + .map(|(location, paths, _)| (location, paths)) .collect::>()) } @@ -1499,13 +1503,13 @@ pub fn delete_unloaded_items( #[cfg(test)] mod tests { - use std::thread; - use std::time::Duration; - use super::*; - use crate::persistence::model::SerializedWorkspace; - use crate::persistence::model::{SerializedItem, SerializedPane, SerializedPaneGroup}; + use crate::persistence::model::{ + SerializedItem, SerializedPane, SerializedPaneGroup, SerializedWorkspace, + }; use gpui; + use pretty_assertions::assert_eq; + use std::{thread, time::Duration}; #[gpui::test] async fn test_breakpoints() { @@ -1558,7 +1562,8 @@ mod tests { let workspace = SerializedWorkspace { id, - location: SerializedWorkspaceLocation::from_local_paths(["/tmp"]), + paths: PathList::new(&["/tmp"]), + location: SerializedWorkspaceLocation::Local, center_group: Default::default(), window_bounds: Default::default(), display: Default::default(), @@ -1711,7 +1716,8 @@ mod tests { let workspace = SerializedWorkspace { id, - location: SerializedWorkspaceLocation::from_local_paths(["/tmp"]), + paths: PathList::new(&["/tmp"]), + location: SerializedWorkspaceLocation::Local, center_group: Default::default(), window_bounds: Default::default(), display: Default::default(), @@ -1757,7 +1763,8 @@ mod tests { let workspace_without_breakpoint = SerializedWorkspace { id, - location: SerializedWorkspaceLocation::from_local_paths(["/tmp"]), + paths: PathList::new(&["/tmp"]), + location: SerializedWorkspaceLocation::Local, center_group: Default::default(), window_bounds: Default::default(), display: Default::default(), @@ -1851,7 +1858,8 @@ mod tests { let mut workspace_1 = SerializedWorkspace { id: WorkspaceId(1), - location: SerializedWorkspaceLocation::from_local_paths(["/tmp", "/tmp2"]), + paths: PathList::new(&["/tmp", "/tmp2"]), + location: SerializedWorkspaceLocation::Local, center_group: Default::default(), window_bounds: Default::default(), display: Default::default(), @@ -1864,7 +1872,8 @@ mod tests { let workspace_2 = SerializedWorkspace { id: WorkspaceId(2), - location: SerializedWorkspaceLocation::from_local_paths(["/tmp"]), + paths: PathList::new(&["/tmp"]), + location: SerializedWorkspaceLocation::Local, center_group: Default::default(), window_bounds: Default::default(), display: Default::default(), @@ -1893,7 +1902,7 @@ mod tests { }) .await; - workspace_1.location = SerializedWorkspaceLocation::from_local_paths(["/tmp", "/tmp3"]); + workspace_1.paths = PathList::new(&["/tmp", "/tmp3"]); db.save_workspace(workspace_1.clone()).await; db.save_workspace(workspace_1).await; db.save_workspace(workspace_2).await; @@ -1969,10 +1978,8 @@ mod tests { let workspace = SerializedWorkspace { id: WorkspaceId(5), - location: SerializedWorkspaceLocation::Local( - LocalPaths::new(["/tmp", "/tmp2"]), - LocalPathsOrder::new([1, 0]), - ), + paths: PathList::new(&["/tmp", "/tmp2"]), + location: SerializedWorkspaceLocation::Local, center_group, window_bounds: Default::default(), breakpoints: Default::default(), @@ -2004,10 +2011,8 @@ mod tests { let workspace_1 = SerializedWorkspace { id: WorkspaceId(1), - location: SerializedWorkspaceLocation::Local( - LocalPaths::new(["/tmp", "/tmp2"]), - LocalPathsOrder::new([0, 1]), - ), + paths: PathList::new(&["/tmp", "/tmp2"]), + location: SerializedWorkspaceLocation::Local, center_group: Default::default(), window_bounds: Default::default(), breakpoints: Default::default(), @@ -2020,7 +2025,8 @@ mod tests { let mut workspace_2 = SerializedWorkspace { id: WorkspaceId(2), - location: SerializedWorkspaceLocation::from_local_paths(["/tmp"]), + paths: PathList::new(&["/tmp"]), + location: SerializedWorkspaceLocation::Local, center_group: Default::default(), window_bounds: Default::default(), display: Default::default(), @@ -2049,7 +2055,7 @@ mod tests { assert_eq!(db.workspace_for_roots(&["/tmp3", "/tmp2", "/tmp4"]), None); // Test 'mutate' case of updating a pre-existing id - workspace_2.location = SerializedWorkspaceLocation::from_local_paths(["/tmp", "/tmp2"]); + workspace_2.paths = PathList::new(&["/tmp", "/tmp2"]); db.save_workspace(workspace_2.clone()).await; assert_eq!( @@ -2060,10 +2066,8 @@ mod tests { // Test other mechanism for mutating let mut workspace_3 = SerializedWorkspace { id: WorkspaceId(3), - location: SerializedWorkspaceLocation::Local( - LocalPaths::new(["/tmp", "/tmp2"]), - LocalPathsOrder::new([1, 0]), - ), + paths: PathList::new(&["/tmp2", "/tmp"]), + location: SerializedWorkspaceLocation::Local, center_group: Default::default(), window_bounds: Default::default(), breakpoints: Default::default(), @@ -2081,8 +2085,7 @@ mod tests { ); // Make sure that updating paths differently also works - workspace_3.location = - SerializedWorkspaceLocation::from_local_paths(["/tmp3", "/tmp4", "/tmp2"]); + workspace_3.paths = PathList::new(&["/tmp3", "/tmp4", "/tmp2"]); db.save_workspace(workspace_3.clone()).await; assert_eq!(db.workspace_for_roots(&["/tmp2", "tmp"]), None); assert_eq!( @@ -2100,7 +2103,8 @@ mod tests { let workspace_1 = SerializedWorkspace { id: WorkspaceId(1), - location: SerializedWorkspaceLocation::from_local_paths(["/tmp1"]), + paths: PathList::new(&["/tmp1"]), + location: SerializedWorkspaceLocation::Local, center_group: Default::default(), window_bounds: Default::default(), display: Default::default(), @@ -2113,7 +2117,8 @@ mod tests { let workspace_2 = SerializedWorkspace { id: WorkspaceId(2), - location: SerializedWorkspaceLocation::from_local_paths(["/tmp2"]), + paths: PathList::new(&["/tmp2"]), + location: SerializedWorkspaceLocation::Local, center_group: Default::default(), window_bounds: Default::default(), display: Default::default(), @@ -2126,7 +2131,8 @@ mod tests { let workspace_3 = SerializedWorkspace { id: WorkspaceId(3), - location: SerializedWorkspaceLocation::from_local_paths(["/tmp3"]), + paths: PathList::new(&["/tmp3"]), + location: SerializedWorkspaceLocation::Local, center_group: Default::default(), window_bounds: Default::default(), display: Default::default(), @@ -2139,7 +2145,8 @@ mod tests { let workspace_4 = SerializedWorkspace { id: WorkspaceId(4), - location: SerializedWorkspaceLocation::from_local_paths(["/tmp4"]), + paths: PathList::new(&["/tmp4"]), + location: SerializedWorkspaceLocation::Local, center_group: Default::default(), window_bounds: Default::default(), display: Default::default(), @@ -2150,14 +2157,15 @@ mod tests { window_id: None, }; - let ssh_project = db - .get_or_create_ssh_project("my-host".to_string(), Some(1234), vec![], None) + let connection_id = db + .get_or_create_ssh_connection("my-host".to_string(), Some(1234), None) .await .unwrap(); let workspace_5 = SerializedWorkspace { id: WorkspaceId(5), - location: SerializedWorkspaceLocation::Ssh(ssh_project.clone()), + paths: PathList::default(), + location: SerializedWorkspaceLocation::Ssh(db.ssh_connection(connection_id).unwrap()), center_group: Default::default(), window_bounds: Default::default(), display: Default::default(), @@ -2170,10 +2178,8 @@ mod tests { let workspace_6 = SerializedWorkspace { id: WorkspaceId(6), - location: SerializedWorkspaceLocation::Local( - LocalPaths::new(["/tmp6a", "/tmp6b", "/tmp6c"]), - LocalPathsOrder::new([2, 1, 0]), - ), + paths: PathList::new(&["/tmp6a", "/tmp6b", "/tmp6c"]), + location: SerializedWorkspaceLocation::Local, center_group: Default::default(), window_bounds: Default::default(), breakpoints: Default::default(), @@ -2195,41 +2201,36 @@ mod tests { let locations = db.session_workspaces("session-id-1".to_owned()).unwrap(); assert_eq!(locations.len(), 2); - assert_eq!(locations[0].0, LocalPaths::new(["/tmp2"])); - assert_eq!(locations[0].1, LocalPathsOrder::new([0])); - assert_eq!(locations[0].2, Some(20)); - assert_eq!(locations[1].0, LocalPaths::new(["/tmp1"])); - assert_eq!(locations[1].1, LocalPathsOrder::new([0])); - assert_eq!(locations[1].2, Some(10)); + assert_eq!(locations[0].0, PathList::new(&["/tmp2"])); + assert_eq!(locations[0].1, Some(20)); + assert_eq!(locations[1].0, PathList::new(&["/tmp1"])); + assert_eq!(locations[1].1, Some(10)); let locations = db.session_workspaces("session-id-2".to_owned()).unwrap(); assert_eq!(locations.len(), 2); - let empty_paths: Vec<&str> = Vec::new(); - assert_eq!(locations[0].0, LocalPaths::new(empty_paths.iter())); - assert_eq!(locations[0].1, LocalPathsOrder::new([])); - assert_eq!(locations[0].2, Some(50)); - assert_eq!(locations[0].3, Some(ssh_project.id.0)); - assert_eq!(locations[1].0, LocalPaths::new(["/tmp3"])); - assert_eq!(locations[1].1, LocalPathsOrder::new([0])); - assert_eq!(locations[1].2, Some(30)); + assert_eq!(locations[0].0, PathList::default()); + assert_eq!(locations[0].1, Some(50)); + assert_eq!(locations[0].2, Some(connection_id)); + assert_eq!(locations[1].0, PathList::new(&["/tmp3"])); + assert_eq!(locations[1].1, Some(30)); let locations = db.session_workspaces("session-id-3".to_owned()).unwrap(); assert_eq!(locations.len(), 1); assert_eq!( locations[0].0, - LocalPaths::new(["/tmp6a", "/tmp6b", "/tmp6c"]), + PathList::new(&["/tmp6a", "/tmp6b", "/tmp6c"]), ); - assert_eq!(locations[0].1, LocalPathsOrder::new([2, 1, 0])); - assert_eq!(locations[0].2, Some(60)); + assert_eq!(locations[0].1, Some(60)); } fn default_workspace>( - workspace_id: &[P], + paths: &[P], center_group: &SerializedPaneGroup, ) -> SerializedWorkspace { SerializedWorkspace { id: WorkspaceId(4), - location: SerializedWorkspaceLocation::from_local_paths(workspace_id), + paths: PathList::new(paths), + location: SerializedWorkspaceLocation::Local, center_group: center_group.clone(), window_bounds: Default::default(), display: Default::default(), @@ -2252,30 +2253,18 @@ mod tests { WorkspaceDb::open_test_db("test_serializing_workspaces_last_session_workspaces").await; let workspaces = [ - (1, vec![dir1.path()], vec![0], 9), - (2, vec![dir2.path()], vec![0], 5), - (3, vec![dir3.path()], vec![0], 8), - (4, vec![dir4.path()], vec![0], 2), - ( - 5, - vec![dir1.path(), dir2.path(), dir3.path()], - vec![0, 1, 2], - 3, - ), - ( - 6, - vec![dir2.path(), dir3.path(), dir4.path()], - vec![2, 1, 0], - 4, - ), + (1, vec![dir1.path()], 9), + (2, vec![dir2.path()], 5), + (3, vec![dir3.path()], 8), + (4, vec![dir4.path()], 2), + (5, vec![dir1.path(), dir2.path(), dir3.path()], 3), + (6, vec![dir4.path(), dir3.path(), dir2.path()], 4), ] .into_iter() - .map(|(id, locations, order, window_id)| SerializedWorkspace { + .map(|(id, paths, window_id)| SerializedWorkspace { id: WorkspaceId(id), - location: SerializedWorkspaceLocation::Local( - LocalPaths::new(locations), - LocalPathsOrder::new(order), - ), + paths: PathList::new(paths.as_slice()), + location: SerializedWorkspaceLocation::Local, center_group: Default::default(), window_bounds: Default::default(), display: Default::default(), @@ -2300,39 +2289,37 @@ mod tests { WindowId::from(4), // Bottom ])); - let have = db + let locations = db .last_session_workspace_locations("one-session", stack) .unwrap(); - assert_eq!(have.len(), 6); assert_eq!( - have[0], - SerializedWorkspaceLocation::from_local_paths(&[dir4.path()]) - ); - assert_eq!( - have[1], - SerializedWorkspaceLocation::from_local_paths([dir3.path()]) - ); - assert_eq!( - have[2], - SerializedWorkspaceLocation::from_local_paths([dir2.path()]) - ); - assert_eq!( - have[3], - SerializedWorkspaceLocation::from_local_paths([dir1.path()]) - ); - assert_eq!( - have[4], - SerializedWorkspaceLocation::Local( - LocalPaths::new([dir1.path(), dir2.path(), dir3.path()]), - LocalPathsOrder::new([0, 1, 2]), - ), - ); - assert_eq!( - have[5], - SerializedWorkspaceLocation::Local( - LocalPaths::new([dir2.path(), dir3.path(), dir4.path()]), - LocalPathsOrder::new([2, 1, 0]), - ), + locations, + [ + ( + SerializedWorkspaceLocation::Local, + PathList::new(&[dir4.path()]) + ), + ( + SerializedWorkspaceLocation::Local, + PathList::new(&[dir3.path()]) + ), + ( + SerializedWorkspaceLocation::Local, + PathList::new(&[dir2.path()]) + ), + ( + SerializedWorkspaceLocation::Local, + PathList::new(&[dir1.path()]) + ), + ( + SerializedWorkspaceLocation::Local, + PathList::new(&[dir1.path(), dir2.path(), dir3.path()]) + ), + ( + SerializedWorkspaceLocation::Local, + PathList::new(&[dir4.path(), dir3.path(), dir2.path()]) + ), + ] ); } @@ -2343,7 +2330,7 @@ mod tests { ) .await; - let ssh_projects = [ + let ssh_connections = [ ("host-1", "my-user-1"), ("host-2", "my-user-2"), ("host-3", "my-user-3"), @@ -2351,24 +2338,30 @@ mod tests { ] .into_iter() .map(|(host, user)| async { - db.get_or_create_ssh_project(host.to_string(), None, vec![], Some(user.to_string())) + db.get_or_create_ssh_connection(host.to_string(), None, Some(user.to_string())) .await - .unwrap() + .unwrap(); + SerializedSshConnection { + host: host.into(), + port: None, + user: Some(user.into()), + } }) .collect::>(); - let ssh_projects = futures::future::join_all(ssh_projects).await; + let ssh_connections = futures::future::join_all(ssh_connections).await; let workspaces = [ - (1, ssh_projects[0].clone(), 9), - (2, ssh_projects[1].clone(), 5), - (3, ssh_projects[2].clone(), 8), - (4, ssh_projects[3].clone(), 2), + (1, ssh_connections[0].clone(), 9), + (2, ssh_connections[1].clone(), 5), + (3, ssh_connections[2].clone(), 8), + (4, ssh_connections[3].clone(), 2), ] .into_iter() - .map(|(id, ssh_project, window_id)| SerializedWorkspace { + .map(|(id, ssh_connection, window_id)| SerializedWorkspace { id: WorkspaceId(id), - location: SerializedWorkspaceLocation::Ssh(ssh_project), + paths: PathList::default(), + location: SerializedWorkspaceLocation::Ssh(ssh_connection), center_group: Default::default(), window_bounds: Default::default(), display: Default::default(), @@ -2397,19 +2390,31 @@ mod tests { assert_eq!(have.len(), 4); assert_eq!( have[0], - SerializedWorkspaceLocation::Ssh(ssh_projects[3].clone()) + ( + SerializedWorkspaceLocation::Ssh(ssh_connections[3].clone()), + PathList::default() + ) ); assert_eq!( have[1], - SerializedWorkspaceLocation::Ssh(ssh_projects[2].clone()) + ( + SerializedWorkspaceLocation::Ssh(ssh_connections[2].clone()), + PathList::default() + ) ); assert_eq!( have[2], - SerializedWorkspaceLocation::Ssh(ssh_projects[1].clone()) + ( + SerializedWorkspaceLocation::Ssh(ssh_connections[1].clone()), + PathList::default() + ) ); assert_eq!( have[3], - SerializedWorkspaceLocation::Ssh(ssh_projects[0].clone()) + ( + SerializedWorkspaceLocation::Ssh(ssh_connections[0].clone()), + PathList::default() + ) ); } @@ -2417,116 +2422,110 @@ mod tests { async fn test_get_or_create_ssh_project() { let db = WorkspaceDb::open_test_db("test_get_or_create_ssh_project").await; - let (host, port, paths, user) = ( - "example.com".to_string(), - Some(22_u16), - vec!["/home/user".to_string(), "/etc/nginx".to_string()], - Some("user".to_string()), - ); + let host = "example.com".to_string(); + let port = Some(22_u16); + let user = Some("user".to_string()); - let project = db - .get_or_create_ssh_project(host.clone(), port, paths.clone(), user.clone()) + let connection_id = db + .get_or_create_ssh_connection(host.clone(), port, user.clone()) .await .unwrap(); - assert_eq!(project.host, host); - assert_eq!(project.paths, paths); - assert_eq!(project.user, user); - // Test that calling the function again with the same parameters returns the same project - let same_project = db - .get_or_create_ssh_project(host.clone(), port, paths.clone(), user.clone()) + let same_connection = db + .get_or_create_ssh_connection(host.clone(), port, user.clone()) .await .unwrap(); - assert_eq!(project.id, same_project.id); + assert_eq!(connection_id, same_connection); // Test with different parameters - let (host2, paths2, user2) = ( - "otherexample.com".to_string(), - vec!["/home/otheruser".to_string()], - Some("otheruser".to_string()), - ); + let host2 = "otherexample.com".to_string(); + let port2 = None; + let user2 = Some("otheruser".to_string()); - let different_project = db - .get_or_create_ssh_project(host2.clone(), None, paths2.clone(), user2.clone()) + let different_connection = db + .get_or_create_ssh_connection(host2.clone(), port2, user2.clone()) .await .unwrap(); - assert_ne!(project.id, different_project.id); - assert_eq!(different_project.host, host2); - assert_eq!(different_project.paths, paths2); - assert_eq!(different_project.user, user2); + assert_ne!(connection_id, different_connection); } #[gpui::test] async fn test_get_or_create_ssh_project_with_null_user() { let db = WorkspaceDb::open_test_db("test_get_or_create_ssh_project_with_null_user").await; - let (host, port, paths, user) = ( - "example.com".to_string(), - None, - vec!["/home/user".to_string()], - None, - ); + let (host, port, user) = ("example.com".to_string(), None, None); - let project = db - .get_or_create_ssh_project(host.clone(), port, paths.clone(), None) + let connection_id = db + .get_or_create_ssh_connection(host.clone(), port, None) .await .unwrap(); - assert_eq!(project.host, host); - assert_eq!(project.paths, paths); - assert_eq!(project.user, None); - - // Test that calling the function again with the same parameters returns the same project - let same_project = db - .get_or_create_ssh_project(host.clone(), port, paths.clone(), user.clone()) + let same_connection_id = db + .get_or_create_ssh_connection(host.clone(), port, user.clone()) .await .unwrap(); - assert_eq!(project.id, same_project.id); + assert_eq!(connection_id, same_connection_id); } #[gpui::test] - async fn test_get_ssh_projects() { - let db = WorkspaceDb::open_test_db("test_get_ssh_projects").await; + async fn test_get_ssh_connections() { + let db = WorkspaceDb::open_test_db("test_get_ssh_connections").await; - let projects = vec![ - ( - "example.com".to_string(), - None, - vec!["/home/user".to_string()], - None, - ), + let connections = [ + ("example.com".to_string(), None, None), ( "anotherexample.com".to_string(), Some(123_u16), - vec!["/home/user2".to_string()], Some("user2".to_string()), ), - ( - "yetanother.com".to_string(), - Some(345_u16), - vec!["/home/user3".to_string(), "/proc/1234/exe".to_string()], - None, - ), + ("yetanother.com".to_string(), Some(345_u16), None), ]; - for (host, port, paths, user) in projects.iter() { - let project = db - .get_or_create_ssh_project(host.clone(), *port, paths.clone(), user.clone()) - .await - .unwrap(); - - assert_eq!(&project.host, host); - assert_eq!(&project.port, port); - assert_eq!(&project.paths, paths); - assert_eq!(&project.user, user); + let mut ids = Vec::new(); + for (host, port, user) in connections.iter() { + ids.push( + db.get_or_create_ssh_connection(host.clone(), *port, user.clone()) + .await + .unwrap(), + ); } - let stored_projects = db.ssh_projects().unwrap(); - assert_eq!(stored_projects.len(), projects.len()); + let stored_projects = db.ssh_connections().unwrap(); + assert_eq!( + stored_projects, + [ + ( + ids[0], + SerializedSshConnection { + host: "example.com".into(), + port: None, + user: None, + } + ), + ( + ids[1], + SerializedSshConnection { + host: "anotherexample.com".into(), + port: Some(123), + user: Some("user2".into()), + } + ), + ( + ids[2], + SerializedSshConnection { + host: "yetanother.com".into(), + port: Some(345), + user: None, + } + ), + ] + .into_iter() + .collect::>(), + ); } #[gpui::test] @@ -2659,56 +2658,4 @@ mod tests { assert_eq!(workspace.center_group, new_workspace.center_group); } - - #[gpui::test] - async fn test_update_ssh_project_paths() { - zlog::init_test(); - - let db = WorkspaceDb::open_test_db("test_update_ssh_project_paths").await; - - let (host, port, initial_paths, user) = ( - "example.com".to_string(), - Some(22_u16), - vec!["/home/user".to_string(), "/etc/nginx".to_string()], - Some("user".to_string()), - ); - - let project = db - .get_or_create_ssh_project(host.clone(), port, initial_paths.clone(), user.clone()) - .await - .unwrap(); - - assert_eq!(project.host, host); - assert_eq!(project.paths, initial_paths); - assert_eq!(project.user, user); - - let new_paths = vec![ - "/home/user".to_string(), - "/etc/nginx".to_string(), - "/var/log".to_string(), - "/opt/app".to_string(), - ]; - - let updated_project = db - .update_ssh_project_paths(project.id, new_paths.clone()) - .await - .unwrap(); - - assert_eq!(updated_project.id, project.id); - assert_eq!(updated_project.paths, new_paths); - - let retrieved_project = db - .get_ssh_project( - host.clone(), - port, - serde_json::to_string(&new_paths).unwrap(), - user.clone(), - ) - .await - .unwrap() - .unwrap(); - - assert_eq!(retrieved_project.id, project.id); - assert_eq!(retrieved_project.paths, new_paths); - } } diff --git a/crates/workspace/src/persistence/model.rs b/crates/workspace/src/persistence/model.rs index 15a54ac62f..04757d0495 100644 --- a/crates/workspace/src/persistence/model.rs +++ b/crates/workspace/src/persistence/model.rs @@ -1,256 +1,48 @@ use super::{SerializedAxis, SerializedWindowBounds}; use crate::{ Member, Pane, PaneAxis, SerializableItemRegistry, Workspace, WorkspaceId, item::ItemHandle, + path_list::PathList, }; -use anyhow::{Context as _, Result}; +use anyhow::Result; use async_recursion::async_recursion; use db::sqlez::{ bindable::{Bind, Column, StaticColumnCount}, statement::Statement, }; use gpui::{AsyncWindowContext, Entity, WeakEntity}; -use itertools::Itertools as _; + use project::{Project, debugger::breakpoint_store::SourceBreakpoint}; -use remote::ssh_session::SshProjectId; use serde::{Deserialize, Serialize}; use std::{ collections::BTreeMap, path::{Path, PathBuf}, sync::Arc, }; -use util::{ResultExt, paths::SanitizedPath}; +use util::ResultExt; use uuid::Uuid; +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize, +)] +pub(crate) struct SshConnectionId(pub u64); + #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] -pub struct SerializedSshProject { - pub id: SshProjectId, +pub struct SerializedSshConnection { pub host: String, pub port: Option, - pub paths: Vec, pub user: Option, } -impl SerializedSshProject { - pub fn ssh_urls(&self) -> Vec { - self.paths - .iter() - .map(|path| { - let mut result = String::new(); - if let Some(user) = &self.user { - result.push_str(user); - result.push('@'); - } - result.push_str(&self.host); - if let Some(port) = &self.port { - result.push(':'); - result.push_str(&port.to_string()); - } - result.push_str(path); - PathBuf::from(result) - }) - .collect() - } -} - -impl StaticColumnCount for SerializedSshProject { - fn column_count() -> usize { - 5 - } -} - -impl Bind for &SerializedSshProject { - fn bind(&self, statement: &Statement, start_index: i32) -> Result { - let next_index = statement.bind(&self.id.0, start_index)?; - let next_index = statement.bind(&self.host, next_index)?; - let next_index = statement.bind(&self.port, next_index)?; - let raw_paths = serde_json::to_string(&self.paths)?; - let next_index = statement.bind(&raw_paths, next_index)?; - statement.bind(&self.user, next_index) - } -} - -impl Column for SerializedSshProject { - fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { - let id = statement.column_int64(start_index)?; - let host = statement.column_text(start_index + 1)?.to_string(); - let (port, _) = Option::::column(statement, start_index + 2)?; - let raw_paths = statement.column_text(start_index + 3)?.to_string(); - let paths: Vec = serde_json::from_str(&raw_paths)?; - - let (user, _) = Option::::column(statement, start_index + 4)?; - - Ok(( - Self { - id: SshProjectId(id as u64), - host, - port, - paths, - user, - }, - start_index + 5, - )) - } -} - -#[derive(Debug, PartialEq, Clone)] -pub struct LocalPaths(Arc>); - -impl LocalPaths { - pub fn new>(paths: impl IntoIterator) -> Self { - let mut paths: Vec = paths - .into_iter() - .map(|p| SanitizedPath::from(p).into()) - .collect(); - // Ensure all future `zed workspace1 workspace2` and `zed workspace2 workspace1` calls are using the same workspace. - // The actual workspace order is stored in the `LocalPathsOrder` struct. - paths.sort(); - Self(Arc::new(paths)) - } - - pub fn paths(&self) -> &Arc> { - &self.0 - } -} - -impl StaticColumnCount for LocalPaths {} -impl Bind for &LocalPaths { - fn bind(&self, statement: &Statement, start_index: i32) -> Result { - statement.bind(&bincode::serialize(&self.0)?, start_index) - } -} - -impl Column for LocalPaths { - fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { - let path_blob = statement.column_blob(start_index)?; - let paths: Arc> = if path_blob.is_empty() { - Default::default() - } else { - bincode::deserialize(path_blob).context("Bincode deserialization of paths failed")? - }; - - Ok((Self(paths), start_index + 1)) - } -} - -#[derive(Debug, PartialEq, Clone)] -pub struct LocalPathsOrder(Vec); - -impl LocalPathsOrder { - pub fn new(order: impl IntoIterator) -> Self { - Self(order.into_iter().collect()) - } - - pub fn order(&self) -> &[usize] { - self.0.as_slice() - } - - pub fn default_for_paths(paths: &LocalPaths) -> Self { - Self::new(0..paths.0.len()) - } -} - -impl StaticColumnCount for LocalPathsOrder {} -impl Bind for &LocalPathsOrder { - fn bind(&self, statement: &Statement, start_index: i32) -> Result { - statement.bind(&bincode::serialize(&self.0)?, start_index) - } -} - -impl Column for LocalPathsOrder { - fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { - let order_blob = statement.column_blob(start_index)?; - let order = if order_blob.is_empty() { - Vec::new() - } else { - bincode::deserialize(order_blob).context("deserializing workspace root order")? - }; - - Ok((Self(order), start_index + 1)) - } -} - #[derive(Debug, PartialEq, Clone)] pub enum SerializedWorkspaceLocation { - Local(LocalPaths, LocalPathsOrder), - Ssh(SerializedSshProject), + Local, + Ssh(SerializedSshConnection), } impl SerializedWorkspaceLocation { - /// Create a new `SerializedWorkspaceLocation` from a list of local paths. - /// - /// The paths will be sorted and the order will be stored in the `LocalPathsOrder` struct. - /// - /// # Examples - /// - /// ``` - /// use std::path::Path; - /// use zed_workspace::SerializedWorkspaceLocation; - /// - /// let location = SerializedWorkspaceLocation::from_local_paths(vec![ - /// Path::new("path/to/workspace1"), - /// Path::new("path/to/workspace2"), - /// ]); - /// assert_eq!(location, SerializedWorkspaceLocation::Local( - /// LocalPaths::new(vec![ - /// Path::new("path/to/workspace1"), - /// Path::new("path/to/workspace2"), - /// ]), - /// LocalPathsOrder::new(vec![0, 1]), - /// )); - /// ``` - /// - /// ``` - /// use std::path::Path; - /// use zed_workspace::SerializedWorkspaceLocation; - /// - /// let location = SerializedWorkspaceLocation::from_local_paths(vec![ - /// Path::new("path/to/workspace2"), - /// Path::new("path/to/workspace1"), - /// ]); - /// - /// assert_eq!(location, SerializedWorkspaceLocation::Local( - /// LocalPaths::new(vec![ - /// Path::new("path/to/workspace1"), - /// Path::new("path/to/workspace2"), - /// ]), - /// LocalPathsOrder::new(vec![1, 0]), - /// )); - /// ``` - pub fn from_local_paths>(paths: impl IntoIterator) -> Self { - let mut indexed_paths: Vec<_> = paths - .into_iter() - .map(|p| p.as_ref().to_path_buf()) - .enumerate() - .collect(); - - indexed_paths.sort_by(|(_, a), (_, b)| a.cmp(b)); - - let sorted_paths: Vec<_> = indexed_paths.iter().map(|(_, path)| path.clone()).collect(); - let order: Vec<_> = indexed_paths.iter().map(|(index, _)| *index).collect(); - - Self::Local(LocalPaths::new(sorted_paths), LocalPathsOrder::new(order)) - } - /// Get sorted paths pub fn sorted_paths(&self) -> Arc> { - match self { - SerializedWorkspaceLocation::Local(paths, order) => { - if order.order().is_empty() { - paths.paths().clone() - } else { - Arc::new( - order - .order() - .iter() - .zip(paths.paths().iter()) - .sorted_by_key(|(i, _)| **i) - .map(|(_, p)| p.clone()) - .collect(), - ) - } - } - SerializedWorkspaceLocation::Ssh(ssh_project) => Arc::new(ssh_project.ssh_urls()), - } + unimplemented!() } } @@ -258,6 +50,7 @@ impl SerializedWorkspaceLocation { pub(crate) struct SerializedWorkspace { pub(crate) id: WorkspaceId, pub(crate) location: SerializedWorkspaceLocation, + pub(crate) paths: PathList, pub(crate) center_group: SerializedPaneGroup, pub(crate) window_bounds: Option, pub(crate) centered_layout: bool, @@ -581,80 +374,3 @@ impl Column for SerializedItem { )) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_serialize_local_paths() { - let paths = vec!["b", "a", "c"]; - let serialized = SerializedWorkspaceLocation::from_local_paths(paths); - - assert_eq!( - serialized, - SerializedWorkspaceLocation::Local( - LocalPaths::new(vec!["a", "b", "c"]), - LocalPathsOrder::new(vec![1, 0, 2]) - ) - ); - } - - #[test] - fn test_sorted_paths() { - let paths = vec!["b", "a", "c"]; - let serialized = SerializedWorkspaceLocation::from_local_paths(paths); - assert_eq!( - serialized.sorted_paths(), - Arc::new(vec![ - PathBuf::from("b"), - PathBuf::from("a"), - PathBuf::from("c"), - ]) - ); - - let paths = Arc::new(vec![ - PathBuf::from("a"), - PathBuf::from("b"), - PathBuf::from("c"), - ]); - let order = vec![2, 0, 1]; - let serialized = - SerializedWorkspaceLocation::Local(LocalPaths(paths), LocalPathsOrder(order)); - assert_eq!( - serialized.sorted_paths(), - Arc::new(vec![ - PathBuf::from("b"), - PathBuf::from("c"), - PathBuf::from("a"), - ]) - ); - - let paths = Arc::new(vec![ - PathBuf::from("a"), - PathBuf::from("b"), - PathBuf::from("c"), - ]); - let order = vec![]; - let serialized = - SerializedWorkspaceLocation::Local(LocalPaths(paths.clone()), LocalPathsOrder(order)); - assert_eq!(serialized.sorted_paths(), paths); - - let urls = ["/a", "/b", "/c"]; - let serialized = SerializedWorkspaceLocation::Ssh(SerializedSshProject { - id: SshProjectId(0), - host: "host".to_string(), - port: Some(22), - paths: urls.iter().map(|s| s.to_string()).collect(), - user: Some("user".to_string()), - }); - assert_eq!( - serialized.sorted_paths(), - Arc::new( - urls.iter() - .map(|p| PathBuf::from(format!("user@host:22{}", p))) - .collect() - ) - ); - } -} diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 499e4f4619..bf58786d67 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -1,10 +1,12 @@ pub mod dock; pub mod history_manager; +pub mod invalid_buffer_view; pub mod item; mod modal_layer; pub mod notifications; pub mod pane; pub mod pane_group; +mod path_list; mod persistence; pub mod searchable; pub mod shared_screen; @@ -17,6 +19,7 @@ mod workspace_settings; pub use crate::notifications::NotificationFrame; pub use dock::Panel; +pub use path_list::PathList; pub use toast_layer::{ToastAction, ToastLayer, ToastView}; use anyhow::{Context as _, Result, anyhow}; @@ -61,13 +64,10 @@ use notifications::{ }; pub use pane::*; pub use pane_group::*; -use persistence::{ - DB, SerializedWindowBounds, - model::{SerializedSshProject, SerializedWorkspace}, -}; +use persistence::{DB, SerializedWindowBounds, model::SerializedWorkspace}; pub use persistence::{ DB as WORKSPACE_DB, WorkspaceDb, delete_unloaded_items, - model::{ItemId, LocalPaths, SerializedWorkspaceLocation}, + model::{ItemId, SerializedSshConnection, SerializedWorkspaceLocation}, }; use postage::stream::Stream; use project::{ @@ -612,21 +612,49 @@ impl ProjectItemRegistry { ); self.build_project_item_for_path_fns .push(|project, project_path, window, cx| { + let project_path = project_path.clone(); + let abs_path = project.read(cx).absolute_path(&project_path, cx); + let is_local = project.read(cx).is_local(); let project_item = - ::try_open(project, project_path, cx)?; + ::try_open(project, &project_path, cx)?; let project = project.clone(); - Some(window.spawn(cx, async move |cx| { - let project_item = project_item.await?; - let project_entry_id: Option = - project_item.read_with(cx, project::ProjectItem::entry_id)?; - let build_workspace_item = Box::new( - |pane: &mut Pane, window: &mut Window, cx: &mut Context| { - Box::new(cx.new(|cx| { - T::for_project_item(project, Some(pane), project_item, window, cx) - })) as Box + Some(window.spawn(cx, async move |cx| match project_item.await { + Ok(project_item) => { + let project_item = project_item; + let project_entry_id: Option = + project_item.read_with(cx, project::ProjectItem::entry_id)?; + let build_workspace_item = Box::new( + |pane: &mut Pane, window: &mut Window, cx: &mut Context| { + Box::new(cx.new(|cx| { + T::for_project_item( + project, + Some(pane), + project_item, + window, + cx, + ) + })) as Box + }, + ) as Box<_>; + Ok((project_entry_id, build_workspace_item)) + } + Err(e) => match abs_path { + Some(abs_path) => match cx.update(|window, cx| { + T::for_broken_project_item(abs_path, is_local, &e, window, cx) + })? { + Some(broken_project_item_view) => { + let build_workspace_item = Box::new( + move |_: &mut Pane, _: &mut Window, cx: &mut Context| { + cx.new(|_| broken_project_item_view).boxed_clone() + }, + ) + as Box<_>; + Ok((None, build_workspace_item)) + } + None => Err(e)?, }, - ) as Box<_>; - Ok((project_entry_id, build_workspace_item)) + None => Err(e)?, + }, })) }); } @@ -1013,7 +1041,7 @@ pub enum OpenVisible { enum WorkspaceLocation { // Valid local paths or SSH project to serialize - Location(SerializedWorkspaceLocation), + Location(SerializedWorkspaceLocation, PathList), // No valid location found hence clear session id DetachFromSession, // No valid location found to serialize @@ -1097,7 +1125,6 @@ pub struct Workspace { terminal_provider: Option>, debugger_provider: Option>, serializable_items_tx: UnboundedSender>, - serialized_ssh_project: Option, _items_serializer: Task>, session_id: Option, scheduled_tasks: Vec>, @@ -1146,8 +1173,6 @@ impl Workspace { project::Event::WorktreeRemoved(_) | project::Event::WorktreeAdded(_) => { this.update_window_title(window, cx); - this.update_ssh_paths(cx); - this.serialize_ssh_paths(window, cx); this.serialize_workspace(window, cx); // This event could be triggered by `AddFolderToProject` or `RemoveFromProject`. this.update_history(cx); @@ -1432,7 +1457,7 @@ impl Workspace { serializable_items_tx, _items_serializer, session_id: Some(session_id), - serialized_ssh_project: None, + scheduled_tasks: Vec::new(), } } @@ -1472,20 +1497,9 @@ impl Workspace { let serialized_workspace = persistence::DB.workspace_for_roots(paths_to_open.as_slice()); - let workspace_location = serialized_workspace - .as_ref() - .map(|ws| &ws.location) - .and_then(|loc| match loc { - SerializedWorkspaceLocation::Local(_, order) => { - Some((loc.sorted_paths(), order.order())) - } - _ => None, - }); - - if let Some((paths, order)) = workspace_location { - paths_to_open = paths.iter().cloned().collect(); - - if order.iter().enumerate().any(|(i, &j)| i != j) { + if let Some(paths) = serialized_workspace.as_ref().map(|ws| &ws.paths) { + paths_to_open = paths.paths().to_vec(); + if !paths.is_lexicographically_ordered() { project_handle .update(cx, |project, cx| { project.set_worktrees_reordered(true, cx); @@ -2005,14 +2019,6 @@ impl Workspace { self.debugger_provider.clone() } - pub fn serialized_ssh_project(&self) -> Option { - self.serialized_ssh_project.clone() - } - - pub fn set_serialized_ssh_project(&mut self, serialized_ssh_project: SerializedSshProject) { - self.serialized_ssh_project = Some(serialized_ssh_project); - } - pub fn prompt_for_open_path( &mut self, path_prompt_options: PathPromptOptions, @@ -2249,27 +2255,43 @@ impl Workspace { })?; if let Some(active_call) = active_call - && close_intent != CloseIntent::Quit && workspace_count == 1 && active_call.read_with(cx, |call, _| call.room().is_some())? { - let answer = cx.update(|window, cx| { - window.prompt( - PromptLevel::Warning, - "Do you want to leave the current call?", - None, - &["Close window and hang up", "Cancel"], - cx, - ) - })?; + if close_intent == CloseIntent::CloseWindow { + let answer = cx.update(|window, cx| { + window.prompt( + PromptLevel::Warning, + "Do you want to leave the current call?", + None, + &["Close window and hang up", "Cancel"], + cx, + ) + })?; - if answer.await.log_err() == Some(1) { - return anyhow::Ok(false); - } else { - active_call - .update(cx, |call, cx| call.hang_up(cx))? - .await - .log_err(); + if answer.await.log_err() == Some(1) { + return anyhow::Ok(false); + } else { + active_call + .update(cx, |call, cx| call.hang_up(cx))? + .await + .log_err(); + } + } + if close_intent == CloseIntent::ReplaceWindow { + _ = active_call.update(cx, |this, cx| { + let workspace = cx + .windows() + .iter() + .filter_map(|window| window.downcast::()) + .next() + .unwrap(); + let project = workspace.read(cx)?.project.clone(); + if project.read(cx).is_shared() { + this.unshare_project(project, cx)?; + } + Ok::<_, anyhow::Error>(()) + })?; } } @@ -3363,9 +3385,8 @@ impl Workspace { window: &mut Window, cx: &mut App, ) -> Task, WorkspaceItemBuilder)>> { - let project = self.project().clone(); let registry = cx.default_global::().clone(); - registry.open_path(&project, &path, window, cx) + registry.open_path(self.project(), &path, window, cx) } pub fn find_project_item( @@ -5044,59 +5065,12 @@ impl Workspace { self.session_id.clone() } - fn local_paths(&self, cx: &App) -> Option>> { + pub fn root_paths(&self, cx: &App) -> Vec> { let project = self.project().read(cx); - - if project.is_local() { - Some( - project - .visible_worktrees(cx) - .map(|worktree| worktree.read(cx).abs_path()) - .collect::>(), - ) - } else { - None - } - } - - fn update_ssh_paths(&mut self, cx: &App) { - let project = self.project().read(cx); - if !project.is_local() { - let paths: Vec = project - .visible_worktrees(cx) - .map(|worktree| worktree.read(cx).abs_path().to_string_lossy().to_string()) - .collect(); - if let Some(ssh_project) = &mut self.serialized_ssh_project { - ssh_project.paths = paths; - } - } - } - - fn serialize_ssh_paths(&mut self, window: &mut Window, cx: &mut Context) { - if self._schedule_serialize_ssh_paths.is_none() { - self._schedule_serialize_ssh_paths = - Some(cx.spawn_in(window, async move |this, cx| { - cx.background_executor() - .timer(SERIALIZATION_THROTTLE_TIME) - .await; - this.update_in(cx, |this, window, cx| { - let task = if let Some(ssh_project) = &this.serialized_ssh_project { - let ssh_project_id = ssh_project.id; - let ssh_project_paths = ssh_project.paths.clone(); - window.spawn(cx, async move |_| { - persistence::DB - .update_ssh_project_paths(ssh_project_id, ssh_project_paths) - .await - }) - } else { - Task::ready(Err(anyhow::anyhow!("No SSH project to serialize"))) - }; - task.detach(); - this._schedule_serialize_ssh_paths.take(); - }) - .log_err(); - })); - } + project + .visible_worktrees(cx) + .map(|worktree| worktree.read(cx).abs_path()) + .collect::>() } fn remove_panes(&mut self, member: Member, window: &mut Window, cx: &mut Context) { @@ -5269,7 +5243,7 @@ impl Workspace { } match self.serialize_workspace_location(cx) { - WorkspaceLocation::Location(location) => { + WorkspaceLocation::Location(location, paths) => { let breakpoints = self.project.update(cx, |project, cx| { project .breakpoint_store() @@ -5283,6 +5257,7 @@ impl Workspace { let serialized_workspace = SerializedWorkspace { id: database_id, location, + paths, center_group, window_bounds, display: Default::default(), @@ -5308,13 +5283,19 @@ impl Workspace { } fn serialize_workspace_location(&self, cx: &App) -> WorkspaceLocation { - if let Some(ssh_project) = &self.serialized_ssh_project { - WorkspaceLocation::Location(SerializedWorkspaceLocation::Ssh(ssh_project.clone())) - } else if let Some(local_paths) = self.local_paths(cx) { - if !local_paths.is_empty() { - WorkspaceLocation::Location(SerializedWorkspaceLocation::from_local_paths( - local_paths, - )) + let paths = PathList::new(&self.root_paths(cx)); + if let Some(connection) = self.project.read(cx).ssh_connection_options(cx) { + WorkspaceLocation::Location( + SerializedWorkspaceLocation::Ssh(SerializedSshConnection { + host: connection.host, + port: connection.port, + user: connection.username, + }), + paths, + ) + } else if self.project.read(cx).is_local() { + if !paths.is_empty() { + WorkspaceLocation::Location(SerializedWorkspaceLocation::Local, paths) } else { WorkspaceLocation::DetachFromSession } @@ -5327,13 +5308,13 @@ impl Workspace { let Some(id) = self.database_id() else { return; }; - let location = match self.serialize_workspace_location(cx) { - WorkspaceLocation::Location(location) => location, - _ => return, - }; + if !self.project.read(cx).is_local() { + return; + } if let Some(manager) = HistoryManager::global(cx) { + let paths = PathList::new(&self.root_paths(cx)); manager.update(cx, |this, cx| { - this.update_history(id, HistoryManagerEntry::new(id, &location), cx); + this.update_history(id, HistoryManagerEntry::new(id, &paths), cx); }); } } @@ -6799,14 +6780,14 @@ impl WorkspaceHandle for Entity { } } -pub async fn last_opened_workspace_location() -> Option { +pub async fn last_opened_workspace_location() -> Option<(SerializedWorkspaceLocation, PathList)> { DB.last_workspace().await.log_err().flatten() } pub fn last_session_workspace_locations( last_session_id: &str, last_session_window_stack: Option>, -) -> Option> { +) -> Option> { DB.last_session_workspace_locations(last_session_id, last_session_window_stack) .log_err() } @@ -7309,7 +7290,7 @@ pub fn open_ssh_project_with_new_connection( cx: &mut App, ) -> Task> { cx.spawn(async move |cx| { - let (serialized_ssh_project, workspace_id, serialized_workspace) = + let (workspace_id, serialized_workspace) = serialize_ssh_project(connection_options.clone(), paths.clone(), cx).await?; let session = match cx @@ -7343,7 +7324,6 @@ pub fn open_ssh_project_with_new_connection( open_ssh_project_inner( project, paths, - serialized_ssh_project, workspace_id, serialized_workspace, app_state, @@ -7363,13 +7343,12 @@ pub fn open_ssh_project_with_existing_connection( cx: &mut AsyncApp, ) -> Task> { cx.spawn(async move |cx| { - let (serialized_ssh_project, workspace_id, serialized_workspace) = + let (workspace_id, serialized_workspace) = serialize_ssh_project(connection_options.clone(), paths.clone(), cx).await?; open_ssh_project_inner( project, paths, - serialized_ssh_project, workspace_id, serialized_workspace, app_state, @@ -7383,7 +7362,6 @@ pub fn open_ssh_project_with_existing_connection( async fn open_ssh_project_inner( project: Entity, paths: Vec, - serialized_ssh_project: SerializedSshProject, workspace_id: WorkspaceId, serialized_workspace: Option, app_state: Arc, @@ -7436,7 +7414,6 @@ async fn open_ssh_project_inner( let mut workspace = Workspace::new(Some(workspace_id), project, app_state.clone(), window, cx); - workspace.set_serialized_ssh_project(serialized_ssh_project); workspace.update_history(cx); if let Some(ref serialized) = serialized_workspace { @@ -7473,28 +7450,18 @@ fn serialize_ssh_project( connection_options: SshConnectionOptions, paths: Vec, cx: &AsyncApp, -) -> Task< - Result<( - SerializedSshProject, - WorkspaceId, - Option, - )>, -> { +) -> Task)>> { cx.background_spawn(async move { - let serialized_ssh_project = persistence::DB - .get_or_create_ssh_project( + let ssh_connection_id = persistence::DB + .get_or_create_ssh_connection( connection_options.host.clone(), connection_options.port, - paths - .iter() - .map(|path| path.to_string_lossy().to_string()) - .collect::>(), connection_options.username.clone(), ) .await?; let serialized_workspace = - persistence::DB.workspace_for_ssh_project(&serialized_ssh_project); + persistence::DB.ssh_workspace_for_roots(&paths, ssh_connection_id); let workspace_id = if let Some(workspace_id) = serialized_workspace.as_ref().map(|workspace| workspace.id) @@ -7504,7 +7471,7 @@ fn serialize_ssh_project( persistence::DB.next_id().await? }; - Ok((serialized_ssh_project, workspace_id, serialized_workspace)) + Ok((workspace_id, serialized_workspace)) }) } @@ -8051,18 +8018,15 @@ pub fn ssh_workspace_position_from_db( paths_to_open: &[PathBuf], cx: &App, ) -> Task> { - let paths = paths_to_open - .iter() - .map(|path| path.to_string_lossy().to_string()) - .collect::>(); + let paths = paths_to_open.to_vec(); cx.background_spawn(async move { - let serialized_ssh_project = persistence::DB - .get_or_create_ssh_project(host, port, paths, user) + let ssh_connection_id = persistence::DB + .get_or_create_ssh_connection(host, port, user) .await .context("fetching serialized ssh project")?; let serialized_workspace = - persistence::DB.workspace_for_ssh_project(&serialized_ssh_project); + persistence::DB.ssh_workspace_for_roots(&paths, ssh_connection_id); let (window_bounds, display) = if let Some(bounds) = window_bounds_env_override() { (Some(WindowBounds::Windowed(bounds)), None) diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index c61e23f0a1..6f4ead9ebb 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -20,6 +20,7 @@ path = "src/main.rs" [dependencies] activity_indicator.workspace = true +acp_tools.workspace = true agent.workspace = true agent_ui.workspace = true agent_settings.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 8beefd5891..e99c8b564b 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -47,8 +47,8 @@ use theme::{ use util::{ResultExt, TryFutureExt, maybe}; use uuid::Uuid; use workspace::{ - AppState, SerializedWorkspaceLocation, Toast, Workspace, WorkspaceSettings, WorkspaceStore, - notifications::NotificationId, + AppState, PathList, SerializedWorkspaceLocation, Toast, Workspace, WorkspaceSettings, + WorkspaceStore, notifications::NotificationId, }; use zed::{ OpenListener, OpenRequest, RawOpenRequest, app_menus, build_window_options, @@ -566,6 +566,7 @@ pub fn main() { language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); agent_settings::init(cx); agent_servers::init(cx); + acp_tools::init(cx); web_search::init(cx); web_search_providers::init(app_state.client.clone(), cx); snippet_provider::init(cx); @@ -948,15 +949,14 @@ async fn restore_or_create_workspace(app_state: Arc, cx: &mut AsyncApp if let Some(locations) = restorable_workspace_locations(cx, &app_state).await { let mut tasks = Vec::new(); - for location in locations { + for (location, paths) in locations { match location { - SerializedWorkspaceLocation::Local(location, _) => { + SerializedWorkspaceLocation::Local => { let app_state = app_state.clone(); - let paths = location.paths().to_vec(); let task = cx.spawn(async move |cx| { let open_task = cx.update(|cx| { workspace::open_paths( - &paths, + &paths.paths(), app_state, workspace::OpenOptions::default(), cx, @@ -978,7 +978,7 @@ async fn restore_or_create_workspace(app_state: Arc, cx: &mut AsyncApp match connection_options { Ok(connection_options) => recent_projects::open_ssh_project( connection_options, - ssh.paths.into_iter().map(PathBuf::from).collect(), + paths.paths().into_iter().map(PathBuf::from).collect(), app_state, workspace::OpenOptions::default(), cx, @@ -1069,7 +1069,7 @@ async fn restore_or_create_workspace(app_state: Arc, cx: &mut AsyncApp pub(crate) async fn restorable_workspace_locations( cx: &mut AsyncApp, app_state: &Arc, -) -> Option> { +) -> Option> { let mut restore_behavior = cx .update(|cx| WorkspaceSettings::get(None, cx).restore_on_startup) .ok()?; diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 638e1dca0e..1b9657dcc6 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -4434,6 +4434,7 @@ mod tests { assert_eq!(actions_without_namespace, Vec::<&str>::new()); let expected_namespaces = vec![ + "acp", "activity_indicator", "agent", #[cfg(not(target_os = "macos"))] diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index a9abd9bc74..bc2d757fd1 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -75,13 +75,10 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { let new_provider = all_language_settings(None, cx).edit_predictions.provider; if new_provider != provider { - let tos_accepted = user_store.read(cx).has_accepted_terms_of_service(); - telemetry::event!( "Edit Prediction Provider Changed", from = provider, to = new_provider, - zed_ai_tos_accepted = tos_accepted, ); provider = new_provider; @@ -92,28 +89,6 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { user_store.clone(), cx, ); - - if !tos_accepted { - match provider { - EditPredictionProvider::Zed => { - let Some(window) = cx.active_window() else { - return; - }; - - window - .update(cx, |_, window, cx| { - window.dispatch_action( - Box::new(zed_actions::OpenZedPredictOnboarding), - cx, - ); - }) - .ok(); - } - EditPredictionProvider::None - | EditPredictionProvider::Copilot - | EditPredictionProvider::Supermaven => {} - } - } } } }) diff --git a/crates/zed/src/zed/open_listener.rs b/crates/zed/src/zed/open_listener.rs index 827c7754fa..2194fb7af5 100644 --- a/crates/zed/src/zed/open_listener.rs +++ b/crates/zed/src/zed/open_listener.rs @@ -26,6 +26,7 @@ use std::thread; use std::time::Duration; use util::ResultExt; use util::paths::PathWithPosition; +use workspace::PathList; use workspace::item::ItemHandle; use workspace::{AppState, OpenOptions, SerializedWorkspaceLocation, Workspace}; @@ -361,12 +362,14 @@ async fn open_workspaces( if open_new_workspace == Some(true) { Vec::new() } else { - let locations = restorable_workspace_locations(cx, &app_state).await; - locations.unwrap_or_default() + restorable_workspace_locations(cx, &app_state) + .await + .unwrap_or_default() } } else { - vec![SerializedWorkspaceLocation::from_local_paths( - paths.into_iter().map(PathBuf::from), + vec![( + SerializedWorkspaceLocation::Local, + PathList::new(&paths.into_iter().map(PathBuf::from).collect::>()), )] }; @@ -394,9 +397,9 @@ async fn open_workspaces( // If there are paths to open, open a workspace for each grouping of paths let mut errored = false; - for location in grouped_locations { + for (location, workspace_paths) in grouped_locations { match location { - SerializedWorkspaceLocation::Local(workspace_paths, _) => { + SerializedWorkspaceLocation::Local => { let workspace_paths = workspace_paths .paths() .iter() @@ -429,7 +432,7 @@ async fn open_workspaces( cx.spawn(async move |cx| { open_ssh_project( connection_options, - ssh.paths.into_iter().map(PathBuf::from).collect(), + workspace_paths.paths().to_vec(), app_state, OpenOptions::default(), cx, diff --git a/crates/zed_actions/src/lib.rs b/crates/zed_actions/src/lib.rs index 9455369e9a..069abc0a12 100644 --- a/crates/zed_actions/src/lib.rs +++ b/crates/zed_actions/src/lib.rs @@ -156,7 +156,10 @@ pub mod workspace { #[action(deprecated_aliases = ["editor::CopyPath", "outline_panel::CopyPath", "project_panel::CopyPath"])] CopyPath, #[action(deprecated_aliases = ["editor::CopyRelativePath", "outline_panel::CopyRelativePath", "project_panel::CopyRelativePath"])] - CopyRelativePath + CopyRelativePath, + /// Opens the selected file with the system's default application. + #[action(deprecated_aliases = ["project_panel::OpenWithSystem"])] + OpenWithSystem, ] ); } diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 916699d29b..7b14d12796 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -118,12 +118,8 @@ impl Dismissable for ZedPredictUpsell { } } -pub fn should_show_upsell_modal(user_store: &Entity, cx: &App) -> bool { - if user_store.read(cx).has_accepted_terms_of_service() { - !ZedPredictUpsell::dismissed() - } else { - true - } +pub fn should_show_upsell_modal() -> bool { + !ZedPredictUpsell::dismissed() } #[derive(Clone)] @@ -1547,16 +1543,6 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { ) -> bool { true } - - fn needs_terms_acceptance(&self, cx: &App) -> bool { - !self - .zeta - .read(cx) - .user_store - .read(cx) - .has_accepted_terms_of_service() - } - fn is_refreshing(&self) -> bool { !self.pending_completions.is_empty() } @@ -1569,10 +1555,6 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { _debounce: bool, cx: &mut Context, ) { - if self.needs_terms_acceptance(cx) { - return; - } - if self.zeta.read(cx).update_required { return; } diff --git a/docs/src/configuring-zed.md b/docs/src/configuring-zed.md index 696370e310..fb139db6e4 100644 --- a/docs/src/configuring-zed.md +++ b/docs/src/configuring-zed.md @@ -2425,6 +2425,7 @@ Examples: { "completions": { "words": "fallback", + "words_min_length": 3, "lsp": true, "lsp_fetch_timeout_ms": 0, "lsp_insert_mode": "replace_suffix" @@ -2444,6 +2445,17 @@ Examples: 2. `fallback` - Only if LSP response errors or times out, use document's words to show completions 3. `disabled` - Never fetch or complete document's words for completions (word-based completions can still be queried via a separate action) +### Min Words Query Length + +- Description: Minimum number of characters required to automatically trigger word-based completions. + Before that value, it's still possible to trigger the words-based completion manually with the corresponding editor command. +- Setting: `words_min_length` +- Default: `3` + +**Options** + +Positive integer values + ### LSP - Description: Whether to fetch LSP completions or not. diff --git a/script/squawk b/script/squawk index 8489206f14..497fcff089 100755 --- a/script/squawk +++ b/script/squawk @@ -15,13 +15,11 @@ SQUAWK_VERSION=0.26.0 SQUAWK_BIN="./target/squawk-$SQUAWK_VERSION" SQUAWK_ARGS="--assume-in-transaction --config script/lib/squawk.toml" -if [ ! -f "$SQUAWK_BIN" ]; then - pkgutil --pkg-info com.apple.pkg.RosettaUpdateAuto || /usr/sbin/softwareupdate --install-rosetta --agree-to-license - # When bootstrapping a brand new CI machine, the `target` directory may not exist yet. - mkdir -p "./target" - curl -L -o "$SQUAWK_BIN" "https://github.com/sbdchd/squawk/releases/download/v$SQUAWK_VERSION/squawk-darwin-x86_64" - chmod +x "$SQUAWK_BIN" -fi +pkgutil --pkg-info com.apple.pkg.RosettaUpdateAuto || /usr/sbin/softwareupdate --install-rosetta --agree-to-license +# When bootstrapping a brand new CI machine, the `target` directory may not exist yet. +mkdir -p "./target" +curl -L -o "$SQUAWK_BIN" "https://github.com/sbdchd/squawk/releases/download/v$SQUAWK_VERSION/squawk-darwin-x86_64" +chmod +x "$SQUAWK_BIN" if [ -n "$SQUAWK_GITHUB_TOKEN" ]; then export SQUAWK_GITHUB_REPO_OWNER=$(echo $GITHUB_REPOSITORY | awk -F/ '{print $1}') diff --git a/tooling/workspace-hack/Cargo.toml b/tooling/workspace-hack/Cargo.toml index 054e757056..2f9a963abc 100644 --- a/tooling/workspace-hack/Cargo.toml +++ b/tooling/workspace-hack/Cargo.toml @@ -54,6 +54,7 @@ digest = { version = "0.10", features = ["mac", "oid", "std"] } either = { version = "1", features = ["serde", "use_std"] } euclid = { version = "0.22" } event-listener = { version = "5" } +event-listener-strategy = { version = "0.5" } flate2 = { version = "1", features = ["zlib-rs"] } form_urlencoded = { version = "1" } futures = { version = "0.3", features = ["io-compat"] } @@ -108,7 +109,6 @@ rustc-hash = { version = "1" } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", default-features = false, features = ["fs", "net", "std"] } rustls = { version = "0.23", features = ["ring"] } rustls-webpki = { version = "0.103", default-features = false, features = ["aws-lc-rs", "ring", "std"] } -schemars = { version = "1", features = ["chrono04", "indexmap2", "semver1"] } sea-orm = { version = "1", features = ["runtime-tokio-rustls", "sqlx-postgres", "sqlx-sqlite"] } sea-query-binder = { version = "0.7", default-features = false, features = ["postgres-array", "sqlx-postgres", "sqlx-sqlite", "with-bigdecimal", "with-chrono", "with-json", "with-rust_decimal", "with-time", "with-uuid"] } semver = { version = "1", features = ["serde"] } @@ -183,6 +183,7 @@ digest = { version = "0.10", features = ["mac", "oid", "std"] } either = { version = "1", features = ["serde", "use_std"] } euclid = { version = "0.22" } event-listener = { version = "5" } +event-listener-strategy = { version = "0.5" } flate2 = { version = "1", features = ["zlib-rs"] } form_urlencoded = { version = "1" } futures = { version = "0.3", features = ["io-compat"] } @@ -242,7 +243,6 @@ rustc-hash = { version = "1" } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", default-features = false, features = ["fs", "net", "std"] } rustls = { version = "0.23", features = ["ring"] } rustls-webpki = { version = "0.103", default-features = false, features = ["aws-lc-rs", "ring", "std"] } -schemars = { version = "1", features = ["chrono04", "indexmap2", "semver1"] } sea-orm = { version = "1", features = ["runtime-tokio-rustls", "sqlx-postgres", "sqlx-sqlite"] } sea-query-binder = { version = "0.7", default-features = false, features = ["postgres-array", "sqlx-postgres", "sqlx-sqlite", "with-bigdecimal", "with-chrono", "with-json", "with-rust_decimal", "with-time", "with-uuid"] } semver = { version = "1", features = ["serde"] } @@ -403,7 +403,6 @@ bytemuck = { version = "1", default-features = false, features = ["min_const_gen cipher = { version = "0.4", default-features = false, features = ["block-padding", "rand_core", "zeroize"] } codespan-reporting = { version = "0.12" } crypto-common = { version = "0.1", default-features = false, features = ["rand_core", "std"] } -event-listener-strategy = { version = "0.5" } flume = { version = "0.11" } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } @@ -444,7 +443,6 @@ bytemuck = { version = "1", default-features = false, features = ["min_const_gen cipher = { version = "0.4", default-features = false, features = ["block-padding", "rand_core", "zeroize"] } codespan-reporting = { version = "0.12" } crypto-common = { version = "0.1", default-features = false, features = ["rand_core", "std"] } -event-listener-strategy = { version = "0.5" } flume = { version = "0.11" } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } @@ -483,7 +481,6 @@ bytemuck = { version = "1", default-features = false, features = ["min_const_gen cipher = { version = "0.4", default-features = false, features = ["block-padding", "rand_core", "zeroize"] } codespan-reporting = { version = "0.12" } crypto-common = { version = "0.1", default-features = false, features = ["rand_core", "std"] } -event-listener-strategy = { version = "0.5" } flume = { version = "0.11" } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } @@ -524,7 +521,6 @@ bytemuck = { version = "1", default-features = false, features = ["min_const_gen cipher = { version = "0.4", default-features = false, features = ["block-padding", "rand_core", "zeroize"] } codespan-reporting = { version = "0.12" } crypto-common = { version = "0.1", default-features = false, features = ["rand_core", "std"] } -event-listener-strategy = { version = "0.5" } flume = { version = "0.11" } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } @@ -610,7 +606,6 @@ bytemuck = { version = "1", default-features = false, features = ["min_const_gen cipher = { version = "0.4", default-features = false, features = ["block-padding", "rand_core", "zeroize"] } codespan-reporting = { version = "0.12" } crypto-common = { version = "0.1", default-features = false, features = ["rand_core", "std"] } -event-listener-strategy = { version = "0.5" } flume = { version = "0.11" } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } @@ -651,7 +646,6 @@ bytemuck = { version = "1", default-features = false, features = ["min_const_gen cipher = { version = "0.4", default-features = false, features = ["block-padding", "rand_core", "zeroize"] } codespan-reporting = { version = "0.12" } crypto-common = { version = "0.1", default-features = false, features = ["rand_core", "std"] } -event-listener-strategy = { version = "0.5" } flume = { version = "0.11" } foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] }