Merge branch 'main' into ui-scrollbar-teardown

This commit is contained in:
MrSubidubi 2025-08-16 00:38:25 +02:00
commit 91cdf69924
83 changed files with 1532 additions and 3929 deletions

161
Cargo.lock generated
View file

@ -1262,26 +1262,6 @@ dependencies = [
"syn 2.0.101",
]
[[package]]
name = "async-stripe"
version = "0.40.0"
source = "git+https://github.com/zed-industries/async-stripe?rev=3672dd4efb7181aa597bf580bf5a2f5d23db6735#3672dd4efb7181aa597bf580bf5a2f5d23db6735"
dependencies = [
"chrono",
"futures-util",
"http-types",
"hyper 0.14.32",
"hyper-rustls 0.24.2",
"serde",
"serde_json",
"serde_path_to_error",
"serde_qs 0.10.1",
"smart-default 0.6.0",
"smol_str 0.1.24",
"thiserror 1.0.69",
"tokio",
]
[[package]]
name = "async-tar"
version = "0.5.0"
@ -2083,12 +2063,6 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce"
[[package]]
name = "base64"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]]
name = "base64"
version = "0.21.7"
@ -3281,7 +3255,6 @@ dependencies = [
"anyhow",
"assistant_context",
"assistant_slash_command",
"async-stripe",
"async-trait",
"async-tungstenite",
"audio",
@ -3308,7 +3281,6 @@ dependencies = [
"dap_adapters",
"dashmap 6.1.0",
"debugger_ui",
"derive_more 0.99.19",
"editor",
"envy",
"extension",
@ -3324,7 +3296,6 @@ dependencies = [
"http_client",
"hyper 0.14.32",
"indoc",
"jsonwebtoken",
"language",
"language_model",
"livekit_api",
@ -3370,7 +3341,6 @@ dependencies = [
"telemetry_events",
"text",
"theme",
"thiserror 2.0.12",
"time",
"tokio",
"toml 0.8.20",
@ -3872,7 +3842,7 @@ dependencies = [
"rustc-hash 1.1.0",
"rustybuzz 0.14.1",
"self_cell",
"smol_str 0.2.2",
"smol_str",
"swash",
"sys-locale",
"ttf-parser 0.21.1",
@ -6376,17 +6346,6 @@ dependencies = [
"windows-targets 0.48.5",
]
[[package]]
name = "getrandom"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
dependencies = [
"cfg-if",
"libc",
"wasi 0.9.0+wasi-snapshot-preview1",
]
[[package]]
name = "getrandom"
version = "0.2.15"
@ -7990,27 +7949,6 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f"
[[package]]
name = "http-types"
version = "2.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e9b187a72d63adbfba487f48095306ac823049cb504ee195541e91c7775f5ad"
dependencies = [
"anyhow",
"async-channel 1.9.0",
"base64 0.13.1",
"futures-lite 1.13.0",
"http 0.2.12",
"infer",
"pin-project-lite",
"rand 0.7.3",
"serde",
"serde_json",
"serde_qs 0.8.5",
"serde_urlencoded",
"url",
]
[[package]]
name = "http_client"
version = "0.1.0"
@ -8489,12 +8427,6 @@ version = "2.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
[[package]]
name = "infer"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac"
[[package]]
name = "inherent"
version = "1.0.12"
@ -10271,7 +10203,7 @@ dependencies = [
"num-traits",
"range-map",
"scroll",
"smart-default 0.7.1",
"smart-default",
]
[[package]]
@ -13145,19 +13077,6 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
[[package]]
name = "rand"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
dependencies = [
"getrandom 0.1.16",
"libc",
"rand_chacha 0.2.2",
"rand_core 0.5.1",
"rand_hc",
]
[[package]]
name = "rand"
version = "0.8.5"
@ -13179,16 +13098,6 @@ dependencies = [
"rand_core 0.9.3",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
dependencies = [
"ppv-lite86",
"rand_core 0.5.1",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
@ -13209,15 +13118,6 @@ dependencies = [
"rand_core 0.9.3",
]
[[package]]
name = "rand_core"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
dependencies = [
"getrandom 0.1.16",
]
[[package]]
name = "rand_core"
version = "0.6.4"
@ -13236,15 +13136,6 @@ dependencies = [
"getrandom 0.3.2",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
dependencies = [
"rand_core 0.5.1",
]
[[package]]
name = "range-map"
version = "0.2.0"
@ -14899,28 +14790,6 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_qs"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7715380eec75f029a4ef7de39a9200e0a63823176b759d055b613f5a87df6a6"
dependencies = [
"percent-encoding",
"serde",
"thiserror 1.0.69",
]
[[package]]
name = "serde_qs"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cac3f1e2ca2fe333923a1ae72caca910b98ed0630bb35ef6f8c8517d6e81afa"
dependencies = [
"percent-encoding",
"serde",
"thiserror 1.0.69",
]
[[package]]
name = "serde_repr"
version = "0.1.20"
@ -15297,17 +15166,6 @@ dependencies = [
"serde",
]
[[package]]
name = "smart-default"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "133659a15339456eeeb07572eb02a91c91e9815e9cbc89566944d2c8d3efdbf6"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "smart-default"
version = "0.7.1"
@ -15336,15 +15194,6 @@ dependencies = [
"futures-lite 2.6.0",
]
[[package]]
name = "smol_str"
version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fad6c857cbab2627dcf01ec85a623ca4e7dcb5691cbaa3d7fb7653671f0d09c9"
dependencies = [
"serde",
]
[[package]]
name = "smol_str"
version = "0.2.2"
@ -18194,12 +18043,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "wasi"
version = "0.9.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"

View file

@ -667,20 +667,6 @@ workspace-hack = "0.1.0"
yawc = { git = "https://github.com/deviant-forks/yawc", rev = "1899688f3e69ace4545aceb97b2a13881cf26142" }
zstd = "0.11"
[workspace.dependencies.async-stripe]
git = "https://github.com/zed-industries/async-stripe"
rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735"
default-features = false
features = [
"runtime-tokio-hyper-rustls",
"billing",
"checkout",
"events",
# The features below are only enabled to get the `events` feature to build.
"chrono",
"connect",
]
[workspace.dependencies.windows]
version = "0.61"
features = [

View file

@ -109,7 +109,7 @@ pub enum AgentThreadEntry {
}
impl AgentThreadEntry {
fn to_markdown(&self, cx: &App) -> String {
pub fn to_markdown(&self, cx: &App) -> String {
match self {
Self::UserMessage(message) => message.to_markdown(cx),
Self::AssistantMessage(message) => message.to_markdown(cx),
@ -117,6 +117,14 @@ impl AgentThreadEntry {
}
}
pub fn user_message(&self) -> Option<&UserMessage> {
if let AgentThreadEntry::UserMessage(message) = self {
Some(message)
} else {
None
}
}
pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
if let AgentThreadEntry::ToolCall(call) = self {
itertools::Either::Left(call.diffs())

View file

@ -309,7 +309,7 @@ pub struct AgentSettingsContent {
///
/// Default: true
expand_terminal_card: Option<bool>,
/// Whether to always use cmd-enter (or ctrl-enter on Linux) to send messages in the agent panel.
/// Whether to always use cmd-enter (or ctrl-enter on Linux or Windows) to send messages in the agent panel.
///
/// Default: false
use_modifier_to_send: Option<bool>,

View file

@ -1,38 +1,31 @@
use std::ffi::OsStr;
use std::ops::Range;
use std::path::{Path, PathBuf};
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use acp_thread::{MentionUri, selection_name};
use acp_thread::MentionUri;
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use editor::display_map::CreaseId;
use editor::{CompletionProvider, Editor, ExcerptId, ToOffset as _};
use editor::{CompletionProvider, Editor, ExcerptId};
use futures::future::{Shared, try_join_all};
use futures::{FutureExt, TryFutureExt};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{App, Entity, ImageFormat, Img, Task, WeakEntity};
use http_client::HttpClientWithUrl;
use itertools::Itertools as _;
use gpui::{App, Entity, ImageFormat, Task, WeakEntity};
use language::{Buffer, CodeLabel, HighlightId};
use language_model::LanguageModelImage;
use lsp::CompletionContext;
use parking_lot::Mutex;
use project::{
Completion, CompletionIntent, CompletionResponse, Project, ProjectPath, Symbol, WorktreeId,
};
use prompt_store::PromptStore;
use rope::Point;
use text::{Anchor, OffsetRangeExt as _, ToPoint as _};
use text::{Anchor, ToPoint as _};
use ui::prelude::*;
use url::Url;
use workspace::Workspace;
use workspace::notifications::NotifyResultExt;
use agent::thread_store::{TextThreadStore, ThreadStore};
use crate::context_picker::fetch_context_picker::fetch_url_content;
use crate::acp::message_editor::MessageEditor;
use crate::context_picker::file_context_picker::{FileMatch, search_files};
use crate::context_picker::rules_context_picker::{RulesContextEntry, search_rules};
use crate::context_picker::symbol_context_picker::SymbolMatch;
@ -47,14 +40,14 @@ use crate::context_picker::{
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MentionImage {
pub abs_path: Option<Arc<Path>>,
pub abs_path: Option<PathBuf>,
pub data: SharedString,
pub format: ImageFormat,
}
#[derive(Default)]
pub struct MentionSet {
uri_by_crease_id: HashMap<CreaseId, MentionUri>,
pub(crate) uri_by_crease_id: HashMap<CreaseId, MentionUri>,
fetch_results: HashMap<Url, Shared<Task<Result<String, String>>>>,
images: HashMap<CreaseId, Shared<Task<Result<MentionImage, String>>>>,
}
@ -84,11 +77,6 @@ impl MentionSet {
.chain(self.images.drain().map(|(id, _)| id))
}
pub fn clear(&mut self) {
self.fetch_results.clear();
self.uri_by_crease_id.clear();
}
pub fn contents(
&self,
project: Entity<Project>,
@ -97,6 +85,8 @@ impl MentionSet {
window: &mut Window,
cx: &mut App,
) -> Task<Result<HashMap<CreaseId, Mention>>> {
let mut processed_image_creases = HashSet::default();
let mut contents = self
.uri_by_crease_id
.iter()
@ -106,46 +96,15 @@ impl MentionSet {
// TODO directories
let uri = uri.clone();
let abs_path = abs_path.to_path_buf();
let extension = abs_path.extension().and_then(OsStr::to_str).unwrap_or("");
if Img::extensions().contains(&extension) && !extension.contains("svg") {
let open_image_task = project.update(cx, |project, cx| {
let path = project
.find_project_path(&abs_path, cx)
.context("Failed to find project path")?;
anyhow::Ok(project.open_image(path, cx))
if let Some(task) = self.images.get(&crease_id).cloned() {
processed_image_creases.insert(crease_id);
return cx.spawn(async move |_| {
let image = task.await.map_err(|e| anyhow!("{e}"))?;
anyhow::Ok((crease_id, Mention::Image(image)))
});
cx.spawn(async move |cx| {
let image_item = open_image_task?.await?;
let (data, format) = image_item.update(cx, |image_item, cx| {
let format = image_item.image.format;
(
LanguageModelImage::from_image(
image_item.image.clone(),
cx,
),
format,
)
})?;
let data = cx.spawn(async move |_| {
if let Some(data) = data.await {
Ok(data.source)
} else {
anyhow::bail!("Failed to convert image")
}
});
anyhow::Ok((
crease_id,
Mention::Image(MentionImage {
abs_path: Some(abs_path.as_path().into()),
data: data.await?,
format,
}),
))
})
} else {
let buffer_task = project.update(cx, |project, cx| {
let path = project
.find_project_path(abs_path, cx)
@ -159,7 +118,6 @@ impl MentionSet {
anyhow::Ok((crease_id, Mention::Text { uri, content }))
})
}
}
MentionUri::Symbol {
path, line_range, ..
}
@ -252,15 +210,19 @@ impl MentionSet {
})
.collect::<Vec<_>>();
contents.extend(self.images.iter().map(|(crease_id, image)| {
// Handle images that didn't have a mention URI (because they were added by the paste handler).
contents.extend(self.images.iter().filter_map(|(crease_id, image)| {
if processed_image_creases.contains(crease_id) {
return None;
}
let crease_id = *crease_id;
let image = image.clone();
cx.spawn(async move |_| {
Some(cx.spawn(async move |_| {
Ok((
crease_id,
Mention::Image(image.await.map_err(|e| anyhow::anyhow!("{e}"))?),
))
})
}))
}));
cx.spawn(async move |_cx| {
@ -488,36 +450,31 @@ fn search(
}
pub struct ContextPickerCompletionProvider {
mention_set: Arc<Mutex<MentionSet>>,
workspace: WeakEntity<Workspace>,
thread_store: WeakEntity<ThreadStore>,
text_thread_store: WeakEntity<TextThreadStore>,
editor: WeakEntity<Editor>,
message_editor: WeakEntity<MessageEditor>,
}
impl ContextPickerCompletionProvider {
pub fn new(
mention_set: Arc<Mutex<MentionSet>>,
workspace: WeakEntity<Workspace>,
thread_store: WeakEntity<ThreadStore>,
text_thread_store: WeakEntity<TextThreadStore>,
editor: WeakEntity<Editor>,
message_editor: WeakEntity<MessageEditor>,
) -> Self {
Self {
mention_set,
workspace,
thread_store,
text_thread_store,
editor,
message_editor,
}
}
fn completion_for_entry(
entry: ContextPickerEntry,
excerpt_id: ExcerptId,
source_range: Range<Anchor>,
editor: Entity<Editor>,
mention_set: Arc<Mutex<MentionSet>>,
message_editor: WeakEntity<MessageEditor>,
workspace: &Entity<Workspace>,
cx: &mut App,
) -> Option<Completion> {
@ -538,88 +495,39 @@ impl ContextPickerCompletionProvider {
ContextPickerEntry::Action(action) => {
let (new_text, on_action) = match action {
ContextPickerAction::AddSelections => {
let selections = selection_ranges(workspace, cx);
const PLACEHOLDER: &str = "selection ";
let selections = selection_ranges(workspace, cx)
.into_iter()
.enumerate()
.map(|(ix, (buffer, range))| {
(
buffer,
range,
(PLACEHOLDER.len() * ix)..(PLACEHOLDER.len() * (ix + 1) - 1),
)
})
.collect::<Vec<_>>();
let new_text = std::iter::repeat(PLACEHOLDER)
.take(selections.len())
.chain(std::iter::once(""))
.join(" ");
let new_text: String = PLACEHOLDER.repeat(selections.len());
let callback = Arc::new({
let mention_set = mention_set.clone();
let selections = selections.clone();
let source_range = source_range.clone();
move |_, window: &mut Window, cx: &mut App| {
let editor = editor.clone();
let mention_set = mention_set.clone();
let selections = selections.clone();
let message_editor = message_editor.clone();
let source_range = source_range.clone();
window.defer(cx, move |window, cx| {
let mut current_offset = 0;
for (buffer, selection_range) in selections {
let snapshot =
editor.read(cx).buffer().read(cx).snapshot(cx);
let Some(start) = snapshot
.anchor_in_excerpt(excerpt_id, source_range.start)
else {
return;
};
let offset = start.to_offset(&snapshot) + current_offset;
let text_len = PLACEHOLDER.len() - 1;
let range = snapshot.anchor_after(offset)
..snapshot.anchor_after(offset + text_len);
let path = buffer
.read(cx)
.file()
.map_or(PathBuf::from("untitled"), |file| {
file.path().to_path_buf()
});
let point_range = snapshot
.as_singleton()
.map(|(_, _, snapshot)| {
selection_range.to_point(&snapshot)
})
.unwrap_or_default();
let line_range = point_range.start.row..point_range.end.row;
let uri = MentionUri::Selection {
path: path.clone(),
line_range: line_range.clone(),
};
let crease = crate::context_picker::crease_for_mention(
selection_name(&path, &line_range).into(),
uri.icon_path(cx),
range,
editor.downgrade(),
);
let [crease_id]: [_; 1] =
editor.update(cx, |editor, cx| {
let crease_ids =
editor.insert_creases(vec![crease.clone()], cx);
editor.fold_creases(
vec![crease],
false,
message_editor
.update(cx, |message_editor, cx| {
message_editor.confirm_mention_for_selection(
source_range,
selections,
window,
cx,
);
crease_ids.try_into().unwrap()
)
})
.ok();
});
mention_set.lock().insert_uri(
crease_id,
MentionUri::Selection { path, line_range },
);
current_offset += text_len + 1;
}
});
false
}
});
@ -647,11 +555,9 @@ impl ContextPickerCompletionProvider {
fn completion_for_thread(
thread_entry: ThreadContextEntry,
excerpt_id: ExcerptId,
source_range: Range<Anchor>,
recent: bool,
editor: Entity<Editor>,
mention_set: Arc<Mutex<MentionSet>>,
editor: WeakEntity<MessageEditor>,
cx: &mut App,
) -> Completion {
let uri = match &thread_entry {
@ -683,13 +589,10 @@ impl ContextPickerCompletionProvider {
source: project::CompletionSource::Custom,
icon_path: Some(icon_for_completion.clone()),
confirm: Some(confirm_completion_callback(
uri.icon_path(cx),
thread_entry.title().clone(),
excerpt_id,
source_range.start,
new_text_len - 1,
editor.clone(),
mention_set,
editor,
uri,
)),
}
@ -697,10 +600,8 @@ impl ContextPickerCompletionProvider {
fn completion_for_rules(
rule: RulesContextEntry,
excerpt_id: ExcerptId,
source_range: Range<Anchor>,
editor: Entity<Editor>,
mention_set: Arc<Mutex<MentionSet>>,
editor: WeakEntity<MessageEditor>,
cx: &mut App,
) -> Completion {
let uri = MentionUri::Rule {
@ -719,13 +620,10 @@ impl ContextPickerCompletionProvider {
source: project::CompletionSource::Custom,
icon_path: Some(icon_path.clone()),
confirm: Some(confirm_completion_callback(
icon_path,
rule.title.clone(),
excerpt_id,
source_range.start,
new_text_len - 1,
editor.clone(),
mention_set,
editor,
uri,
)),
}
@ -736,10 +634,8 @@ impl ContextPickerCompletionProvider {
path_prefix: &str,
is_recent: bool,
is_directory: bool,
excerpt_id: ExcerptId,
source_range: Range<Anchor>,
editor: Entity<Editor>,
mention_set: Arc<Mutex<MentionSet>>,
message_editor: WeakEntity<MessageEditor>,
project: Entity<Project>,
cx: &mut App,
) -> Option<Completion> {
@ -777,13 +673,10 @@ impl ContextPickerCompletionProvider {
icon_path: Some(completion_icon_path),
insert_text_mode: None,
confirm: Some(confirm_completion_callback(
crease_icon_path,
file_name,
excerpt_id,
source_range.start,
new_text_len - 1,
editor,
mention_set.clone(),
message_editor,
file_uri,
)),
})
@ -791,10 +684,8 @@ impl ContextPickerCompletionProvider {
fn completion_for_symbol(
symbol: Symbol,
excerpt_id: ExcerptId,
source_range: Range<Anchor>,
editor: Entity<Editor>,
mention_set: Arc<Mutex<MentionSet>>,
message_editor: WeakEntity<MessageEditor>,
workspace: Entity<Workspace>,
cx: &mut App,
) -> Option<Completion> {
@ -820,13 +711,10 @@ impl ContextPickerCompletionProvider {
icon_path: Some(icon_path.clone()),
insert_text_mode: None,
confirm: Some(confirm_completion_callback(
icon_path,
symbol.name.clone().into(),
excerpt_id,
source_range.start,
new_text_len - 1,
editor.clone(),
mention_set.clone(),
message_editor,
uri,
)),
})
@ -835,116 +723,32 @@ impl ContextPickerCompletionProvider {
fn completion_for_fetch(
source_range: Range<Anchor>,
url_to_fetch: SharedString,
excerpt_id: ExcerptId,
editor: Entity<Editor>,
mention_set: Arc<Mutex<MentionSet>>,
http_client: Arc<HttpClientWithUrl>,
message_editor: WeakEntity<MessageEditor>,
cx: &mut App,
) -> Option<Completion> {
let new_text = format!("@fetch {} ", url_to_fetch.clone());
let new_text_len = new_text.len();
let mention_uri = MentionUri::Fetch {
url: url::Url::parse(url_to_fetch.as_ref())
let url_to_fetch = url::Url::parse(url_to_fetch.as_ref())
.or_else(|_| url::Url::parse(&format!("https://{url_to_fetch}")))
.ok()?,
.ok()?;
let mention_uri = MentionUri::Fetch {
url: url_to_fetch.clone(),
};
let icon_path = mention_uri.icon_path(cx);
Some(Completion {
replace_range: source_range.clone(),
new_text,
new_text: new_text.clone(),
label: CodeLabel::plain(url_to_fetch.to_string(), None),
documentation: None,
source: project::CompletionSource::Custom,
icon_path: Some(icon_path.clone()),
insert_text_mode: None,
confirm: Some({
let start = source_range.start;
let content_len = new_text_len - 1;
let editor = editor.clone();
let url_to_fetch = url_to_fetch.clone();
let source_range = source_range.clone();
let icon_path = icon_path.clone();
let mention_uri = mention_uri.clone();
Arc::new(move |_, window, cx| {
let Some(url) = url::Url::parse(url_to_fetch.as_ref())
.or_else(|_| url::Url::parse(&format!("https://{url_to_fetch}")))
.notify_app_err(cx)
else {
return false;
};
let editor = editor.clone();
let mention_set = mention_set.clone();
let http_client = http_client.clone();
let source_range = source_range.clone();
let icon_path = icon_path.clone();
let mention_uri = mention_uri.clone();
window.defer(cx, move |window, cx| {
let url = url.clone();
let Some(crease_id) = crate::context_picker::insert_crease_for_mention(
excerpt_id,
start,
content_len,
url.to_string().into(),
icon_path,
editor.clone(),
window,
cx,
) else {
return;
};
let editor = editor.clone();
let mention_set = mention_set.clone();
let http_client = http_client.clone();
let source_range = source_range.clone();
let url_string = url.to_string();
let fetch = cx
.background_executor()
.spawn(async move {
fetch_url_content(http_client, url_string)
.map_err(|e| e.to_string())
.await
})
.shared();
mention_set.lock().add_fetch_result(url, fetch.clone());
window
.spawn(cx, async move |cx| {
if fetch.await.notify_async_err(cx).is_some() {
mention_set
.lock()
.insert_uri(crease_id, mention_uri.clone());
} else {
// Remove crease if we failed to fetch
editor
.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx);
let Some(anchor) = snapshot
.anchor_in_excerpt(excerpt_id, source_range.start)
else {
return;
};
editor.display_map.update(cx, |display_map, cx| {
display_map.unfold_intersecting(
vec![anchor..anchor],
true,
cx,
);
});
editor.remove_creases([crease_id], cx);
})
.ok();
}
Some(())
})
.detach();
});
false
})
}),
confirm: Some(confirm_completion_callback(
url_to_fetch.to_string().into(),
source_range.start,
new_text.len() - 1,
message_editor,
mention_uri,
)),
})
}
}
@ -968,7 +772,7 @@ fn build_code_label_for_full_path(file_name: &str, directory: Option<&str>, cx:
impl CompletionProvider for ContextPickerCompletionProvider {
fn completions(
&self,
excerpt_id: ExcerptId,
_excerpt_id: ExcerptId,
buffer: &Entity<Buffer>,
buffer_position: Anchor,
_trigger: CompletionContext,
@ -992,39 +796,24 @@ impl CompletionProvider for ContextPickerCompletionProvider {
};
let project = workspace.read(cx).project().clone();
let http_client = workspace.read(cx).client().http_client();
let snapshot = buffer.read(cx).snapshot();
let source_range = snapshot.anchor_before(state.source_range.start)
..snapshot.anchor_after(state.source_range.end);
let thread_store = self.thread_store.clone();
let text_thread_store = self.text_thread_store.clone();
let editor = self.editor.clone();
let editor = self.message_editor.clone();
let Ok((exclude_paths, exclude_threads)) =
self.message_editor.update(cx, |message_editor, _cx| {
message_editor.mentioned_path_and_threads()
})
else {
return Task::ready(Ok(Vec::new()));
};
let MentionCompletion { mode, argument, .. } = state;
let query = argument.unwrap_or_else(|| "".to_string());
let (exclude_paths, exclude_threads) = {
let mention_set = self.mention_set.lock();
let mut excluded_paths = HashSet::default();
let mut excluded_threads = HashSet::default();
for uri in mention_set.uri_by_crease_id.values() {
match uri {
MentionUri::File { abs_path, .. } => {
excluded_paths.insert(abs_path.clone());
}
MentionUri::Thread { id, .. } => {
excluded_threads.insert(id.clone());
}
_ => {}
}
}
(excluded_paths, excluded_threads)
};
let recent_entries = recent_context_picker_entries(
Some(thread_store.clone()),
Some(text_thread_store.clone()),
@ -1051,13 +840,8 @@ impl CompletionProvider for ContextPickerCompletionProvider {
cx,
);
let mention_set = self.mention_set.clone();
cx.spawn(async move |_, cx| {
let matches = search_task.await;
let Some(editor) = editor.upgrade() else {
return Ok(Vec::new());
};
let completions = cx.update(|cx| {
matches
@ -1074,10 +858,8 @@ impl CompletionProvider for ContextPickerCompletionProvider {
&mat.path_prefix,
is_recent,
mat.is_dir,
excerpt_id,
source_range.clone(),
editor.clone(),
mention_set.clone(),
project.clone(),
cx,
)
@ -1085,10 +867,8 @@ impl CompletionProvider for ContextPickerCompletionProvider {
Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol(
symbol,
excerpt_id,
source_range.clone(),
editor.clone(),
mention_set.clone(),
workspace.clone(),
cx,
),
@ -1097,39 +877,30 @@ impl CompletionProvider for ContextPickerCompletionProvider {
thread, is_recent, ..
}) => Some(Self::completion_for_thread(
thread,
excerpt_id,
source_range.clone(),
is_recent,
editor.clone(),
mention_set.clone(),
cx,
)),
Match::Rules(user_rules) => Some(Self::completion_for_rules(
user_rules,
excerpt_id,
source_range.clone(),
editor.clone(),
mention_set.clone(),
cx,
)),
Match::Fetch(url) => Self::completion_for_fetch(
source_range.clone(),
url,
excerpt_id,
editor.clone(),
mention_set.clone(),
http_client.clone(),
cx,
),
Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry(
entry,
excerpt_id,
source_range.clone(),
editor.clone(),
mention_set.clone(),
&workspace,
cx,
),
@ -1182,36 +953,30 @@ impl CompletionProvider for ContextPickerCompletionProvider {
}
fn confirm_completion_callback(
crease_icon_path: SharedString,
crease_text: SharedString,
excerpt_id: ExcerptId,
start: Anchor,
content_len: usize,
editor: Entity<Editor>,
mention_set: Arc<Mutex<MentionSet>>,
message_editor: WeakEntity<MessageEditor>,
mention_uri: MentionUri,
) -> Arc<dyn Fn(CompletionIntent, &mut Window, &mut App) -> bool + Send + Sync> {
Arc::new(move |_, window, cx| {
let message_editor = message_editor.clone();
let crease_text = crease_text.clone();
let crease_icon_path = crease_icon_path.clone();
let editor = editor.clone();
let mention_set = mention_set.clone();
let mention_uri = mention_uri.clone();
window.defer(cx, move |window, cx| {
if let Some(crease_id) = crate::context_picker::insert_crease_for_mention(
excerpt_id,
message_editor
.clone()
.update(cx, |message_editor, cx| {
message_editor.confirm_completion(
crease_text,
start,
content_len,
crease_text.clone(),
crease_icon_path,
editor.clone(),
mention_uri,
window,
cx,
) {
mention_set
.lock()
.insert_uri(crease_id, mention_uri.clone());
}
)
})
.ok();
});
false
})
@ -1279,13 +1044,13 @@ impl MentionCompletion {
#[cfg(test)]
mod tests {
use super::*;
use editor::AnchorRangeExt;
use editor::{AnchorRangeExt, EditorMode};
use gpui::{EventEmitter, FocusHandle, Focusable, TestAppContext, VisualTestContext};
use project::{Project, ProjectPath};
use serde_json::json;
use settings::SettingsStore;
use smol::stream::StreamExt as _;
use std::{ops::Deref, path::Path, rc::Rc};
use std::{ops::Deref, path::Path};
use util::path;
use workspace::{AppState, Item};
@ -1359,9 +1124,9 @@ mod tests {
assert_eq!(MentionCompletion::try_parse("test@", 0), None);
}
struct AtMentionEditor(Entity<Editor>);
struct MessageEditorItem(Entity<MessageEditor>);
impl Item for AtMentionEditor {
impl Item for MessageEditorItem {
type Event = ();
fn include_in_nav_history() -> bool {
@ -1373,15 +1138,15 @@ mod tests {
}
}
impl EventEmitter<()> for AtMentionEditor {}
impl EventEmitter<()> for MessageEditorItem {}
impl Focusable for AtMentionEditor {
impl Focusable for MessageEditorItem {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.0.read(cx).focus_handle(cx).clone()
}
}
impl Render for AtMentionEditor {
impl Render for MessageEditorItem {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
self.0.clone().into_any_element()
}
@ -1467,19 +1232,28 @@ mod tests {
opened_editors.push(buffer);
}
let editor = workspace.update_in(&mut cx, |workspace, window, cx| {
let editor = cx.new(|cx| {
Editor::new(
editor::EditorMode::full(),
multi_buffer::MultiBuffer::build_simple("", cx),
None,
let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx));
let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
let (message_editor, editor) = workspace.update_in(&mut cx, |workspace, window, cx| {
let workspace_handle = cx.weak_entity();
let message_editor = cx.new(|cx| {
MessageEditor::new(
workspace_handle,
project.clone(),
thread_store.clone(),
text_thread_store.clone(),
EditorMode::AutoHeight {
max_lines: None,
min_lines: 1,
},
window,
cx,
)
});
workspace.active_pane().update(cx, |pane, cx| {
pane.add_item(
Box::new(cx.new(|_| AtMentionEditor(editor.clone()))),
Box::new(cx.new(|_| MessageEditorItem(message_editor.clone()))),
true,
true,
None,
@ -1487,24 +1261,9 @@ mod tests {
cx,
);
});
editor
});
let mention_set = Arc::new(Mutex::new(MentionSet::default()));
let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx));
let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
let editor_entity = editor.downgrade();
editor.update_in(&mut cx, |editor, window, cx| {
window.focus(&editor.focus_handle(cx));
editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new(
mention_set.clone(),
workspace.downgrade(),
thread_store.downgrade(),
text_thread_store.downgrade(),
editor_entity,
))));
message_editor.read(cx).focus_handle(cx).focus(window);
let editor = message_editor.read(cx).editor().clone();
(message_editor, editor)
});
cx.simulate_input("Lorem ");
@ -1573,9 +1332,9 @@ mod tests {
);
});
let contents = cx
.update(|window, cx| {
mention_set.lock().contents(
let contents = message_editor
.update_in(&mut cx, |message_editor, window, cx| {
message_editor.mention_set().contents(
project.clone(),
thread_store.clone(),
text_thread_store.clone(),
@ -1641,9 +1400,9 @@ mod tests {
cx.run_until_parked();
let contents = cx
.update(|window, cx| {
mention_set.lock().contents(
let contents = message_editor
.update_in(&mut cx, |message_editor, window, cx| {
message_editor.mention_set().contents(
project.clone(),
thread_store.clone(),
text_thread_store.clone(),
@ -1765,9 +1524,9 @@ mod tests {
editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx);
});
let contents = cx
.update(|window, cx| {
mention_set.lock().contents(
let contents = message_editor
.update_in(&mut cx, |message_editor, window, cx| {
message_editor.mention_set().contents(
project.clone(),
thread_store,
text_thread_store,

View file

@ -1,45 +1,141 @@
use std::{collections::HashMap, ops::Range};
use std::ops::Range;
use acp_thread::AcpThread;
use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer};
use acp_thread::{AcpThread, AgentThreadEntry};
use agent::{TextThreadStore, ThreadStore};
use collections::HashMap;
use editor::{Editor, EditorMode, MinimapVisibility};
use gpui::{
AnyEntity, App, AppContext as _, Entity, EntityId, TextStyleRefinement, WeakEntity, Window,
AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, TextStyleRefinement,
WeakEntity, Window,
};
use language::language_settings::SoftWrap;
use project::Project;
use settings::Settings as _;
use terminal_view::TerminalView;
use theme::ThemeSettings;
use ui::TextSize;
use ui::{Context, TextSize};
use workspace::Workspace;
#[derive(Default)]
use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
pub struct EntryViewState {
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
thread_store: Entity<ThreadStore>,
text_thread_store: Entity<TextThreadStore>,
entries: Vec<Entry>,
}
impl EntryViewState {
pub fn new(
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
thread_store: Entity<ThreadStore>,
text_thread_store: Entity<TextThreadStore>,
) -> Self {
Self {
workspace,
project,
thread_store,
text_thread_store,
entries: Vec::new(),
}
}
pub fn entry(&self, index: usize) -> Option<&Entry> {
self.entries.get(index)
}
pub fn sync_entry(
&mut self,
workspace: WeakEntity<Workspace>,
thread: Entity<AcpThread>,
index: usize,
thread: &Entity<AcpThread>,
window: &mut Window,
cx: &mut App,
cx: &mut Context<Self>,
) {
debug_assert!(index <= self.entries.len());
let entry = if let Some(entry) = self.entries.get_mut(index) {
entry
} else {
self.entries.push(Entry::default());
self.entries.last_mut().unwrap()
let Some(thread_entry) = thread.read(cx).entries().get(index) else {
return;
};
entry.sync_diff_multibuffers(&thread, index, window, cx);
entry.sync_terminals(&workspace, &thread, index, window, cx);
match thread_entry {
AgentThreadEntry::UserMessage(message) => {
let has_id = message.id.is_some();
let chunks = message.chunks.clone();
let message_editor = cx.new(|cx| {
let mut editor = MessageEditor::new(
self.workspace.clone(),
self.project.clone(),
self.thread_store.clone(),
self.text_thread_store.clone(),
editor::EditorMode::AutoHeight {
min_lines: 1,
max_lines: None,
},
window,
cx,
);
if !has_id {
editor.set_read_only(true, cx);
}
editor.set_message(chunks, window, cx);
editor
});
cx.subscribe(&message_editor, move |_, editor, event, cx| {
cx.emit(EntryViewEvent {
entry_index: index,
view_event: ViewEvent::MessageEditorEvent(editor, *event),
})
})
.detach();
self.set_entry(index, Entry::UserMessage(message_editor));
}
AgentThreadEntry::ToolCall(tool_call) => {
let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
views
} else {
self.set_entry(index, Entry::empty());
let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
unreachable!()
};
views
};
for terminal in terminals {
views.entry(terminal.entity_id()).or_insert_with(|| {
create_terminal(
self.workspace.clone(),
self.project.clone(),
terminal.clone(),
window,
cx,
)
.into_any()
});
}
for diff in diffs {
views
.entry(diff.entity_id())
.or_insert_with(|| create_editor_diff(diff.clone(), window, cx).into_any());
}
}
AgentThreadEntry::AssistantMessage(_) => {
if index == self.entries.len() {
self.entries.push(Entry::empty())
}
}
};
}
fn set_entry(&mut self, index: usize, entry: Entry) {
if index == self.entries.len() {
self.entries.push(entry);
} else {
self.entries[index] = entry;
}
}
pub fn remove(&mut self, range: Range<usize>) {
@ -48,26 +144,51 @@ impl EntryViewState {
pub fn settings_changed(&mut self, cx: &mut App) {
for entry in self.entries.iter() {
for view in entry.views.values() {
match entry {
Entry::UserMessage { .. } => {}
Entry::Content(response_views) => {
for view in response_views.values() {
if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
diff_editor.update(cx, |diff_editor, cx| {
diff_editor
.set_text_style_refinement(diff_editor_text_style_refinement(cx));
diff_editor.set_text_style_refinement(
diff_editor_text_style_refinement(cx),
);
cx.notify();
})
}
}
}
}
}
}
}
pub struct Entry {
views: HashMap<EntityId, AnyEntity>,
impl EventEmitter<EntryViewEvent> for EntryViewState {}
pub struct EntryViewEvent {
pub entry_index: usize,
pub view_event: ViewEvent,
}
pub enum ViewEvent {
MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
}
pub enum Entry {
UserMessage(Entity<MessageEditor>),
Content(HashMap<EntityId, AnyEntity>),
}
impl Entry {
pub fn editor_for_diff(&self, diff: &Entity<MultiBuffer>) -> Option<Entity<Editor>> {
self.views
pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
match self {
Self::UserMessage(editor) => Some(editor),
Entry::Content(_) => None,
}
}
pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
self.content_map()?
.get(&diff.entity_id())
.cloned()
.map(|entity| entity.downcast::<Editor>().unwrap())
@ -77,42 +198,66 @@ impl Entry {
&self,
terminal: &Entity<acp_thread::Terminal>,
) -> Option<Entity<TerminalView>> {
self.views
self.content_map()?
.get(&terminal.entity_id())
.cloned()
.map(|entity| entity.downcast::<TerminalView>().unwrap())
}
fn sync_diff_multibuffers(
&mut self,
thread: &Entity<AcpThread>,
index: usize,
window: &mut Window,
cx: &mut App,
) {
let Some(entry) = thread.read(cx).entries().get(index) else {
return;
};
let multibuffers = entry
.diffs()
.map(|diff| diff.read(cx).multibuffer().clone());
let multibuffers = multibuffers.collect::<Vec<_>>();
for multibuffer in multibuffers {
if self.views.contains_key(&multibuffer.entity_id()) {
return;
fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
match self {
Self::Content(map) => Some(map),
_ => None,
}
}
let editor = cx.new(|cx| {
fn empty() -> Self {
Self::Content(HashMap::default())
}
#[cfg(test)]
pub fn has_content(&self) -> bool {
match self {
Self::Content(map) => !map.is_empty(),
Self::UserMessage(_) => false,
}
}
}
fn create_terminal(
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
terminal: Entity<acp_thread::Terminal>,
window: &mut Window,
cx: &mut App,
) -> Entity<TerminalView> {
cx.new(|cx| {
let mut view = TerminalView::new(
terminal.read(cx).inner().clone(),
workspace.clone(),
None,
project.downgrade(),
window,
cx,
);
view.set_embedded_mode(Some(1000), cx);
view
})
}
fn create_editor_diff(
diff: Entity<acp_thread::Diff>,
window: &mut Window,
cx: &mut App,
) -> Entity<Editor> {
cx.new(|cx| {
let mut editor = Editor::new(
EditorMode::Full {
scale_ui_elements_with_buffer_font_size: false,
show_active_line_background: false,
sized_by_content: true,
},
multibuffer.clone(),
diff.read(cx).multibuffer().clone(),
None,
window,
cx,
@ -132,61 +277,7 @@ impl Entry {
editor.set_expand_all_diff_hunks(cx);
editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
editor
});
let entity_id = multibuffer.entity_id();
self.views.insert(entity_id, editor.into_any());
}
}
fn sync_terminals(
&mut self,
workspace: &WeakEntity<Workspace>,
thread: &Entity<AcpThread>,
index: usize,
window: &mut Window,
cx: &mut App,
) {
let Some(entry) = thread.read(cx).entries().get(index) else {
return;
};
let terminals = entry
.terminals()
.map(|terminal| terminal.clone())
.collect::<Vec<_>>();
for terminal in terminals {
if self.views.contains_key(&terminal.entity_id()) {
return;
}
let Some(strong_workspace) = workspace.upgrade() else {
return;
};
let terminal_view = cx.new(|cx| {
let mut view = TerminalView::new(
terminal.read(cx).inner().clone(),
workspace.clone(),
None,
strong_workspace.read(cx).project().downgrade(),
window,
cx,
);
view.set_embedded_mode(Some(1000), cx);
view
});
let entity_id = terminal.entity_id();
self.views.insert(entity_id, terminal_view.into_any());
}
}
#[cfg(test)]
pub fn len(&self) -> usize {
self.views.len()
}
})
}
fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
@ -201,26 +292,20 @@ fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
}
}
impl Default for Entry {
fn default() -> Self {
Self {
// Avoid allocating in the heap by default
views: HashMap::with_capacity(0),
}
}
}
#[cfg(test)]
mod tests {
use std::{path::Path, rc::Rc};
use acp_thread::{AgentConnection, StubAgentConnection};
use agent::{TextThreadStore, ThreadStore};
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
use editor::{EditorSettings, RowInfo};
use fs::FakeFs;
use gpui::{SemanticVersion, TestAppContext};
use gpui::{AppContext as _, SemanticVersion, TestAppContext};
use crate::acp::entry_view_state::EntryViewState;
use multi_buffer::MultiBufferRow;
use pretty_assertions::assert_matches;
use project::Project;
@ -230,8 +315,6 @@ mod tests {
use util::path;
use workspace::Workspace;
use crate::acp::entry_view_state::EntryViewState;
#[gpui::test]
async fn test_diff_sync(cx: &mut TestAppContext) {
init_test(cx);
@ -269,7 +352,7 @@ mod tests {
.update(|_, cx| {
connection
.clone()
.new_thread(project, Path::new(path!("/project")), cx)
.new_thread(project.clone(), Path::new(path!("/project")), cx)
})
.await
.unwrap();
@ -279,12 +362,23 @@ mod tests {
connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
});
let mut view_state = EntryViewState::default();
cx.update(|window, cx| {
view_state.sync_entry(workspace.downgrade(), thread.clone(), 0, window, cx);
let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx));
let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
let view_state = cx.new(|_cx| {
EntryViewState::new(
workspace.downgrade(),
project.clone(),
thread_store,
text_thread_store,
)
});
let multibuffer = thread.read_with(cx, |thread, cx| {
view_state.update_in(cx, |view_state, window, cx| {
view_state.sync_entry(0, &thread, window, cx)
});
let diff = thread.read_with(cx, |thread, _cx| {
thread
.entries()
.get(0)
@ -292,15 +386,14 @@ mod tests {
.diffs()
.next()
.unwrap()
.read(cx)
.multibuffer()
.clone()
});
cx.run_until_parked();
let entry = view_state.entry(0).unwrap();
let diff_editor = entry.editor_for_diff(&multibuffer).unwrap();
let diff_editor = view_state.read_with(cx, |view_state, _cx| {
view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
});
assert_eq!(
diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
"hi world\nhello world"

View file

@ -1,61 +1,63 @@
use crate::acp::completion_provider::ContextPickerCompletionProvider;
use crate::acp::completion_provider::MentionImage;
use crate::acp::completion_provider::MentionSet;
use acp_thread::MentionUri;
use agent::TextThreadStore;
use agent::ThreadStore;
use crate::{
acp::completion_provider::{ContextPickerCompletionProvider, MentionImage, MentionSet},
context_picker::fetch_context_picker::fetch_url_content,
};
use acp_thread::{MentionUri, selection_name};
use agent::{TextThreadStore, ThreadId, ThreadStore};
use agent_client_protocol as acp;
use anyhow::Result;
use collections::HashSet;
use editor::ExcerptId;
use editor::actions::Paste;
use editor::display_map::CreaseId;
use editor::{
AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode,
EditorStyle, MultiBuffer,
Anchor, AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement,
EditorMode, EditorStyle, ExcerptId, FoldPlaceholder, MultiBuffer, ToOffset,
actions::Paste,
display_map::{Crease, CreaseId, FoldId},
};
use futures::FutureExt as _;
use gpui::ClipboardEntry;
use gpui::Image;
use gpui::ImageFormat;
use futures::{FutureExt as _, TryFutureExt as _};
use gpui::{
AppContext, Context, Entity, EventEmitter, FocusHandle, Focusable, Task, TextStyle, WeakEntity,
AppContext, ClipboardEntry, Context, Entity, EventEmitter, FocusHandle, Focusable, Image,
ImageFormat, Img, Task, TextStyle, WeakEntity,
};
use language::Buffer;
use language::Language;
use language::{Buffer, Language};
use language_model::LanguageModelImage;
use parking_lot::Mutex;
use project::{CompletionIntent, Project};
use settings::Settings;
use std::fmt::Write;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use std::{
ffi::OsStr,
fmt::Write,
ops::Range,
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
};
use text::OffsetRangeExt;
use theme::ThemeSettings;
use ui::IconName;
use ui::SharedString;
use ui::{
ActiveTheme, App, InteractiveElement, IntoElement, ParentElement, Render, Styled, TextSize,
Window, div,
ActiveTheme, AnyElement, App, ButtonCommon, ButtonLike, ButtonStyle, Color, Icon, IconName,
IconSize, InteractiveElement, IntoElement, Label, LabelCommon, LabelSize, ParentElement,
Render, SelectableButton, SharedString, Styled, TextSize, TintColor, Toggleable, Window, div,
h_flex,
};
use util::ResultExt;
use workspace::Workspace;
use workspace::notifications::NotifyResultExt as _;
use workspace::{Workspace, notifications::NotifyResultExt as _};
use zed_actions::agent::Chat;
use super::completion_provider::Mention;
pub struct MessageEditor {
mention_set: MentionSet,
editor: Entity<Editor>,
project: Entity<Project>,
workspace: WeakEntity<Workspace>,
thread_store: Entity<ThreadStore>,
text_thread_store: Entity<TextThreadStore>,
mention_set: Arc<Mutex<MentionSet>>,
}
#[derive(Clone, Copy)]
pub enum MessageEditorEvent {
Send,
Cancel,
Focus,
}
impl EventEmitter<MessageEditorEvent> for MessageEditor {}
@ -77,8 +79,13 @@ impl MessageEditor {
},
None,
);
let mention_set = Arc::new(Mutex::new(MentionSet::default()));
let completion_provider = ContextPickerCompletionProvider::new(
workspace.clone(),
thread_store.downgrade(),
text_thread_store.downgrade(),
cx.weak_entity(),
);
let mention_set = MentionSet::default();
let editor = cx.new(|cx| {
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
@ -88,13 +95,7 @@ impl MessageEditor {
editor.set_show_indent_guides(false, cx);
editor.set_soft_wrap();
editor.set_use_modal_editing(true);
editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new(
mention_set.clone(),
workspace,
thread_store.downgrade(),
text_thread_store.downgrade(),
cx.weak_entity(),
))));
editor.set_completion_provider(Some(Rc::new(completion_provider)));
editor.set_context_menu_options(ContextMenuOptions {
min_entries_visible: 12,
max_entries_visible: 12,
@ -103,25 +104,254 @@ impl MessageEditor {
editor
});
cx.on_focus(&editor.focus_handle(cx), window, |_, _, cx| {
cx.emit(MessageEditorEvent::Focus)
})
.detach();
Self {
editor,
project,
mention_set,
thread_store,
text_thread_store,
workspace,
}
}
#[cfg(test)]
pub(crate) fn editor(&self) -> &Entity<Editor> {
&self.editor
}
#[cfg(test)]
pub(crate) fn mention_set(&mut self) -> &mut MentionSet {
&mut self.mention_set
}
pub fn is_empty(&self, cx: &App) -> bool {
self.editor.read(cx).is_empty(cx)
}
pub fn mentioned_path_and_threads(&self) -> (HashSet<PathBuf>, HashSet<ThreadId>) {
let mut excluded_paths = HashSet::default();
let mut excluded_threads = HashSet::default();
for uri in self.mention_set.uri_by_crease_id.values() {
match uri {
MentionUri::File { abs_path, .. } => {
excluded_paths.insert(abs_path.clone());
}
MentionUri::Thread { id, .. } => {
excluded_threads.insert(id.clone());
}
_ => {}
}
}
(excluded_paths, excluded_threads)
}
pub fn confirm_completion(
&mut self,
crease_text: SharedString,
start: text::Anchor,
content_len: usize,
mention_uri: MentionUri,
window: &mut Window,
cx: &mut Context<Self>,
) {
let snapshot = self
.editor
.update(cx, |editor, cx| editor.snapshot(window, cx));
let Some((excerpt_id, _, _)) = snapshot.buffer_snapshot.as_singleton() else {
return;
};
let Some(anchor) = snapshot
.buffer_snapshot
.anchor_in_excerpt(*excerpt_id, start)
else {
return;
};
let Some(crease_id) = crate::context_picker::insert_crease_for_mention(
*excerpt_id,
start,
content_len,
crease_text.clone(),
mention_uri.icon_path(cx),
self.editor.clone(),
window,
cx,
) else {
return;
};
self.mention_set.insert_uri(crease_id, mention_uri.clone());
match mention_uri {
MentionUri::Fetch { url } => {
self.confirm_mention_for_fetch(crease_id, anchor, url, window, cx);
}
MentionUri::File {
abs_path,
is_directory,
} => {
self.confirm_mention_for_file(
crease_id,
anchor,
abs_path,
is_directory,
window,
cx,
);
}
MentionUri::Symbol { .. }
| MentionUri::Thread { .. }
| MentionUri::TextThread { .. }
| MentionUri::Rule { .. }
| MentionUri::Selection { .. } => {}
}
}
fn confirm_mention_for_file(
&mut self,
crease_id: CreaseId,
anchor: Anchor,
abs_path: PathBuf,
_is_directory: bool,
window: &mut Window,
cx: &mut Context<Self>,
) {
let extension = abs_path
.extension()
.and_then(OsStr::to_str)
.unwrap_or_default();
let project = self.project.clone();
let Some(project_path) = project
.read(cx)
.project_path_for_absolute_path(&abs_path, cx)
else {
return;
};
if Img::extensions().contains(&extension) && !extension.contains("svg") {
let image = cx.spawn(async move |_, cx| {
let image = project
.update(cx, |project, cx| project.open_image(project_path, cx))?
.await?;
image.read_with(cx, |image, _cx| image.image.clone())
});
self.confirm_mention_for_image(crease_id, anchor, Some(abs_path), image, window, cx);
}
}
fn confirm_mention_for_fetch(
&mut self,
crease_id: CreaseId,
anchor: Anchor,
url: url::Url,
window: &mut Window,
cx: &mut Context<Self>,
) {
let Some(http_client) = self
.workspace
.update(cx, |workspace, _cx| workspace.client().http_client())
.ok()
else {
return;
};
let url_string = url.to_string();
let fetch = cx
.background_executor()
.spawn(async move {
fetch_url_content(http_client, url_string)
.map_err(|e| e.to_string())
.await
})
.shared();
self.mention_set
.add_fetch_result(url.clone(), fetch.clone());
cx.spawn_in(window, async move |this, cx| {
let fetch = fetch.await.notify_async_err(cx);
this.update(cx, |this, cx| {
let mention_uri = MentionUri::Fetch { url };
if fetch.is_some() {
this.mention_set.insert_uri(crease_id, mention_uri.clone());
} else {
// Remove crease if we failed to fetch
this.editor.update(cx, |editor, cx| {
editor.display_map.update(cx, |display_map, cx| {
display_map.unfold_intersecting(vec![anchor..anchor], true, cx);
});
editor.remove_creases([crease_id], cx);
});
}
})
.ok();
})
.detach();
}
pub fn confirm_mention_for_selection(
&mut self,
source_range: Range<text::Anchor>,
selections: Vec<(Entity<Buffer>, Range<text::Anchor>, Range<usize>)>,
window: &mut Window,
cx: &mut Context<Self>,
) {
let snapshot = self.editor.read(cx).buffer().read(cx).snapshot(cx);
let Some((&excerpt_id, _, _)) = snapshot.as_singleton() else {
return;
};
let Some(start) = snapshot.anchor_in_excerpt(excerpt_id, source_range.start) else {
return;
};
let offset = start.to_offset(&snapshot);
for (buffer, selection_range, range_to_fold) in selections {
let range = snapshot.anchor_after(offset + range_to_fold.start)
..snapshot.anchor_after(offset + range_to_fold.end);
let path = buffer
.read(cx)
.file()
.map_or(PathBuf::from("untitled"), |file| file.path().to_path_buf());
let snapshot = buffer.read(cx).snapshot();
let point_range = selection_range.to_point(&snapshot);
let line_range = point_range.start.row..point_range.end.row;
let uri = MentionUri::Selection {
path: path.clone(),
line_range: line_range.clone(),
};
let crease = crate::context_picker::crease_for_mention(
selection_name(&path, &line_range).into(),
uri.icon_path(cx),
range,
self.editor.downgrade(),
);
let crease_id = self.editor.update(cx, |editor, cx| {
let crease_ids = editor.insert_creases(vec![crease.clone()], cx);
editor.fold_creases(vec![crease], false, window, cx);
crease_ids.first().copied().unwrap()
});
self.mention_set
.insert_uri(crease_id, MentionUri::Selection { path, line_range });
}
}
pub fn contents(
&self,
window: &mut Window,
cx: &mut Context<Self>,
) -> Task<Result<Vec<acp::ContentBlock>>> {
let contents = self.mention_set.lock().contents(
let contents = self.mention_set.contents(
self.project.clone(),
self.thread_store.clone(),
self.text_thread_store.clone(),
@ -198,15 +428,15 @@ impl MessageEditor {
pub fn clear(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.editor.update(cx, |editor, cx| {
editor.clear(window, cx);
editor.remove_creases(self.mention_set.lock().drain(), cx)
editor.remove_creases(self.mention_set.drain(), cx)
});
}
fn chat(&mut self, _: &Chat, _: &mut Window, cx: &mut Context<Self>) {
fn send(&mut self, _: &Chat, _: &mut Window, cx: &mut Context<Self>) {
cx.emit(MessageEditorEvent::Send)
}
fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
fn cancel(&mut self, _: &editor::actions::Cancel, _: &mut Window, cx: &mut Context<Self>) {
cx.emit(MessageEditorEvent::Cancel)
}
@ -233,11 +463,16 @@ impl MessageEditor {
let replacement_text = "image";
for image in images {
let (excerpt_id, anchor) = self.editor.update(cx, |message_editor, cx| {
let (excerpt_id, text_anchor, multibuffer_anchor) =
self.editor.update(cx, |message_editor, cx| {
let snapshot = message_editor.snapshot(window, cx);
let (excerpt_id, _, snapshot) = snapshot.buffer_snapshot.as_singleton().unwrap();
let (excerpt_id, _, buffer_snapshot) =
snapshot.buffer_snapshot.as_singleton().unwrap();
let anchor = snapshot.anchor_before(snapshot.len());
let text_anchor = buffer_snapshot.anchor_before(buffer_snapshot.len());
let multibuffer_anchor = snapshot
.buffer_snapshot
.anchor_in_excerpt(*excerpt_id, text_anchor);
message_editor.edit(
[(
multi_buffer::Anchor::max()..multi_buffer::Anchor::max(),
@ -245,15 +480,29 @@ impl MessageEditor {
)],
cx,
);
(*excerpt_id, anchor)
(*excerpt_id, text_anchor, multibuffer_anchor)
});
self.insert_image(
let content_len = replacement_text.len();
let Some(anchor) = multibuffer_anchor else {
return;
};
let Some(crease_id) = insert_crease_for_image(
excerpt_id,
text_anchor,
content_len,
None.clone(),
self.editor.clone(),
window,
cx,
) else {
return;
};
self.confirm_mention_for_image(
crease_id,
anchor,
replacement_text.len(),
Arc::new(image),
None,
Task::ready(Ok(Arc::new(image))),
window,
cx,
);
@ -267,9 +516,6 @@ impl MessageEditor {
cx: &mut Context<Self>,
) {
let buffer = self.editor.read(cx).buffer().clone();
let Some((&excerpt_id, _, _)) = buffer.read(cx).snapshot(cx).as_singleton() else {
return;
};
let Some(buffer) = buffer.read(cx).as_singleton() else {
return;
};
@ -292,10 +538,8 @@ impl MessageEditor {
&path_prefix,
false,
entry.is_dir(),
excerpt_id,
anchor..anchor,
self.editor.clone(),
self.mention_set.clone(),
cx.weak_entity(),
self.project.clone(),
cx,
) else {
@ -317,33 +561,32 @@ impl MessageEditor {
}
}
fn insert_image(
pub fn set_read_only(&mut self, read_only: bool, cx: &mut Context<Self>) {
self.editor.update(cx, |message_editor, cx| {
message_editor.set_read_only(read_only);
cx.notify()
})
}
fn confirm_mention_for_image(
&mut self,
excerpt_id: ExcerptId,
crease_start: text::Anchor,
content_len: usize,
image: Arc<Image>,
abs_path: Option<Arc<Path>>,
crease_id: CreaseId,
anchor: Anchor,
abs_path: Option<PathBuf>,
image: Task<Result<Arc<Image>>>,
window: &mut Window,
cx: &mut Context<Self>,
) {
let Some(crease_id) = insert_crease_for_image(
excerpt_id,
crease_start,
content_len,
self.editor.clone(),
window,
cx,
) else {
return;
};
self.editor.update(cx, |_editor, cx| {
let format = image.format;
let convert = LanguageModelImage::from_image(image, cx);
let task = cx
.spawn_in(window, async move |editor, cx| {
if let Some(image) = convert.await {
let image = image.await.map_err(|e| e.to_string())?;
let format = image.format;
let image = cx
.update(|_, cx| LanguageModelImage::from_image(image, cx))
.map_err(|e| e.to_string())?
.await;
if let Some(image) = image {
Ok(MentionImage {
abs_path,
data: image.source,
@ -352,12 +595,6 @@ impl MessageEditor {
} else {
editor
.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx);
let Some(anchor) =
snapshot.anchor_in_excerpt(excerpt_id, crease_start)
else {
return;
};
editor.display_map.update(cx, |display_map, cx| {
display_map.unfold_intersecting(vec![anchor..anchor], true, cx);
});
@ -375,7 +612,7 @@ impl MessageEditor {
})
.detach();
self.mention_set.lock().insert_image(crease_id, task);
self.mention_set.insert_image(crease_id, task);
});
}
@ -392,6 +629,8 @@ impl MessageEditor {
window: &mut Window,
cx: &mut Context<Self>,
) {
self.clear(window, cx);
let mut text = String::new();
let mut mentions = Vec::new();
let mut images = Vec::new();
@ -429,7 +668,6 @@ impl MessageEditor {
editor.buffer().read(cx).snapshot(cx)
});
self.mention_set.lock().clear();
for (range, mention_uri) in mentions {
let anchor = snapshot.anchor_before(range.start);
let crease_id = crate::context_picker::insert_crease_for_mention(
@ -444,7 +682,7 @@ impl MessageEditor {
);
if let Some(crease_id) = crease_id {
self.mention_set.lock().insert_uri(crease_id, mention_uri);
self.mention_set.insert_uri(crease_id, mention_uri);
}
}
for (range, content) in images {
@ -479,7 +717,7 @@ impl MessageEditor {
let data: SharedString = content.data.to_string().into();
if let Some(crease_id) = crease_id {
self.mention_set.lock().insert_image(
self.mention_set.insert_image(
crease_id,
Task::ready(Ok(MentionImage {
abs_path,
@ -499,6 +737,11 @@ impl MessageEditor {
editor.set_text(text, window, cx);
});
}
#[cfg(test)]
pub fn text(&self, cx: &App) -> String {
self.editor.read(cx).text(cx)
}
}
impl Focusable for MessageEditor {
@ -511,7 +754,7 @@ impl Render for MessageEditor {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
div()
.key_context("MessageEditor")
.on_action(cx.listener(Self::chat))
.on_action(cx.listener(Self::send))
.on_action(cx.listener(Self::cancel))
.capture_action(cx.listener(Self::paste))
.flex_1()
@ -550,20 +793,78 @@ pub(crate) fn insert_crease_for_image(
excerpt_id: ExcerptId,
anchor: text::Anchor,
content_len: usize,
abs_path: Option<Arc<Path>>,
editor: Entity<Editor>,
window: &mut Window,
cx: &mut App,
) -> Option<CreaseId> {
crate::context_picker::insert_crease_for_mention(
excerpt_id,
anchor,
content_len,
"Image".into(),
IconName::Image.path().into(),
editor,
window,
cx,
let crease_label = abs_path
.as_ref()
.and_then(|path| path.file_name())
.map(|name| name.to_string_lossy().to_string().into())
.unwrap_or(SharedString::from("Image"));
editor.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx);
let start = snapshot.anchor_in_excerpt(excerpt_id, anchor)?;
let start = start.bias_right(&snapshot);
let end = snapshot.anchor_before(start.to_offset(&snapshot) + content_len);
let placeholder = FoldPlaceholder {
render: render_image_fold_icon_button(crease_label, cx.weak_entity()),
merge_adjacent: false,
..Default::default()
};
let crease = Crease::Inline {
range: start..end,
placeholder,
render_toggle: None,
render_trailer: None,
metadata: None,
};
let ids = editor.insert_creases(vec![crease.clone()], cx);
editor.fold_creases(vec![crease], false, window, cx);
Some(ids[0])
})
}
fn render_image_fold_icon_button(
label: SharedString,
editor: WeakEntity<Editor>,
) -> Arc<dyn Send + Sync + Fn(FoldId, Range<Anchor>, &mut App) -> AnyElement> {
Arc::new({
move |fold_id, fold_range, cx| {
let is_in_text_selection = editor
.update(cx, |editor, cx| editor.is_range_selected(&fold_range, cx))
.unwrap_or_default();
ButtonLike::new(fold_id)
.style(ButtonStyle::Filled)
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
.toggle_state(is_in_text_selection)
.child(
h_flex()
.gap_1()
.child(
Icon::new(IconName::Image)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(
Label::new(label.clone())
.size(LabelSize::Small)
.buffer_font(cx)
.single_line(),
),
)
.into_any_element()
}
})
}
#[cfg(test)]

View file

@ -45,6 +45,7 @@ use zed_actions::assistant::OpenRulesLibrary;
use super::entry_view_state::EntryViewState;
use crate::acp::AcpModelSelectorPopover;
use crate::acp::entry_view_state::{EntryViewEvent, ViewEvent};
use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
use crate::agent_diff::AgentDiff;
use crate::profile_selector::{ProfileProvider, ProfileSelector};
@ -101,10 +102,8 @@ pub struct AcpThreadView {
agent: Rc<dyn AgentServer>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
thread_store: Entity<ThreadStore>,
text_thread_store: Entity<TextThreadStore>,
thread_state: ThreadState,
entry_view_state: EntryViewState,
entry_view_state: Entity<EntryViewState>,
message_editor: Entity<MessageEditor>,
model_selector: Option<Entity<AcpModelSelectorPopover>>,
profile_selector: Option<Entity<ProfileSelector>>,
@ -119,16 +118,9 @@ pub struct AcpThreadView {
plan_expanded: bool,
editor_expanded: bool,
terminal_expanded: bool,
editing_message: Option<EditingMessage>,
editing_message: Option<usize>,
_cancel_task: Option<Task<()>>,
_subscriptions: [Subscription; 2],
}
struct EditingMessage {
index: usize,
message_id: UserMessageId,
editor: Entity<MessageEditor>,
_subscription: Subscription,
_subscriptions: [Subscription; 3],
}
enum ThreadState {
@ -175,25 +167,33 @@ impl AcpThreadView {
let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0));
let entry_view_state = cx.new(|_| {
EntryViewState::new(
workspace.clone(),
project.clone(),
thread_store.clone(),
text_thread_store.clone(),
)
});
let subscriptions = [
cx.observe_global_in::<SettingsStore>(window, Self::settings_changed),
cx.subscribe_in(&message_editor, window, Self::on_message_editor_event),
cx.subscribe_in(&message_editor, window, Self::handle_message_editor_event),
cx.subscribe_in(&entry_view_state, window, Self::handle_entry_view_event),
];
Self {
agent: agent.clone(),
workspace: workspace.clone(),
project: project.clone(),
thread_store,
text_thread_store,
entry_view_state,
thread_state: Self::initial_state(agent, workspace, project, window, cx),
message_editor,
model_selector: None,
profile_selector: None,
notifications: Vec::new(),
notification_subscriptions: HashMap::default(),
entry_view_state: EntryViewState::default(),
list_state,
list_state: list_state,
thread_error: None,
auth_task: None,
expanded_tool_calls: HashSet::default(),
@ -412,7 +412,7 @@ impl AcpThreadView {
cx.notify();
}
pub fn on_message_editor_event(
pub fn handle_message_editor_event(
&mut self,
_: &Entity<MessageEditor>,
event: &MessageEditorEvent,
@ -422,6 +422,28 @@ impl AcpThreadView {
match event {
MessageEditorEvent::Send => self.send(window, cx),
MessageEditorEvent::Cancel => self.cancel_generation(cx),
MessageEditorEvent::Focus => {}
}
}
pub fn handle_entry_view_event(
&mut self,
_: &Entity<EntryViewState>,
event: &EntryViewEvent,
window: &mut Window,
cx: &mut Context<Self>,
) {
match &event.view_event {
ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Focus) => {
self.editing_message = Some(event.entry_index);
cx.notify();
}
ViewEvent::MessageEditorEvent(editor, MessageEditorEvent::Send) => {
self.regenerate(event.entry_index, editor, window, cx);
}
ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Cancel) => {
self.cancel_editing(&Default::default(), window, cx);
}
}
}
@ -492,27 +514,56 @@ impl AcpThreadView {
.detach();
}
fn cancel_editing(&mut self, _: &ClickEvent, _window: &mut Window, cx: &mut Context<Self>) {
self.editing_message.take();
cx.notify();
}
fn regenerate(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
let Some(editing_message) = self.editing_message.take() else {
return;
};
fn cancel_editing(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
let Some(thread) = self.thread().cloned() else {
return;
};
let rewind = thread.update(cx, |thread, cx| {
thread.rewind(editing_message.message_id, cx)
});
if let Some(index) = self.editing_message.take() {
if let Some(editor) = self
.entry_view_state
.read(cx)
.entry(index)
.and_then(|e| e.message_editor())
.cloned()
{
editor.update(cx, |editor, cx| {
if let Some(user_message) = thread
.read(cx)
.entries()
.get(index)
.and_then(|e| e.user_message())
{
editor.set_message(user_message.chunks.clone(), window, cx);
}
})
}
};
self.focus_handle(cx).focus(window);
cx.notify();
}
fn regenerate(
&mut self,
entry_ix: usize,
message_editor: &Entity<MessageEditor>,
window: &mut Window,
cx: &mut Context<Self>,
) {
let Some(thread) = self.thread().cloned() else {
return;
};
let Some(rewind) = thread.update(cx, |thread, cx| {
let user_message_id = thread.entries().get(entry_ix)?.user_message()?.id.clone()?;
Some(thread.rewind(user_message_id, cx))
}) else {
return;
};
let contents =
message_editor.update(cx, |message_editor, cx| message_editor.contents(window, cx));
let contents = editing_message
.editor
.update(cx, |message_editor, cx| message_editor.contents(window, cx));
let task = cx.foreground_executor().spawn(async move {
rewind.await?;
contents.await
@ -568,27 +619,20 @@ impl AcpThreadView {
AcpThreadEvent::NewEntry => {
let len = thread.read(cx).entries().len();
let index = len - 1;
self.entry_view_state.sync_entry(
self.workspace.clone(),
thread.clone(),
index,
window,
cx,
);
self.entry_view_state.update(cx, |view_state, cx| {
view_state.sync_entry(index, &thread, window, cx)
});
self.list_state.splice(index..index, 1);
}
AcpThreadEvent::EntryUpdated(index) => {
self.entry_view_state.sync_entry(
self.workspace.clone(),
thread.clone(),
*index,
window,
cx,
);
self.entry_view_state.update(cx, |view_state, cx| {
view_state.sync_entry(*index, &thread, window, cx)
});
self.list_state.splice(*index..index + 1, 1);
}
AcpThreadEvent::EntriesRemoved(range) => {
self.entry_view_state.remove(range.clone());
self.entry_view_state
.update(cx, |view_state, _cx| view_state.remove(range.clone()));
self.list_state.splice(range.clone(), 0);
}
AcpThreadEvent::ToolAuthorizationRequired => {
@ -720,29 +764,15 @@ impl AcpThreadView {
.border_1()
.border_color(cx.theme().colors().border)
.text_xs()
.id("message")
.on_click(cx.listener({
move |this, _, window, cx| {
this.start_editing_message(entry_ix, window, cx)
}
}))
.children(
if let Some(editing) = self.editing_message.as_ref()
&& Some(&editing.message_id) == message.id.as_ref()
{
Some(
self.render_edit_message_editor(editing, cx)
.into_any_element(),
)
} else {
message.content.markdown().map(|md| {
self.render_markdown(
md.clone(),
user_message_markdown_style(window, cx),
)
self.entry_view_state
.read(cx)
.entry(entry_ix)
.and_then(|entry| entry.message_editor())
.map(|editor| {
self.render_sent_message_editor(entry_ix, editor, cx)
.into_any_element()
})
},
}),
),
)
.into_any(),
@ -817,8 +847,8 @@ impl AcpThreadView {
primary
};
if let Some(editing) = self.editing_message.as_ref()
&& editing.index < entry_ix
if let Some(editing_index) = self.editing_message.as_ref()
&& *editing_index < entry_ix
{
let backdrop = div()
.id(("backdrop", entry_ix))
@ -832,8 +862,8 @@ impl AcpThreadView {
div()
.relative()
.child(backdrop)
.child(primary)
.child(backdrop)
.into_any_element()
} else {
primary
@ -1254,9 +1284,7 @@ impl AcpThreadView {
Empty.into_any_element()
}
}
ToolCallContent::Diff(diff) => {
self.render_diff_editor(entry_ix, &diff.read(cx).multibuffer(), cx)
}
ToolCallContent::Diff(diff) => self.render_diff_editor(entry_ix, &diff, cx),
ToolCallContent::Terminal(terminal) => {
self.render_terminal_tool_call(entry_ix, terminal, tool_call, window, cx)
}
@ -1403,7 +1431,7 @@ impl AcpThreadView {
fn render_diff_editor(
&self,
entry_ix: usize,
multibuffer: &Entity<MultiBuffer>,
diff: &Entity<acp_thread::Diff>,
cx: &Context<Self>,
) -> AnyElement {
v_flex()
@ -1411,8 +1439,8 @@ impl AcpThreadView {
.border_t_1()
.border_color(self.tool_card_border_color(cx))
.child(
if let Some(entry) = self.entry_view_state.entry(entry_ix)
&& let Some(editor) = entry.editor_for_diff(&multibuffer)
if let Some(entry) = self.entry_view_state.read(cx).entry(entry_ix)
&& let Some(editor) = entry.editor_for_diff(&diff)
{
editor.clone().into_any_element()
} else {
@ -1615,6 +1643,7 @@ impl AcpThreadView {
let terminal_view = self
.entry_view_state
.read(cx)
.entry(entry_ix)
.and_then(|entry| entry.terminal(&terminal));
let show_output = self.terminal_expanded && terminal_view.is_some();
@ -2483,63 +2512,16 @@ impl AcpThreadView {
)
}
fn start_editing_message(&mut self, index: usize, window: &mut Window, cx: &mut Context<Self>) {
let Some(thread) = self.thread() else {
return;
};
let Some(AgentThreadEntry::UserMessage(message)) = thread.read(cx).entries().get(index)
else {
return;
};
let Some(message_id) = message.id.clone() else {
return;
};
self.list_state.scroll_to_reveal_item(index);
let chunks = message.chunks.clone();
let editor = cx.new(|cx| {
let mut editor = MessageEditor::new(
self.workspace.clone(),
self.project.clone(),
self.thread_store.clone(),
self.text_thread_store.clone(),
editor::EditorMode::AutoHeight {
min_lines: 1,
max_lines: None,
},
window,
cx,
);
editor.set_message(chunks, window, cx);
editor
});
let subscription =
cx.subscribe_in(&editor, window, |this, _, event, window, cx| match event {
MessageEditorEvent::Send => {
this.regenerate(&Default::default(), window, cx);
}
MessageEditorEvent::Cancel => {
this.cancel_editing(&Default::default(), window, cx);
}
});
editor.focus_handle(cx).focus(window);
self.editing_message.replace(EditingMessage {
index: index,
message_id: message_id.clone(),
editor,
_subscription: subscription,
});
cx.notify();
}
fn render_edit_message_editor(&self, editing: &EditingMessage, cx: &Context<Self>) -> Div {
v_flex()
.w_full()
.gap_2()
.child(editing.editor.clone())
.child(
fn render_sent_message_editor(
&self,
entry_ix: usize,
editor: &Entity<MessageEditor>,
cx: &Context<Self>,
) -> Div {
v_flex().w_full().gap_2().child(editor.clone()).when(
self.editing_message == Some(entry_ix),
|el| {
el.child(
h_flex()
.gap_1()
.child(
@ -2552,13 +2534,16 @@ impl AcpThreadView {
.color(Color::Muted)
.size(LabelSize::XSmall),
)
.child(self.render_editing_message_editor_buttons(editing, cx)),
.child(self.render_sent_message_editor_buttons(entry_ix, editor, cx)),
)
},
)
}
fn render_editing_message_editor_buttons(
fn render_sent_message_editor_buttons(
&self,
editing: &EditingMessage,
entry_ix: usize,
editor: &Entity<MessageEditor>,
cx: &Context<Self>,
) -> Div {
h_flex()
@ -2571,7 +2556,7 @@ impl AcpThreadView {
.icon_color(Color::Error)
.icon_size(IconSize::Small)
.tooltip({
let focus_handle = editing.editor.focus_handle(cx);
let focus_handle = editor.focus_handle(cx);
move |window, cx| {
Tooltip::for_action_in(
"Cancel Edit",
@ -2586,12 +2571,12 @@ impl AcpThreadView {
)
.child(
IconButton::new("confirm-edit-message", IconName::Return)
.disabled(editing.editor.read(cx).is_empty(cx))
.disabled(editor.read(cx).is_empty(cx))
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Muted)
.icon_size(IconSize::Small)
.tooltip({
let focus_handle = editing.editor.focus_handle(cx);
let focus_handle = editor.focus_handle(cx);
move |window, cx| {
Tooltip::for_action_in(
"Regenerate",
@ -2602,7 +2587,12 @@ impl AcpThreadView {
)
}
})
.on_click(cx.listener(Self::regenerate)),
.on_click(cx.listener({
let editor = editor.clone();
move |this, _, window, cx| {
this.regenerate(entry_ix, &editor, window, cx);
}
})),
)
}
@ -3102,7 +3092,9 @@ impl AcpThreadView {
}
fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
self.entry_view_state.settings_changed(cx);
self.entry_view_state.update(cx, |entry_view_state, cx| {
entry_view_state.settings_changed(cx);
});
}
pub(crate) fn insert_dragged_files(
@ -3117,9 +3109,7 @@ impl AcpThreadView {
drop(added_worktrees);
})
}
}
impl AcpThreadView {
fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option<Div> {
let content = match self.thread_error.as_ref()? {
ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx),
@ -3411,35 +3401,6 @@ impl Render for AcpThreadView {
}
}
fn user_message_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
let mut style = default_markdown_style(false, window, cx);
let mut text_style = window.text_style();
let theme_settings = ThemeSettings::get_global(cx);
let buffer_font = theme_settings.buffer_font.family.clone();
let buffer_font_size = TextSize::Small.rems(cx);
text_style.refine(&TextStyleRefinement {
font_family: Some(buffer_font),
font_size: Some(buffer_font_size.into()),
..Default::default()
});
style.base_text_style = text_style;
style.link_callback = Some(Rc::new(move |url, cx| {
if MentionUri::parse(url).is_ok() {
let colors = cx.theme().colors();
Some(TextStyleRefinement {
background_color: Some(colors.element_background),
..Default::default()
})
} else {
None
}
}));
style
}
fn default_markdown_style(buffer_font: bool, window: &Window, cx: &App) -> MarkdownStyle {
let theme_settings = ThemeSettings::get_global(cx);
let colors = cx.theme().colors();
@ -3598,12 +3559,13 @@ pub(crate) mod tests {
use agent_client_protocol::SessionId;
use editor::EditorSettings;
use fs::FakeFs;
use gpui::{SemanticVersion, TestAppContext, VisualTestContext};
use gpui::{EventEmitter, SemanticVersion, TestAppContext, VisualTestContext};
use project::Project;
use serde_json::json;
use settings::SettingsStore;
use std::any::Any;
use std::path::Path;
use workspace::Item;
use super::*;
@ -3750,6 +3712,50 @@ pub(crate) mod tests {
(thread_view, cx)
}
fn add_to_workspace(thread_view: Entity<AcpThreadView>, cx: &mut VisualTestContext) {
let workspace = thread_view.read_with(cx, |thread_view, _cx| thread_view.workspace.clone());
workspace
.update_in(cx, |workspace, window, cx| {
workspace.add_item_to_active_pane(
Box::new(cx.new(|_| ThreadViewItem(thread_view.clone()))),
None,
true,
window,
cx,
);
})
.unwrap();
}
struct ThreadViewItem(Entity<AcpThreadView>);
impl Item for ThreadViewItem {
type Event = ();
fn include_in_nav_history() -> bool {
false
}
fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
"Test".into()
}
}
impl EventEmitter<()> for ThreadViewItem {}
impl Focusable for ThreadViewItem {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.0.read(cx).focus_handle(cx).clone()
}
}
impl Render for ThreadViewItem {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
self.0.clone().into_any_element()
}
}
struct StubAgentServer<C> {
connection: C,
}
@ -3771,19 +3777,19 @@ pub(crate) mod tests {
C: 'static + AgentConnection + Send + Clone,
{
fn logo(&self) -> ui::IconName {
unimplemented!()
ui::IconName::Ai
}
fn name(&self) -> &'static str {
unimplemented!()
"Test"
}
fn empty_state_headline(&self) -> &'static str {
unimplemented!()
"Test"
}
fn empty_state_message(&self) -> &'static str {
unimplemented!()
"Test"
}
fn connect(
@ -3932,9 +3938,17 @@ pub(crate) mod tests {
assert_eq!(thread.entries().len(), 2);
});
thread_view.read_with(cx, |view, _| {
assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
thread_view.read_with(cx, |view, cx| {
view.entry_view_state.read_with(cx, |entry_view_state, _| {
assert!(
entry_view_state
.entry(0)
.unwrap()
.message_editor()
.is_some()
);
assert!(entry_view_state.entry(1).unwrap().has_content());
});
});
// Second user message
@ -3963,18 +3977,31 @@ pub(crate) mod tests {
let second_user_message_id = thread.read_with(cx, |thread, _| {
assert_eq!(thread.entries().len(), 4);
let AgentThreadEntry::UserMessage(user_message) = thread.entries().get(2).unwrap()
else {
let AgentThreadEntry::UserMessage(user_message) = &thread.entries()[2] else {
panic!();
};
user_message.id.clone().unwrap()
});
thread_view.read_with(cx, |view, _| {
assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
assert_eq!(view.entry_view_state.entry(2).unwrap().len(), 0);
assert_eq!(view.entry_view_state.entry(3).unwrap().len(), 1);
thread_view.read_with(cx, |view, cx| {
view.entry_view_state.read_with(cx, |entry_view_state, _| {
assert!(
entry_view_state
.entry(0)
.unwrap()
.message_editor()
.is_some()
);
assert!(entry_view_state.entry(1).unwrap().has_content());
assert!(
entry_view_state
.entry(2)
.unwrap()
.message_editor()
.is_some()
);
assert!(entry_view_state.entry(3).unwrap().has_content());
});
});
// Rewind to first message
@ -3989,13 +4016,169 @@ pub(crate) mod tests {
assert_eq!(thread.entries().len(), 2);
});
thread_view.read_with(cx, |view, _| {
assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
thread_view.read_with(cx, |view, cx| {
view.entry_view_state.read_with(cx, |entry_view_state, _| {
assert!(
entry_view_state
.entry(0)
.unwrap()
.message_editor()
.is_some()
);
assert!(entry_view_state.entry(1).unwrap().has_content());
// Old views should be dropped
assert!(view.entry_view_state.entry(2).is_none());
assert!(view.entry_view_state.entry(3).is_none());
assert!(entry_view_state.entry(2).is_none());
assert!(entry_view_state.entry(3).is_none());
});
});
}
#[gpui::test]
async fn test_message_editing_cancel(cx: &mut TestAppContext) {
init_test(cx);
let connection = StubAgentConnection::new();
connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk {
content: acp::ContentBlock::Text(acp::TextContent {
text: "Response".into(),
annotations: None,
}),
}]);
let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await;
add_to_workspace(thread_view.clone(), cx);
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
message_editor.update_in(cx, |editor, window, cx| {
editor.set_text("Original message to edit", window, cx);
});
thread_view.update_in(cx, |thread_view, window, cx| {
thread_view.send(window, cx);
});
cx.run_until_parked();
let user_message_editor = thread_view.read_with(cx, |view, cx| {
assert_eq!(view.editing_message, None);
view.entry_view_state
.read(cx)
.entry(0)
.unwrap()
.message_editor()
.unwrap()
.clone()
});
// Focus
cx.focus(&user_message_editor);
thread_view.read_with(cx, |view, _cx| {
assert_eq!(view.editing_message, Some(0));
});
// Edit
user_message_editor.update_in(cx, |editor, window, cx| {
editor.set_text("Edited message content", window, cx);
});
// Cancel
user_message_editor.update_in(cx, |_editor, window, cx| {
window.dispatch_action(Box::new(editor::actions::Cancel), cx);
});
thread_view.read_with(cx, |view, _cx| {
assert_eq!(view.editing_message, None);
});
user_message_editor.read_with(cx, |editor, cx| {
assert_eq!(editor.text(cx), "Original message to edit");
});
}
#[gpui::test]
async fn test_message_editing_regenerate(cx: &mut TestAppContext) {
init_test(cx);
let connection = StubAgentConnection::new();
connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk {
content: acp::ContentBlock::Text(acp::TextContent {
text: "Response".into(),
annotations: None,
}),
}]);
let (thread_view, cx) =
setup_thread_view(StubAgentServer::new(connection.clone()), cx).await;
add_to_workspace(thread_view.clone(), cx);
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
message_editor.update_in(cx, |editor, window, cx| {
editor.set_text("Original message to edit", window, cx);
});
thread_view.update_in(cx, |thread_view, window, cx| {
thread_view.send(window, cx);
});
cx.run_until_parked();
let user_message_editor = thread_view.read_with(cx, |view, cx| {
assert_eq!(view.editing_message, None);
assert_eq!(view.thread().unwrap().read(cx).entries().len(), 2);
view.entry_view_state
.read(cx)
.entry(0)
.unwrap()
.message_editor()
.unwrap()
.clone()
});
// Focus
cx.focus(&user_message_editor);
// Edit
user_message_editor.update_in(cx, |editor, window, cx| {
editor.set_text("Edited message content", window, cx);
});
// Send
connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk {
content: acp::ContentBlock::Text(acp::TextContent {
text: "New Response".into(),
annotations: None,
}),
}]);
user_message_editor.update_in(cx, |_editor, window, cx| {
window.dispatch_action(Box::new(Chat), cx);
});
cx.run_until_parked();
thread_view.read_with(cx, |view, cx| {
assert_eq!(view.editing_message, None);
let entries = view.thread().unwrap().read(cx).entries();
assert_eq!(entries.len(), 2);
assert_eq!(
entries[0].to_markdown(cx),
"## User\n\nEdited message content\n\n"
);
assert_eq!(
entries[1].to_markdown(cx),
"## Assistant\n\nNew Response\n\n"
);
let new_editor = view.entry_view_state.read_with(cx, |state, _cx| {
assert!(!state.entry(1).unwrap().has_content());
state.entry(0).unwrap().message_editor().unwrap().clone()
});
assert_eq!(new_editor.read(cx).text(cx), "Edited message content");
})
}
}

View file

@ -462,7 +462,7 @@ impl AgentConfiguration {
"modifier-send",
"Use modifier to submit a message",
Some(
"Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(),
"Make a modifier (cmd-enter on macOS, ctrl-enter on Linux or Windows) required to send messages.".into(),
),
use_modifier_to_send,
move |state, _window, cx| {

View file

@ -818,12 +818,10 @@ impl AgentPanel {
ActiveView::Thread { thread, .. } => {
thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx));
}
ActiveView::ExternalAgentThread { thread_view, .. } => {
thread_view.update(cx, |thread_element, cx| {
thread_element.cancel_generation(cx)
});
}
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {}
ActiveView::ExternalAgentThread { .. }
| ActiveView::TextThread { .. }
| ActiveView::History
| ActiveView::Configuration => {}
}
}

View file

@ -13,7 +13,7 @@ use anyhow::{Result, anyhow};
use collections::HashSet;
pub use completion_provider::ContextPickerCompletionProvider;
use editor::display_map::{Crease, CreaseId, CreaseMetadata, FoldId};
use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset};
use editor::{Anchor, Editor, ExcerptId, FoldPlaceholder, ToOffset};
use fetch_context_picker::FetchContextPicker;
use file_context_picker::FileContextPicker;
use file_context_picker::render_file_context_entry;
@ -228,7 +228,7 @@ impl ContextPicker {
}
fn build_menu(&mut self, window: &mut Window, cx: &mut Context<Self>) -> Entity<ContextMenu> {
let context_picker = cx.entity().clone();
let context_picker = cx.entity();
let menu = ContextMenu::build(window, cx, move |menu, _window, cx| {
let recent = self.recent_entries(cx);
@ -837,42 +837,9 @@ fn render_fold_icon_button(
) -> Arc<dyn Send + Sync + Fn(FoldId, Range<Anchor>, &mut App) -> AnyElement> {
Arc::new({
move |fold_id, fold_range, cx| {
let is_in_text_selection = editor.upgrade().is_some_and(|editor| {
editor.update(cx, |editor, cx| {
let snapshot = editor
.buffer()
.update(cx, |multi_buffer, cx| multi_buffer.snapshot(cx));
let is_in_pending_selection = || {
editor
.selections
.pending
.as_ref()
.is_some_and(|pending_selection| {
pending_selection
.selection
.range()
.includes(&fold_range, &snapshot)
})
};
let mut is_in_complete_selection = || {
editor
.selections
.disjoint_in_range::<usize>(fold_range.clone(), cx)
.into_iter()
.any(|selection| {
// This is needed to cover a corner case, if we just check for an existing
// selection in the fold range, having a cursor at the start of the fold
// marks it as selected. Non-empty selections don't cause this.
let length = selection.end - selection.start;
length > 0
})
};
is_in_pending_selection() || is_in_complete_selection()
})
});
let is_in_text_selection = editor
.update(cx, |editor, cx| editor.is_range_selected(&fold_range, cx))
.unwrap_or_default();
ButtonLike::new(fold_id)
.style(ButtonStyle::Filled)

View file

@ -72,7 +72,7 @@ pub fn init(
let Some(window) = window else {
return;
};
let workspace = cx.entity().clone();
let workspace = cx.entity();
InlineAssistant::update_global(cx, |inline_assistant, cx| {
inline_assistant.register_workspace(&workspace, window, cx)
});

View file

@ -1,5 +1,6 @@
use std::{cmp::Reverse, sync::Arc};
use cloud_llm_client::Plan;
use collections::{HashSet, IndexMap};
use feature_flags::ZedProFeatureFlag;
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
@ -10,7 +11,6 @@ use language_model::{
};
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
use proto::Plan;
use ui::{ListItem, ListItemSpacing, prelude::*};
const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
@ -536,7 +536,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
) -> Option<gpui::AnyElement> {
use feature_flags::FeatureFlagAppExt;
let plan = proto::Plan::ZedPro;
let plan = Plan::ZedPro;
Some(
h_flex()
@ -557,7 +557,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
window
.dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
}),
Plan::Free | Plan::ZedProTrial => Button::new(
Plan::ZedFree | Plan::ZedProTrial => Button::new(
"try-pro",
if plan == Plan::ZedProTrial {
"Upgrade to Pro"

View file

@ -163,7 +163,7 @@ impl Render for ProfileSelector {
.unwrap_or_else(|| "Unknown".into());
if self.provider.profiles_supported(cx) {
let this = cx.entity().clone();
let this = cx.entity();
let focus_handle = self.focus_handle.clone();
let trigger_button = Button::new("profile-selector-model", selected_profile)
.label_size(LabelSize::Small)

View file

@ -1,16 +1,12 @@
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
use anyhow::{Context as _, Result, anyhow};
use chrono::Duration;
use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
use futures::{StreamExt, stream::BoxStream};
use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
use http_client::{AsyncBody, Method, Request, http};
use parking_lot::Mutex;
use rpc::{
ConnectionId, Peer, Receipt, TypedEnvelope,
proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
};
use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto};
use std::sync::Arc;
pub struct FakeServer {
@ -187,7 +183,6 @@ impl FakeServer {
pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
self.executor.start_waiting();
loop {
let message = self
.state
.lock()
@ -205,33 +200,11 @@ impl FakeServer {
return Ok(*message.downcast().unwrap());
}
let accepted_tos_at = chrono::Utc::now()
.checked_sub_signed(Duration::hours(5))
.expect("failed to build accepted_tos_at")
.timestamp() as u64;
if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
self.respond(
message
.downcast::<TypedEnvelope<GetPrivateUserInfo>>()
.unwrap()
.receipt(),
GetPrivateUserInfoResponse {
metrics_id: "the-metrics-id".into(),
staff: false,
flags: Default::default(),
accepted_tos_at: Some(accepted_tos_at),
},
);
continue;
}
panic!(
"fake server received unexpected message type: {:?}",
type_name
);
}
}
pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
self.peer.respond(receipt, response).unwrap()

View file

@ -177,7 +177,6 @@ impl UserStore {
let (mut current_user_tx, current_user_rx) = watch::channel();
let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded();
let rpc_subscriptions = vec![
client.add_message_handler(cx.weak_entity(), Self::handle_update_plan),
client.add_message_handler(cx.weak_entity(), Self::handle_update_contacts),
client.add_message_handler(cx.weak_entity(), Self::handle_update_invite_info),
client.add_message_handler(cx.weak_entity(), Self::handle_show_contacts),
@ -343,26 +342,6 @@ impl UserStore {
Ok(())
}
async fn handle_update_plan(
this: Entity<Self>,
_message: TypedEnvelope<proto::UpdateUserPlan>,
mut cx: AsyncApp,
) -> Result<()> {
let client = this
.read_with(&cx, |this, _| this.client.upgrade())?
.context("client was dropped")?;
let response = client
.cloud_client()
.get_authenticated_user()
.await
.context("failed to fetch authenticated user")?;
this.update(&mut cx, |this, cx| {
this.update_authenticated_user(response, cx);
})
}
fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
match message {
UpdateContacts::Wait(barrier) => {
@ -1019,19 +998,6 @@ impl RequestUsage {
}
}
pub fn from_proto(amount: u32, limit: proto::UsageLimit) -> Option<Self> {
let limit = match limit.variant? {
proto::usage_limit::Variant::Limited(limited) => {
UsageLimit::Limited(limited.limit as i32)
}
proto::usage_limit::Variant::Unlimited(_) => UsageLimit::Unlimited,
};
Some(RequestUsage {
limit,
amount: amount as i32,
})
}
fn from_headers(
limit_name: &str,
amount_name: &str,

View file

@ -19,7 +19,6 @@ test-support = ["sqlite"]
[dependencies]
anyhow.workspace = true
async-stripe.workspace = true
async-trait.workspace = true
async-tungstenite.workspace = true
aws-config = { version = "1.1.5" }
@ -33,13 +32,11 @@ clock.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
dashmap.workspace = true
derive_more.workspace = true
envy = "0.4.2"
futures.workspace = true
gpui.workspace = true
hex.workspace = true
http_client.workspace = true
jsonwebtoken.workspace = true
livekit_api.workspace = true
log.workspace = true
nanoid.workspace = true
@ -65,7 +62,6 @@ subtle.workspace = true
supermaven_api.workspace = true
telemetry_events.workspace = true
text.workspace = true
thiserror.workspace = true
time.workspace = true
tokio = { workspace = true, features = ["full"] }
toml.workspace = true
@ -136,6 +132,3 @@ util.workspace = true
workspace = { workspace = true, features = ["test-support"] }
worktree = { workspace = true, features = ["test-support"] }
zlog.workspace = true
[package.metadata.cargo-machete]
ignored = ["async-stripe"]

View file

@ -219,12 +219,6 @@ spec:
secretKeyRef:
name: slack
key: panics_webhook
- name: STRIPE_API_KEY
valueFrom:
secretKeyRef:
name: stripe
key: api_key
optional: true
- name: COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR
value: "1000"
- name: SUPERMAVEN_ADMIN_API_KEY

View file

@ -1,16 +1,10 @@
pub mod billing;
pub mod contributors;
pub mod events;
pub mod extensions;
pub mod ips_file;
pub mod slack;
use crate::db::Database;
use crate::{
AppState, Error, Result, auth,
db::{User, UserId},
rpc,
};
use crate::{AppState, Error, Result, auth, db::UserId, rpc};
use anyhow::Context as _;
use axum::{
Extension, Json, Router,
@ -97,7 +91,6 @@ impl std::fmt::Display for SystemIdHeader {
pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
Router::new()
.route("/users/look_up", get(look_up_user))
.route("/users/:id/access_tokens", post(create_access_token))
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
.merge(contributors::router())
@ -139,99 +132,6 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
Ok::<_, Error>(next.run(req).await)
}
#[derive(Debug, Deserialize)]
struct LookUpUserParams {
identifier: String,
}
#[derive(Debug, Serialize)]
struct LookUpUserResponse {
user: Option<User>,
}
async fn look_up_user(
Query(params): Query<LookUpUserParams>,
Extension(app): Extension<Arc<AppState>>,
) -> Result<Json<LookUpUserResponse>> {
let user = resolve_identifier_to_user(&app.db, &params.identifier).await?;
let user = if let Some(user) = user {
match user {
UserOrId::User(user) => Some(user),
UserOrId::Id(id) => app.db.get_user_by_id(id).await?,
}
} else {
None
};
Ok(Json(LookUpUserResponse { user }))
}
enum UserOrId {
User(User),
Id(UserId),
}
async fn resolve_identifier_to_user(
db: &Arc<Database>,
identifier: &str,
) -> Result<Option<UserOrId>> {
if let Some(identifier) = identifier.parse::<i32>().ok() {
let user = db.get_user_by_id(UserId(identifier)).await?;
return Ok(user.map(UserOrId::User));
}
if identifier.starts_with("cus_") {
let billing_customer = db
.get_billing_customer_by_stripe_customer_id(&identifier)
.await?;
return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)));
}
if identifier.starts_with("sub_") {
let billing_subscription = db
.get_billing_subscription_by_stripe_subscription_id(&identifier)
.await?;
if let Some(billing_subscription) = billing_subscription {
let billing_customer = db
.get_billing_customer_by_id(billing_subscription.billing_customer_id)
.await?;
return Ok(
billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))
);
} else {
return Ok(None);
}
}
if identifier.contains('@') {
let user = db.get_user_by_email(identifier).await?;
return Ok(user.map(UserOrId::User));
}
if let Some(user) = db.get_user_by_github_login(identifier).await? {
return Ok(Some(UserOrId::User(user)));
}
Ok(None)
}
#[derive(Deserialize, Debug)]
struct CreateUserParams {
github_user_id: i32,
github_login: String,
email_address: String,
email_confirmation_code: Option<String>,
#[serde(default)]
admin: bool,
#[serde(default)]
invite_count: i32,
}
async fn get_rpc_server_snapshot(
Extension(rpc_server): Extension<Arc<rpc::Server>>,
) -> Result<ErasedJson> {

View file

@ -1,59 +0,0 @@
use std::sync::Arc;
use stripe::SubscriptionStatus;
use crate::AppState;
use crate::db::billing_subscription::StripeSubscriptionStatus;
use crate::db::{CreateBillingCustomerParams, billing_customer};
use crate::stripe_client::{StripeClient, StripeCustomerId};
impl From<SubscriptionStatus> for StripeSubscriptionStatus {
fn from(value: SubscriptionStatus) -> Self {
match value {
SubscriptionStatus::Incomplete => Self::Incomplete,
SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
SubscriptionStatus::Trialing => Self::Trialing,
SubscriptionStatus::Active => Self::Active,
SubscriptionStatus::PastDue => Self::PastDue,
SubscriptionStatus::Canceled => Self::Canceled,
SubscriptionStatus::Unpaid => Self::Unpaid,
SubscriptionStatus::Paused => Self::Paused,
}
}
}
/// Finds or creates a billing customer using the provided customer.
pub async fn find_or_create_billing_customer(
app: &Arc<AppState>,
stripe_client: &dyn StripeClient,
customer_id: &StripeCustomerId,
) -> anyhow::Result<Option<billing_customer::Model>> {
// If we already have a billing customer record associated with the Stripe customer,
// there's nothing more we need to do.
if let Some(billing_customer) = app
.db
.get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
.await?
{
return Ok(Some(billing_customer));
}
let customer = stripe_client.get_customer(customer_id).await?;
let Some(email) = customer.email else {
return Ok(None);
};
let Some(user) = app.db.get_user_by_email(&email).await? else {
return Ok(None);
};
let billing_customer = app
.db
.create_billing_customer(&CreateBillingCustomerParams {
user_id: user.id,
stripe_customer_id: customer.id.to_string(),
})
.await?;
Ok(Some(billing_customer))
}

View file

@ -564,170 +564,10 @@ fn for_snowflake(
country_code: Option<String>,
checksum_matched: bool,
) -> impl Iterator<Item = SnowflakeRow> {
body.events.into_iter().filter_map(move |event| {
body.events.into_iter().map(move |event| {
let timestamp =
first_event_at + Duration::milliseconds(event.milliseconds_since_first_event);
// We will need to double check, but I believe all of the events that
// are being transformed here are now migrated over to use the
// telemetry::event! macro, as of this commit so this code can go away
// when we feel enough users have upgraded past this point.
let (event_type, mut event_properties) = match &event.event {
Event::Editor(e) => (
match e.operation.as_str() {
"open" => "Editor Opened".to_string(),
"save" => "Editor Saved".to_string(),
_ => format!("Unknown Editor Event: {}", e.operation),
},
serde_json::to_value(e).unwrap(),
),
Event::EditPrediction(e) => (
format!(
"Edit Prediction {}",
if e.suggestion_accepted {
"Accepted"
} else {
"Discarded"
}
),
serde_json::to_value(e).unwrap(),
),
Event::EditPredictionRating(e) => (
"Edit Prediction Rated".to_string(),
serde_json::to_value(e).unwrap(),
),
Event::Call(e) => {
let event_type = match e.operation.trim() {
"unshare project" => "Project Unshared".to_string(),
"open channel notes" => "Channel Notes Opened".to_string(),
"share project" => "Project Shared".to_string(),
"join channel" => "Channel Joined".to_string(),
"hang up" => "Call Ended".to_string(),
"accept incoming" => "Incoming Call Accepted".to_string(),
"invite" => "Participant Invited".to_string(),
"disable microphone" => "Microphone Disabled".to_string(),
"enable microphone" => "Microphone Enabled".to_string(),
"enable screen share" => "Screen Share Enabled".to_string(),
"disable screen share" => "Screen Share Disabled".to_string(),
"decline incoming" => "Incoming Call Declined".to_string(),
_ => format!("Unknown Call Event: {}", e.operation),
};
(event_type, serde_json::to_value(e).unwrap())
}
Event::Assistant(e) => (
match e.phase {
telemetry_events::AssistantPhase::Response => "Assistant Responded".to_string(),
telemetry_events::AssistantPhase::Invoked => "Assistant Invoked".to_string(),
telemetry_events::AssistantPhase::Accepted => {
"Assistant Response Accepted".to_string()
}
telemetry_events::AssistantPhase::Rejected => {
"Assistant Response Rejected".to_string()
}
},
serde_json::to_value(e).unwrap(),
),
Event::Cpu(_) | Event::Memory(_) => return None,
Event::App(e) => {
let mut properties = json!({});
let event_type = match e.operation.trim() {
// App
"open" => "App Opened".to_string(),
"first open" => "App First Opened".to_string(),
"first open for release channel" => {
"App First Opened For Release Channel".to_string()
}
"close" => "App Closed".to_string(),
// Project
"open project" => "Project Opened".to_string(),
"open node project" => {
properties["project_type"] = json!("node");
"Project Opened".to_string()
}
"open pnpm project" => {
properties["project_type"] = json!("pnpm");
"Project Opened".to_string()
}
"open yarn project" => {
properties["project_type"] = json!("yarn");
"Project Opened".to_string()
}
// SSH
"create ssh server" => "SSH Server Created".to_string(),
"create ssh project" => "SSH Project Created".to_string(),
"open ssh project" => "SSH Project Opened".to_string(),
// Welcome Page
"welcome page: change keymap" => "Welcome Keymap Changed".to_string(),
"welcome page: change theme" => "Welcome Theme Changed".to_string(),
"welcome page: close" => "Welcome Page Closed".to_string(),
"welcome page: edit settings" => "Welcome Settings Edited".to_string(),
"welcome page: install cli" => "Welcome CLI Installed".to_string(),
"welcome page: open" => "Welcome Page Opened".to_string(),
"welcome page: open extensions" => "Welcome Extensions Page Opened".to_string(),
"welcome page: sign in to copilot" => "Welcome Copilot Signed In".to_string(),
"welcome page: toggle diagnostic telemetry" => {
"Welcome Diagnostic Telemetry Toggled".to_string()
}
"welcome page: toggle metric telemetry" => {
"Welcome Metric Telemetry Toggled".to_string()
}
"welcome page: toggle vim" => "Welcome Vim Mode Toggled".to_string(),
"welcome page: view docs" => "Welcome Documentation Viewed".to_string(),
// Extensions
"extensions page: open" => "Extensions Page Opened".to_string(),
"extensions: install extension" => "Extension Installed".to_string(),
"extensions: uninstall extension" => "Extension Uninstalled".to_string(),
// Misc
"markdown preview: open" => "Markdown Preview Opened".to_string(),
"project diagnostics: open" => "Project Diagnostics Opened".to_string(),
"project search: open" => "Project Search Opened".to_string(),
"repl sessions: open" => "REPL Session Started".to_string(),
// Feature Upsell
"feature upsell: toggle vim" => {
properties["source"] = json!("Feature Upsell");
"Vim Mode Toggled".to_string()
}
_ => e
.operation
.strip_prefix("feature upsell: viewed docs (")
.and_then(|s| s.strip_suffix(')'))
.map_or_else(
|| format!("Unknown App Event: {}", e.operation),
|docs_url| {
properties["url"] = json!(docs_url);
properties["source"] = json!("Feature Upsell");
"Documentation Viewed".to_string()
},
),
};
(event_type, properties)
}
Event::Setting(e) => (
"Settings Changed".to_string(),
serde_json::to_value(e).unwrap(),
),
Event::Extension(e) => (
"Extension Loaded".to_string(),
serde_json::to_value(e).unwrap(),
),
Event::Edit(e) => (
"Editor Edited".to_string(),
serde_json::to_value(e).unwrap(),
),
Event::Action(e) => (
"Action Invoked".to_string(),
serde_json::to_value(e).unwrap(),
),
Event::Repl(e) => (
"Kernel Status Changed".to_string(),
serde_json::to_value(e).unwrap(),
),
Event::Flexible(e) => (
e.event_type.clone(),
serde_json::to_value(&e.event_properties).unwrap(),
@ -759,7 +599,7 @@ fn for_snowflake(
})
});
Some(SnowflakeRow {
SnowflakeRow {
time: timestamp,
user_id: body.metrics_id.clone(),
device_id: body.system_id.clone(),
@ -767,7 +607,7 @@ fn for_snowflake(
event_properties,
user_properties,
insert_id: Some(Uuid::new_v4().to_string()),
})
}
})
}

View file

@ -1,5 +1,4 @@
use crate::db::{BillingCustomerId, BillingSubscriptionId};
use crate::stripe_client;
use chrono::{Datelike as _, NaiveDate, Utc};
use sea_orm::entity::prelude::*;
use serde::Serialize;
@ -160,17 +159,3 @@ pub enum StripeCancellationReason {
#[sea_orm(string_value = "payment_failed")]
PaymentFailed,
}
impl From<stripe_client::StripeCancellationDetailsReason> for StripeCancellationReason {
fn from(value: stripe_client::StripeCancellationDetailsReason) -> Self {
match value {
stripe_client::StripeCancellationDetailsReason::CancellationRequested => {
Self::CancellationRequested
}
stripe_client::StripeCancellationDetailsReason::PaymentDisputed => {
Self::PaymentDisputed
}
stripe_client::StripeCancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
}
}
}

View file

@ -7,8 +7,6 @@ pub mod llm;
pub mod migrations;
pub mod rpc;
pub mod seed;
pub mod stripe_billing;
pub mod stripe_client;
pub mod user_backfiller;
#[cfg(test)]
@ -27,16 +25,12 @@ use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::ResultExt;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::{RealStripeClient, StripeClient};
pub type Result<T, E = Error> = std::result::Result<T, E>;
pub enum Error {
Http(StatusCode, String, HeaderMap),
Database(sea_orm::error::DbErr),
Internal(anyhow::Error),
Stripe(stripe::StripeError),
}
impl From<anyhow::Error> for Error {
@ -51,12 +45,6 @@ impl From<sea_orm::error::DbErr> for Error {
}
}
impl From<stripe::StripeError> for Error {
fn from(error: stripe::StripeError) -> Self {
Self::Stripe(error)
}
}
impl From<axum::Error> for Error {
fn from(error: axum::Error) -> Self {
Self::Internal(error.into())
@ -104,14 +92,6 @@ impl IntoResponse for Error {
);
(StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
}
Error::Stripe(error) => {
log::error!(
"HTTP error {}: {:?}",
StatusCode::INTERNAL_SERVER_ERROR,
&error
);
(StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
}
}
}
}
@ -122,7 +102,6 @@ impl std::fmt::Debug for Error {
Error::Http(code, message, _headers) => (code, message).fmt(f),
Error::Database(error) => error.fmt(f),
Error::Internal(error) => error.fmt(f),
Error::Stripe(error) => error.fmt(f),
}
}
}
@ -133,7 +112,6 @@ impl std::fmt::Display for Error {
Error::Http(code, message, _) => write!(f, "{code}: {message}"),
Error::Database(error) => error.fmt(f),
Error::Internal(error) => error.fmt(f),
Error::Stripe(error) => error.fmt(f),
}
}
}
@ -179,7 +157,6 @@ pub struct Config {
pub zed_client_checksum_seed: Option<String>,
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
pub stripe_api_key: Option<String>,
pub supermaven_admin_api_key: Option<Arc<str>>,
pub user_backfiller_github_access_token: Option<Arc<str>>,
}
@ -234,7 +211,6 @@ impl Config {
auto_join_channel_id: None,
migrations_path: None,
seed_path: None,
stripe_api_key: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
kinesis_region: None,
@ -269,11 +245,6 @@ pub struct AppState {
pub llm_db: Option<Arc<LlmDatabase>>,
pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>,
/// This is a real instance of the Stripe client; we're working to replace references to this with the
/// [`StripeClient`] trait.
pub real_stripe_client: Option<Arc<stripe::Client>>,
pub stripe_client: Option<Arc<dyn StripeClient>>,
pub stripe_billing: Option<Arc<StripeBilling>>,
pub executor: Executor,
pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
pub config: Config,
@ -316,18 +287,11 @@ impl AppState {
};
let db = Arc::new(db);
let stripe_client = build_stripe_client(&config).map(Arc::new).log_err();
let this = Self {
db: db.clone(),
llm_db,
livekit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
stripe_billing: stripe_client
.clone()
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
real_stripe_client: stripe_client.clone(),
stripe_client: stripe_client
.map(|stripe_client| Arc::new(RealStripeClient::new(stripe_client)) as _),
executor,
kinesis_client: if config.kinesis_access_key.is_some() {
build_kinesis_client(&config).await.log_err()
@ -340,14 +304,6 @@ impl AppState {
}
}
fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
let api_key = config
.stripe_api_key
.as_ref()
.context("missing stripe_api_key")?;
Ok(stripe::Client::new(api_key))
}
async fn build_blob_store_client(config: &Config) -> anyhow::Result<aws_sdk_s3::Client> {
let keys = aws_sdk_s3::config::Credentials::new(
config

View file

@ -1,12 +1 @@
pub mod db;
mod token;
pub use token::*;
pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial";
/// The name of the feature flag that bypasses the account age check.
pub const BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG: &str = "bypass-account-age-check";
/// The minimum account age an account must have in order to use the LLM service.
pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);

View file

@ -1,146 +0,0 @@
use crate::db::billing_subscription::SubscriptionKind;
use crate::db::{billing_customer, billing_subscription, user};
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG};
use crate::{Config, db::billing_preference};
use anyhow::{Context as _, Result};
use chrono::{NaiveDateTime, Utc};
use cloud_llm_client::Plan;
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use thiserror::Error;
use uuid::Uuid;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LlmTokenClaims {
pub iat: u64,
pub exp: u64,
pub jti: String,
pub user_id: u64,
pub system_id: Option<String>,
pub metrics_id: Uuid,
pub github_user_login: String,
pub account_created_at: NaiveDateTime,
pub is_staff: bool,
pub has_llm_closed_beta_feature_flag: bool,
pub bypass_account_age_check: bool,
pub use_llm_request_queue: bool,
pub plan: Plan,
pub has_extended_trial: bool,
pub subscription_period: (NaiveDateTime, NaiveDateTime),
pub enable_model_request_overages: bool,
pub model_request_overages_spend_limit_in_cents: u32,
pub can_use_web_search_tool: bool,
#[serde(default)]
pub has_overdue_invoices: bool,
}
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
impl LlmTokenClaims {
pub fn create(
user: &user::Model,
is_staff: bool,
billing_customer: billing_customer::Model,
billing_preferences: Option<billing_preference::Model>,
feature_flags: &Vec<String>,
subscription: billing_subscription::Model,
system_id: Option<String>,
config: &Config,
) -> Result<String> {
let secret = config
.llm_api_secret
.as_ref()
.context("no LLM API secret")?;
let plan = if is_staff {
Plan::ZedPro
} else {
subscription.kind.map_or(Plan::ZedFree, |kind| match kind {
SubscriptionKind::ZedFree => Plan::ZedFree,
SubscriptionKind::ZedPro => Plan::ZedPro,
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
})
};
let subscription_period =
billing_subscription::Model::current_period(Some(subscription), is_staff)
.map(|(start, end)| (start.naive_utc(), end.naive_utc()))
.context("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started.")?;
let now = Utc::now();
let claims = Self {
iat: now.timestamp() as u64,
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
jti: uuid::Uuid::new_v4().to_string(),
user_id: user.id.to_proto(),
system_id,
metrics_id: user.metrics_id,
github_user_login: user.github_login.clone(),
account_created_at: user.account_created_at(),
is_staff,
has_llm_closed_beta_feature_flag: feature_flags
.iter()
.any(|flag| flag == "llm-closed-beta"),
bypass_account_age_check: feature_flags
.iter()
.any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG),
can_use_web_search_tool: true,
use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"),
plan,
has_extended_trial: feature_flags
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG),
subscription_period,
enable_model_request_overages: billing_preferences
.as_ref()
.map_or(false, |preferences| {
preferences.model_request_overages_enabled
}),
model_request_overages_spend_limit_in_cents: billing_preferences
.as_ref()
.map_or(0, |preferences| {
preferences.model_request_overages_spend_limit_in_cents as u32
}),
has_overdue_invoices: billing_customer.has_overdue_invoices,
};
Ok(jsonwebtoken::encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_ref()),
)?)
}
pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
let secret = config
.llm_api_secret
.as_ref()
.context("no LLM API secret")?;
match jsonwebtoken::decode::<Self>(
token,
&DecodingKey::from_secret(secret.as_ref()),
&Validation::default(),
) {
Ok(token) => Ok(token.claims),
Err(e) => {
if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
Err(ValidateLlmTokenError::Expired)
} else {
Err(ValidateLlmTokenError::JwtError(e))
}
}
}
}
}
#[derive(Error, Debug)]
pub enum ValidateLlmTokenError {
#[error("access token is expired")]
Expired,
#[error("access token validation error: {0}")]
JwtError(#[from] jsonwebtoken::errors::Error),
#[error("{0}")]
Other(#[from] anyhow::Error),
}

View file

@ -102,13 +102,6 @@ async fn main() -> Result<()> {
let state = AppState::new(config, Executor::Production).await?;
if let Some(stripe_billing) = state.stripe_billing.clone() {
let executor = state.executor.clone();
executor.spawn_detached(async move {
stripe_billing.initialize().await.trace_err();
});
}
if mode.is_collab() {
state.db.purge_old_embeddings().await.trace_err();

View file

@ -1,14 +1,6 @@
mod connection_pool;
use crate::api::billing::find_or_create_billing_customer;
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::db::LlmDatabase;
use crate::llm::{
AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
MIN_ACCOUNT_AGE_FOR_LLM_USE,
};
use crate::stripe_client::StripeCustomerId;
use crate::{
AppState, Error, Result, auth,
db::{
@ -37,7 +29,6 @@ use axum::{
response::IntoResponse,
routing::get,
};
use chrono::Utc;
use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
@ -148,13 +139,6 @@ pub enum Principal {
}
impl Principal {
fn user(&self) -> &User {
match self {
Principal::User(user) => user,
Principal::Impersonated { user, .. } => user,
}
}
fn update_span(&self, span: &tracing::Span) {
match &self {
Principal::User(user) => {
@ -218,6 +202,7 @@ struct Session {
/// The GeoIP country code for the user.
#[allow(unused)]
geoip_country_code: Option<String>,
#[allow(unused)]
system_id: Option<String>,
_executor: Executor,
}
@ -463,9 +448,6 @@ impl Server {
.add_request_handler(follow)
.add_message_handler(unfollow)
.add_message_handler(update_followers)
.add_request_handler(get_private_user_info)
.add_request_handler(get_llm_api_token)
.add_request_handler(accept_terms_of_service)
.add_message_handler(acknowledge_channel_message)
.add_message_handler(acknowledge_buffer_version)
.add_request_handler(get_supermaven_api_key)
@ -1000,8 +982,6 @@ impl Server {
.await?;
}
update_user_plan(session).await?;
let contacts = self.app_state.db.get_contacts(user.id).await?;
{
@ -2835,214 +2815,6 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
version.0.minor() < 139
}
async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
if is_staff {
return Ok(proto::Plan::ZedPro);
}
let subscription = db.get_active_billing_subscription(user_id).await?;
let subscription_kind = subscription.and_then(|subscription| subscription.kind);
let plan = if let Some(subscription_kind) = subscription_kind {
match subscription_kind {
SubscriptionKind::ZedPro => proto::Plan::ZedPro,
SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
SubscriptionKind::ZedFree => proto::Plan::Free,
}
} else {
proto::Plan::Free
};
Ok(plan)
}
async fn make_update_user_plan_message(
user: &User,
is_staff: bool,
db: &Arc<Database>,
llm_db: Option<Arc<LlmDatabase>>,
) -> Result<proto::UpdateUserPlan> {
let feature_flags = db.get_user_flags(user.id).await?;
let plan = current_plan(db, user.id, is_staff).await?;
let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
let billing_preferences = db.get_billing_preferences(user.id).await?;
let (subscription_period, usage) = if let Some(llm_db) = llm_db {
let subscription = db.get_active_billing_subscription(user.id).await?;
let subscription_period =
crate::db::billing_subscription::Model::current_period(subscription, is_staff);
let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
llm_db
.get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
.await?
} else {
None
};
(subscription_period, usage)
} else {
(None, None)
};
let bypass_account_age_check = feature_flags
.iter()
.any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
let account_too_young = !matches!(plan, proto::Plan::ZedPro)
&& !bypass_account_age_check
&& user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
Ok(proto::UpdateUserPlan {
plan: plan.into(),
trial_started_at: billing_customer
.as_ref()
.and_then(|billing_customer| billing_customer.trial_started_at)
.map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
is_usage_based_billing_enabled: if is_staff {
Some(true)
} else {
billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
},
subscription_period: subscription_period.map(|(started_at, ended_at)| {
proto::SubscriptionPeriod {
started_at: started_at.timestamp() as u64,
ended_at: ended_at.timestamp() as u64,
}
}),
account_too_young: Some(account_too_young),
has_overdue_invoices: billing_customer
.map(|billing_customer| billing_customer.has_overdue_invoices),
usage: Some(
usage
.map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
.unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
),
})
}
fn model_requests_limit(
plan: cloud_llm_client::Plan,
feature_flags: &Vec<String>,
) -> cloud_llm_client::UsageLimit {
match plan.model_requests_limit() {
cloud_llm_client::UsageLimit::Limited(limit) => {
let limit = if plan == cloud_llm_client::Plan::ZedProTrial
&& feature_flags
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
{
1_000
} else {
limit
};
cloud_llm_client::UsageLimit::Limited(limit)
}
cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
}
}
fn subscription_usage_to_proto(
plan: proto::Plan,
usage: crate::llm::db::subscription_usage::Model,
feature_flags: &Vec<String>,
) -> proto::SubscriptionUsage {
let plan = match plan {
proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
};
proto::SubscriptionUsage {
model_requests_usage_amount: usage.model_requests as u32,
model_requests_usage_limit: Some(proto::UsageLimit {
variant: Some(match model_requests_limit(plan, feature_flags) {
cloud_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
}),
edit_predictions_usage_amount: usage.edit_predictions as u32,
edit_predictions_usage_limit: Some(proto::UsageLimit {
variant: Some(match plan.edit_predictions_limit() {
cloud_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
}),
}
}
fn make_default_subscription_usage(
plan: proto::Plan,
feature_flags: &Vec<String>,
) -> proto::SubscriptionUsage {
let plan = match plan {
proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
};
proto::SubscriptionUsage {
model_requests_usage_amount: 0,
model_requests_usage_limit: Some(proto::UsageLimit {
variant: Some(match model_requests_limit(plan, feature_flags) {
cloud_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
}),
edit_predictions_usage_amount: 0,
edit_predictions_usage_limit: Some(proto::UsageLimit {
variant: Some(match plan.edit_predictions_limit() {
cloud_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
}),
}
}
async fn update_user_plan(session: &Session) -> Result<()> {
let db = session.db().await;
let update_user_plan = make_update_user_plan_message(
session.principal.user(),
session.is_staff(),
&db.0,
session.app_state.llm_db.clone(),
)
.await?;
session
.peer
.send(session.connection_id, update_user_plan)
.trace_err();
Ok(())
}
async fn subscribe_to_channels(
_: proto::SubscribeToChannels,
session: MessageContext,
@ -4211,139 +3983,6 @@ async fn mark_notification_as_read(
Ok(())
}
/// Get the current users information
async fn get_private_user_info(
_request: proto::GetPrivateUserInfo,
response: Response<proto::GetPrivateUserInfo>,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
let user = db
.get_user_by_id(session.user_id())
.await?
.context("user not found")?;
let flags = db.get_user_flags(session.user_id()).await?;
response.send(proto::GetPrivateUserInfoResponse {
metrics_id,
staff: user.admin,
flags,
accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
})?;
Ok(())
}
/// Accept the terms of service (tos) on behalf of the current user
async fn accept_terms_of_service(
_request: proto::AcceptTermsOfService,
response: Response<proto::AcceptTermsOfService>,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let accepted_tos_at = Utc::now();
db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
.await?;
response.send(proto::AcceptTermsOfServiceResponse {
accepted_tos_at: accepted_tos_at.timestamp() as u64,
})?;
// When the user accepts the terms of service, we want to refresh their LLM
// token to grant access.
session
.peer
.send(session.connection_id, proto::RefreshLlmToken {})?;
Ok(())
}
async fn get_llm_api_token(
_request: proto::GetLlmToken,
response: Response<proto::GetLlmToken>,
session: MessageContext,
) -> Result<()> {
let db = session.db().await;
let flags = db.get_user_flags(session.user_id()).await?;
let user_id = session.user_id();
let user = db
.get_user_by_id(user_id)
.await?
.with_context(|| format!("user {user_id} not found"))?;
if user.accepted_tos_at.is_none() {
Err(anyhow!("terms of service not accepted"))?
}
let stripe_client = session
.app_state
.stripe_client
.as_ref()
.context("failed to retrieve Stripe client")?;
let stripe_billing = session
.app_state
.stripe_billing
.as_ref()
.context("failed to retrieve Stripe billing object")?;
let billing_customer = if let Some(billing_customer) =
db.get_billing_customer_by_user_id(user.id).await?
{
billing_customer
} else {
let customer_id = stripe_billing
.find_or_create_customer_by_email(user.email_address.as_deref())
.await?;
find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
.await?
.context("billing customer not found")?
};
let billing_subscription =
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
billing_subscription
} else {
let stripe_customer_id =
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
let stripe_subscription = stripe_billing
.subscribe_to_zed_free(stripe_customer_id)
.await?;
db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
billing_customer_id: billing_customer.id,
kind: Some(SubscriptionKind::ZedFree),
stripe_subscription_id: stripe_subscription.id.to_string(),
stripe_subscription_status: stripe_subscription.status.into(),
stripe_cancellation_reason: None,
stripe_current_period_start: Some(stripe_subscription.current_period_start),
stripe_current_period_end: Some(stripe_subscription.current_period_end),
})
.await?
};
let billing_preferences = db.get_billing_preferences(user.id).await?;
let token = LlmTokenClaims::create(
&user,
session.is_staff(),
billing_customer,
billing_preferences,
&flags,
billing_subscription,
session.system_id.clone(),
&session.app_state.config,
)?;
response.send(proto::GetLlmTokenResponse { token })?;
Ok(())
}
fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
let message = match message {
TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),

View file

@ -1,156 +0,0 @@
use std::sync::Arc;
use anyhow::anyhow;
use collections::HashMap;
use stripe::SubscriptionStatus;
use tokio::sync::RwLock;
use crate::Result;
use crate::stripe_client::{
RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateSubscriptionItems,
StripeCreateSubscriptionParams, StripeCustomerId, StripePrice, StripePriceId,
StripeSubscription,
};
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
client: Arc<dyn StripeClient>,
}
#[derive(Default)]
struct StripeBillingState {
prices_by_lookup_key: HashMap<String, StripePrice>,
}
impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self {
client: Arc::new(RealStripeClient::new(client.clone())),
state: RwLock::default(),
}
}
#[cfg(test)]
pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
Self {
client,
state: RwLock::default(),
}
}
pub fn client(&self) -> &Arc<dyn StripeClient> {
&self.client
}
pub async fn initialize(&self) -> Result<()> {
log::info!("StripeBilling: initializing");
let mut state = self.state.write().await;
let prices = self.client.list_prices().await?;
for price in prices {
if let Some(lookup_key) = price.lookup_key.clone() {
state.prices_by_lookup_key.insert(lookup_key, price);
}
}
log::info!("StripeBilling: initialized");
Ok(())
}
pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
self.find_price_id_by_lookup_key("zed-pro").await
}
pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
self.find_price_id_by_lookup_key("zed-free").await
}
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
self.state
.read()
.await
.prices_by_lookup_key
.get(lookup_key)
.map(|price| price.id.clone())
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
}
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
self.state
.read()
.await
.prices_by_lookup_key
.get(lookup_key)
.cloned()
.ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
}
/// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
/// not already exist.
///
/// Always returns a new Stripe customer if the email address is `None`.
pub async fn find_or_create_customer_by_email(
&self,
email_address: Option<&str>,
) -> Result<StripeCustomerId> {
let existing_customer = if let Some(email) = email_address {
let customers = self.client.list_customers_by_email(email).await?;
customers.first().cloned()
} else {
None
};
let customer_id = if let Some(existing_customer) = existing_customer {
existing_customer.id
} else {
let customer = self
.client
.create_customer(crate::stripe_client::CreateCustomerParams {
email: email_address,
})
.await?;
customer.id
};
Ok(customer_id)
}
pub async fn subscribe_to_zed_free(
&self,
customer_id: StripeCustomerId,
) -> Result<StripeSubscription> {
let zed_free_price_id = self.zed_free_price_id().await?;
let existing_subscriptions = self
.client
.list_subscriptions_for_customer(&customer_id)
.await?;
let existing_active_subscription =
existing_subscriptions.into_iter().find(|subscription| {
subscription.status == SubscriptionStatus::Active
|| subscription.status == SubscriptionStatus::Trialing
});
if let Some(subscription) = existing_active_subscription {
return Ok(subscription);
}
let params = StripeCreateSubscriptionParams {
customer: customer_id,
items: vec![StripeCreateSubscriptionItems {
price: Some(zed_free_price_id),
quantity: Some(1),
}],
automatic_tax: Some(StripeAutomaticTax { enabled: true }),
};
let subscription = self.client.create_subscription(params).await?;
Ok(subscription)
}
}

View file

@ -1,285 +0,0 @@
#[cfg(test)]
mod fake_stripe_client;
mod real_stripe_client;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
#[cfg(test)]
pub use fake_stripe_client::*;
pub use real_stripe_client::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Serialize)]
pub struct StripeCustomerId(pub Arc<str>);
#[derive(Debug, Clone)]
pub struct StripeCustomer {
pub id: StripeCustomerId,
pub email: Option<String>,
}
#[derive(Debug)]
pub struct CreateCustomerParams<'a> {
pub email: Option<&'a str>,
}
#[derive(Debug)]
pub struct UpdateCustomerParams<'a> {
pub email: Option<&'a str>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeSubscriptionId(pub Arc<str>);
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscription {
pub id: StripeSubscriptionId,
pub customer: StripeCustomerId,
// TODO: Create our own version of this enum.
pub status: stripe::SubscriptionStatus,
pub current_period_end: i64,
pub current_period_start: i64,
pub items: Vec<StripeSubscriptionItem>,
pub cancel_at: Option<i64>,
pub cancellation_details: Option<StripeCancellationDetails>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeSubscriptionItemId(pub Arc<str>);
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionItem {
pub id: StripeSubscriptionItemId,
pub price: Option<StripePrice>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct StripeCancellationDetails {
pub reason: Option<StripeCancellationDetailsReason>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCancellationDetailsReason {
CancellationRequested,
PaymentDisputed,
PaymentFailed,
}
#[derive(Debug)]
pub struct StripeCreateSubscriptionParams {
pub customer: StripeCustomerId,
pub items: Vec<StripeCreateSubscriptionItems>,
pub automatic_tax: Option<StripeAutomaticTax>,
}
#[derive(Debug)]
pub struct StripeCreateSubscriptionItems {
pub price: Option<StripePriceId>,
pub quantity: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct UpdateSubscriptionParams {
pub items: Option<Vec<UpdateSubscriptionItems>>,
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct UpdateSubscriptionItems {
pub price: Option<StripePriceId>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionTrialSettings {
pub end_behavior: StripeSubscriptionTrialSettingsEndBehavior,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionTrialSettingsEndBehavior {
pub missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
Cancel,
CreateInvoice,
Pause,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripePriceId(pub Arc<str>);
#[derive(Debug, PartialEq, Clone)]
pub struct StripePrice {
pub id: StripePriceId,
pub unit_amount: Option<i64>,
pub lookup_key: Option<String>,
pub recurring: Option<StripePriceRecurring>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripePriceRecurring {
pub meter: Option<String>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Deserialize)]
pub struct StripeMeterId(pub Arc<str>);
#[derive(Debug, Clone, Deserialize)]
pub struct StripeMeter {
pub id: StripeMeterId,
pub event_name: String,
}
#[derive(Debug, Serialize)]
pub struct StripeCreateMeterEventParams<'a> {
pub identifier: &'a str,
pub event_name: &'a str,
pub payload: StripeCreateMeterEventPayload<'a>,
pub timestamp: Option<i64>,
}
#[derive(Debug, Serialize)]
pub struct StripeCreateMeterEventPayload<'a> {
pub value: u64,
pub stripe_customer_id: &'a StripeCustomerId,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeBillingAddressCollection {
Auto,
Required,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeCustomerUpdate {
pub address: Option<StripeCustomerUpdateAddress>,
pub name: Option<StripeCustomerUpdateName>,
pub shipping: Option<StripeCustomerUpdateShipping>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCustomerUpdateAddress {
Auto,
Never,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCustomerUpdateName {
Auto,
Never,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCustomerUpdateShipping {
Auto,
Never,
}
#[derive(Debug, Default)]
pub struct StripeCreateCheckoutSessionParams<'a> {
pub customer: Option<&'a StripeCustomerId>,
pub client_reference_id: Option<&'a str>,
pub mode: Option<StripeCheckoutSessionMode>,
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
pub success_url: Option<&'a str>,
pub billing_address_collection: Option<StripeBillingAddressCollection>,
pub customer_update: Option<StripeCustomerUpdate>,
pub tax_id_collection: Option<StripeTaxIdCollection>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCheckoutSessionMode {
Payment,
Setup,
Subscription,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeCreateCheckoutSessionLineItems {
pub price: Option<String>,
pub quantity: Option<u64>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCheckoutSessionPaymentMethodCollection {
Always,
IfRequired,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeCreateCheckoutSessionSubscriptionData {
pub metadata: Option<HashMap<String, String>>,
pub trial_period_days: Option<u32>,
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeTaxIdCollection {
pub enabled: bool,
}
#[derive(Debug, Clone)]
pub struct StripeAutomaticTax {
pub enabled: bool,
}
#[derive(Debug)]
pub struct StripeCheckoutSession {
pub url: Option<String>,
}
#[async_trait]
pub trait StripeClient: Send + Sync {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer>;
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
async fn update_customer(
&self,
customer_id: &StripeCustomerId,
params: UpdateCustomerParams<'_>,
) -> Result<StripeCustomer>;
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>>;
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
) -> Result<StripeSubscription>;
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription>;
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
params: UpdateSubscriptionParams,
) -> Result<()>;
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()>;
async fn list_prices(&self) -> Result<Vec<StripePrice>>;
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>;
async fn create_checkout_session(
&self,
params: StripeCreateCheckoutSessionParams<'_>,
) -> Result<StripeCheckoutSession>;
}

View file

@ -1,247 +0,0 @@
use std::sync::Arc;
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use chrono::{Duration, Utc};
use collections::HashMap;
use parking_lot::Mutex;
use uuid::Uuid;
use crate::stripe_client::{
CreateCustomerParams, StripeBillingAddressCollection, StripeCheckoutSession,
StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate,
StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, StripeTaxIdCollection,
UpdateCustomerParams, UpdateSubscriptionParams,
};
#[derive(Debug, Clone)]
pub struct StripeCreateMeterEventCall {
pub identifier: Arc<str>,
pub event_name: Arc<str>,
pub value: u64,
pub stripe_customer_id: StripeCustomerId,
pub timestamp: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct StripeCreateCheckoutSessionCall {
pub customer: Option<StripeCustomerId>,
pub client_reference_id: Option<String>,
pub mode: Option<StripeCheckoutSessionMode>,
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
pub success_url: Option<String>,
pub billing_address_collection: Option<StripeBillingAddressCollection>,
pub customer_update: Option<StripeCustomerUpdate>,
pub tax_id_collection: Option<StripeTaxIdCollection>,
}
pub struct FakeStripeClient {
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
pub update_subscription_calls:
Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
}
impl FakeStripeClient {
pub fn new() -> Self {
Self {
customers: Arc::new(Mutex::new(HashMap::default())),
subscriptions: Arc::new(Mutex::new(HashMap::default())),
update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
prices: Arc::new(Mutex::new(HashMap::default())),
meters: Arc::new(Mutex::new(HashMap::default())),
create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
}
}
}
#[async_trait]
impl StripeClient for FakeStripeClient {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
Ok(self
.customers
.lock()
.values()
.filter(|customer| customer.email.as_deref() == Some(email))
.cloned()
.collect())
}
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
self.customers
.lock()
.get(customer_id)
.cloned()
.ok_or_else(|| anyhow!("no customer found for {customer_id:?}"))
}
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
let customer = StripeCustomer {
id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
email: params.email.map(|email| email.to_string()),
};
self.customers
.lock()
.insert(customer.id.clone(), customer.clone());
Ok(customer)
}
async fn update_customer(
&self,
customer_id: &StripeCustomerId,
params: UpdateCustomerParams<'_>,
) -> Result<StripeCustomer> {
let mut customers = self.customers.lock();
if let Some(customer) = customers.get_mut(customer_id) {
if let Some(email) = params.email {
customer.email = Some(email.to_string());
}
Ok(customer.clone())
} else {
Err(anyhow!("no customer found for {customer_id:?}"))
}
}
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>> {
let subscriptions = self
.subscriptions
.lock()
.values()
.filter(|subscription| subscription.customer == *customer_id)
.cloned()
.collect();
Ok(subscriptions)
}
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
) -> Result<StripeSubscription> {
self.subscriptions
.lock()
.get(subscription_id)
.cloned()
.ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
}
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription> {
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
customer: params.customer,
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: params
.items
.into_iter()
.map(|item| StripeSubscriptionItem {
id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
price: item
.price
.and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
})
.collect(),
cancel_at: None,
cancellation_details: None,
};
self.subscriptions
.lock()
.insert(subscription.id.clone(), subscription.clone());
Ok(subscription)
}
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
params: UpdateSubscriptionParams,
) -> Result<()> {
let subscription = self.get_subscription(subscription_id).await?;
self.update_subscription_calls
.lock()
.push((subscription.id, params));
Ok(())
}
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
// TODO: Implement fake subscription cancellation.
let _ = subscription_id;
Ok(())
}
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
let prices = self.prices.lock().values().cloned().collect();
Ok(prices)
}
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
let meters = self.meters.lock().values().cloned().collect();
Ok(meters)
}
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
self.create_meter_event_calls
.lock()
.push(StripeCreateMeterEventCall {
identifier: params.identifier.into(),
event_name: params.event_name.into(),
value: params.payload.value,
stripe_customer_id: params.payload.stripe_customer_id.clone(),
timestamp: params.timestamp,
});
Ok(())
}
async fn create_checkout_session(
&self,
params: StripeCreateCheckoutSessionParams<'_>,
) -> Result<StripeCheckoutSession> {
self.create_checkout_session_calls
.lock()
.push(StripeCreateCheckoutSessionCall {
customer: params.customer.cloned(),
client_reference_id: params.client_reference_id.map(|id| id.to_string()),
mode: params.mode,
line_items: params.line_items,
payment_method_collection: params.payment_method_collection,
subscription_data: params.subscription_data,
success_url: params.success_url.map(|url| url.to_string()),
billing_address_collection: params.billing_address_collection,
customer_update: params.customer_update,
tax_id_collection: params.tax_id_collection,
});
Ok(StripeCheckoutSession {
url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
})
}
}

View file

@ -1,612 +0,0 @@
use std::str::FromStr as _;
use std::sync::Arc;
use anyhow::{Context as _, Result, anyhow};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use stripe::{
CancellationDetails, CancellationDetailsReason, CheckoutSession, CheckoutSessionMode,
CheckoutSessionPaymentMethodCollection, CreateCheckoutSession, CreateCheckoutSessionLineItems,
CreateCheckoutSessionSubscriptionData, CreateCheckoutSessionSubscriptionDataTrialSettings,
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior,
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod,
CreateCustomer, CreateSubscriptionAutomaticTax, Customer, CustomerId, ListCustomers, Price,
PriceId, Recurring, Subscription, SubscriptionId, SubscriptionItem, SubscriptionItemId,
UpdateCustomer, UpdateSubscriptionItems, UpdateSubscriptionTrialSettings,
UpdateSubscriptionTrialSettingsEndBehavior,
UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
};
use crate::stripe_client::{
CreateCustomerParams, StripeAutomaticTax, StripeBillingAddressCollection,
StripeCancellationDetails, StripeCancellationDetailsReason, StripeCheckoutSession,
StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate,
StripeCustomerUpdateAddress, StripeCustomerUpdateName, StripeCustomerUpdateShipping,
StripeMeter, StripePrice, StripePriceId, StripePriceRecurring, StripeSubscription,
StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, StripeTaxIdCollection,
UpdateCustomerParams, UpdateSubscriptionParams,
};
pub struct RealStripeClient {
client: Arc<stripe::Client>,
}
impl RealStripeClient {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self { client }
}
}
#[async_trait]
impl StripeClient for RealStripeClient {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
let response = Customer::list(
&self.client,
&ListCustomers {
email: Some(email),
..Default::default()
},
)
.await?;
Ok(response
.data
.into_iter()
.map(StripeCustomer::from)
.collect())
}
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
let customer_id = customer_id.try_into()?;
let customer = Customer::retrieve(&self.client, &customer_id, &[]).await?;
Ok(StripeCustomer::from(customer))
}
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
let customer = Customer::create(
&self.client,
CreateCustomer {
email: params.email,
..Default::default()
},
)
.await?;
Ok(StripeCustomer::from(customer))
}
async fn update_customer(
&self,
customer_id: &StripeCustomerId,
params: UpdateCustomerParams<'_>,
) -> Result<StripeCustomer> {
let customer = Customer::update(
&self.client,
&customer_id.try_into()?,
UpdateCustomer {
email: params.email,
..Default::default()
},
)
.await?;
Ok(StripeCustomer::from(customer))
}
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>> {
let customer_id = customer_id.try_into()?;
let subscriptions = stripe::Subscription::list(
&self.client,
&stripe::ListSubscriptions {
customer: Some(customer_id),
status: None,
..Default::default()
},
)
.await?;
Ok(subscriptions
.data
.into_iter()
.map(StripeSubscription::from)
.collect())
}
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
) -> Result<StripeSubscription> {
let subscription_id = subscription_id.try_into()?;
let subscription = Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
Ok(StripeSubscription::from(subscription))
}
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription> {
let customer_id = params.customer.try_into()?;
let mut create_subscription = stripe::CreateSubscription::new(customer_id);
create_subscription.items = Some(
params
.items
.into_iter()
.map(|item| stripe::CreateSubscriptionItems {
price: item.price.map(|price| price.to_string()),
quantity: item.quantity,
..Default::default()
})
.collect(),
);
create_subscription.automatic_tax = params.automatic_tax.map(Into::into);
let subscription = Subscription::create(&self.client, create_subscription).await?;
Ok(StripeSubscription::from(subscription))
}
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
params: UpdateSubscriptionParams,
) -> Result<()> {
let subscription_id = subscription_id.try_into()?;
stripe::Subscription::update(
&self.client,
&subscription_id,
stripe::UpdateSubscription {
items: params.items.map(|items| {
items
.into_iter()
.map(|item| UpdateSubscriptionItems {
price: item.price.map(|price| price.to_string()),
..Default::default()
})
.collect()
}),
trial_settings: params.trial_settings.map(Into::into),
..Default::default()
},
)
.await?;
Ok(())
}
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
let subscription_id = subscription_id.try_into()?;
Subscription::cancel(
&self.client,
&subscription_id,
stripe::CancelSubscription {
invoice_now: None,
..Default::default()
},
)
.await?;
Ok(())
}
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
let response = stripe::Price::list(
&self.client,
&stripe::ListPrices {
limit: Some(100),
..Default::default()
},
)
.await?;
Ok(response.data.into_iter().map(StripePrice::from).collect())
}
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
#[derive(Serialize)]
struct Params {
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<u64>,
}
let response = self
.client
.get_query::<stripe::List<StripeMeter>, _>(
"/billing/meters",
Params { limit: Some(100) },
)
.await?;
Ok(response.data)
}
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
#[derive(Deserialize)]
struct StripeMeterEvent {
pub identifier: String,
}
let identifier = params.identifier;
match self
.client
.post_form::<StripeMeterEvent, _>("/billing/meter_events", params)
.await
{
Ok(_event) => Ok(()),
Err(stripe::StripeError::Stripe(error)) => {
if error.http_status == 400
&& error
.message
.as_ref()
.map_or(false, |message| message.contains(identifier))
{
Ok(())
} else {
Err(anyhow!(stripe::StripeError::Stripe(error)))
}
}
Err(error) => Err(anyhow!("failed to create meter event: {error:?}")),
}
}
async fn create_checkout_session(
&self,
params: StripeCreateCheckoutSessionParams<'_>,
) -> Result<StripeCheckoutSession> {
let params = params.try_into()?;
let session = CheckoutSession::create(&self.client, params).await?;
Ok(session.into())
}
}
impl From<CustomerId> for StripeCustomerId {
fn from(value: CustomerId) -> Self {
Self(value.as_str().into())
}
}
impl TryFrom<StripeCustomerId> for CustomerId {
type Error = anyhow::Error;
fn try_from(value: StripeCustomerId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
}
}
impl TryFrom<&StripeCustomerId> for CustomerId {
type Error = anyhow::Error;
fn try_from(value: &StripeCustomerId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
}
}
impl From<Customer> for StripeCustomer {
fn from(value: Customer) -> Self {
StripeCustomer {
id: value.id.into(),
email: value.email,
}
}
}
impl From<SubscriptionId> for StripeSubscriptionId {
fn from(value: SubscriptionId) -> Self {
Self(value.as_str().into())
}
}
impl TryFrom<&StripeSubscriptionId> for SubscriptionId {
type Error = anyhow::Error;
fn try_from(value: &StripeSubscriptionId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe subscription ID")
}
}
impl From<Subscription> for StripeSubscription {
fn from(value: Subscription) -> Self {
Self {
id: value.id.into(),
customer: value.customer.id().into(),
status: value.status,
current_period_start: value.current_period_start,
current_period_end: value.current_period_end,
items: value.items.data.into_iter().map(Into::into).collect(),
cancel_at: value.cancel_at,
cancellation_details: value.cancellation_details.map(Into::into),
}
}
}
impl From<CancellationDetails> for StripeCancellationDetails {
fn from(value: CancellationDetails) -> Self {
Self {
reason: value.reason.map(Into::into),
}
}
}
impl From<CancellationDetailsReason> for StripeCancellationDetailsReason {
fn from(value: CancellationDetailsReason) -> Self {
match value {
CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
}
}
}
impl From<SubscriptionItemId> for StripeSubscriptionItemId {
fn from(value: SubscriptionItemId) -> Self {
Self(value.as_str().into())
}
}
impl From<SubscriptionItem> for StripeSubscriptionItem {
fn from(value: SubscriptionItem) -> Self {
Self {
id: value.id.into(),
price: value.price.map(Into::into),
}
}
}
impl From<StripeAutomaticTax> for CreateSubscriptionAutomaticTax {
fn from(value: StripeAutomaticTax) -> Self {
Self {
enabled: value.enabled,
liability: None,
}
}
}
impl From<StripeSubscriptionTrialSettings> for UpdateSubscriptionTrialSettings {
fn from(value: StripeSubscriptionTrialSettings) -> Self {
Self {
end_behavior: value.end_behavior.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehavior>
for UpdateSubscriptionTrialSettingsEndBehavior
{
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
Self {
missing_payment_method: value.missing_payment_method.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod
{
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
match value {
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
Self::CreateInvoice
}
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
}
}
}
impl From<PriceId> for StripePriceId {
fn from(value: PriceId) -> Self {
Self(value.as_str().into())
}
}
impl TryFrom<StripePriceId> for PriceId {
type Error = anyhow::Error;
fn try_from(value: StripePriceId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID")
}
}
impl From<Price> for StripePrice {
fn from(value: Price) -> Self {
Self {
id: value.id.into(),
unit_amount: value.unit_amount,
lookup_key: value.lookup_key,
recurring: value.recurring.map(StripePriceRecurring::from),
}
}
}
impl From<Recurring> for StripePriceRecurring {
fn from(value: Recurring) -> Self {
Self { meter: value.meter }
}
}
impl<'a> TryFrom<StripeCreateCheckoutSessionParams<'a>> for CreateCheckoutSession<'a> {
type Error = anyhow::Error;
fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result<Self, Self::Error> {
Ok(Self {
customer: value
.customer
.map(|customer_id| customer_id.try_into())
.transpose()?,
client_reference_id: value.client_reference_id,
mode: value.mode.map(Into::into),
line_items: value
.line_items
.map(|line_items| line_items.into_iter().map(Into::into).collect()),
payment_method_collection: value.payment_method_collection.map(Into::into),
subscription_data: value.subscription_data.map(Into::into),
success_url: value.success_url,
billing_address_collection: value.billing_address_collection.map(Into::into),
customer_update: value.customer_update.map(Into::into),
tax_id_collection: value.tax_id_collection.map(Into::into),
..Default::default()
})
}
}
impl From<StripeCheckoutSessionMode> for CheckoutSessionMode {
fn from(value: StripeCheckoutSessionMode) -> Self {
match value {
StripeCheckoutSessionMode::Payment => Self::Payment,
StripeCheckoutSessionMode::Setup => Self::Setup,
StripeCheckoutSessionMode::Subscription => Self::Subscription,
}
}
}
impl From<StripeCreateCheckoutSessionLineItems> for CreateCheckoutSessionLineItems {
fn from(value: StripeCreateCheckoutSessionLineItems) -> Self {
Self {
price: value.price,
quantity: value.quantity,
..Default::default()
}
}
}
impl From<StripeCheckoutSessionPaymentMethodCollection> for CheckoutSessionPaymentMethodCollection {
fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self {
match value {
StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always,
StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired,
}
}
}
impl From<StripeCreateCheckoutSessionSubscriptionData> for CreateCheckoutSessionSubscriptionData {
fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self {
Self {
trial_period_days: value.trial_period_days,
trial_settings: value.trial_settings.map(Into::into),
metadata: value.metadata,
..Default::default()
}
}
}
impl From<StripeSubscriptionTrialSettings> for CreateCheckoutSessionSubscriptionDataTrialSettings {
fn from(value: StripeSubscriptionTrialSettings) -> Self {
Self {
end_behavior: value.end_behavior.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehavior>
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior
{
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
Self {
missing_payment_method: value.missing_payment_method.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod
{
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
match value {
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
Self::CreateInvoice
}
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
}
}
}
impl From<CheckoutSession> for StripeCheckoutSession {
fn from(value: CheckoutSession) -> Self {
Self { url: value.url }
}
}
impl From<StripeBillingAddressCollection> for stripe::CheckoutSessionBillingAddressCollection {
fn from(value: StripeBillingAddressCollection) -> Self {
match value {
StripeBillingAddressCollection::Auto => {
stripe::CheckoutSessionBillingAddressCollection::Auto
}
StripeBillingAddressCollection::Required => {
stripe::CheckoutSessionBillingAddressCollection::Required
}
}
}
}
impl From<StripeCustomerUpdateAddress> for stripe::CreateCheckoutSessionCustomerUpdateAddress {
fn from(value: StripeCustomerUpdateAddress) -> Self {
match value {
StripeCustomerUpdateAddress::Auto => {
stripe::CreateCheckoutSessionCustomerUpdateAddress::Auto
}
StripeCustomerUpdateAddress::Never => {
stripe::CreateCheckoutSessionCustomerUpdateAddress::Never
}
}
}
}
impl From<StripeCustomerUpdateName> for stripe::CreateCheckoutSessionCustomerUpdateName {
fn from(value: StripeCustomerUpdateName) -> Self {
match value {
StripeCustomerUpdateName::Auto => stripe::CreateCheckoutSessionCustomerUpdateName::Auto,
StripeCustomerUpdateName::Never => {
stripe::CreateCheckoutSessionCustomerUpdateName::Never
}
}
}
}
impl From<StripeCustomerUpdateShipping> for stripe::CreateCheckoutSessionCustomerUpdateShipping {
fn from(value: StripeCustomerUpdateShipping) -> Self {
match value {
StripeCustomerUpdateShipping::Auto => {
stripe::CreateCheckoutSessionCustomerUpdateShipping::Auto
}
StripeCustomerUpdateShipping::Never => {
stripe::CreateCheckoutSessionCustomerUpdateShipping::Never
}
}
}
}
impl From<StripeCustomerUpdate> for stripe::CreateCheckoutSessionCustomerUpdate {
fn from(value: StripeCustomerUpdate) -> Self {
stripe::CreateCheckoutSessionCustomerUpdate {
address: value.address.map(Into::into),
name: value.name.map(Into::into),
shipping: value.shipping.map(Into::into),
}
}
}
impl From<StripeTaxIdCollection> for stripe::CreateCheckoutSessionTaxIdCollection {
fn from(value: StripeTaxIdCollection) -> Self {
stripe::CreateCheckoutSessionTaxIdCollection {
enabled: value.enabled,
}
}
}

View file

@ -8,7 +8,6 @@ mod channel_buffer_tests;
mod channel_guest_tests;
mod channel_message_tests;
mod channel_tests;
// mod debug_panel_tests;
mod editor_tests;
mod following_tests;
mod git_tests;
@ -18,7 +17,6 @@ mod random_channel_buffer_tests;
mod random_project_collaboration_tests;
mod randomized_test_helpers;
mod remote_editing_collaboration_tests;
mod stripe_billing_tests;
mod test_server;
use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};

View file

@ -1,123 +0,0 @@
use std::sync::Arc;
use pretty_assertions::assert_eq;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::{FakeStripeClient, StripePrice, StripePriceId, StripePriceRecurring};
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
let stripe_client = Arc::new(FakeStripeClient::new());
let stripe_billing = StripeBilling::test(stripe_client.clone());
(stripe_billing, stripe_client)
}
#[gpui::test]
async fn test_initialize() {
let (stripe_billing, stripe_client) = make_stripe_billing();
// Add test prices
let price1 = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(1_000),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
let price2 = StripePrice {
id: StripePriceId("price_2".into()),
unit_amount: Some(0),
lookup_key: Some("zed-free".to_string()),
recurring: None,
};
let price3 = StripePrice {
id: StripePriceId("price_3".into()),
unit_amount: Some(500),
lookup_key: None,
recurring: Some(StripePriceRecurring {
meter: Some("meter_1".to_string()),
}),
};
stripe_client
.prices
.lock()
.insert(price1.id.clone(), price1);
stripe_client
.prices
.lock()
.insert(price2.id.clone(), price2);
stripe_client
.prices
.lock()
.insert(price3.id.clone(), price3);
// Initialize the billing system
stripe_billing.initialize().await.unwrap();
// Verify that prices can be found by lookup key
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
assert_eq!(zed_pro_price_id.to_string(), "price_1");
let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
assert_eq!(zed_free_price_id.to_string(), "price_2");
// Verify that a price can be found by lookup key
let zed_pro_price = stripe_billing
.find_price_by_lookup_key("zed-pro")
.await
.unwrap();
assert_eq!(zed_pro_price.id.to_string(), "price_1");
assert_eq!(zed_pro_price.unit_amount, Some(1_000));
// Verify that finding a non-existent lookup key returns an error
let result = stripe_billing
.find_price_by_lookup_key("non-existent")
.await;
assert!(result.is_err());
}
#[gpui::test]
async fn test_find_or_create_customer_by_email() {
let (stripe_billing, stripe_client) = make_stripe_billing();
// Create a customer with an email that doesn't yet correspond to a customer.
{
let email = "user@example.com";
let customer_id = stripe_billing
.find_or_create_customer_by_email(Some(email))
.await
.unwrap();
let customer = stripe_client
.customers
.lock()
.get(&customer_id)
.unwrap()
.clone();
assert_eq!(customer.email.as_deref(), Some(email));
}
// Create a customer with an email that corresponds to an existing customer.
{
let email = "user2@example.com";
let existing_customer_id = stripe_billing
.find_or_create_customer_by_email(Some(email))
.await
.unwrap();
let customer_id = stripe_billing
.find_or_create_customer_by_email(Some(email))
.await
.unwrap();
assert_eq!(customer_id, existing_customer_id);
let customer = stripe_client
.customers
.lock()
.get(&customer_id)
.unwrap()
.clone();
assert_eq!(customer.email.as_deref(), Some(email));
}
}

View file

@ -1,4 +1,3 @@
use crate::stripe_client::FakeStripeClient;
use crate::{
AppState, Config,
db::{NewUserParams, UserId, tests::TestDb},
@ -569,9 +568,6 @@ impl TestServer {
llm_db: None,
livekit_client: Some(Arc::new(livekit_test_server.create_api_client())),
blob_store_client: None,
real_stripe_client: None,
stripe_client: Some(Arc::new(FakeStripeClient::new())),
stripe_billing: None,
executor,
kinesis_client: None,
config: Config {
@ -608,7 +604,6 @@ impl TestServer {
auto_join_channel_id: None,
migrations_path: None,
seed_path: None,
stripe_api_key: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
kinesis_region: None,

View file

@ -674,7 +674,7 @@ impl ChatPanel {
})
})
.when_some(message_id, |el, message_id| {
let this = cx.entity().clone();
let this = cx.entity();
el.child(
self.render_popover_button(

View file

@ -95,7 +95,7 @@ pub fn init(cx: &mut App) {
.and_then(|room| room.read(cx).channel_id());
if let Some(channel_id) = channel_id {
let workspace = cx.entity().clone();
let workspace = cx.entity();
window.defer(cx, move |window, cx| {
ChannelView::open(channel_id, None, workspace, window, cx)
.detach_and_log_err(cx)
@ -1142,7 +1142,7 @@ impl CollabPanel {
window: &mut Window,
cx: &mut Context<Self>,
) {
let this = cx.entity().clone();
let this = cx.entity();
if !(role == proto::ChannelRole::Guest
|| role == proto::ChannelRole::Talker
|| role == proto::ChannelRole::Member)
@ -1272,7 +1272,7 @@ impl CollabPanel {
.channel_for_id(clipboard.channel_id)
.map(|channel| channel.name.clone())
});
let this = cx.entity().clone();
let this = cx.entity();
let context_menu = ContextMenu::build(window, cx, |mut context_menu, window, cx| {
if self.has_subchannels(ix) {
@ -1439,7 +1439,7 @@ impl CollabPanel {
window: &mut Window,
cx: &mut Context<Self>,
) {
let this = cx.entity().clone();
let this = cx.entity();
let in_room = ActiveCall::global(cx).read(cx).room().is_some();
let context_menu = ContextMenu::build(window, cx, |mut context_menu, _, _| {

View file

@ -586,7 +586,7 @@ impl ChannelModalDelegate {
return;
};
let user_id = membership.user.id;
let picker = cx.entity().clone();
let picker = cx.entity();
let context_menu = ContextMenu::build(window, cx, |mut menu, _window, _cx| {
let role = membership.role;

View file

@ -321,7 +321,7 @@ impl NotificationPanel {
.justify_end()
.child(Button::new("decline", "Decline").on_click({
let notification = notification.clone();
let entity = cx.entity().clone();
let entity = cx.entity();
move |_, _, cx| {
entity.update(cx, |this, cx| {
this.respond_to_notification(
@ -334,7 +334,7 @@ impl NotificationPanel {
}))
.child(Button::new("accept", "Accept").on_click({
let notification = notification.clone();
let entity = cx.entity().clone();
let entity = cx.entity();
move |_, _, cx| {
entity.update(cx, |this, cx| {
this.respond_to_notification(

View file

@ -291,7 +291,7 @@ pub(crate) fn new_debugger_pane(
let Some(project) = project.upgrade() else {
return ControlFlow::Break(());
};
let this_pane = cx.entity().clone();
let this_pane = cx.entity();
let item = if tab.pane == this_pane {
pane.item_for_index(tab.ix)
} else {
@ -502,7 +502,7 @@ pub(crate) fn new_debugger_pane(
.on_drag(
DraggedTab {
item: item.boxed_clone(),
pane: cx.entity().clone(),
pane: cx.entity(),
detail: 0,
is_active: selected,
ix,

View file

@ -971,7 +971,7 @@ async fn active_diagnostics_dismiss_after_invalidation(cx: &mut TestAppContext)
let mut cx = EditorTestContext::new(cx).await;
let lsp_store =
cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store());
cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store());
cx.set_state(indoc! {"
ˇfn func(abc def: i32) -> u32 {
@ -1065,7 +1065,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) {
let mut cx = EditorTestContext::new(cx).await;
let lsp_store =
cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store());
cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store());
cx.set_state(indoc! {"
ˇfn func(abc def: i32) -> u32 {
@ -1239,7 +1239,7 @@ async fn test_diagnostics_with_links(cx: &mut TestAppContext) {
}
"});
let lsp_store =
cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store());
cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store());
cx.update(|_, cx| {
lsp_store.update(cx, |lsp_store, cx| {
@ -1293,7 +1293,7 @@ async fn test_hover_diagnostic_and_info_popovers(cx: &mut gpui::TestAppContext)
fn «test»() { println!(); }
"});
let lsp_store =
cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store());
cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store());
cx.update(|_, cx| {
lsp_store.update(cx, |lsp_store, cx| {
lsp_store.update_diagnostics(
@ -1450,7 +1450,7 @@ async fn go_to_diagnostic_with_severity(cx: &mut TestAppContext) {
let mut cx = EditorTestContext::new(cx).await;
let lsp_store =
cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store());
cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store());
cx.set_state(indoc! {"error warning info hiˇnt"});

View file

@ -127,7 +127,7 @@ impl Render for EditPredictionButton {
}),
);
}
let this = cx.entity().clone();
let this = cx.entity();
div().child(
PopoverMenu::new("copilot")
@ -182,7 +182,7 @@ impl Render for EditPredictionButton {
let icon = status.to_icon();
let tooltip_text = status.to_tooltip();
let has_menu = status.has_menu();
let this = cx.entity().clone();
let this = cx.entity();
let fs = self.fs.clone();
return div().child(
@ -331,7 +331,7 @@ impl Render for EditPredictionButton {
})
});
let this = cx.entity().clone();
let this = cx.entity();
let mut popover_menu = PopoverMenu::new("zeta")
.menu(move |window, cx| {

View file

@ -1039,9 +1039,7 @@ pub struct Editor {
inline_diagnostics: Vec<(Anchor, InlineDiagnostic)>,
soft_wrap_mode_override: Option<language_settings::SoftWrap>,
hard_wrap: Option<usize>,
// TODO: make this a access method
pub project: Option<Entity<Project>>,
project: Option<Entity<Project>>,
semantics_provider: Option<Rc<dyn SemanticsProvider>>,
completion_provider: Option<Rc<dyn CompletionProvider>>,
collaboration_hub: Option<Box<dyn CollaborationHub>>,
@ -2326,7 +2324,7 @@ impl Editor {
editor.go_to_active_debug_line(window, cx);
if let Some(buffer) = buffer.read(cx).as_singleton() {
if let Some(project) = editor.project.as_ref() {
if let Some(project) = editor.project() {
let handle = project.update(cx, |project, cx| {
project.register_buffer_with_language_servers(&buffer, cx)
});
@ -2371,6 +2369,34 @@ impl Editor {
.is_some_and(|menu| menu.context_menu.focus_handle(cx).is_focused(window))
}
pub fn is_range_selected(&mut self, range: &Range<Anchor>, cx: &mut Context<Self>) -> bool {
if self
.selections
.pending
.as_ref()
.is_some_and(|pending_selection| {
let snapshot = self.buffer().read(cx).snapshot(cx);
pending_selection
.selection
.range()
.includes(&range, &snapshot)
})
{
return true;
}
self.selections
.disjoint_in_range::<usize>(range.clone(), cx)
.into_iter()
.any(|selection| {
// This is needed to cover a corner case, if we just check for an existing
// selection in the fold range, having a cursor at the start of the fold
// marks it as selected. Non-empty selections don't cause this.
let length = selection.end - selection.start;
length > 0
})
}
pub fn key_context(&self, window: &Window, cx: &App) -> KeyContext {
self.key_context_internal(self.has_active_edit_prediction(), window, cx)
}
@ -2626,6 +2652,10 @@ impl Editor {
&self.buffer
}
pub fn project(&self) -> Option<&Entity<Project>> {
self.project.as_ref()
}
pub fn workspace(&self) -> Option<Entity<Workspace>> {
self.workspace.as_ref()?.0.upgrade()
}
@ -5212,7 +5242,7 @@ impl Editor {
restrict_to_languages: Option<&HashSet<Arc<Language>>>,
cx: &mut Context<Editor>,
) -> HashMap<ExcerptId, (Entity<Buffer>, clock::Global, Range<usize>)> {
let Some(project) = self.project.as_ref() else {
let Some(project) = self.project() else {
return HashMap::default();
};
let project = project.read(cx);
@ -5294,7 +5324,7 @@ impl Editor {
return None;
}
let project = self.project.as_ref()?;
let project = self.project()?;
let position = self.selections.newest_anchor().head();
let (buffer, buffer_position) = self
.buffer
@ -6141,7 +6171,7 @@ impl Editor {
cx: &mut App,
) -> Task<Vec<task::DebugScenario>> {
maybe!({
let project = self.project.as_ref()?;
let project = self.project()?;
let dap_store = project.read(cx).dap_store();
let mut scenarios = vec![];
let resolved_tasks = resolved_tasks.as_ref()?;
@ -7907,7 +7937,7 @@ impl Editor {
let snapshot = self.snapshot(window, cx);
let multi_buffer_snapshot = &snapshot.display_snapshot.buffer_snapshot;
let Some(project) = self.project.as_ref() else {
let Some(project) = self.project() else {
return breakpoint_display_points;
};
@ -10501,7 +10531,7 @@ impl Editor {
) {
if let Some(working_directory) = self.active_excerpt(cx).and_then(|(_, buffer, _)| {
let project_path = buffer.read(cx).project_path(cx)?;
let project = self.project.as_ref()?.read(cx);
let project = self.project()?.read(cx);
let entry = project.entry_for_path(&project_path, cx)?;
let parent = match &entry.canonical_path {
Some(canonical_path) => canonical_path.to_path_buf(),
@ -14875,7 +14905,7 @@ impl Editor {
self.clear_tasks();
return Task::ready(());
}
let project = self.project.as_ref().map(Entity::downgrade);
let project = self.project().map(Entity::downgrade);
let task_sources = self.lsp_task_sources(cx);
let multi_buffer = self.buffer.downgrade();
cx.spawn_in(window, async move |editor, cx| {
@ -17054,7 +17084,7 @@ impl Editor {
if !pull_diagnostics_settings.enabled {
return None;
}
let project = self.project.as_ref()?.downgrade();
let project = self.project()?.downgrade();
let debounce = Duration::from_millis(pull_diagnostics_settings.debounce_ms);
let mut buffers = self.buffer.read(cx).all_buffers();
if let Some(buffer_id) = buffer_id {
@ -18018,7 +18048,7 @@ impl Editor {
hunks: impl Iterator<Item = MultiBufferDiffHunk>,
cx: &mut App,
) -> Option<()> {
let project = self.project.as_ref()?;
let project = self.project()?;
let buffer = project.read(cx).buffer_for_id(buffer_id, cx)?;
let diff = self.buffer.read(cx).diff_for(buffer_id)?;
let buffer_snapshot = buffer.read(cx).snapshot();
@ -18678,7 +18708,7 @@ impl Editor {
self.active_excerpt(cx).and_then(|(_, buffer, _)| {
let buffer = buffer.read(cx);
if let Some(project_path) = buffer.project_path(cx) {
let project = self.project.as_ref()?.read(cx);
let project = self.project()?.read(cx);
project.absolute_path(&project_path, cx)
} else {
buffer
@ -18691,7 +18721,7 @@ impl Editor {
fn target_file_path(&self, cx: &mut Context<Self>) -> Option<PathBuf> {
self.active_excerpt(cx).and_then(|(_, buffer, _)| {
let project_path = buffer.read(cx).project_path(cx)?;
let project = self.project.as_ref()?.read(cx);
let project = self.project()?.read(cx);
let entry = project.entry_for_path(&project_path, cx)?;
let path = entry.path.to_path_buf();
Some(path)
@ -18912,7 +18942,7 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some(project) = self.project.as_ref() {
if let Some(project) = self.project() {
let Some(buffer) = self.buffer().read(cx).as_singleton() else {
return;
};
@ -19028,7 +19058,7 @@ impl Editor {
return Task::ready(Err(anyhow!("failed to determine buffer and selection")));
};
let Some(project) = self.project.as_ref() else {
let Some(project) = self.project() else {
return Task::ready(Err(anyhow!("editor does not have project")));
};
@ -21015,7 +21045,7 @@ impl Editor {
cx: &mut Context<Self>,
) {
let workspace = self.workspace();
let project = self.project.as_ref();
let project = self.project();
let save_tasks = self.buffer().update(cx, |multi_buffer, cx| {
let mut tasks = Vec::new();
for (buffer_id, changes) in revert_changes {

View file

@ -74,7 +74,7 @@ fn test_edit_events(cx: &mut TestAppContext) {
let editor1 = cx.add_window({
let events = events.clone();
|window, cx| {
let entity = cx.entity().clone();
let entity = cx.entity();
cx.subscribe_in(
&entity,
window,
@ -95,7 +95,7 @@ fn test_edit_events(cx: &mut TestAppContext) {
let events = events.clone();
|window, cx| {
cx.subscribe_in(
&cx.entity().clone(),
&cx.entity(),
window,
move |_, _, event: &EditorEvent, _, _| match event {
EditorEvent::Edited { .. } => events.borrow_mut().push(("editor2", "edited")),
@ -15082,7 +15082,7 @@ async fn go_to_prev_overlapping_diagnostic(executor: BackgroundExecutor, cx: &mu
let mut cx = EditorTestContext::new(cx).await;
let lsp_store =
cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store());
cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store());
cx.set_state(indoc! {"
ˇfn func(abc def: i32) -> u32 {
@ -19634,13 +19634,8 @@ fn test_crease_insertion_and_rendering(cx: &mut TestAppContext) {
editor.insert_creases(Some(crease), cx);
let snapshot = editor.snapshot(window, cx);
let _div = snapshot.render_crease_toggle(
MultiBufferRow(1),
false,
cx.entity().clone(),
window,
cx,
);
let _div =
snapshot.render_crease_toggle(MultiBufferRow(1), false, cx.entity(), window, cx);
snapshot
})
.unwrap();

View file

@ -7818,7 +7818,7 @@ impl Element for EditorElement {
min_lines,
max_lines,
} => {
let editor_handle = cx.entity().clone();
let editor_handle = cx.entity();
let max_line_number_width =
self.max_line_number_width(&editor.snapshot(window, cx), window);
window.request_measured_layout(

View file

@ -250,7 +250,7 @@ fn show_hover(
let (excerpt_id, _, _) = editor.buffer().read(cx).excerpt_containing(anchor, cx)?;
let language_registry = editor.project.as_ref()?.read(cx).languages().clone();
let language_registry = editor.project()?.read(cx).languages().clone();
let provider = editor.semantics_provider.clone()?;
if !ignore_timeout {

View file

@ -678,7 +678,7 @@ impl Item for Editor {
let buffer = buffer.read(cx);
let path = buffer.project_path(cx)?;
let buffer_id = buffer.remote_id();
let project = self.project.as_ref()?.read(cx);
let project = self.project()?.read(cx);
let entry = project.entry_for_path(&path, cx)?;
let (repo, repo_path) = project
.git_store()

View file

@ -51,7 +51,7 @@ pub(super) fn refresh_linked_ranges(
if editor.pending_rename.is_some() {
return None;
}
let project = editor.project.as_ref()?.downgrade();
let project = editor.project()?.downgrade();
editor.linked_editing_range_task = Some(cx.spawn_in(window, async move |editor, cx| {
cx.background_executor().timer(UPDATE_DEBOUNCE).await;

View file

@ -169,7 +169,7 @@ impl Editor {
else {
return;
};
let Some(lsp_store) = self.project.as_ref().map(|p| p.read(cx).lsp_store()) else {
let Some(lsp_store) = self.project().map(|p| p.read(cx).lsp_store()) else {
return;
};
let task = lsp_store.update(cx, |lsp_store, cx| {

View file

@ -297,9 +297,8 @@ impl EditorTestContext {
pub fn set_head_text(&mut self, diff_base: &str) {
self.cx.run_until_parked();
let fs = self.update_editor(|editor, _, cx| {
editor.project.as_ref().unwrap().read(cx).fs().as_fake()
});
let fs =
self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake());
let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone());
fs.set_head_for_repo(
&Self::root_path().join(".git"),
@ -311,18 +310,16 @@ impl EditorTestContext {
pub fn clear_index_text(&mut self) {
self.cx.run_until_parked();
let fs = self.update_editor(|editor, _, cx| {
editor.project.as_ref().unwrap().read(cx).fs().as_fake()
});
let fs =
self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake());
fs.set_index_for_repo(&Self::root_path().join(".git"), &[]);
self.cx.run_until_parked();
}
pub fn set_index_text(&mut self, diff_base: &str) {
self.cx.run_until_parked();
let fs = self.update_editor(|editor, _, cx| {
editor.project.as_ref().unwrap().read(cx).fs().as_fake()
});
let fs =
self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake());
let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone());
fs.set_index_for_repo(
&Self::root_path().join(".git"),
@ -333,9 +330,8 @@ impl EditorTestContext {
#[track_caller]
pub fn assert_index_text(&mut self, expected: Option<&str>) {
let fs = self.update_editor(|editor, _, cx| {
editor.project.as_ref().unwrap().read(cx).fs().as_fake()
});
let fs =
self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake());
let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone());
let mut found = None;
fs.with_git_state(&Self::root_path().join(".git"), false, |git_state| {

View file

@ -701,7 +701,7 @@ impl ExtensionsPage {
extension: &ExtensionMetadata,
cx: &mut Context<Self>,
) -> ExtensionCard {
let this = cx.entity().clone();
let this = cx.entity();
let status = Self::extension_status(&extension.id, cx);
let has_dev_extension = Self::dev_extension_exists(&extension.id, cx);

View file

@ -112,7 +112,7 @@ fn excerpt_for_buffer_updated(
}
fn buffer_added(editor: &mut Editor, buffer: Entity<Buffer>, cx: &mut Context<Editor>) {
let Some(project) = &editor.project else {
let Some(project) = editor.project() else {
return;
};
let git_store = project.read(cx).git_store().clone();
@ -469,7 +469,7 @@ pub(crate) fn resolve_conflict(
let Some((workspace, project, multibuffer, buffer)) = editor
.update(cx, |editor, cx| {
let workspace = editor.workspace()?;
let project = editor.project.clone()?;
let project = editor.project()?.clone();
let multibuffer = editor.buffer().clone();
let buffer_id = resolved_conflict.ours.end.buffer_id?;
let buffer = multibuffer.read(cx).buffer(buffer_id)?;

View file

@ -3246,7 +3246,7 @@ impl GitPanel {
* MAX_PANEL_EDITOR_LINES
+ gap;
let git_panel = cx.entity().clone();
let git_panel = cx.entity();
let display_name = SharedString::from(Arc::from(
active_repository
.read(cx)

View file

@ -595,9 +595,7 @@ impl Render for TextInput {
.w_full()
.p(px(4.))
.bg(white())
.child(TextElement {
input: cx.entity().clone(),
}),
.child(TextElement { input: cx.entity() }),
)
}
}

View file

@ -1358,7 +1358,7 @@ impl Render for LspLogToolbarItemView {
})
.collect();
let log_toolbar_view = cx.entity().clone();
let log_toolbar_view = cx.entity();
let lsp_menu = PopoverMenu::new("LspLogView")
.anchor(Corner::TopLeft)

View file

@ -1007,7 +1007,7 @@ impl Render for LspTool {
(None, "All Servers Operational")
};
let lsp_tool = cx.entity().clone();
let lsp_tool = cx.entity();
div().child(
PopoverMenu::new("lsp-tool")

View file

@ -456,7 +456,7 @@ impl SyntaxTreeToolbarItemView {
let active_layer = buffer_state.active_layer.clone()?;
let active_buffer = buffer_state.buffer.read(cx).snapshot();
let view = cx.entity().clone();
let view = cx.entity();
Some(
PopoverMenu::new("Syntax Tree")
.trigger(Self::render_header(&active_layer))

View file

@ -4655,20 +4655,15 @@ impl OutlinePanel {
.when(show_indent_guides, |list| {
list.with_decoration(
ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx))
.with_compute_indents_fn(
cx.entity().clone(),
|outline_panel, range, _, _| {
.with_compute_indents_fn(cx.entity(), |outline_panel, range, _, _| {
let entries = outline_panel.cached_entries.get(range);
if let Some(entries) = entries {
entries.into_iter().map(|item| item.depth).collect()
} else {
smallvec::SmallVec::new()
}
},
)
.with_render_fn(
cx.entity().clone(),
move |outline_panel, params, _, _| {
})
.with_render_fn(cx.entity(), move |outline_panel, params, _, _| {
const LEFT_OFFSET: Pixels = px(14.);
let indent_size = params.indent_size;
@ -4698,8 +4693,7 @@ impl OutlinePanel {
}
})
.collect()
},
),
}),
)
})
.custom_scrollbars(

View file

@ -5187,9 +5187,7 @@ impl Render for ProjectPanel {
.when(show_indent_guides, |list| {
list.with_decoration(
ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx))
.with_compute_indents_fn(
cx.entity().clone(),
|this, range, window, cx| {
.with_compute_indents_fn(cx.entity(), |this, range, window, cx| {
let mut items =
SmallVec::with_capacity(range.end - range.start);
this.iter_visible_entries(
@ -5197,16 +5195,14 @@ impl Render for ProjectPanel {
window,
cx,
|entry, _, entries, _, _| {
let (depth, _) =
Self::calculate_depth_and_difference(
let (depth, _) = Self::calculate_depth_and_difference(
entry, entries,
);
items.push(depth);
},
);
items
},
)
})
.on_click(cx.listener(
|this, active_indent_guide: &IndentGuideLayout, window, cx| {
if window.modifiers().secondary() {
@ -5230,7 +5226,7 @@ impl Render for ProjectPanel {
}
},
))
.with_render_fn(cx.entity().clone(), move |this, params, _, cx| {
.with_render_fn(cx.entity(), move |this, params, _, cx| {
const LEFT_OFFSET: Pixels = px(14.);
const PADDING_Y: Pixels = px(4.);
const HITBOX_OVERDRAW: Pixels = px(3.);
@ -5283,7 +5279,7 @@ impl Render for ProjectPanel {
})
.when(show_sticky_entries, |list| {
let sticky_items = ui::sticky_items(
cx.entity().clone(),
cx.entity(),
|this, range, window, cx| {
let mut items = SmallVec::with_capacity(range.end - range.start);
this.iter_visible_entries(
@ -5310,7 +5306,7 @@ impl Render for ProjectPanel {
list.with_decoration(if show_indent_guides {
sticky_items.with_decoration(
ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx))
.with_render_fn(cx.entity().clone(), move |_, params, _, _| {
.with_render_fn(cx.entity(), move |_, params, _, _| {
const LEFT_OFFSET: Pixels = px(14.);
let indent_size = params.indent_size;

View file

@ -158,14 +158,6 @@ message SynchronizeContextsResponse {
repeated ContextVersion contexts = 1;
}
message GetLlmToken {}
message GetLlmTokenResponse {
string token = 1;
}
message RefreshLlmToken {}
enum LanguageModelRole {
LanguageModelUser = 0;
LanguageModelAssistant = 1;

View file

@ -6,62 +6,6 @@ message UpdateInviteInfo {
uint32 count = 2;
}
message GetPrivateUserInfo {}
message GetPrivateUserInfoResponse {
string metrics_id = 1;
bool staff = 2;
repeated string flags = 3;
optional uint64 accepted_tos_at = 4;
}
enum Plan {
Free = 0;
ZedPro = 1;
ZedProTrial = 2;
}
message UpdateUserPlan {
Plan plan = 1;
optional uint64 trial_started_at = 2;
optional bool is_usage_based_billing_enabled = 3;
optional SubscriptionUsage usage = 4;
optional SubscriptionPeriod subscription_period = 5;
optional bool account_too_young = 6;
optional bool has_overdue_invoices = 7;
}
message SubscriptionPeriod {
uint64 started_at = 1;
uint64 ended_at = 2;
}
message SubscriptionUsage {
uint32 model_requests_usage_amount = 1;
UsageLimit model_requests_usage_limit = 2;
uint32 edit_predictions_usage_amount = 3;
UsageLimit edit_predictions_usage_limit = 4;
}
message UsageLimit {
oneof variant {
Limited limited = 1;
Unlimited unlimited = 2;
}
message Limited {
uint32 limit = 1;
}
message Unlimited {}
}
message AcceptTermsOfService {}
message AcceptTermsOfServiceResponse {
uint64 accepted_tos_at = 1;
}
message ShutdownRemoteServer {}
message Toast {

View file

@ -135,12 +135,7 @@ message Envelope {
FollowResponse follow_response = 99;
UpdateFollowers update_followers = 100;
Unfollow unfollow = 101;
GetPrivateUserInfo get_private_user_info = 102;
GetPrivateUserInfoResponse get_private_user_info_response = 103;
UpdateUserPlan update_user_plan = 234;
UpdateDiffBases update_diff_bases = 104;
AcceptTermsOfService accept_terms_of_service = 239;
AcceptTermsOfServiceResponse accept_terms_of_service_response = 240;
OnTypeFormatting on_type_formatting = 105;
OnTypeFormattingResponse on_type_formatting_response = 106;
@ -250,10 +245,6 @@ message Envelope {
AddWorktree add_worktree = 222;
AddWorktreeResponse add_worktree_response = 223;
GetLlmToken get_llm_token = 235;
GetLlmTokenResponse get_llm_token_response = 236;
RefreshLlmToken refresh_llm_token = 259;
LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241;
LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242;
@ -406,6 +397,7 @@ message Envelope {
}
reserved 87 to 88;
reserved 102 to 103;
reserved 158 to 161;
reserved 164;
reserved 166 to 169;
@ -419,10 +411,13 @@ message Envelope {
reserved 221;
reserved 224 to 229;
reserved 230 to 231;
reserved 234 to 236;
reserved 239 to 240;
reserved 246;
reserved 270;
reserved 247 to 254;
reserved 255 to 256;
reserved 259;
reserved 270;
reserved 280 to 281;
reserved 332 to 333;
}

View file

@ -20,8 +20,6 @@ pub const SSH_PEER_ID: PeerId = PeerId { owner_id: 0, id: 0 };
pub const SSH_PROJECT_ID: u64 = 0;
messages!(
(AcceptTermsOfService, Foreground),
(AcceptTermsOfServiceResponse, Foreground),
(Ack, Foreground),
(AckBufferOperation, Background),
(AckChannelMessage, Background),
@ -105,8 +103,6 @@ messages!(
(GetPathMetadataResponse, Background),
(GetPermalinkToLine, Foreground),
(GetPermalinkToLineResponse, Foreground),
(GetPrivateUserInfo, Foreground),
(GetPrivateUserInfoResponse, Foreground),
(GetProjectSymbols, Background),
(GetProjectSymbolsResponse, Background),
(GetReferences, Background),
@ -119,8 +115,6 @@ messages!(
(GetTypeDefinitionResponse, Background),
(GetImplementation, Background),
(GetImplementationResponse, Background),
(GetLlmToken, Background),
(GetLlmTokenResponse, Background),
(OpenUnstagedDiff, Foreground),
(OpenUnstagedDiffResponse, Foreground),
(OpenUncommittedDiff, Foreground),
@ -196,7 +190,6 @@ messages!(
(PrepareRenameResponse, Background),
(ProjectEntryResponse, Foreground),
(RefreshInlayHints, Foreground),
(RefreshLlmToken, Background),
(RegisterBufferWithLanguageServers, Background),
(RejoinChannelBuffers, Foreground),
(RejoinChannelBuffersResponse, Foreground),
@ -280,7 +273,6 @@ messages!(
(UpdateProject, Foreground),
(UpdateProjectCollaborator, Foreground),
(UpdateUserChannels, Foreground),
(UpdateUserPlan, Foreground),
(UpdateWorktree, Foreground),
(UpdateWorktreeSettings, Foreground),
(UpdateRepository, Foreground),
@ -321,7 +313,6 @@ messages!(
);
request_messages!(
(AcceptTermsOfService, AcceptTermsOfServiceResponse),
(ApplyCodeAction, ApplyCodeActionResponse),
(
ApplyCompletionAdditionalEdits,
@ -354,9 +345,7 @@ request_messages!(
(GetDocumentHighlights, GetDocumentHighlightsResponse),
(GetDocumentSymbols, GetDocumentSymbolsResponse),
(GetHover, GetHoverResponse),
(GetLlmToken, GetLlmTokenResponse),
(GetNotifications, GetNotificationsResponse),
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
(GetProjectSymbols, GetProjectSymbolsResponse),
(GetReferences, GetReferencesResponse),
(GetSignatureHelp, GetSignatureHelpResponse),

View file

@ -1291,7 +1291,7 @@ impl RemoteServerProjects {
let connection_string = connection_string.clone();
move |_, _: &menu::Confirm, window, cx| {
remove_ssh_server(
cx.entity().clone(),
cx.entity(),
server_index,
connection_string.clone(),
window,
@ -1311,7 +1311,7 @@ impl RemoteServerProjects {
.child(Label::new("Remove Server").color(Color::Error))
.on_click(cx.listener(move |_, _, window, cx| {
remove_ssh_server(
cx.entity().clone(),
cx.entity(),
server_index,
connection_string.clone(),
window,

View file

@ -244,7 +244,7 @@ impl Session {
repl_session_id = cx.entity_id().to_string(),
);
let session_view = cx.entity().clone();
let session_view = cx.entity();
let kernel = match self.kernel_specification.clone() {
KernelSpecification::Jupyter(kernel_specification)

View file

@ -2,9 +2,9 @@ mod registrar;
use crate::{
FocusSearch, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext, SearchOption,
SearchOptions, SelectAllMatches, SelectNextMatch, SelectPreviousMatch, ToggleCaseSensitive,
ToggleRegex, ToggleReplace, ToggleSelection, ToggleWholeWord,
search_bar::{input_base_styles, render_action_button, render_text_input},
SearchOptions, SearchSource, SelectAllMatches, SelectNextMatch, SelectPreviousMatch,
ToggleCaseSensitive, ToggleRegex, ToggleReplace, ToggleSelection, ToggleWholeWord,
search_bar::{ActionButtonState, input_base_styles, render_action_button, render_text_input},
};
use any_vec::AnyVec;
use anyhow::Context as _;
@ -213,22 +213,25 @@ impl Render for BufferSearchBar {
h_flex()
.gap_1()
.when(case, |div| {
div.child(
SearchOption::CaseSensitive
.as_button(self.search_options, focus_handle.clone()),
)
div.child(SearchOption::CaseSensitive.as_button(
self.search_options,
SearchSource::Buffer,
focus_handle.clone(),
))
})
.when(word, |div| {
div.child(
SearchOption::WholeWord
.as_button(self.search_options, focus_handle.clone()),
)
div.child(SearchOption::WholeWord.as_button(
self.search_options,
SearchSource::Buffer,
focus_handle.clone(),
))
})
.when(regex, |div| {
div.child(
SearchOption::Regex
.as_button(self.search_options, focus_handle.clone()),
)
div.child(SearchOption::Regex.as_button(
self.search_options,
SearchSource::Buffer,
focus_handle.clone(),
))
}),
)
});
@ -240,7 +243,7 @@ impl Render for BufferSearchBar {
this.child(render_action_button(
"buffer-search-bar-toggle",
IconName::Replace,
self.replace_enabled,
self.replace_enabled.then_some(ActionButtonState::Toggled),
"Toggle Replace",
&ToggleReplace,
focus_handle.clone(),
@ -285,7 +288,9 @@ impl Render for BufferSearchBar {
.child(render_action_button(
"buffer-search-nav-button",
ui::IconName::ChevronLeft,
self.active_match_index.is_some(),
self.active_match_index
.is_none()
.then_some(ActionButtonState::Disabled),
"Select Previous Match",
&SelectPreviousMatch,
query_focus.clone(),
@ -293,7 +298,9 @@ impl Render for BufferSearchBar {
.child(render_action_button(
"buffer-search-nav-button",
ui::IconName::ChevronRight,
self.active_match_index.is_some(),
self.active_match_index
.is_none()
.then_some(ActionButtonState::Disabled),
"Select Next Match",
&SelectNextMatch,
query_focus.clone(),
@ -313,7 +320,7 @@ impl Render for BufferSearchBar {
el.child(render_action_button(
"buffer-search-nav-button",
IconName::SelectAll,
true,
Default::default(),
"Select All Matches",
&SelectAllMatches,
query_focus,
@ -324,7 +331,7 @@ impl Render for BufferSearchBar {
el.child(render_action_button(
"buffer-search",
IconName::Close,
true,
Default::default(),
"Close Search Bar",
&Dismiss,
focus_handle.clone(),
@ -352,7 +359,7 @@ impl Render for BufferSearchBar {
.child(render_action_button(
"buffer-search-replace-button",
IconName::ReplaceNext,
true,
Default::default(),
"Replace Next Match",
&ReplaceNext,
focus_handle.clone(),
@ -360,7 +367,7 @@ impl Render for BufferSearchBar {
.child(render_action_button(
"buffer-search-replace-button",
IconName::ReplaceAll,
true,
Default::default(),
"Replace All Matches",
&ReplaceAll,
focus_handle,
@ -394,7 +401,7 @@ impl Render for BufferSearchBar {
div.child(h_flex().absolute().right_0().child(render_action_button(
"buffer-search",
IconName::Close,
true,
Default::default(),
"Close Search Bar",
&Dismiss,
focus_handle.clone(),

View file

@ -1,9 +1,9 @@
use crate::{
BufferSearchBar, FocusSearch, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext,
SearchOption, SearchOptions, SelectNextMatch, SelectPreviousMatch, ToggleCaseSensitive,
ToggleIncludeIgnored, ToggleRegex, ToggleReplace, ToggleWholeWord,
SearchOption, SearchOptions, SearchSource, SelectNextMatch, SelectPreviousMatch,
ToggleCaseSensitive, ToggleIncludeIgnored, ToggleRegex, ToggleReplace, ToggleWholeWord,
buffer_search::Deploy,
search_bar::{input_base_styles, render_action_button, render_text_input},
search_bar::{ActionButtonState, input_base_styles, render_action_button, render_text_input},
};
use anyhow::Context as _;
use collections::HashMap;
@ -1665,7 +1665,7 @@ impl ProjectSearchBar {
});
}
fn toggle_search_option(
pub(crate) fn toggle_search_option(
&mut self,
option: SearchOptions,
window: &mut Window,
@ -1962,17 +1962,21 @@ impl Render for ProjectSearchBar {
.child(
h_flex()
.gap_1()
.child(
SearchOption::CaseSensitive
.as_button(search.search_options, focus_handle.clone()),
)
.child(
SearchOption::WholeWord
.as_button(search.search_options, focus_handle.clone()),
)
.child(
SearchOption::Regex.as_button(search.search_options, focus_handle.clone()),
),
.child(SearchOption::CaseSensitive.as_button(
search.search_options,
SearchSource::Project(cx),
focus_handle.clone(),
))
.child(SearchOption::WholeWord.as_button(
search.search_options,
SearchSource::Project(cx),
focus_handle.clone(),
))
.child(SearchOption::Regex.as_button(
search.search_options,
SearchSource::Project(cx),
focus_handle.clone(),
)),
);
let query_focus = search.query_editor.focus_handle(cx);
@ -1985,7 +1989,10 @@ impl Render for ProjectSearchBar {
.child(render_action_button(
"project-search-nav-button",
IconName::ChevronLeft,
search.active_match_index.is_some(),
search
.active_match_index
.is_none()
.then_some(ActionButtonState::Disabled),
"Select Previous Match",
&SelectPreviousMatch,
query_focus.clone(),
@ -1993,7 +2000,10 @@ impl Render for ProjectSearchBar {
.child(render_action_button(
"project-search-nav-button",
IconName::ChevronRight,
search.active_match_index.is_some(),
search
.active_match_index
.is_none()
.then_some(ActionButtonState::Disabled),
"Select Next Match",
&SelectNextMatch,
query_focus,
@ -2054,7 +2064,7 @@ impl Render for ProjectSearchBar {
self.active_project_search
.as_ref()
.map(|search| search.read(cx).replace_enabled)
.unwrap_or_default(),
.and_then(|enabled| enabled.then_some(ActionButtonState::Toggled)),
"Toggle Replace",
&ToggleReplace,
focus_handle.clone(),
@ -2079,7 +2089,7 @@ impl Render for ProjectSearchBar {
.child(render_action_button(
"project-search-replace-button",
IconName::ReplaceNext,
true,
Default::default(),
"Replace Next Match",
&ReplaceNext,
focus_handle.clone(),
@ -2087,7 +2097,7 @@ impl Render for ProjectSearchBar {
.child(render_action_button(
"project-search-replace-button",
IconName::ReplaceAll,
true,
Default::default(),
"Replace All Matches",
&ReplaceAll,
focus_handle,
@ -2129,10 +2139,11 @@ impl Render for ProjectSearchBar {
this.toggle_opened_only(window, cx);
})),
)
.child(
SearchOption::IncludeIgnored
.as_button(search.search_options, focus_handle.clone()),
);
.child(SearchOption::IncludeIgnored.as_button(
search.search_options,
SearchSource::Project(cx),
focus_handle.clone(),
));
h_flex()
.w_full()
.gap_2()

View file

@ -1,7 +1,7 @@
use bitflags::bitflags;
pub use buffer_search::BufferSearchBar;
use editor::SearchSettings;
use gpui::{Action, App, FocusHandle, IntoElement, actions};
use gpui::{Action, App, ClickEvent, FocusHandle, IntoElement, actions};
use project::search::SearchQuery;
pub use project_search::ProjectSearchView;
use ui::{ButtonStyle, IconButton, IconButtonShape};
@ -11,6 +11,8 @@ use workspace::{Toast, Workspace};
pub use search_status_button::SEARCH_ICON;
use crate::project_search::ProjectSearchBar;
pub mod buffer_search;
pub mod project_search;
pub(crate) mod search_bar;
@ -83,9 +85,14 @@ pub enum SearchOption {
Backwards,
}
pub(crate) enum SearchSource<'a, 'b> {
Buffer,
Project(&'a Context<'b, ProjectSearchBar>),
}
impl SearchOption {
pub fn as_options(self) -> SearchOptions {
SearchOptions::from_bits(1 << self as u8).unwrap()
pub fn as_options(&self) -> SearchOptions {
SearchOptions::from_bits(1 << *self as u8).unwrap()
}
pub fn label(&self) -> &'static str {
@ -119,17 +126,33 @@ impl SearchOption {
}
}
pub fn as_button(&self, active: SearchOptions, focus_handle: FocusHandle) -> impl IntoElement {
pub(crate) fn as_button(
&self,
active: SearchOptions,
search_source: SearchSource,
focus_handle: FocusHandle,
) -> impl IntoElement {
let action = self.to_toggle_action();
let label = self.label();
IconButton::new(label, self.icon())
.on_click({
IconButton::new(
(label, matches!(search_source, SearchSource::Buffer) as u32),
self.icon(),
)
.map(|button| match search_source {
SearchSource::Buffer => {
let focus_handle = focus_handle.clone();
move |_, window, cx| {
button.on_click(move |_: &ClickEvent, window, cx| {
if !focus_handle.is_focused(&window) {
window.focus(&focus_handle);
}
window.dispatch_action(action.boxed_clone(), cx)
window.dispatch_action(action.boxed_clone(), cx);
})
}
SearchSource::Project(cx) => {
let options = self.as_options();
button.on_click(cx.listener(move |this, _: &ClickEvent, window, cx| {
this.toggle_search_option(options, window, cx);
}))
}
})
.style(ButtonStyle::Subtle)

View file

@ -5,10 +5,15 @@ use theme::ThemeSettings;
use ui::{IconButton, IconButtonShape};
use ui::{Tooltip, prelude::*};
pub(super) enum ActionButtonState {
Disabled,
Toggled,
}
pub(super) fn render_action_button(
id_prefix: &'static str,
icon: ui::IconName,
active: bool,
button_state: Option<ActionButtonState>,
tooltip: &'static str,
action: &'static dyn Action,
focus_handle: FocusHandle,
@ -28,7 +33,10 @@ pub(super) fn render_action_button(
}
})
.tooltip(move |window, cx| Tooltip::for_action_in(tooltip, action, &focus_handle, window, cx))
.disabled(!active)
.when_some(button_state, |this, state| match state {
ActionButtonState::Toggled => this.toggle_state(true),
ActionButtonState::Disabled => this.disabled(true),
})
}
pub(crate) fn input_base_styles(border_color: Hsla, map: impl FnOnce(Div) -> Div) -> Div {

View file

@ -2177,6 +2177,7 @@ impl KeybindingEditorModal {
let value = action_arguments
.as_ref()
.filter(|args| !args.is_empty())
.map(|args| {
serde_json::from_str(args).context("Failed to parse action arguments as JSON")
})

View file

@ -65,7 +65,7 @@ impl Render for IndentGuidesStory {
},
)
.with_compute_indents_fn(
cx.entity().clone(),
cx.entity(),
|this, range, _cx, _context| {
this.depths
.iter()

View file

@ -2,7 +2,7 @@
use semantic_version::SemanticVersion;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, fmt::Display, sync::Arc, time::Duration};
use std::{collections::HashMap, fmt::Display, time::Duration};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct EventRequestBody {
@ -93,19 +93,6 @@ impl Display for AssistantPhase {
#[serde(tag = "type")]
pub enum Event {
Flexible(FlexibleEvent),
Editor(EditorEvent),
EditPrediction(EditPredictionEvent),
EditPredictionRating(EditPredictionRatingEvent),
Call(CallEvent),
Assistant(AssistantEventData),
Cpu(CpuEvent),
Memory(MemoryEvent),
App(AppEvent),
Setting(SettingEvent),
Extension(ExtensionEvent),
Edit(EditEvent),
Action(ActionEvent),
Repl(ReplEvent),
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
@ -114,54 +101,12 @@ pub struct FlexibleEvent {
pub event_properties: HashMap<String, serde_json::Value>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct EditorEvent {
/// The editor operation performed (open, save)
pub operation: String,
/// The extension of the file that was opened or saved
pub file_extension: Option<String>,
/// Whether the user is in vim mode or not
pub vim_mode: bool,
/// Whether the user has copilot enabled or not
pub copilot_enabled: bool,
/// Whether the user has copilot enabled for the language of the file opened or saved
pub copilot_enabled_for_language: bool,
/// Whether the client is opening/saving a local file or a remote file via SSH
#[serde(default)]
pub is_via_ssh: bool,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct EditPredictionEvent {
/// Provider of the completion suggestion (e.g. copilot, supermaven)
pub provider: String,
pub suggestion_accepted: bool,
pub file_extension: Option<String>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum EditPredictionRating {
Positive,
Negative,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct EditPredictionRatingEvent {
pub rating: EditPredictionRating,
pub input_events: Arc<str>,
pub input_excerpt: Arc<str>,
pub output_excerpt: Arc<str>,
pub feedback: String,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CallEvent {
/// Operation performed: invite/join call; begin/end screenshare; share/unshare project; etc
pub operation: String,
pub room_id: Option<u64>,
pub channel_id: Option<u64>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AssistantEventData {
/// Unique random identifier for each assistant tab (None for inline assist)
@ -180,57 +125,6 @@ pub struct AssistantEventData {
pub language_name: Option<String>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CpuEvent {
pub usage_as_percentage: f32,
pub core_count: u32,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct MemoryEvent {
pub memory_in_bytes: u64,
pub virtual_memory_in_bytes: u64,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ActionEvent {
pub source: String,
pub action: String,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct EditEvent {
pub duration: i64,
pub environment: String,
/// Whether the edits occurred locally or remotely via SSH
#[serde(default)]
pub is_via_ssh: bool,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct SettingEvent {
pub setting: String,
pub value: String,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ExtensionEvent {
pub extension_id: Arc<str>,
pub version: Arc<str>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AppEvent {
pub operation: String,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ReplEvent {
pub kernel_language: String,
pub kernel_status: String,
pub repl_session_id: String,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct BacktraceFrame {
pub ip: usize,

View file

@ -947,7 +947,7 @@ pub fn new_terminal_pane(
cx: &mut Context<TerminalPanel>,
) -> Entity<Pane> {
let is_local = project.read(cx).is_local();
let terminal_panel = cx.entity().clone();
let terminal_panel = cx.entity();
let pane = cx.new(|cx| {
let mut pane = Pane::new(
workspace.clone(),
@ -1009,7 +1009,7 @@ pub fn new_terminal_pane(
return ControlFlow::Break(());
};
if let Some(tab) = dropped_item.downcast_ref::<DraggedTab>() {
let this_pane = cx.entity().clone();
let this_pane = cx.entity();
let item = if tab.pane == this_pane {
pane.item_for_index(tab.ix)
} else {

View file

@ -1391,7 +1391,7 @@ impl Render for TerminalView {
}
let terminal_handle = self.terminal.clone();
let terminal_view_handle = cx.entity().clone();
let terminal_view_handle = cx.entity();
let focused = self.focus_handle.is_focused(window);

View file

@ -299,7 +299,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) {
Vim::action(editor, cx, |vim, action: &VimSave, window, cx| {
vim.update_editor(cx, |_, editor, cx| {
let Some(project) = editor.project.clone() else {
let Some(project) = editor.project().cloned() else {
return;
};
let Some(worktree) = project.read(cx).visible_worktrees(cx).next() else {
@ -436,7 +436,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) {
let Some(workspace) = vim.workspace(window) else {
return;
};
let Some(project) = editor.project.clone() else {
let Some(project) = editor.project().cloned() else {
return;
};
let Some(worktree) = project.read(cx).visible_worktrees(cx).next() else {

View file

@ -20,7 +20,7 @@ impl ModeIndicator {
})
.detach();
let handle = cx.entity().clone();
let handle = cx.entity();
let window_handle = window.window_handle();
cx.observe_new::<Vim>(move |_, window, cx| {
let Some(window) = window else {
@ -29,7 +29,7 @@ impl ModeIndicator {
if window.window_handle() != window_handle {
return;
}
let vim = cx.entity().clone();
let vim = cx.entity();
handle.update(cx, |_, cx| {
cx.subscribe(&vim, |mode_indicator, vim, event, cx| match event {
VimEvent::Focused => {

View file

@ -332,7 +332,7 @@ impl Vim {
Vim::take_forced_motion(cx);
let prior_selections = self.editor_selections(window, cx);
let cursor_word = self.editor_cursor_word(window, cx);
let vim = cx.entity().clone();
let vim = cx.entity();
let searched = pane.update(cx, |pane, cx| {
self.search.direction = direction;

View file

@ -402,7 +402,7 @@ impl Vim {
const NAMESPACE: &'static str = "vim";
pub fn new(window: &mut Window, cx: &mut Context<Editor>) -> Entity<Self> {
let editor = cx.entity().clone();
let editor = cx.entity();
let mut initial_mode = VimSettings::get_global(cx).default_mode;
if initial_mode == Mode::Normal && HelixModeSetting::get_global(cx).0 {

View file

@ -253,7 +253,7 @@ impl Dock {
cx: &mut Context<Workspace>,
) -> Entity<Self> {
let focus_handle = cx.focus_handle();
let workspace = cx.entity().clone();
let workspace = cx.entity();
let dock = cx.new(|cx| {
let focus_subscription =
cx.on_focus(&focus_handle, window, |dock: &mut Dock, window, cx| {

View file

@ -346,7 +346,7 @@ impl Render for LanguageServerPrompt {
)
.child(Label::new(request.message.to_string()).size(LabelSize::Small))
.children(request.actions.iter().enumerate().map(|(ix, action)| {
let this_handle = cx.entity().clone();
let this_handle = cx.entity();
Button::new(ix, action.title.clone())
.size(ButtonSize::Large)
.on_click(move |_, window, cx| {

View file

@ -2198,7 +2198,7 @@ impl Pane {
fn update_status_bar(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let workspace = self.workspace.clone();
let pane = cx.entity().clone();
let pane = cx.entity();
window.defer(cx, move |window, cx| {
let Ok(status_bar) =
@ -2279,7 +2279,7 @@ impl Pane {
cx: &mut Context<Self>,
) {
maybe!({
let pane = cx.entity().clone();
let pane = cx.entity();
let destination_index = match operation {
PinOperation::Pin => self.pinned_tab_count.min(ix),
@ -2473,7 +2473,7 @@ impl Pane {
.on_drag(
DraggedTab {
item: item.boxed_clone(),
pane: cx.entity().clone(),
pane: cx.entity(),
detail,
is_active,
ix,
@ -2832,7 +2832,7 @@ impl Pane {
let navigate_backward = IconButton::new("navigate_backward", IconName::ArrowLeft)
.icon_size(IconSize::Small)
.on_click({
let entity = cx.entity().clone();
let entity = cx.entity();
move |_, window, cx| {
entity.update(cx, |pane, cx| pane.navigate_backward(window, cx))
}
@ -2848,7 +2848,7 @@ impl Pane {
let navigate_forward = IconButton::new("navigate_forward", IconName::ArrowRight)
.icon_size(IconSize::Small)
.on_click({
let entity = cx.entity().clone();
let entity = cx.entity();
move |_, window, cx| entity.update(cx, |pane, cx| pane.navigate_forward(window, cx))
})
.disabled(!self.can_navigate_forward())
@ -3054,7 +3054,7 @@ impl Pane {
return;
}
}
let mut to_pane = cx.entity().clone();
let mut to_pane = cx.entity();
let split_direction = self.drag_split_direction;
let item_id = dragged_tab.item.item_id();
if let Some(preview_item_id) = self.preview_item_id {
@ -3163,7 +3163,7 @@ impl Pane {
return;
}
}
let mut to_pane = cx.entity().clone();
let mut to_pane = cx.entity();
let split_direction = self.drag_split_direction;
let project_entry_id = *project_entry_id;
self.workspace
@ -3239,7 +3239,7 @@ impl Pane {
return;
}
}
let mut to_pane = cx.entity().clone();
let mut to_pane = cx.entity();
let mut split_direction = self.drag_split_direction;
let paths = paths.paths().to_vec();
let is_remote = self

View file

@ -6338,7 +6338,7 @@ impl Render for Workspace {
.border_b_1()
.border_color(colors.border)
.child({
let this = cx.entity().clone();
let this = cx.entity();
canvas(
move |bounds, window, cx| {
this.update(cx, |this, cx| {

View file

@ -319,7 +319,7 @@ pub fn initialize_workspace(
return;
};
let workspace_handle = cx.entity().clone();
let workspace_handle = cx.entity();
let center_pane = workspace.active_pane().clone();
initialize_pane(workspace, &center_pane, window, cx);

View file

@ -229,8 +229,7 @@ fn assign_edit_prediction_provider(
if let Some(file) = buffer.read(cx).file() {
let id = file.worktree_id(cx);
if let Some(inner_worktree) = editor
.project
.as_ref()
.project()
.and_then(|project| project.read(cx).worktree_for_id(id, cx))
{
worktree = Some(inner_worktree);