diff --git a/Cargo.lock b/Cargo.lock index 7b9bf9ab57..b12f9e667b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1262,26 +1262,6 @@ dependencies = [ "syn 2.0.101", ] -[[package]] -name = "async-stripe" -version = "0.40.0" -source = "git+https://github.com/zed-industries/async-stripe?rev=3672dd4efb7181aa597bf580bf5a2f5d23db6735#3672dd4efb7181aa597bf580bf5a2f5d23db6735" -dependencies = [ - "chrono", - "futures-util", - "http-types", - "hyper 0.14.32", - "hyper-rustls 0.24.2", - "serde", - "serde_json", - "serde_path_to_error", - "serde_qs 0.10.1", - "smart-default 0.6.0", - "smol_str 0.1.24", - "thiserror 1.0.69", - "tokio", -] - [[package]] name = "async-tar" version = "0.5.0" @@ -2083,12 +2063,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce" -[[package]] -name = "base64" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" - [[package]] name = "base64" version = "0.21.7" @@ -3281,7 +3255,6 @@ dependencies = [ "anyhow", "assistant_context", "assistant_slash_command", - "async-stripe", "async-trait", "async-tungstenite", "audio", @@ -3308,7 +3281,6 @@ dependencies = [ "dap_adapters", "dashmap 6.1.0", "debugger_ui", - "derive_more 0.99.19", "editor", "envy", "extension", @@ -3324,7 +3296,6 @@ dependencies = [ "http_client", "hyper 0.14.32", "indoc", - "jsonwebtoken", "language", "language_model", "livekit_api", @@ -3370,7 +3341,6 @@ dependencies = [ "telemetry_events", "text", "theme", - "thiserror 2.0.12", "time", "tokio", "toml 0.8.20", @@ -3872,7 +3842,7 @@ dependencies = [ "rustc-hash 1.1.0", "rustybuzz 0.14.1", "self_cell", - "smol_str 0.2.2", + "smol_str", "swash", "sys-locale", "ttf-parser 0.21.1", @@ -6376,17 +6346,6 @@ dependencies = [ "windows-targets 0.48.5", ] -[[package]] -name = "getrandom" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" -dependencies = [ - "cfg-if", - "libc", - "wasi 0.9.0+wasi-snapshot-preview1", -] - [[package]] name = "getrandom" version = "0.2.15" @@ -7990,27 +7949,6 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f" -[[package]] -name = "http-types" -version = "2.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e9b187a72d63adbfba487f48095306ac823049cb504ee195541e91c7775f5ad" -dependencies = [ - "anyhow", - "async-channel 1.9.0", - "base64 0.13.1", - "futures-lite 1.13.0", - "http 0.2.12", - "infer", - "pin-project-lite", - "rand 0.7.3", - "serde", - "serde_json", - "serde_qs 0.8.5", - "serde_urlencoded", - "url", -] - [[package]] name = "http_client" version = "0.1.0" @@ -8489,12 +8427,6 @@ version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" -[[package]] -name = "infer" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac" - [[package]] name = "inherent" version = "1.0.12" @@ -10271,7 +10203,7 @@ dependencies = [ "num-traits", "range-map", "scroll", - "smart-default 0.7.1", + "smart-default", ] [[package]] @@ -13145,19 +13077,6 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" -[[package]] -name = "rand" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" -dependencies = [ - "getrandom 0.1.16", - "libc", - "rand_chacha 0.2.2", - "rand_core 0.5.1", - "rand_hc", -] - [[package]] name = "rand" version = "0.8.5" @@ -13179,16 +13098,6 @@ dependencies = [ "rand_core 0.9.3", ] -[[package]] -name = "rand_chacha" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" -dependencies = [ - "ppv-lite86", - "rand_core 0.5.1", -] - [[package]] name = "rand_chacha" version = "0.3.1" @@ -13209,15 +13118,6 @@ dependencies = [ "rand_core 0.9.3", ] -[[package]] -name = "rand_core" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" -dependencies = [ - "getrandom 0.1.16", -] - [[package]] name = "rand_core" version = "0.6.4" @@ -13236,15 +13136,6 @@ dependencies = [ "getrandom 0.3.2", ] -[[package]] -name = "rand_hc" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" -dependencies = [ - "rand_core 0.5.1", -] - [[package]] name = "range-map" version = "0.2.0" @@ -14899,28 +14790,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_qs" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7715380eec75f029a4ef7de39a9200e0a63823176b759d055b613f5a87df6a6" -dependencies = [ - "percent-encoding", - "serde", - "thiserror 1.0.69", -] - -[[package]] -name = "serde_qs" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cac3f1e2ca2fe333923a1ae72caca910b98ed0630bb35ef6f8c8517d6e81afa" -dependencies = [ - "percent-encoding", - "serde", - "thiserror 1.0.69", -] - [[package]] name = "serde_repr" version = "0.1.20" @@ -15297,17 +15166,6 @@ dependencies = [ "serde", ] -[[package]] -name = "smart-default" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "133659a15339456eeeb07572eb02a91c91e9815e9cbc89566944d2c8d3efdbf6" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "smart-default" version = "0.7.1" @@ -15336,15 +15194,6 @@ dependencies = [ "futures-lite 2.6.0", ] -[[package]] -name = "smol_str" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fad6c857cbab2627dcf01ec85a623ca4e7dcb5691cbaa3d7fb7653671f0d09c9" -dependencies = [ - "serde", -] - [[package]] name = "smol_str" version = "0.2.2" @@ -18194,12 +18043,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "wasi" -version = "0.9.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index baa4ee7f4e..644b6c0f40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -667,20 +667,6 @@ workspace-hack = "0.1.0" yawc = { git = "https://github.com/deviant-forks/yawc", rev = "1899688f3e69ace4545aceb97b2a13881cf26142" } zstd = "0.11" -[workspace.dependencies.async-stripe] -git = "https://github.com/zed-industries/async-stripe" -rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735" -default-features = false -features = [ - "runtime-tokio-hyper-rustls", - "billing", - "checkout", - "events", - # The features below are only enabled to get the `events` feature to build. - "chrono", - "connect", -] - [workspace.dependencies.windows] version = "0.61" features = [ diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 4995ddb9df..2ef94a3cbe 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -109,7 +109,7 @@ pub enum AgentThreadEntry { } impl AgentThreadEntry { - fn to_markdown(&self, cx: &App) -> String { + pub fn to_markdown(&self, cx: &App) -> String { match self { Self::UserMessage(message) => message.to_markdown(cx), Self::AssistantMessage(message) => message.to_markdown(cx), @@ -117,6 +117,14 @@ impl AgentThreadEntry { } } + pub fn user_message(&self) -> Option<&UserMessage> { + if let AgentThreadEntry::UserMessage(message) = self { + Some(message) + } else { + None + } + } + pub fn diffs(&self) -> impl Iterator> { if let AgentThreadEntry::ToolCall(call) = self { itertools::Either::Left(call.diffs()) diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index d9557c5d00..fd38ba1f7f 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -309,7 +309,7 @@ pub struct AgentSettingsContent { /// /// Default: true expand_terminal_card: Option, - /// Whether to always use cmd-enter (or ctrl-enter on Linux) to send messages in the agent panel. + /// Whether to always use cmd-enter (or ctrl-enter on Linux or Windows) to send messages in the agent panel. /// /// Default: false use_modifier_to_send: Option, diff --git a/crates/agent_ui/src/acp/completion_provider.rs b/crates/agent_ui/src/acp/completion_provider.rs index adcfab85b1..1a9861d13a 100644 --- a/crates/agent_ui/src/acp/completion_provider.rs +++ b/crates/agent_ui/src/acp/completion_provider.rs @@ -1,38 +1,31 @@ -use std::ffi::OsStr; use std::ops::Range; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::AtomicBool; -use acp_thread::{MentionUri, selection_name}; +use acp_thread::MentionUri; use anyhow::{Context as _, Result, anyhow}; use collections::{HashMap, HashSet}; use editor::display_map::CreaseId; -use editor::{CompletionProvider, Editor, ExcerptId, ToOffset as _}; +use editor::{CompletionProvider, Editor, ExcerptId}; use futures::future::{Shared, try_join_all}; -use futures::{FutureExt, TryFutureExt}; use fuzzy::{StringMatch, StringMatchCandidate}; -use gpui::{App, Entity, ImageFormat, Img, Task, WeakEntity}; -use http_client::HttpClientWithUrl; -use itertools::Itertools as _; +use gpui::{App, Entity, ImageFormat, Task, WeakEntity}; use language::{Buffer, CodeLabel, HighlightId}; -use language_model::LanguageModelImage; use lsp::CompletionContext; -use parking_lot::Mutex; use project::{ Completion, CompletionIntent, CompletionResponse, Project, ProjectPath, Symbol, WorktreeId, }; use prompt_store::PromptStore; use rope::Point; -use text::{Anchor, OffsetRangeExt as _, ToPoint as _}; +use text::{Anchor, ToPoint as _}; use ui::prelude::*; use url::Url; use workspace::Workspace; -use workspace::notifications::NotifyResultExt; use agent::thread_store::{TextThreadStore, ThreadStore}; -use crate::context_picker::fetch_context_picker::fetch_url_content; +use crate::acp::message_editor::MessageEditor; use crate::context_picker::file_context_picker::{FileMatch, search_files}; use crate::context_picker::rules_context_picker::{RulesContextEntry, search_rules}; use crate::context_picker::symbol_context_picker::SymbolMatch; @@ -47,14 +40,14 @@ use crate::context_picker::{ #[derive(Clone, Debug, Eq, PartialEq)] pub struct MentionImage { - pub abs_path: Option>, + pub abs_path: Option, pub data: SharedString, pub format: ImageFormat, } #[derive(Default)] pub struct MentionSet { - uri_by_crease_id: HashMap, + pub(crate) uri_by_crease_id: HashMap, fetch_results: HashMap>>>, images: HashMap>>>, } @@ -84,11 +77,6 @@ impl MentionSet { .chain(self.images.drain().map(|(id, _)| id)) } - pub fn clear(&mut self) { - self.fetch_results.clear(); - self.uri_by_crease_id.clear(); - } - pub fn contents( &self, project: Entity, @@ -97,6 +85,8 @@ impl MentionSet { window: &mut Window, cx: &mut App, ) -> Task>> { + let mut processed_image_creases = HashSet::default(); + let mut contents = self .uri_by_crease_id .iter() @@ -106,59 +96,27 @@ impl MentionSet { // TODO directories let uri = uri.clone(); let abs_path = abs_path.to_path_buf(); - let extension = abs_path.extension().and_then(OsStr::to_str).unwrap_or(""); - if Img::extensions().contains(&extension) && !extension.contains("svg") { - let open_image_task = project.update(cx, |project, cx| { - let path = project - .find_project_path(&abs_path, cx) - .context("Failed to find project path")?; - anyhow::Ok(project.open_image(path, cx)) + if let Some(task) = self.images.get(&crease_id).cloned() { + processed_image_creases.insert(crease_id); + return cx.spawn(async move |_| { + let image = task.await.map_err(|e| anyhow!("{e}"))?; + anyhow::Ok((crease_id, Mention::Image(image))) }); - - cx.spawn(async move |cx| { - let image_item = open_image_task?.await?; - let (data, format) = image_item.update(cx, |image_item, cx| { - let format = image_item.image.format; - ( - LanguageModelImage::from_image( - image_item.image.clone(), - cx, - ), - format, - ) - })?; - let data = cx.spawn(async move |_| { - if let Some(data) = data.await { - Ok(data.source) - } else { - anyhow::bail!("Failed to convert image") - } - }); - - anyhow::Ok(( - crease_id, - Mention::Image(MentionImage { - abs_path: Some(abs_path.as_path().into()), - data: data.await?, - format, - }), - )) - }) - } else { - let buffer_task = project.update(cx, |project, cx| { - let path = project - .find_project_path(abs_path, cx) - .context("Failed to find project path")?; - anyhow::Ok(project.open_buffer(path, cx)) - }); - cx.spawn(async move |cx| { - let buffer = buffer_task?.await?; - let content = buffer.read_with(cx, |buffer, _cx| buffer.text())?; - - anyhow::Ok((crease_id, Mention::Text { uri, content })) - }) } + + let buffer_task = project.update(cx, |project, cx| { + let path = project + .find_project_path(abs_path, cx) + .context("Failed to find project path")?; + anyhow::Ok(project.open_buffer(path, cx)) + }); + cx.spawn(async move |cx| { + let buffer = buffer_task?.await?; + let content = buffer.read_with(cx, |buffer, _cx| buffer.text())?; + + anyhow::Ok((crease_id, Mention::Text { uri, content })) + }) } MentionUri::Symbol { path, line_range, .. @@ -252,15 +210,19 @@ impl MentionSet { }) .collect::>(); - contents.extend(self.images.iter().map(|(crease_id, image)| { + // Handle images that didn't have a mention URI (because they were added by the paste handler). + contents.extend(self.images.iter().filter_map(|(crease_id, image)| { + if processed_image_creases.contains(crease_id) { + return None; + } let crease_id = *crease_id; let image = image.clone(); - cx.spawn(async move |_| { + Some(cx.spawn(async move |_| { Ok(( crease_id, Mention::Image(image.await.map_err(|e| anyhow::anyhow!("{e}"))?), )) - }) + })) })); cx.spawn(async move |_cx| { @@ -488,36 +450,31 @@ fn search( } pub struct ContextPickerCompletionProvider { - mention_set: Arc>, workspace: WeakEntity, thread_store: WeakEntity, text_thread_store: WeakEntity, - editor: WeakEntity, + message_editor: WeakEntity, } impl ContextPickerCompletionProvider { pub fn new( - mention_set: Arc>, workspace: WeakEntity, thread_store: WeakEntity, text_thread_store: WeakEntity, - editor: WeakEntity, + message_editor: WeakEntity, ) -> Self { Self { - mention_set, workspace, thread_store, text_thread_store, - editor, + message_editor, } } fn completion_for_entry( entry: ContextPickerEntry, - excerpt_id: ExcerptId, source_range: Range, - editor: Entity, - mention_set: Arc>, + message_editor: WeakEntity, workspace: &Entity, cx: &mut App, ) -> Option { @@ -538,88 +495,39 @@ impl ContextPickerCompletionProvider { ContextPickerEntry::Action(action) => { let (new_text, on_action) = match action { ContextPickerAction::AddSelections => { - let selections = selection_ranges(workspace, cx); - const PLACEHOLDER: &str = "selection "; + let selections = selection_ranges(workspace, cx) + .into_iter() + .enumerate() + .map(|(ix, (buffer, range))| { + ( + buffer, + range, + (PLACEHOLDER.len() * ix)..(PLACEHOLDER.len() * (ix + 1) - 1), + ) + }) + .collect::>(); - let new_text = std::iter::repeat(PLACEHOLDER) - .take(selections.len()) - .chain(std::iter::once("")) - .join(" "); + let new_text: String = PLACEHOLDER.repeat(selections.len()); let callback = Arc::new({ - let mention_set = mention_set.clone(); - let selections = selections.clone(); + let source_range = source_range.clone(); move |_, window: &mut Window, cx: &mut App| { - let editor = editor.clone(); - let mention_set = mention_set.clone(); let selections = selections.clone(); + let message_editor = message_editor.clone(); + let source_range = source_range.clone(); window.defer(cx, move |window, cx| { - let mut current_offset = 0; - - for (buffer, selection_range) in selections { - let snapshot = - editor.read(cx).buffer().read(cx).snapshot(cx); - let Some(start) = snapshot - .anchor_in_excerpt(excerpt_id, source_range.start) - else { - return; - }; - - let offset = start.to_offset(&snapshot) + current_offset; - let text_len = PLACEHOLDER.len() - 1; - - let range = snapshot.anchor_after(offset) - ..snapshot.anchor_after(offset + text_len); - - let path = buffer - .read(cx) - .file() - .map_or(PathBuf::from("untitled"), |file| { - file.path().to_path_buf() - }); - - let point_range = snapshot - .as_singleton() - .map(|(_, _, snapshot)| { - selection_range.to_point(&snapshot) - }) - .unwrap_or_default(); - let line_range = point_range.start.row..point_range.end.row; - - let uri = MentionUri::Selection { - path: path.clone(), - line_range: line_range.clone(), - }; - let crease = crate::context_picker::crease_for_mention( - selection_name(&path, &line_range).into(), - uri.icon_path(cx), - range, - editor.downgrade(), - ); - - let [crease_id]: [_; 1] = - editor.update(cx, |editor, cx| { - let crease_ids = - editor.insert_creases(vec![crease.clone()], cx); - editor.fold_creases( - vec![crease], - false, - window, - cx, - ); - crease_ids.try_into().unwrap() - }); - - mention_set.lock().insert_uri( - crease_id, - MentionUri::Selection { path, line_range }, - ); - - current_offset += text_len + 1; - } + message_editor + .update(cx, |message_editor, cx| { + message_editor.confirm_mention_for_selection( + source_range, + selections, + window, + cx, + ) + }) + .ok(); }); - false } }); @@ -647,11 +555,9 @@ impl ContextPickerCompletionProvider { fn completion_for_thread( thread_entry: ThreadContextEntry, - excerpt_id: ExcerptId, source_range: Range, recent: bool, - editor: Entity, - mention_set: Arc>, + editor: WeakEntity, cx: &mut App, ) -> Completion { let uri = match &thread_entry { @@ -683,13 +589,10 @@ impl ContextPickerCompletionProvider { source: project::CompletionSource::Custom, icon_path: Some(icon_for_completion.clone()), confirm: Some(confirm_completion_callback( - uri.icon_path(cx), thread_entry.title().clone(), - excerpt_id, source_range.start, new_text_len - 1, - editor.clone(), - mention_set, + editor, uri, )), } @@ -697,10 +600,8 @@ impl ContextPickerCompletionProvider { fn completion_for_rules( rule: RulesContextEntry, - excerpt_id: ExcerptId, source_range: Range, - editor: Entity, - mention_set: Arc>, + editor: WeakEntity, cx: &mut App, ) -> Completion { let uri = MentionUri::Rule { @@ -719,13 +620,10 @@ impl ContextPickerCompletionProvider { source: project::CompletionSource::Custom, icon_path: Some(icon_path.clone()), confirm: Some(confirm_completion_callback( - icon_path, rule.title.clone(), - excerpt_id, source_range.start, new_text_len - 1, - editor.clone(), - mention_set, + editor, uri, )), } @@ -736,10 +634,8 @@ impl ContextPickerCompletionProvider { path_prefix: &str, is_recent: bool, is_directory: bool, - excerpt_id: ExcerptId, source_range: Range, - editor: Entity, - mention_set: Arc>, + message_editor: WeakEntity, project: Entity, cx: &mut App, ) -> Option { @@ -777,13 +673,10 @@ impl ContextPickerCompletionProvider { icon_path: Some(completion_icon_path), insert_text_mode: None, confirm: Some(confirm_completion_callback( - crease_icon_path, file_name, - excerpt_id, source_range.start, new_text_len - 1, - editor, - mention_set.clone(), + message_editor, file_uri, )), }) @@ -791,10 +684,8 @@ impl ContextPickerCompletionProvider { fn completion_for_symbol( symbol: Symbol, - excerpt_id: ExcerptId, source_range: Range, - editor: Entity, - mention_set: Arc>, + message_editor: WeakEntity, workspace: Entity, cx: &mut App, ) -> Option { @@ -820,13 +711,10 @@ impl ContextPickerCompletionProvider { icon_path: Some(icon_path.clone()), insert_text_mode: None, confirm: Some(confirm_completion_callback( - icon_path, symbol.name.clone().into(), - excerpt_id, source_range.start, new_text_len - 1, - editor.clone(), - mention_set.clone(), + message_editor, uri, )), }) @@ -835,116 +723,32 @@ impl ContextPickerCompletionProvider { fn completion_for_fetch( source_range: Range, url_to_fetch: SharedString, - excerpt_id: ExcerptId, - editor: Entity, - mention_set: Arc>, - http_client: Arc, + message_editor: WeakEntity, cx: &mut App, ) -> Option { let new_text = format!("@fetch {} ", url_to_fetch.clone()); - let new_text_len = new_text.len(); + let url_to_fetch = url::Url::parse(url_to_fetch.as_ref()) + .or_else(|_| url::Url::parse(&format!("https://{url_to_fetch}"))) + .ok()?; let mention_uri = MentionUri::Fetch { - url: url::Url::parse(url_to_fetch.as_ref()) - .or_else(|_| url::Url::parse(&format!("https://{url_to_fetch}"))) - .ok()?, + url: url_to_fetch.clone(), }; let icon_path = mention_uri.icon_path(cx); Some(Completion { replace_range: source_range.clone(), - new_text, + new_text: new_text.clone(), label: CodeLabel::plain(url_to_fetch.to_string(), None), documentation: None, source: project::CompletionSource::Custom, icon_path: Some(icon_path.clone()), insert_text_mode: None, - confirm: Some({ - let start = source_range.start; - let content_len = new_text_len - 1; - let editor = editor.clone(); - let url_to_fetch = url_to_fetch.clone(); - let source_range = source_range.clone(); - let icon_path = icon_path.clone(); - let mention_uri = mention_uri.clone(); - Arc::new(move |_, window, cx| { - let Some(url) = url::Url::parse(url_to_fetch.as_ref()) - .or_else(|_| url::Url::parse(&format!("https://{url_to_fetch}"))) - .notify_app_err(cx) - else { - return false; - }; - - let editor = editor.clone(); - let mention_set = mention_set.clone(); - let http_client = http_client.clone(); - let source_range = source_range.clone(); - let icon_path = icon_path.clone(); - let mention_uri = mention_uri.clone(); - window.defer(cx, move |window, cx| { - let url = url.clone(); - - let Some(crease_id) = crate::context_picker::insert_crease_for_mention( - excerpt_id, - start, - content_len, - url.to_string().into(), - icon_path, - editor.clone(), - window, - cx, - ) else { - return; - }; - - let editor = editor.clone(); - let mention_set = mention_set.clone(); - let http_client = http_client.clone(); - let source_range = source_range.clone(); - - let url_string = url.to_string(); - let fetch = cx - .background_executor() - .spawn(async move { - fetch_url_content(http_client, url_string) - .map_err(|e| e.to_string()) - .await - }) - .shared(); - mention_set.lock().add_fetch_result(url, fetch.clone()); - - window - .spawn(cx, async move |cx| { - if fetch.await.notify_async_err(cx).is_some() { - mention_set - .lock() - .insert_uri(crease_id, mention_uri.clone()); - } else { - // Remove crease if we failed to fetch - editor - .update(cx, |editor, cx| { - let snapshot = editor.buffer().read(cx).snapshot(cx); - let Some(anchor) = snapshot - .anchor_in_excerpt(excerpt_id, source_range.start) - else { - return; - }; - editor.display_map.update(cx, |display_map, cx| { - display_map.unfold_intersecting( - vec![anchor..anchor], - true, - cx, - ); - }); - editor.remove_creases([crease_id], cx); - }) - .ok(); - } - Some(()) - }) - .detach(); - }); - false - }) - }), + confirm: Some(confirm_completion_callback( + url_to_fetch.to_string().into(), + source_range.start, + new_text.len() - 1, + message_editor, + mention_uri, + )), }) } } @@ -968,7 +772,7 @@ fn build_code_label_for_full_path(file_name: &str, directory: Option<&str>, cx: impl CompletionProvider for ContextPickerCompletionProvider { fn completions( &self, - excerpt_id: ExcerptId, + _excerpt_id: ExcerptId, buffer: &Entity, buffer_position: Anchor, _trigger: CompletionContext, @@ -992,39 +796,24 @@ impl CompletionProvider for ContextPickerCompletionProvider { }; let project = workspace.read(cx).project().clone(); - let http_client = workspace.read(cx).client().http_client(); let snapshot = buffer.read(cx).snapshot(); let source_range = snapshot.anchor_before(state.source_range.start) ..snapshot.anchor_after(state.source_range.end); let thread_store = self.thread_store.clone(); let text_thread_store = self.text_thread_store.clone(); - let editor = self.editor.clone(); + let editor = self.message_editor.clone(); + let Ok((exclude_paths, exclude_threads)) = + self.message_editor.update(cx, |message_editor, _cx| { + message_editor.mentioned_path_and_threads() + }) + else { + return Task::ready(Ok(Vec::new())); + }; let MentionCompletion { mode, argument, .. } = state; let query = argument.unwrap_or_else(|| "".to_string()); - let (exclude_paths, exclude_threads) = { - let mention_set = self.mention_set.lock(); - - let mut excluded_paths = HashSet::default(); - let mut excluded_threads = HashSet::default(); - - for uri in mention_set.uri_by_crease_id.values() { - match uri { - MentionUri::File { abs_path, .. } => { - excluded_paths.insert(abs_path.clone()); - } - MentionUri::Thread { id, .. } => { - excluded_threads.insert(id.clone()); - } - _ => {} - } - } - - (excluded_paths, excluded_threads) - }; - let recent_entries = recent_context_picker_entries( Some(thread_store.clone()), Some(text_thread_store.clone()), @@ -1051,13 +840,8 @@ impl CompletionProvider for ContextPickerCompletionProvider { cx, ); - let mention_set = self.mention_set.clone(); - cx.spawn(async move |_, cx| { let matches = search_task.await; - let Some(editor) = editor.upgrade() else { - return Ok(Vec::new()); - }; let completions = cx.update(|cx| { matches @@ -1074,10 +858,8 @@ impl CompletionProvider for ContextPickerCompletionProvider { &mat.path_prefix, is_recent, mat.is_dir, - excerpt_id, source_range.clone(), editor.clone(), - mention_set.clone(), project.clone(), cx, ) @@ -1085,10 +867,8 @@ impl CompletionProvider for ContextPickerCompletionProvider { Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol( symbol, - excerpt_id, source_range.clone(), editor.clone(), - mention_set.clone(), workspace.clone(), cx, ), @@ -1097,39 +877,30 @@ impl CompletionProvider for ContextPickerCompletionProvider { thread, is_recent, .. }) => Some(Self::completion_for_thread( thread, - excerpt_id, source_range.clone(), is_recent, editor.clone(), - mention_set.clone(), cx, )), Match::Rules(user_rules) => Some(Self::completion_for_rules( user_rules, - excerpt_id, source_range.clone(), editor.clone(), - mention_set.clone(), cx, )), Match::Fetch(url) => Self::completion_for_fetch( source_range.clone(), url, - excerpt_id, editor.clone(), - mention_set.clone(), - http_client.clone(), cx, ), Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry( entry, - excerpt_id, source_range.clone(), editor.clone(), - mention_set.clone(), &workspace, cx, ), @@ -1182,36 +953,30 @@ impl CompletionProvider for ContextPickerCompletionProvider { } fn confirm_completion_callback( - crease_icon_path: SharedString, crease_text: SharedString, - excerpt_id: ExcerptId, start: Anchor, content_len: usize, - editor: Entity, - mention_set: Arc>, + message_editor: WeakEntity, mention_uri: MentionUri, ) -> Arc bool + Send + Sync> { Arc::new(move |_, window, cx| { + let message_editor = message_editor.clone(); let crease_text = crease_text.clone(); - let crease_icon_path = crease_icon_path.clone(); - let editor = editor.clone(); - let mention_set = mention_set.clone(); let mention_uri = mention_uri.clone(); window.defer(cx, move |window, cx| { - if let Some(crease_id) = crate::context_picker::insert_crease_for_mention( - excerpt_id, - start, - content_len, - crease_text.clone(), - crease_icon_path, - editor.clone(), - window, - cx, - ) { - mention_set - .lock() - .insert_uri(crease_id, mention_uri.clone()); - } + message_editor + .clone() + .update(cx, |message_editor, cx| { + message_editor.confirm_completion( + crease_text, + start, + content_len, + mention_uri, + window, + cx, + ) + }) + .ok(); }); false }) @@ -1279,13 +1044,13 @@ impl MentionCompletion { #[cfg(test)] mod tests { use super::*; - use editor::AnchorRangeExt; + use editor::{AnchorRangeExt, EditorMode}; use gpui::{EventEmitter, FocusHandle, Focusable, TestAppContext, VisualTestContext}; use project::{Project, ProjectPath}; use serde_json::json; use settings::SettingsStore; use smol::stream::StreamExt as _; - use std::{ops::Deref, path::Path, rc::Rc}; + use std::{ops::Deref, path::Path}; use util::path; use workspace::{AppState, Item}; @@ -1359,9 +1124,9 @@ mod tests { assert_eq!(MentionCompletion::try_parse("test@", 0), None); } - struct AtMentionEditor(Entity); + struct MessageEditorItem(Entity); - impl Item for AtMentionEditor { + impl Item for MessageEditorItem { type Event = (); fn include_in_nav_history() -> bool { @@ -1373,15 +1138,15 @@ mod tests { } } - impl EventEmitter<()> for AtMentionEditor {} + impl EventEmitter<()> for MessageEditorItem {} - impl Focusable for AtMentionEditor { + impl Focusable for MessageEditorItem { fn focus_handle(&self, cx: &App) -> FocusHandle { self.0.read(cx).focus_handle(cx).clone() } } - impl Render for AtMentionEditor { + impl Render for MessageEditorItem { fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { self.0.clone().into_any_element() } @@ -1467,19 +1232,28 @@ mod tests { opened_editors.push(buffer); } - let editor = workspace.update_in(&mut cx, |workspace, window, cx| { - let editor = cx.new(|cx| { - Editor::new( - editor::EditorMode::full(), - multi_buffer::MultiBuffer::build_simple("", cx), - None, + let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx)); + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + + let (message_editor, editor) = workspace.update_in(&mut cx, |workspace, window, cx| { + let workspace_handle = cx.weak_entity(); + let message_editor = cx.new(|cx| { + MessageEditor::new( + workspace_handle, + project.clone(), + thread_store.clone(), + text_thread_store.clone(), + EditorMode::AutoHeight { + max_lines: None, + min_lines: 1, + }, window, cx, ) }); workspace.active_pane().update(cx, |pane, cx| { pane.add_item( - Box::new(cx.new(|_| AtMentionEditor(editor.clone()))), + Box::new(cx.new(|_| MessageEditorItem(message_editor.clone()))), true, true, None, @@ -1487,24 +1261,9 @@ mod tests { cx, ); }); - editor - }); - - let mention_set = Arc::new(Mutex::new(MentionSet::default())); - - let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx)); - let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); - - let editor_entity = editor.downgrade(); - editor.update_in(&mut cx, |editor, window, cx| { - window.focus(&editor.focus_handle(cx)); - editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new( - mention_set.clone(), - workspace.downgrade(), - thread_store.downgrade(), - text_thread_store.downgrade(), - editor_entity, - )))); + message_editor.read(cx).focus_handle(cx).focus(window); + let editor = message_editor.read(cx).editor().clone(); + (message_editor, editor) }); cx.simulate_input("Lorem "); @@ -1573,9 +1332,9 @@ mod tests { ); }); - let contents = cx - .update(|window, cx| { - mention_set.lock().contents( + let contents = message_editor + .update_in(&mut cx, |message_editor, window, cx| { + message_editor.mention_set().contents( project.clone(), thread_store.clone(), text_thread_store.clone(), @@ -1641,9 +1400,9 @@ mod tests { cx.run_until_parked(); - let contents = cx - .update(|window, cx| { - mention_set.lock().contents( + let contents = message_editor + .update_in(&mut cx, |message_editor, window, cx| { + message_editor.mention_set().contents( project.clone(), thread_store.clone(), text_thread_store.clone(), @@ -1765,9 +1524,9 @@ mod tests { editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); }); - let contents = cx - .update(|window, cx| { - mention_set.lock().contents( + let contents = message_editor + .update_in(&mut cx, |message_editor, window, cx| { + message_editor.mention_set().contents( project.clone(), thread_store, text_thread_store, diff --git a/crates/agent_ui/src/acp/entry_view_state.rs b/crates/agent_ui/src/acp/entry_view_state.rs index 2f5f855e90..e99d1f6323 100644 --- a/crates/agent_ui/src/acp/entry_view_state.rs +++ b/crates/agent_ui/src/acp/entry_view_state.rs @@ -1,45 +1,141 @@ -use std::{collections::HashMap, ops::Range}; +use std::ops::Range; -use acp_thread::AcpThread; -use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer}; +use acp_thread::{AcpThread, AgentThreadEntry}; +use agent::{TextThreadStore, ThreadStore}; +use collections::HashMap; +use editor::{Editor, EditorMode, MinimapVisibility}; use gpui::{ - AnyEntity, App, AppContext as _, Entity, EntityId, TextStyleRefinement, WeakEntity, Window, + AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, TextStyleRefinement, + WeakEntity, Window, }; use language::language_settings::SoftWrap; +use project::Project; use settings::Settings as _; use terminal_view::TerminalView; use theme::ThemeSettings; -use ui::TextSize; +use ui::{Context, TextSize}; use workspace::Workspace; -#[derive(Default)] +use crate::acp::message_editor::{MessageEditor, MessageEditorEvent}; + pub struct EntryViewState { + workspace: WeakEntity, + project: Entity, + thread_store: Entity, + text_thread_store: Entity, entries: Vec, } impl EntryViewState { + pub fn new( + workspace: WeakEntity, + project: Entity, + thread_store: Entity, + text_thread_store: Entity, + ) -> Self { + Self { + workspace, + project, + thread_store, + text_thread_store, + entries: Vec::new(), + } + } + pub fn entry(&self, index: usize) -> Option<&Entry> { self.entries.get(index) } pub fn sync_entry( &mut self, - workspace: WeakEntity, - thread: Entity, index: usize, + thread: &Entity, window: &mut Window, - cx: &mut App, + cx: &mut Context, ) { - debug_assert!(index <= self.entries.len()); - let entry = if let Some(entry) = self.entries.get_mut(index) { - entry - } else { - self.entries.push(Entry::default()); - self.entries.last_mut().unwrap() + let Some(thread_entry) = thread.read(cx).entries().get(index) else { + return; }; - entry.sync_diff_multibuffers(&thread, index, window, cx); - entry.sync_terminals(&workspace, &thread, index, window, cx); + match thread_entry { + AgentThreadEntry::UserMessage(message) => { + let has_id = message.id.is_some(); + let chunks = message.chunks.clone(); + let message_editor = cx.new(|cx| { + let mut editor = MessageEditor::new( + self.workspace.clone(), + self.project.clone(), + self.thread_store.clone(), + self.text_thread_store.clone(), + editor::EditorMode::AutoHeight { + min_lines: 1, + max_lines: None, + }, + window, + cx, + ); + if !has_id { + editor.set_read_only(true, cx); + } + editor.set_message(chunks, window, cx); + editor + }); + cx.subscribe(&message_editor, move |_, editor, event, cx| { + cx.emit(EntryViewEvent { + entry_index: index, + view_event: ViewEvent::MessageEditorEvent(editor, *event), + }) + }) + .detach(); + self.set_entry(index, Entry::UserMessage(message_editor)); + } + AgentThreadEntry::ToolCall(tool_call) => { + let terminals = tool_call.terminals().cloned().collect::>(); + let diffs = tool_call.diffs().cloned().collect::>(); + + let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) { + views + } else { + self.set_entry(index, Entry::empty()); + let Some(Entry::Content(views)) = self.entries.get_mut(index) else { + unreachable!() + }; + views + }; + + for terminal in terminals { + views.entry(terminal.entity_id()).or_insert_with(|| { + create_terminal( + self.workspace.clone(), + self.project.clone(), + terminal.clone(), + window, + cx, + ) + .into_any() + }); + } + + for diff in diffs { + views + .entry(diff.entity_id()) + .or_insert_with(|| create_editor_diff(diff.clone(), window, cx).into_any()); + } + } + AgentThreadEntry::AssistantMessage(_) => { + if index == self.entries.len() { + self.entries.push(Entry::empty()) + } + } + }; + } + + fn set_entry(&mut self, index: usize, entry: Entry) { + if index == self.entries.len() { + self.entries.push(entry); + } else { + self.entries[index] = entry; + } } pub fn remove(&mut self, range: Range) { @@ -48,26 +144,51 @@ impl EntryViewState { pub fn settings_changed(&mut self, cx: &mut App) { for entry in self.entries.iter() { - for view in entry.views.values() { - if let Ok(diff_editor) = view.clone().downcast::() { - diff_editor.update(cx, |diff_editor, cx| { - diff_editor - .set_text_style_refinement(diff_editor_text_style_refinement(cx)); - cx.notify(); - }) + match entry { + Entry::UserMessage { .. } => {} + Entry::Content(response_views) => { + for view in response_views.values() { + if let Ok(diff_editor) = view.clone().downcast::() { + diff_editor.update(cx, |diff_editor, cx| { + diff_editor.set_text_style_refinement( + diff_editor_text_style_refinement(cx), + ); + cx.notify(); + }) + } + } } } } } } -pub struct Entry { - views: HashMap, +impl EventEmitter for EntryViewState {} + +pub struct EntryViewEvent { + pub entry_index: usize, + pub view_event: ViewEvent, +} + +pub enum ViewEvent { + MessageEditorEvent(Entity, MessageEditorEvent), +} + +pub enum Entry { + UserMessage(Entity), + Content(HashMap), } impl Entry { - pub fn editor_for_diff(&self, diff: &Entity) -> Option> { - self.views + pub fn message_editor(&self) -> Option<&Entity> { + match self { + Self::UserMessage(editor) => Some(editor), + Entry::Content(_) => None, + } + } + + pub fn editor_for_diff(&self, diff: &Entity) -> Option> { + self.content_map()? .get(&diff.entity_id()) .cloned() .map(|entity| entity.downcast::().unwrap()) @@ -77,118 +198,88 @@ impl Entry { &self, terminal: &Entity, ) -> Option> { - self.views + self.content_map()? .get(&terminal.entity_id()) .cloned() .map(|entity| entity.downcast::().unwrap()) } - fn sync_diff_multibuffers( - &mut self, - thread: &Entity, - index: usize, - window: &mut Window, - cx: &mut App, - ) { - let Some(entry) = thread.read(cx).entries().get(index) else { - return; - }; - - let multibuffers = entry - .diffs() - .map(|diff| diff.read(cx).multibuffer().clone()); - - let multibuffers = multibuffers.collect::>(); - - for multibuffer in multibuffers { - if self.views.contains_key(&multibuffer.entity_id()) { - return; - } - - let editor = cx.new(|cx| { - let mut editor = Editor::new( - EditorMode::Full { - scale_ui_elements_with_buffer_font_size: false, - show_active_line_background: false, - sized_by_content: true, - }, - multibuffer.clone(), - None, - window, - cx, - ); - editor.set_show_gutter(false, cx); - editor.disable_inline_diagnostics(); - editor.disable_expand_excerpt_buttons(cx); - editor.set_show_vertical_scrollbar(false, cx); - editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); - editor.set_soft_wrap_mode(SoftWrap::None, cx); - editor.scroll_manager.set_forbid_vertical_scroll(true); - editor.set_show_indent_guides(false, cx); - editor.set_read_only(true); - editor.set_show_breakpoints(false, cx); - editor.set_show_code_actions(false, cx); - editor.set_show_git_diff_gutter(false, cx); - editor.set_expand_all_diff_hunks(cx); - editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); - editor - }); - - let entity_id = multibuffer.entity_id(); - self.views.insert(entity_id, editor.into_any()); + fn content_map(&self) -> Option<&HashMap> { + match self { + Self::Content(map) => Some(map), + _ => None, } } - fn sync_terminals( - &mut self, - workspace: &WeakEntity, - thread: &Entity, - index: usize, - window: &mut Window, - cx: &mut App, - ) { - let Some(entry) = thread.read(cx).entries().get(index) else { - return; - }; - - let terminals = entry - .terminals() - .map(|terminal| terminal.clone()) - .collect::>(); - - for terminal in terminals { - if self.views.contains_key(&terminal.entity_id()) { - return; - } - - let Some(strong_workspace) = workspace.upgrade() else { - return; - }; - - let terminal_view = cx.new(|cx| { - let mut view = TerminalView::new( - terminal.read(cx).inner().clone(), - workspace.clone(), - None, - strong_workspace.read(cx).project().downgrade(), - window, - cx, - ); - view.set_embedded_mode(Some(1000), cx); - view - }); - - let entity_id = terminal.entity_id(); - self.views.insert(entity_id, terminal_view.into_any()); - } + fn empty() -> Self { + Self::Content(HashMap::default()) } #[cfg(test)] - pub fn len(&self) -> usize { - self.views.len() + pub fn has_content(&self) -> bool { + match self { + Self::Content(map) => !map.is_empty(), + Self::UserMessage(_) => false, + } } } +fn create_terminal( + workspace: WeakEntity, + project: Entity, + terminal: Entity, + window: &mut Window, + cx: &mut App, +) -> Entity { + cx.new(|cx| { + let mut view = TerminalView::new( + terminal.read(cx).inner().clone(), + workspace.clone(), + None, + project.downgrade(), + window, + cx, + ); + view.set_embedded_mode(Some(1000), cx); + view + }) +} + +fn create_editor_diff( + diff: Entity, + window: &mut Window, + cx: &mut App, +) -> Entity { + cx.new(|cx| { + let mut editor = Editor::new( + EditorMode::Full { + scale_ui_elements_with_buffer_font_size: false, + show_active_line_background: false, + sized_by_content: true, + }, + diff.read(cx).multibuffer().clone(), + None, + window, + cx, + ); + editor.set_show_gutter(false, cx); + editor.disable_inline_diagnostics(); + editor.disable_expand_excerpt_buttons(cx); + editor.set_show_vertical_scrollbar(false, cx); + editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); + editor.set_soft_wrap_mode(SoftWrap::None, cx); + editor.scroll_manager.set_forbid_vertical_scroll(true); + editor.set_show_indent_guides(false, cx); + editor.set_read_only(true); + editor.set_show_breakpoints(false, cx); + editor.set_show_code_actions(false, cx); + editor.set_show_git_diff_gutter(false, cx); + editor.set_expand_all_diff_hunks(cx); + editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); + editor + }) +} + fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement { TextStyleRefinement { font_size: Some( @@ -201,26 +292,20 @@ fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement { } } -impl Default for Entry { - fn default() -> Self { - Self { - // Avoid allocating in the heap by default - views: HashMap::with_capacity(0), - } - } -} - #[cfg(test)] mod tests { use std::{path::Path, rc::Rc}; use acp_thread::{AgentConnection, StubAgentConnection}; + use agent::{TextThreadStore, ThreadStore}; use agent_client_protocol as acp; use agent_settings::AgentSettings; use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind}; use editor::{EditorSettings, RowInfo}; use fs::FakeFs; - use gpui::{SemanticVersion, TestAppContext}; + use gpui::{AppContext as _, SemanticVersion, TestAppContext}; + + use crate::acp::entry_view_state::EntryViewState; use multi_buffer::MultiBufferRow; use pretty_assertions::assert_matches; use project::Project; @@ -230,8 +315,6 @@ mod tests { use util::path; use workspace::Workspace; - use crate::acp::entry_view_state::EntryViewState; - #[gpui::test] async fn test_diff_sync(cx: &mut TestAppContext) { init_test(cx); @@ -269,7 +352,7 @@ mod tests { .update(|_, cx| { connection .clone() - .new_thread(project, Path::new(path!("/project")), cx) + .new_thread(project.clone(), Path::new(path!("/project")), cx) }) .await .unwrap(); @@ -279,12 +362,23 @@ mod tests { connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx) }); - let mut view_state = EntryViewState::default(); - cx.update(|window, cx| { - view_state.sync_entry(workspace.downgrade(), thread.clone(), 0, window, cx); + let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx)); + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + + let view_state = cx.new(|_cx| { + EntryViewState::new( + workspace.downgrade(), + project.clone(), + thread_store, + text_thread_store, + ) }); - let multibuffer = thread.read_with(cx, |thread, cx| { + view_state.update_in(cx, |view_state, window, cx| { + view_state.sync_entry(0, &thread, window, cx) + }); + + let diff = thread.read_with(cx, |thread, _cx| { thread .entries() .get(0) @@ -292,15 +386,14 @@ mod tests { .diffs() .next() .unwrap() - .read(cx) - .multibuffer() .clone() }); cx.run_until_parked(); - let entry = view_state.entry(0).unwrap(); - let diff_editor = entry.editor_for_diff(&multibuffer).unwrap(); + let diff_editor = view_state.read_with(cx, |view_state, _cx| { + view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap() + }); assert_eq!( diff_editor.read_with(cx, |editor, cx| editor.text(cx)), "hi world\nhello world" diff --git a/crates/agent_ui/src/acp/message_editor.rs b/crates/agent_ui/src/acp/message_editor.rs index 8d512948dd..a4d74db266 100644 --- a/crates/agent_ui/src/acp/message_editor.rs +++ b/crates/agent_ui/src/acp/message_editor.rs @@ -1,61 +1,63 @@ -use crate::acp::completion_provider::ContextPickerCompletionProvider; -use crate::acp::completion_provider::MentionImage; -use crate::acp::completion_provider::MentionSet; -use acp_thread::MentionUri; -use agent::TextThreadStore; -use agent::ThreadStore; +use crate::{ + acp::completion_provider::{ContextPickerCompletionProvider, MentionImage, MentionSet}, + context_picker::fetch_context_picker::fetch_url_content, +}; +use acp_thread::{MentionUri, selection_name}; +use agent::{TextThreadStore, ThreadId, ThreadStore}; use agent_client_protocol as acp; use anyhow::Result; use collections::HashSet; -use editor::ExcerptId; -use editor::actions::Paste; -use editor::display_map::CreaseId; use editor::{ - AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode, - EditorStyle, MultiBuffer, + Anchor, AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, + EditorMode, EditorStyle, ExcerptId, FoldPlaceholder, MultiBuffer, ToOffset, + actions::Paste, + display_map::{Crease, CreaseId, FoldId}, }; -use futures::FutureExt as _; -use gpui::ClipboardEntry; -use gpui::Image; -use gpui::ImageFormat; +use futures::{FutureExt as _, TryFutureExt as _}; use gpui::{ - AppContext, Context, Entity, EventEmitter, FocusHandle, Focusable, Task, TextStyle, WeakEntity, + AppContext, ClipboardEntry, Context, Entity, EventEmitter, FocusHandle, Focusable, Image, + ImageFormat, Img, Task, TextStyle, WeakEntity, }; -use language::Buffer; -use language::Language; +use language::{Buffer, Language}; use language_model::LanguageModelImage; -use parking_lot::Mutex; use project::{CompletionIntent, Project}; use settings::Settings; -use std::fmt::Write; -use std::path::Path; -use std::rc::Rc; -use std::sync::Arc; +use std::{ + ffi::OsStr, + fmt::Write, + ops::Range, + path::{Path, PathBuf}, + rc::Rc, + sync::Arc, +}; +use text::OffsetRangeExt; use theme::ThemeSettings; -use ui::IconName; -use ui::SharedString; use ui::{ - ActiveTheme, App, InteractiveElement, IntoElement, ParentElement, Render, Styled, TextSize, - Window, div, + ActiveTheme, AnyElement, App, ButtonCommon, ButtonLike, ButtonStyle, Color, Icon, IconName, + IconSize, InteractiveElement, IntoElement, Label, LabelCommon, LabelSize, ParentElement, + Render, SelectableButton, SharedString, Styled, TextSize, TintColor, Toggleable, Window, div, + h_flex, }; use util::ResultExt; -use workspace::Workspace; -use workspace::notifications::NotifyResultExt as _; +use workspace::{Workspace, notifications::NotifyResultExt as _}; use zed_actions::agent::Chat; use super::completion_provider::Mention; pub struct MessageEditor { + mention_set: MentionSet, editor: Entity, project: Entity, + workspace: WeakEntity, thread_store: Entity, text_thread_store: Entity, - mention_set: Arc>, } +#[derive(Clone, Copy)] pub enum MessageEditorEvent { Send, Cancel, + Focus, } impl EventEmitter for MessageEditor {} @@ -77,8 +79,13 @@ impl MessageEditor { }, None, ); - - let mention_set = Arc::new(Mutex::new(MentionSet::default())); + let completion_provider = ContextPickerCompletionProvider::new( + workspace.clone(), + thread_store.downgrade(), + text_thread_store.downgrade(), + cx.weak_entity(), + ); + let mention_set = MentionSet::default(); let editor = cx.new(|cx| { let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx)); let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); @@ -88,13 +95,7 @@ impl MessageEditor { editor.set_show_indent_guides(false, cx); editor.set_soft_wrap(); editor.set_use_modal_editing(true); - editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new( - mention_set.clone(), - workspace, - thread_store.downgrade(), - text_thread_store.downgrade(), - cx.weak_entity(), - )))); + editor.set_completion_provider(Some(Rc::new(completion_provider))); editor.set_context_menu_options(ContextMenuOptions { min_entries_visible: 12, max_entries_visible: 12, @@ -103,25 +104,254 @@ impl MessageEditor { editor }); + cx.on_focus(&editor.focus_handle(cx), window, |_, _, cx| { + cx.emit(MessageEditorEvent::Focus) + }) + .detach(); + Self { editor, project, mention_set, thread_store, text_thread_store, + workspace, } } + #[cfg(test)] + pub(crate) fn editor(&self) -> &Entity { + &self.editor + } + + #[cfg(test)] + pub(crate) fn mention_set(&mut self) -> &mut MentionSet { + &mut self.mention_set + } + pub fn is_empty(&self, cx: &App) -> bool { self.editor.read(cx).is_empty(cx) } + pub fn mentioned_path_and_threads(&self) -> (HashSet, HashSet) { + let mut excluded_paths = HashSet::default(); + let mut excluded_threads = HashSet::default(); + + for uri in self.mention_set.uri_by_crease_id.values() { + match uri { + MentionUri::File { abs_path, .. } => { + excluded_paths.insert(abs_path.clone()); + } + MentionUri::Thread { id, .. } => { + excluded_threads.insert(id.clone()); + } + _ => {} + } + } + + (excluded_paths, excluded_threads) + } + + pub fn confirm_completion( + &mut self, + crease_text: SharedString, + start: text::Anchor, + content_len: usize, + mention_uri: MentionUri, + window: &mut Window, + cx: &mut Context, + ) { + let snapshot = self + .editor + .update(cx, |editor, cx| editor.snapshot(window, cx)); + let Some((excerpt_id, _, _)) = snapshot.buffer_snapshot.as_singleton() else { + return; + }; + let Some(anchor) = snapshot + .buffer_snapshot + .anchor_in_excerpt(*excerpt_id, start) + else { + return; + }; + + let Some(crease_id) = crate::context_picker::insert_crease_for_mention( + *excerpt_id, + start, + content_len, + crease_text.clone(), + mention_uri.icon_path(cx), + self.editor.clone(), + window, + cx, + ) else { + return; + }; + self.mention_set.insert_uri(crease_id, mention_uri.clone()); + + match mention_uri { + MentionUri::Fetch { url } => { + self.confirm_mention_for_fetch(crease_id, anchor, url, window, cx); + } + MentionUri::File { + abs_path, + is_directory, + } => { + self.confirm_mention_for_file( + crease_id, + anchor, + abs_path, + is_directory, + window, + cx, + ); + } + MentionUri::Symbol { .. } + | MentionUri::Thread { .. } + | MentionUri::TextThread { .. } + | MentionUri::Rule { .. } + | MentionUri::Selection { .. } => {} + } + } + + fn confirm_mention_for_file( + &mut self, + crease_id: CreaseId, + anchor: Anchor, + abs_path: PathBuf, + _is_directory: bool, + window: &mut Window, + cx: &mut Context, + ) { + let extension = abs_path + .extension() + .and_then(OsStr::to_str) + .unwrap_or_default(); + let project = self.project.clone(); + let Some(project_path) = project + .read(cx) + .project_path_for_absolute_path(&abs_path, cx) + else { + return; + }; + + if Img::extensions().contains(&extension) && !extension.contains("svg") { + let image = cx.spawn(async move |_, cx| { + let image = project + .update(cx, |project, cx| project.open_image(project_path, cx))? + .await?; + image.read_with(cx, |image, _cx| image.image.clone()) + }); + self.confirm_mention_for_image(crease_id, anchor, Some(abs_path), image, window, cx); + } + } + + fn confirm_mention_for_fetch( + &mut self, + crease_id: CreaseId, + anchor: Anchor, + url: url::Url, + window: &mut Window, + cx: &mut Context, + ) { + let Some(http_client) = self + .workspace + .update(cx, |workspace, _cx| workspace.client().http_client()) + .ok() + else { + return; + }; + + let url_string = url.to_string(); + let fetch = cx + .background_executor() + .spawn(async move { + fetch_url_content(http_client, url_string) + .map_err(|e| e.to_string()) + .await + }) + .shared(); + self.mention_set + .add_fetch_result(url.clone(), fetch.clone()); + + cx.spawn_in(window, async move |this, cx| { + let fetch = fetch.await.notify_async_err(cx); + this.update(cx, |this, cx| { + let mention_uri = MentionUri::Fetch { url }; + if fetch.is_some() { + this.mention_set.insert_uri(crease_id, mention_uri.clone()); + } else { + // Remove crease if we failed to fetch + this.editor.update(cx, |editor, cx| { + editor.display_map.update(cx, |display_map, cx| { + display_map.unfold_intersecting(vec![anchor..anchor], true, cx); + }); + editor.remove_creases([crease_id], cx); + }); + } + }) + .ok(); + }) + .detach(); + } + + pub fn confirm_mention_for_selection( + &mut self, + source_range: Range, + selections: Vec<(Entity, Range, Range)>, + window: &mut Window, + cx: &mut Context, + ) { + let snapshot = self.editor.read(cx).buffer().read(cx).snapshot(cx); + let Some((&excerpt_id, _, _)) = snapshot.as_singleton() else { + return; + }; + let Some(start) = snapshot.anchor_in_excerpt(excerpt_id, source_range.start) else { + return; + }; + + let offset = start.to_offset(&snapshot); + + for (buffer, selection_range, range_to_fold) in selections { + let range = snapshot.anchor_after(offset + range_to_fold.start) + ..snapshot.anchor_after(offset + range_to_fold.end); + + let path = buffer + .read(cx) + .file() + .map_or(PathBuf::from("untitled"), |file| file.path().to_path_buf()); + let snapshot = buffer.read(cx).snapshot(); + + let point_range = selection_range.to_point(&snapshot); + let line_range = point_range.start.row..point_range.end.row; + + let uri = MentionUri::Selection { + path: path.clone(), + line_range: line_range.clone(), + }; + let crease = crate::context_picker::crease_for_mention( + selection_name(&path, &line_range).into(), + uri.icon_path(cx), + range, + self.editor.downgrade(), + ); + + let crease_id = self.editor.update(cx, |editor, cx| { + let crease_ids = editor.insert_creases(vec![crease.clone()], cx); + editor.fold_creases(vec![crease], false, window, cx); + crease_ids.first().copied().unwrap() + }); + + self.mention_set + .insert_uri(crease_id, MentionUri::Selection { path, line_range }); + } + } + pub fn contents( &self, window: &mut Window, cx: &mut Context, ) -> Task>> { - let contents = self.mention_set.lock().contents( + let contents = self.mention_set.contents( self.project.clone(), self.thread_store.clone(), self.text_thread_store.clone(), @@ -198,15 +428,15 @@ impl MessageEditor { pub fn clear(&mut self, window: &mut Window, cx: &mut Context) { self.editor.update(cx, |editor, cx| { editor.clear(window, cx); - editor.remove_creases(self.mention_set.lock().drain(), cx) + editor.remove_creases(self.mention_set.drain(), cx) }); } - fn chat(&mut self, _: &Chat, _: &mut Window, cx: &mut Context) { + fn send(&mut self, _: &Chat, _: &mut Window, cx: &mut Context) { cx.emit(MessageEditorEvent::Send) } - fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { + fn cancel(&mut self, _: &editor::actions::Cancel, _: &mut Window, cx: &mut Context) { cx.emit(MessageEditorEvent::Cancel) } @@ -233,27 +463,46 @@ impl MessageEditor { let replacement_text = "image"; for image in images { - let (excerpt_id, anchor) = self.editor.update(cx, |message_editor, cx| { - let snapshot = message_editor.snapshot(window, cx); - let (excerpt_id, _, snapshot) = snapshot.buffer_snapshot.as_singleton().unwrap(); + let (excerpt_id, text_anchor, multibuffer_anchor) = + self.editor.update(cx, |message_editor, cx| { + let snapshot = message_editor.snapshot(window, cx); + let (excerpt_id, _, buffer_snapshot) = + snapshot.buffer_snapshot.as_singleton().unwrap(); - let anchor = snapshot.anchor_before(snapshot.len()); - message_editor.edit( - [( - multi_buffer::Anchor::max()..multi_buffer::Anchor::max(), - format!("{replacement_text} "), - )], - cx, - ); - (*excerpt_id, anchor) - }); + let text_anchor = buffer_snapshot.anchor_before(buffer_snapshot.len()); + let multibuffer_anchor = snapshot + .buffer_snapshot + .anchor_in_excerpt(*excerpt_id, text_anchor); + message_editor.edit( + [( + multi_buffer::Anchor::max()..multi_buffer::Anchor::max(), + format!("{replacement_text} "), + )], + cx, + ); + (*excerpt_id, text_anchor, multibuffer_anchor) + }); - self.insert_image( + let content_len = replacement_text.len(); + let Some(anchor) = multibuffer_anchor else { + return; + }; + let Some(crease_id) = insert_crease_for_image( excerpt_id, + text_anchor, + content_len, + None.clone(), + self.editor.clone(), + window, + cx, + ) else { + return; + }; + self.confirm_mention_for_image( + crease_id, anchor, - replacement_text.len(), - Arc::new(image), None, + Task::ready(Ok(Arc::new(image))), window, cx, ); @@ -267,9 +516,6 @@ impl MessageEditor { cx: &mut Context, ) { let buffer = self.editor.read(cx).buffer().clone(); - let Some((&excerpt_id, _, _)) = buffer.read(cx).snapshot(cx).as_singleton() else { - return; - }; let Some(buffer) = buffer.read(cx).as_singleton() else { return; }; @@ -292,10 +538,8 @@ impl MessageEditor { &path_prefix, false, entry.is_dir(), - excerpt_id, anchor..anchor, - self.editor.clone(), - self.mention_set.clone(), + cx.weak_entity(), self.project.clone(), cx, ) else { @@ -317,33 +561,32 @@ impl MessageEditor { } } - fn insert_image( + pub fn set_read_only(&mut self, read_only: bool, cx: &mut Context) { + self.editor.update(cx, |message_editor, cx| { + message_editor.set_read_only(read_only); + cx.notify() + }) + } + + fn confirm_mention_for_image( &mut self, - excerpt_id: ExcerptId, - crease_start: text::Anchor, - content_len: usize, - image: Arc, - abs_path: Option>, + crease_id: CreaseId, + anchor: Anchor, + abs_path: Option, + image: Task>>, window: &mut Window, cx: &mut Context, ) { - let Some(crease_id) = insert_crease_for_image( - excerpt_id, - crease_start, - content_len, - self.editor.clone(), - window, - cx, - ) else { - return; - }; self.editor.update(cx, |_editor, cx| { - let format = image.format; - let convert = LanguageModelImage::from_image(image, cx); - let task = cx .spawn_in(window, async move |editor, cx| { - if let Some(image) = convert.await { + let image = image.await.map_err(|e| e.to_string())?; + let format = image.format; + let image = cx + .update(|_, cx| LanguageModelImage::from_image(image, cx)) + .map_err(|e| e.to_string())? + .await; + if let Some(image) = image { Ok(MentionImage { abs_path, data: image.source, @@ -352,12 +595,6 @@ impl MessageEditor { } else { editor .update(cx, |editor, cx| { - let snapshot = editor.buffer().read(cx).snapshot(cx); - let Some(anchor) = - snapshot.anchor_in_excerpt(excerpt_id, crease_start) - else { - return; - }; editor.display_map.update(cx, |display_map, cx| { display_map.unfold_intersecting(vec![anchor..anchor], true, cx); }); @@ -375,7 +612,7 @@ impl MessageEditor { }) .detach(); - self.mention_set.lock().insert_image(crease_id, task); + self.mention_set.insert_image(crease_id, task); }); } @@ -392,6 +629,8 @@ impl MessageEditor { window: &mut Window, cx: &mut Context, ) { + self.clear(window, cx); + let mut text = String::new(); let mut mentions = Vec::new(); let mut images = Vec::new(); @@ -429,7 +668,6 @@ impl MessageEditor { editor.buffer().read(cx).snapshot(cx) }); - self.mention_set.lock().clear(); for (range, mention_uri) in mentions { let anchor = snapshot.anchor_before(range.start); let crease_id = crate::context_picker::insert_crease_for_mention( @@ -444,7 +682,7 @@ impl MessageEditor { ); if let Some(crease_id) = crease_id { - self.mention_set.lock().insert_uri(crease_id, mention_uri); + self.mention_set.insert_uri(crease_id, mention_uri); } } for (range, content) in images { @@ -479,7 +717,7 @@ impl MessageEditor { let data: SharedString = content.data.to_string().into(); if let Some(crease_id) = crease_id { - self.mention_set.lock().insert_image( + self.mention_set.insert_image( crease_id, Task::ready(Ok(MentionImage { abs_path, @@ -499,6 +737,11 @@ impl MessageEditor { editor.set_text(text, window, cx); }); } + + #[cfg(test)] + pub fn text(&self, cx: &App) -> String { + self.editor.read(cx).text(cx) + } } impl Focusable for MessageEditor { @@ -511,7 +754,7 @@ impl Render for MessageEditor { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { div() .key_context("MessageEditor") - .on_action(cx.listener(Self::chat)) + .on_action(cx.listener(Self::send)) .on_action(cx.listener(Self::cancel)) .capture_action(cx.listener(Self::paste)) .flex_1() @@ -550,20 +793,78 @@ pub(crate) fn insert_crease_for_image( excerpt_id: ExcerptId, anchor: text::Anchor, content_len: usize, + abs_path: Option>, editor: Entity, window: &mut Window, cx: &mut App, ) -> Option { - crate::context_picker::insert_crease_for_mention( - excerpt_id, - anchor, - content_len, - "Image".into(), - IconName::Image.path().into(), - editor, - window, - cx, - ) + let crease_label = abs_path + .as_ref() + .and_then(|path| path.file_name()) + .map(|name| name.to_string_lossy().to_string().into()) + .unwrap_or(SharedString::from("Image")); + + editor.update(cx, |editor, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + + let start = snapshot.anchor_in_excerpt(excerpt_id, anchor)?; + + let start = start.bias_right(&snapshot); + let end = snapshot.anchor_before(start.to_offset(&snapshot) + content_len); + + let placeholder = FoldPlaceholder { + render: render_image_fold_icon_button(crease_label, cx.weak_entity()), + merge_adjacent: false, + ..Default::default() + }; + + let crease = Crease::Inline { + range: start..end, + placeholder, + render_toggle: None, + render_trailer: None, + metadata: None, + }; + + let ids = editor.insert_creases(vec![crease.clone()], cx); + editor.fold_creases(vec![crease], false, window, cx); + + Some(ids[0]) + }) +} + +fn render_image_fold_icon_button( + label: SharedString, + editor: WeakEntity, +) -> Arc, &mut App) -> AnyElement> { + Arc::new({ + move |fold_id, fold_range, cx| { + let is_in_text_selection = editor + .update(cx, |editor, cx| editor.is_range_selected(&fold_range, cx)) + .unwrap_or_default(); + + ButtonLike::new(fold_id) + .style(ButtonStyle::Filled) + .selected_style(ButtonStyle::Tinted(TintColor::Accent)) + .toggle_state(is_in_text_selection) + .child( + h_flex() + .gap_1() + .child( + Icon::new(IconName::Image) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child( + Label::new(label.clone()) + .size(LabelSize::Small) + .buffer_font(cx) + .single_line(), + ), + ) + .into_any_element() + } + }) } #[cfg(test)] diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 3ce68985e0..e0545878f8 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -45,6 +45,7 @@ use zed_actions::assistant::OpenRulesLibrary; use super::entry_view_state::EntryViewState; use crate::acp::AcpModelSelectorPopover; +use crate::acp::entry_view_state::{EntryViewEvent, ViewEvent}; use crate::acp::message_editor::{MessageEditor, MessageEditorEvent}; use crate::agent_diff::AgentDiff; use crate::profile_selector::{ProfileProvider, ProfileSelector}; @@ -101,10 +102,8 @@ pub struct AcpThreadView { agent: Rc, workspace: WeakEntity, project: Entity, - thread_store: Entity, - text_thread_store: Entity, thread_state: ThreadState, - entry_view_state: EntryViewState, + entry_view_state: Entity, message_editor: Entity, model_selector: Option>, profile_selector: Option>, @@ -119,16 +118,9 @@ pub struct AcpThreadView { plan_expanded: bool, editor_expanded: bool, terminal_expanded: bool, - editing_message: Option, + editing_message: Option, _cancel_task: Option>, - _subscriptions: [Subscription; 2], -} - -struct EditingMessage { - index: usize, - message_id: UserMessageId, - editor: Entity, - _subscription: Subscription, + _subscriptions: [Subscription; 3], } enum ThreadState { @@ -175,25 +167,33 @@ impl AcpThreadView { let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0)); + let entry_view_state = cx.new(|_| { + EntryViewState::new( + workspace.clone(), + project.clone(), + thread_store.clone(), + text_thread_store.clone(), + ) + }); + let subscriptions = [ cx.observe_global_in::(window, Self::settings_changed), - cx.subscribe_in(&message_editor, window, Self::on_message_editor_event), + cx.subscribe_in(&message_editor, window, Self::handle_message_editor_event), + cx.subscribe_in(&entry_view_state, window, Self::handle_entry_view_event), ]; Self { agent: agent.clone(), workspace: workspace.clone(), project: project.clone(), - thread_store, - text_thread_store, + entry_view_state, thread_state: Self::initial_state(agent, workspace, project, window, cx), message_editor, model_selector: None, profile_selector: None, notifications: Vec::new(), notification_subscriptions: HashMap::default(), - entry_view_state: EntryViewState::default(), - list_state, + list_state: list_state, thread_error: None, auth_task: None, expanded_tool_calls: HashSet::default(), @@ -412,7 +412,7 @@ impl AcpThreadView { cx.notify(); } - pub fn on_message_editor_event( + pub fn handle_message_editor_event( &mut self, _: &Entity, event: &MessageEditorEvent, @@ -422,6 +422,28 @@ impl AcpThreadView { match event { MessageEditorEvent::Send => self.send(window, cx), MessageEditorEvent::Cancel => self.cancel_generation(cx), + MessageEditorEvent::Focus => {} + } + } + + pub fn handle_entry_view_event( + &mut self, + _: &Entity, + event: &EntryViewEvent, + window: &mut Window, + cx: &mut Context, + ) { + match &event.view_event { + ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Focus) => { + self.editing_message = Some(event.entry_index); + cx.notify(); + } + ViewEvent::MessageEditorEvent(editor, MessageEditorEvent::Send) => { + self.regenerate(event.entry_index, editor, window, cx); + } + ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Cancel) => { + self.cancel_editing(&Default::default(), window, cx); + } } } @@ -492,27 +514,56 @@ impl AcpThreadView { .detach(); } - fn cancel_editing(&mut self, _: &ClickEvent, _window: &mut Window, cx: &mut Context) { - self.editing_message.take(); - cx.notify(); - } - - fn regenerate(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context) { - let Some(editing_message) = self.editing_message.take() else { - return; - }; - + fn cancel_editing(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context) { let Some(thread) = self.thread().cloned() else { return; }; - let rewind = thread.update(cx, |thread, cx| { - thread.rewind(editing_message.message_id, cx) - }); + if let Some(index) = self.editing_message.take() { + if let Some(editor) = self + .entry_view_state + .read(cx) + .entry(index) + .and_then(|e| e.message_editor()) + .cloned() + { + editor.update(cx, |editor, cx| { + if let Some(user_message) = thread + .read(cx) + .entries() + .get(index) + .and_then(|e| e.user_message()) + { + editor.set_message(user_message.chunks.clone(), window, cx); + } + }) + } + }; + self.focus_handle(cx).focus(window); + cx.notify(); + } + + fn regenerate( + &mut self, + entry_ix: usize, + message_editor: &Entity, + window: &mut Window, + cx: &mut Context, + ) { + let Some(thread) = self.thread().cloned() else { + return; + }; + + let Some(rewind) = thread.update(cx, |thread, cx| { + let user_message_id = thread.entries().get(entry_ix)?.user_message()?.id.clone()?; + Some(thread.rewind(user_message_id, cx)) + }) else { + return; + }; + + let contents = + message_editor.update(cx, |message_editor, cx| message_editor.contents(window, cx)); - let contents = editing_message - .editor - .update(cx, |message_editor, cx| message_editor.contents(window, cx)); let task = cx.foreground_executor().spawn(async move { rewind.await?; contents.await @@ -568,27 +619,20 @@ impl AcpThreadView { AcpThreadEvent::NewEntry => { let len = thread.read(cx).entries().len(); let index = len - 1; - self.entry_view_state.sync_entry( - self.workspace.clone(), - thread.clone(), - index, - window, - cx, - ); + self.entry_view_state.update(cx, |view_state, cx| { + view_state.sync_entry(index, &thread, window, cx) + }); self.list_state.splice(index..index, 1); } AcpThreadEvent::EntryUpdated(index) => { - self.entry_view_state.sync_entry( - self.workspace.clone(), - thread.clone(), - *index, - window, - cx, - ); + self.entry_view_state.update(cx, |view_state, cx| { + view_state.sync_entry(*index, &thread, window, cx) + }); self.list_state.splice(*index..index + 1, 1); } AcpThreadEvent::EntriesRemoved(range) => { - self.entry_view_state.remove(range.clone()); + self.entry_view_state + .update(cx, |view_state, _cx| view_state.remove(range.clone())); self.list_state.splice(range.clone(), 0); } AcpThreadEvent::ToolAuthorizationRequired => { @@ -720,29 +764,15 @@ impl AcpThreadView { .border_1() .border_color(cx.theme().colors().border) .text_xs() - .id("message") - .on_click(cx.listener({ - move |this, _, window, cx| { - this.start_editing_message(entry_ix, window, cx) - } - })) .children( - if let Some(editing) = self.editing_message.as_ref() - && Some(&editing.message_id) == message.id.as_ref() - { - Some( - self.render_edit_message_editor(editing, cx) - .into_any_element(), - ) - } else { - message.content.markdown().map(|md| { - self.render_markdown( - md.clone(), - user_message_markdown_style(window, cx), - ) - .into_any_element() - }) - }, + self.entry_view_state + .read(cx) + .entry(entry_ix) + .and_then(|entry| entry.message_editor()) + .map(|editor| { + self.render_sent_message_editor(entry_ix, editor, cx) + .into_any_element() + }), ), ) .into_any(), @@ -817,8 +847,8 @@ impl AcpThreadView { primary }; - if let Some(editing) = self.editing_message.as_ref() - && editing.index < entry_ix + if let Some(editing_index) = self.editing_message.as_ref() + && *editing_index < entry_ix { let backdrop = div() .id(("backdrop", entry_ix)) @@ -832,8 +862,8 @@ impl AcpThreadView { div() .relative() - .child(backdrop) .child(primary) + .child(backdrop) .into_any_element() } else { primary @@ -1254,9 +1284,7 @@ impl AcpThreadView { Empty.into_any_element() } } - ToolCallContent::Diff(diff) => { - self.render_diff_editor(entry_ix, &diff.read(cx).multibuffer(), cx) - } + ToolCallContent::Diff(diff) => self.render_diff_editor(entry_ix, &diff, cx), ToolCallContent::Terminal(terminal) => { self.render_terminal_tool_call(entry_ix, terminal, tool_call, window, cx) } @@ -1403,7 +1431,7 @@ impl AcpThreadView { fn render_diff_editor( &self, entry_ix: usize, - multibuffer: &Entity, + diff: &Entity, cx: &Context, ) -> AnyElement { v_flex() @@ -1411,8 +1439,8 @@ impl AcpThreadView { .border_t_1() .border_color(self.tool_card_border_color(cx)) .child( - if let Some(entry) = self.entry_view_state.entry(entry_ix) - && let Some(editor) = entry.editor_for_diff(&multibuffer) + if let Some(entry) = self.entry_view_state.read(cx).entry(entry_ix) + && let Some(editor) = entry.editor_for_diff(&diff) { editor.clone().into_any_element() } else { @@ -1615,6 +1643,7 @@ impl AcpThreadView { let terminal_view = self .entry_view_state + .read(cx) .entry(entry_ix) .and_then(|entry| entry.terminal(&terminal)); let show_output = self.terminal_expanded && terminal_view.is_some(); @@ -2483,82 +2512,38 @@ impl AcpThreadView { ) } - fn start_editing_message(&mut self, index: usize, window: &mut Window, cx: &mut Context) { - let Some(thread) = self.thread() else { - return; - }; - let Some(AgentThreadEntry::UserMessage(message)) = thread.read(cx).entries().get(index) - else { - return; - }; - let Some(message_id) = message.id.clone() else { - return; - }; - - self.list_state.scroll_to_reveal_item(index); - - let chunks = message.chunks.clone(); - let editor = cx.new(|cx| { - let mut editor = MessageEditor::new( - self.workspace.clone(), - self.project.clone(), - self.thread_store.clone(), - self.text_thread_store.clone(), - editor::EditorMode::AutoHeight { - min_lines: 1, - max_lines: None, - }, - window, - cx, - ); - editor.set_message(chunks, window, cx); - editor - }); - let subscription = - cx.subscribe_in(&editor, window, |this, _, event, window, cx| match event { - MessageEditorEvent::Send => { - this.regenerate(&Default::default(), window, cx); - } - MessageEditorEvent::Cancel => { - this.cancel_editing(&Default::default(), window, cx); - } - }); - editor.focus_handle(cx).focus(window); - - self.editing_message.replace(EditingMessage { - index: index, - message_id: message_id.clone(), - editor, - _subscription: subscription, - }); - cx.notify(); - } - - fn render_edit_message_editor(&self, editing: &EditingMessage, cx: &Context) -> Div { - v_flex() - .w_full() - .gap_2() - .child(editing.editor.clone()) - .child( - h_flex() - .gap_1() - .child( - Icon::new(IconName::Warning) - .color(Color::Warning) - .size(IconSize::XSmall), - ) - .child( - Label::new("Editing will restart the thread from this point.") - .color(Color::Muted) - .size(LabelSize::XSmall), - ) - .child(self.render_editing_message_editor_buttons(editing, cx)), - ) - } - - fn render_editing_message_editor_buttons( + fn render_sent_message_editor( &self, - editing: &EditingMessage, + entry_ix: usize, + editor: &Entity, + cx: &Context, + ) -> Div { + v_flex().w_full().gap_2().child(editor.clone()).when( + self.editing_message == Some(entry_ix), + |el| { + el.child( + h_flex() + .gap_1() + .child( + Icon::new(IconName::Warning) + .color(Color::Warning) + .size(IconSize::XSmall), + ) + .child( + Label::new("Editing will restart the thread from this point.") + .color(Color::Muted) + .size(LabelSize::XSmall), + ) + .child(self.render_sent_message_editor_buttons(entry_ix, editor, cx)), + ) + }, + ) + } + + fn render_sent_message_editor_buttons( + &self, + entry_ix: usize, + editor: &Entity, cx: &Context, ) -> Div { h_flex() @@ -2571,7 +2556,7 @@ impl AcpThreadView { .icon_color(Color::Error) .icon_size(IconSize::Small) .tooltip({ - let focus_handle = editing.editor.focus_handle(cx); + let focus_handle = editor.focus_handle(cx); move |window, cx| { Tooltip::for_action_in( "Cancel Edit", @@ -2586,12 +2571,12 @@ impl AcpThreadView { ) .child( IconButton::new("confirm-edit-message", IconName::Return) - .disabled(editing.editor.read(cx).is_empty(cx)) + .disabled(editor.read(cx).is_empty(cx)) .shape(ui::IconButtonShape::Square) .icon_color(Color::Muted) .icon_size(IconSize::Small) .tooltip({ - let focus_handle = editing.editor.focus_handle(cx); + let focus_handle = editor.focus_handle(cx); move |window, cx| { Tooltip::for_action_in( "Regenerate", @@ -2602,7 +2587,12 @@ impl AcpThreadView { ) } }) - .on_click(cx.listener(Self::regenerate)), + .on_click(cx.listener({ + let editor = editor.clone(); + move |this, _, window, cx| { + this.regenerate(entry_ix, &editor, window, cx); + } + })), ) } @@ -3102,7 +3092,9 @@ impl AcpThreadView { } fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context) { - self.entry_view_state.settings_changed(cx); + self.entry_view_state.update(cx, |entry_view_state, cx| { + entry_view_state.settings_changed(cx); + }); } pub(crate) fn insert_dragged_files( @@ -3117,9 +3109,7 @@ impl AcpThreadView { drop(added_worktrees); }) } -} -impl AcpThreadView { fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option
{ let content = match self.thread_error.as_ref()? { ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx), @@ -3411,35 +3401,6 @@ impl Render for AcpThreadView { } } -fn user_message_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { - let mut style = default_markdown_style(false, window, cx); - let mut text_style = window.text_style(); - let theme_settings = ThemeSettings::get_global(cx); - - let buffer_font = theme_settings.buffer_font.family.clone(); - let buffer_font_size = TextSize::Small.rems(cx); - - text_style.refine(&TextStyleRefinement { - font_family: Some(buffer_font), - font_size: Some(buffer_font_size.into()), - ..Default::default() - }); - - style.base_text_style = text_style; - style.link_callback = Some(Rc::new(move |url, cx| { - if MentionUri::parse(url).is_ok() { - let colors = cx.theme().colors(); - Some(TextStyleRefinement { - background_color: Some(colors.element_background), - ..Default::default() - }) - } else { - None - } - })); - style -} - fn default_markdown_style(buffer_font: bool, window: &Window, cx: &App) -> MarkdownStyle { let theme_settings = ThemeSettings::get_global(cx); let colors = cx.theme().colors(); @@ -3598,12 +3559,13 @@ pub(crate) mod tests { use agent_client_protocol::SessionId; use editor::EditorSettings; use fs::FakeFs; - use gpui::{SemanticVersion, TestAppContext, VisualTestContext}; + use gpui::{EventEmitter, SemanticVersion, TestAppContext, VisualTestContext}; use project::Project; use serde_json::json; use settings::SettingsStore; use std::any::Any; use std::path::Path; + use workspace::Item; use super::*; @@ -3750,6 +3712,50 @@ pub(crate) mod tests { (thread_view, cx) } + fn add_to_workspace(thread_view: Entity, cx: &mut VisualTestContext) { + let workspace = thread_view.read_with(cx, |thread_view, _cx| thread_view.workspace.clone()); + + workspace + .update_in(cx, |workspace, window, cx| { + workspace.add_item_to_active_pane( + Box::new(cx.new(|_| ThreadViewItem(thread_view.clone()))), + None, + true, + window, + cx, + ); + }) + .unwrap(); + } + + struct ThreadViewItem(Entity); + + impl Item for ThreadViewItem { + type Event = (); + + fn include_in_nav_history() -> bool { + false + } + + fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { + "Test".into() + } + } + + impl EventEmitter<()> for ThreadViewItem {} + + impl Focusable for ThreadViewItem { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.0.read(cx).focus_handle(cx).clone() + } + } + + impl Render for ThreadViewItem { + fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { + self.0.clone().into_any_element() + } + } + struct StubAgentServer { connection: C, } @@ -3771,19 +3777,19 @@ pub(crate) mod tests { C: 'static + AgentConnection + Send + Clone, { fn logo(&self) -> ui::IconName { - unimplemented!() + ui::IconName::Ai } fn name(&self) -> &'static str { - unimplemented!() + "Test" } fn empty_state_headline(&self) -> &'static str { - unimplemented!() + "Test" } fn empty_state_message(&self) -> &'static str { - unimplemented!() + "Test" } fn connect( @@ -3932,9 +3938,17 @@ pub(crate) mod tests { assert_eq!(thread.entries().len(), 2); }); - thread_view.read_with(cx, |view, _| { - assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0); - assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1); + thread_view.read_with(cx, |view, cx| { + view.entry_view_state.read_with(cx, |entry_view_state, _| { + assert!( + entry_view_state + .entry(0) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(1).unwrap().has_content()); + }); }); // Second user message @@ -3963,18 +3977,31 @@ pub(crate) mod tests { let second_user_message_id = thread.read_with(cx, |thread, _| { assert_eq!(thread.entries().len(), 4); - let AgentThreadEntry::UserMessage(user_message) = thread.entries().get(2).unwrap() - else { + let AgentThreadEntry::UserMessage(user_message) = &thread.entries()[2] else { panic!(); }; user_message.id.clone().unwrap() }); - thread_view.read_with(cx, |view, _| { - assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0); - assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1); - assert_eq!(view.entry_view_state.entry(2).unwrap().len(), 0); - assert_eq!(view.entry_view_state.entry(3).unwrap().len(), 1); + thread_view.read_with(cx, |view, cx| { + view.entry_view_state.read_with(cx, |entry_view_state, _| { + assert!( + entry_view_state + .entry(0) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(1).unwrap().has_content()); + assert!( + entry_view_state + .entry(2) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(3).unwrap().has_content()); + }); }); // Rewind to first message @@ -3989,13 +4016,169 @@ pub(crate) mod tests { assert_eq!(thread.entries().len(), 2); }); - thread_view.read_with(cx, |view, _| { - assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0); - assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1); + thread_view.read_with(cx, |view, cx| { + view.entry_view_state.read_with(cx, |entry_view_state, _| { + assert!( + entry_view_state + .entry(0) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(1).unwrap().has_content()); - // Old views should be dropped - assert!(view.entry_view_state.entry(2).is_none()); - assert!(view.entry_view_state.entry(3).is_none()); + // Old views should be dropped + assert!(entry_view_state.entry(2).is_none()); + assert!(entry_view_state.entry(3).is_none()); + }); }); } + + #[gpui::test] + async fn test_message_editing_cancel(cx: &mut TestAppContext) { + init_test(cx); + + let connection = StubAgentConnection::new(); + + connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text: "Response".into(), + annotations: None, + }), + }]); + + let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await; + add_to_workspace(thread_view.clone(), cx); + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Original message to edit", window, cx); + }); + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); + + cx.run_until_parked(); + + let user_message_editor = thread_view.read_with(cx, |view, cx| { + assert_eq!(view.editing_message, None); + + view.entry_view_state + .read(cx) + .entry(0) + .unwrap() + .message_editor() + .unwrap() + .clone() + }); + + // Focus + cx.focus(&user_message_editor); + thread_view.read_with(cx, |view, _cx| { + assert_eq!(view.editing_message, Some(0)); + }); + + // Edit + user_message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Edited message content", window, cx); + }); + + // Cancel + user_message_editor.update_in(cx, |_editor, window, cx| { + window.dispatch_action(Box::new(editor::actions::Cancel), cx); + }); + + thread_view.read_with(cx, |view, _cx| { + assert_eq!(view.editing_message, None); + }); + + user_message_editor.read_with(cx, |editor, cx| { + assert_eq!(editor.text(cx), "Original message to edit"); + }); + } + + #[gpui::test] + async fn test_message_editing_regenerate(cx: &mut TestAppContext) { + init_test(cx); + + let connection = StubAgentConnection::new(); + + connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text: "Response".into(), + annotations: None, + }), + }]); + + let (thread_view, cx) = + setup_thread_view(StubAgentServer::new(connection.clone()), cx).await; + add_to_workspace(thread_view.clone(), cx); + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Original message to edit", window, cx); + }); + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); + + cx.run_until_parked(); + + let user_message_editor = thread_view.read_with(cx, |view, cx| { + assert_eq!(view.editing_message, None); + assert_eq!(view.thread().unwrap().read(cx).entries().len(), 2); + + view.entry_view_state + .read(cx) + .entry(0) + .unwrap() + .message_editor() + .unwrap() + .clone() + }); + + // Focus + cx.focus(&user_message_editor); + + // Edit + user_message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Edited message content", window, cx); + }); + + // Send + connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text: "New Response".into(), + annotations: None, + }), + }]); + + user_message_editor.update_in(cx, |_editor, window, cx| { + window.dispatch_action(Box::new(Chat), cx); + }); + + cx.run_until_parked(); + + thread_view.read_with(cx, |view, cx| { + assert_eq!(view.editing_message, None); + + let entries = view.thread().unwrap().read(cx).entries(); + assert_eq!(entries.len(), 2); + assert_eq!( + entries[0].to_markdown(cx), + "## User\n\nEdited message content\n\n" + ); + assert_eq!( + entries[1].to_markdown(cx), + "## Assistant\n\nNew Response\n\n" + ); + + let new_editor = view.entry_view_state.read_with(cx, |state, _cx| { + assert!(!state.entry(1).unwrap().has_content()); + state.entry(0).unwrap().message_editor().unwrap().clone() + }); + + assert_eq!(new_editor.read(cx).text(cx), "Edited message content"); + }) + } } diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 96e8b2191f..3d258ffc57 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -462,7 +462,7 @@ impl AgentConfiguration { "modifier-send", "Use modifier to submit a message", Some( - "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(), + "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux or Windows) required to send messages.".into(), ), use_modifier_to_send, move |state, _window, cx| { diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 73915195f5..519f7980ff 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -818,12 +818,10 @@ impl AgentPanel { ActiveView::Thread { thread, .. } => { thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx)); } - ActiveView::ExternalAgentThread { thread_view, .. } => { - thread_view.update(cx, |thread_element, cx| { - thread_element.cancel_generation(cx) - }); - } - ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} + ActiveView::ExternalAgentThread { .. } + | ActiveView::TextThread { .. } + | ActiveView::History + | ActiveView::Configuration => {} } } diff --git a/crates/agent_ui/src/context_picker.rs b/crates/agent_ui/src/context_picker.rs index 7dc00bfae2..131023d249 100644 --- a/crates/agent_ui/src/context_picker.rs +++ b/crates/agent_ui/src/context_picker.rs @@ -13,7 +13,7 @@ use anyhow::{Result, anyhow}; use collections::HashSet; pub use completion_provider::ContextPickerCompletionProvider; use editor::display_map::{Crease, CreaseId, CreaseMetadata, FoldId}; -use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset}; +use editor::{Anchor, Editor, ExcerptId, FoldPlaceholder, ToOffset}; use fetch_context_picker::FetchContextPicker; use file_context_picker::FileContextPicker; use file_context_picker::render_file_context_entry; @@ -228,7 +228,7 @@ impl ContextPicker { } fn build_menu(&mut self, window: &mut Window, cx: &mut Context) -> Entity { - let context_picker = cx.entity().clone(); + let context_picker = cx.entity(); let menu = ContextMenu::build(window, cx, move |menu, _window, cx| { let recent = self.recent_entries(cx); @@ -837,42 +837,9 @@ fn render_fold_icon_button( ) -> Arc, &mut App) -> AnyElement> { Arc::new({ move |fold_id, fold_range, cx| { - let is_in_text_selection = editor.upgrade().is_some_and(|editor| { - editor.update(cx, |editor, cx| { - let snapshot = editor - .buffer() - .update(cx, |multi_buffer, cx| multi_buffer.snapshot(cx)); - - let is_in_pending_selection = || { - editor - .selections - .pending - .as_ref() - .is_some_and(|pending_selection| { - pending_selection - .selection - .range() - .includes(&fold_range, &snapshot) - }) - }; - - let mut is_in_complete_selection = || { - editor - .selections - .disjoint_in_range::(fold_range.clone(), cx) - .into_iter() - .any(|selection| { - // This is needed to cover a corner case, if we just check for an existing - // selection in the fold range, having a cursor at the start of the fold - // marks it as selected. Non-empty selections don't cause this. - let length = selection.end - selection.start; - length > 0 - }) - }; - - is_in_pending_selection() || is_in_complete_selection() - }) - }); + let is_in_text_selection = editor + .update(cx, |editor, cx| editor.is_range_selected(&fold_range, cx)) + .unwrap_or_default(); ButtonLike::new(fold_id) .style(ButtonStyle::Filled) diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 4a4a747899..bbd3595805 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -72,7 +72,7 @@ pub fn init( let Some(window) = window else { return; }; - let workspace = cx.entity().clone(); + let workspace = cx.entity(); InlineAssistant::update_global(cx, |inline_assistant, cx| { inline_assistant.register_workspace(&workspace, window, cx) }); diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 7121624c87..bb8514a224 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -1,5 +1,6 @@ use std::{cmp::Reverse, sync::Arc}; +use cloud_llm_client::Plan; use collections::{HashSet, IndexMap}; use feature_flags::ZedProFeatureFlag; use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; @@ -10,7 +11,6 @@ use language_model::{ }; use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; -use proto::Plan; use ui::{ListItem, ListItemSpacing, prelude::*}; const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro"; @@ -536,7 +536,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { ) -> Option { use feature_flags::FeatureFlagAppExt; - let plan = proto::Plan::ZedPro; + let plan = Plan::ZedPro; Some( h_flex() @@ -557,7 +557,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { window .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx) }), - Plan::Free | Plan::ZedProTrial => Button::new( + Plan::ZedFree | Plan::ZedProTrial => Button::new( "try-pro", if plan == Plan::ZedProTrial { "Upgrade to Pro" diff --git a/crates/agent_ui/src/profile_selector.rs b/crates/agent_ui/src/profile_selector.rs index 27ca69590f..ce25f531e2 100644 --- a/crates/agent_ui/src/profile_selector.rs +++ b/crates/agent_ui/src/profile_selector.rs @@ -163,7 +163,7 @@ impl Render for ProfileSelector { .unwrap_or_else(|| "Unknown".into()); if self.provider.profiles_supported(cx) { - let this = cx.entity().clone(); + let this = cx.entity(); let focus_handle = self.focus_handle.clone(); let trigger_button = Button::new("profile-selector-model", selected_profile) .label_size(LabelSize::Small) diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 439fb100d2..3c451fcb01 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -1,16 +1,12 @@ use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore}; use anyhow::{Context as _, Result, anyhow}; -use chrono::Duration; use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo}; use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit}; use futures::{StreamExt, stream::BoxStream}; use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext}; use http_client::{AsyncBody, Method, Request, http}; use parking_lot::Mutex; -use rpc::{ - ConnectionId, Peer, Receipt, TypedEnvelope, - proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse}, -}; +use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto}; use std::sync::Arc; pub struct FakeServer { @@ -187,50 +183,27 @@ impl FakeServer { pub async fn receive(&self) -> Result> { self.executor.start_waiting(); - loop { - let message = self - .state - .lock() - .incoming - .as_mut() - .expect("not connected") - .next() - .await - .context("other half hung up")?; - self.executor.finish_waiting(); - let type_name = message.payload_type_name(); - let message = message.into_any(); + let message = self + .state + .lock() + .incoming + .as_mut() + .expect("not connected") + .next() + .await + .context("other half hung up")?; + self.executor.finish_waiting(); + let type_name = message.payload_type_name(); + let message = message.into_any(); - if message.is::>() { - return Ok(*message.downcast().unwrap()); - } - - let accepted_tos_at = chrono::Utc::now() - .checked_sub_signed(Duration::hours(5)) - .expect("failed to build accepted_tos_at") - .timestamp() as u64; - - if message.is::>() { - self.respond( - message - .downcast::>() - .unwrap() - .receipt(), - GetPrivateUserInfoResponse { - metrics_id: "the-metrics-id".into(), - staff: false, - flags: Default::default(), - accepted_tos_at: Some(accepted_tos_at), - }, - ); - continue; - } - - panic!( - "fake server received unexpected message type: {:?}", - type_name - ); + if message.is::>() { + return Ok(*message.downcast().unwrap()); } + + panic!( + "fake server received unexpected message type: {:?}", + type_name + ); } pub fn respond(&self, receipt: Receipt, response: T::Response) { diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index faf46945d8..da7f50076b 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -177,7 +177,6 @@ impl UserStore { let (mut current_user_tx, current_user_rx) = watch::channel(); let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded(); let rpc_subscriptions = vec![ - client.add_message_handler(cx.weak_entity(), Self::handle_update_plan), client.add_message_handler(cx.weak_entity(), Self::handle_update_contacts), client.add_message_handler(cx.weak_entity(), Self::handle_update_invite_info), client.add_message_handler(cx.weak_entity(), Self::handle_show_contacts), @@ -343,26 +342,6 @@ impl UserStore { Ok(()) } - async fn handle_update_plan( - this: Entity, - _message: TypedEnvelope, - mut cx: AsyncApp, - ) -> Result<()> { - let client = this - .read_with(&cx, |this, _| this.client.upgrade())? - .context("client was dropped")?; - - let response = client - .cloud_client() - .get_authenticated_user() - .await - .context("failed to fetch authenticated user")?; - - this.update(&mut cx, |this, cx| { - this.update_authenticated_user(response, cx); - }) - } - fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { match message { UpdateContacts::Wait(barrier) => { @@ -1019,19 +998,6 @@ impl RequestUsage { } } - pub fn from_proto(amount: u32, limit: proto::UsageLimit) -> Option { - let limit = match limit.variant? { - proto::usage_limit::Variant::Limited(limited) => { - UsageLimit::Limited(limited.limit as i32) - } - proto::usage_limit::Variant::Unlimited(_) => UsageLimit::Unlimited, - }; - Some(RequestUsage { - limit, - amount: amount as i32, - }) - } - fn from_headers( limit_name: &str, amount_name: &str, diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 9af95317e6..6fc591be13 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -19,7 +19,6 @@ test-support = ["sqlite"] [dependencies] anyhow.workspace = true -async-stripe.workspace = true async-trait.workspace = true async-tungstenite.workspace = true aws-config = { version = "1.1.5" } @@ -33,13 +32,11 @@ clock.workspace = true cloud_llm_client.workspace = true collections.workspace = true dashmap.workspace = true -derive_more.workspace = true envy = "0.4.2" futures.workspace = true gpui.workspace = true hex.workspace = true http_client.workspace = true -jsonwebtoken.workspace = true livekit_api.workspace = true log.workspace = true nanoid.workspace = true @@ -65,7 +62,6 @@ subtle.workspace = true supermaven_api.workspace = true telemetry_events.workspace = true text.workspace = true -thiserror.workspace = true time.workspace = true tokio = { workspace = true, features = ["full"] } toml.workspace = true @@ -136,6 +132,3 @@ util.workspace = true workspace = { workspace = true, features = ["test-support"] } worktree = { workspace = true, features = ["test-support"] } zlog.workspace = true - -[package.metadata.cargo-machete] -ignored = ["async-stripe"] diff --git a/crates/collab/k8s/collab.template.yml b/crates/collab/k8s/collab.template.yml index 45fc018a4a..214b550ac2 100644 --- a/crates/collab/k8s/collab.template.yml +++ b/crates/collab/k8s/collab.template.yml @@ -219,12 +219,6 @@ spec: secretKeyRef: name: slack key: panics_webhook - - name: STRIPE_API_KEY - valueFrom: - secretKeyRef: - name: stripe - key: api_key - optional: true - name: COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR value: "1000" - name: SUPERMAVEN_ADMIN_API_KEY diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 078a4469ae..0cc7e2b2e9 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -1,16 +1,10 @@ -pub mod billing; pub mod contributors; pub mod events; pub mod extensions; pub mod ips_file; pub mod slack; -use crate::db::Database; -use crate::{ - AppState, Error, Result, auth, - db::{User, UserId}, - rpc, -}; +use crate::{AppState, Error, Result, auth, db::UserId, rpc}; use anyhow::Context as _; use axum::{ Extension, Json, Router, @@ -97,7 +91,6 @@ impl std::fmt::Display for SystemIdHeader { pub fn routes(rpc_server: Arc) -> Router<(), Body> { Router::new() - .route("/users/look_up", get(look_up_user)) .route("/users/:id/access_tokens", post(create_access_token)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) .merge(contributors::router()) @@ -139,99 +132,6 @@ pub async fn validate_api_token(req: Request, next: Next) -> impl IntoR Ok::<_, Error>(next.run(req).await) } -#[derive(Debug, Deserialize)] -struct LookUpUserParams { - identifier: String, -} - -#[derive(Debug, Serialize)] -struct LookUpUserResponse { - user: Option, -} - -async fn look_up_user( - Query(params): Query, - Extension(app): Extension>, -) -> Result> { - let user = resolve_identifier_to_user(&app.db, ¶ms.identifier).await?; - let user = if let Some(user) = user { - match user { - UserOrId::User(user) => Some(user), - UserOrId::Id(id) => app.db.get_user_by_id(id).await?, - } - } else { - None - }; - - Ok(Json(LookUpUserResponse { user })) -} - -enum UserOrId { - User(User), - Id(UserId), -} - -async fn resolve_identifier_to_user( - db: &Arc, - identifier: &str, -) -> Result> { - if let Some(identifier) = identifier.parse::().ok() { - let user = db.get_user_by_id(UserId(identifier)).await?; - - return Ok(user.map(UserOrId::User)); - } - - if identifier.starts_with("cus_") { - let billing_customer = db - .get_billing_customer_by_stripe_customer_id(&identifier) - .await?; - - return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))); - } - - if identifier.starts_with("sub_") { - let billing_subscription = db - .get_billing_subscription_by_stripe_subscription_id(&identifier) - .await?; - - if let Some(billing_subscription) = billing_subscription { - let billing_customer = db - .get_billing_customer_by_id(billing_subscription.billing_customer_id) - .await?; - - return Ok( - billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)) - ); - } else { - return Ok(None); - } - } - - if identifier.contains('@') { - let user = db.get_user_by_email(identifier).await?; - - return Ok(user.map(UserOrId::User)); - } - - if let Some(user) = db.get_user_by_github_login(identifier).await? { - return Ok(Some(UserOrId::User(user))); - } - - Ok(None) -} - -#[derive(Deserialize, Debug)] -struct CreateUserParams { - github_user_id: i32, - github_login: String, - email_address: String, - email_confirmation_code: Option, - #[serde(default)] - admin: bool, - #[serde(default)] - invite_count: i32, -} - async fn get_rpc_server_snapshot( Extension(rpc_server): Extension>, ) -> Result { diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs deleted file mode 100644 index a0325d14c4..0000000000 --- a/crates/collab/src/api/billing.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::sync::Arc; -use stripe::SubscriptionStatus; - -use crate::AppState; -use crate::db::billing_subscription::StripeSubscriptionStatus; -use crate::db::{CreateBillingCustomerParams, billing_customer}; -use crate::stripe_client::{StripeClient, StripeCustomerId}; - -impl From for StripeSubscriptionStatus { - fn from(value: SubscriptionStatus) -> Self { - match value { - SubscriptionStatus::Incomplete => Self::Incomplete, - SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired, - SubscriptionStatus::Trialing => Self::Trialing, - SubscriptionStatus::Active => Self::Active, - SubscriptionStatus::PastDue => Self::PastDue, - SubscriptionStatus::Canceled => Self::Canceled, - SubscriptionStatus::Unpaid => Self::Unpaid, - SubscriptionStatus::Paused => Self::Paused, - } - } -} - -/// Finds or creates a billing customer using the provided customer. -pub async fn find_or_create_billing_customer( - app: &Arc, - stripe_client: &dyn StripeClient, - customer_id: &StripeCustomerId, -) -> anyhow::Result> { - // If we already have a billing customer record associated with the Stripe customer, - // there's nothing more we need to do. - if let Some(billing_customer) = app - .db - .get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref()) - .await? - { - return Ok(Some(billing_customer)); - } - - let customer = stripe_client.get_customer(customer_id).await?; - - let Some(email) = customer.email else { - return Ok(None); - }; - - let Some(user) = app.db.get_user_by_email(&email).await? else { - return Ok(None); - }; - - let billing_customer = app - .db - .create_billing_customer(&CreateBillingCustomerParams { - user_id: user.id, - stripe_customer_id: customer.id.to_string(), - }) - .await?; - - Ok(Some(billing_customer)) -} diff --git a/crates/collab/src/api/events.rs b/crates/collab/src/api/events.rs index 2f34a843a8..cd1dc42e64 100644 --- a/crates/collab/src/api/events.rs +++ b/crates/collab/src/api/events.rs @@ -564,170 +564,10 @@ fn for_snowflake( country_code: Option, checksum_matched: bool, ) -> impl Iterator { - body.events.into_iter().filter_map(move |event| { + body.events.into_iter().map(move |event| { let timestamp = first_event_at + Duration::milliseconds(event.milliseconds_since_first_event); - // We will need to double check, but I believe all of the events that - // are being transformed here are now migrated over to use the - // telemetry::event! macro, as of this commit so this code can go away - // when we feel enough users have upgraded past this point. let (event_type, mut event_properties) = match &event.event { - Event::Editor(e) => ( - match e.operation.as_str() { - "open" => "Editor Opened".to_string(), - "save" => "Editor Saved".to_string(), - _ => format!("Unknown Editor Event: {}", e.operation), - }, - serde_json::to_value(e).unwrap(), - ), - Event::EditPrediction(e) => ( - format!( - "Edit Prediction {}", - if e.suggestion_accepted { - "Accepted" - } else { - "Discarded" - } - ), - serde_json::to_value(e).unwrap(), - ), - Event::EditPredictionRating(e) => ( - "Edit Prediction Rated".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Call(e) => { - let event_type = match e.operation.trim() { - "unshare project" => "Project Unshared".to_string(), - "open channel notes" => "Channel Notes Opened".to_string(), - "share project" => "Project Shared".to_string(), - "join channel" => "Channel Joined".to_string(), - "hang up" => "Call Ended".to_string(), - "accept incoming" => "Incoming Call Accepted".to_string(), - "invite" => "Participant Invited".to_string(), - "disable microphone" => "Microphone Disabled".to_string(), - "enable microphone" => "Microphone Enabled".to_string(), - "enable screen share" => "Screen Share Enabled".to_string(), - "disable screen share" => "Screen Share Disabled".to_string(), - "decline incoming" => "Incoming Call Declined".to_string(), - _ => format!("Unknown Call Event: {}", e.operation), - }; - - (event_type, serde_json::to_value(e).unwrap()) - } - Event::Assistant(e) => ( - match e.phase { - telemetry_events::AssistantPhase::Response => "Assistant Responded".to_string(), - telemetry_events::AssistantPhase::Invoked => "Assistant Invoked".to_string(), - telemetry_events::AssistantPhase::Accepted => { - "Assistant Response Accepted".to_string() - } - telemetry_events::AssistantPhase::Rejected => { - "Assistant Response Rejected".to_string() - } - }, - serde_json::to_value(e).unwrap(), - ), - Event::Cpu(_) | Event::Memory(_) => return None, - Event::App(e) => { - let mut properties = json!({}); - let event_type = match e.operation.trim() { - // App - "open" => "App Opened".to_string(), - "first open" => "App First Opened".to_string(), - "first open for release channel" => { - "App First Opened For Release Channel".to_string() - } - "close" => "App Closed".to_string(), - - // Project - "open project" => "Project Opened".to_string(), - "open node project" => { - properties["project_type"] = json!("node"); - "Project Opened".to_string() - } - "open pnpm project" => { - properties["project_type"] = json!("pnpm"); - "Project Opened".to_string() - } - "open yarn project" => { - properties["project_type"] = json!("yarn"); - "Project Opened".to_string() - } - - // SSH - "create ssh server" => "SSH Server Created".to_string(), - "create ssh project" => "SSH Project Created".to_string(), - "open ssh project" => "SSH Project Opened".to_string(), - - // Welcome Page - "welcome page: change keymap" => "Welcome Keymap Changed".to_string(), - "welcome page: change theme" => "Welcome Theme Changed".to_string(), - "welcome page: close" => "Welcome Page Closed".to_string(), - "welcome page: edit settings" => "Welcome Settings Edited".to_string(), - "welcome page: install cli" => "Welcome CLI Installed".to_string(), - "welcome page: open" => "Welcome Page Opened".to_string(), - "welcome page: open extensions" => "Welcome Extensions Page Opened".to_string(), - "welcome page: sign in to copilot" => "Welcome Copilot Signed In".to_string(), - "welcome page: toggle diagnostic telemetry" => { - "Welcome Diagnostic Telemetry Toggled".to_string() - } - "welcome page: toggle metric telemetry" => { - "Welcome Metric Telemetry Toggled".to_string() - } - "welcome page: toggle vim" => "Welcome Vim Mode Toggled".to_string(), - "welcome page: view docs" => "Welcome Documentation Viewed".to_string(), - - // Extensions - "extensions page: open" => "Extensions Page Opened".to_string(), - "extensions: install extension" => "Extension Installed".to_string(), - "extensions: uninstall extension" => "Extension Uninstalled".to_string(), - - // Misc - "markdown preview: open" => "Markdown Preview Opened".to_string(), - "project diagnostics: open" => "Project Diagnostics Opened".to_string(), - "project search: open" => "Project Search Opened".to_string(), - "repl sessions: open" => "REPL Session Started".to_string(), - - // Feature Upsell - "feature upsell: toggle vim" => { - properties["source"] = json!("Feature Upsell"); - "Vim Mode Toggled".to_string() - } - _ => e - .operation - .strip_prefix("feature upsell: viewed docs (") - .and_then(|s| s.strip_suffix(')')) - .map_or_else( - || format!("Unknown App Event: {}", e.operation), - |docs_url| { - properties["url"] = json!(docs_url); - properties["source"] = json!("Feature Upsell"); - "Documentation Viewed".to_string() - }, - ), - }; - (event_type, properties) - } - Event::Setting(e) => ( - "Settings Changed".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Extension(e) => ( - "Extension Loaded".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Edit(e) => ( - "Editor Edited".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Action(e) => ( - "Action Invoked".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Repl(e) => ( - "Kernel Status Changed".to_string(), - serde_json::to_value(e).unwrap(), - ), Event::Flexible(e) => ( e.event_type.clone(), serde_json::to_value(&e.event_properties).unwrap(), @@ -759,7 +599,7 @@ fn for_snowflake( }) }); - Some(SnowflakeRow { + SnowflakeRow { time: timestamp, user_id: body.metrics_id.clone(), device_id: body.system_id.clone(), @@ -767,7 +607,7 @@ fn for_snowflake( event_properties, user_properties, insert_id: Some(Uuid::new_v4().to_string()), - }) + } }) } diff --git a/crates/collab/src/db/tables/billing_subscription.rs b/crates/collab/src/db/tables/billing_subscription.rs index 522973dbc9..f5684aeec3 100644 --- a/crates/collab/src/db/tables/billing_subscription.rs +++ b/crates/collab/src/db/tables/billing_subscription.rs @@ -1,5 +1,4 @@ use crate::db::{BillingCustomerId, BillingSubscriptionId}; -use crate::stripe_client; use chrono::{Datelike as _, NaiveDate, Utc}; use sea_orm::entity::prelude::*; use serde::Serialize; @@ -160,17 +159,3 @@ pub enum StripeCancellationReason { #[sea_orm(string_value = "payment_failed")] PaymentFailed, } - -impl From for StripeCancellationReason { - fn from(value: stripe_client::StripeCancellationDetailsReason) -> Self { - match value { - stripe_client::StripeCancellationDetailsReason::CancellationRequested => { - Self::CancellationRequested - } - stripe_client::StripeCancellationDetailsReason::PaymentDisputed => { - Self::PaymentDisputed - } - stripe_client::StripeCancellationDetailsReason::PaymentFailed => Self::PaymentFailed, - } - } -} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 905859ca69..a68286a5a3 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -7,8 +7,6 @@ pub mod llm; pub mod migrations; pub mod rpc; pub mod seed; -pub mod stripe_billing; -pub mod stripe_client; pub mod user_backfiller; #[cfg(test)] @@ -27,16 +25,12 @@ use serde::Deserialize; use std::{path::PathBuf, sync::Arc}; use util::ResultExt; -use crate::stripe_billing::StripeBilling; -use crate::stripe_client::{RealStripeClient, StripeClient}; - pub type Result = std::result::Result; pub enum Error { Http(StatusCode, String, HeaderMap), Database(sea_orm::error::DbErr), Internal(anyhow::Error), - Stripe(stripe::StripeError), } impl From for Error { @@ -51,12 +45,6 @@ impl From for Error { } } -impl From for Error { - fn from(error: stripe::StripeError) -> Self { - Self::Stripe(error) - } -} - impl From for Error { fn from(error: axum::Error) -> Self { Self::Internal(error.into()) @@ -104,14 +92,6 @@ impl IntoResponse for Error { ); (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() } - Error::Stripe(error) => { - log::error!( - "HTTP error {}: {:?}", - StatusCode::INTERNAL_SERVER_ERROR, - &error - ); - (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() - } } } } @@ -122,7 +102,6 @@ impl std::fmt::Debug for Error { Error::Http(code, message, _headers) => (code, message).fmt(f), Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), - Error::Stripe(error) => error.fmt(f), } } } @@ -133,7 +112,6 @@ impl std::fmt::Display for Error { Error::Http(code, message, _) => write!(f, "{code}: {message}"), Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), - Error::Stripe(error) => error.fmt(f), } } } @@ -179,7 +157,6 @@ pub struct Config { pub zed_client_checksum_seed: Option, pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, - pub stripe_api_key: Option, pub supermaven_admin_api_key: Option>, pub user_backfiller_github_access_token: Option>, } @@ -234,7 +211,6 @@ impl Config { auto_join_channel_id: None, migrations_path: None, seed_path: None, - stripe_api_key: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, kinesis_region: None, @@ -269,11 +245,6 @@ pub struct AppState { pub llm_db: Option>, pub livekit_client: Option>, pub blob_store_client: Option, - /// This is a real instance of the Stripe client; we're working to replace references to this with the - /// [`StripeClient`] trait. - pub real_stripe_client: Option>, - pub stripe_client: Option>, - pub stripe_billing: Option>, pub executor: Executor, pub kinesis_client: Option<::aws_sdk_kinesis::Client>, pub config: Config, @@ -316,18 +287,11 @@ impl AppState { }; let db = Arc::new(db); - let stripe_client = build_stripe_client(&config).map(Arc::new).log_err(); let this = Self { db: db.clone(), llm_db, livekit_client, blob_store_client: build_blob_store_client(&config).await.log_err(), - stripe_billing: stripe_client - .clone() - .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))), - real_stripe_client: stripe_client.clone(), - stripe_client: stripe_client - .map(|stripe_client| Arc::new(RealStripeClient::new(stripe_client)) as _), executor, kinesis_client: if config.kinesis_access_key.is_some() { build_kinesis_client(&config).await.log_err() @@ -340,14 +304,6 @@ impl AppState { } } -fn build_stripe_client(config: &Config) -> anyhow::Result { - let api_key = config - .stripe_api_key - .as_ref() - .context("missing stripe_api_key")?; - Ok(stripe::Client::new(api_key)) -} - async fn build_blob_store_client(config: &Config) -> anyhow::Result { let keys = aws_sdk_s3::config::Credentials::new( config diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index de74858168..dec10232bd 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -1,12 +1 @@ pub mod db; -mod token; - -pub use token::*; - -pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial"; - -/// The name of the feature flag that bypasses the account age check. -pub const BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG: &str = "bypass-account-age-check"; - -/// The minimum account age an account must have in order to use the LLM service. -pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30); diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs deleted file mode 100644 index da01c7f3be..0000000000 --- a/crates/collab/src/llm/token.rs +++ /dev/null @@ -1,146 +0,0 @@ -use crate::db::billing_subscription::SubscriptionKind; -use crate::db::{billing_customer, billing_subscription, user}; -use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG}; -use crate::{Config, db::billing_preference}; -use anyhow::{Context as _, Result}; -use chrono::{NaiveDateTime, Utc}; -use cloud_llm_client::Plan; -use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; -use serde::{Deserialize, Serialize}; -use std::time::Duration; -use thiserror::Error; -use uuid::Uuid; - -#[derive(Clone, Debug, Default, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct LlmTokenClaims { - pub iat: u64, - pub exp: u64, - pub jti: String, - pub user_id: u64, - pub system_id: Option, - pub metrics_id: Uuid, - pub github_user_login: String, - pub account_created_at: NaiveDateTime, - pub is_staff: bool, - pub has_llm_closed_beta_feature_flag: bool, - pub bypass_account_age_check: bool, - pub use_llm_request_queue: bool, - pub plan: Plan, - pub has_extended_trial: bool, - pub subscription_period: (NaiveDateTime, NaiveDateTime), - pub enable_model_request_overages: bool, - pub model_request_overages_spend_limit_in_cents: u32, - pub can_use_web_search_tool: bool, - #[serde(default)] - pub has_overdue_invoices: bool, -} - -const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60); - -impl LlmTokenClaims { - pub fn create( - user: &user::Model, - is_staff: bool, - billing_customer: billing_customer::Model, - billing_preferences: Option, - feature_flags: &Vec, - subscription: billing_subscription::Model, - system_id: Option, - config: &Config, - ) -> Result { - let secret = config - .llm_api_secret - .as_ref() - .context("no LLM API secret")?; - - let plan = if is_staff { - Plan::ZedPro - } else { - subscription.kind.map_or(Plan::ZedFree, |kind| match kind { - SubscriptionKind::ZedFree => Plan::ZedFree, - SubscriptionKind::ZedPro => Plan::ZedPro, - SubscriptionKind::ZedProTrial => Plan::ZedProTrial, - }) - }; - let subscription_period = - billing_subscription::Model::current_period(Some(subscription), is_staff) - .map(|(start, end)| (start.naive_utc(), end.naive_utc())) - .context("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started.")?; - - let now = Utc::now(); - let claims = Self { - iat: now.timestamp() as u64, - exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64, - jti: uuid::Uuid::new_v4().to_string(), - user_id: user.id.to_proto(), - system_id, - metrics_id: user.metrics_id, - github_user_login: user.github_login.clone(), - account_created_at: user.account_created_at(), - is_staff, - has_llm_closed_beta_feature_flag: feature_flags - .iter() - .any(|flag| flag == "llm-closed-beta"), - bypass_account_age_check: feature_flags - .iter() - .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG), - can_use_web_search_tool: true, - use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"), - plan, - has_extended_trial: feature_flags - .iter() - .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG), - subscription_period, - enable_model_request_overages: billing_preferences - .as_ref() - .map_or(false, |preferences| { - preferences.model_request_overages_enabled - }), - model_request_overages_spend_limit_in_cents: billing_preferences - .as_ref() - .map_or(0, |preferences| { - preferences.model_request_overages_spend_limit_in_cents as u32 - }), - has_overdue_invoices: billing_customer.has_overdue_invoices, - }; - - Ok(jsonwebtoken::encode( - &Header::default(), - &claims, - &EncodingKey::from_secret(secret.as_ref()), - )?) - } - - pub fn validate(token: &str, config: &Config) -> Result { - let secret = config - .llm_api_secret - .as_ref() - .context("no LLM API secret")?; - - match jsonwebtoken::decode::( - token, - &DecodingKey::from_secret(secret.as_ref()), - &Validation::default(), - ) { - Ok(token) => Ok(token.claims), - Err(e) => { - if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature { - Err(ValidateLlmTokenError::Expired) - } else { - Err(ValidateLlmTokenError::JwtError(e)) - } - } - } - } -} - -#[derive(Error, Debug)] -pub enum ValidateLlmTokenError { - #[error("access token is expired")] - Expired, - #[error("access token validation error: {0}")] - JwtError(#[from] jsonwebtoken::errors::Error), - #[error("{0}")] - Other(#[from] anyhow::Error), -} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 20641cb232..177c97f076 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -102,13 +102,6 @@ async fn main() -> Result<()> { let state = AppState::new(config, Executor::Production).await?; - if let Some(stripe_billing) = state.stripe_billing.clone() { - let executor = state.executor.clone(); - executor.spawn_detached(async move { - stripe_billing.initialize().await.trace_err(); - }); - } - if mode.is_collab() { state.db.purge_old_embeddings().await.trace_err(); diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 584970a4c6..ef749ac9b7 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,14 +1,6 @@ mod connection_pool; -use crate::api::billing::find_or_create_billing_customer; use crate::api::{CloudflareIpCountryHeader, SystemIdHeader}; -use crate::db::billing_subscription::SubscriptionKind; -use crate::llm::db::LlmDatabase; -use crate::llm::{ - AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims, - MIN_ACCOUNT_AGE_FOR_LLM_USE, -}; -use crate::stripe_client::StripeCustomerId; use crate::{ AppState, Error, Result, auth, db::{ @@ -37,7 +29,6 @@ use axum::{ response::IntoResponse, routing::get, }; -use chrono::Utc; use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; @@ -148,13 +139,6 @@ pub enum Principal { } impl Principal { - fn user(&self) -> &User { - match self { - Principal::User(user) => user, - Principal::Impersonated { user, .. } => user, - } - } - fn update_span(&self, span: &tracing::Span) { match &self { Principal::User(user) => { @@ -218,6 +202,7 @@ struct Session { /// The GeoIP country code for the user. #[allow(unused)] geoip_country_code: Option, + #[allow(unused)] system_id: Option, _executor: Executor, } @@ -463,9 +448,6 @@ impl Server { .add_request_handler(follow) .add_message_handler(unfollow) .add_message_handler(update_followers) - .add_request_handler(get_private_user_info) - .add_request_handler(get_llm_api_token) - .add_request_handler(accept_terms_of_service) .add_message_handler(acknowledge_channel_message) .add_message_handler(acknowledge_buffer_version) .add_request_handler(get_supermaven_api_key) @@ -1000,8 +982,6 @@ impl Server { .await?; } - update_user_plan(session).await?; - let contacts = self.app_state.db.get_contacts(user.id).await?; { @@ -2835,214 +2815,6 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool { version.0.minor() < 139 } -async fn current_plan(db: &Arc, user_id: UserId, is_staff: bool) -> Result { - if is_staff { - return Ok(proto::Plan::ZedPro); - } - - let subscription = db.get_active_billing_subscription(user_id).await?; - let subscription_kind = subscription.and_then(|subscription| subscription.kind); - - let plan = if let Some(subscription_kind) = subscription_kind { - match subscription_kind { - SubscriptionKind::ZedPro => proto::Plan::ZedPro, - SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial, - SubscriptionKind::ZedFree => proto::Plan::Free, - } - } else { - proto::Plan::Free - }; - - Ok(plan) -} - -async fn make_update_user_plan_message( - user: &User, - is_staff: bool, - db: &Arc, - llm_db: Option>, -) -> Result { - let feature_flags = db.get_user_flags(user.id).await?; - let plan = current_plan(db, user.id, is_staff).await?; - let billing_customer = db.get_billing_customer_by_user_id(user.id).await?; - let billing_preferences = db.get_billing_preferences(user.id).await?; - - let (subscription_period, usage) = if let Some(llm_db) = llm_db { - let subscription = db.get_active_billing_subscription(user.id).await?; - - let subscription_period = - crate::db::billing_subscription::Model::current_period(subscription, is_staff); - - let usage = if let Some((period_start_at, period_end_at)) = subscription_period { - llm_db - .get_subscription_usage_for_period(user.id, period_start_at, period_end_at) - .await? - } else { - None - }; - - (subscription_period, usage) - } else { - (None, None) - }; - - let bypass_account_age_check = feature_flags - .iter() - .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG); - let account_too_young = !matches!(plan, proto::Plan::ZedPro) - && !bypass_account_age_check - && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE; - - Ok(proto::UpdateUserPlan { - plan: plan.into(), - trial_started_at: billing_customer - .as_ref() - .and_then(|billing_customer| billing_customer.trial_started_at) - .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64), - is_usage_based_billing_enabled: if is_staff { - Some(true) - } else { - billing_preferences.map(|preferences| preferences.model_request_overages_enabled) - }, - subscription_period: subscription_period.map(|(started_at, ended_at)| { - proto::SubscriptionPeriod { - started_at: started_at.timestamp() as u64, - ended_at: ended_at.timestamp() as u64, - } - }), - account_too_young: Some(account_too_young), - has_overdue_invoices: billing_customer - .map(|billing_customer| billing_customer.has_overdue_invoices), - usage: Some( - usage - .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags)) - .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)), - ), - }) -} - -fn model_requests_limit( - plan: cloud_llm_client::Plan, - feature_flags: &Vec, -) -> cloud_llm_client::UsageLimit { - match plan.model_requests_limit() { - cloud_llm_client::UsageLimit::Limited(limit) => { - let limit = if plan == cloud_llm_client::Plan::ZedProTrial - && feature_flags - .iter() - .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG) - { - 1_000 - } else { - limit - }; - - cloud_llm_client::UsageLimit::Limited(limit) - } - cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited, - } -} - -fn subscription_usage_to_proto( - plan: proto::Plan, - usage: crate::llm::db::subscription_usage::Model, - feature_flags: &Vec, -) -> proto::SubscriptionUsage { - let plan = match plan { - proto::Plan::Free => cloud_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, - }; - - proto::SubscriptionUsage { - model_requests_usage_amount: usage.model_requests as u32, - model_requests_usage_limit: Some(proto::UsageLimit { - variant: Some(match model_requests_limit(plan, feature_flags) { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - edit_predictions_usage_amount: usage.edit_predictions as u32, - edit_predictions_usage_limit: Some(proto::UsageLimit { - variant: Some(match plan.edit_predictions_limit() { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - } -} - -fn make_default_subscription_usage( - plan: proto::Plan, - feature_flags: &Vec, -) -> proto::SubscriptionUsage { - let plan = match plan { - proto::Plan::Free => cloud_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, - }; - - proto::SubscriptionUsage { - model_requests_usage_amount: 0, - model_requests_usage_limit: Some(proto::UsageLimit { - variant: Some(match model_requests_limit(plan, feature_flags) { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - edit_predictions_usage_amount: 0, - edit_predictions_usage_limit: Some(proto::UsageLimit { - variant: Some(match plan.edit_predictions_limit() { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - } -} - -async fn update_user_plan(session: &Session) -> Result<()> { - let db = session.db().await; - - let update_user_plan = make_update_user_plan_message( - session.principal.user(), - session.is_staff(), - &db.0, - session.app_state.llm_db.clone(), - ) - .await?; - - session - .peer - .send(session.connection_id, update_user_plan) - .trace_err(); - - Ok(()) -} - async fn subscribe_to_channels( _: proto::SubscribeToChannels, session: MessageContext, @@ -4211,139 +3983,6 @@ async fn mark_notification_as_read( Ok(()) } -/// Get the current users information -async fn get_private_user_info( - _request: proto::GetPrivateUserInfo, - response: Response, - session: MessageContext, -) -> Result<()> { - let db = session.db().await; - - let metrics_id = db.get_user_metrics_id(session.user_id()).await?; - let user = db - .get_user_by_id(session.user_id()) - .await? - .context("user not found")?; - let flags = db.get_user_flags(session.user_id()).await?; - - response.send(proto::GetPrivateUserInfoResponse { - metrics_id, - staff: user.admin, - flags, - accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64), - })?; - Ok(()) -} - -/// Accept the terms of service (tos) on behalf of the current user -async fn accept_terms_of_service( - _request: proto::AcceptTermsOfService, - response: Response, - session: MessageContext, -) -> Result<()> { - let db = session.db().await; - - let accepted_tos_at = Utc::now(); - db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc())) - .await?; - - response.send(proto::AcceptTermsOfServiceResponse { - accepted_tos_at: accepted_tos_at.timestamp() as u64, - })?; - - // When the user accepts the terms of service, we want to refresh their LLM - // token to grant access. - session - .peer - .send(session.connection_id, proto::RefreshLlmToken {})?; - - Ok(()) -} - -async fn get_llm_api_token( - _request: proto::GetLlmToken, - response: Response, - session: MessageContext, -) -> Result<()> { - let db = session.db().await; - - let flags = db.get_user_flags(session.user_id()).await?; - - let user_id = session.user_id(); - let user = db - .get_user_by_id(user_id) - .await? - .with_context(|| format!("user {user_id} not found"))?; - - if user.accepted_tos_at.is_none() { - Err(anyhow!("terms of service not accepted"))? - } - - let stripe_client = session - .app_state - .stripe_client - .as_ref() - .context("failed to retrieve Stripe client")?; - - let stripe_billing = session - .app_state - .stripe_billing - .as_ref() - .context("failed to retrieve Stripe billing object")?; - - let billing_customer = if let Some(billing_customer) = - db.get_billing_customer_by_user_id(user.id).await? - { - billing_customer - } else { - let customer_id = stripe_billing - .find_or_create_customer_by_email(user.email_address.as_deref()) - .await?; - - find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id) - .await? - .context("billing customer not found")? - }; - - let billing_subscription = - if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? { - billing_subscription - } else { - let stripe_customer_id = - StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); - - let stripe_subscription = stripe_billing - .subscribe_to_zed_free(stripe_customer_id) - .await?; - - db.create_billing_subscription(&db::CreateBillingSubscriptionParams { - billing_customer_id: billing_customer.id, - kind: Some(SubscriptionKind::ZedFree), - stripe_subscription_id: stripe_subscription.id.to_string(), - stripe_subscription_status: stripe_subscription.status.into(), - stripe_cancellation_reason: None, - stripe_current_period_start: Some(stripe_subscription.current_period_start), - stripe_current_period_end: Some(stripe_subscription.current_period_end), - }) - .await? - }; - - let billing_preferences = db.get_billing_preferences(user.id).await?; - - let token = LlmTokenClaims::create( - &user, - session.is_staff(), - billing_customer, - billing_preferences, - &flags, - billing_subscription, - session.system_id.clone(), - &session.app_state.config, - )?; - response.send(proto::GetLlmTokenResponse { token })?; - Ok(()) -} - fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result { let message = match message { TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()), diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs deleted file mode 100644 index ef5bef3e7e..0000000000 --- a/crates/collab/src/stripe_billing.rs +++ /dev/null @@ -1,156 +0,0 @@ -use std::sync::Arc; - -use anyhow::anyhow; -use collections::HashMap; -use stripe::SubscriptionStatus; -use tokio::sync::RwLock; - -use crate::Result; -use crate::stripe_client::{ - RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateSubscriptionItems, - StripeCreateSubscriptionParams, StripeCustomerId, StripePrice, StripePriceId, - StripeSubscription, -}; - -pub struct StripeBilling { - state: RwLock, - client: Arc, -} - -#[derive(Default)] -struct StripeBillingState { - prices_by_lookup_key: HashMap, -} - -impl StripeBilling { - pub fn new(client: Arc) -> Self { - Self { - client: Arc::new(RealStripeClient::new(client.clone())), - state: RwLock::default(), - } - } - - #[cfg(test)] - pub fn test(client: Arc) -> Self { - Self { - client, - state: RwLock::default(), - } - } - - pub fn client(&self) -> &Arc { - &self.client - } - - pub async fn initialize(&self) -> Result<()> { - log::info!("StripeBilling: initializing"); - - let mut state = self.state.write().await; - - let prices = self.client.list_prices().await?; - - for price in prices { - if let Some(lookup_key) = price.lookup_key.clone() { - state.prices_by_lookup_key.insert(lookup_key, price); - } - } - - log::info!("StripeBilling: initialized"); - - Ok(()) - } - - pub async fn zed_pro_price_id(&self) -> Result { - self.find_price_id_by_lookup_key("zed-pro").await - } - - pub async fn zed_free_price_id(&self) -> Result { - self.find_price_id_by_lookup_key("zed-free").await - } - - pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result { - self.state - .read() - .await - .prices_by_lookup_key - .get(lookup_key) - .map(|price| price.id.clone()) - .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}"))) - } - - pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result { - self.state - .read() - .await - .prices_by_lookup_key - .get(lookup_key) - .cloned() - .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}"))) - } - - /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does - /// not already exist. - /// - /// Always returns a new Stripe customer if the email address is `None`. - pub async fn find_or_create_customer_by_email( - &self, - email_address: Option<&str>, - ) -> Result { - let existing_customer = if let Some(email) = email_address { - let customers = self.client.list_customers_by_email(email).await?; - - customers.first().cloned() - } else { - None - }; - - let customer_id = if let Some(existing_customer) = existing_customer { - existing_customer.id - } else { - let customer = self - .client - .create_customer(crate::stripe_client::CreateCustomerParams { - email: email_address, - }) - .await?; - - customer.id - }; - - Ok(customer_id) - } - - pub async fn subscribe_to_zed_free( - &self, - customer_id: StripeCustomerId, - ) -> Result { - let zed_free_price_id = self.zed_free_price_id().await?; - - let existing_subscriptions = self - .client - .list_subscriptions_for_customer(&customer_id) - .await?; - - let existing_active_subscription = - existing_subscriptions.into_iter().find(|subscription| { - subscription.status == SubscriptionStatus::Active - || subscription.status == SubscriptionStatus::Trialing - }); - if let Some(subscription) = existing_active_subscription { - return Ok(subscription); - } - - let params = StripeCreateSubscriptionParams { - customer: customer_id, - items: vec![StripeCreateSubscriptionItems { - price: Some(zed_free_price_id), - quantity: Some(1), - }], - automatic_tax: Some(StripeAutomaticTax { enabled: true }), - }; - - let subscription = self.client.create_subscription(params).await?; - - Ok(subscription) - } -} diff --git a/crates/collab/src/stripe_client.rs b/crates/collab/src/stripe_client.rs deleted file mode 100644 index 6e75a4d874..0000000000 --- a/crates/collab/src/stripe_client.rs +++ /dev/null @@ -1,285 +0,0 @@ -#[cfg(test)] -mod fake_stripe_client; -mod real_stripe_client; - -use std::collections::HashMap; -use std::sync::Arc; - -use anyhow::Result; -use async_trait::async_trait; - -#[cfg(test)] -pub use fake_stripe_client::*; -pub use real_stripe_client::*; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Serialize)] -pub struct StripeCustomerId(pub Arc); - -#[derive(Debug, Clone)] -pub struct StripeCustomer { - pub id: StripeCustomerId, - pub email: Option, -} - -#[derive(Debug)] -pub struct CreateCustomerParams<'a> { - pub email: Option<&'a str>, -} - -#[derive(Debug)] -pub struct UpdateCustomerParams<'a> { - pub email: Option<&'a str>, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] -pub struct StripeSubscriptionId(pub Arc); - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscription { - pub id: StripeSubscriptionId, - pub customer: StripeCustomerId, - // TODO: Create our own version of this enum. - pub status: stripe::SubscriptionStatus, - pub current_period_end: i64, - pub current_period_start: i64, - pub items: Vec, - pub cancel_at: Option, - pub cancellation_details: Option, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] -pub struct StripeSubscriptionItemId(pub Arc); - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscriptionItem { - pub id: StripeSubscriptionItemId, - pub price: Option, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct StripeCancellationDetails { - pub reason: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCancellationDetailsReason { - CancellationRequested, - PaymentDisputed, - PaymentFailed, -} - -#[derive(Debug)] -pub struct StripeCreateSubscriptionParams { - pub customer: StripeCustomerId, - pub items: Vec, - pub automatic_tax: Option, -} - -#[derive(Debug)] -pub struct StripeCreateSubscriptionItems { - pub price: Option, - pub quantity: Option, -} - -#[derive(Debug, Clone)] -pub struct UpdateSubscriptionParams { - pub items: Option>, - pub trial_settings: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct UpdateSubscriptionItems { - pub price: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscriptionTrialSettings { - pub end_behavior: StripeSubscriptionTrialSettingsEndBehavior, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscriptionTrialSettingsEndBehavior { - pub missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod { - Cancel, - CreateInvoice, - Pause, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] -pub struct StripePriceId(pub Arc); - -#[derive(Debug, PartialEq, Clone)] -pub struct StripePrice { - pub id: StripePriceId, - pub unit_amount: Option, - pub lookup_key: Option, - pub recurring: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripePriceRecurring { - pub meter: Option, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Deserialize)] -pub struct StripeMeterId(pub Arc); - -#[derive(Debug, Clone, Deserialize)] -pub struct StripeMeter { - pub id: StripeMeterId, - pub event_name: String, -} - -#[derive(Debug, Serialize)] -pub struct StripeCreateMeterEventParams<'a> { - pub identifier: &'a str, - pub event_name: &'a str, - pub payload: StripeCreateMeterEventPayload<'a>, - pub timestamp: Option, -} - -#[derive(Debug, Serialize)] -pub struct StripeCreateMeterEventPayload<'a> { - pub value: u64, - pub stripe_customer_id: &'a StripeCustomerId, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeBillingAddressCollection { - Auto, - Required, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeCustomerUpdate { - pub address: Option, - pub name: Option, - pub shipping: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCustomerUpdateAddress { - Auto, - Never, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCustomerUpdateName { - Auto, - Never, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCustomerUpdateShipping { - Auto, - Never, -} - -#[derive(Debug, Default)] -pub struct StripeCreateCheckoutSessionParams<'a> { - pub customer: Option<&'a StripeCustomerId>, - pub client_reference_id: Option<&'a str>, - pub mode: Option, - pub line_items: Option>, - pub payment_method_collection: Option, - pub subscription_data: Option, - pub success_url: Option<&'a str>, - pub billing_address_collection: Option, - pub customer_update: Option, - pub tax_id_collection: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCheckoutSessionMode { - Payment, - Setup, - Subscription, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeCreateCheckoutSessionLineItems { - pub price: Option, - pub quantity: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCheckoutSessionPaymentMethodCollection { - Always, - IfRequired, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeCreateCheckoutSessionSubscriptionData { - pub metadata: Option>, - pub trial_period_days: Option, - pub trial_settings: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeTaxIdCollection { - pub enabled: bool, -} - -#[derive(Debug, Clone)] -pub struct StripeAutomaticTax { - pub enabled: bool, -} - -#[derive(Debug)] -pub struct StripeCheckoutSession { - pub url: Option, -} - -#[async_trait] -pub trait StripeClient: Send + Sync { - async fn list_customers_by_email(&self, email: &str) -> Result>; - - async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result; - - async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result; - - async fn update_customer( - &self, - customer_id: &StripeCustomerId, - params: UpdateCustomerParams<'_>, - ) -> Result; - - async fn list_subscriptions_for_customer( - &self, - customer_id: &StripeCustomerId, - ) -> Result>; - - async fn get_subscription( - &self, - subscription_id: &StripeSubscriptionId, - ) -> Result; - - async fn create_subscription( - &self, - params: StripeCreateSubscriptionParams, - ) -> Result; - - async fn update_subscription( - &self, - subscription_id: &StripeSubscriptionId, - params: UpdateSubscriptionParams, - ) -> Result<()>; - - async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()>; - - async fn list_prices(&self) -> Result>; - - async fn list_meters(&self) -> Result>; - - async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>; - - async fn create_checkout_session( - &self, - params: StripeCreateCheckoutSessionParams<'_>, - ) -> Result; -} diff --git a/crates/collab/src/stripe_client/fake_stripe_client.rs b/crates/collab/src/stripe_client/fake_stripe_client.rs deleted file mode 100644 index 9bb08443ec..0000000000 --- a/crates/collab/src/stripe_client/fake_stripe_client.rs +++ /dev/null @@ -1,247 +0,0 @@ -use std::sync::Arc; - -use anyhow::{Result, anyhow}; -use async_trait::async_trait; -use chrono::{Duration, Utc}; -use collections::HashMap; -use parking_lot::Mutex; -use uuid::Uuid; - -use crate::stripe_client::{ - CreateCustomerParams, StripeBillingAddressCollection, StripeCheckoutSession, - StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient, - StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, - StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, - StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate, - StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription, - StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, StripeTaxIdCollection, - UpdateCustomerParams, UpdateSubscriptionParams, -}; - -#[derive(Debug, Clone)] -pub struct StripeCreateMeterEventCall { - pub identifier: Arc, - pub event_name: Arc, - pub value: u64, - pub stripe_customer_id: StripeCustomerId, - pub timestamp: Option, -} - -#[derive(Debug, Clone)] -pub struct StripeCreateCheckoutSessionCall { - pub customer: Option, - pub client_reference_id: Option, - pub mode: Option, - pub line_items: Option>, - pub payment_method_collection: Option, - pub subscription_data: Option, - pub success_url: Option, - pub billing_address_collection: Option, - pub customer_update: Option, - pub tax_id_collection: Option, -} - -pub struct FakeStripeClient { - pub customers: Arc>>, - pub subscriptions: Arc>>, - pub update_subscription_calls: - Arc>>, - pub prices: Arc>>, - pub meters: Arc>>, - pub create_meter_event_calls: Arc>>, - pub create_checkout_session_calls: Arc>>, -} - -impl FakeStripeClient { - pub fn new() -> Self { - Self { - customers: Arc::new(Mutex::new(HashMap::default())), - subscriptions: Arc::new(Mutex::new(HashMap::default())), - update_subscription_calls: Arc::new(Mutex::new(Vec::new())), - prices: Arc::new(Mutex::new(HashMap::default())), - meters: Arc::new(Mutex::new(HashMap::default())), - create_meter_event_calls: Arc::new(Mutex::new(Vec::new())), - create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())), - } - } -} - -#[async_trait] -impl StripeClient for FakeStripeClient { - async fn list_customers_by_email(&self, email: &str) -> Result> { - Ok(self - .customers - .lock() - .values() - .filter(|customer| customer.email.as_deref() == Some(email)) - .cloned() - .collect()) - } - - async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result { - self.customers - .lock() - .get(customer_id) - .cloned() - .ok_or_else(|| anyhow!("no customer found for {customer_id:?}")) - } - - async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result { - let customer = StripeCustomer { - id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()), - email: params.email.map(|email| email.to_string()), - }; - - self.customers - .lock() - .insert(customer.id.clone(), customer.clone()); - - Ok(customer) - } - - async fn update_customer( - &self, - customer_id: &StripeCustomerId, - params: UpdateCustomerParams<'_>, - ) -> Result { - let mut customers = self.customers.lock(); - if let Some(customer) = customers.get_mut(customer_id) { - if let Some(email) = params.email { - customer.email = Some(email.to_string()); - } - Ok(customer.clone()) - } else { - Err(anyhow!("no customer found for {customer_id:?}")) - } - } - - async fn list_subscriptions_for_customer( - &self, - customer_id: &StripeCustomerId, - ) -> Result> { - let subscriptions = self - .subscriptions - .lock() - .values() - .filter(|subscription| subscription.customer == *customer_id) - .cloned() - .collect(); - - Ok(subscriptions) - } - - async fn get_subscription( - &self, - subscription_id: &StripeSubscriptionId, - ) -> Result { - self.subscriptions - .lock() - .get(subscription_id) - .cloned() - .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}")) - } - - async fn create_subscription( - &self, - params: StripeCreateSubscriptionParams, - ) -> Result { - let now = Utc::now(); - - let subscription = StripeSubscription { - id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()), - customer: params.customer, - status: stripe::SubscriptionStatus::Active, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(30)).timestamp(), - items: params - .items - .into_iter() - .map(|item| StripeSubscriptionItem { - id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()), - price: item - .price - .and_then(|price_id| self.prices.lock().get(&price_id).cloned()), - }) - .collect(), - cancel_at: None, - cancellation_details: None, - }; - - self.subscriptions - .lock() - .insert(subscription.id.clone(), subscription.clone()); - - Ok(subscription) - } - - async fn update_subscription( - &self, - subscription_id: &StripeSubscriptionId, - params: UpdateSubscriptionParams, - ) -> Result<()> { - let subscription = self.get_subscription(subscription_id).await?; - - self.update_subscription_calls - .lock() - .push((subscription.id, params)); - - Ok(()) - } - - async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> { - // TODO: Implement fake subscription cancellation. - let _ = subscription_id; - - Ok(()) - } - - async fn list_prices(&self) -> Result> { - let prices = self.prices.lock().values().cloned().collect(); - - Ok(prices) - } - - async fn list_meters(&self) -> Result> { - let meters = self.meters.lock().values().cloned().collect(); - - Ok(meters) - } - - async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> { - self.create_meter_event_calls - .lock() - .push(StripeCreateMeterEventCall { - identifier: params.identifier.into(), - event_name: params.event_name.into(), - value: params.payload.value, - stripe_customer_id: params.payload.stripe_customer_id.clone(), - timestamp: params.timestamp, - }); - - Ok(()) - } - - async fn create_checkout_session( - &self, - params: StripeCreateCheckoutSessionParams<'_>, - ) -> Result { - self.create_checkout_session_calls - .lock() - .push(StripeCreateCheckoutSessionCall { - customer: params.customer.cloned(), - client_reference_id: params.client_reference_id.map(|id| id.to_string()), - mode: params.mode, - line_items: params.line_items, - payment_method_collection: params.payment_method_collection, - subscription_data: params.subscription_data, - success_url: params.success_url.map(|url| url.to_string()), - billing_address_collection: params.billing_address_collection, - customer_update: params.customer_update, - tax_id_collection: params.tax_id_collection, - }); - - Ok(StripeCheckoutSession { - url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()), - }) - } -} diff --git a/crates/collab/src/stripe_client/real_stripe_client.rs b/crates/collab/src/stripe_client/real_stripe_client.rs deleted file mode 100644 index 07c191ff30..0000000000 --- a/crates/collab/src/stripe_client/real_stripe_client.rs +++ /dev/null @@ -1,612 +0,0 @@ -use std::str::FromStr as _; -use std::sync::Arc; - -use anyhow::{Context as _, Result, anyhow}; -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; -use stripe::{ - CancellationDetails, CancellationDetailsReason, CheckoutSession, CheckoutSessionMode, - CheckoutSessionPaymentMethodCollection, CreateCheckoutSession, CreateCheckoutSessionLineItems, - CreateCheckoutSessionSubscriptionData, CreateCheckoutSessionSubscriptionDataTrialSettings, - CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior, - CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod, - CreateCustomer, CreateSubscriptionAutomaticTax, Customer, CustomerId, ListCustomers, Price, - PriceId, Recurring, Subscription, SubscriptionId, SubscriptionItem, SubscriptionItemId, - UpdateCustomer, UpdateSubscriptionItems, UpdateSubscriptionTrialSettings, - UpdateSubscriptionTrialSettingsEndBehavior, - UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, -}; - -use crate::stripe_client::{ - CreateCustomerParams, StripeAutomaticTax, StripeBillingAddressCollection, - StripeCancellationDetails, StripeCancellationDetailsReason, StripeCheckoutSession, - StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient, - StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, - StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, - StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate, - StripeCustomerUpdateAddress, StripeCustomerUpdateName, StripeCustomerUpdateShipping, - StripeMeter, StripePrice, StripePriceId, StripePriceRecurring, StripeSubscription, - StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, - StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, StripeTaxIdCollection, - UpdateCustomerParams, UpdateSubscriptionParams, -}; - -pub struct RealStripeClient { - client: Arc, -} - -impl RealStripeClient { - pub fn new(client: Arc) -> Self { - Self { client } - } -} - -#[async_trait] -impl StripeClient for RealStripeClient { - async fn list_customers_by_email(&self, email: &str) -> Result> { - let response = Customer::list( - &self.client, - &ListCustomers { - email: Some(email), - ..Default::default() - }, - ) - .await?; - - Ok(response - .data - .into_iter() - .map(StripeCustomer::from) - .collect()) - } - - async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result { - let customer_id = customer_id.try_into()?; - - let customer = Customer::retrieve(&self.client, &customer_id, &[]).await?; - - Ok(StripeCustomer::from(customer)) - } - - async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result { - let customer = Customer::create( - &self.client, - CreateCustomer { - email: params.email, - ..Default::default() - }, - ) - .await?; - - Ok(StripeCustomer::from(customer)) - } - - async fn update_customer( - &self, - customer_id: &StripeCustomerId, - params: UpdateCustomerParams<'_>, - ) -> Result { - let customer = Customer::update( - &self.client, - &customer_id.try_into()?, - UpdateCustomer { - email: params.email, - ..Default::default() - }, - ) - .await?; - - Ok(StripeCustomer::from(customer)) - } - - async fn list_subscriptions_for_customer( - &self, - customer_id: &StripeCustomerId, - ) -> Result> { - let customer_id = customer_id.try_into()?; - - let subscriptions = stripe::Subscription::list( - &self.client, - &stripe::ListSubscriptions { - customer: Some(customer_id), - status: None, - ..Default::default() - }, - ) - .await?; - - Ok(subscriptions - .data - .into_iter() - .map(StripeSubscription::from) - .collect()) - } - - async fn get_subscription( - &self, - subscription_id: &StripeSubscriptionId, - ) -> Result { - let subscription_id = subscription_id.try_into()?; - - let subscription = Subscription::retrieve(&self.client, &subscription_id, &[]).await?; - - Ok(StripeSubscription::from(subscription)) - } - - async fn create_subscription( - &self, - params: StripeCreateSubscriptionParams, - ) -> Result { - let customer_id = params.customer.try_into()?; - - let mut create_subscription = stripe::CreateSubscription::new(customer_id); - create_subscription.items = Some( - params - .items - .into_iter() - .map(|item| stripe::CreateSubscriptionItems { - price: item.price.map(|price| price.to_string()), - quantity: item.quantity, - ..Default::default() - }) - .collect(), - ); - create_subscription.automatic_tax = params.automatic_tax.map(Into::into); - - let subscription = Subscription::create(&self.client, create_subscription).await?; - - Ok(StripeSubscription::from(subscription)) - } - - async fn update_subscription( - &self, - subscription_id: &StripeSubscriptionId, - params: UpdateSubscriptionParams, - ) -> Result<()> { - let subscription_id = subscription_id.try_into()?; - - stripe::Subscription::update( - &self.client, - &subscription_id, - stripe::UpdateSubscription { - items: params.items.map(|items| { - items - .into_iter() - .map(|item| UpdateSubscriptionItems { - price: item.price.map(|price| price.to_string()), - ..Default::default() - }) - .collect() - }), - trial_settings: params.trial_settings.map(Into::into), - ..Default::default() - }, - ) - .await?; - - Ok(()) - } - - async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> { - let subscription_id = subscription_id.try_into()?; - - Subscription::cancel( - &self.client, - &subscription_id, - stripe::CancelSubscription { - invoice_now: None, - ..Default::default() - }, - ) - .await?; - - Ok(()) - } - - async fn list_prices(&self) -> Result> { - let response = stripe::Price::list( - &self.client, - &stripe::ListPrices { - limit: Some(100), - ..Default::default() - }, - ) - .await?; - - Ok(response.data.into_iter().map(StripePrice::from).collect()) - } - - async fn list_meters(&self) -> Result> { - #[derive(Serialize)] - struct Params { - #[serde(skip_serializing_if = "Option::is_none")] - limit: Option, - } - - let response = self - .client - .get_query::, _>( - "/billing/meters", - Params { limit: Some(100) }, - ) - .await?; - - Ok(response.data) - } - - async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> { - #[derive(Deserialize)] - struct StripeMeterEvent { - pub identifier: String, - } - - let identifier = params.identifier; - match self - .client - .post_form::("/billing/meter_events", params) - .await - { - Ok(_event) => Ok(()), - Err(stripe::StripeError::Stripe(error)) => { - if error.http_status == 400 - && error - .message - .as_ref() - .map_or(false, |message| message.contains(identifier)) - { - Ok(()) - } else { - Err(anyhow!(stripe::StripeError::Stripe(error))) - } - } - Err(error) => Err(anyhow!("failed to create meter event: {error:?}")), - } - } - - async fn create_checkout_session( - &self, - params: StripeCreateCheckoutSessionParams<'_>, - ) -> Result { - let params = params.try_into()?; - let session = CheckoutSession::create(&self.client, params).await?; - - Ok(session.into()) - } -} - -impl From for StripeCustomerId { - fn from(value: CustomerId) -> Self { - Self(value.as_str().into()) - } -} - -impl TryFrom for CustomerId { - type Error = anyhow::Error; - - fn try_from(value: StripeCustomerId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID") - } -} - -impl TryFrom<&StripeCustomerId> for CustomerId { - type Error = anyhow::Error; - - fn try_from(value: &StripeCustomerId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID") - } -} - -impl From for StripeCustomer { - fn from(value: Customer) -> Self { - StripeCustomer { - id: value.id.into(), - email: value.email, - } - } -} - -impl From for StripeSubscriptionId { - fn from(value: SubscriptionId) -> Self { - Self(value.as_str().into()) - } -} - -impl TryFrom<&StripeSubscriptionId> for SubscriptionId { - type Error = anyhow::Error; - - fn try_from(value: &StripeSubscriptionId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe subscription ID") - } -} - -impl From for StripeSubscription { - fn from(value: Subscription) -> Self { - Self { - id: value.id.into(), - customer: value.customer.id().into(), - status: value.status, - current_period_start: value.current_period_start, - current_period_end: value.current_period_end, - items: value.items.data.into_iter().map(Into::into).collect(), - cancel_at: value.cancel_at, - cancellation_details: value.cancellation_details.map(Into::into), - } - } -} - -impl From for StripeCancellationDetails { - fn from(value: CancellationDetails) -> Self { - Self { - reason: value.reason.map(Into::into), - } - } -} - -impl From for StripeCancellationDetailsReason { - fn from(value: CancellationDetailsReason) -> Self { - match value { - CancellationDetailsReason::CancellationRequested => Self::CancellationRequested, - CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed, - CancellationDetailsReason::PaymentFailed => Self::PaymentFailed, - } - } -} - -impl From for StripeSubscriptionItemId { - fn from(value: SubscriptionItemId) -> Self { - Self(value.as_str().into()) - } -} - -impl From for StripeSubscriptionItem { - fn from(value: SubscriptionItem) -> Self { - Self { - id: value.id.into(), - price: value.price.map(Into::into), - } - } -} - -impl From for CreateSubscriptionAutomaticTax { - fn from(value: StripeAutomaticTax) -> Self { - Self { - enabled: value.enabled, - liability: None, - } - } -} - -impl From for UpdateSubscriptionTrialSettings { - fn from(value: StripeSubscriptionTrialSettings) -> Self { - Self { - end_behavior: value.end_behavior.into(), - } - } -} - -impl From - for UpdateSubscriptionTrialSettingsEndBehavior -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self { - Self { - missing_payment_method: value.missing_payment_method.into(), - } - } -} - -impl From - for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self { - match value { - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => { - Self::CreateInvoice - } - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause, - } - } -} - -impl From for StripePriceId { - fn from(value: PriceId) -> Self { - Self(value.as_str().into()) - } -} - -impl TryFrom for PriceId { - type Error = anyhow::Error; - - fn try_from(value: StripePriceId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID") - } -} - -impl From for StripePrice { - fn from(value: Price) -> Self { - Self { - id: value.id.into(), - unit_amount: value.unit_amount, - lookup_key: value.lookup_key, - recurring: value.recurring.map(StripePriceRecurring::from), - } - } -} - -impl From for StripePriceRecurring { - fn from(value: Recurring) -> Self { - Self { meter: value.meter } - } -} - -impl<'a> TryFrom> for CreateCheckoutSession<'a> { - type Error = anyhow::Error; - - fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result { - Ok(Self { - customer: value - .customer - .map(|customer_id| customer_id.try_into()) - .transpose()?, - client_reference_id: value.client_reference_id, - mode: value.mode.map(Into::into), - line_items: value - .line_items - .map(|line_items| line_items.into_iter().map(Into::into).collect()), - payment_method_collection: value.payment_method_collection.map(Into::into), - subscription_data: value.subscription_data.map(Into::into), - success_url: value.success_url, - billing_address_collection: value.billing_address_collection.map(Into::into), - customer_update: value.customer_update.map(Into::into), - tax_id_collection: value.tax_id_collection.map(Into::into), - ..Default::default() - }) - } -} - -impl From for CheckoutSessionMode { - fn from(value: StripeCheckoutSessionMode) -> Self { - match value { - StripeCheckoutSessionMode::Payment => Self::Payment, - StripeCheckoutSessionMode::Setup => Self::Setup, - StripeCheckoutSessionMode::Subscription => Self::Subscription, - } - } -} - -impl From for CreateCheckoutSessionLineItems { - fn from(value: StripeCreateCheckoutSessionLineItems) -> Self { - Self { - price: value.price, - quantity: value.quantity, - ..Default::default() - } - } -} - -impl From for CheckoutSessionPaymentMethodCollection { - fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self { - match value { - StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always, - StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired, - } - } -} - -impl From for CreateCheckoutSessionSubscriptionData { - fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self { - Self { - trial_period_days: value.trial_period_days, - trial_settings: value.trial_settings.map(Into::into), - metadata: value.metadata, - ..Default::default() - } - } -} - -impl From for CreateCheckoutSessionSubscriptionDataTrialSettings { - fn from(value: StripeSubscriptionTrialSettings) -> Self { - Self { - end_behavior: value.end_behavior.into(), - } - } -} - -impl From - for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self { - Self { - missing_payment_method: value.missing_payment_method.into(), - } - } -} - -impl From - for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self { - match value { - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => { - Self::CreateInvoice - } - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause, - } - } -} - -impl From for StripeCheckoutSession { - fn from(value: CheckoutSession) -> Self { - Self { url: value.url } - } -} - -impl From for stripe::CheckoutSessionBillingAddressCollection { - fn from(value: StripeBillingAddressCollection) -> Self { - match value { - StripeBillingAddressCollection::Auto => { - stripe::CheckoutSessionBillingAddressCollection::Auto - } - StripeBillingAddressCollection::Required => { - stripe::CheckoutSessionBillingAddressCollection::Required - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdateAddress { - fn from(value: StripeCustomerUpdateAddress) -> Self { - match value { - StripeCustomerUpdateAddress::Auto => { - stripe::CreateCheckoutSessionCustomerUpdateAddress::Auto - } - StripeCustomerUpdateAddress::Never => { - stripe::CreateCheckoutSessionCustomerUpdateAddress::Never - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdateName { - fn from(value: StripeCustomerUpdateName) -> Self { - match value { - StripeCustomerUpdateName::Auto => stripe::CreateCheckoutSessionCustomerUpdateName::Auto, - StripeCustomerUpdateName::Never => { - stripe::CreateCheckoutSessionCustomerUpdateName::Never - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdateShipping { - fn from(value: StripeCustomerUpdateShipping) -> Self { - match value { - StripeCustomerUpdateShipping::Auto => { - stripe::CreateCheckoutSessionCustomerUpdateShipping::Auto - } - StripeCustomerUpdateShipping::Never => { - stripe::CreateCheckoutSessionCustomerUpdateShipping::Never - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdate { - fn from(value: StripeCustomerUpdate) -> Self { - stripe::CreateCheckoutSessionCustomerUpdate { - address: value.address.map(Into::into), - name: value.name.map(Into::into), - shipping: value.shipping.map(Into::into), - } - } -} - -impl From for stripe::CreateCheckoutSessionTaxIdCollection { - fn from(value: StripeTaxIdCollection) -> Self { - stripe::CreateCheckoutSessionTaxIdCollection { - enabled: value.enabled, - } - } -} diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 8d5d076780..ddf245b06f 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -8,7 +8,6 @@ mod channel_buffer_tests; mod channel_guest_tests; mod channel_message_tests; mod channel_tests; -// mod debug_panel_tests; mod editor_tests; mod following_tests; mod git_tests; @@ -18,7 +17,6 @@ mod random_channel_buffer_tests; mod random_project_collaboration_tests; mod randomized_test_helpers; mod remote_editing_collaboration_tests; -mod stripe_billing_tests; mod test_server; use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; diff --git a/crates/collab/src/tests/stripe_billing_tests.rs b/crates/collab/src/tests/stripe_billing_tests.rs deleted file mode 100644 index bb84bedfcf..0000000000 --- a/crates/collab/src/tests/stripe_billing_tests.rs +++ /dev/null @@ -1,123 +0,0 @@ -use std::sync::Arc; - -use pretty_assertions::assert_eq; - -use crate::stripe_billing::StripeBilling; -use crate::stripe_client::{FakeStripeClient, StripePrice, StripePriceId, StripePriceRecurring}; - -fn make_stripe_billing() -> (StripeBilling, Arc) { - let stripe_client = Arc::new(FakeStripeClient::new()); - let stripe_billing = StripeBilling::test(stripe_client.clone()); - - (stripe_billing, stripe_client) -} - -#[gpui::test] -async fn test_initialize() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - // Add test prices - let price1 = StripePrice { - id: StripePriceId("price_1".into()), - unit_amount: Some(1_000), - lookup_key: Some("zed-pro".to_string()), - recurring: None, - }; - let price2 = StripePrice { - id: StripePriceId("price_2".into()), - unit_amount: Some(0), - lookup_key: Some("zed-free".to_string()), - recurring: None, - }; - let price3 = StripePrice { - id: StripePriceId("price_3".into()), - unit_amount: Some(500), - lookup_key: None, - recurring: Some(StripePriceRecurring { - meter: Some("meter_1".to_string()), - }), - }; - stripe_client - .prices - .lock() - .insert(price1.id.clone(), price1); - stripe_client - .prices - .lock() - .insert(price2.id.clone(), price2); - stripe_client - .prices - .lock() - .insert(price3.id.clone(), price3); - - // Initialize the billing system - stripe_billing.initialize().await.unwrap(); - - // Verify that prices can be found by lookup key - let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap(); - assert_eq!(zed_pro_price_id.to_string(), "price_1"); - - let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap(); - assert_eq!(zed_free_price_id.to_string(), "price_2"); - - // Verify that a price can be found by lookup key - let zed_pro_price = stripe_billing - .find_price_by_lookup_key("zed-pro") - .await - .unwrap(); - assert_eq!(zed_pro_price.id.to_string(), "price_1"); - assert_eq!(zed_pro_price.unit_amount, Some(1_000)); - - // Verify that finding a non-existent lookup key returns an error - let result = stripe_billing - .find_price_by_lookup_key("non-existent") - .await; - assert!(result.is_err()); -} - -#[gpui::test] -async fn test_find_or_create_customer_by_email() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - // Create a customer with an email that doesn't yet correspond to a customer. - { - let email = "user@example.com"; - - let customer_id = stripe_billing - .find_or_create_customer_by_email(Some(email)) - .await - .unwrap(); - - let customer = stripe_client - .customers - .lock() - .get(&customer_id) - .unwrap() - .clone(); - assert_eq!(customer.email.as_deref(), Some(email)); - } - - // Create a customer with an email that corresponds to an existing customer. - { - let email = "user2@example.com"; - - let existing_customer_id = stripe_billing - .find_or_create_customer_by_email(Some(email)) - .await - .unwrap(); - - let customer_id = stripe_billing - .find_or_create_customer_by_email(Some(email)) - .await - .unwrap(); - assert_eq!(customer_id, existing_customer_id); - - let customer = stripe_client - .customers - .lock() - .get(&customer_id) - .unwrap() - .clone(); - assert_eq!(customer.email.as_deref(), Some(email)); - } -} diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index f5a0e8ea81..8c545b0670 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -1,4 +1,3 @@ -use crate::stripe_client::FakeStripeClient; use crate::{ AppState, Config, db::{NewUserParams, UserId, tests::TestDb}, @@ -569,9 +568,6 @@ impl TestServer { llm_db: None, livekit_client: Some(Arc::new(livekit_test_server.create_api_client())), blob_store_client: None, - real_stripe_client: None, - stripe_client: Some(Arc::new(FakeStripeClient::new())), - stripe_billing: None, executor, kinesis_client: None, config: Config { @@ -608,7 +604,6 @@ impl TestServer { auto_join_channel_id: None, migrations_path: None, seed_path: None, - stripe_api_key: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, kinesis_region: None, diff --git a/crates/collab_ui/src/chat_panel.rs b/crates/collab_ui/src/chat_panel.rs index 51d9f003f8..2bbaa8446c 100644 --- a/crates/collab_ui/src/chat_panel.rs +++ b/crates/collab_ui/src/chat_panel.rs @@ -674,7 +674,7 @@ impl ChatPanel { }) }) .when_some(message_id, |el, message_id| { - let this = cx.entity().clone(); + let this = cx.entity(); el.child( self.render_popover_button( diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 430b447580..c2cc6a7ad5 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -95,7 +95,7 @@ pub fn init(cx: &mut App) { .and_then(|room| room.read(cx).channel_id()); if let Some(channel_id) = channel_id { - let workspace = cx.entity().clone(); + let workspace = cx.entity(); window.defer(cx, move |window, cx| { ChannelView::open(channel_id, None, workspace, window, cx) .detach_and_log_err(cx) @@ -1142,7 +1142,7 @@ impl CollabPanel { window: &mut Window, cx: &mut Context, ) { - let this = cx.entity().clone(); + let this = cx.entity(); if !(role == proto::ChannelRole::Guest || role == proto::ChannelRole::Talker || role == proto::ChannelRole::Member) @@ -1272,7 +1272,7 @@ impl CollabPanel { .channel_for_id(clipboard.channel_id) .map(|channel| channel.name.clone()) }); - let this = cx.entity().clone(); + let this = cx.entity(); let context_menu = ContextMenu::build(window, cx, |mut context_menu, window, cx| { if self.has_subchannels(ix) { @@ -1439,7 +1439,7 @@ impl CollabPanel { window: &mut Window, cx: &mut Context, ) { - let this = cx.entity().clone(); + let this = cx.entity(); let in_room = ActiveCall::global(cx).read(cx).room().is_some(); let context_menu = ContextMenu::build(window, cx, |mut context_menu, _, _| { diff --git a/crates/collab_ui/src/collab_panel/channel_modal.rs b/crates/collab_ui/src/collab_panel/channel_modal.rs index c0d3130ee9..e558835dba 100644 --- a/crates/collab_ui/src/collab_panel/channel_modal.rs +++ b/crates/collab_ui/src/collab_panel/channel_modal.rs @@ -586,7 +586,7 @@ impl ChannelModalDelegate { return; }; let user_id = membership.user.id; - let picker = cx.entity().clone(); + let picker = cx.entity(); let context_menu = ContextMenu::build(window, cx, |mut menu, _window, _cx| { let role = membership.role; diff --git a/crates/collab_ui/src/notification_panel.rs b/crates/collab_ui/src/notification_panel.rs index 3a280ff667..a3420d603b 100644 --- a/crates/collab_ui/src/notification_panel.rs +++ b/crates/collab_ui/src/notification_panel.rs @@ -321,7 +321,7 @@ impl NotificationPanel { .justify_end() .child(Button::new("decline", "Decline").on_click({ let notification = notification.clone(); - let entity = cx.entity().clone(); + let entity = cx.entity(); move |_, _, cx| { entity.update(cx, |this, cx| { this.respond_to_notification( @@ -334,7 +334,7 @@ impl NotificationPanel { })) .child(Button::new("accept", "Accept").on_click({ let notification = notification.clone(); - let entity = cx.entity().clone(); + let entity = cx.entity(); move |_, _, cx| { entity.update(cx, |this, cx| { this.respond_to_notification( diff --git a/crates/debugger_ui/src/session/running.rs b/crates/debugger_ui/src/session/running.rs index c8bee42039..f3117aee07 100644 --- a/crates/debugger_ui/src/session/running.rs +++ b/crates/debugger_ui/src/session/running.rs @@ -291,7 +291,7 @@ pub(crate) fn new_debugger_pane( let Some(project) = project.upgrade() else { return ControlFlow::Break(()); }; - let this_pane = cx.entity().clone(); + let this_pane = cx.entity(); let item = if tab.pane == this_pane { pane.item_for_index(tab.ix) } else { @@ -502,7 +502,7 @@ pub(crate) fn new_debugger_pane( .on_drag( DraggedTab { item: item.boxed_clone(), - pane: cx.entity().clone(), + pane: cx.entity(), detail: 0, is_active: selected, ix, diff --git a/crates/diagnostics/src/diagnostics_tests.rs b/crates/diagnostics/src/diagnostics_tests.rs index 8fb223b2cb..5df1b13897 100644 --- a/crates/diagnostics/src/diagnostics_tests.rs +++ b/crates/diagnostics/src/diagnostics_tests.rs @@ -971,7 +971,7 @@ async fn active_diagnostics_dismiss_after_invalidation(cx: &mut TestAppContext) let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {" ˇfn func(abc def: i32) -> u32 { @@ -1065,7 +1065,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {" ˇfn func(abc def: i32) -> u32 { @@ -1239,7 +1239,7 @@ async fn test_diagnostics_with_links(cx: &mut TestAppContext) { } "}); let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.update(|_, cx| { lsp_store.update(cx, |lsp_store, cx| { @@ -1293,7 +1293,7 @@ async fn test_hover_diagnostic_and_info_popovers(cx: &mut gpui::TestAppContext) fn «test»() { println!(); } "}); let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.update(|_, cx| { lsp_store.update(cx, |lsp_store, cx| { lsp_store.update_diagnostics( @@ -1450,7 +1450,7 @@ async fn go_to_diagnostic_with_severity(cx: &mut TestAppContext) { let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {"error warning info hiˇnt"}); diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs index 3d3b43d71b..4632a03daf 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -127,7 +127,7 @@ impl Render for EditPredictionButton { }), ); } - let this = cx.entity().clone(); + let this = cx.entity(); div().child( PopoverMenu::new("copilot") @@ -182,7 +182,7 @@ impl Render for EditPredictionButton { let icon = status.to_icon(); let tooltip_text = status.to_tooltip(); let has_menu = status.has_menu(); - let this = cx.entity().clone(); + let this = cx.entity(); let fs = self.fs.clone(); return div().child( @@ -331,7 +331,7 @@ impl Render for EditPredictionButton { }) }); - let this = cx.entity().clone(); + let this = cx.entity(); let mut popover_menu = PopoverMenu::new("zeta") .menu(move |window, cx| { diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index a461fec791..0bda46774a 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -1039,9 +1039,7 @@ pub struct Editor { inline_diagnostics: Vec<(Anchor, InlineDiagnostic)>, soft_wrap_mode_override: Option, hard_wrap: Option, - - // TODO: make this a access method - pub project: Option>, + project: Option>, semantics_provider: Option>, completion_provider: Option>, collaboration_hub: Option>, @@ -2326,7 +2324,7 @@ impl Editor { editor.go_to_active_debug_line(window, cx); if let Some(buffer) = buffer.read(cx).as_singleton() { - if let Some(project) = editor.project.as_ref() { + if let Some(project) = editor.project() { let handle = project.update(cx, |project, cx| { project.register_buffer_with_language_servers(&buffer, cx) }); @@ -2371,6 +2369,34 @@ impl Editor { .is_some_and(|menu| menu.context_menu.focus_handle(cx).is_focused(window)) } + pub fn is_range_selected(&mut self, range: &Range, cx: &mut Context) -> bool { + if self + .selections + .pending + .as_ref() + .is_some_and(|pending_selection| { + let snapshot = self.buffer().read(cx).snapshot(cx); + pending_selection + .selection + .range() + .includes(&range, &snapshot) + }) + { + return true; + } + + self.selections + .disjoint_in_range::(range.clone(), cx) + .into_iter() + .any(|selection| { + // This is needed to cover a corner case, if we just check for an existing + // selection in the fold range, having a cursor at the start of the fold + // marks it as selected. Non-empty selections don't cause this. + let length = selection.end - selection.start; + length > 0 + }) + } + pub fn key_context(&self, window: &Window, cx: &App) -> KeyContext { self.key_context_internal(self.has_active_edit_prediction(), window, cx) } @@ -2626,6 +2652,10 @@ impl Editor { &self.buffer } + pub fn project(&self) -> Option<&Entity> { + self.project.as_ref() + } + pub fn workspace(&self) -> Option> { self.workspace.as_ref()?.0.upgrade() } @@ -5212,7 +5242,7 @@ impl Editor { restrict_to_languages: Option<&HashSet>>, cx: &mut Context, ) -> HashMap, clock::Global, Range)> { - let Some(project) = self.project.as_ref() else { + let Some(project) = self.project() else { return HashMap::default(); }; let project = project.read(cx); @@ -5294,7 +5324,7 @@ impl Editor { return None; } - let project = self.project.as_ref()?; + let project = self.project()?; let position = self.selections.newest_anchor().head(); let (buffer, buffer_position) = self .buffer @@ -6141,7 +6171,7 @@ impl Editor { cx: &mut App, ) -> Task> { maybe!({ - let project = self.project.as_ref()?; + let project = self.project()?; let dap_store = project.read(cx).dap_store(); let mut scenarios = vec![]; let resolved_tasks = resolved_tasks.as_ref()?; @@ -7907,7 +7937,7 @@ impl Editor { let snapshot = self.snapshot(window, cx); let multi_buffer_snapshot = &snapshot.display_snapshot.buffer_snapshot; - let Some(project) = self.project.as_ref() else { + let Some(project) = self.project() else { return breakpoint_display_points; }; @@ -10501,7 +10531,7 @@ impl Editor { ) { if let Some(working_directory) = self.active_excerpt(cx).and_then(|(_, buffer, _)| { let project_path = buffer.read(cx).project_path(cx)?; - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); let entry = project.entry_for_path(&project_path, cx)?; let parent = match &entry.canonical_path { Some(canonical_path) => canonical_path.to_path_buf(), @@ -14875,7 +14905,7 @@ impl Editor { self.clear_tasks(); return Task::ready(()); } - let project = self.project.as_ref().map(Entity::downgrade); + let project = self.project().map(Entity::downgrade); let task_sources = self.lsp_task_sources(cx); let multi_buffer = self.buffer.downgrade(); cx.spawn_in(window, async move |editor, cx| { @@ -17054,7 +17084,7 @@ impl Editor { if !pull_diagnostics_settings.enabled { return None; } - let project = self.project.as_ref()?.downgrade(); + let project = self.project()?.downgrade(); let debounce = Duration::from_millis(pull_diagnostics_settings.debounce_ms); let mut buffers = self.buffer.read(cx).all_buffers(); if let Some(buffer_id) = buffer_id { @@ -18018,7 +18048,7 @@ impl Editor { hunks: impl Iterator, cx: &mut App, ) -> Option<()> { - let project = self.project.as_ref()?; + let project = self.project()?; let buffer = project.read(cx).buffer_for_id(buffer_id, cx)?; let diff = self.buffer.read(cx).diff_for(buffer_id)?; let buffer_snapshot = buffer.read(cx).snapshot(); @@ -18678,7 +18708,7 @@ impl Editor { self.active_excerpt(cx).and_then(|(_, buffer, _)| { let buffer = buffer.read(cx); if let Some(project_path) = buffer.project_path(cx) { - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); project.absolute_path(&project_path, cx) } else { buffer @@ -18691,7 +18721,7 @@ impl Editor { fn target_file_path(&self, cx: &mut Context) -> Option { self.active_excerpt(cx).and_then(|(_, buffer, _)| { let project_path = buffer.read(cx).project_path(cx)?; - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); let entry = project.entry_for_path(&project_path, cx)?; let path = entry.path.to_path_buf(); Some(path) @@ -18912,7 +18942,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if let Some(project) = self.project.as_ref() { + if let Some(project) = self.project() { let Some(buffer) = self.buffer().read(cx).as_singleton() else { return; }; @@ -19028,7 +19058,7 @@ impl Editor { return Task::ready(Err(anyhow!("failed to determine buffer and selection"))); }; - let Some(project) = self.project.as_ref() else { + let Some(project) = self.project() else { return Task::ready(Err(anyhow!("editor does not have project"))); }; @@ -21015,7 +21045,7 @@ impl Editor { cx: &mut Context, ) { let workspace = self.workspace(); - let project = self.project.as_ref(); + let project = self.project(); let save_tasks = self.buffer().update(cx, |multi_buffer, cx| { let mut tasks = Vec::new(); for (buffer_id, changes) in revert_changes { diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index a5966b3301..ef2bdc5da3 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -74,7 +74,7 @@ fn test_edit_events(cx: &mut TestAppContext) { let editor1 = cx.add_window({ let events = events.clone(); |window, cx| { - let entity = cx.entity().clone(); + let entity = cx.entity(); cx.subscribe_in( &entity, window, @@ -95,7 +95,7 @@ fn test_edit_events(cx: &mut TestAppContext) { let events = events.clone(); |window, cx| { cx.subscribe_in( - &cx.entity().clone(), + &cx.entity(), window, move |_, _, event: &EditorEvent, _, _| match event { EditorEvent::Edited { .. } => events.borrow_mut().push(("editor2", "edited")), @@ -15082,7 +15082,7 @@ async fn go_to_prev_overlapping_diagnostic(executor: BackgroundExecutor, cx: &mu let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {" ˇfn func(abc def: i32) -> u32 { @@ -19634,13 +19634,8 @@ fn test_crease_insertion_and_rendering(cx: &mut TestAppContext) { editor.insert_creases(Some(crease), cx); let snapshot = editor.snapshot(window, cx); - let _div = snapshot.render_crease_toggle( - MultiBufferRow(1), - false, - cx.entity().clone(), - window, - cx, - ); + let _div = + snapshot.render_crease_toggle(MultiBufferRow(1), false, cx.entity(), window, cx); snapshot }) .unwrap(); diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 465650a38e..64844270f4 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -7818,7 +7818,7 @@ impl Element for EditorElement { min_lines, max_lines, } => { - let editor_handle = cx.entity().clone(); + let editor_handle = cx.entity(); let max_line_number_width = self.max_line_number_width(&editor.snapshot(window, cx), window); window.request_measured_layout( diff --git a/crates/editor/src/hover_popover.rs b/crates/editor/src/hover_popover.rs index b5904abca5..a443de6c34 100644 --- a/crates/editor/src/hover_popover.rs +++ b/crates/editor/src/hover_popover.rs @@ -250,7 +250,7 @@ fn show_hover( let (excerpt_id, _, _) = editor.buffer().read(cx).excerpt_containing(anchor, cx)?; - let language_registry = editor.project.as_ref()?.read(cx).languages().clone(); + let language_registry = editor.project()?.read(cx).languages().clone(); let provider = editor.semantics_provider.clone()?; if !ignore_timeout { diff --git a/crates/editor/src/items.rs b/crates/editor/src/items.rs index 45a4f7365c..34533002ff 100644 --- a/crates/editor/src/items.rs +++ b/crates/editor/src/items.rs @@ -678,7 +678,7 @@ impl Item for Editor { let buffer = buffer.read(cx); let path = buffer.project_path(cx)?; let buffer_id = buffer.remote_id(); - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); let entry = project.entry_for_path(&path, cx)?; let (repo, repo_path) = project .git_store() diff --git a/crates/editor/src/linked_editing_ranges.rs b/crates/editor/src/linked_editing_ranges.rs index a185de33ca..aaf9032b04 100644 --- a/crates/editor/src/linked_editing_ranges.rs +++ b/crates/editor/src/linked_editing_ranges.rs @@ -51,7 +51,7 @@ pub(super) fn refresh_linked_ranges( if editor.pending_rename.is_some() { return None; } - let project = editor.project.as_ref()?.downgrade(); + let project = editor.project()?.downgrade(); editor.linked_editing_range_task = Some(cx.spawn_in(window, async move |editor, cx| { cx.background_executor().timer(UPDATE_DEBOUNCE).await; diff --git a/crates/editor/src/signature_help.rs b/crates/editor/src/signature_help.rs index 6c1fed7ef8..d242e886c2 100644 --- a/crates/editor/src/signature_help.rs +++ b/crates/editor/src/signature_help.rs @@ -169,7 +169,7 @@ impl Editor { else { return; }; - let Some(lsp_store) = self.project.as_ref().map(|p| p.read(cx).lsp_store()) else { + let Some(lsp_store) = self.project().map(|p| p.read(cx).lsp_store()) else { return; }; let task = lsp_store.update(cx, |lsp_store, cx| { diff --git a/crates/editor/src/test/editor_test_context.rs b/crates/editor/src/test/editor_test_context.rs index bdf73da5fb..dbb519c40e 100644 --- a/crates/editor/src/test/editor_test_context.rs +++ b/crates/editor/src/test/editor_test_context.rs @@ -297,9 +297,8 @@ impl EditorTestContext { pub fn set_head_text(&mut self, diff_base: &str) { self.cx.run_until_parked(); - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone()); fs.set_head_for_repo( &Self::root_path().join(".git"), @@ -311,18 +310,16 @@ impl EditorTestContext { pub fn clear_index_text(&mut self) { self.cx.run_until_parked(); - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); fs.set_index_for_repo(&Self::root_path().join(".git"), &[]); self.cx.run_until_parked(); } pub fn set_index_text(&mut self, diff_base: &str) { self.cx.run_until_parked(); - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone()); fs.set_index_for_repo( &Self::root_path().join(".git"), @@ -333,9 +330,8 @@ impl EditorTestContext { #[track_caller] pub fn assert_index_text(&mut self, expected: Option<&str>) { - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone()); let mut found = None; fs.with_git_state(&Self::root_path().join(".git"), false, |git_state| { diff --git a/crates/extensions_ui/src/extensions_ui.rs b/crates/extensions_ui/src/extensions_ui.rs index ee671236c6..f63287f55a 100644 --- a/crates/extensions_ui/src/extensions_ui.rs +++ b/crates/extensions_ui/src/extensions_ui.rs @@ -701,7 +701,7 @@ impl ExtensionsPage { extension: &ExtensionMetadata, cx: &mut Context, ) -> ExtensionCard { - let this = cx.entity().clone(); + let this = cx.entity(); let status = Self::extension_status(&extension.id, cx); let has_dev_extension = Self::dev_extension_exists(&extension.id, cx); diff --git a/crates/git_ui/src/conflict_view.rs b/crates/git_ui/src/conflict_view.rs index 0bbb9411be..6482ebb9f8 100644 --- a/crates/git_ui/src/conflict_view.rs +++ b/crates/git_ui/src/conflict_view.rs @@ -112,7 +112,7 @@ fn excerpt_for_buffer_updated( } fn buffer_added(editor: &mut Editor, buffer: Entity, cx: &mut Context) { - let Some(project) = &editor.project else { + let Some(project) = editor.project() else { return; }; let git_store = project.read(cx).git_store().clone(); @@ -469,7 +469,7 @@ pub(crate) fn resolve_conflict( let Some((workspace, project, multibuffer, buffer)) = editor .update(cx, |editor, cx| { let workspace = editor.workspace()?; - let project = editor.project.clone()?; + let project = editor.project()?.clone(); let multibuffer = editor.buffer().clone(); let buffer_id = resolved_conflict.ours.end.buffer_id?; let buffer = multibuffer.read(cx).buffer(buffer_id)?; diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index aa305cbd8d..4b30bdebc0 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -3246,7 +3246,7 @@ impl GitPanel { * MAX_PANEL_EDITOR_LINES + gap; - let git_panel = cx.entity().clone(); + let git_panel = cx.entity(); let display_name = SharedString::from(Arc::from( active_repository .read(cx) diff --git a/crates/gpui/examples/input.rs b/crates/gpui/examples/input.rs index 52a5b08b96..b0f560e38d 100644 --- a/crates/gpui/examples/input.rs +++ b/crates/gpui/examples/input.rs @@ -595,9 +595,7 @@ impl Render for TextInput { .w_full() .p(px(4.)) .bg(white()) - .child(TextElement { - input: cx.entity().clone(), - }), + .child(TextElement { input: cx.entity() }), ) } } diff --git a/crates/language_tools/src/lsp_log.rs b/crates/language_tools/src/lsp_log.rs index 606f3a3f0e..823d59ce12 100644 --- a/crates/language_tools/src/lsp_log.rs +++ b/crates/language_tools/src/lsp_log.rs @@ -1358,7 +1358,7 @@ impl Render for LspLogToolbarItemView { }) .collect(); - let log_toolbar_view = cx.entity().clone(); + let log_toolbar_view = cx.entity(); let lsp_menu = PopoverMenu::new("LspLogView") .anchor(Corner::TopLeft) diff --git a/crates/language_tools/src/lsp_tool.rs b/crates/language_tools/src/lsp_tool.rs index 50547253a9..3244350a34 100644 --- a/crates/language_tools/src/lsp_tool.rs +++ b/crates/language_tools/src/lsp_tool.rs @@ -1007,7 +1007,7 @@ impl Render for LspTool { (None, "All Servers Operational") }; - let lsp_tool = cx.entity().clone(); + let lsp_tool = cx.entity(); div().child( PopoverMenu::new("lsp-tool") diff --git a/crates/language_tools/src/syntax_tree_view.rs b/crates/language_tools/src/syntax_tree_view.rs index eadba2c1d2..9946442ec8 100644 --- a/crates/language_tools/src/syntax_tree_view.rs +++ b/crates/language_tools/src/syntax_tree_view.rs @@ -456,7 +456,7 @@ impl SyntaxTreeToolbarItemView { let active_layer = buffer_state.active_layer.clone()?; let active_buffer = buffer_state.buffer.read(cx).snapshot(); - let view = cx.entity().clone(); + let view = cx.entity(); Some( PopoverMenu::new("Syntax Tree") .trigger(Self::render_header(&active_layer)) diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index b248b4d4a2..64ac605bb7 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -4655,51 +4655,45 @@ impl OutlinePanel { .when(show_indent_guides, |list| { list.with_decoration( ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx)) - .with_compute_indents_fn( - cx.entity().clone(), - |outline_panel, range, _, _| { - let entries = outline_panel.cached_entries.get(range); - if let Some(entries) = entries { - entries.into_iter().map(|item| item.depth).collect() - } else { - smallvec::SmallVec::new() - } - }, - ) - .with_render_fn( - cx.entity().clone(), - move |outline_panel, params, _, _| { - const LEFT_OFFSET: Pixels = px(14.); + .with_compute_indents_fn(cx.entity(), |outline_panel, range, _, _| { + let entries = outline_panel.cached_entries.get(range); + if let Some(entries) = entries { + entries.into_iter().map(|item| item.depth).collect() + } else { + smallvec::SmallVec::new() + } + }) + .with_render_fn(cx.entity(), move |outline_panel, params, _, _| { + const LEFT_OFFSET: Pixels = px(14.); - let indent_size = params.indent_size; - let item_height = params.item_height; - let active_indent_guide_ix = find_active_indent_guide_ix( - outline_panel, - ¶ms.indent_guides, - ); + let indent_size = params.indent_size; + let item_height = params.item_height; + let active_indent_guide_ix = find_active_indent_guide_ix( + outline_panel, + ¶ms.indent_guides, + ); - params - .indent_guides - .into_iter() - .enumerate() - .map(|(ix, layout)| { - let bounds = Bounds::new( - point( - layout.offset.x * indent_size + LEFT_OFFSET, - layout.offset.y * item_height, - ), - size(px(1.), layout.length * item_height), - ); - ui::RenderedIndentGuide { - bounds, - layout, - is_active: active_indent_guide_ix == Some(ix), - hitbox: None, - } - }) - .collect() - }, - ), + params + .indent_guides + .into_iter() + .enumerate() + .map(|(ix, layout)| { + let bounds = Bounds::new( + point( + layout.offset.x * indent_size + LEFT_OFFSET, + layout.offset.y * item_height, + ), + size(px(1.), layout.length * item_height), + ); + ui::RenderedIndentGuide { + bounds, + layout, + is_active: active_indent_guide_ix == Some(ix), + hitbox: None, + } + }) + .collect() + }), ) }) .custom_scrollbars( diff --git a/crates/project_panel/src/project_panel.rs b/crates/project_panel/src/project_panel.rs index 5dc1915c0f..e05466a57c 100644 --- a/crates/project_panel/src/project_panel.rs +++ b/crates/project_panel/src/project_panel.rs @@ -5187,26 +5187,22 @@ impl Render for ProjectPanel { .when(show_indent_guides, |list| { list.with_decoration( ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx)) - .with_compute_indents_fn( - cx.entity().clone(), - |this, range, window, cx| { - let mut items = - SmallVec::with_capacity(range.end - range.start); - this.iter_visible_entries( - range, - window, - cx, - |entry, _, entries, _, _| { - let (depth, _) = - Self::calculate_depth_and_difference( - entry, entries, - ); - items.push(depth); - }, - ); - items - }, - ) + .with_compute_indents_fn(cx.entity(), |this, range, window, cx| { + let mut items = + SmallVec::with_capacity(range.end - range.start); + this.iter_visible_entries( + range, + window, + cx, + |entry, _, entries, _, _| { + let (depth, _) = Self::calculate_depth_and_difference( + entry, entries, + ); + items.push(depth); + }, + ); + items + }) .on_click(cx.listener( |this, active_indent_guide: &IndentGuideLayout, window, cx| { if window.modifiers().secondary() { @@ -5230,7 +5226,7 @@ impl Render for ProjectPanel { } }, )) - .with_render_fn(cx.entity().clone(), move |this, params, _, cx| { + .with_render_fn(cx.entity(), move |this, params, _, cx| { const LEFT_OFFSET: Pixels = px(14.); const PADDING_Y: Pixels = px(4.); const HITBOX_OVERDRAW: Pixels = px(3.); @@ -5283,7 +5279,7 @@ impl Render for ProjectPanel { }) .when(show_sticky_entries, |list| { let sticky_items = ui::sticky_items( - cx.entity().clone(), + cx.entity(), |this, range, window, cx| { let mut items = SmallVec::with_capacity(range.end - range.start); this.iter_visible_entries( @@ -5310,7 +5306,7 @@ impl Render for ProjectPanel { list.with_decoration(if show_indent_guides { sticky_items.with_decoration( ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx)) - .with_render_fn(cx.entity().clone(), move |_, params, _, _| { + .with_render_fn(cx.entity(), move |_, params, _, _| { const LEFT_OFFSET: Pixels = px(14.); let indent_size = params.indent_size; diff --git a/crates/proto/proto/ai.proto b/crates/proto/proto/ai.proto index 67c2224387..1064ed2f8d 100644 --- a/crates/proto/proto/ai.proto +++ b/crates/proto/proto/ai.proto @@ -158,14 +158,6 @@ message SynchronizeContextsResponse { repeated ContextVersion contexts = 1; } -message GetLlmToken {} - -message GetLlmTokenResponse { - string token = 1; -} - -message RefreshLlmToken {} - enum LanguageModelRole { LanguageModelUser = 0; LanguageModelAssistant = 1; diff --git a/crates/proto/proto/app.proto b/crates/proto/proto/app.proto index 353f19adb2..1f2ab1f539 100644 --- a/crates/proto/proto/app.proto +++ b/crates/proto/proto/app.proto @@ -6,62 +6,6 @@ message UpdateInviteInfo { uint32 count = 2; } -message GetPrivateUserInfo {} - -message GetPrivateUserInfoResponse { - string metrics_id = 1; - bool staff = 2; - repeated string flags = 3; - optional uint64 accepted_tos_at = 4; -} - -enum Plan { - Free = 0; - ZedPro = 1; - ZedProTrial = 2; -} - -message UpdateUserPlan { - Plan plan = 1; - optional uint64 trial_started_at = 2; - optional bool is_usage_based_billing_enabled = 3; - optional SubscriptionUsage usage = 4; - optional SubscriptionPeriod subscription_period = 5; - optional bool account_too_young = 6; - optional bool has_overdue_invoices = 7; -} - -message SubscriptionPeriod { - uint64 started_at = 1; - uint64 ended_at = 2; -} - -message SubscriptionUsage { - uint32 model_requests_usage_amount = 1; - UsageLimit model_requests_usage_limit = 2; - uint32 edit_predictions_usage_amount = 3; - UsageLimit edit_predictions_usage_limit = 4; -} - -message UsageLimit { - oneof variant { - Limited limited = 1; - Unlimited unlimited = 2; - } - - message Limited { - uint32 limit = 1; - } - - message Unlimited {} -} - -message AcceptTermsOfService {} - -message AcceptTermsOfServiceResponse { - uint64 accepted_tos_at = 1; -} - message ShutdownRemoteServer {} message Toast { diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 856a793c2f..310fcf584e 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -135,12 +135,7 @@ message Envelope { FollowResponse follow_response = 99; UpdateFollowers update_followers = 100; Unfollow unfollow = 101; - GetPrivateUserInfo get_private_user_info = 102; - GetPrivateUserInfoResponse get_private_user_info_response = 103; - UpdateUserPlan update_user_plan = 234; UpdateDiffBases update_diff_bases = 104; - AcceptTermsOfService accept_terms_of_service = 239; - AcceptTermsOfServiceResponse accept_terms_of_service_response = 240; OnTypeFormatting on_type_formatting = 105; OnTypeFormattingResponse on_type_formatting_response = 106; @@ -250,10 +245,6 @@ message Envelope { AddWorktree add_worktree = 222; AddWorktreeResponse add_worktree_response = 223; - GetLlmToken get_llm_token = 235; - GetLlmTokenResponse get_llm_token_response = 236; - RefreshLlmToken refresh_llm_token = 259; - LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241; LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242; @@ -406,6 +397,7 @@ message Envelope { } reserved 87 to 88; + reserved 102 to 103; reserved 158 to 161; reserved 164; reserved 166 to 169; @@ -419,10 +411,13 @@ message Envelope { reserved 221; reserved 224 to 229; reserved 230 to 231; + reserved 234 to 236; + reserved 239 to 240; reserved 246; - reserved 270; reserved 247 to 254; reserved 255 to 256; + reserved 259; + reserved 270; reserved 280 to 281; reserved 332 to 333; } diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index a5dd97661f..802db09590 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -20,8 +20,6 @@ pub const SSH_PEER_ID: PeerId = PeerId { owner_id: 0, id: 0 }; pub const SSH_PROJECT_ID: u64 = 0; messages!( - (AcceptTermsOfService, Foreground), - (AcceptTermsOfServiceResponse, Foreground), (Ack, Foreground), (AckBufferOperation, Background), (AckChannelMessage, Background), @@ -105,8 +103,6 @@ messages!( (GetPathMetadataResponse, Background), (GetPermalinkToLine, Foreground), (GetPermalinkToLineResponse, Foreground), - (GetPrivateUserInfo, Foreground), - (GetPrivateUserInfoResponse, Foreground), (GetProjectSymbols, Background), (GetProjectSymbolsResponse, Background), (GetReferences, Background), @@ -119,8 +115,6 @@ messages!( (GetTypeDefinitionResponse, Background), (GetImplementation, Background), (GetImplementationResponse, Background), - (GetLlmToken, Background), - (GetLlmTokenResponse, Background), (OpenUnstagedDiff, Foreground), (OpenUnstagedDiffResponse, Foreground), (OpenUncommittedDiff, Foreground), @@ -196,7 +190,6 @@ messages!( (PrepareRenameResponse, Background), (ProjectEntryResponse, Foreground), (RefreshInlayHints, Foreground), - (RefreshLlmToken, Background), (RegisterBufferWithLanguageServers, Background), (RejoinChannelBuffers, Foreground), (RejoinChannelBuffersResponse, Foreground), @@ -280,7 +273,6 @@ messages!( (UpdateProject, Foreground), (UpdateProjectCollaborator, Foreground), (UpdateUserChannels, Foreground), - (UpdateUserPlan, Foreground), (UpdateWorktree, Foreground), (UpdateWorktreeSettings, Foreground), (UpdateRepository, Foreground), @@ -321,7 +313,6 @@ messages!( ); request_messages!( - (AcceptTermsOfService, AcceptTermsOfServiceResponse), (ApplyCodeAction, ApplyCodeActionResponse), ( ApplyCompletionAdditionalEdits, @@ -354,9 +345,7 @@ request_messages!( (GetDocumentHighlights, GetDocumentHighlightsResponse), (GetDocumentSymbols, GetDocumentSymbolsResponse), (GetHover, GetHoverResponse), - (GetLlmToken, GetLlmTokenResponse), (GetNotifications, GetNotificationsResponse), - (GetPrivateUserInfo, GetPrivateUserInfoResponse), (GetProjectSymbols, GetProjectSymbolsResponse), (GetReferences, GetReferencesResponse), (GetSignatureHelp, GetSignatureHelpResponse), diff --git a/crates/recent_projects/src/remote_servers.rs b/crates/recent_projects/src/remote_servers.rs index cd65f8fa6a..81562f5b14 100644 --- a/crates/recent_projects/src/remote_servers.rs +++ b/crates/recent_projects/src/remote_servers.rs @@ -1291,7 +1291,7 @@ impl RemoteServerProjects { let connection_string = connection_string.clone(); move |_, _: &menu::Confirm, window, cx| { remove_ssh_server( - cx.entity().clone(), + cx.entity(), server_index, connection_string.clone(), window, @@ -1311,7 +1311,7 @@ impl RemoteServerProjects { .child(Label::new("Remove Server").color(Color::Error)) .on_click(cx.listener(move |_, _, window, cx| { remove_ssh_server( - cx.entity().clone(), + cx.entity(), server_index, connection_string.clone(), window, diff --git a/crates/repl/src/session.rs b/crates/repl/src/session.rs index 729a616135..f945e5ed9f 100644 --- a/crates/repl/src/session.rs +++ b/crates/repl/src/session.rs @@ -244,7 +244,7 @@ impl Session { repl_session_id = cx.entity_id().to_string(), ); - let session_view = cx.entity().clone(); + let session_view = cx.entity(); let kernel = match self.kernel_specification.clone() { KernelSpecification::Jupyter(kernel_specification) diff --git a/crates/search/src/buffer_search.rs b/crates/search/src/buffer_search.rs index da2d35d74c..189f48e6b6 100644 --- a/crates/search/src/buffer_search.rs +++ b/crates/search/src/buffer_search.rs @@ -2,9 +2,9 @@ mod registrar; use crate::{ FocusSearch, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext, SearchOption, - SearchOptions, SelectAllMatches, SelectNextMatch, SelectPreviousMatch, ToggleCaseSensitive, - ToggleRegex, ToggleReplace, ToggleSelection, ToggleWholeWord, - search_bar::{input_base_styles, render_action_button, render_text_input}, + SearchOptions, SearchSource, SelectAllMatches, SelectNextMatch, SelectPreviousMatch, + ToggleCaseSensitive, ToggleRegex, ToggleReplace, ToggleSelection, ToggleWholeWord, + search_bar::{ActionButtonState, input_base_styles, render_action_button, render_text_input}, }; use any_vec::AnyVec; use anyhow::Context as _; @@ -213,22 +213,25 @@ impl Render for BufferSearchBar { h_flex() .gap_1() .when(case, |div| { - div.child( - SearchOption::CaseSensitive - .as_button(self.search_options, focus_handle.clone()), - ) + div.child(SearchOption::CaseSensitive.as_button( + self.search_options, + SearchSource::Buffer, + focus_handle.clone(), + )) }) .when(word, |div| { - div.child( - SearchOption::WholeWord - .as_button(self.search_options, focus_handle.clone()), - ) + div.child(SearchOption::WholeWord.as_button( + self.search_options, + SearchSource::Buffer, + focus_handle.clone(), + )) }) .when(regex, |div| { - div.child( - SearchOption::Regex - .as_button(self.search_options, focus_handle.clone()), - ) + div.child(SearchOption::Regex.as_button( + self.search_options, + SearchSource::Buffer, + focus_handle.clone(), + )) }), ) }); @@ -240,7 +243,7 @@ impl Render for BufferSearchBar { this.child(render_action_button( "buffer-search-bar-toggle", IconName::Replace, - self.replace_enabled, + self.replace_enabled.then_some(ActionButtonState::Toggled), "Toggle Replace", &ToggleReplace, focus_handle.clone(), @@ -285,7 +288,9 @@ impl Render for BufferSearchBar { .child(render_action_button( "buffer-search-nav-button", ui::IconName::ChevronLeft, - self.active_match_index.is_some(), + self.active_match_index + .is_none() + .then_some(ActionButtonState::Disabled), "Select Previous Match", &SelectPreviousMatch, query_focus.clone(), @@ -293,7 +298,9 @@ impl Render for BufferSearchBar { .child(render_action_button( "buffer-search-nav-button", ui::IconName::ChevronRight, - self.active_match_index.is_some(), + self.active_match_index + .is_none() + .then_some(ActionButtonState::Disabled), "Select Next Match", &SelectNextMatch, query_focus.clone(), @@ -313,7 +320,7 @@ impl Render for BufferSearchBar { el.child(render_action_button( "buffer-search-nav-button", IconName::SelectAll, - true, + Default::default(), "Select All Matches", &SelectAllMatches, query_focus, @@ -324,7 +331,7 @@ impl Render for BufferSearchBar { el.child(render_action_button( "buffer-search", IconName::Close, - true, + Default::default(), "Close Search Bar", &Dismiss, focus_handle.clone(), @@ -352,7 +359,7 @@ impl Render for BufferSearchBar { .child(render_action_button( "buffer-search-replace-button", IconName::ReplaceNext, - true, + Default::default(), "Replace Next Match", &ReplaceNext, focus_handle.clone(), @@ -360,7 +367,7 @@ impl Render for BufferSearchBar { .child(render_action_button( "buffer-search-replace-button", IconName::ReplaceAll, - true, + Default::default(), "Replace All Matches", &ReplaceAll, focus_handle, @@ -394,7 +401,7 @@ impl Render for BufferSearchBar { div.child(h_flex().absolute().right_0().child(render_action_button( "buffer-search", IconName::Close, - true, + Default::default(), "Close Search Bar", &Dismiss, focus_handle.clone(), diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index b791f748ad..056c3556ba 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -1,9 +1,9 @@ use crate::{ BufferSearchBar, FocusSearch, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext, - SearchOption, SearchOptions, SelectNextMatch, SelectPreviousMatch, ToggleCaseSensitive, - ToggleIncludeIgnored, ToggleRegex, ToggleReplace, ToggleWholeWord, + SearchOption, SearchOptions, SearchSource, SelectNextMatch, SelectPreviousMatch, + ToggleCaseSensitive, ToggleIncludeIgnored, ToggleRegex, ToggleReplace, ToggleWholeWord, buffer_search::Deploy, - search_bar::{input_base_styles, render_action_button, render_text_input}, + search_bar::{ActionButtonState, input_base_styles, render_action_button, render_text_input}, }; use anyhow::Context as _; use collections::HashMap; @@ -1665,7 +1665,7 @@ impl ProjectSearchBar { }); } - fn toggle_search_option( + pub(crate) fn toggle_search_option( &mut self, option: SearchOptions, window: &mut Window, @@ -1962,17 +1962,21 @@ impl Render for ProjectSearchBar { .child( h_flex() .gap_1() - .child( - SearchOption::CaseSensitive - .as_button(search.search_options, focus_handle.clone()), - ) - .child( - SearchOption::WholeWord - .as_button(search.search_options, focus_handle.clone()), - ) - .child( - SearchOption::Regex.as_button(search.search_options, focus_handle.clone()), - ), + .child(SearchOption::CaseSensitive.as_button( + search.search_options, + SearchSource::Project(cx), + focus_handle.clone(), + )) + .child(SearchOption::WholeWord.as_button( + search.search_options, + SearchSource::Project(cx), + focus_handle.clone(), + )) + .child(SearchOption::Regex.as_button( + search.search_options, + SearchSource::Project(cx), + focus_handle.clone(), + )), ); let query_focus = search.query_editor.focus_handle(cx); @@ -1985,7 +1989,10 @@ impl Render for ProjectSearchBar { .child(render_action_button( "project-search-nav-button", IconName::ChevronLeft, - search.active_match_index.is_some(), + search + .active_match_index + .is_none() + .then_some(ActionButtonState::Disabled), "Select Previous Match", &SelectPreviousMatch, query_focus.clone(), @@ -1993,7 +2000,10 @@ impl Render for ProjectSearchBar { .child(render_action_button( "project-search-nav-button", IconName::ChevronRight, - search.active_match_index.is_some(), + search + .active_match_index + .is_none() + .then_some(ActionButtonState::Disabled), "Select Next Match", &SelectNextMatch, query_focus, @@ -2054,7 +2064,7 @@ impl Render for ProjectSearchBar { self.active_project_search .as_ref() .map(|search| search.read(cx).replace_enabled) - .unwrap_or_default(), + .and_then(|enabled| enabled.then_some(ActionButtonState::Toggled)), "Toggle Replace", &ToggleReplace, focus_handle.clone(), @@ -2079,7 +2089,7 @@ impl Render for ProjectSearchBar { .child(render_action_button( "project-search-replace-button", IconName::ReplaceNext, - true, + Default::default(), "Replace Next Match", &ReplaceNext, focus_handle.clone(), @@ -2087,7 +2097,7 @@ impl Render for ProjectSearchBar { .child(render_action_button( "project-search-replace-button", IconName::ReplaceAll, - true, + Default::default(), "Replace All Matches", &ReplaceAll, focus_handle, @@ -2129,10 +2139,11 @@ impl Render for ProjectSearchBar { this.toggle_opened_only(window, cx); })), ) - .child( - SearchOption::IncludeIgnored - .as_button(search.search_options, focus_handle.clone()), - ); + .child(SearchOption::IncludeIgnored.as_button( + search.search_options, + SearchSource::Project(cx), + focus_handle.clone(), + )); h_flex() .w_full() .gap_2() diff --git a/crates/search/src/search.rs b/crates/search/src/search.rs index 89064e0a27..904c74d03c 100644 --- a/crates/search/src/search.rs +++ b/crates/search/src/search.rs @@ -1,7 +1,7 @@ use bitflags::bitflags; pub use buffer_search::BufferSearchBar; use editor::SearchSettings; -use gpui::{Action, App, FocusHandle, IntoElement, actions}; +use gpui::{Action, App, ClickEvent, FocusHandle, IntoElement, actions}; use project::search::SearchQuery; pub use project_search::ProjectSearchView; use ui::{ButtonStyle, IconButton, IconButtonShape}; @@ -11,6 +11,8 @@ use workspace::{Toast, Workspace}; pub use search_status_button::SEARCH_ICON; +use crate::project_search::ProjectSearchBar; + pub mod buffer_search; pub mod project_search; pub(crate) mod search_bar; @@ -83,9 +85,14 @@ pub enum SearchOption { Backwards, } +pub(crate) enum SearchSource<'a, 'b> { + Buffer, + Project(&'a Context<'b, ProjectSearchBar>), +} + impl SearchOption { - pub fn as_options(self) -> SearchOptions { - SearchOptions::from_bits(1 << self as u8).unwrap() + pub fn as_options(&self) -> SearchOptions { + SearchOptions::from_bits(1 << *self as u8).unwrap() } pub fn label(&self) -> &'static str { @@ -119,25 +126,41 @@ impl SearchOption { } } - pub fn as_button(&self, active: SearchOptions, focus_handle: FocusHandle) -> impl IntoElement { + pub(crate) fn as_button( + &self, + active: SearchOptions, + search_source: SearchSource, + focus_handle: FocusHandle, + ) -> impl IntoElement { let action = self.to_toggle_action(); let label = self.label(); - IconButton::new(label, self.icon()) - .on_click({ + IconButton::new( + (label, matches!(search_source, SearchSource::Buffer) as u32), + self.icon(), + ) + .map(|button| match search_source { + SearchSource::Buffer => { let focus_handle = focus_handle.clone(); - move |_, window, cx| { + button.on_click(move |_: &ClickEvent, window, cx| { if !focus_handle.is_focused(&window) { window.focus(&focus_handle); } - window.dispatch_action(action.boxed_clone(), cx) - } - }) - .style(ButtonStyle::Subtle) - .shape(IconButtonShape::Square) - .toggle_state(active.contains(self.as_options())) - .tooltip({ - move |window, cx| Tooltip::for_action_in(label, action, &focus_handle, window, cx) - }) + window.dispatch_action(action.boxed_clone(), cx); + }) + } + SearchSource::Project(cx) => { + let options = self.as_options(); + button.on_click(cx.listener(move |this, _: &ClickEvent, window, cx| { + this.toggle_search_option(options, window, cx); + })) + } + }) + .style(ButtonStyle::Subtle) + .shape(IconButtonShape::Square) + .toggle_state(active.contains(self.as_options())) + .tooltip({ + move |window, cx| Tooltip::for_action_in(label, action, &focus_handle, window, cx) + }) } } diff --git a/crates/search/src/search_bar.rs b/crates/search/src/search_bar.rs index 094ce3638e..8cc838a8a6 100644 --- a/crates/search/src/search_bar.rs +++ b/crates/search/src/search_bar.rs @@ -5,10 +5,15 @@ use theme::ThemeSettings; use ui::{IconButton, IconButtonShape}; use ui::{Tooltip, prelude::*}; +pub(super) enum ActionButtonState { + Disabled, + Toggled, +} + pub(super) fn render_action_button( id_prefix: &'static str, icon: ui::IconName, - active: bool, + button_state: Option, tooltip: &'static str, action: &'static dyn Action, focus_handle: FocusHandle, @@ -28,7 +33,10 @@ pub(super) fn render_action_button( } }) .tooltip(move |window, cx| Tooltip::for_action_in(tooltip, action, &focus_handle, window, cx)) - .disabled(!active) + .when_some(button_state, |this, state| match state { + ActionButtonState::Toggled => this.toggle_state(true), + ActionButtonState::Disabled => this.disabled(true), + }) } pub(crate) fn input_base_styles(border_color: Hsla, map: impl FnOnce(Div) -> Div) -> Div { diff --git a/crates/settings_ui/src/keybindings.rs b/crates/settings_ui/src/keybindings.rs index c6fab10a68..5d7de81ba1 100644 --- a/crates/settings_ui/src/keybindings.rs +++ b/crates/settings_ui/src/keybindings.rs @@ -2177,6 +2177,7 @@ impl KeybindingEditorModal { let value = action_arguments .as_ref() + .filter(|args| !args.is_empty()) .map(|args| { serde_json::from_str(args).context("Failed to parse action arguments as JSON") }) diff --git a/crates/storybook/src/stories/indent_guides.rs b/crates/storybook/src/stories/indent_guides.rs index e4f9669b1f..db23ea79bd 100644 --- a/crates/storybook/src/stories/indent_guides.rs +++ b/crates/storybook/src/stories/indent_guides.rs @@ -65,7 +65,7 @@ impl Render for IndentGuidesStory { }, ) .with_compute_indents_fn( - cx.entity().clone(), + cx.entity(), |this, range, _cx, _context| { this.depths .iter() diff --git a/crates/telemetry_events/src/telemetry_events.rs b/crates/telemetry_events/src/telemetry_events.rs index 735a1310ae..12d8d4c04b 100644 --- a/crates/telemetry_events/src/telemetry_events.rs +++ b/crates/telemetry_events/src/telemetry_events.rs @@ -2,7 +2,7 @@ use semantic_version::SemanticVersion; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, fmt::Display, sync::Arc, time::Duration}; +use std::{collections::HashMap, fmt::Display, time::Duration}; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct EventRequestBody { @@ -93,19 +93,6 @@ impl Display for AssistantPhase { #[serde(tag = "type")] pub enum Event { Flexible(FlexibleEvent), - Editor(EditorEvent), - EditPrediction(EditPredictionEvent), - EditPredictionRating(EditPredictionRatingEvent), - Call(CallEvent), - Assistant(AssistantEventData), - Cpu(CpuEvent), - Memory(MemoryEvent), - App(AppEvent), - Setting(SettingEvent), - Extension(ExtensionEvent), - Edit(EditEvent), - Action(ActionEvent), - Repl(ReplEvent), } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -114,54 +101,12 @@ pub struct FlexibleEvent { pub event_properties: HashMap, } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct EditorEvent { - /// The editor operation performed (open, save) - pub operation: String, - /// The extension of the file that was opened or saved - pub file_extension: Option, - /// Whether the user is in vim mode or not - pub vim_mode: bool, - /// Whether the user has copilot enabled or not - pub copilot_enabled: bool, - /// Whether the user has copilot enabled for the language of the file opened or saved - pub copilot_enabled_for_language: bool, - /// Whether the client is opening/saving a local file or a remote file via SSH - #[serde(default)] - pub is_via_ssh: bool, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct EditPredictionEvent { - /// Provider of the completion suggestion (e.g. copilot, supermaven) - pub provider: String, - pub suggestion_accepted: bool, - pub file_extension: Option, -} - #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum EditPredictionRating { Positive, Negative, } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct EditPredictionRatingEvent { - pub rating: EditPredictionRating, - pub input_events: Arc, - pub input_excerpt: Arc, - pub output_excerpt: Arc, - pub feedback: String, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct CallEvent { - /// Operation performed: invite/join call; begin/end screenshare; share/unshare project; etc - pub operation: String, - pub room_id: Option, - pub channel_id: Option, -} - #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct AssistantEventData { /// Unique random identifier for each assistant tab (None for inline assist) @@ -180,57 +125,6 @@ pub struct AssistantEventData { pub language_name: Option, } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct CpuEvent { - pub usage_as_percentage: f32, - pub core_count: u32, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct MemoryEvent { - pub memory_in_bytes: u64, - pub virtual_memory_in_bytes: u64, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct ActionEvent { - pub source: String, - pub action: String, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct EditEvent { - pub duration: i64, - pub environment: String, - /// Whether the edits occurred locally or remotely via SSH - #[serde(default)] - pub is_via_ssh: bool, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct SettingEvent { - pub setting: String, - pub value: String, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct ExtensionEvent { - pub extension_id: Arc, - pub version: Arc, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct AppEvent { - pub operation: String, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct ReplEvent { - pub kernel_language: String, - pub kernel_status: String, - pub repl_session_id: String, -} - #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct BacktraceFrame { pub ip: usize, diff --git a/crates/terminal_view/src/terminal_panel.rs b/crates/terminal_view/src/terminal_panel.rs index c9528c39b9..568dc1db2e 100644 --- a/crates/terminal_view/src/terminal_panel.rs +++ b/crates/terminal_view/src/terminal_panel.rs @@ -947,7 +947,7 @@ pub fn new_terminal_pane( cx: &mut Context, ) -> Entity { let is_local = project.read(cx).is_local(); - let terminal_panel = cx.entity().clone(); + let terminal_panel = cx.entity(); let pane = cx.new(|cx| { let mut pane = Pane::new( workspace.clone(), @@ -1009,7 +1009,7 @@ pub fn new_terminal_pane( return ControlFlow::Break(()); }; if let Some(tab) = dropped_item.downcast_ref::() { - let this_pane = cx.entity().clone(); + let this_pane = cx.entity(); let item = if tab.pane == this_pane { pane.item_for_index(tab.ix) } else { diff --git a/crates/terminal_view/src/terminal_view.rs b/crates/terminal_view/src/terminal_view.rs index 7234c16d2c..ca0a79ccfd 100644 --- a/crates/terminal_view/src/terminal_view.rs +++ b/crates/terminal_view/src/terminal_view.rs @@ -1391,7 +1391,7 @@ impl Render for TerminalView { } let terminal_handle = self.terminal.clone(); - let terminal_view_handle = cx.entity().clone(); + let terminal_view_handle = cx.entity(); let focused = self.focus_handle.is_focused(window); diff --git a/crates/vim/src/command.rs b/crates/vim/src/command.rs index 264fa4bf2f..ce5e5a0300 100644 --- a/crates/vim/src/command.rs +++ b/crates/vim/src/command.rs @@ -299,7 +299,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, action: &VimSave, window, cx| { vim.update_editor(cx, |_, editor, cx| { - let Some(project) = editor.project.clone() else { + let Some(project) = editor.project().cloned() else { return; }; let Some(worktree) = project.read(cx).visible_worktrees(cx).next() else { @@ -436,7 +436,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { let Some(workspace) = vim.workspace(window) else { return; }; - let Some(project) = editor.project.clone() else { + let Some(project) = editor.project().cloned() else { return; }; let Some(worktree) = project.read(cx).visible_worktrees(cx).next() else { diff --git a/crates/vim/src/mode_indicator.rs b/crates/vim/src/mode_indicator.rs index d54b270074..714b74f239 100644 --- a/crates/vim/src/mode_indicator.rs +++ b/crates/vim/src/mode_indicator.rs @@ -20,7 +20,7 @@ impl ModeIndicator { }) .detach(); - let handle = cx.entity().clone(); + let handle = cx.entity(); let window_handle = window.window_handle(); cx.observe_new::(move |_, window, cx| { let Some(window) = window else { @@ -29,7 +29,7 @@ impl ModeIndicator { if window.window_handle() != window_handle { return; } - let vim = cx.entity().clone(); + let vim = cx.entity(); handle.update(cx, |_, cx| { cx.subscribe(&vim, |mode_indicator, vim, event, cx| match event { VimEvent::Focused => { diff --git a/crates/vim/src/normal/search.rs b/crates/vim/src/normal/search.rs index e4e95ca48e..4054c552ae 100644 --- a/crates/vim/src/normal/search.rs +++ b/crates/vim/src/normal/search.rs @@ -332,7 +332,7 @@ impl Vim { Vim::take_forced_motion(cx); let prior_selections = self.editor_selections(window, cx); let cursor_word = self.editor_cursor_word(window, cx); - let vim = cx.entity().clone(); + let vim = cx.entity(); let searched = pane.update(cx, |pane, cx| { self.search.direction = direction; diff --git a/crates/vim/src/vim.rs b/crates/vim/src/vim.rs index 51bf2dd131..44d9b8f456 100644 --- a/crates/vim/src/vim.rs +++ b/crates/vim/src/vim.rs @@ -402,7 +402,7 @@ impl Vim { const NAMESPACE: &'static str = "vim"; pub fn new(window: &mut Window, cx: &mut Context) -> Entity { - let editor = cx.entity().clone(); + let editor = cx.entity(); let mut initial_mode = VimSettings::get_global(cx).default_mode; if initial_mode == Mode::Normal && HelixModeSetting::get_global(cx).0 { diff --git a/crates/workspace/src/dock.rs b/crates/workspace/src/dock.rs index ca63d3e553..ae72df3971 100644 --- a/crates/workspace/src/dock.rs +++ b/crates/workspace/src/dock.rs @@ -253,7 +253,7 @@ impl Dock { cx: &mut Context, ) -> Entity { let focus_handle = cx.focus_handle(); - let workspace = cx.entity().clone(); + let workspace = cx.entity(); let dock = cx.new(|cx| { let focus_subscription = cx.on_focus(&focus_handle, window, |dock: &mut Dock, window, cx| { diff --git a/crates/workspace/src/notifications.rs b/crates/workspace/src/notifications.rs index 7d8a28b0f1..1356322a5c 100644 --- a/crates/workspace/src/notifications.rs +++ b/crates/workspace/src/notifications.rs @@ -346,7 +346,7 @@ impl Render for LanguageServerPrompt { ) .child(Label::new(request.message.to_string()).size(LabelSize::Small)) .children(request.actions.iter().enumerate().map(|(ix, action)| { - let this_handle = cx.entity().clone(); + let this_handle = cx.entity(); Button::new(ix, action.title.clone()) .size(ButtonSize::Large) .on_click(move |_, window, cx| { diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index 759e91f758..860a57c21f 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -2198,7 +2198,7 @@ impl Pane { fn update_status_bar(&mut self, window: &mut Window, cx: &mut Context) { let workspace = self.workspace.clone(); - let pane = cx.entity().clone(); + let pane = cx.entity(); window.defer(cx, move |window, cx| { let Ok(status_bar) = @@ -2279,7 +2279,7 @@ impl Pane { cx: &mut Context, ) { maybe!({ - let pane = cx.entity().clone(); + let pane = cx.entity(); let destination_index = match operation { PinOperation::Pin => self.pinned_tab_count.min(ix), @@ -2473,7 +2473,7 @@ impl Pane { .on_drag( DraggedTab { item: item.boxed_clone(), - pane: cx.entity().clone(), + pane: cx.entity(), detail, is_active, ix, @@ -2832,7 +2832,7 @@ impl Pane { let navigate_backward = IconButton::new("navigate_backward", IconName::ArrowLeft) .icon_size(IconSize::Small) .on_click({ - let entity = cx.entity().clone(); + let entity = cx.entity(); move |_, window, cx| { entity.update(cx, |pane, cx| pane.navigate_backward(window, cx)) } @@ -2848,7 +2848,7 @@ impl Pane { let navigate_forward = IconButton::new("navigate_forward", IconName::ArrowRight) .icon_size(IconSize::Small) .on_click({ - let entity = cx.entity().clone(); + let entity = cx.entity(); move |_, window, cx| entity.update(cx, |pane, cx| pane.navigate_forward(window, cx)) }) .disabled(!self.can_navigate_forward()) @@ -3054,7 +3054,7 @@ impl Pane { return; } } - let mut to_pane = cx.entity().clone(); + let mut to_pane = cx.entity(); let split_direction = self.drag_split_direction; let item_id = dragged_tab.item.item_id(); if let Some(preview_item_id) = self.preview_item_id { @@ -3163,7 +3163,7 @@ impl Pane { return; } } - let mut to_pane = cx.entity().clone(); + let mut to_pane = cx.entity(); let split_direction = self.drag_split_direction; let project_entry_id = *project_entry_id; self.workspace @@ -3239,7 +3239,7 @@ impl Pane { return; } } - let mut to_pane = cx.entity().clone(); + let mut to_pane = cx.entity(); let mut split_direction = self.drag_split_direction; let paths = paths.paths().to_vec(); let is_remote = self diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index ade6838fad..1eaa125ba5 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -6338,7 +6338,7 @@ impl Render for Workspace { .border_b_1() .border_color(colors.border) .child({ - let this = cx.entity().clone(); + let this = cx.entity(); canvas( move |bounds, window, cx| { this.update(cx, |this, cx| { diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 84145a1be4..b06652b2ce 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -319,7 +319,7 @@ pub fn initialize_workspace( return; }; - let workspace_handle = cx.entity().clone(); + let workspace_handle = cx.entity(); let center_pane = workspace.active_pane().clone(); initialize_pane(workspace, ¢er_pane, window, cx); diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index da4b6e78c6..5b0826413b 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -229,8 +229,7 @@ fn assign_edit_prediction_provider( if let Some(file) = buffer.read(cx).file() { let id = file.worktree_id(cx); if let Some(inner_worktree) = editor - .project - .as_ref() + .project() .and_then(|project| project.read(cx).worktree_for_id(id, cx)) { worktree = Some(inner_worktree);