diff --git a/Cargo.lock b/Cargo.lock index c1e3a936d4..4e4c18b9e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4021,7 +4021,7 @@ dependencies = [ "util", "uuid", "workspace", - "zed_predict_tos", + "zed_predict_onboarding", ] [[package]] @@ -6415,7 +6415,7 @@ dependencies = [ "ui", "workspace", "zed_actions", - "zed_predict_tos", + "zed_predict_onboarding", "zeta", ] @@ -13541,6 +13541,7 @@ dependencies = [ "windows 0.58.0", "workspace", "zed_actions", + "zed_predict_onboarding", ] [[package]] @@ -16557,7 +16558,7 @@ dependencies = [ "winresource", "workspace", "zed_actions", - "zed_predict_tos", + "zed_predict_onboarding", "zeta", ] @@ -16672,13 +16673,21 @@ dependencies = [ ] [[package]] -name = "zed_predict_tos" +name = "zed_predict_onboarding" version = "0.1.0" dependencies = [ + "chrono", "client", + "db", + "feature_flags", + "fs", "gpui", + "language", "menu", + "settings", + "theme", "ui", + "util", "workspace", ] @@ -16872,6 +16881,7 @@ dependencies = [ "collections", "command_palette_hooks", "ctor", + "db", "editor", "env_logger 0.11.6", "feature_flags", @@ -16886,6 +16896,7 @@ dependencies = [ "menu", "reqwest_client", "rpc", + "serde", "serde_json", "settings", "similar", diff --git a/Cargo.toml b/Cargo.toml index 569e358190..2bdbe2a567 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -152,7 +152,7 @@ members = [ "crates/worktree", "crates/zed", "crates/zed_actions", - "crates/zed_predict_tos", + "crates/zed_predict_onboarding", "crates/zeta", # @@ -201,7 +201,6 @@ edition = "2021" activity_indicator = { path = "crates/activity_indicator" } ai = { path = "crates/ai" } -zed_predict_tos = { path = "crates/zed_predict_tos" } anthropic = { path = "crates/anthropic" } assets = { path = "crates/assets" } assistant = { path = "crates/assistant" } @@ -350,6 +349,7 @@ workspace = { path = "crates/workspace" } worktree = { path = "crates/worktree" } zed = { path = "crates/zed" } zed_actions = { path = "crates/zed_actions" } +zed_predict_onboarding = { path = "crates/zed_predict_onboarding" } zeta = { path = "crates/zeta" } # diff --git a/assets/icons/zed_predict.svg b/assets/icons/zed_predict.svg index 75d64a9439..79fd8c8fc1 100644 --- a/assets/icons/zed_predict.svg +++ b/assets/icons/zed_predict.svg @@ -1,5 +1,5 @@ - + diff --git a/assets/icons/zed_predict_bg.svg b/assets/icons/zed_predict_bg.svg new file mode 100644 index 0000000000..de2a0d444c --- /dev/null +++ b/assets/icons/zed_predict_bg.svg @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index a5b062c22a..5957001a27 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -823,5 +823,12 @@ "shift-end": "terminal::ScrollToBottom", "ctrl-shift-space": "terminal::ToggleViMode" } + }, + { + "context": "ZedPredictModal", + "use_key_equivalents": true, + "bindings": { + "escape": "menu::Cancel" + } } ] diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 693337f644..0ea34140ab 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -883,7 +883,7 @@ } }, { - "context": "ZedPredictTos", + "context": "ZedPredictModal", "use_key_equivalents": true, "bindings": { "escape": "menu::Cancel" diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 91af9a1a7f..a8d298e1c4 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -121,9 +121,7 @@ pub enum Event { }, ShowContacts, ParticipantIndicesChanged, - TermsStatusUpdated { - accepted: bool, - }, + PrivateUserInfoUpdated, } #[derive(Clone, Copy)] @@ -227,9 +225,7 @@ impl UserStore { }; this.set_current_user_accepted_tos_at(accepted_tos_at); - cx.emit(Event::TermsStatusUpdated { - accepted: accepted_tos_at.is_some(), - }); + cx.emit(Event::PrivateUserInfoUpdated); }) } else { anyhow::Ok(()) @@ -244,6 +240,8 @@ impl UserStore { Status::SignedOut => { current_user_tx.send(None).await.ok(); this.update(&mut cx, |this, cx| { + this.accepted_tos_at = None; + cx.emit(Event::PrivateUserInfoUpdated); cx.notify(); this.clear_contacts() })? @@ -714,7 +712,7 @@ impl UserStore { this.update(&mut cx, |this, cx| { this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at)); - cx.emit(Event::TermsStatusUpdated { accepted: true }); + cx.emit(Event::PrivateUserInfoUpdated); }) } else { Err(anyhow!("client not found")) diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 3234092637..b5ee3713fc 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -447,7 +447,7 @@ async fn predict_edits( )); } - let sample_input_output = claims.is_staff && rand::random::() < 0.1; + let should_sample = claims.is_staff || params.can_collect_data; let api_url = state .config @@ -541,7 +541,7 @@ async fn predict_edits( let output = choice.text.clone(); async move { - let properties = if sample_input_output { + let properties = if should_sample { json!({ "model": model.to_string(), "headers": response.headers, diff --git a/crates/editor/Cargo.toml b/crates/editor/Cargo.toml index 95afdc4c1a..9a35c17ad1 100644 --- a/crates/editor/Cargo.toml +++ b/crates/editor/Cargo.toml @@ -88,7 +88,7 @@ url.workspace = true util.workspace = true uuid.workspace = true workspace.workspace = true -zed_predict_tos.workspace = true +zed_predict_onboarding.workspace = true [dev-dependencies] ctor.workspace = true diff --git a/crates/editor/src/code_context_menus.rs b/crates/editor/src/code_context_menus.rs index 7238fc65fe..c8cc4872d8 100644 --- a/crates/editor/src/code_context_menus.rs +++ b/crates/editor/src/code_context_menus.rs @@ -652,7 +652,7 @@ impl CompletionsMenu { ) .on_click(cx.listener(move |editor, _event, window, cx| { cx.stop_propagation(); - editor.toggle_zed_predict_tos(window, cx); + editor.toggle_zed_predict_onboarding(window, cx); })), ), diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index b3df175f14..3972925435 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -69,7 +69,7 @@ pub use element::{ }; use futures::{future, FutureExt}; use fuzzy::StringMatchCandidate; -use zed_predict_tos::ZedPredictTos; +use zed_predict_onboarding::ZedPredictModal; use code_context_menus::{ AvailableCodeAction, CodeActionContents, CodeActionsItem, CodeActionsMenu, CodeContextMenu, @@ -3948,12 +3948,21 @@ impl Editor { self.do_completion(action.item_ix, CompletionIntent::Compose, window, cx) } - fn toggle_zed_predict_tos(&mut self, window: &mut Window, cx: &mut Context) { + fn toggle_zed_predict_onboarding(&mut self, window: &mut Window, cx: &mut Context) { let (Some(workspace), Some(project)) = (self.workspace(), self.project.as_ref()) else { return; }; - ZedPredictTos::toggle(workspace, project.read(cx).user_store().clone(), window, cx); + let project = project.read(cx); + + ZedPredictModal::toggle( + workspace, + project.user_store().clone(), + project.client().clone(), + project.fs().clone(), + window, + cx, + ); } fn do_completion( @@ -3985,7 +3994,7 @@ impl Editor { )) => { drop(entries); drop(context_menu); - self.toggle_zed_predict_tos(window, cx); + self.toggle_zed_predict_onboarding(window, cx); return Some(Task::ready(Ok(()))); } _ => {} diff --git a/crates/editor/src/persistence.rs b/crates/editor/src/persistence.rs index 06e2ea1f9b..4895512417 100644 --- a/crates/editor/src/persistence.rs +++ b/crates/editor/src/persistence.rs @@ -87,8 +87,8 @@ define_connection!( // mtime_seconds: Option, // mtime_nanos: Option, // ) - pub static ref DB: EditorDb = - &[sql! ( + pub static ref DB: EditorDb = &[ + sql! ( CREATE TABLE editors( item_id INTEGER NOT NULL, workspace_id INTEGER NOT NULL, @@ -134,7 +134,7 @@ define_connection!( ALTER TABLE editors ADD COLUMN mtime_seconds INTEGER DEFAULT NULL; ALTER TABLE editors ADD COLUMN mtime_nanos INTEGER DEFAULT NULL; ), - ]; + ]; ); impl EditorDb { diff --git a/crates/inline_completion/src/inline_completion.rs b/crates/inline_completion/src/inline_completion.rs index 5b99dcbb79..089600ef2c 100644 --- a/crates/inline_completion/src/inline_completion.rs +++ b/crates/inline_completion/src/inline_completion.rs @@ -18,6 +18,31 @@ pub struct InlineCompletion { pub edit_preview: Option, } +pub enum DataCollectionState { + /// The provider doesn't support data collection. + Unsupported, + /// When there's a file not saved yet. In this case, we can't tell to which project it belongs. + Unknown, + /// Data collection is enabled + Enabled, + /// Data collection is disabled or unanswered. + Disabled, +} + +impl DataCollectionState { + pub fn is_supported(&self) -> bool { + !matches!(self, DataCollectionState::Unsupported) + } + + pub fn is_unknown(&self) -> bool { + matches!(self, DataCollectionState::Unknown) + } + + pub fn is_enabled(&self) -> bool { + matches!(self, DataCollectionState::Enabled) + } +} + pub trait InlineCompletionProvider: 'static + Sized { fn name() -> &'static str; fn display_name() -> &'static str; @@ -26,6 +51,10 @@ pub trait InlineCompletionProvider: 'static + Sized { fn show_tab_accept_marker() -> bool { false } + fn data_collection_state(&self, _cx: &App) -> DataCollectionState { + DataCollectionState::Unsupported + } + fn toggle_data_collection(&mut self, _cx: &mut App) {} fn is_enabled( &self, buffer: &Entity, @@ -72,6 +101,8 @@ pub trait InlineCompletionProviderHandle { fn show_completions_in_menu(&self) -> bool; fn show_completions_in_normal_mode(&self) -> bool; fn show_tab_accept_marker(&self) -> bool; + fn data_collection_state(&self, cx: &App) -> DataCollectionState; + fn toggle_data_collection(&self, cx: &mut App); fn needs_terms_acceptance(&self, cx: &App) -> bool; fn is_refreshing(&self, cx: &App) -> bool; fn refresh( @@ -122,6 +153,14 @@ where T::show_tab_accept_marker() } + fn data_collection_state(&self, cx: &App) -> DataCollectionState { + self.read(cx).data_collection_state(cx) + } + + fn toggle_data_collection(&self, cx: &mut App) { + self.update(cx, |this, cx| this.toggle_data_collection(cx)) + } + fn is_enabled( &self, buffer: &Entity, diff --git a/crates/inline_completion_button/Cargo.toml b/crates/inline_completion_button/Cargo.toml index 27aaf093a6..627b7b791c 100644 --- a/crates/inline_completion_button/Cargo.toml +++ b/crates/inline_completion_button/Cargo.toml @@ -29,7 +29,7 @@ workspace.workspace = true zed_actions.workspace = true zeta.workspace = true client.workspace = true -zed_predict_tos.workspace = true +zed_predict_onboarding.workspace = true [dev-dependencies] copilot = { workspace = true, features = ["test-support"] } diff --git a/crates/inline_completion_button/src/inline_completion_button.rs b/crates/inline_completion_button/src/inline_completion_button.rs index 84c77e79d6..82e2239345 100644 --- a/crates/inline_completion_button/src/inline_completion_button.rs +++ b/crates/inline_completion_button/src/inline_completion_button.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use client::UserStore; +use client::{Client, UserStore}; use copilot::{Copilot, Status}; use editor::{actions::ShowInlineCompletion, scroll::Autoscroll, Editor}; use feature_flags::{ @@ -20,18 +20,16 @@ use language::{ use settings::{update_settings_file, Settings, SettingsStore}; use std::{path::Path, sync::Arc, time::Duration}; use supermaven::{AccountStatus, Supermaven}; -use ui::{prelude::*, ButtonLike, Color, Icon, IconWithIndicator, Indicator, PopoverMenuHandle}; +use ui::{ + prelude::*, ButtonLike, Clickable, ContextMenu, ContextMenuEntry, IconButton, + IconWithIndicator, Indicator, PopoverMenu, PopoverMenuHandle, Tooltip, +}; use workspace::{ - create_and_open_local_file, - item::ItemHandle, - notifications::NotificationId, - ui::{ - ButtonCommon, Clickable, ContextMenu, IconButton, IconName, IconSize, PopoverMenu, Tooltip, - }, - StatusItemView, Toast, Workspace, + create_and_open_local_file, item::ItemHandle, notifications::NotificationId, StatusItemView, + Toast, Workspace, }; use zed_actions::OpenBrowser; -use zed_predict_tos::ZedPredictTos; +use zed_predict_onboarding::ZedPredictModal; use zeta::RateCompletionModal; actions!(zeta, [RateCompletions]); @@ -48,6 +46,7 @@ pub struct InlineCompletionButton { language: Option>, file: Option>, inline_completion_provider: Option>, + client: Arc, fs: Arc, workspace: WeakEntity, user_store: Entity, @@ -231,14 +230,16 @@ impl Render for InlineCompletionButton { return div(); } - if !self - .user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false) - { + let current_user_terms_accepted = + self.user_store.read(cx).current_user_has_accepted_terms(); + + if !current_user_terms_accepted.unwrap_or(false) { let workspace = self.workspace.clone(); let user_store = self.user_store.clone(); + let client = self.client.clone(); + let fs = self.fs.clone(); + + let signed_in = current_user_terms_accepted.is_some(); return div().child( ButtonLike::new("zeta-pending-tos-icon") @@ -252,20 +253,29 @@ impl Render for InlineCompletionButton { )) .into_any_element(), ) - .tooltip(|window, cx| { + .tooltip(move |window, cx| { Tooltip::with_meta( "Edit Predictions", None, - "Read Terms of Service", + if signed_in { + "Read Terms of Service" + } else { + "Sign in to use" + }, window, cx, ) }) .on_click(cx.listener(move |_, _, window, cx| { - let user_store = user_store.clone(); - if let Some(workspace) = workspace.upgrade() { - ZedPredictTos::toggle(workspace, user_store, window, cx); + ZedPredictModal::toggle( + workspace, + user_store.clone(), + client.clone(), + fs.clone(), + window, + cx, + ); } })), ); @@ -318,6 +328,7 @@ impl InlineCompletionButton { workspace: WeakEntity, fs: Arc, user_store: Entity, + client: Arc, popover_menu_handle: PopoverMenuHandle, cx: &mut Context, ) -> Self { @@ -337,6 +348,7 @@ impl InlineCompletionButton { inline_completion_provider: None, popover_menu_handle, workspace, + client, fs, user_store, } @@ -430,6 +442,22 @@ impl InlineCompletionButton { move |_, cx| toggle_inline_completions_globally(fs.clone(), cx), ); + if let Some(provider) = &self.inline_completion_provider { + let data_collection = provider.data_collection_state(cx); + + if data_collection.is_supported() { + let provider = provider.clone(); + menu = menu.separator().item( + ContextMenuEntry::new("Data Collection") + .toggleable(IconPosition::Start, data_collection.is_enabled()) + .disabled(data_collection.is_unknown()) + .handler(move |_, cx| { + provider.toggle_data_collection(cx); + }), + ); + } + } + if let Some(editor_focus_handle) = self.editor_focus_handle.clone() { menu = menu .separator() diff --git a/crates/rpc/src/llm.rs b/crates/rpc/src/llm.rs index df48d9ed92..c1612662bf 100644 --- a/crates/rpc/src/llm.rs +++ b/crates/rpc/src/llm.rs @@ -39,6 +39,9 @@ pub struct PredictEditsParams { pub outline: Option, pub input_events: String, pub input_excerpt: String, + /// Whether the user provided consent for sampling this interaction. + #[serde(default)] + pub can_collect_data: bool, } #[derive(Debug, Serialize, Deserialize)] diff --git a/crates/title_bar/Cargo.toml b/crates/title_bar/Cargo.toml index a07c2e4b64..42b05d28b3 100644 --- a/crates/title_bar/Cargo.toml +++ b/crates/title_bar/Cargo.toml @@ -48,6 +48,7 @@ telemetry.workspace = true workspace.workspace = true zed_actions.workspace = true git_ui.workspace = true +zed_predict_onboarding.workspace = true [target.'cfg(windows)'.dependencies] windows.workspace = true diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index d18e783995..7ed4d4609a 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -37,6 +37,7 @@ use ui::{ use util::ResultExt; use workspace::{notifications::NotifyResultExt, Workspace}; use zed_actions::{OpenBrowser, OpenRecent, OpenRemote}; +use zed_predict_onboarding::ZedPredictBanner; #[cfg(feature = "stories")] pub use stories::*; @@ -113,6 +114,7 @@ pub struct TitleBar { application_menu: Option>, _subscriptions: Vec, git_ui_enabled: Arc, + zed_predict_banner: Entity, } impl Render for TitleBar { @@ -196,6 +198,7 @@ impl Render for TitleBar { .on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation()), ) .child(self.render_collaborator_list(window, cx)) + .child(self.zed_predict_banner.clone()) .child( h_flex() .gap_1() @@ -271,6 +274,7 @@ impl TitleBar { let project = workspace.project().clone(); let user_store = workspace.app_state().user_store.clone(); let client = workspace.app_state().client.clone(); + let fs = workspace.app_state().fs.clone(); let active_call = ActiveCall::global(cx); let platform_style = PlatformStyle::platform(); @@ -306,6 +310,16 @@ impl TitleBar { } })); + let zed_predict_banner = cx.new(|cx| { + ZedPredictBanner::new( + workspace.weak_handle(), + user_store.clone(), + client.clone(), + fs.clone(), + cx, + ) + }); + Self { platform_style, content: div().id(id.into()), @@ -319,6 +333,7 @@ impl TitleBar { client, _subscriptions: subscriptions, git_ui_enabled: is_git_ui_enabled, + zed_predict_banner, } } diff --git a/crates/ui/src/components/context_menu.rs b/crates/ui/src/components/context_menu.rs index b4e01a05d2..765c216ccd 100644 --- a/crates/ui/src/components/context_menu.rs +++ b/crates/ui/src/components/context_menu.rs @@ -64,6 +64,11 @@ impl ContextMenuEntry { } } + pub fn toggleable(mut self, toggle_position: IconPosition, toggled: bool) -> Self { + self.toggle = Some((toggle_position, toggled)); + self + } + pub fn icon(mut self, icon: IconName) -> Self { self.icon = Some(icon); self diff --git a/crates/workspace/src/notifications.rs b/crates/workspace/src/notifications.rs index e805527b91..145c0b7aec 100644 --- a/crates/workspace/src/notifications.rs +++ b/crates/workspace/src/notifications.rs @@ -379,6 +379,12 @@ pub mod simple_message_notification { click_message: Option, secondary_click_message: Option, secondary_on_click: Option)>>, + tertiary_click_message: Option, + tertiary_on_click: Option)>>, + more_info_message: Option, + more_info_url: Option>, + show_close_button: bool, + title: Option, } impl EventEmitter for MessageNotification {} @@ -402,6 +408,12 @@ pub mod simple_message_notification { click_message: None, secondary_on_click: None, secondary_click_message: None, + tertiary_on_click: None, + tertiary_click_message: None, + more_info_message: None, + more_info_url: None, + show_close_button: true, + title: None, } } @@ -437,31 +449,85 @@ pub mod simple_message_notification { self } + pub fn with_tertiary_click_message(mut self, message: S) -> Self + where + S: Into, + { + self.tertiary_click_message = Some(message.into()); + self + } + + pub fn on_tertiary_click(mut self, on_click: F) -> Self + where + F: 'static + Fn(&mut Window, &mut Context), + { + self.tertiary_on_click = Some(Arc::new(on_click)); + self + } + + pub fn more_info_message(mut self, message: S) -> Self + where + S: Into, + { + self.more_info_message = Some(message.into()); + self + } + + pub fn more_info_url(mut self, url: S) -> Self + where + S: Into>, + { + self.more_info_url = Some(url.into()); + self + } + pub fn dismiss(&mut self, cx: &mut Context) { cx.emit(DismissEvent); } + + pub fn show_close_button(mut self, show: bool) -> Self { + self.show_close_button = show; + self + } + + pub fn with_title(mut self, title: S) -> Self + where + S: Into, + { + self.title = Some(title.into()); + self + } } impl Render for MessageNotification { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() .p_3() - .gap_2() + .gap_3() .elevation_3(cx) .child( h_flex() .gap_4() .justify_between() .items_start() - .child(div().max_w_96().child((self.build_content)(window, cx))) .child( - IconButton::new("close", IconName::Close) - .on_click(cx.listener(|this, _, _, cx| this.dismiss(cx))), - ), + v_flex() + .gap_0p5() + .when_some(self.title.clone(), |element, title| { + element.child(Label::new(title)) + }) + .child(div().max_w_96().child((self.build_content)(window, cx))), + ) + .when(self.show_close_button, |this| { + this.child( + IconButton::new("close", IconName::Close) + .on_click(cx.listener(|this, _, _, cx| this.dismiss(cx))), + ) + }), ) .child( h_flex() - .gap_2() + .gap_1() .children(self.click_message.iter().map(|message| { Button::new(message.clone(), message.clone()) .label_size(LabelSize::Small) @@ -489,7 +555,40 @@ pub mod simple_message_notification { }; this.dismiss(cx) })) - })), + })) + .child( + h_flex() + .w_full() + .gap_1() + .justify_end() + .children(self.tertiary_click_message.iter().map(|message| { + Button::new(message.clone(), message.clone()) + .label_size(LabelSize::Small) + .on_click(cx.listener(|this, _, window, cx| { + if let Some(on_click) = this.tertiary_on_click.as_ref() + { + (on_click)(window, cx) + }; + this.dismiss(cx) + })) + })) + .children( + self.more_info_message + .iter() + .zip(self.more_info_url.iter()) + .map(|(message, url)| { + let url = url.clone(); + Button::new(message.clone(), message.clone()) + .label_size(LabelSize::Small) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::Indicator) + .icon_color(Color::Muted) + .on_click(cx.listener(move |_, _, _, cx| { + cx.open_url(&url); + })) + }), + ), + ), ) } } diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index bedadc41bd..89bcd53d6d 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -58,7 +58,9 @@ use persistence::{ SerializedWindowBounds, DB, }; use postage::stream::Stream; -use project::{DirectoryLister, Project, ProjectEntryId, ProjectPath, ResolvedPath, Worktree}; +use project::{ + DirectoryLister, Project, ProjectEntryId, ProjectPath, ResolvedPath, Worktree, WorktreeId, +}; use remote::{ssh_session::ConnectionIdentifier, SshClientDelegate, SshConnectionOptions}; use schemars::JsonSchema; use serde::Deserialize; @@ -2200,6 +2202,18 @@ impl Workspace { } } + pub fn absolute_path_of_worktree( + &self, + worktree_id: WorktreeId, + cx: &mut Context, + ) -> Option { + self.project + .read(cx) + .worktree_for_id(worktree_id, cx) + // TODO: use `abs_path` or `root_dir` + .map(|wt| wt.read(cx).abs_path().as_ref().to_path_buf()) + } + fn add_folder_to_project( &mut self, _: &AddFolderToProject, diff --git a/crates/worktree/src/worktree.rs b/crates/worktree/src/worktree.rs index 64d0b0a90a..2c3d75fd82 100644 --- a/crates/worktree/src/worktree.rs +++ b/crates/worktree/src/worktree.rs @@ -2751,6 +2751,8 @@ impl Snapshot { self.entry_for_path("") } + /// TODO: what's the difference between `root_dir` and `abs_path`? + /// is there any? if so, document it. pub fn root_dir(&self) -> Option> { self.root_entry() .filter(|entry| entry.is_dir()) diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 7e53c12ae1..a141c054f0 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -16,7 +16,7 @@ path = "src/main.rs" [dependencies] activity_indicator.workspace = true -zed_predict_tos.workspace = true +zed_predict_onboarding.workspace = true anyhow.workspace = true assets.workspace = true assistant.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 35103fbba5..881d256f5b 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -439,6 +439,7 @@ fn main() { inline_completion_registry::init( app_state.client.clone(), app_state.user_store.clone(), + app_state.fs.clone(), cx, ); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 422695aa4d..1cfc9460d4 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -176,6 +176,7 @@ pub fn initialize_workspace( workspace.weak_handle(), app_state.fs.clone(), app_state.user_store.clone(), + app_state.client.clone(), popover_menu_handle.clone(), cx, ) diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/inline_completion_registry.rs index 457cc4ae28..98d710e978 100644 --- a/crates/zed/src/zed/inline_completion_registry.rs +++ b/crates/zed/src/zed/inline_completion_registry.rs @@ -5,13 +5,17 @@ use collections::HashMap; use copilot::{Copilot, CopilotCompletionProvider}; use editor::{Editor, EditorMode}; use feature_flags::{FeatureFlagAppExt, PredictEditsFeatureFlag}; -use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity, Window}; +use fs::Fs; +use gpui::{AnyWindowHandle, App, AppContext, Context, Entity, WeakEntity}; use language::language_settings::{all_language_settings, InlineCompletionProvider}; use settings::SettingsStore; use supermaven::{Supermaven, SupermavenCompletionProvider}; -use zed_predict_tos::ZedPredictTos; +use ui::Window; +use workspace::Workspace; +use zed_predict_onboarding::ZedPredictModal; +use zeta::ProviderDataCollection; -pub fn init(client: Arc, user_store: Entity, cx: &mut App) { +pub fn init(client: Arc, user_store: Entity, fs: Arc, cx: &mut App) { let editors: Rc, AnyWindowHandle>>> = Rc::default(); cx.observe_new({ let editors = editors.clone(); @@ -37,6 +41,7 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { } }) .detach(); + editors .borrow_mut() .insert(editor_handle, window.window_handle()); @@ -91,6 +96,7 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { let editors = editors.clone(); let client = client.clone(); let user_store = user_store.clone(); + let fs = fs.clone(); move |cx| { let new_provider = all_language_settings(None, cx).inline_completions.provider; if new_provider != provider { @@ -123,9 +129,11 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { window .update(cx, |_, window, cx| { - ZedPredictTos::toggle( + ZedPredictModal::toggle( workspace, user_store.clone(), + client.clone(), + fs.clone(), window, cx, ); @@ -214,17 +222,19 @@ fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut Context, user_store: Entity, window: &mut Window, cx: &mut Context, ) { + let singleton_buffer = editor.buffer().read(cx).as_singleton(); + match provider { - language::language_settings::InlineCompletionProvider::None => {} - language::language_settings::InlineCompletionProvider::Copilot => { + InlineCompletionProvider::None => {} + InlineCompletionProvider::Copilot => { if let Some(copilot) = Copilot::global(cx) { - if let Some(buffer) = editor.buffer().read(cx).as_singleton() { + if let Some(buffer) = singleton_buffer { if buffer.read(cx).file().is_some() { copilot.update(cx, |copilot, cx| { copilot.register_buffer(&buffer, cx); @@ -235,26 +245,35 @@ fn assign_inline_completion_provider( editor.set_inline_completion_provider(Some(provider), window, cx); } } - language::language_settings::InlineCompletionProvider::Supermaven => { + InlineCompletionProvider::Supermaven => { if let Some(supermaven) = Supermaven::global(cx) { let provider = cx.new(|_| SupermavenCompletionProvider::new(supermaven)); editor.set_inline_completion_provider(Some(provider), window, cx); } } - - language::language_settings::InlineCompletionProvider::Zed => { + InlineCompletionProvider::Zed => { if cx.has_flag::() || (cfg!(debug_assertions) && client.status().borrow().is_connected()) { let zeta = zeta::Zeta::register(client.clone(), user_store, cx); - if let Some(buffer) = editor.buffer().read(cx).as_singleton() { + if let Some(buffer) = &singleton_buffer { if buffer.read(cx).file().is_some() { zeta.update(cx, |zeta, cx| { zeta.register_buffer(&buffer, cx); }); } } - let provider = cx.new(|_| zeta::ZetaInlineCompletionProvider::new(zeta)); + + let data_collection = ProviderDataCollection::new( + zeta.clone(), + window.root::().flatten(), + singleton_buffer, + cx, + ); + + let provider = + cx.new(|_| zeta::ZetaInlineCompletionProvider::new(zeta, data_collection)); + editor.set_inline_completion_provider(Some(provider), window, cx); } } diff --git a/crates/zed_predict_tos/Cargo.toml b/crates/zed_predict_onboarding/Cargo.toml similarity index 55% rename from crates/zed_predict_tos/Cargo.toml rename to crates/zed_predict_onboarding/Cargo.toml index 657cf4b1b0..c444f321e9 100644 --- a/crates/zed_predict_tos/Cargo.toml +++ b/crates/zed_predict_onboarding/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "zed_predict_tos" +name = "zed_predict_onboarding" version = "0.1.0" edition = "2021" publish = false @@ -9,15 +9,23 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/zed_predict_tos.rs" +path = "src/lib.rs" doctest = false [features] test-support = [] [dependencies] +chrono.workspace = true client.workspace = true +db.workspace = true +feature_flags.workspace = true +fs.workspace = true gpui.workspace = true -ui.workspace = true -workspace.workspace = true +language.workspace = true menu.workspace = true +settings.workspace = true +theme.workspace = true +ui.workspace = true +util.workspace = true +workspace.workspace = true diff --git a/crates/zed_predict_tos/LICENSE-GPL b/crates/zed_predict_onboarding/LICENSE-GPL similarity index 100% rename from crates/zed_predict_tos/LICENSE-GPL rename to crates/zed_predict_onboarding/LICENSE-GPL diff --git a/crates/zed_predict_onboarding/src/banner.rs b/crates/zed_predict_onboarding/src/banner.rs new file mode 100644 index 0000000000..76ca956be7 --- /dev/null +++ b/crates/zed_predict_onboarding/src/banner.rs @@ -0,0 +1,168 @@ +use std::sync::Arc; + +use crate::ZedPredictModal; +use chrono::Utc; +use client::{Client, UserStore}; +use feature_flags::{FeatureFlagAppExt as _, PredictEditsFeatureFlag}; +use fs::Fs; +use gpui::{Entity, Subscription, WeakEntity}; +use language::language_settings::{all_language_settings, InlineCompletionProvider}; +use settings::SettingsStore; +use ui::{prelude::*, ButtonLike, Tooltip}; +use util::ResultExt; +use workspace::Workspace; + +/// Prompts user to try AI inline prediction feature +pub struct ZedPredictBanner { + workspace: WeakEntity, + user_store: Entity, + client: Arc, + fs: Arc, + dismissed: bool, + _subscription: Subscription, +} + +impl ZedPredictBanner { + pub fn new( + workspace: WeakEntity, + user_store: Entity, + client: Arc, + fs: Arc, + cx: &mut Context, + ) -> Self { + Self { + workspace, + user_store, + client, + fs, + dismissed: get_dismissed(), + _subscription: cx.observe_global::(Self::handle_settings_changed), + } + } + + fn should_show(&self, cx: &mut App) -> bool { + if !cx.has_flag::() || self.dismissed { + return false; + } + + let provider = all_language_settings(None, cx).inline_completions.provider; + + match provider { + InlineCompletionProvider::None + | InlineCompletionProvider::Copilot + | InlineCompletionProvider::Supermaven => true, + InlineCompletionProvider::Zed => false, + } + } + + fn handle_settings_changed(&mut self, cx: &mut Context) { + if self.dismissed { + return; + } + + let provider = all_language_settings(None, cx).inline_completions.provider; + + match provider { + InlineCompletionProvider::None + | InlineCompletionProvider::Copilot + | InlineCompletionProvider::Supermaven => {} + InlineCompletionProvider::Zed => { + self.dismiss(cx); + } + } + } + + fn dismiss(&mut self, cx: &mut Context) { + persist_dismissed(cx); + self.dismissed = true; + cx.notify(); + } +} + +const DISMISSED_AT_KEY: &str = "zed_predict_banner_dismissed_at"; + +pub(crate) fn get_dismissed() -> bool { + db::kvp::KEY_VALUE_STORE + .read_kvp(DISMISSED_AT_KEY) + .log_err() + .map_or(false, |dismissed| dismissed.is_some()) +} + +pub(crate) fn persist_dismissed(cx: &mut App) { + cx.spawn(|_| { + let time = Utc::now().to_rfc3339(); + db::kvp::KEY_VALUE_STORE.write_kvp(DISMISSED_AT_KEY.into(), time) + }) + .detach_and_log_err(cx); +} + +impl Render for ZedPredictBanner { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + if !self.should_show(cx) { + return div(); + } + + let border_color = cx.theme().colors().editor_foreground.opacity(0.3); + let banner = h_flex() + .rounded_md() + .border_1() + .border_color(border_color) + .child( + ButtonLike::new("try-zed-predict") + .child( + h_flex() + .h_full() + .items_center() + .gap_1p5() + .child(Icon::new(IconName::ZedPredict).size(IconSize::Small)) + .child( + h_flex() + .gap_0p5() + .child( + Label::new("Introducing:") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child(Label::new("Edit Prediction").size(LabelSize::Small)), + ), + ) + .on_click({ + let workspace = self.workspace.clone(); + let user_store = self.user_store.clone(); + let client = self.client.clone(); + let fs = self.fs.clone(); + move |_, window, cx| { + let Some(workspace) = workspace.upgrade() else { + return; + }; + ZedPredictModal::toggle( + workspace, + user_store.clone(), + client.clone(), + fs.clone(), + window, + cx, + ); + } + }), + ) + .child( + div().border_l_1().border_color(border_color).child( + IconButton::new("close", IconName::Close) + .icon_size(IconSize::Indicator) + .on_click(cx.listener(|this, _, _window, cx| this.dismiss(cx))) + .tooltip(|window, cx| { + Tooltip::with_meta( + "Close Announcement Banner", + None, + "It won't show again for this feature", + window, + cx, + ) + }), + ), + ); + + div().pr_1().child(banner) + } +} diff --git a/crates/zed_predict_onboarding/src/lib.rs b/crates/zed_predict_onboarding/src/lib.rs new file mode 100644 index 0000000000..75c77843cd --- /dev/null +++ b/crates/zed_predict_onboarding/src/lib.rs @@ -0,0 +1,5 @@ +mod banner; +mod modal; + +pub use banner::ZedPredictBanner; +pub use modal::ZedPredictModal; diff --git a/crates/zed_predict_onboarding/src/modal.rs b/crates/zed_predict_onboarding/src/modal.rs new file mode 100644 index 0000000000..e353a40aec --- /dev/null +++ b/crates/zed_predict_onboarding/src/modal.rs @@ -0,0 +1,313 @@ +use std::{sync::Arc, time::Duration}; + +use client::{Client, UserStore}; +use feature_flags::FeatureFlagAppExt as _; +use fs::Fs; +use gpui::{ + ease_in_out, svg, Animation, AnimationExt as _, ClickEvent, DismissEvent, Entity, EventEmitter, + FocusHandle, Focusable, MouseDownEvent, Render, +}; +use language::language_settings::{AllLanguageSettings, InlineCompletionProvider}; +use settings::{update_settings_file, Settings}; +use ui::{prelude::*, CheckboxWithLabel, TintColor}; +use workspace::{notifications::NotifyTaskExt, ModalView, Workspace}; + +/// Introduces user to AI inline prediction feature and terms of service +pub struct ZedPredictModal { + user_store: Entity, + client: Arc, + fs: Arc, + focus_handle: FocusHandle, + sign_in_status: SignInStatus, + terms_of_service: bool, +} + +#[derive(PartialEq, Eq)] +enum SignInStatus { + /// Signed out or signed in but not from this modal + Idle, + /// Authentication triggered from this modal + Waiting, + /// Signed in after authentication from this modal + SignedIn, +} + +impl ZedPredictModal { + fn new( + user_store: Entity, + client: Arc, + fs: Arc, + cx: &mut Context, + ) -> Self { + ZedPredictModal { + user_store, + client, + fs, + focus_handle: cx.focus_handle(), + sign_in_status: SignInStatus::Idle, + terms_of_service: false, + } + } + + pub fn toggle( + workspace: Entity, + user_store: Entity, + client: Arc, + fs: Arc, + window: &mut Window, + cx: &mut App, + ) { + workspace.update(cx, |this, cx| { + this.toggle_modal(window, cx, |_window, cx| { + ZedPredictModal::new(user_store, client, fs, cx) + }); + }); + } + + fn view_terms(&mut self, _: &ClickEvent, _: &mut Window, cx: &mut Context) { + cx.open_url("https://zed.dev/terms-of-service"); + cx.notify(); + } + + fn view_blog(&mut self, _: &ClickEvent, _: &mut Window, cx: &mut Context) { + cx.open_url("https://zed.dev/blog/"); // TODO Add the link when live + cx.notify(); + } + + fn accept_and_enable(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context) { + let task = self + .user_store + .update(cx, |this, cx| this.accept_terms_of_service(cx)); + + cx.spawn(|this, mut cx| async move { + task.await?; + + this.update(&mut cx, |this, cx| { + update_settings_file::(this.fs.clone(), cx, move |file, _| { + file.features + .get_or_insert(Default::default()) + .inline_completion_provider = Some(InlineCompletionProvider::Zed); + }); + + cx.emit(DismissEvent); + }) + }) + .detach_and_notify_err(window, cx); + } + + fn sign_in(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context) { + let client = self.client.clone(); + self.sign_in_status = SignInStatus::Waiting; + + cx.spawn(move |this, mut cx| async move { + let result = client.authenticate_and_connect(true, &cx).await; + + let status = match result { + Ok(_) => SignInStatus::SignedIn, + Err(_) => SignInStatus::Idle, + }; + + this.update(&mut cx, |this, cx| { + this.sign_in_status = status; + cx.notify() + })?; + + result + }) + .detach_and_notify_err(window, cx); + } + + fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { + cx.emit(DismissEvent); + } +} + +impl EventEmitter for ZedPredictModal {} + +impl Focusable for ZedPredictModal { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl ModalView for ZedPredictModal {} + +impl Render for ZedPredictModal { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let base = v_flex() + .w(px(420.)) + .p_4() + .relative() + .gap_2() + .overflow_hidden() + .elevation_3(cx) + .id("zed predict tos") + .track_focus(&self.focus_handle(cx)) + .on_action(cx.listener(Self::cancel)) + .key_context("ZedPredictModal") + .on_action(cx.listener(|_, _: &menu::Cancel, _window, cx| { + cx.emit(DismissEvent); + })) + .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, _cx| { + this.focus_handle.focus(window); + })) + .child( + div() + .p_1p5() + .absolute() + .top_0() + .left_0() + .right_0() + .h(px(200.)) + .child( + svg() + .path("icons/zed_predict_bg.svg") + .text_color(cx.theme().colors().icon_disabled) + .w(px(416.)) + .h(px(128.)) + .overflow_hidden(), + ), + ) + .child( + h_flex() + .w_full() + .mb_2() + .justify_between() + .child( + v_flex() + .gap_1() + .child( + Label::new("Introducing Zed AI's") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child(Headline::new("Edit Prediction").size(HeadlineSize::Large)), + ) + .child({ + let tab = |n: usize| { + let text_color = cx.theme().colors().text; + let border_color = cx.theme().colors().text_accent.opacity(0.4); + + h_flex().child( + h_flex() + .px_4() + .py_0p5() + .bg(cx.theme().colors().editor_background) + .border_1() + .border_color(border_color) + .rounded_md() + .font(theme::ThemeSettings::get_global(cx).buffer_font.clone()) + .text_size(TextSize::XSmall.rems(cx)) + .text_color(text_color) + .child("tab") + .with_animation( + ElementId::Integer(n), + Animation::new(Duration::from_secs(2)).repeat(), + move |tab, delta| { + let delta = (delta - 0.15 * n as f32) / 0.7; + let delta = 1.0 - (0.5 - delta).abs() * 2.; + let delta = ease_in_out(delta.clamp(0., 1.)); + let delta = 0.1 + 0.9 * delta; + + tab.border_color(border_color.opacity(delta)) + .text_color(text_color.opacity(delta)) + }, + ), + ) + }; + + v_flex() + .gap_2() + .items_center() + .pr_4() + .child(tab(0).ml_neg_20()) + .child(tab(1)) + .child(tab(2).ml_20()) + }), + ) + .child(h_flex().absolute().top_2().right_2().child( + IconButton::new("cancel", IconName::X).on_click(cx.listener( + |_, _: &ClickEvent, _window, cx| { + cx.emit(DismissEvent); + }, + )), + )); + + let blog_post_button = if cx.is_staff() { + Some( + Button::new("view-blog", "Read the Blog Post") + .full_width() + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::Indicator) + .icon_color(Color::Muted) + .on_click(cx.listener(Self::view_blog)), + ) + } else { + // TODO: put back when blog post is published + None + }; + + if self.user_store.read(cx).current_user().is_some() { + let copy = match self.sign_in_status { + SignInStatus::Idle => "Get accurate and helpful edit predictions at every keystroke. To set Zed as your inline completions provider, ensure you:", + SignInStatus::SignedIn => "Almost there! Ensure you:", + SignInStatus::Waiting => unreachable!(), + }; + + base.child(Label::new(copy).color(Color::Muted)) + .child( + h_flex() + .gap_0p5() + .child(CheckboxWithLabel::new( + "tos-checkbox", + Label::new("Have read and accepted the").color(Color::Muted), + self.terms_of_service.into(), + cx.listener(move |this, state, _window, cx| { + this.terms_of_service = *state == ToggleState::Selected; + cx.notify() + }), + )) + .child( + Button::new("view-tos", "Terms of Service") + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::Indicator) + .icon_color(Color::Muted) + .on_click(cx.listener(Self::view_terms)), + ), + ) + .child( + v_flex() + .mt_2() + .gap_2() + .w_full() + .child( + Button::new("accept-tos", "Enable Edit Predictions") + .disabled(!self.terms_of_service) + .style(ButtonStyle::Tinted(TintColor::Accent)) + .full_width() + .on_click(cx.listener(Self::accept_and_enable)), + ) + .children(blog_post_button), + ) + } else { + base.child( + Label::new("To set Zed as your inline completions provider, please sign in.") + .color(Color::Muted), + ) + .child( + v_flex() + .mt_2() + .gap_2() + .w_full() + .child( + Button::new("accept-tos", "Sign in with GitHub") + .disabled(self.sign_in_status == SignInStatus::Waiting) + .style(ButtonStyle::Tinted(TintColor::Accent)) + .full_width() + .on_click(cx.listener(Self::sign_in)), + ) + .children(blog_post_button), + ) + } + } +} diff --git a/crates/zed_predict_tos/src/zed_predict_tos.rs b/crates/zed_predict_tos/src/zed_predict_tos.rs deleted file mode 100644 index 9d312bba75..0000000000 --- a/crates/zed_predict_tos/src/zed_predict_tos.rs +++ /dev/null @@ -1,155 +0,0 @@ -//! AI service Terms of Service acceptance modal. - -use client::UserStore; -use gpui::{ - App, ClickEvent, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, MouseDownEvent, - Render, -}; -use ui::{prelude::*, TintColor}; -use workspace::{ModalView, Workspace}; - -/// Terms of acceptance for AI inline prediction. -pub struct ZedPredictTos { - focus_handle: FocusHandle, - user_store: Entity, - workspace: Entity, - viewed: bool, -} - -impl ZedPredictTos { - fn new( - workspace: Entity, - user_store: Entity, - cx: &mut Context, - ) -> Self { - ZedPredictTos { - viewed: false, - focus_handle: cx.focus_handle(), - user_store, - workspace, - } - } - pub fn toggle( - workspace: Entity, - user_store: Entity, - window: &mut Window, - cx: &mut App, - ) { - workspace.update(cx, |this, cx| { - let workspace = cx.entity().clone(); - this.toggle_modal(window, cx, |_window, cx| { - ZedPredictTos::new(workspace, user_store, cx) - }); - }); - } - - fn view_terms(&mut self, _: &ClickEvent, _window: &mut Window, cx: &mut Context) { - self.viewed = true; - cx.open_url("https://zed.dev/terms-of-service"); - cx.notify(); - } - - fn accept_terms(&mut self, _: &ClickEvent, _window: &mut Window, cx: &mut Context) { - let task = self - .user_store - .update(cx, |this, cx| this.accept_terms_of_service(cx)); - - let workspace = self.workspace.clone(); - - cx.spawn(|this, mut cx| async move { - match task.await { - Ok(_) => this.update(&mut cx, |_, cx| { - cx.emit(DismissEvent); - }), - Err(err) => workspace.update(&mut cx, |this, cx| { - this.show_error(&err, cx); - }), - } - }) - .detach_and_log_err(cx); - } - - fn cancel(&mut self, _: &menu::Cancel, _window: &mut Window, cx: &mut Context) { - cx.emit(DismissEvent); - } -} - -impl EventEmitter for ZedPredictTos {} - -impl Focusable for ZedPredictTos { - fn focus_handle(&self, _cx: &App) -> FocusHandle { - self.focus_handle.clone() - } -} - -impl ModalView for ZedPredictTos {} - -impl Render for ZedPredictTos { - fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - v_flex() - .id("zed predict tos") - .track_focus(&self.focus_handle(cx)) - .on_action(cx.listener(Self::cancel)) - .key_context("ZedPredictTos") - .elevation_3(cx) - .w_96() - .items_center() - .p_4() - .gap_2() - .on_action(cx.listener(|_, _: &menu::Cancel, _window, cx| { - cx.emit(DismissEvent); - })) - .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, _cx| { - this.focus_handle.focus(window); - })) - .child( - h_flex() - .w_full() - .justify_between() - .child( - v_flex() - .gap_0p5() - .child( - Label::new("Zed AI") - .size(LabelSize::Small) - .color(Color::Muted), - ) - .child(Headline::new("Edit Prediction")), - ) - .child(Icon::new(IconName::ZedPredict).size(IconSize::XLarge)), - ) - .child( - Label::new( - "To use Zed AI's Edit Prediction feature, please read and accept our Terms of Service.", - ) - .color(Color::Muted), - ) - .child( - v_flex() - .mt_2() - .gap_0p5() - .w_full() - .child(if self.viewed { - Button::new("accept-tos", "I've Read and Accept the Terms of Service") - .style(ButtonStyle::Tinted(TintColor::Accent)) - .full_width() - .on_click(cx.listener(Self::accept_terms)) - } else { - Button::new("view-tos", "Read Terms of Service") - .style(ButtonStyle::Tinted(TintColor::Accent)) - .icon(IconName::ArrowUpRight) - .icon_size(IconSize::XSmall) - .icon_position(IconPosition::End) - .full_width() - .on_click(cx.listener(Self::view_terms)) - }) - .child( - Button::new("cancel", "Cancel") - .full_width() - .on_click(cx.listener(|_, _: &ClickEvent, _window, cx| { - cx.emit(DismissEvent); - })), - ), - ) - } -} diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index a849c315bb..3e4e3fde81 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -22,6 +22,7 @@ arrayvec.workspace = true client.workspace = true collections.workspace = true command_palette_hooks.workspace = true +db.workspace = true editor.workspace = true feature_flags.workspace = true futures.workspace = true @@ -34,6 +35,7 @@ language_models.workspace = true log.workspace = true menu.workspace = true rpc.workspace = true +serde.workspace = true serde_json.workspace = true settings.workspace = true similar.workspace = true diff --git a/crates/zeta/src/persistence.rs b/crates/zeta/src/persistence.rs new file mode 100644 index 0000000000..05a5b12f22 --- /dev/null +++ b/crates/zeta/src/persistence.rs @@ -0,0 +1,54 @@ +use anyhow::Result; +use collections::HashMap; +use std::path::{Path, PathBuf}; +use workspace::WorkspaceDb; + +use db::sqlez_macros::sql; +use db::{define_connection, query}; + +define_connection!( + pub static ref DB: ZetaDb = &[ + sql! ( + CREATE TABLE zeta_preferences( + worktree_path BLOB NOT NULL PRIMARY KEY, + accepted_data_collection INTEGER + ) STRICT; + ), + ]; +); + +impl ZetaDb { + pub fn get_all_zeta_preferences(&self) -> Result> { + Ok(self.get_all_zeta_preferences_query()?.into_iter().collect()) + } + + query! { + fn get_all_zeta_preferences_query() -> Result> { + SELECT worktree_path, accepted_data_collection FROM zeta_preferences + } + } + + query! { + pub fn get_accepted_data_collection(worktree_path: &Path) -> Result> { + SELECT accepted_data_collection FROM zeta_preferences + WHERE worktree_path = ? + } + } + + query! { + pub async fn save_accepted_data_collection(worktree_path: PathBuf, accepted_data_collection: bool) -> Result<()> { + INSERT INTO zeta_preferences + (worktree_path, accepted_data_collection) + VALUES + (?1, ?2) + ON CONFLICT (worktree_path) DO UPDATE SET + accepted_data_collection = ?2 + } + } + + query! { + pub async fn clear_all_zeta_preferences() -> Result<()> { + DELETE FROM zeta_preferences + } + } +} diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 99fae6600f..0fe08cd107 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -1,7 +1,10 @@ mod completion_diff_element; +mod persistence; mod rate_completion_modal; pub(crate) use completion_diff_element::*; +use db::kvp::KEY_VALUE_STORE; +use inline_completion::DataCollectionState; pub use rate_completion_modal::*; use anyhow::{anyhow, Context as _, Result}; @@ -12,6 +15,7 @@ use feature_flags::FeatureFlagAppExt as _; use futures::AsyncReadExt; use gpui::{ actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task, + WeakEntity, }; use http_client::{HttpClient, Method}; use language::{ @@ -20,26 +24,33 @@ use language::{ }; use language_models::LlmApiToken; use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME}; +use serde::{Deserialize, Serialize}; use std::{ borrow::Cow, - cmp, + cmp, env, fmt::Write, future::Future, mem, ops::Range, - path::Path, + path::{Path, PathBuf}, sync::Arc, time::{Duration, Instant}, }; use telemetry_events::InlineCompletionRating; use util::ResultExt; use uuid::Uuid; +use workspace::{ + notifications::{simple_message_notification::MessageNotification, NotificationId}, + Workspace, +}; const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>"; const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>"; const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>"; const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>"; const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1); +const ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY: &'static str = + "zed_predict_data_collection_never_ask_again"; // TODO(mgsloan): more systematic way to choose or tune these fairly arbitrary constants? @@ -187,6 +198,7 @@ pub struct Zeta { registered_buffers: HashMap, shown_completions: VecDeque, rated_completions: HashSet, + data_collection_preferences: DataCollectionPreferences, llm_token: LlmApiToken, _llm_token_subscription: Subscription, tos_accepted: bool, // Terms of service accepted @@ -216,13 +228,13 @@ impl Zeta { fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx); - Self { client, events: VecDeque::new(), shown_completions: VecDeque::new(), rated_completions: HashSet::default(), registered_buffers: HashMap::default(), + data_collection_preferences: Self::load_data_collection_preferences(cx), llm_token: LlmApiToken::default(), _llm_token_subscription: cx.subscribe( &refresh_llm_token_listener, @@ -240,11 +252,16 @@ impl Zeta { .read(cx) .current_user_has_accepted_terms() .unwrap_or(false), - _user_store_subscription: cx.subscribe(&user_store, |this, _, event, _| match event { - client::user::Event::TermsStatusUpdated { accepted } => { - this.tos_accepted = *accepted; + _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| { + match event { + client::user::Event::PrivateUserInfoUpdated => { + this.tos_accepted = user_store + .read(cx) + .current_user_has_accepted_terms() + .unwrap_or(false); + } + _ => {} } - _ => {} }), } } @@ -308,11 +325,8 @@ impl Zeta { event: &language::BufferEvent, cx: &mut Context, ) { - match event { - language::BufferEvent::Edited => { - self.report_changes_for_buffer(&buffer, cx); - } - _ => {} + if let language::BufferEvent::Edited = event { + self.report_changes_for_buffer(&buffer, cx); } } @@ -320,6 +334,7 @@ impl Zeta { &mut self, buffer: &Entity, cursor: language::Anchor, + can_collect_data: bool, cx: &mut Context, perform_predict_edits: F, ) -> Task>> @@ -370,6 +385,7 @@ impl Zeta { input_events: input_events.clone(), input_excerpt: input_excerpt.clone(), outline: Some(input_outline.clone()), + can_collect_data, }; let response = perform_predict_edits(client, llm_token, is_staff, body).await?; @@ -540,16 +556,25 @@ and then another ) -> Task>> { use std::future::ready; - self.request_completion_impl(buffer, position, cx, |_, _, _, _| ready(Ok(response))) + self.request_completion_impl(buffer, position, false, cx, |_, _, _, _| { + ready(Ok(response)) + }) } pub fn request_completion( &mut self, buffer: &Entity, position: language::Anchor, + can_collect_data: bool, cx: &mut Context, ) -> Task>> { - self.request_completion_impl(buffer, position, cx, Self::perform_predict_edits) + self.request_completion_impl( + buffer, + position, + can_collect_data, + cx, + Self::perform_predict_edits, + ) } fn perform_predict_edits( @@ -862,6 +887,80 @@ and then another new_snapshot } + + pub fn data_collection_choice_at(&self, path: &Path) -> DataCollectionChoice { + match self.data_collection_preferences.per_worktree.get(path) { + Some(true) => DataCollectionChoice::Enabled, + Some(false) => DataCollectionChoice::Disabled, + None => DataCollectionChoice::NotAnswered, + } + } + + fn update_data_collection_choice_for_worktree( + &mut self, + absolute_path_of_project_worktree: PathBuf, + can_collect_data: bool, + cx: &mut Context, + ) { + self.data_collection_preferences + .per_worktree + .insert(absolute_path_of_project_worktree.clone(), can_collect_data); + + db::write_and_log(cx, move || { + persistence::DB + .save_accepted_data_collection(absolute_path_of_project_worktree, can_collect_data) + }); + } + + fn set_never_ask_again_for_data_collection(&mut self, cx: &mut Context) { + self.data_collection_preferences.never_ask_again = true; + + // persist choice + db::write_and_log(cx, move || { + KEY_VALUE_STORE.write_kvp( + ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY.into(), + "true".to_string(), + ) + }); + } + + fn load_data_collection_preferences(cx: &mut Context) -> DataCollectionPreferences { + if env::var("ZED_PREDICT_CLEAR_DATA_COLLECTION_PREFERENCES").is_ok() { + db::write_and_log(cx, move || async move { + KEY_VALUE_STORE + .delete_kvp(ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY.into()) + .await + .log_err(); + + persistence::DB.clear_all_zeta_preferences().await + }); + return DataCollectionPreferences::default(); + } + + let never_ask_again = KEY_VALUE_STORE + .read_kvp(ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY) + .log_err() + .flatten() + .map(|value| value == "true") + .unwrap_or(false); + + let preferences_per_project = persistence::DB + .get_all_zeta_preferences() + .log_err() + .unwrap_or_else(HashMap::default); + + DataCollectionPreferences { + never_ask_again, + per_worktree: preferences_per_project, + } + } +} + +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +struct DataCollectionPreferences { + /// Set when a user clicks on "Never Ask Again", can never be unset. + never_ask_again: bool, + per_worktree: HashMap, } fn common_prefix, T2: Iterator>(a: T1, b: T2) -> usize { @@ -1276,22 +1375,120 @@ struct PendingCompletion { _task: Task<()>, } +#[derive(Clone, Copy)] +pub enum DataCollectionChoice { + NotAnswered, + Enabled, + Disabled, +} + +impl DataCollectionChoice { + pub fn is_enabled(&self) -> bool { + match self { + Self::Enabled => true, + Self::NotAnswered | Self::Disabled => false, + } + } + + pub fn is_answered(&self) -> bool { + match self { + Self::Enabled | Self::Disabled => true, + Self::NotAnswered => false, + } + } + + pub fn toggle(&self) -> DataCollectionChoice { + match self { + Self::Enabled => Self::Disabled, + Self::Disabled => Self::Enabled, + Self::NotAnswered => Self::Enabled, + } + } +} + pub struct ZetaInlineCompletionProvider { zeta: Entity, pending_completions: ArrayVec, next_pending_completion_id: usize, current_completion: Option, + data_collection: Option, +} + +pub struct ProviderDataCollection { + workspace: WeakEntity, + worktree_root_path: PathBuf, + choice: DataCollectionChoice, +} + +impl ProviderDataCollection { + pub fn new( + zeta: Entity, + workspace: Option>, + buffer: Option>, + cx: &mut App, + ) -> Option { + let workspace = workspace?; + + let worktree_root_path = buffer?.update(cx, |buffer, cx| { + let file = buffer.file()?; + + if !file.is_local() || file.is_private() { + return None; + } + + workspace.update(cx, |workspace, cx| { + Some( + workspace + .absolute_path_of_worktree(file.worktree_id(cx), cx)? + .to_path_buf(), + ) + }) + })?; + + let choice = zeta.read(cx).data_collection_choice_at(&worktree_root_path); + + Some(ProviderDataCollection { + workspace: workspace.downgrade(), + worktree_root_path, + choice, + }) + } + + fn set_choice(&mut self, choice: DataCollectionChoice, zeta: &Entity, cx: &mut App) { + self.choice = choice; + + let worktree_root_path = self.worktree_root_path.clone(); + + zeta.update(cx, |zeta, cx| { + zeta.update_data_collection_choice_for_worktree( + worktree_root_path, + choice.is_enabled(), + cx, + ) + }); + } + + fn toggle_choice(&mut self, zeta: &Entity, cx: &mut App) { + self.set_choice(self.choice.toggle(), zeta, cx); + } } impl ZetaInlineCompletionProvider { pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(8); - pub fn new(zeta: Entity) -> Self { + pub fn new(zeta: Entity, data_collection: Option) -> Self { Self { zeta, pending_completions: ArrayVec::new(), next_pending_completion_id: 0, current_completion: None, + data_collection, + } + } + + fn set_data_collection_choice(&mut self, choice: DataCollectionChoice, cx: &mut App) { + if let Some(data_collection) = self.data_collection.as_mut() { + data_collection.set_choice(choice, &self.zeta, cx); } } } @@ -1302,7 +1499,7 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide } fn display_name() -> &'static str { - "Zed Predict" + "Zed's Edit Predictions" } fn show_completions_in_menu() -> bool { @@ -1317,6 +1514,24 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide true } + fn data_collection_state(&self, _cx: &App) -> DataCollectionState { + let Some(data_collection) = self.data_collection.as_ref() else { + return DataCollectionState::Unknown; + }; + + if data_collection.choice.is_enabled() { + DataCollectionState::Enabled + } else { + DataCollectionState::Disabled + } + } + + fn toggle_data_collection(&mut self, cx: &mut App) { + if let Some(data_collection) = self.data_collection.as_mut() { + data_collection.toggle_choice(&self.zeta, cx); + } + } + fn is_enabled( &self, buffer: &Entity, @@ -1362,6 +1577,10 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide let pending_completion_id = self.next_pending_completion_id; self.next_pending_completion_id += 1; + let can_collect_data = self + .data_collection + .as_ref() + .map_or(false, |data_collection| data_collection.choice.is_enabled()); let task = cx.spawn(|this, mut cx| async move { if debounce { @@ -1370,7 +1589,7 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide let completion_request = this.update(&mut cx, |this, cx| { this.zeta.update(cx, |zeta, cx| { - zeta.request_completion(&buffer, position, cx) + zeta.request_completion(&buffer, position, can_collect_data, cx) }) }); @@ -1447,8 +1666,80 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide // Right now we don't support cycling. } - fn accept(&mut self, _cx: &mut Context) { + fn accept(&mut self, cx: &mut Context) { self.pending_completions.clear(); + + let Some(data_collection) = self.data_collection.as_mut() else { + return; + }; + + if data_collection.choice.is_answered() + || self + .zeta + .read(cx) + .data_collection_preferences + .never_ask_again + { + return; + } + + struct ZetaDataCollectionNotification; + let notification_id = NotificationId::unique::(); + + const DATA_COLLECTION_INFO_URL: &str = "https://zed.dev/terms-of-service"; // TODO: Replace for a link that's dedicated to Edit Predictions data collection + + let this = cx.entity(); + data_collection + .workspace + .update(cx, |workspace, cx| { + workspace.show_notification(notification_id, cx, |cx| { + let zeta = self.zeta.clone(); + + cx.new(move |_cx| { + let message = + "To allow Zed to suggest better edits, turn on data collection. You \ + can turn off at any time via the status bar menu."; + MessageNotification::new(message) + .with_title("Per-Project Data Collection Program") + .show_close_button(false) + .with_click_message("Turn On") + .on_click({ + let this = this.clone(); + move |_window, cx| { + this.update(cx, |this, cx| { + this.set_data_collection_choice( + DataCollectionChoice::Enabled, + cx, + ) + }); + } + }) + .with_secondary_click_message("Turn Off") + .on_secondary_click({ + move |_window, cx| { + this.update(cx, |this, cx| { + this.set_data_collection_choice( + DataCollectionChoice::Disabled, + cx, + ) + }); + } + }) + .with_tertiary_click_message("Never Ask Again") + .on_tertiary_click({ + let zeta = zeta.clone(); + move |_window, cx| { + zeta.update(cx, |zeta, cx| { + zeta.set_never_ask_again_for_data_collection(cx); + }); + } + }) + .more_info_message("Learn More") + .more_info_url(DATA_COLLECTION_INFO_URL) + }) + }); + }) + .log_err(); } fn discard(&mut self, _cx: &mut Context) { @@ -1688,8 +1979,9 @@ mod tests { let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); - let completion_task = - zeta.update(cx, |zeta, cx| zeta.request_completion(&buffer, cursor, cx)); + let completion_task = zeta.update(cx, |zeta, cx| { + zeta.request_completion(&buffer, cursor, false, cx) + }); let token_request = server.receive::().await.unwrap(); server.respond(