zeta: Onboarding and title bar banner (#23797)

Release Notes:

- N/A

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Danilo <danilo@zed.dev>
Co-authored-by: João Marcos <joao@zed.dev>
This commit is contained in:
Agus Zubiaga 2025-01-30 16:55:32 -03:00 committed by GitHub
parent 4ab372d6b5
commit e23e03592b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 1207 additions and 249 deletions

19
Cargo.lock generated
View file

@ -4021,7 +4021,7 @@ dependencies = [
"util",
"uuid",
"workspace",
"zed_predict_tos",
"zed_predict_onboarding",
]
[[package]]
@ -6415,7 +6415,7 @@ dependencies = [
"ui",
"workspace",
"zed_actions",
"zed_predict_tos",
"zed_predict_onboarding",
"zeta",
]
@ -13541,6 +13541,7 @@ dependencies = [
"windows 0.58.0",
"workspace",
"zed_actions",
"zed_predict_onboarding",
]
[[package]]
@ -16557,7 +16558,7 @@ dependencies = [
"winresource",
"workspace",
"zed_actions",
"zed_predict_tos",
"zed_predict_onboarding",
"zeta",
]
@ -16672,13 +16673,21 @@ dependencies = [
]
[[package]]
name = "zed_predict_tos"
name = "zed_predict_onboarding"
version = "0.1.0"
dependencies = [
"chrono",
"client",
"db",
"feature_flags",
"fs",
"gpui",
"language",
"menu",
"settings",
"theme",
"ui",
"util",
"workspace",
]
@ -16872,6 +16881,7 @@ dependencies = [
"collections",
"command_palette_hooks",
"ctor",
"db",
"editor",
"env_logger 0.11.6",
"feature_flags",
@ -16886,6 +16896,7 @@ dependencies = [
"menu",
"reqwest_client",
"rpc",
"serde",
"serde_json",
"settings",
"similar",

View file

@ -152,7 +152,7 @@ members = [
"crates/worktree",
"crates/zed",
"crates/zed_actions",
"crates/zed_predict_tos",
"crates/zed_predict_onboarding",
"crates/zeta",
#
@ -201,7 +201,6 @@ edition = "2021"
activity_indicator = { path = "crates/activity_indicator" }
ai = { path = "crates/ai" }
zed_predict_tos = { path = "crates/zed_predict_tos" }
anthropic = { path = "crates/anthropic" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
@ -350,6 +349,7 @@ workspace = { path = "crates/workspace" }
worktree = { path = "crates/worktree" }
zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
zed_predict_onboarding = { path = "crates/zed_predict_onboarding" }
zeta = { path = "crates/zeta" }
#

View file

@ -1,5 +1,5 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M7 8.9V11C5.34478 11 4.65522 11 3 11V10.4L7 5.6V5H3V7.1" stroke="black" stroke-width="1.5"/>
<path d="M12 5L14 8L12 11" stroke="black" stroke-width="1.5"/>
<path d="M10 6.5L11 8L10 9.5" stroke="black" stroke-width="1.5"/>
<path d="M7.5 8.9V11C5.43097 11 4.56903 11 2.5 11V10.4L7.5 5.6V5H2.5V7.1" stroke="black" stroke-width="1.5"/>
</svg>

Before

Width:  |  Height:  |  Size: 334 B

After

Width:  |  Height:  |  Size: 342 B

Before After
Before After

View file

@ -0,0 +1,19 @@
<svg width="420" height="128" xmlns="http://www.w3.org/2000/svg">
<defs>
<pattern id="tilePattern" width="22" height="22" patternUnits="userSpaceOnUse">
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M12 5L14 8L12 11" stroke="black" stroke-width="1.5"/>
<path d="M10 6.5L11 8L10 9.5" stroke="black" stroke-width="1.5"/>
<path d="M7.5 8.9V11C5.43097 11 4.56903 11 2.5 11V10.4L7.5 5.6V5H2.5V7.1" stroke="black" stroke-width="1.5"/>
</svg>
</pattern>
<linearGradient id="fade" y2="1" x2="0">
<stop offset="0" stop-color="white" stop-opacity=".24"/>
<stop offset="1" stop-color="white" stop-opacity="0"/>
</linearGradient>
<mask id="fadeMask" maskContentUnits="objectBoundingBox">
<rect width="1" height="1" fill="url(#fade)"/>
</mask>
</defs>
<rect width="100%" height="100%" fill="url(#tilePattern)" mask="url(#fadeMask)"/>
</svg>

After

Width:  |  Height:  |  Size: 971 B

View file

@ -823,5 +823,12 @@
"shift-end": "terminal::ScrollToBottom",
"ctrl-shift-space": "terminal::ToggleViMode"
}
},
{
"context": "ZedPredictModal",
"use_key_equivalents": true,
"bindings": {
"escape": "menu::Cancel"
}
}
]

View file

@ -883,7 +883,7 @@
}
},
{
"context": "ZedPredictTos",
"context": "ZedPredictModal",
"use_key_equivalents": true,
"bindings": {
"escape": "menu::Cancel"

View file

@ -121,9 +121,7 @@ pub enum Event {
},
ShowContacts,
ParticipantIndicesChanged,
TermsStatusUpdated {
accepted: bool,
},
PrivateUserInfoUpdated,
}
#[derive(Clone, Copy)]
@ -227,9 +225,7 @@ impl UserStore {
};
this.set_current_user_accepted_tos_at(accepted_tos_at);
cx.emit(Event::TermsStatusUpdated {
accepted: accepted_tos_at.is_some(),
});
cx.emit(Event::PrivateUserInfoUpdated);
})
} else {
anyhow::Ok(())
@ -244,6 +240,8 @@ impl UserStore {
Status::SignedOut => {
current_user_tx.send(None).await.ok();
this.update(&mut cx, |this, cx| {
this.accepted_tos_at = None;
cx.emit(Event::PrivateUserInfoUpdated);
cx.notify();
this.clear_contacts()
})?
@ -714,7 +712,7 @@ impl UserStore {
this.update(&mut cx, |this, cx| {
this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at));
cx.emit(Event::TermsStatusUpdated { accepted: true });
cx.emit(Event::PrivateUserInfoUpdated);
})
} else {
Err(anyhow!("client not found"))

View file

@ -447,7 +447,7 @@ async fn predict_edits(
));
}
let sample_input_output = claims.is_staff && rand::random::<f32>() < 0.1;
let should_sample = claims.is_staff || params.can_collect_data;
let api_url = state
.config
@ -541,7 +541,7 @@ async fn predict_edits(
let output = choice.text.clone();
async move {
let properties = if sample_input_output {
let properties = if should_sample {
json!({
"model": model.to_string(),
"headers": response.headers,

View file

@ -88,7 +88,7 @@ url.workspace = true
util.workspace = true
uuid.workspace = true
workspace.workspace = true
zed_predict_tos.workspace = true
zed_predict_onboarding.workspace = true
[dev-dependencies]
ctor.workspace = true

View file

@ -652,7 +652,7 @@ impl CompletionsMenu {
)
.on_click(cx.listener(move |editor, _event, window, cx| {
cx.stop_propagation();
editor.toggle_zed_predict_tos(window, cx);
editor.toggle_zed_predict_onboarding(window, cx);
})),
),

View file

@ -69,7 +69,7 @@ pub use element::{
};
use futures::{future, FutureExt};
use fuzzy::StringMatchCandidate;
use zed_predict_tos::ZedPredictTos;
use zed_predict_onboarding::ZedPredictModal;
use code_context_menus::{
AvailableCodeAction, CodeActionContents, CodeActionsItem, CodeActionsMenu, CodeContextMenu,
@ -3948,12 +3948,21 @@ impl Editor {
self.do_completion(action.item_ix, CompletionIntent::Compose, window, cx)
}
fn toggle_zed_predict_tos(&mut self, window: &mut Window, cx: &mut Context<Self>) {
fn toggle_zed_predict_onboarding(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let (Some(workspace), Some(project)) = (self.workspace(), self.project.as_ref()) else {
return;
};
ZedPredictTos::toggle(workspace, project.read(cx).user_store().clone(), window, cx);
let project = project.read(cx);
ZedPredictModal::toggle(
workspace,
project.user_store().clone(),
project.client().clone(),
project.fs().clone(),
window,
cx,
);
}
fn do_completion(
@ -3985,7 +3994,7 @@ impl Editor {
)) => {
drop(entries);
drop(context_menu);
self.toggle_zed_predict_tos(window, cx);
self.toggle_zed_predict_onboarding(window, cx);
return Some(Task::ready(Ok(())));
}
_ => {}

View file

@ -87,8 +87,8 @@ define_connection!(
// mtime_seconds: Option<i64>,
// mtime_nanos: Option<i32>,
// )
pub static ref DB: EditorDb<WorkspaceDb> =
&[sql! (
pub static ref DB: EditorDb<WorkspaceDb> = &[
sql! (
CREATE TABLE editors(
item_id INTEGER NOT NULL,
workspace_id INTEGER NOT NULL,
@ -134,7 +134,7 @@ define_connection!(
ALTER TABLE editors ADD COLUMN mtime_seconds INTEGER DEFAULT NULL;
ALTER TABLE editors ADD COLUMN mtime_nanos INTEGER DEFAULT NULL;
),
];
];
);
impl EditorDb {

View file

@ -18,6 +18,31 @@ pub struct InlineCompletion {
pub edit_preview: Option<language::EditPreview>,
}
pub enum DataCollectionState {
/// The provider doesn't support data collection.
Unsupported,
/// When there's a file not saved yet. In this case, we can't tell to which project it belongs.
Unknown,
/// Data collection is enabled
Enabled,
/// Data collection is disabled or unanswered.
Disabled,
}
impl DataCollectionState {
pub fn is_supported(&self) -> bool {
!matches!(self, DataCollectionState::Unsupported)
}
pub fn is_unknown(&self) -> bool {
matches!(self, DataCollectionState::Unknown)
}
pub fn is_enabled(&self) -> bool {
matches!(self, DataCollectionState::Enabled)
}
}
pub trait InlineCompletionProvider: 'static + Sized {
fn name() -> &'static str;
fn display_name() -> &'static str;
@ -26,6 +51,10 @@ pub trait InlineCompletionProvider: 'static + Sized {
fn show_tab_accept_marker() -> bool {
false
}
fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
DataCollectionState::Unsupported
}
fn toggle_data_collection(&mut self, _cx: &mut App) {}
fn is_enabled(
&self,
buffer: &Entity<Buffer>,
@ -72,6 +101,8 @@ pub trait InlineCompletionProviderHandle {
fn show_completions_in_menu(&self) -> bool;
fn show_completions_in_normal_mode(&self) -> bool;
fn show_tab_accept_marker(&self) -> bool;
fn data_collection_state(&self, cx: &App) -> DataCollectionState;
fn toggle_data_collection(&self, cx: &mut App);
fn needs_terms_acceptance(&self, cx: &App) -> bool;
fn is_refreshing(&self, cx: &App) -> bool;
fn refresh(
@ -122,6 +153,14 @@ where
T::show_tab_accept_marker()
}
fn data_collection_state(&self, cx: &App) -> DataCollectionState {
self.read(cx).data_collection_state(cx)
}
fn toggle_data_collection(&self, cx: &mut App) {
self.update(cx, |this, cx| this.toggle_data_collection(cx))
}
fn is_enabled(
&self,
buffer: &Entity<Buffer>,

View file

@ -29,7 +29,7 @@ workspace.workspace = true
zed_actions.workspace = true
zeta.workspace = true
client.workspace = true
zed_predict_tos.workspace = true
zed_predict_onboarding.workspace = true
[dev-dependencies]
copilot = { workspace = true, features = ["test-support"] }

View file

@ -1,5 +1,5 @@
use anyhow::Result;
use client::UserStore;
use client::{Client, UserStore};
use copilot::{Copilot, Status};
use editor::{actions::ShowInlineCompletion, scroll::Autoscroll, Editor};
use feature_flags::{
@ -20,18 +20,16 @@ use language::{
use settings::{update_settings_file, Settings, SettingsStore};
use std::{path::Path, sync::Arc, time::Duration};
use supermaven::{AccountStatus, Supermaven};
use ui::{prelude::*, ButtonLike, Color, Icon, IconWithIndicator, Indicator, PopoverMenuHandle};
use ui::{
prelude::*, ButtonLike, Clickable, ContextMenu, ContextMenuEntry, IconButton,
IconWithIndicator, Indicator, PopoverMenu, PopoverMenuHandle, Tooltip,
};
use workspace::{
create_and_open_local_file,
item::ItemHandle,
notifications::NotificationId,
ui::{
ButtonCommon, Clickable, ContextMenu, IconButton, IconName, IconSize, PopoverMenu, Tooltip,
},
StatusItemView, Toast, Workspace,
create_and_open_local_file, item::ItemHandle, notifications::NotificationId, StatusItemView,
Toast, Workspace,
};
use zed_actions::OpenBrowser;
use zed_predict_tos::ZedPredictTos;
use zed_predict_onboarding::ZedPredictModal;
use zeta::RateCompletionModal;
actions!(zeta, [RateCompletions]);
@ -48,6 +46,7 @@ pub struct InlineCompletionButton {
language: Option<Arc<Language>>,
file: Option<Arc<dyn File>>,
inline_completion_provider: Option<Arc<dyn inline_completion::InlineCompletionProviderHandle>>,
client: Arc<Client>,
fs: Arc<dyn Fs>,
workspace: WeakEntity<Workspace>,
user_store: Entity<UserStore>,
@ -231,14 +230,16 @@ impl Render for InlineCompletionButton {
return div();
}
if !self
.user_store
.read(cx)
.current_user_has_accepted_terms()
.unwrap_or(false)
{
let current_user_terms_accepted =
self.user_store.read(cx).current_user_has_accepted_terms();
if !current_user_terms_accepted.unwrap_or(false) {
let workspace = self.workspace.clone();
let user_store = self.user_store.clone();
let client = self.client.clone();
let fs = self.fs.clone();
let signed_in = current_user_terms_accepted.is_some();
return div().child(
ButtonLike::new("zeta-pending-tos-icon")
@ -252,20 +253,29 @@ impl Render for InlineCompletionButton {
))
.into_any_element(),
)
.tooltip(|window, cx| {
.tooltip(move |window, cx| {
Tooltip::with_meta(
"Edit Predictions",
None,
"Read Terms of Service",
if signed_in {
"Read Terms of Service"
} else {
"Sign in to use"
},
window,
cx,
)
})
.on_click(cx.listener(move |_, _, window, cx| {
let user_store = user_store.clone();
if let Some(workspace) = workspace.upgrade() {
ZedPredictTos::toggle(workspace, user_store, window, cx);
ZedPredictModal::toggle(
workspace,
user_store.clone(),
client.clone(),
fs.clone(),
window,
cx,
);
}
})),
);
@ -318,6 +328,7 @@ impl InlineCompletionButton {
workspace: WeakEntity<Workspace>,
fs: Arc<dyn Fs>,
user_store: Entity<UserStore>,
client: Arc<Client>,
popover_menu_handle: PopoverMenuHandle<ContextMenu>,
cx: &mut Context<Self>,
) -> Self {
@ -337,6 +348,7 @@ impl InlineCompletionButton {
inline_completion_provider: None,
popover_menu_handle,
workspace,
client,
fs,
user_store,
}
@ -430,6 +442,22 @@ impl InlineCompletionButton {
move |_, cx| toggle_inline_completions_globally(fs.clone(), cx),
);
if let Some(provider) = &self.inline_completion_provider {
let data_collection = provider.data_collection_state(cx);
if data_collection.is_supported() {
let provider = provider.clone();
menu = menu.separator().item(
ContextMenuEntry::new("Data Collection")
.toggleable(IconPosition::Start, data_collection.is_enabled())
.disabled(data_collection.is_unknown())
.handler(move |_, cx| {
provider.toggle_data_collection(cx);
}),
);
}
}
if let Some(editor_focus_handle) = self.editor_focus_handle.clone() {
menu = menu
.separator()

View file

@ -39,6 +39,9 @@ pub struct PredictEditsParams {
pub outline: Option<String>,
pub input_events: String,
pub input_excerpt: String,
/// Whether the user provided consent for sampling this interaction.
#[serde(default)]
pub can_collect_data: bool,
}
#[derive(Debug, Serialize, Deserialize)]

View file

@ -48,6 +48,7 @@ telemetry.workspace = true
workspace.workspace = true
zed_actions.workspace = true
git_ui.workspace = true
zed_predict_onboarding.workspace = true
[target.'cfg(windows)'.dependencies]
windows.workspace = true

View file

@ -37,6 +37,7 @@ use ui::{
use util::ResultExt;
use workspace::{notifications::NotifyResultExt, Workspace};
use zed_actions::{OpenBrowser, OpenRecent, OpenRemote};
use zed_predict_onboarding::ZedPredictBanner;
#[cfg(feature = "stories")]
pub use stories::*;
@ -113,6 +114,7 @@ pub struct TitleBar {
application_menu: Option<Entity<ApplicationMenu>>,
_subscriptions: Vec<Subscription>,
git_ui_enabled: Arc<AtomicBool>,
zed_predict_banner: Entity<ZedPredictBanner>,
}
impl Render for TitleBar {
@ -196,6 +198,7 @@ impl Render for TitleBar {
.on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation()),
)
.child(self.render_collaborator_list(window, cx))
.child(self.zed_predict_banner.clone())
.child(
h_flex()
.gap_1()
@ -271,6 +274,7 @@ impl TitleBar {
let project = workspace.project().clone();
let user_store = workspace.app_state().user_store.clone();
let client = workspace.app_state().client.clone();
let fs = workspace.app_state().fs.clone();
let active_call = ActiveCall::global(cx);
let platform_style = PlatformStyle::platform();
@ -306,6 +310,16 @@ impl TitleBar {
}
}));
let zed_predict_banner = cx.new(|cx| {
ZedPredictBanner::new(
workspace.weak_handle(),
user_store.clone(),
client.clone(),
fs.clone(),
cx,
)
});
Self {
platform_style,
content: div().id(id.into()),
@ -319,6 +333,7 @@ impl TitleBar {
client,
_subscriptions: subscriptions,
git_ui_enabled: is_git_ui_enabled,
zed_predict_banner,
}
}

View file

@ -64,6 +64,11 @@ impl ContextMenuEntry {
}
}
pub fn toggleable(mut self, toggle_position: IconPosition, toggled: bool) -> Self {
self.toggle = Some((toggle_position, toggled));
self
}
pub fn icon(mut self, icon: IconName) -> Self {
self.icon = Some(icon);
self

View file

@ -379,6 +379,12 @@ pub mod simple_message_notification {
click_message: Option<SharedString>,
secondary_click_message: Option<SharedString>,
secondary_on_click: Option<Arc<dyn Fn(&mut Window, &mut Context<Self>)>>,
tertiary_click_message: Option<SharedString>,
tertiary_on_click: Option<Arc<dyn Fn(&mut Window, &mut Context<Self>)>>,
more_info_message: Option<SharedString>,
more_info_url: Option<Arc<str>>,
show_close_button: bool,
title: Option<SharedString>,
}
impl EventEmitter<DismissEvent> for MessageNotification {}
@ -402,6 +408,12 @@ pub mod simple_message_notification {
click_message: None,
secondary_on_click: None,
secondary_click_message: None,
tertiary_on_click: None,
tertiary_click_message: None,
more_info_message: None,
more_info_url: None,
show_close_button: true,
title: None,
}
}
@ -437,31 +449,85 @@ pub mod simple_message_notification {
self
}
pub fn with_tertiary_click_message<S>(mut self, message: S) -> Self
where
S: Into<SharedString>,
{
self.tertiary_click_message = Some(message.into());
self
}
pub fn on_tertiary_click<F>(mut self, on_click: F) -> Self
where
F: 'static + Fn(&mut Window, &mut Context<Self>),
{
self.tertiary_on_click = Some(Arc::new(on_click));
self
}
pub fn more_info_message<S>(mut self, message: S) -> Self
where
S: Into<SharedString>,
{
self.more_info_message = Some(message.into());
self
}
pub fn more_info_url<S>(mut self, url: S) -> Self
where
S: Into<Arc<str>>,
{
self.more_info_url = Some(url.into());
self
}
pub fn dismiss(&mut self, cx: &mut Context<Self>) {
cx.emit(DismissEvent);
}
pub fn show_close_button(mut self, show: bool) -> Self {
self.show_close_button = show;
self
}
pub fn with_title<S>(mut self, title: S) -> Self
where
S: Into<SharedString>,
{
self.title = Some(title.into());
self
}
}
impl Render for MessageNotification {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.p_3()
.gap_2()
.gap_3()
.elevation_3(cx)
.child(
h_flex()
.gap_4()
.justify_between()
.items_start()
.child(div().max_w_96().child((self.build_content)(window, cx)))
.child(
IconButton::new("close", IconName::Close)
.on_click(cx.listener(|this, _, _, cx| this.dismiss(cx))),
),
v_flex()
.gap_0p5()
.when_some(self.title.clone(), |element, title| {
element.child(Label::new(title))
})
.child(div().max_w_96().child((self.build_content)(window, cx))),
)
.when(self.show_close_button, |this| {
this.child(
IconButton::new("close", IconName::Close)
.on_click(cx.listener(|this, _, _, cx| this.dismiss(cx))),
)
}),
)
.child(
h_flex()
.gap_2()
.gap_1()
.children(self.click_message.iter().map(|message| {
Button::new(message.clone(), message.clone())
.label_size(LabelSize::Small)
@ -489,7 +555,40 @@ pub mod simple_message_notification {
};
this.dismiss(cx)
}))
})),
}))
.child(
h_flex()
.w_full()
.gap_1()
.justify_end()
.children(self.tertiary_click_message.iter().map(|message| {
Button::new(message.clone(), message.clone())
.label_size(LabelSize::Small)
.on_click(cx.listener(|this, _, window, cx| {
if let Some(on_click) = this.tertiary_on_click.as_ref()
{
(on_click)(window, cx)
};
this.dismiss(cx)
}))
}))
.children(
self.more_info_message
.iter()
.zip(self.more_info_url.iter())
.map(|(message, url)| {
let url = url.clone();
Button::new(message.clone(), message.clone())
.label_size(LabelSize::Small)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::Indicator)
.icon_color(Color::Muted)
.on_click(cx.listener(move |_, _, _, cx| {
cx.open_url(&url);
}))
}),
),
),
)
}
}

View file

@ -58,7 +58,9 @@ use persistence::{
SerializedWindowBounds, DB,
};
use postage::stream::Stream;
use project::{DirectoryLister, Project, ProjectEntryId, ProjectPath, ResolvedPath, Worktree};
use project::{
DirectoryLister, Project, ProjectEntryId, ProjectPath, ResolvedPath, Worktree, WorktreeId,
};
use remote::{ssh_session::ConnectionIdentifier, SshClientDelegate, SshConnectionOptions};
use schemars::JsonSchema;
use serde::Deserialize;
@ -2200,6 +2202,18 @@ impl Workspace {
}
}
pub fn absolute_path_of_worktree(
&self,
worktree_id: WorktreeId,
cx: &mut Context<Self>,
) -> Option<PathBuf> {
self.project
.read(cx)
.worktree_for_id(worktree_id, cx)
// TODO: use `abs_path` or `root_dir`
.map(|wt| wt.read(cx).abs_path().as_ref().to_path_buf())
}
fn add_folder_to_project(
&mut self,
_: &AddFolderToProject,

View file

@ -2751,6 +2751,8 @@ impl Snapshot {
self.entry_for_path("")
}
/// TODO: what's the difference between `root_dir` and `abs_path`?
/// is there any? if so, document it.
pub fn root_dir(&self) -> Option<Arc<Path>> {
self.root_entry()
.filter(|entry| entry.is_dir())

View file

@ -16,7 +16,7 @@ path = "src/main.rs"
[dependencies]
activity_indicator.workspace = true
zed_predict_tos.workspace = true
zed_predict_onboarding.workspace = true
anyhow.workspace = true
assets.workspace = true
assistant.workspace = true

View file

@ -439,6 +439,7 @@ fn main() {
inline_completion_registry::init(
app_state.client.clone(),
app_state.user_store.clone(),
app_state.fs.clone(),
cx,
);
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx);

View file

@ -176,6 +176,7 @@ pub fn initialize_workspace(
workspace.weak_handle(),
app_state.fs.clone(),
app_state.user_store.clone(),
app_state.client.clone(),
popover_menu_handle.clone(),
cx,
)

View file

@ -5,13 +5,17 @@ use collections::HashMap;
use copilot::{Copilot, CopilotCompletionProvider};
use editor::{Editor, EditorMode};
use feature_flags::{FeatureFlagAppExt, PredictEditsFeatureFlag};
use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity, Window};
use fs::Fs;
use gpui::{AnyWindowHandle, App, AppContext, Context, Entity, WeakEntity};
use language::language_settings::{all_language_settings, InlineCompletionProvider};
use settings::SettingsStore;
use supermaven::{Supermaven, SupermavenCompletionProvider};
use zed_predict_tos::ZedPredictTos;
use ui::Window;
use workspace::Workspace;
use zed_predict_onboarding::ZedPredictModal;
use zeta::ProviderDataCollection;
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, fs: Arc<dyn Fs>, cx: &mut App) {
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
cx.observe_new({
let editors = editors.clone();
@ -37,6 +41,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
}
})
.detach();
editors
.borrow_mut()
.insert(editor_handle, window.window_handle());
@ -91,6 +96,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
let editors = editors.clone();
let client = client.clone();
let user_store = user_store.clone();
let fs = fs.clone();
move |cx| {
let new_provider = all_language_settings(None, cx).inline_completions.provider;
if new_provider != provider {
@ -123,9 +129,11 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
window
.update(cx, |_, window, cx| {
ZedPredictTos::toggle(
ZedPredictModal::toggle(
workspace,
user_store.clone(),
client.clone(),
fs.clone(),
window,
cx,
);
@ -214,17 +222,19 @@ fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut Context<Ed
fn assign_inline_completion_provider(
editor: &mut Editor,
provider: language::language_settings::InlineCompletionProvider,
provider: InlineCompletionProvider,
client: &Arc<Client>,
user_store: Entity<UserStore>,
window: &mut Window,
cx: &mut Context<Editor>,
) {
let singleton_buffer = editor.buffer().read(cx).as_singleton();
match provider {
language::language_settings::InlineCompletionProvider::None => {}
language::language_settings::InlineCompletionProvider::Copilot => {
InlineCompletionProvider::None => {}
InlineCompletionProvider::Copilot => {
if let Some(copilot) = Copilot::global(cx) {
if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
if let Some(buffer) = singleton_buffer {
if buffer.read(cx).file().is_some() {
copilot.update(cx, |copilot, cx| {
copilot.register_buffer(&buffer, cx);
@ -235,26 +245,35 @@ fn assign_inline_completion_provider(
editor.set_inline_completion_provider(Some(provider), window, cx);
}
}
language::language_settings::InlineCompletionProvider::Supermaven => {
InlineCompletionProvider::Supermaven => {
if let Some(supermaven) = Supermaven::global(cx) {
let provider = cx.new(|_| SupermavenCompletionProvider::new(supermaven));
editor.set_inline_completion_provider(Some(provider), window, cx);
}
}
language::language_settings::InlineCompletionProvider::Zed => {
InlineCompletionProvider::Zed => {
if cx.has_flag::<PredictEditsFeatureFlag>()
|| (cfg!(debug_assertions) && client.status().borrow().is_connected())
{
let zeta = zeta::Zeta::register(client.clone(), user_store, cx);
if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
if let Some(buffer) = &singleton_buffer {
if buffer.read(cx).file().is_some() {
zeta.update(cx, |zeta, cx| {
zeta.register_buffer(&buffer, cx);
});
}
}
let provider = cx.new(|_| zeta::ZetaInlineCompletionProvider::new(zeta));
let data_collection = ProviderDataCollection::new(
zeta.clone(),
window.root::<Workspace>().flatten(),
singleton_buffer,
cx,
);
let provider =
cx.new(|_| zeta::ZetaInlineCompletionProvider::new(zeta, data_collection));
editor.set_inline_completion_provider(Some(provider), window, cx);
}
}

View file

@ -1,5 +1,5 @@
[package]
name = "zed_predict_tos"
name = "zed_predict_onboarding"
version = "0.1.0"
edition = "2021"
publish = false
@ -9,15 +9,23 @@ license = "GPL-3.0-or-later"
workspace = true
[lib]
path = "src/zed_predict_tos.rs"
path = "src/lib.rs"
doctest = false
[features]
test-support = []
[dependencies]
chrono.workspace = true
client.workspace = true
db.workspace = true
feature_flags.workspace = true
fs.workspace = true
gpui.workspace = true
ui.workspace = true
workspace.workspace = true
language.workspace = true
menu.workspace = true
settings.workspace = true
theme.workspace = true
ui.workspace = true
util.workspace = true
workspace.workspace = true

View file

@ -0,0 +1,168 @@
use std::sync::Arc;
use crate::ZedPredictModal;
use chrono::Utc;
use client::{Client, UserStore};
use feature_flags::{FeatureFlagAppExt as _, PredictEditsFeatureFlag};
use fs::Fs;
use gpui::{Entity, Subscription, WeakEntity};
use language::language_settings::{all_language_settings, InlineCompletionProvider};
use settings::SettingsStore;
use ui::{prelude::*, ButtonLike, Tooltip};
use util::ResultExt;
use workspace::Workspace;
/// Prompts user to try AI inline prediction feature
pub struct ZedPredictBanner {
workspace: WeakEntity<Workspace>,
user_store: Entity<UserStore>,
client: Arc<Client>,
fs: Arc<dyn Fs>,
dismissed: bool,
_subscription: Subscription,
}
impl ZedPredictBanner {
pub fn new(
workspace: WeakEntity<Workspace>,
user_store: Entity<UserStore>,
client: Arc<Client>,
fs: Arc<dyn Fs>,
cx: &mut Context<Self>,
) -> Self {
Self {
workspace,
user_store,
client,
fs,
dismissed: get_dismissed(),
_subscription: cx.observe_global::<SettingsStore>(Self::handle_settings_changed),
}
}
fn should_show(&self, cx: &mut App) -> bool {
if !cx.has_flag::<PredictEditsFeatureFlag>() || self.dismissed {
return false;
}
let provider = all_language_settings(None, cx).inline_completions.provider;
match provider {
InlineCompletionProvider::None
| InlineCompletionProvider::Copilot
| InlineCompletionProvider::Supermaven => true,
InlineCompletionProvider::Zed => false,
}
}
fn handle_settings_changed(&mut self, cx: &mut Context<Self>) {
if self.dismissed {
return;
}
let provider = all_language_settings(None, cx).inline_completions.provider;
match provider {
InlineCompletionProvider::None
| InlineCompletionProvider::Copilot
| InlineCompletionProvider::Supermaven => {}
InlineCompletionProvider::Zed => {
self.dismiss(cx);
}
}
}
fn dismiss(&mut self, cx: &mut Context<Self>) {
persist_dismissed(cx);
self.dismissed = true;
cx.notify();
}
}
const DISMISSED_AT_KEY: &str = "zed_predict_banner_dismissed_at";
pub(crate) fn get_dismissed() -> bool {
db::kvp::KEY_VALUE_STORE
.read_kvp(DISMISSED_AT_KEY)
.log_err()
.map_or(false, |dismissed| dismissed.is_some())
}
pub(crate) fn persist_dismissed(cx: &mut App) {
cx.spawn(|_| {
let time = Utc::now().to_rfc3339();
db::kvp::KEY_VALUE_STORE.write_kvp(DISMISSED_AT_KEY.into(), time)
})
.detach_and_log_err(cx);
}
impl Render for ZedPredictBanner {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
if !self.should_show(cx) {
return div();
}
let border_color = cx.theme().colors().editor_foreground.opacity(0.3);
let banner = h_flex()
.rounded_md()
.border_1()
.border_color(border_color)
.child(
ButtonLike::new("try-zed-predict")
.child(
h_flex()
.h_full()
.items_center()
.gap_1p5()
.child(Icon::new(IconName::ZedPredict).size(IconSize::Small))
.child(
h_flex()
.gap_0p5()
.child(
Label::new("Introducing:")
.size(LabelSize::Small)
.color(Color::Muted),
)
.child(Label::new("Edit Prediction").size(LabelSize::Small)),
),
)
.on_click({
let workspace = self.workspace.clone();
let user_store = self.user_store.clone();
let client = self.client.clone();
let fs = self.fs.clone();
move |_, window, cx| {
let Some(workspace) = workspace.upgrade() else {
return;
};
ZedPredictModal::toggle(
workspace,
user_store.clone(),
client.clone(),
fs.clone(),
window,
cx,
);
}
}),
)
.child(
div().border_l_1().border_color(border_color).child(
IconButton::new("close", IconName::Close)
.icon_size(IconSize::Indicator)
.on_click(cx.listener(|this, _, _window, cx| this.dismiss(cx)))
.tooltip(|window, cx| {
Tooltip::with_meta(
"Close Announcement Banner",
None,
"It won't show again for this feature",
window,
cx,
)
}),
),
);
div().pr_1().child(banner)
}
}

View file

@ -0,0 +1,5 @@
mod banner;
mod modal;
pub use banner::ZedPredictBanner;
pub use modal::ZedPredictModal;

View file

@ -0,0 +1,313 @@
use std::{sync::Arc, time::Duration};
use client::{Client, UserStore};
use feature_flags::FeatureFlagAppExt as _;
use fs::Fs;
use gpui::{
ease_in_out, svg, Animation, AnimationExt as _, ClickEvent, DismissEvent, Entity, EventEmitter,
FocusHandle, Focusable, MouseDownEvent, Render,
};
use language::language_settings::{AllLanguageSettings, InlineCompletionProvider};
use settings::{update_settings_file, Settings};
use ui::{prelude::*, CheckboxWithLabel, TintColor};
use workspace::{notifications::NotifyTaskExt, ModalView, Workspace};
/// Introduces user to AI inline prediction feature and terms of service
pub struct ZedPredictModal {
user_store: Entity<UserStore>,
client: Arc<Client>,
fs: Arc<dyn Fs>,
focus_handle: FocusHandle,
sign_in_status: SignInStatus,
terms_of_service: bool,
}
#[derive(PartialEq, Eq)]
enum SignInStatus {
/// Signed out or signed in but not from this modal
Idle,
/// Authentication triggered from this modal
Waiting,
/// Signed in after authentication from this modal
SignedIn,
}
impl ZedPredictModal {
fn new(
user_store: Entity<UserStore>,
client: Arc<Client>,
fs: Arc<dyn Fs>,
cx: &mut Context<Self>,
) -> Self {
ZedPredictModal {
user_store,
client,
fs,
focus_handle: cx.focus_handle(),
sign_in_status: SignInStatus::Idle,
terms_of_service: false,
}
}
pub fn toggle(
workspace: Entity<Workspace>,
user_store: Entity<UserStore>,
client: Arc<Client>,
fs: Arc<dyn Fs>,
window: &mut Window,
cx: &mut App,
) {
workspace.update(cx, |this, cx| {
this.toggle_modal(window, cx, |_window, cx| {
ZedPredictModal::new(user_store, client, fs, cx)
});
});
}
fn view_terms(&mut self, _: &ClickEvent, _: &mut Window, cx: &mut Context<Self>) {
cx.open_url("https://zed.dev/terms-of-service");
cx.notify();
}
fn view_blog(&mut self, _: &ClickEvent, _: &mut Window, cx: &mut Context<Self>) {
cx.open_url("https://zed.dev/blog/"); // TODO Add the link when live
cx.notify();
}
fn accept_and_enable(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
let task = self
.user_store
.update(cx, |this, cx| this.accept_terms_of_service(cx));
cx.spawn(|this, mut cx| async move {
task.await?;
this.update(&mut cx, |this, cx| {
update_settings_file::<AllLanguageSettings>(this.fs.clone(), cx, move |file, _| {
file.features
.get_or_insert(Default::default())
.inline_completion_provider = Some(InlineCompletionProvider::Zed);
});
cx.emit(DismissEvent);
})
})
.detach_and_notify_err(window, cx);
}
fn sign_in(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
let client = self.client.clone();
self.sign_in_status = SignInStatus::Waiting;
cx.spawn(move |this, mut cx| async move {
let result = client.authenticate_and_connect(true, &cx).await;
let status = match result {
Ok(_) => SignInStatus::SignedIn,
Err(_) => SignInStatus::Idle,
};
this.update(&mut cx, |this, cx| {
this.sign_in_status = status;
cx.notify()
})?;
result
})
.detach_and_notify_err(window, cx);
}
fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
cx.emit(DismissEvent);
}
}
impl EventEmitter<DismissEvent> for ZedPredictModal {}
impl Focusable for ZedPredictModal {
fn focus_handle(&self, _cx: &App) -> FocusHandle {
self.focus_handle.clone()
}
}
impl ModalView for ZedPredictModal {}
impl Render for ZedPredictModal {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let base = v_flex()
.w(px(420.))
.p_4()
.relative()
.gap_2()
.overflow_hidden()
.elevation_3(cx)
.id("zed predict tos")
.track_focus(&self.focus_handle(cx))
.on_action(cx.listener(Self::cancel))
.key_context("ZedPredictModal")
.on_action(cx.listener(|_, _: &menu::Cancel, _window, cx| {
cx.emit(DismissEvent);
}))
.on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, _cx| {
this.focus_handle.focus(window);
}))
.child(
div()
.p_1p5()
.absolute()
.top_0()
.left_0()
.right_0()
.h(px(200.))
.child(
svg()
.path("icons/zed_predict_bg.svg")
.text_color(cx.theme().colors().icon_disabled)
.w(px(416.))
.h(px(128.))
.overflow_hidden(),
),
)
.child(
h_flex()
.w_full()
.mb_2()
.justify_between()
.child(
v_flex()
.gap_1()
.child(
Label::new("Introducing Zed AI's")
.size(LabelSize::Small)
.color(Color::Muted),
)
.child(Headline::new("Edit Prediction").size(HeadlineSize::Large)),
)
.child({
let tab = |n: usize| {
let text_color = cx.theme().colors().text;
let border_color = cx.theme().colors().text_accent.opacity(0.4);
h_flex().child(
h_flex()
.px_4()
.py_0p5()
.bg(cx.theme().colors().editor_background)
.border_1()
.border_color(border_color)
.rounded_md()
.font(theme::ThemeSettings::get_global(cx).buffer_font.clone())
.text_size(TextSize::XSmall.rems(cx))
.text_color(text_color)
.child("tab")
.with_animation(
ElementId::Integer(n),
Animation::new(Duration::from_secs(2)).repeat(),
move |tab, delta| {
let delta = (delta - 0.15 * n as f32) / 0.7;
let delta = 1.0 - (0.5 - delta).abs() * 2.;
let delta = ease_in_out(delta.clamp(0., 1.));
let delta = 0.1 + 0.9 * delta;
tab.border_color(border_color.opacity(delta))
.text_color(text_color.opacity(delta))
},
),
)
};
v_flex()
.gap_2()
.items_center()
.pr_4()
.child(tab(0).ml_neg_20())
.child(tab(1))
.child(tab(2).ml_20())
}),
)
.child(h_flex().absolute().top_2().right_2().child(
IconButton::new("cancel", IconName::X).on_click(cx.listener(
|_, _: &ClickEvent, _window, cx| {
cx.emit(DismissEvent);
},
)),
));
let blog_post_button = if cx.is_staff() {
Some(
Button::new("view-blog", "Read the Blog Post")
.full_width()
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::Indicator)
.icon_color(Color::Muted)
.on_click(cx.listener(Self::view_blog)),
)
} else {
// TODO: put back when blog post is published
None
};
if self.user_store.read(cx).current_user().is_some() {
let copy = match self.sign_in_status {
SignInStatus::Idle => "Get accurate and helpful edit predictions at every keystroke. To set Zed as your inline completions provider, ensure you:",
SignInStatus::SignedIn => "Almost there! Ensure you:",
SignInStatus::Waiting => unreachable!(),
};
base.child(Label::new(copy).color(Color::Muted))
.child(
h_flex()
.gap_0p5()
.child(CheckboxWithLabel::new(
"tos-checkbox",
Label::new("Have read and accepted the").color(Color::Muted),
self.terms_of_service.into(),
cx.listener(move |this, state, _window, cx| {
this.terms_of_service = *state == ToggleState::Selected;
cx.notify()
}),
))
.child(
Button::new("view-tos", "Terms of Service")
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::Indicator)
.icon_color(Color::Muted)
.on_click(cx.listener(Self::view_terms)),
),
)
.child(
v_flex()
.mt_2()
.gap_2()
.w_full()
.child(
Button::new("accept-tos", "Enable Edit Predictions")
.disabled(!self.terms_of_service)
.style(ButtonStyle::Tinted(TintColor::Accent))
.full_width()
.on_click(cx.listener(Self::accept_and_enable)),
)
.children(blog_post_button),
)
} else {
base.child(
Label::new("To set Zed as your inline completions provider, please sign in.")
.color(Color::Muted),
)
.child(
v_flex()
.mt_2()
.gap_2()
.w_full()
.child(
Button::new("accept-tos", "Sign in with GitHub")
.disabled(self.sign_in_status == SignInStatus::Waiting)
.style(ButtonStyle::Tinted(TintColor::Accent))
.full_width()
.on_click(cx.listener(Self::sign_in)),
)
.children(blog_post_button),
)
}
}
}

View file

@ -1,155 +0,0 @@
//! AI service Terms of Service acceptance modal.
use client::UserStore;
use gpui::{
App, ClickEvent, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, MouseDownEvent,
Render,
};
use ui::{prelude::*, TintColor};
use workspace::{ModalView, Workspace};
/// Terms of acceptance for AI inline prediction.
pub struct ZedPredictTos {
focus_handle: FocusHandle,
user_store: Entity<UserStore>,
workspace: Entity<Workspace>,
viewed: bool,
}
impl ZedPredictTos {
fn new(
workspace: Entity<Workspace>,
user_store: Entity<UserStore>,
cx: &mut Context<Self>,
) -> Self {
ZedPredictTos {
viewed: false,
focus_handle: cx.focus_handle(),
user_store,
workspace,
}
}
pub fn toggle(
workspace: Entity<Workspace>,
user_store: Entity<UserStore>,
window: &mut Window,
cx: &mut App,
) {
workspace.update(cx, |this, cx| {
let workspace = cx.entity().clone();
this.toggle_modal(window, cx, |_window, cx| {
ZedPredictTos::new(workspace, user_store, cx)
});
});
}
fn view_terms(&mut self, _: &ClickEvent, _window: &mut Window, cx: &mut Context<Self>) {
self.viewed = true;
cx.open_url("https://zed.dev/terms-of-service");
cx.notify();
}
fn accept_terms(&mut self, _: &ClickEvent, _window: &mut Window, cx: &mut Context<Self>) {
let task = self
.user_store
.update(cx, |this, cx| this.accept_terms_of_service(cx));
let workspace = self.workspace.clone();
cx.spawn(|this, mut cx| async move {
match task.await {
Ok(_) => this.update(&mut cx, |_, cx| {
cx.emit(DismissEvent);
}),
Err(err) => workspace.update(&mut cx, |this, cx| {
this.show_error(&err, cx);
}),
}
})
.detach_and_log_err(cx);
}
fn cancel(&mut self, _: &menu::Cancel, _window: &mut Window, cx: &mut Context<Self>) {
cx.emit(DismissEvent);
}
}
impl EventEmitter<DismissEvent> for ZedPredictTos {}
impl Focusable for ZedPredictTos {
fn focus_handle(&self, _cx: &App) -> FocusHandle {
self.focus_handle.clone()
}
}
impl ModalView for ZedPredictTos {}
impl Render for ZedPredictTos {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.id("zed predict tos")
.track_focus(&self.focus_handle(cx))
.on_action(cx.listener(Self::cancel))
.key_context("ZedPredictTos")
.elevation_3(cx)
.w_96()
.items_center()
.p_4()
.gap_2()
.on_action(cx.listener(|_, _: &menu::Cancel, _window, cx| {
cx.emit(DismissEvent);
}))
.on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, _cx| {
this.focus_handle.focus(window);
}))
.child(
h_flex()
.w_full()
.justify_between()
.child(
v_flex()
.gap_0p5()
.child(
Label::new("Zed AI")
.size(LabelSize::Small)
.color(Color::Muted),
)
.child(Headline::new("Edit Prediction")),
)
.child(Icon::new(IconName::ZedPredict).size(IconSize::XLarge)),
)
.child(
Label::new(
"To use Zed AI's Edit Prediction feature, please read and accept our Terms of Service.",
)
.color(Color::Muted),
)
.child(
v_flex()
.mt_2()
.gap_0p5()
.w_full()
.child(if self.viewed {
Button::new("accept-tos", "I've Read and Accept the Terms of Service")
.style(ButtonStyle::Tinted(TintColor::Accent))
.full_width()
.on_click(cx.listener(Self::accept_terms))
} else {
Button::new("view-tos", "Read Terms of Service")
.style(ButtonStyle::Tinted(TintColor::Accent))
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::XSmall)
.icon_position(IconPosition::End)
.full_width()
.on_click(cx.listener(Self::view_terms))
})
.child(
Button::new("cancel", "Cancel")
.full_width()
.on_click(cx.listener(|_, _: &ClickEvent, _window, cx| {
cx.emit(DismissEvent);
})),
),
)
}
}

View file

@ -22,6 +22,7 @@ arrayvec.workspace = true
client.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
db.workspace = true
editor.workspace = true
feature_flags.workspace = true
futures.workspace = true
@ -34,6 +35,7 @@ language_models.workspace = true
log.workspace = true
menu.workspace = true
rpc.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
similar.workspace = true

View file

@ -0,0 +1,54 @@
use anyhow::Result;
use collections::HashMap;
use std::path::{Path, PathBuf};
use workspace::WorkspaceDb;
use db::sqlez_macros::sql;
use db::{define_connection, query};
define_connection!(
pub static ref DB: ZetaDb<WorkspaceDb> = &[
sql! (
CREATE TABLE zeta_preferences(
worktree_path BLOB NOT NULL PRIMARY KEY,
accepted_data_collection INTEGER
) STRICT;
),
];
);
impl ZetaDb {
pub fn get_all_zeta_preferences(&self) -> Result<HashMap<PathBuf, bool>> {
Ok(self.get_all_zeta_preferences_query()?.into_iter().collect())
}
query! {
fn get_all_zeta_preferences_query() -> Result<Vec<(PathBuf, bool)>> {
SELECT worktree_path, accepted_data_collection FROM zeta_preferences
}
}
query! {
pub fn get_accepted_data_collection(worktree_path: &Path) -> Result<Option<bool>> {
SELECT accepted_data_collection FROM zeta_preferences
WHERE worktree_path = ?
}
}
query! {
pub async fn save_accepted_data_collection(worktree_path: PathBuf, accepted_data_collection: bool) -> Result<()> {
INSERT INTO zeta_preferences
(worktree_path, accepted_data_collection)
VALUES
(?1, ?2)
ON CONFLICT (worktree_path) DO UPDATE SET
accepted_data_collection = ?2
}
}
query! {
pub async fn clear_all_zeta_preferences() -> Result<()> {
DELETE FROM zeta_preferences
}
}
}

View file

@ -1,7 +1,10 @@
mod completion_diff_element;
mod persistence;
mod rate_completion_modal;
pub(crate) use completion_diff_element::*;
use db::kvp::KEY_VALUE_STORE;
use inline_completion::DataCollectionState;
pub use rate_completion_modal::*;
use anyhow::{anyhow, Context as _, Result};
@ -12,6 +15,7 @@ use feature_flags::FeatureFlagAppExt as _;
use futures::AsyncReadExt;
use gpui::{
actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task,
WeakEntity,
};
use http_client::{HttpClient, Method};
use language::{
@ -20,26 +24,33 @@ use language::{
};
use language_models::LlmApiToken;
use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
use serde::{Deserialize, Serialize};
use std::{
borrow::Cow,
cmp,
cmp, env,
fmt::Write,
future::Future,
mem,
ops::Range,
path::Path,
path::{Path, PathBuf},
sync::Arc,
time::{Duration, Instant},
};
use telemetry_events::InlineCompletionRating;
use util::ResultExt;
use uuid::Uuid;
use workspace::{
notifications::{simple_message_notification::MessageNotification, NotificationId},
Workspace,
};
const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
const ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY: &'static str =
"zed_predict_data_collection_never_ask_again";
// TODO(mgsloan): more systematic way to choose or tune these fairly arbitrary constants?
@ -187,6 +198,7 @@ pub struct Zeta {
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
shown_completions: VecDeque<InlineCompletion>,
rated_completions: HashSet<InlineCompletionId>,
data_collection_preferences: DataCollectionPreferences,
llm_token: LlmApiToken,
_llm_token_subscription: Subscription,
tos_accepted: bool, // Terms of service accepted
@ -216,13 +228,13 @@ impl Zeta {
fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
Self {
client,
events: VecDeque::new(),
shown_completions: VecDeque::new(),
rated_completions: HashSet::default(),
registered_buffers: HashMap::default(),
data_collection_preferences: Self::load_data_collection_preferences(cx),
llm_token: LlmApiToken::default(),
_llm_token_subscription: cx.subscribe(
&refresh_llm_token_listener,
@ -240,11 +252,16 @@ impl Zeta {
.read(cx)
.current_user_has_accepted_terms()
.unwrap_or(false),
_user_store_subscription: cx.subscribe(&user_store, |this, _, event, _| match event {
client::user::Event::TermsStatusUpdated { accepted } => {
this.tos_accepted = *accepted;
_user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
match event {
client::user::Event::PrivateUserInfoUpdated => {
this.tos_accepted = user_store
.read(cx)
.current_user_has_accepted_terms()
.unwrap_or(false);
}
_ => {}
}
_ => {}
}),
}
}
@ -308,11 +325,8 @@ impl Zeta {
event: &language::BufferEvent,
cx: &mut Context<Self>,
) {
match event {
language::BufferEvent::Edited => {
self.report_changes_for_buffer(&buffer, cx);
}
_ => {}
if let language::BufferEvent::Edited = event {
self.report_changes_for_buffer(&buffer, cx);
}
}
@ -320,6 +334,7 @@ impl Zeta {
&mut self,
buffer: &Entity<Buffer>,
cursor: language::Anchor,
can_collect_data: bool,
cx: &mut Context<Self>,
perform_predict_edits: F,
) -> Task<Result<Option<InlineCompletion>>>
@ -370,6 +385,7 @@ impl Zeta {
input_events: input_events.clone(),
input_excerpt: input_excerpt.clone(),
outline: Some(input_outline.clone()),
can_collect_data,
};
let response = perform_predict_edits(client, llm_token, is_staff, body).await?;
@ -540,16 +556,25 @@ and then another
) -> Task<Result<Option<InlineCompletion>>> {
use std::future::ready;
self.request_completion_impl(buffer, position, cx, |_, _, _, _| ready(Ok(response)))
self.request_completion_impl(buffer, position, false, cx, |_, _, _, _| {
ready(Ok(response))
})
}
pub fn request_completion(
&mut self,
buffer: &Entity<Buffer>,
position: language::Anchor,
can_collect_data: bool,
cx: &mut Context<Self>,
) -> Task<Result<Option<InlineCompletion>>> {
self.request_completion_impl(buffer, position, cx, Self::perform_predict_edits)
self.request_completion_impl(
buffer,
position,
can_collect_data,
cx,
Self::perform_predict_edits,
)
}
fn perform_predict_edits(
@ -862,6 +887,80 @@ and then another
new_snapshot
}
pub fn data_collection_choice_at(&self, path: &Path) -> DataCollectionChoice {
match self.data_collection_preferences.per_worktree.get(path) {
Some(true) => DataCollectionChoice::Enabled,
Some(false) => DataCollectionChoice::Disabled,
None => DataCollectionChoice::NotAnswered,
}
}
fn update_data_collection_choice_for_worktree(
&mut self,
absolute_path_of_project_worktree: PathBuf,
can_collect_data: bool,
cx: &mut Context<Self>,
) {
self.data_collection_preferences
.per_worktree
.insert(absolute_path_of_project_worktree.clone(), can_collect_data);
db::write_and_log(cx, move || {
persistence::DB
.save_accepted_data_collection(absolute_path_of_project_worktree, can_collect_data)
});
}
fn set_never_ask_again_for_data_collection(&mut self, cx: &mut Context<Self>) {
self.data_collection_preferences.never_ask_again = true;
// persist choice
db::write_and_log(cx, move || {
KEY_VALUE_STORE.write_kvp(
ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY.into(),
"true".to_string(),
)
});
}
fn load_data_collection_preferences(cx: &mut Context<Self>) -> DataCollectionPreferences {
if env::var("ZED_PREDICT_CLEAR_DATA_COLLECTION_PREFERENCES").is_ok() {
db::write_and_log(cx, move || async move {
KEY_VALUE_STORE
.delete_kvp(ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY.into())
.await
.log_err();
persistence::DB.clear_all_zeta_preferences().await
});
return DataCollectionPreferences::default();
}
let never_ask_again = KEY_VALUE_STORE
.read_kvp(ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY)
.log_err()
.flatten()
.map(|value| value == "true")
.unwrap_or(false);
let preferences_per_project = persistence::DB
.get_all_zeta_preferences()
.log_err()
.unwrap_or_else(HashMap::default);
DataCollectionPreferences {
never_ask_again,
per_worktree: preferences_per_project,
}
}
}
#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct DataCollectionPreferences {
/// Set when a user clicks on "Never Ask Again", can never be unset.
never_ask_again: bool,
per_worktree: HashMap<PathBuf, bool>,
}
fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
@ -1276,22 +1375,120 @@ struct PendingCompletion {
_task: Task<()>,
}
#[derive(Clone, Copy)]
pub enum DataCollectionChoice {
NotAnswered,
Enabled,
Disabled,
}
impl DataCollectionChoice {
pub fn is_enabled(&self) -> bool {
match self {
Self::Enabled => true,
Self::NotAnswered | Self::Disabled => false,
}
}
pub fn is_answered(&self) -> bool {
match self {
Self::Enabled | Self::Disabled => true,
Self::NotAnswered => false,
}
}
pub fn toggle(&self) -> DataCollectionChoice {
match self {
Self::Enabled => Self::Disabled,
Self::Disabled => Self::Enabled,
Self::NotAnswered => Self::Enabled,
}
}
}
pub struct ZetaInlineCompletionProvider {
zeta: Entity<Zeta>,
pending_completions: ArrayVec<PendingCompletion, 2>,
next_pending_completion_id: usize,
current_completion: Option<CurrentInlineCompletion>,
data_collection: Option<ProviderDataCollection>,
}
pub struct ProviderDataCollection {
workspace: WeakEntity<Workspace>,
worktree_root_path: PathBuf,
choice: DataCollectionChoice,
}
impl ProviderDataCollection {
pub fn new(
zeta: Entity<Zeta>,
workspace: Option<Entity<Workspace>>,
buffer: Option<Entity<Buffer>>,
cx: &mut App,
) -> Option<ProviderDataCollection> {
let workspace = workspace?;
let worktree_root_path = buffer?.update(cx, |buffer, cx| {
let file = buffer.file()?;
if !file.is_local() || file.is_private() {
return None;
}
workspace.update(cx, |workspace, cx| {
Some(
workspace
.absolute_path_of_worktree(file.worktree_id(cx), cx)?
.to_path_buf(),
)
})
})?;
let choice = zeta.read(cx).data_collection_choice_at(&worktree_root_path);
Some(ProviderDataCollection {
workspace: workspace.downgrade(),
worktree_root_path,
choice,
})
}
fn set_choice(&mut self, choice: DataCollectionChoice, zeta: &Entity<Zeta>, cx: &mut App) {
self.choice = choice;
let worktree_root_path = self.worktree_root_path.clone();
zeta.update(cx, |zeta, cx| {
zeta.update_data_collection_choice_for_worktree(
worktree_root_path,
choice.is_enabled(),
cx,
)
});
}
fn toggle_choice(&mut self, zeta: &Entity<Zeta>, cx: &mut App) {
self.set_choice(self.choice.toggle(), zeta, cx);
}
}
impl ZetaInlineCompletionProvider {
pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(8);
pub fn new(zeta: Entity<Zeta>) -> Self {
pub fn new(zeta: Entity<Zeta>, data_collection: Option<ProviderDataCollection>) -> Self {
Self {
zeta,
pending_completions: ArrayVec::new(),
next_pending_completion_id: 0,
current_completion: None,
data_collection,
}
}
fn set_data_collection_choice(&mut self, choice: DataCollectionChoice, cx: &mut App) {
if let Some(data_collection) = self.data_collection.as_mut() {
data_collection.set_choice(choice, &self.zeta, cx);
}
}
}
@ -1302,7 +1499,7 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide
}
fn display_name() -> &'static str {
"Zed Predict"
"Zed's Edit Predictions"
}
fn show_completions_in_menu() -> bool {
@ -1317,6 +1514,24 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide
true
}
fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
let Some(data_collection) = self.data_collection.as_ref() else {
return DataCollectionState::Unknown;
};
if data_collection.choice.is_enabled() {
DataCollectionState::Enabled
} else {
DataCollectionState::Disabled
}
}
fn toggle_data_collection(&mut self, cx: &mut App) {
if let Some(data_collection) = self.data_collection.as_mut() {
data_collection.toggle_choice(&self.zeta, cx);
}
}
fn is_enabled(
&self,
buffer: &Entity<Buffer>,
@ -1362,6 +1577,10 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide
let pending_completion_id = self.next_pending_completion_id;
self.next_pending_completion_id += 1;
let can_collect_data = self
.data_collection
.as_ref()
.map_or(false, |data_collection| data_collection.choice.is_enabled());
let task = cx.spawn(|this, mut cx| async move {
if debounce {
@ -1370,7 +1589,7 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide
let completion_request = this.update(&mut cx, |this, cx| {
this.zeta.update(cx, |zeta, cx| {
zeta.request_completion(&buffer, position, cx)
zeta.request_completion(&buffer, position, can_collect_data, cx)
})
});
@ -1447,8 +1666,80 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide
// Right now we don't support cycling.
}
fn accept(&mut self, _cx: &mut Context<Self>) {
fn accept(&mut self, cx: &mut Context<Self>) {
self.pending_completions.clear();
let Some(data_collection) = self.data_collection.as_mut() else {
return;
};
if data_collection.choice.is_answered()
|| self
.zeta
.read(cx)
.data_collection_preferences
.never_ask_again
{
return;
}
struct ZetaDataCollectionNotification;
let notification_id = NotificationId::unique::<ZetaDataCollectionNotification>();
const DATA_COLLECTION_INFO_URL: &str = "https://zed.dev/terms-of-service"; // TODO: Replace for a link that's dedicated to Edit Predictions data collection
let this = cx.entity();
data_collection
.workspace
.update(cx, |workspace, cx| {
workspace.show_notification(notification_id, cx, |cx| {
let zeta = self.zeta.clone();
cx.new(move |_cx| {
let message =
"To allow Zed to suggest better edits, turn on data collection. You \
can turn off at any time via the status bar menu.";
MessageNotification::new(message)
.with_title("Per-Project Data Collection Program")
.show_close_button(false)
.with_click_message("Turn On")
.on_click({
let this = this.clone();
move |_window, cx| {
this.update(cx, |this, cx| {
this.set_data_collection_choice(
DataCollectionChoice::Enabled,
cx,
)
});
}
})
.with_secondary_click_message("Turn Off")
.on_secondary_click({
move |_window, cx| {
this.update(cx, |this, cx| {
this.set_data_collection_choice(
DataCollectionChoice::Disabled,
cx,
)
});
}
})
.with_tertiary_click_message("Never Ask Again")
.on_tertiary_click({
let zeta = zeta.clone();
move |_window, cx| {
zeta.update(cx, |zeta, cx| {
zeta.set_never_ask_again_for_data_collection(cx);
});
}
})
.more_info_message("Learn More")
.more_info_url(DATA_COLLECTION_INFO_URL)
})
});
})
.log_err();
}
fn discard(&mut self, _cx: &mut Context<Self>) {
@ -1688,8 +1979,9 @@ mod tests {
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
let completion_task =
zeta.update(cx, |zeta, cx| zeta.request_completion(&buffer, cursor, cx));
let completion_task = zeta.update(cx, |zeta, cx| {
zeta.request_completion(&buffer, cursor, false, cx)
});
let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
server.respond(