Merge branch 'main' into ui-scrollbar-teardown
This commit is contained in:
commit
91cdf69924
83 changed files with 1532 additions and 3929 deletions
161
Cargo.lock
generated
161
Cargo.lock
generated
|
@ -1262,26 +1262,6 @@ dependencies = [
|
|||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stripe"
|
||||
version = "0.40.0"
|
||||
source = "git+https://github.com/zed-industries/async-stripe?rev=3672dd4efb7181aa597bf580bf5a2f5d23db6735#3672dd4efb7181aa597bf580bf5a2f5d23db6735"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"futures-util",
|
||||
"http-types",
|
||||
"hyper 0.14.32",
|
||||
"hyper-rustls 0.24.2",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_qs 0.10.1",
|
||||
"smart-default 0.6.0",
|
||||
"smol_str 0.1.24",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-tar"
|
||||
version = "0.5.0"
|
||||
|
@ -2083,12 +2063,6 @@ version = "0.1.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.7"
|
||||
|
@ -3281,7 +3255,6 @@ dependencies = [
|
|||
"anyhow",
|
||||
"assistant_context",
|
||||
"assistant_slash_command",
|
||||
"async-stripe",
|
||||
"async-trait",
|
||||
"async-tungstenite",
|
||||
"audio",
|
||||
|
@ -3308,7 +3281,6 @@ dependencies = [
|
|||
"dap_adapters",
|
||||
"dashmap 6.1.0",
|
||||
"debugger_ui",
|
||||
"derive_more 0.99.19",
|
||||
"editor",
|
||||
"envy",
|
||||
"extension",
|
||||
|
@ -3324,7 +3296,6 @@ dependencies = [
|
|||
"http_client",
|
||||
"hyper 0.14.32",
|
||||
"indoc",
|
||||
"jsonwebtoken",
|
||||
"language",
|
||||
"language_model",
|
||||
"livekit_api",
|
||||
|
@ -3370,7 +3341,6 @@ dependencies = [
|
|||
"telemetry_events",
|
||||
"text",
|
||||
"theme",
|
||||
"thiserror 2.0.12",
|
||||
"time",
|
||||
"tokio",
|
||||
"toml 0.8.20",
|
||||
|
@ -3872,7 +3842,7 @@ dependencies = [
|
|||
"rustc-hash 1.1.0",
|
||||
"rustybuzz 0.14.1",
|
||||
"self_cell",
|
||||
"smol_str 0.2.2",
|
||||
"smol_str",
|
||||
"swash",
|
||||
"sys-locale",
|
||||
"ttf-parser 0.21.1",
|
||||
|
@ -6376,17 +6346,6 @@ dependencies = [
|
|||
"windows-targets 0.48.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"wasi 0.9.0+wasi-snapshot-preview1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.15"
|
||||
|
@ -7990,27 +7949,6 @@ version = "0.3.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f"
|
||||
|
||||
[[package]]
|
||||
name = "http-types"
|
||||
version = "2.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6e9b187a72d63adbfba487f48095306ac823049cb504ee195541e91c7775f5ad"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-channel 1.9.0",
|
||||
"base64 0.13.1",
|
||||
"futures-lite 1.13.0",
|
||||
"http 0.2.12",
|
||||
"infer",
|
||||
"pin-project-lite",
|
||||
"rand 0.7.3",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_qs 0.8.5",
|
||||
"serde_urlencoded",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http_client"
|
||||
version = "0.1.0"
|
||||
|
@ -8489,12 +8427,6 @@ version = "2.0.6"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
|
||||
|
||||
[[package]]
|
||||
name = "infer"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac"
|
||||
|
||||
[[package]]
|
||||
name = "inherent"
|
||||
version = "1.0.12"
|
||||
|
@ -10271,7 +10203,7 @@ dependencies = [
|
|||
"num-traits",
|
||||
"range-map",
|
||||
"scroll",
|
||||
"smart-default 0.7.1",
|
||||
"smart-default",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -13145,19 +13077,6 @@ version = "0.7.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
|
||||
dependencies = [
|
||||
"getrandom 0.1.16",
|
||||
"libc",
|
||||
"rand_chacha 0.2.2",
|
||||
"rand_core 0.5.1",
|
||||
"rand_hc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.5"
|
||||
|
@ -13179,16 +13098,6 @@ dependencies = [
|
|||
"rand_core 0.9.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core 0.5.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.3.1"
|
||||
|
@ -13209,15 +13118,6 @@ dependencies = [
|
|||
"rand_core 0.9.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
||||
dependencies = [
|
||||
"getrandom 0.1.16",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.6.4"
|
||||
|
@ -13236,15 +13136,6 @@ dependencies = [
|
|||
"getrandom 0.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_hc"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||
dependencies = [
|
||||
"rand_core 0.5.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "range-map"
|
||||
version = "0.2.0"
|
||||
|
@ -14899,28 +14790,6 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_qs"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7715380eec75f029a4ef7de39a9200e0a63823176b759d055b613f5a87df6a6"
|
||||
dependencies = [
|
||||
"percent-encoding",
|
||||
"serde",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_qs"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8cac3f1e2ca2fe333923a1ae72caca910b98ed0630bb35ef6f8c8517d6e81afa"
|
||||
dependencies = [
|
||||
"percent-encoding",
|
||||
"serde",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_repr"
|
||||
version = "0.1.20"
|
||||
|
@ -15297,17 +15166,6 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smart-default"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "133659a15339456eeeb07572eb02a91c91e9815e9cbc89566944d2c8d3efdbf6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smart-default"
|
||||
version = "0.7.1"
|
||||
|
@ -15336,15 +15194,6 @@ dependencies = [
|
|||
"futures-lite 2.6.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smol_str"
|
||||
version = "0.1.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fad6c857cbab2627dcf01ec85a623ca4e7dcb5691cbaa3d7fb7653671f0d09c9"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smol_str"
|
||||
version = "0.2.2"
|
||||
|
@ -18194,12 +18043,6 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.9.0+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.11.0+wasi-snapshot-preview1"
|
||||
|
|
14
Cargo.toml
14
Cargo.toml
|
@ -667,20 +667,6 @@ workspace-hack = "0.1.0"
|
|||
yawc = { git = "https://github.com/deviant-forks/yawc", rev = "1899688f3e69ace4545aceb97b2a13881cf26142" }
|
||||
zstd = "0.11"
|
||||
|
||||
[workspace.dependencies.async-stripe]
|
||||
git = "https://github.com/zed-industries/async-stripe"
|
||||
rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735"
|
||||
default-features = false
|
||||
features = [
|
||||
"runtime-tokio-hyper-rustls",
|
||||
"billing",
|
||||
"checkout",
|
||||
"events",
|
||||
# The features below are only enabled to get the `events` feature to build.
|
||||
"chrono",
|
||||
"connect",
|
||||
]
|
||||
|
||||
[workspace.dependencies.windows]
|
||||
version = "0.61"
|
||||
features = [
|
||||
|
|
|
@ -109,7 +109,7 @@ pub enum AgentThreadEntry {
|
|||
}
|
||||
|
||||
impl AgentThreadEntry {
|
||||
fn to_markdown(&self, cx: &App) -> String {
|
||||
pub fn to_markdown(&self, cx: &App) -> String {
|
||||
match self {
|
||||
Self::UserMessage(message) => message.to_markdown(cx),
|
||||
Self::AssistantMessage(message) => message.to_markdown(cx),
|
||||
|
@ -117,6 +117,14 @@ impl AgentThreadEntry {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn user_message(&self) -> Option<&UserMessage> {
|
||||
if let AgentThreadEntry::UserMessage(message) = self {
|
||||
Some(message)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
|
||||
if let AgentThreadEntry::ToolCall(call) = self {
|
||||
itertools::Either::Left(call.diffs())
|
||||
|
|
|
@ -309,7 +309,7 @@ pub struct AgentSettingsContent {
|
|||
///
|
||||
/// Default: true
|
||||
expand_terminal_card: Option<bool>,
|
||||
/// Whether to always use cmd-enter (or ctrl-enter on Linux) to send messages in the agent panel.
|
||||
/// Whether to always use cmd-enter (or ctrl-enter on Linux or Windows) to send messages in the agent panel.
|
||||
///
|
||||
/// Default: false
|
||||
use_modifier_to_send: Option<bool>,
|
||||
|
|
|
@ -1,38 +1,31 @@
|
|||
use std::ffi::OsStr;
|
||||
use std::ops::Range;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use acp_thread::{MentionUri, selection_name};
|
||||
use acp_thread::MentionUri;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use editor::display_map::CreaseId;
|
||||
use editor::{CompletionProvider, Editor, ExcerptId, ToOffset as _};
|
||||
use editor::{CompletionProvider, Editor, ExcerptId};
|
||||
use futures::future::{Shared, try_join_all};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use fuzzy::{StringMatch, StringMatchCandidate};
|
||||
use gpui::{App, Entity, ImageFormat, Img, Task, WeakEntity};
|
||||
use http_client::HttpClientWithUrl;
|
||||
use itertools::Itertools as _;
|
||||
use gpui::{App, Entity, ImageFormat, Task, WeakEntity};
|
||||
use language::{Buffer, CodeLabel, HighlightId};
|
||||
use language_model::LanguageModelImage;
|
||||
use lsp::CompletionContext;
|
||||
use parking_lot::Mutex;
|
||||
use project::{
|
||||
Completion, CompletionIntent, CompletionResponse, Project, ProjectPath, Symbol, WorktreeId,
|
||||
};
|
||||
use prompt_store::PromptStore;
|
||||
use rope::Point;
|
||||
use text::{Anchor, OffsetRangeExt as _, ToPoint as _};
|
||||
use text::{Anchor, ToPoint as _};
|
||||
use ui::prelude::*;
|
||||
use url::Url;
|
||||
use workspace::Workspace;
|
||||
use workspace::notifications::NotifyResultExt;
|
||||
|
||||
use agent::thread_store::{TextThreadStore, ThreadStore};
|
||||
|
||||
use crate::context_picker::fetch_context_picker::fetch_url_content;
|
||||
use crate::acp::message_editor::MessageEditor;
|
||||
use crate::context_picker::file_context_picker::{FileMatch, search_files};
|
||||
use crate::context_picker::rules_context_picker::{RulesContextEntry, search_rules};
|
||||
use crate::context_picker::symbol_context_picker::SymbolMatch;
|
||||
|
@ -47,14 +40,14 @@ use crate::context_picker::{
|
|||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct MentionImage {
|
||||
pub abs_path: Option<Arc<Path>>,
|
||||
pub abs_path: Option<PathBuf>,
|
||||
pub data: SharedString,
|
||||
pub format: ImageFormat,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct MentionSet {
|
||||
uri_by_crease_id: HashMap<CreaseId, MentionUri>,
|
||||
pub(crate) uri_by_crease_id: HashMap<CreaseId, MentionUri>,
|
||||
fetch_results: HashMap<Url, Shared<Task<Result<String, String>>>>,
|
||||
images: HashMap<CreaseId, Shared<Task<Result<MentionImage, String>>>>,
|
||||
}
|
||||
|
@ -84,11 +77,6 @@ impl MentionSet {
|
|||
.chain(self.images.drain().map(|(id, _)| id))
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.fetch_results.clear();
|
||||
self.uri_by_crease_id.clear();
|
||||
}
|
||||
|
||||
pub fn contents(
|
||||
&self,
|
||||
project: Entity<Project>,
|
||||
|
@ -97,6 +85,8 @@ impl MentionSet {
|
|||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<HashMap<CreaseId, Mention>>> {
|
||||
let mut processed_image_creases = HashSet::default();
|
||||
|
||||
let mut contents = self
|
||||
.uri_by_crease_id
|
||||
.iter()
|
||||
|
@ -106,46 +96,15 @@ impl MentionSet {
|
|||
// TODO directories
|
||||
let uri = uri.clone();
|
||||
let abs_path = abs_path.to_path_buf();
|
||||
let extension = abs_path.extension().and_then(OsStr::to_str).unwrap_or("");
|
||||
|
||||
if Img::extensions().contains(&extension) && !extension.contains("svg") {
|
||||
let open_image_task = project.update(cx, |project, cx| {
|
||||
let path = project
|
||||
.find_project_path(&abs_path, cx)
|
||||
.context("Failed to find project path")?;
|
||||
anyhow::Ok(project.open_image(path, cx))
|
||||
if let Some(task) = self.images.get(&crease_id).cloned() {
|
||||
processed_image_creases.insert(crease_id);
|
||||
return cx.spawn(async move |_| {
|
||||
let image = task.await.map_err(|e| anyhow!("{e}"))?;
|
||||
anyhow::Ok((crease_id, Mention::Image(image)))
|
||||
});
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let image_item = open_image_task?.await?;
|
||||
let (data, format) = image_item.update(cx, |image_item, cx| {
|
||||
let format = image_item.image.format;
|
||||
(
|
||||
LanguageModelImage::from_image(
|
||||
image_item.image.clone(),
|
||||
cx,
|
||||
),
|
||||
format,
|
||||
)
|
||||
})?;
|
||||
let data = cx.spawn(async move |_| {
|
||||
if let Some(data) = data.await {
|
||||
Ok(data.source)
|
||||
} else {
|
||||
anyhow::bail!("Failed to convert image")
|
||||
}
|
||||
});
|
||||
|
||||
anyhow::Ok((
|
||||
crease_id,
|
||||
Mention::Image(MentionImage {
|
||||
abs_path: Some(abs_path.as_path().into()),
|
||||
data: data.await?,
|
||||
format,
|
||||
}),
|
||||
))
|
||||
})
|
||||
} else {
|
||||
let buffer_task = project.update(cx, |project, cx| {
|
||||
let path = project
|
||||
.find_project_path(abs_path, cx)
|
||||
|
@ -159,7 +118,6 @@ impl MentionSet {
|
|||
anyhow::Ok((crease_id, Mention::Text { uri, content }))
|
||||
})
|
||||
}
|
||||
}
|
||||
MentionUri::Symbol {
|
||||
path, line_range, ..
|
||||
}
|
||||
|
@ -252,15 +210,19 @@ impl MentionSet {
|
|||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
contents.extend(self.images.iter().map(|(crease_id, image)| {
|
||||
// Handle images that didn't have a mention URI (because they were added by the paste handler).
|
||||
contents.extend(self.images.iter().filter_map(|(crease_id, image)| {
|
||||
if processed_image_creases.contains(crease_id) {
|
||||
return None;
|
||||
}
|
||||
let crease_id = *crease_id;
|
||||
let image = image.clone();
|
||||
cx.spawn(async move |_| {
|
||||
Some(cx.spawn(async move |_| {
|
||||
Ok((
|
||||
crease_id,
|
||||
Mention::Image(image.await.map_err(|e| anyhow::anyhow!("{e}"))?),
|
||||
))
|
||||
})
|
||||
}))
|
||||
}));
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
|
@ -488,36 +450,31 @@ fn search(
|
|||
}
|
||||
|
||||
pub struct ContextPickerCompletionProvider {
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
text_thread_store: WeakEntity<TextThreadStore>,
|
||||
editor: WeakEntity<Editor>,
|
||||
message_editor: WeakEntity<MessageEditor>,
|
||||
}
|
||||
|
||||
impl ContextPickerCompletionProvider {
|
||||
pub fn new(
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
text_thread_store: WeakEntity<TextThreadStore>,
|
||||
editor: WeakEntity<Editor>,
|
||||
message_editor: WeakEntity<MessageEditor>,
|
||||
) -> Self {
|
||||
Self {
|
||||
mention_set,
|
||||
workspace,
|
||||
thread_store,
|
||||
text_thread_store,
|
||||
editor,
|
||||
message_editor,
|
||||
}
|
||||
}
|
||||
|
||||
fn completion_for_entry(
|
||||
entry: ContextPickerEntry,
|
||||
excerpt_id: ExcerptId,
|
||||
source_range: Range<Anchor>,
|
||||
editor: Entity<Editor>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
message_editor: WeakEntity<MessageEditor>,
|
||||
workspace: &Entity<Workspace>,
|
||||
cx: &mut App,
|
||||
) -> Option<Completion> {
|
||||
|
@ -538,88 +495,39 @@ impl ContextPickerCompletionProvider {
|
|||
ContextPickerEntry::Action(action) => {
|
||||
let (new_text, on_action) = match action {
|
||||
ContextPickerAction::AddSelections => {
|
||||
let selections = selection_ranges(workspace, cx);
|
||||
|
||||
const PLACEHOLDER: &str = "selection ";
|
||||
let selections = selection_ranges(workspace, cx)
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(ix, (buffer, range))| {
|
||||
(
|
||||
buffer,
|
||||
range,
|
||||
(PLACEHOLDER.len() * ix)..(PLACEHOLDER.len() * (ix + 1) - 1),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let new_text = std::iter::repeat(PLACEHOLDER)
|
||||
.take(selections.len())
|
||||
.chain(std::iter::once(""))
|
||||
.join(" ");
|
||||
let new_text: String = PLACEHOLDER.repeat(selections.len());
|
||||
|
||||
let callback = Arc::new({
|
||||
let mention_set = mention_set.clone();
|
||||
let selections = selections.clone();
|
||||
let source_range = source_range.clone();
|
||||
move |_, window: &mut Window, cx: &mut App| {
|
||||
let editor = editor.clone();
|
||||
let mention_set = mention_set.clone();
|
||||
let selections = selections.clone();
|
||||
let message_editor = message_editor.clone();
|
||||
let source_range = source_range.clone();
|
||||
window.defer(cx, move |window, cx| {
|
||||
let mut current_offset = 0;
|
||||
|
||||
for (buffer, selection_range) in selections {
|
||||
let snapshot =
|
||||
editor.read(cx).buffer().read(cx).snapshot(cx);
|
||||
let Some(start) = snapshot
|
||||
.anchor_in_excerpt(excerpt_id, source_range.start)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let offset = start.to_offset(&snapshot) + current_offset;
|
||||
let text_len = PLACEHOLDER.len() - 1;
|
||||
|
||||
let range = snapshot.anchor_after(offset)
|
||||
..snapshot.anchor_after(offset + text_len);
|
||||
|
||||
let path = buffer
|
||||
.read(cx)
|
||||
.file()
|
||||
.map_or(PathBuf::from("untitled"), |file| {
|
||||
file.path().to_path_buf()
|
||||
});
|
||||
|
||||
let point_range = snapshot
|
||||
.as_singleton()
|
||||
.map(|(_, _, snapshot)| {
|
||||
selection_range.to_point(&snapshot)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let line_range = point_range.start.row..point_range.end.row;
|
||||
|
||||
let uri = MentionUri::Selection {
|
||||
path: path.clone(),
|
||||
line_range: line_range.clone(),
|
||||
};
|
||||
let crease = crate::context_picker::crease_for_mention(
|
||||
selection_name(&path, &line_range).into(),
|
||||
uri.icon_path(cx),
|
||||
range,
|
||||
editor.downgrade(),
|
||||
);
|
||||
|
||||
let [crease_id]: [_; 1] =
|
||||
editor.update(cx, |editor, cx| {
|
||||
let crease_ids =
|
||||
editor.insert_creases(vec![crease.clone()], cx);
|
||||
editor.fold_creases(
|
||||
vec![crease],
|
||||
false,
|
||||
message_editor
|
||||
.update(cx, |message_editor, cx| {
|
||||
message_editor.confirm_mention_for_selection(
|
||||
source_range,
|
||||
selections,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
crease_ids.try_into().unwrap()
|
||||
)
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
|
||||
mention_set.lock().insert_uri(
|
||||
crease_id,
|
||||
MentionUri::Selection { path, line_range },
|
||||
);
|
||||
|
||||
current_offset += text_len + 1;
|
||||
}
|
||||
});
|
||||
|
||||
false
|
||||
}
|
||||
});
|
||||
|
@ -647,11 +555,9 @@ impl ContextPickerCompletionProvider {
|
|||
|
||||
fn completion_for_thread(
|
||||
thread_entry: ThreadContextEntry,
|
||||
excerpt_id: ExcerptId,
|
||||
source_range: Range<Anchor>,
|
||||
recent: bool,
|
||||
editor: Entity<Editor>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
editor: WeakEntity<MessageEditor>,
|
||||
cx: &mut App,
|
||||
) -> Completion {
|
||||
let uri = match &thread_entry {
|
||||
|
@ -683,13 +589,10 @@ impl ContextPickerCompletionProvider {
|
|||
source: project::CompletionSource::Custom,
|
||||
icon_path: Some(icon_for_completion.clone()),
|
||||
confirm: Some(confirm_completion_callback(
|
||||
uri.icon_path(cx),
|
||||
thread_entry.title().clone(),
|
||||
excerpt_id,
|
||||
source_range.start,
|
||||
new_text_len - 1,
|
||||
editor.clone(),
|
||||
mention_set,
|
||||
editor,
|
||||
uri,
|
||||
)),
|
||||
}
|
||||
|
@ -697,10 +600,8 @@ impl ContextPickerCompletionProvider {
|
|||
|
||||
fn completion_for_rules(
|
||||
rule: RulesContextEntry,
|
||||
excerpt_id: ExcerptId,
|
||||
source_range: Range<Anchor>,
|
||||
editor: Entity<Editor>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
editor: WeakEntity<MessageEditor>,
|
||||
cx: &mut App,
|
||||
) -> Completion {
|
||||
let uri = MentionUri::Rule {
|
||||
|
@ -719,13 +620,10 @@ impl ContextPickerCompletionProvider {
|
|||
source: project::CompletionSource::Custom,
|
||||
icon_path: Some(icon_path.clone()),
|
||||
confirm: Some(confirm_completion_callback(
|
||||
icon_path,
|
||||
rule.title.clone(),
|
||||
excerpt_id,
|
||||
source_range.start,
|
||||
new_text_len - 1,
|
||||
editor.clone(),
|
||||
mention_set,
|
||||
editor,
|
||||
uri,
|
||||
)),
|
||||
}
|
||||
|
@ -736,10 +634,8 @@ impl ContextPickerCompletionProvider {
|
|||
path_prefix: &str,
|
||||
is_recent: bool,
|
||||
is_directory: bool,
|
||||
excerpt_id: ExcerptId,
|
||||
source_range: Range<Anchor>,
|
||||
editor: Entity<Editor>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
message_editor: WeakEntity<MessageEditor>,
|
||||
project: Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Option<Completion> {
|
||||
|
@ -777,13 +673,10 @@ impl ContextPickerCompletionProvider {
|
|||
icon_path: Some(completion_icon_path),
|
||||
insert_text_mode: None,
|
||||
confirm: Some(confirm_completion_callback(
|
||||
crease_icon_path,
|
||||
file_name,
|
||||
excerpt_id,
|
||||
source_range.start,
|
||||
new_text_len - 1,
|
||||
editor,
|
||||
mention_set.clone(),
|
||||
message_editor,
|
||||
file_uri,
|
||||
)),
|
||||
})
|
||||
|
@ -791,10 +684,8 @@ impl ContextPickerCompletionProvider {
|
|||
|
||||
fn completion_for_symbol(
|
||||
symbol: Symbol,
|
||||
excerpt_id: ExcerptId,
|
||||
source_range: Range<Anchor>,
|
||||
editor: Entity<Editor>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
message_editor: WeakEntity<MessageEditor>,
|
||||
workspace: Entity<Workspace>,
|
||||
cx: &mut App,
|
||||
) -> Option<Completion> {
|
||||
|
@ -820,13 +711,10 @@ impl ContextPickerCompletionProvider {
|
|||
icon_path: Some(icon_path.clone()),
|
||||
insert_text_mode: None,
|
||||
confirm: Some(confirm_completion_callback(
|
||||
icon_path,
|
||||
symbol.name.clone().into(),
|
||||
excerpt_id,
|
||||
source_range.start,
|
||||
new_text_len - 1,
|
||||
editor.clone(),
|
||||
mention_set.clone(),
|
||||
message_editor,
|
||||
uri,
|
||||
)),
|
||||
})
|
||||
|
@ -835,116 +723,32 @@ impl ContextPickerCompletionProvider {
|
|||
fn completion_for_fetch(
|
||||
source_range: Range<Anchor>,
|
||||
url_to_fetch: SharedString,
|
||||
excerpt_id: ExcerptId,
|
||||
editor: Entity<Editor>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
http_client: Arc<HttpClientWithUrl>,
|
||||
message_editor: WeakEntity<MessageEditor>,
|
||||
cx: &mut App,
|
||||
) -> Option<Completion> {
|
||||
let new_text = format!("@fetch {} ", url_to_fetch.clone());
|
||||
let new_text_len = new_text.len();
|
||||
let mention_uri = MentionUri::Fetch {
|
||||
url: url::Url::parse(url_to_fetch.as_ref())
|
||||
let url_to_fetch = url::Url::parse(url_to_fetch.as_ref())
|
||||
.or_else(|_| url::Url::parse(&format!("https://{url_to_fetch}")))
|
||||
.ok()?,
|
||||
.ok()?;
|
||||
let mention_uri = MentionUri::Fetch {
|
||||
url: url_to_fetch.clone(),
|
||||
};
|
||||
let icon_path = mention_uri.icon_path(cx);
|
||||
Some(Completion {
|
||||
replace_range: source_range.clone(),
|
||||
new_text,
|
||||
new_text: new_text.clone(),
|
||||
label: CodeLabel::plain(url_to_fetch.to_string(), None),
|
||||
documentation: None,
|
||||
source: project::CompletionSource::Custom,
|
||||
icon_path: Some(icon_path.clone()),
|
||||
insert_text_mode: None,
|
||||
confirm: Some({
|
||||
let start = source_range.start;
|
||||
let content_len = new_text_len - 1;
|
||||
let editor = editor.clone();
|
||||
let url_to_fetch = url_to_fetch.clone();
|
||||
let source_range = source_range.clone();
|
||||
let icon_path = icon_path.clone();
|
||||
let mention_uri = mention_uri.clone();
|
||||
Arc::new(move |_, window, cx| {
|
||||
let Some(url) = url::Url::parse(url_to_fetch.as_ref())
|
||||
.or_else(|_| url::Url::parse(&format!("https://{url_to_fetch}")))
|
||||
.notify_app_err(cx)
|
||||
else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let editor = editor.clone();
|
||||
let mention_set = mention_set.clone();
|
||||
let http_client = http_client.clone();
|
||||
let source_range = source_range.clone();
|
||||
let icon_path = icon_path.clone();
|
||||
let mention_uri = mention_uri.clone();
|
||||
window.defer(cx, move |window, cx| {
|
||||
let url = url.clone();
|
||||
|
||||
let Some(crease_id) = crate::context_picker::insert_crease_for_mention(
|
||||
excerpt_id,
|
||||
start,
|
||||
content_len,
|
||||
url.to_string().into(),
|
||||
icon_path,
|
||||
editor.clone(),
|
||||
window,
|
||||
cx,
|
||||
) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let editor = editor.clone();
|
||||
let mention_set = mention_set.clone();
|
||||
let http_client = http_client.clone();
|
||||
let source_range = source_range.clone();
|
||||
|
||||
let url_string = url.to_string();
|
||||
let fetch = cx
|
||||
.background_executor()
|
||||
.spawn(async move {
|
||||
fetch_url_content(http_client, url_string)
|
||||
.map_err(|e| e.to_string())
|
||||
.await
|
||||
})
|
||||
.shared();
|
||||
mention_set.lock().add_fetch_result(url, fetch.clone());
|
||||
|
||||
window
|
||||
.spawn(cx, async move |cx| {
|
||||
if fetch.await.notify_async_err(cx).is_some() {
|
||||
mention_set
|
||||
.lock()
|
||||
.insert_uri(crease_id, mention_uri.clone());
|
||||
} else {
|
||||
// Remove crease if we failed to fetch
|
||||
editor
|
||||
.update(cx, |editor, cx| {
|
||||
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||
let Some(anchor) = snapshot
|
||||
.anchor_in_excerpt(excerpt_id, source_range.start)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
editor.display_map.update(cx, |display_map, cx| {
|
||||
display_map.unfold_intersecting(
|
||||
vec![anchor..anchor],
|
||||
true,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
editor.remove_creases([crease_id], cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
Some(())
|
||||
})
|
||||
.detach();
|
||||
});
|
||||
false
|
||||
})
|
||||
}),
|
||||
confirm: Some(confirm_completion_callback(
|
||||
url_to_fetch.to_string().into(),
|
||||
source_range.start,
|
||||
new_text.len() - 1,
|
||||
message_editor,
|
||||
mention_uri,
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -968,7 +772,7 @@ fn build_code_label_for_full_path(file_name: &str, directory: Option<&str>, cx:
|
|||
impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
fn completions(
|
||||
&self,
|
||||
excerpt_id: ExcerptId,
|
||||
_excerpt_id: ExcerptId,
|
||||
buffer: &Entity<Buffer>,
|
||||
buffer_position: Anchor,
|
||||
_trigger: CompletionContext,
|
||||
|
@ -992,39 +796,24 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
|||
};
|
||||
|
||||
let project = workspace.read(cx).project().clone();
|
||||
let http_client = workspace.read(cx).client().http_client();
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let source_range = snapshot.anchor_before(state.source_range.start)
|
||||
..snapshot.anchor_after(state.source_range.end);
|
||||
|
||||
let thread_store = self.thread_store.clone();
|
||||
let text_thread_store = self.text_thread_store.clone();
|
||||
let editor = self.editor.clone();
|
||||
let editor = self.message_editor.clone();
|
||||
let Ok((exclude_paths, exclude_threads)) =
|
||||
self.message_editor.update(cx, |message_editor, _cx| {
|
||||
message_editor.mentioned_path_and_threads()
|
||||
})
|
||||
else {
|
||||
return Task::ready(Ok(Vec::new()));
|
||||
};
|
||||
|
||||
let MentionCompletion { mode, argument, .. } = state;
|
||||
let query = argument.unwrap_or_else(|| "".to_string());
|
||||
|
||||
let (exclude_paths, exclude_threads) = {
|
||||
let mention_set = self.mention_set.lock();
|
||||
|
||||
let mut excluded_paths = HashSet::default();
|
||||
let mut excluded_threads = HashSet::default();
|
||||
|
||||
for uri in mention_set.uri_by_crease_id.values() {
|
||||
match uri {
|
||||
MentionUri::File { abs_path, .. } => {
|
||||
excluded_paths.insert(abs_path.clone());
|
||||
}
|
||||
MentionUri::Thread { id, .. } => {
|
||||
excluded_threads.insert(id.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
(excluded_paths, excluded_threads)
|
||||
};
|
||||
|
||||
let recent_entries = recent_context_picker_entries(
|
||||
Some(thread_store.clone()),
|
||||
Some(text_thread_store.clone()),
|
||||
|
@ -1051,13 +840,8 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
|||
cx,
|
||||
);
|
||||
|
||||
let mention_set = self.mention_set.clone();
|
||||
|
||||
cx.spawn(async move |_, cx| {
|
||||
let matches = search_task.await;
|
||||
let Some(editor) = editor.upgrade() else {
|
||||
return Ok(Vec::new());
|
||||
};
|
||||
|
||||
let completions = cx.update(|cx| {
|
||||
matches
|
||||
|
@ -1074,10 +858,8 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
|||
&mat.path_prefix,
|
||||
is_recent,
|
||||
mat.is_dir,
|
||||
excerpt_id,
|
||||
source_range.clone(),
|
||||
editor.clone(),
|
||||
mention_set.clone(),
|
||||
project.clone(),
|
||||
cx,
|
||||
)
|
||||
|
@ -1085,10 +867,8 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
|||
|
||||
Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol(
|
||||
symbol,
|
||||
excerpt_id,
|
||||
source_range.clone(),
|
||||
editor.clone(),
|
||||
mention_set.clone(),
|
||||
workspace.clone(),
|
||||
cx,
|
||||
),
|
||||
|
@ -1097,39 +877,30 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
|||
thread, is_recent, ..
|
||||
}) => Some(Self::completion_for_thread(
|
||||
thread,
|
||||
excerpt_id,
|
||||
source_range.clone(),
|
||||
is_recent,
|
||||
editor.clone(),
|
||||
mention_set.clone(),
|
||||
cx,
|
||||
)),
|
||||
|
||||
Match::Rules(user_rules) => Some(Self::completion_for_rules(
|
||||
user_rules,
|
||||
excerpt_id,
|
||||
source_range.clone(),
|
||||
editor.clone(),
|
||||
mention_set.clone(),
|
||||
cx,
|
||||
)),
|
||||
|
||||
Match::Fetch(url) => Self::completion_for_fetch(
|
||||
source_range.clone(),
|
||||
url,
|
||||
excerpt_id,
|
||||
editor.clone(),
|
||||
mention_set.clone(),
|
||||
http_client.clone(),
|
||||
cx,
|
||||
),
|
||||
|
||||
Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry(
|
||||
entry,
|
||||
excerpt_id,
|
||||
source_range.clone(),
|
||||
editor.clone(),
|
||||
mention_set.clone(),
|
||||
&workspace,
|
||||
cx,
|
||||
),
|
||||
|
@ -1182,36 +953,30 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
|||
}
|
||||
|
||||
fn confirm_completion_callback(
|
||||
crease_icon_path: SharedString,
|
||||
crease_text: SharedString,
|
||||
excerpt_id: ExcerptId,
|
||||
start: Anchor,
|
||||
content_len: usize,
|
||||
editor: Entity<Editor>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
message_editor: WeakEntity<MessageEditor>,
|
||||
mention_uri: MentionUri,
|
||||
) -> Arc<dyn Fn(CompletionIntent, &mut Window, &mut App) -> bool + Send + Sync> {
|
||||
Arc::new(move |_, window, cx| {
|
||||
let message_editor = message_editor.clone();
|
||||
let crease_text = crease_text.clone();
|
||||
let crease_icon_path = crease_icon_path.clone();
|
||||
let editor = editor.clone();
|
||||
let mention_set = mention_set.clone();
|
||||
let mention_uri = mention_uri.clone();
|
||||
window.defer(cx, move |window, cx| {
|
||||
if let Some(crease_id) = crate::context_picker::insert_crease_for_mention(
|
||||
excerpt_id,
|
||||
message_editor
|
||||
.clone()
|
||||
.update(cx, |message_editor, cx| {
|
||||
message_editor.confirm_completion(
|
||||
crease_text,
|
||||
start,
|
||||
content_len,
|
||||
crease_text.clone(),
|
||||
crease_icon_path,
|
||||
editor.clone(),
|
||||
mention_uri,
|
||||
window,
|
||||
cx,
|
||||
) {
|
||||
mention_set
|
||||
.lock()
|
||||
.insert_uri(crease_id, mention_uri.clone());
|
||||
}
|
||||
)
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
false
|
||||
})
|
||||
|
@ -1279,13 +1044,13 @@ impl MentionCompletion {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use editor::AnchorRangeExt;
|
||||
use editor::{AnchorRangeExt, EditorMode};
|
||||
use gpui::{EventEmitter, FocusHandle, Focusable, TestAppContext, VisualTestContext};
|
||||
use project::{Project, ProjectPath};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt as _;
|
||||
use std::{ops::Deref, path::Path, rc::Rc};
|
||||
use std::{ops::Deref, path::Path};
|
||||
use util::path;
|
||||
use workspace::{AppState, Item};
|
||||
|
||||
|
@ -1359,9 +1124,9 @@ mod tests {
|
|||
assert_eq!(MentionCompletion::try_parse("test@", 0), None);
|
||||
}
|
||||
|
||||
struct AtMentionEditor(Entity<Editor>);
|
||||
struct MessageEditorItem(Entity<MessageEditor>);
|
||||
|
||||
impl Item for AtMentionEditor {
|
||||
impl Item for MessageEditorItem {
|
||||
type Event = ();
|
||||
|
||||
fn include_in_nav_history() -> bool {
|
||||
|
@ -1373,15 +1138,15 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<()> for AtMentionEditor {}
|
||||
impl EventEmitter<()> for MessageEditorItem {}
|
||||
|
||||
impl Focusable for AtMentionEditor {
|
||||
impl Focusable for MessageEditorItem {
|
||||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||
self.0.read(cx).focus_handle(cx).clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AtMentionEditor {
|
||||
impl Render for MessageEditorItem {
|
||||
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
|
||||
self.0.clone().into_any_element()
|
||||
}
|
||||
|
@ -1467,19 +1232,28 @@ mod tests {
|
|||
opened_editors.push(buffer);
|
||||
}
|
||||
|
||||
let editor = workspace.update_in(&mut cx, |workspace, window, cx| {
|
||||
let editor = cx.new(|cx| {
|
||||
Editor::new(
|
||||
editor::EditorMode::full(),
|
||||
multi_buffer::MultiBuffer::build_simple("", cx),
|
||||
None,
|
||||
let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx));
|
||||
let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
|
||||
|
||||
let (message_editor, editor) = workspace.update_in(&mut cx, |workspace, window, cx| {
|
||||
let workspace_handle = cx.weak_entity();
|
||||
let message_editor = cx.new(|cx| {
|
||||
MessageEditor::new(
|
||||
workspace_handle,
|
||||
project.clone(),
|
||||
thread_store.clone(),
|
||||
text_thread_store.clone(),
|
||||
EditorMode::AutoHeight {
|
||||
max_lines: None,
|
||||
min_lines: 1,
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
workspace.active_pane().update(cx, |pane, cx| {
|
||||
pane.add_item(
|
||||
Box::new(cx.new(|_| AtMentionEditor(editor.clone()))),
|
||||
Box::new(cx.new(|_| MessageEditorItem(message_editor.clone()))),
|
||||
true,
|
||||
true,
|
||||
None,
|
||||
|
@ -1487,24 +1261,9 @@ mod tests {
|
|||
cx,
|
||||
);
|
||||
});
|
||||
editor
|
||||
});
|
||||
|
||||
let mention_set = Arc::new(Mutex::new(MentionSet::default()));
|
||||
|
||||
let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx));
|
||||
let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
|
||||
|
||||
let editor_entity = editor.downgrade();
|
||||
editor.update_in(&mut cx, |editor, window, cx| {
|
||||
window.focus(&editor.focus_handle(cx));
|
||||
editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new(
|
||||
mention_set.clone(),
|
||||
workspace.downgrade(),
|
||||
thread_store.downgrade(),
|
||||
text_thread_store.downgrade(),
|
||||
editor_entity,
|
||||
))));
|
||||
message_editor.read(cx).focus_handle(cx).focus(window);
|
||||
let editor = message_editor.read(cx).editor().clone();
|
||||
(message_editor, editor)
|
||||
});
|
||||
|
||||
cx.simulate_input("Lorem ");
|
||||
|
@ -1573,9 +1332,9 @@ mod tests {
|
|||
);
|
||||
});
|
||||
|
||||
let contents = cx
|
||||
.update(|window, cx| {
|
||||
mention_set.lock().contents(
|
||||
let contents = message_editor
|
||||
.update_in(&mut cx, |message_editor, window, cx| {
|
||||
message_editor.mention_set().contents(
|
||||
project.clone(),
|
||||
thread_store.clone(),
|
||||
text_thread_store.clone(),
|
||||
|
@ -1641,9 +1400,9 @@ mod tests {
|
|||
|
||||
cx.run_until_parked();
|
||||
|
||||
let contents = cx
|
||||
.update(|window, cx| {
|
||||
mention_set.lock().contents(
|
||||
let contents = message_editor
|
||||
.update_in(&mut cx, |message_editor, window, cx| {
|
||||
message_editor.mention_set().contents(
|
||||
project.clone(),
|
||||
thread_store.clone(),
|
||||
text_thread_store.clone(),
|
||||
|
@ -1765,9 +1524,9 @@ mod tests {
|
|||
editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx);
|
||||
});
|
||||
|
||||
let contents = cx
|
||||
.update(|window, cx| {
|
||||
mention_set.lock().contents(
|
||||
let contents = message_editor
|
||||
.update_in(&mut cx, |message_editor, window, cx| {
|
||||
message_editor.mention_set().contents(
|
||||
project.clone(),
|
||||
thread_store,
|
||||
text_thread_store,
|
||||
|
|
|
@ -1,45 +1,141 @@
|
|||
use std::{collections::HashMap, ops::Range};
|
||||
use std::ops::Range;
|
||||
|
||||
use acp_thread::AcpThread;
|
||||
use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer};
|
||||
use acp_thread::{AcpThread, AgentThreadEntry};
|
||||
use agent::{TextThreadStore, ThreadStore};
|
||||
use collections::HashMap;
|
||||
use editor::{Editor, EditorMode, MinimapVisibility};
|
||||
use gpui::{
|
||||
AnyEntity, App, AppContext as _, Entity, EntityId, TextStyleRefinement, WeakEntity, Window,
|
||||
AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, TextStyleRefinement,
|
||||
WeakEntity, Window,
|
||||
};
|
||||
use language::language_settings::SoftWrap;
|
||||
use project::Project;
|
||||
use settings::Settings as _;
|
||||
use terminal_view::TerminalView;
|
||||
use theme::ThemeSettings;
|
||||
use ui::TextSize;
|
||||
use ui::{Context, TextSize};
|
||||
use workspace::Workspace;
|
||||
|
||||
#[derive(Default)]
|
||||
use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
|
||||
|
||||
pub struct EntryViewState {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
text_thread_store: Entity<TextThreadStore>,
|
||||
entries: Vec<Entry>,
|
||||
}
|
||||
|
||||
impl EntryViewState {
|
||||
pub fn new(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
text_thread_store: Entity<TextThreadStore>,
|
||||
) -> Self {
|
||||
Self {
|
||||
workspace,
|
||||
project,
|
||||
thread_store,
|
||||
text_thread_store,
|
||||
entries: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn entry(&self, index: usize) -> Option<&Entry> {
|
||||
self.entries.get(index)
|
||||
}
|
||||
|
||||
pub fn sync_entry(
|
||||
&mut self,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
thread: Entity<AcpThread>,
|
||||
index: usize,
|
||||
thread: &Entity<AcpThread>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
debug_assert!(index <= self.entries.len());
|
||||
let entry = if let Some(entry) = self.entries.get_mut(index) {
|
||||
entry
|
||||
} else {
|
||||
self.entries.push(Entry::default());
|
||||
self.entries.last_mut().unwrap()
|
||||
let Some(thread_entry) = thread.read(cx).entries().get(index) else {
|
||||
return;
|
||||
};
|
||||
|
||||
entry.sync_diff_multibuffers(&thread, index, window, cx);
|
||||
entry.sync_terminals(&workspace, &thread, index, window, cx);
|
||||
match thread_entry {
|
||||
AgentThreadEntry::UserMessage(message) => {
|
||||
let has_id = message.id.is_some();
|
||||
let chunks = message.chunks.clone();
|
||||
let message_editor = cx.new(|cx| {
|
||||
let mut editor = MessageEditor::new(
|
||||
self.workspace.clone(),
|
||||
self.project.clone(),
|
||||
self.thread_store.clone(),
|
||||
self.text_thread_store.clone(),
|
||||
editor::EditorMode::AutoHeight {
|
||||
min_lines: 1,
|
||||
max_lines: None,
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
if !has_id {
|
||||
editor.set_read_only(true, cx);
|
||||
}
|
||||
editor.set_message(chunks, window, cx);
|
||||
editor
|
||||
});
|
||||
cx.subscribe(&message_editor, move |_, editor, event, cx| {
|
||||
cx.emit(EntryViewEvent {
|
||||
entry_index: index,
|
||||
view_event: ViewEvent::MessageEditorEvent(editor, *event),
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
self.set_entry(index, Entry::UserMessage(message_editor));
|
||||
}
|
||||
AgentThreadEntry::ToolCall(tool_call) => {
|
||||
let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
|
||||
let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
|
||||
|
||||
let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
|
||||
views
|
||||
} else {
|
||||
self.set_entry(index, Entry::empty());
|
||||
let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
|
||||
unreachable!()
|
||||
};
|
||||
views
|
||||
};
|
||||
|
||||
for terminal in terminals {
|
||||
views.entry(terminal.entity_id()).or_insert_with(|| {
|
||||
create_terminal(
|
||||
self.workspace.clone(),
|
||||
self.project.clone(),
|
||||
terminal.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.into_any()
|
||||
});
|
||||
}
|
||||
|
||||
for diff in diffs {
|
||||
views
|
||||
.entry(diff.entity_id())
|
||||
.or_insert_with(|| create_editor_diff(diff.clone(), window, cx).into_any());
|
||||
}
|
||||
}
|
||||
AgentThreadEntry::AssistantMessage(_) => {
|
||||
if index == self.entries.len() {
|
||||
self.entries.push(Entry::empty())
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn set_entry(&mut self, index: usize, entry: Entry) {
|
||||
if index == self.entries.len() {
|
||||
self.entries.push(entry);
|
||||
} else {
|
||||
self.entries[index] = entry;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, range: Range<usize>) {
|
||||
|
@ -48,26 +144,51 @@ impl EntryViewState {
|
|||
|
||||
pub fn settings_changed(&mut self, cx: &mut App) {
|
||||
for entry in self.entries.iter() {
|
||||
for view in entry.views.values() {
|
||||
match entry {
|
||||
Entry::UserMessage { .. } => {}
|
||||
Entry::Content(response_views) => {
|
||||
for view in response_views.values() {
|
||||
if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
|
||||
diff_editor.update(cx, |diff_editor, cx| {
|
||||
diff_editor
|
||||
.set_text_style_refinement(diff_editor_text_style_refinement(cx));
|
||||
diff_editor.set_text_style_refinement(
|
||||
diff_editor_text_style_refinement(cx),
|
||||
);
|
||||
cx.notify();
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Entry {
|
||||
views: HashMap<EntityId, AnyEntity>,
|
||||
impl EventEmitter<EntryViewEvent> for EntryViewState {}
|
||||
|
||||
pub struct EntryViewEvent {
|
||||
pub entry_index: usize,
|
||||
pub view_event: ViewEvent,
|
||||
}
|
||||
|
||||
pub enum ViewEvent {
|
||||
MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
|
||||
}
|
||||
|
||||
pub enum Entry {
|
||||
UserMessage(Entity<MessageEditor>),
|
||||
Content(HashMap<EntityId, AnyEntity>),
|
||||
}
|
||||
|
||||
impl Entry {
|
||||
pub fn editor_for_diff(&self, diff: &Entity<MultiBuffer>) -> Option<Entity<Editor>> {
|
||||
self.views
|
||||
pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
|
||||
match self {
|
||||
Self::UserMessage(editor) => Some(editor),
|
||||
Entry::Content(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
|
||||
self.content_map()?
|
||||
.get(&diff.entity_id())
|
||||
.cloned()
|
||||
.map(|entity| entity.downcast::<Editor>().unwrap())
|
||||
|
@ -77,42 +198,66 @@ impl Entry {
|
|||
&self,
|
||||
terminal: &Entity<acp_thread::Terminal>,
|
||||
) -> Option<Entity<TerminalView>> {
|
||||
self.views
|
||||
self.content_map()?
|
||||
.get(&terminal.entity_id())
|
||||
.cloned()
|
||||
.map(|entity| entity.downcast::<TerminalView>().unwrap())
|
||||
}
|
||||
|
||||
fn sync_diff_multibuffers(
|
||||
&mut self,
|
||||
thread: &Entity<AcpThread>,
|
||||
index: usize,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let Some(entry) = thread.read(cx).entries().get(index) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let multibuffers = entry
|
||||
.diffs()
|
||||
.map(|diff| diff.read(cx).multibuffer().clone());
|
||||
|
||||
let multibuffers = multibuffers.collect::<Vec<_>>();
|
||||
|
||||
for multibuffer in multibuffers {
|
||||
if self.views.contains_key(&multibuffer.entity_id()) {
|
||||
return;
|
||||
fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
|
||||
match self {
|
||||
Self::Content(map) => Some(map),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
let editor = cx.new(|cx| {
|
||||
fn empty() -> Self {
|
||||
Self::Content(HashMap::default())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn has_content(&self) -> bool {
|
||||
match self {
|
||||
Self::Content(map) => !map.is_empty(),
|
||||
Self::UserMessage(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_terminal(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
terminal: Entity<acp_thread::Terminal>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<TerminalView> {
|
||||
cx.new(|cx| {
|
||||
let mut view = TerminalView::new(
|
||||
terminal.read(cx).inner().clone(),
|
||||
workspace.clone(),
|
||||
None,
|
||||
project.downgrade(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
view.set_embedded_mode(Some(1000), cx);
|
||||
view
|
||||
})
|
||||
}
|
||||
|
||||
fn create_editor_diff(
|
||||
diff: Entity<acp_thread::Diff>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<Editor> {
|
||||
cx.new(|cx| {
|
||||
let mut editor = Editor::new(
|
||||
EditorMode::Full {
|
||||
scale_ui_elements_with_buffer_font_size: false,
|
||||
show_active_line_background: false,
|
||||
sized_by_content: true,
|
||||
},
|
||||
multibuffer.clone(),
|
||||
diff.read(cx).multibuffer().clone(),
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
|
@ -132,61 +277,7 @@ impl Entry {
|
|||
editor.set_expand_all_diff_hunks(cx);
|
||||
editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
|
||||
editor
|
||||
});
|
||||
|
||||
let entity_id = multibuffer.entity_id();
|
||||
self.views.insert(entity_id, editor.into_any());
|
||||
}
|
||||
}
|
||||
|
||||
fn sync_terminals(
|
||||
&mut self,
|
||||
workspace: &WeakEntity<Workspace>,
|
||||
thread: &Entity<AcpThread>,
|
||||
index: usize,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let Some(entry) = thread.read(cx).entries().get(index) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let terminals = entry
|
||||
.terminals()
|
||||
.map(|terminal| terminal.clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for terminal in terminals {
|
||||
if self.views.contains_key(&terminal.entity_id()) {
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(strong_workspace) = workspace.upgrade() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let terminal_view = cx.new(|cx| {
|
||||
let mut view = TerminalView::new(
|
||||
terminal.read(cx).inner().clone(),
|
||||
workspace.clone(),
|
||||
None,
|
||||
strong_workspace.read(cx).project().downgrade(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
view.set_embedded_mode(Some(1000), cx);
|
||||
view
|
||||
});
|
||||
|
||||
let entity_id = terminal.entity_id();
|
||||
self.views.insert(entity_id, terminal_view.into_any());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn len(&self) -> usize {
|
||||
self.views.len()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
|
||||
|
@ -201,26 +292,20 @@ fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
|
|||
}
|
||||
}
|
||||
|
||||
impl Default for Entry {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
// Avoid allocating in the heap by default
|
||||
views: HashMap::with_capacity(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{path::Path, rc::Rc};
|
||||
|
||||
use acp_thread::{AgentConnection, StubAgentConnection};
|
||||
use agent::{TextThreadStore, ThreadStore};
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
|
||||
use editor::{EditorSettings, RowInfo};
|
||||
use fs::FakeFs;
|
||||
use gpui::{SemanticVersion, TestAppContext};
|
||||
use gpui::{AppContext as _, SemanticVersion, TestAppContext};
|
||||
|
||||
use crate::acp::entry_view_state::EntryViewState;
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use pretty_assertions::assert_matches;
|
||||
use project::Project;
|
||||
|
@ -230,8 +315,6 @@ mod tests {
|
|||
use util::path;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::acp::entry_view_state::EntryViewState;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_diff_sync(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
@ -269,7 +352,7 @@ mod tests {
|
|||
.update(|_, cx| {
|
||||
connection
|
||||
.clone()
|
||||
.new_thread(project, Path::new(path!("/project")), cx)
|
||||
.new_thread(project.clone(), Path::new(path!("/project")), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
@ -279,12 +362,23 @@ mod tests {
|
|||
connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
|
||||
});
|
||||
|
||||
let mut view_state = EntryViewState::default();
|
||||
cx.update(|window, cx| {
|
||||
view_state.sync_entry(workspace.downgrade(), thread.clone(), 0, window, cx);
|
||||
let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx));
|
||||
let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
|
||||
|
||||
let view_state = cx.new(|_cx| {
|
||||
EntryViewState::new(
|
||||
workspace.downgrade(),
|
||||
project.clone(),
|
||||
thread_store,
|
||||
text_thread_store,
|
||||
)
|
||||
});
|
||||
|
||||
let multibuffer = thread.read_with(cx, |thread, cx| {
|
||||
view_state.update_in(cx, |view_state, window, cx| {
|
||||
view_state.sync_entry(0, &thread, window, cx)
|
||||
});
|
||||
|
||||
let diff = thread.read_with(cx, |thread, _cx| {
|
||||
thread
|
||||
.entries()
|
||||
.get(0)
|
||||
|
@ -292,15 +386,14 @@ mod tests {
|
|||
.diffs()
|
||||
.next()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.multibuffer()
|
||||
.clone()
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let entry = view_state.entry(0).unwrap();
|
||||
let diff_editor = entry.editor_for_diff(&multibuffer).unwrap();
|
||||
let diff_editor = view_state.read_with(cx, |view_state, _cx| {
|
||||
view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
|
||||
});
|
||||
assert_eq!(
|
||||
diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
|
||||
"hi world\nhello world"
|
||||
|
|
|
@ -1,61 +1,63 @@
|
|||
use crate::acp::completion_provider::ContextPickerCompletionProvider;
|
||||
use crate::acp::completion_provider::MentionImage;
|
||||
use crate::acp::completion_provider::MentionSet;
|
||||
use acp_thread::MentionUri;
|
||||
use agent::TextThreadStore;
|
||||
use agent::ThreadStore;
|
||||
use crate::{
|
||||
acp::completion_provider::{ContextPickerCompletionProvider, MentionImage, MentionSet},
|
||||
context_picker::fetch_context_picker::fetch_url_content,
|
||||
};
|
||||
use acp_thread::{MentionUri, selection_name};
|
||||
use agent::{TextThreadStore, ThreadId, ThreadStore};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use collections::HashSet;
|
||||
use editor::ExcerptId;
|
||||
use editor::actions::Paste;
|
||||
use editor::display_map::CreaseId;
|
||||
use editor::{
|
||||
AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode,
|
||||
EditorStyle, MultiBuffer,
|
||||
Anchor, AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement,
|
||||
EditorMode, EditorStyle, ExcerptId, FoldPlaceholder, MultiBuffer, ToOffset,
|
||||
actions::Paste,
|
||||
display_map::{Crease, CreaseId, FoldId},
|
||||
};
|
||||
use futures::FutureExt as _;
|
||||
use gpui::ClipboardEntry;
|
||||
use gpui::Image;
|
||||
use gpui::ImageFormat;
|
||||
use futures::{FutureExt as _, TryFutureExt as _};
|
||||
use gpui::{
|
||||
AppContext, Context, Entity, EventEmitter, FocusHandle, Focusable, Task, TextStyle, WeakEntity,
|
||||
AppContext, ClipboardEntry, Context, Entity, EventEmitter, FocusHandle, Focusable, Image,
|
||||
ImageFormat, Img, Task, TextStyle, WeakEntity,
|
||||
};
|
||||
use language::Buffer;
|
||||
use language::Language;
|
||||
use language::{Buffer, Language};
|
||||
use language_model::LanguageModelImage;
|
||||
use parking_lot::Mutex;
|
||||
use project::{CompletionIntent, Project};
|
||||
use settings::Settings;
|
||||
use std::fmt::Write;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
ffi::OsStr,
|
||||
fmt::Write,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
};
|
||||
use text::OffsetRangeExt;
|
||||
use theme::ThemeSettings;
|
||||
use ui::IconName;
|
||||
use ui::SharedString;
|
||||
use ui::{
|
||||
ActiveTheme, App, InteractiveElement, IntoElement, ParentElement, Render, Styled, TextSize,
|
||||
Window, div,
|
||||
ActiveTheme, AnyElement, App, ButtonCommon, ButtonLike, ButtonStyle, Color, Icon, IconName,
|
||||
IconSize, InteractiveElement, IntoElement, Label, LabelCommon, LabelSize, ParentElement,
|
||||
Render, SelectableButton, SharedString, Styled, TextSize, TintColor, Toggleable, Window, div,
|
||||
h_flex,
|
||||
};
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
use workspace::notifications::NotifyResultExt as _;
|
||||
use workspace::{Workspace, notifications::NotifyResultExt as _};
|
||||
use zed_actions::agent::Chat;
|
||||
|
||||
use super::completion_provider::Mention;
|
||||
|
||||
pub struct MessageEditor {
|
||||
mention_set: MentionSet,
|
||||
editor: Entity<Editor>,
|
||||
project: Entity<Project>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
text_thread_store: Entity<TextThreadStore>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum MessageEditorEvent {
|
||||
Send,
|
||||
Cancel,
|
||||
Focus,
|
||||
}
|
||||
|
||||
impl EventEmitter<MessageEditorEvent> for MessageEditor {}
|
||||
|
@ -77,8 +79,13 @@ impl MessageEditor {
|
|||
},
|
||||
None,
|
||||
);
|
||||
|
||||
let mention_set = Arc::new(Mutex::new(MentionSet::default()));
|
||||
let completion_provider = ContextPickerCompletionProvider::new(
|
||||
workspace.clone(),
|
||||
thread_store.downgrade(),
|
||||
text_thread_store.downgrade(),
|
||||
cx.weak_entity(),
|
||||
);
|
||||
let mention_set = MentionSet::default();
|
||||
let editor = cx.new(|cx| {
|
||||
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
|
||||
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
|
@ -88,13 +95,7 @@ impl MessageEditor {
|
|||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_soft_wrap();
|
||||
editor.set_use_modal_editing(true);
|
||||
editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new(
|
||||
mention_set.clone(),
|
||||
workspace,
|
||||
thread_store.downgrade(),
|
||||
text_thread_store.downgrade(),
|
||||
cx.weak_entity(),
|
||||
))));
|
||||
editor.set_completion_provider(Some(Rc::new(completion_provider)));
|
||||
editor.set_context_menu_options(ContextMenuOptions {
|
||||
min_entries_visible: 12,
|
||||
max_entries_visible: 12,
|
||||
|
@ -103,25 +104,254 @@ impl MessageEditor {
|
|||
editor
|
||||
});
|
||||
|
||||
cx.on_focus(&editor.focus_handle(cx), window, |_, _, cx| {
|
||||
cx.emit(MessageEditorEvent::Focus)
|
||||
})
|
||||
.detach();
|
||||
|
||||
Self {
|
||||
editor,
|
||||
project,
|
||||
mention_set,
|
||||
thread_store,
|
||||
text_thread_store,
|
||||
workspace,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn editor(&self) -> &Entity<Editor> {
|
||||
&self.editor
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn mention_set(&mut self) -> &mut MentionSet {
|
||||
&mut self.mention_set
|
||||
}
|
||||
|
||||
pub fn is_empty(&self, cx: &App) -> bool {
|
||||
self.editor.read(cx).is_empty(cx)
|
||||
}
|
||||
|
||||
pub fn mentioned_path_and_threads(&self) -> (HashSet<PathBuf>, HashSet<ThreadId>) {
|
||||
let mut excluded_paths = HashSet::default();
|
||||
let mut excluded_threads = HashSet::default();
|
||||
|
||||
for uri in self.mention_set.uri_by_crease_id.values() {
|
||||
match uri {
|
||||
MentionUri::File { abs_path, .. } => {
|
||||
excluded_paths.insert(abs_path.clone());
|
||||
}
|
||||
MentionUri::Thread { id, .. } => {
|
||||
excluded_threads.insert(id.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
(excluded_paths, excluded_threads)
|
||||
}
|
||||
|
||||
pub fn confirm_completion(
|
||||
&mut self,
|
||||
crease_text: SharedString,
|
||||
start: text::Anchor,
|
||||
content_len: usize,
|
||||
mention_uri: MentionUri,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let snapshot = self
|
||||
.editor
|
||||
.update(cx, |editor, cx| editor.snapshot(window, cx));
|
||||
let Some((excerpt_id, _, _)) = snapshot.buffer_snapshot.as_singleton() else {
|
||||
return;
|
||||
};
|
||||
let Some(anchor) = snapshot
|
||||
.buffer_snapshot
|
||||
.anchor_in_excerpt(*excerpt_id, start)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(crease_id) = crate::context_picker::insert_crease_for_mention(
|
||||
*excerpt_id,
|
||||
start,
|
||||
content_len,
|
||||
crease_text.clone(),
|
||||
mention_uri.icon_path(cx),
|
||||
self.editor.clone(),
|
||||
window,
|
||||
cx,
|
||||
) else {
|
||||
return;
|
||||
};
|
||||
self.mention_set.insert_uri(crease_id, mention_uri.clone());
|
||||
|
||||
match mention_uri {
|
||||
MentionUri::Fetch { url } => {
|
||||
self.confirm_mention_for_fetch(crease_id, anchor, url, window, cx);
|
||||
}
|
||||
MentionUri::File {
|
||||
abs_path,
|
||||
is_directory,
|
||||
} => {
|
||||
self.confirm_mention_for_file(
|
||||
crease_id,
|
||||
anchor,
|
||||
abs_path,
|
||||
is_directory,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
MentionUri::Symbol { .. }
|
||||
| MentionUri::Thread { .. }
|
||||
| MentionUri::TextThread { .. }
|
||||
| MentionUri::Rule { .. }
|
||||
| MentionUri::Selection { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn confirm_mention_for_file(
|
||||
&mut self,
|
||||
crease_id: CreaseId,
|
||||
anchor: Anchor,
|
||||
abs_path: PathBuf,
|
||||
_is_directory: bool,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let extension = abs_path
|
||||
.extension()
|
||||
.and_then(OsStr::to_str)
|
||||
.unwrap_or_default();
|
||||
let project = self.project.clone();
|
||||
let Some(project_path) = project
|
||||
.read(cx)
|
||||
.project_path_for_absolute_path(&abs_path, cx)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
if Img::extensions().contains(&extension) && !extension.contains("svg") {
|
||||
let image = cx.spawn(async move |_, cx| {
|
||||
let image = project
|
||||
.update(cx, |project, cx| project.open_image(project_path, cx))?
|
||||
.await?;
|
||||
image.read_with(cx, |image, _cx| image.image.clone())
|
||||
});
|
||||
self.confirm_mention_for_image(crease_id, anchor, Some(abs_path), image, window, cx);
|
||||
}
|
||||
}
|
||||
|
||||
fn confirm_mention_for_fetch(
|
||||
&mut self,
|
||||
crease_id: CreaseId,
|
||||
anchor: Anchor,
|
||||
url: url::Url,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(http_client) = self
|
||||
.workspace
|
||||
.update(cx, |workspace, _cx| workspace.client().http_client())
|
||||
.ok()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let url_string = url.to_string();
|
||||
let fetch = cx
|
||||
.background_executor()
|
||||
.spawn(async move {
|
||||
fetch_url_content(http_client, url_string)
|
||||
.map_err(|e| e.to_string())
|
||||
.await
|
||||
})
|
||||
.shared();
|
||||
self.mention_set
|
||||
.add_fetch_result(url.clone(), fetch.clone());
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let fetch = fetch.await.notify_async_err(cx);
|
||||
this.update(cx, |this, cx| {
|
||||
let mention_uri = MentionUri::Fetch { url };
|
||||
if fetch.is_some() {
|
||||
this.mention_set.insert_uri(crease_id, mention_uri.clone());
|
||||
} else {
|
||||
// Remove crease if we failed to fetch
|
||||
this.editor.update(cx, |editor, cx| {
|
||||
editor.display_map.update(cx, |display_map, cx| {
|
||||
display_map.unfold_intersecting(vec![anchor..anchor], true, cx);
|
||||
});
|
||||
editor.remove_creases([crease_id], cx);
|
||||
});
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
pub fn confirm_mention_for_selection(
|
||||
&mut self,
|
||||
source_range: Range<text::Anchor>,
|
||||
selections: Vec<(Entity<Buffer>, Range<text::Anchor>, Range<usize>)>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let snapshot = self.editor.read(cx).buffer().read(cx).snapshot(cx);
|
||||
let Some((&excerpt_id, _, _)) = snapshot.as_singleton() else {
|
||||
return;
|
||||
};
|
||||
let Some(start) = snapshot.anchor_in_excerpt(excerpt_id, source_range.start) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let offset = start.to_offset(&snapshot);
|
||||
|
||||
for (buffer, selection_range, range_to_fold) in selections {
|
||||
let range = snapshot.anchor_after(offset + range_to_fold.start)
|
||||
..snapshot.anchor_after(offset + range_to_fold.end);
|
||||
|
||||
let path = buffer
|
||||
.read(cx)
|
||||
.file()
|
||||
.map_or(PathBuf::from("untitled"), |file| file.path().to_path_buf());
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
|
||||
let point_range = selection_range.to_point(&snapshot);
|
||||
let line_range = point_range.start.row..point_range.end.row;
|
||||
|
||||
let uri = MentionUri::Selection {
|
||||
path: path.clone(),
|
||||
line_range: line_range.clone(),
|
||||
};
|
||||
let crease = crate::context_picker::crease_for_mention(
|
||||
selection_name(&path, &line_range).into(),
|
||||
uri.icon_path(cx),
|
||||
range,
|
||||
self.editor.downgrade(),
|
||||
);
|
||||
|
||||
let crease_id = self.editor.update(cx, |editor, cx| {
|
||||
let crease_ids = editor.insert_creases(vec![crease.clone()], cx);
|
||||
editor.fold_creases(vec![crease], false, window, cx);
|
||||
crease_ids.first().copied().unwrap()
|
||||
});
|
||||
|
||||
self.mention_set
|
||||
.insert_uri(crease_id, MentionUri::Selection { path, line_range });
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contents(
|
||||
&self,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Vec<acp::ContentBlock>>> {
|
||||
let contents = self.mention_set.lock().contents(
|
||||
let contents = self.mention_set.contents(
|
||||
self.project.clone(),
|
||||
self.thread_store.clone(),
|
||||
self.text_thread_store.clone(),
|
||||
|
@ -198,15 +428,15 @@ impl MessageEditor {
|
|||
pub fn clear(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
editor.clear(window, cx);
|
||||
editor.remove_creases(self.mention_set.lock().drain(), cx)
|
||||
editor.remove_creases(self.mention_set.drain(), cx)
|
||||
});
|
||||
}
|
||||
|
||||
fn chat(&mut self, _: &Chat, _: &mut Window, cx: &mut Context<Self>) {
|
||||
fn send(&mut self, _: &Chat, _: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.emit(MessageEditorEvent::Send)
|
||||
}
|
||||
|
||||
fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
|
||||
fn cancel(&mut self, _: &editor::actions::Cancel, _: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.emit(MessageEditorEvent::Cancel)
|
||||
}
|
||||
|
||||
|
@ -233,11 +463,16 @@ impl MessageEditor {
|
|||
|
||||
let replacement_text = "image";
|
||||
for image in images {
|
||||
let (excerpt_id, anchor) = self.editor.update(cx, |message_editor, cx| {
|
||||
let (excerpt_id, text_anchor, multibuffer_anchor) =
|
||||
self.editor.update(cx, |message_editor, cx| {
|
||||
let snapshot = message_editor.snapshot(window, cx);
|
||||
let (excerpt_id, _, snapshot) = snapshot.buffer_snapshot.as_singleton().unwrap();
|
||||
let (excerpt_id, _, buffer_snapshot) =
|
||||
snapshot.buffer_snapshot.as_singleton().unwrap();
|
||||
|
||||
let anchor = snapshot.anchor_before(snapshot.len());
|
||||
let text_anchor = buffer_snapshot.anchor_before(buffer_snapshot.len());
|
||||
let multibuffer_anchor = snapshot
|
||||
.buffer_snapshot
|
||||
.anchor_in_excerpt(*excerpt_id, text_anchor);
|
||||
message_editor.edit(
|
||||
[(
|
||||
multi_buffer::Anchor::max()..multi_buffer::Anchor::max(),
|
||||
|
@ -245,15 +480,29 @@ impl MessageEditor {
|
|||
)],
|
||||
cx,
|
||||
);
|
||||
(*excerpt_id, anchor)
|
||||
(*excerpt_id, text_anchor, multibuffer_anchor)
|
||||
});
|
||||
|
||||
self.insert_image(
|
||||
let content_len = replacement_text.len();
|
||||
let Some(anchor) = multibuffer_anchor else {
|
||||
return;
|
||||
};
|
||||
let Some(crease_id) = insert_crease_for_image(
|
||||
excerpt_id,
|
||||
text_anchor,
|
||||
content_len,
|
||||
None.clone(),
|
||||
self.editor.clone(),
|
||||
window,
|
||||
cx,
|
||||
) else {
|
||||
return;
|
||||
};
|
||||
self.confirm_mention_for_image(
|
||||
crease_id,
|
||||
anchor,
|
||||
replacement_text.len(),
|
||||
Arc::new(image),
|
||||
None,
|
||||
Task::ready(Ok(Arc::new(image))),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
@ -267,9 +516,6 @@ impl MessageEditor {
|
|||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let buffer = self.editor.read(cx).buffer().clone();
|
||||
let Some((&excerpt_id, _, _)) = buffer.read(cx).snapshot(cx).as_singleton() else {
|
||||
return;
|
||||
};
|
||||
let Some(buffer) = buffer.read(cx).as_singleton() else {
|
||||
return;
|
||||
};
|
||||
|
@ -292,10 +538,8 @@ impl MessageEditor {
|
|||
&path_prefix,
|
||||
false,
|
||||
entry.is_dir(),
|
||||
excerpt_id,
|
||||
anchor..anchor,
|
||||
self.editor.clone(),
|
||||
self.mention_set.clone(),
|
||||
cx.weak_entity(),
|
||||
self.project.clone(),
|
||||
cx,
|
||||
) else {
|
||||
|
@ -317,33 +561,32 @@ impl MessageEditor {
|
|||
}
|
||||
}
|
||||
|
||||
fn insert_image(
|
||||
pub fn set_read_only(&mut self, read_only: bool, cx: &mut Context<Self>) {
|
||||
self.editor.update(cx, |message_editor, cx| {
|
||||
message_editor.set_read_only(read_only);
|
||||
cx.notify()
|
||||
})
|
||||
}
|
||||
|
||||
fn confirm_mention_for_image(
|
||||
&mut self,
|
||||
excerpt_id: ExcerptId,
|
||||
crease_start: text::Anchor,
|
||||
content_len: usize,
|
||||
image: Arc<Image>,
|
||||
abs_path: Option<Arc<Path>>,
|
||||
crease_id: CreaseId,
|
||||
anchor: Anchor,
|
||||
abs_path: Option<PathBuf>,
|
||||
image: Task<Result<Arc<Image>>>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(crease_id) = insert_crease_for_image(
|
||||
excerpt_id,
|
||||
crease_start,
|
||||
content_len,
|
||||
self.editor.clone(),
|
||||
window,
|
||||
cx,
|
||||
) else {
|
||||
return;
|
||||
};
|
||||
self.editor.update(cx, |_editor, cx| {
|
||||
let format = image.format;
|
||||
let convert = LanguageModelImage::from_image(image, cx);
|
||||
|
||||
let task = cx
|
||||
.spawn_in(window, async move |editor, cx| {
|
||||
if let Some(image) = convert.await {
|
||||
let image = image.await.map_err(|e| e.to_string())?;
|
||||
let format = image.format;
|
||||
let image = cx
|
||||
.update(|_, cx| LanguageModelImage::from_image(image, cx))
|
||||
.map_err(|e| e.to_string())?
|
||||
.await;
|
||||
if let Some(image) = image {
|
||||
Ok(MentionImage {
|
||||
abs_path,
|
||||
data: image.source,
|
||||
|
@ -352,12 +595,6 @@ impl MessageEditor {
|
|||
} else {
|
||||
editor
|
||||
.update(cx, |editor, cx| {
|
||||
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||
let Some(anchor) =
|
||||
snapshot.anchor_in_excerpt(excerpt_id, crease_start)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
editor.display_map.update(cx, |display_map, cx| {
|
||||
display_map.unfold_intersecting(vec![anchor..anchor], true, cx);
|
||||
});
|
||||
|
@ -375,7 +612,7 @@ impl MessageEditor {
|
|||
})
|
||||
.detach();
|
||||
|
||||
self.mention_set.lock().insert_image(crease_id, task);
|
||||
self.mention_set.insert_image(crease_id, task);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -392,6 +629,8 @@ impl MessageEditor {
|
|||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.clear(window, cx);
|
||||
|
||||
let mut text = String::new();
|
||||
let mut mentions = Vec::new();
|
||||
let mut images = Vec::new();
|
||||
|
@ -429,7 +668,6 @@ impl MessageEditor {
|
|||
editor.buffer().read(cx).snapshot(cx)
|
||||
});
|
||||
|
||||
self.mention_set.lock().clear();
|
||||
for (range, mention_uri) in mentions {
|
||||
let anchor = snapshot.anchor_before(range.start);
|
||||
let crease_id = crate::context_picker::insert_crease_for_mention(
|
||||
|
@ -444,7 +682,7 @@ impl MessageEditor {
|
|||
);
|
||||
|
||||
if let Some(crease_id) = crease_id {
|
||||
self.mention_set.lock().insert_uri(crease_id, mention_uri);
|
||||
self.mention_set.insert_uri(crease_id, mention_uri);
|
||||
}
|
||||
}
|
||||
for (range, content) in images {
|
||||
|
@ -479,7 +717,7 @@ impl MessageEditor {
|
|||
let data: SharedString = content.data.to_string().into();
|
||||
|
||||
if let Some(crease_id) = crease_id {
|
||||
self.mention_set.lock().insert_image(
|
||||
self.mention_set.insert_image(
|
||||
crease_id,
|
||||
Task::ready(Ok(MentionImage {
|
||||
abs_path,
|
||||
|
@ -499,6 +737,11 @@ impl MessageEditor {
|
|||
editor.set_text(text, window, cx);
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn text(&self, cx: &App) -> String {
|
||||
self.editor.read(cx).text(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for MessageEditor {
|
||||
|
@ -511,7 +754,7 @@ impl Render for MessageEditor {
|
|||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
div()
|
||||
.key_context("MessageEditor")
|
||||
.on_action(cx.listener(Self::chat))
|
||||
.on_action(cx.listener(Self::send))
|
||||
.on_action(cx.listener(Self::cancel))
|
||||
.capture_action(cx.listener(Self::paste))
|
||||
.flex_1()
|
||||
|
@ -550,20 +793,78 @@ pub(crate) fn insert_crease_for_image(
|
|||
excerpt_id: ExcerptId,
|
||||
anchor: text::Anchor,
|
||||
content_len: usize,
|
||||
abs_path: Option<Arc<Path>>,
|
||||
editor: Entity<Editor>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<CreaseId> {
|
||||
crate::context_picker::insert_crease_for_mention(
|
||||
excerpt_id,
|
||||
anchor,
|
||||
content_len,
|
||||
"Image".into(),
|
||||
IconName::Image.path().into(),
|
||||
editor,
|
||||
window,
|
||||
cx,
|
||||
let crease_label = abs_path
|
||||
.as_ref()
|
||||
.and_then(|path| path.file_name())
|
||||
.map(|name| name.to_string_lossy().to_string().into())
|
||||
.unwrap_or(SharedString::from("Image"));
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||
|
||||
let start = snapshot.anchor_in_excerpt(excerpt_id, anchor)?;
|
||||
|
||||
let start = start.bias_right(&snapshot);
|
||||
let end = snapshot.anchor_before(start.to_offset(&snapshot) + content_len);
|
||||
|
||||
let placeholder = FoldPlaceholder {
|
||||
render: render_image_fold_icon_button(crease_label, cx.weak_entity()),
|
||||
merge_adjacent: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let crease = Crease::Inline {
|
||||
range: start..end,
|
||||
placeholder,
|
||||
render_toggle: None,
|
||||
render_trailer: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let ids = editor.insert_creases(vec![crease.clone()], cx);
|
||||
editor.fold_creases(vec![crease], false, window, cx);
|
||||
|
||||
Some(ids[0])
|
||||
})
|
||||
}
|
||||
|
||||
fn render_image_fold_icon_button(
|
||||
label: SharedString,
|
||||
editor: WeakEntity<Editor>,
|
||||
) -> Arc<dyn Send + Sync + Fn(FoldId, Range<Anchor>, &mut App) -> AnyElement> {
|
||||
Arc::new({
|
||||
move |fold_id, fold_range, cx| {
|
||||
let is_in_text_selection = editor
|
||||
.update(cx, |editor, cx| editor.is_range_selected(&fold_range, cx))
|
||||
.unwrap_or_default();
|
||||
|
||||
ButtonLike::new(fold_id)
|
||||
.style(ButtonStyle::Filled)
|
||||
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
|
||||
.toggle_state(is_in_text_selection)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
Icon::new(IconName::Image)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(
|
||||
Label::new(label.clone())
|
||||
.size(LabelSize::Small)
|
||||
.buffer_font(cx)
|
||||
.single_line(),
|
||||
),
|
||||
)
|
||||
.into_any_element()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -45,6 +45,7 @@ use zed_actions::assistant::OpenRulesLibrary;
|
|||
|
||||
use super::entry_view_state::EntryViewState;
|
||||
use crate::acp::AcpModelSelectorPopover;
|
||||
use crate::acp::entry_view_state::{EntryViewEvent, ViewEvent};
|
||||
use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
|
||||
use crate::agent_diff::AgentDiff;
|
||||
use crate::profile_selector::{ProfileProvider, ProfileSelector};
|
||||
|
@ -101,10 +102,8 @@ pub struct AcpThreadView {
|
|||
agent: Rc<dyn AgentServer>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
text_thread_store: Entity<TextThreadStore>,
|
||||
thread_state: ThreadState,
|
||||
entry_view_state: EntryViewState,
|
||||
entry_view_state: Entity<EntryViewState>,
|
||||
message_editor: Entity<MessageEditor>,
|
||||
model_selector: Option<Entity<AcpModelSelectorPopover>>,
|
||||
profile_selector: Option<Entity<ProfileSelector>>,
|
||||
|
@ -119,16 +118,9 @@ pub struct AcpThreadView {
|
|||
plan_expanded: bool,
|
||||
editor_expanded: bool,
|
||||
terminal_expanded: bool,
|
||||
editing_message: Option<EditingMessage>,
|
||||
editing_message: Option<usize>,
|
||||
_cancel_task: Option<Task<()>>,
|
||||
_subscriptions: [Subscription; 2],
|
||||
}
|
||||
|
||||
struct EditingMessage {
|
||||
index: usize,
|
||||
message_id: UserMessageId,
|
||||
editor: Entity<MessageEditor>,
|
||||
_subscription: Subscription,
|
||||
_subscriptions: [Subscription; 3],
|
||||
}
|
||||
|
||||
enum ThreadState {
|
||||
|
@ -175,25 +167,33 @@ impl AcpThreadView {
|
|||
|
||||
let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0));
|
||||
|
||||
let entry_view_state = cx.new(|_| {
|
||||
EntryViewState::new(
|
||||
workspace.clone(),
|
||||
project.clone(),
|
||||
thread_store.clone(),
|
||||
text_thread_store.clone(),
|
||||
)
|
||||
});
|
||||
|
||||
let subscriptions = [
|
||||
cx.observe_global_in::<SettingsStore>(window, Self::settings_changed),
|
||||
cx.subscribe_in(&message_editor, window, Self::on_message_editor_event),
|
||||
cx.subscribe_in(&message_editor, window, Self::handle_message_editor_event),
|
||||
cx.subscribe_in(&entry_view_state, window, Self::handle_entry_view_event),
|
||||
];
|
||||
|
||||
Self {
|
||||
agent: agent.clone(),
|
||||
workspace: workspace.clone(),
|
||||
project: project.clone(),
|
||||
thread_store,
|
||||
text_thread_store,
|
||||
entry_view_state,
|
||||
thread_state: Self::initial_state(agent, workspace, project, window, cx),
|
||||
message_editor,
|
||||
model_selector: None,
|
||||
profile_selector: None,
|
||||
notifications: Vec::new(),
|
||||
notification_subscriptions: HashMap::default(),
|
||||
entry_view_state: EntryViewState::default(),
|
||||
list_state,
|
||||
list_state: list_state,
|
||||
thread_error: None,
|
||||
auth_task: None,
|
||||
expanded_tool_calls: HashSet::default(),
|
||||
|
@ -412,7 +412,7 @@ impl AcpThreadView {
|
|||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn on_message_editor_event(
|
||||
pub fn handle_message_editor_event(
|
||||
&mut self,
|
||||
_: &Entity<MessageEditor>,
|
||||
event: &MessageEditorEvent,
|
||||
|
@ -422,6 +422,28 @@ impl AcpThreadView {
|
|||
match event {
|
||||
MessageEditorEvent::Send => self.send(window, cx),
|
||||
MessageEditorEvent::Cancel => self.cancel_generation(cx),
|
||||
MessageEditorEvent::Focus => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_entry_view_event(
|
||||
&mut self,
|
||||
_: &Entity<EntryViewState>,
|
||||
event: &EntryViewEvent,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match &event.view_event {
|
||||
ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Focus) => {
|
||||
self.editing_message = Some(event.entry_index);
|
||||
cx.notify();
|
||||
}
|
||||
ViewEvent::MessageEditorEvent(editor, MessageEditorEvent::Send) => {
|
||||
self.regenerate(event.entry_index, editor, window, cx);
|
||||
}
|
||||
ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Cancel) => {
|
||||
self.cancel_editing(&Default::default(), window, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -492,27 +514,56 @@ impl AcpThreadView {
|
|||
.detach();
|
||||
}
|
||||
|
||||
fn cancel_editing(&mut self, _: &ClickEvent, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.editing_message.take();
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn regenerate(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let Some(editing_message) = self.editing_message.take() else {
|
||||
return;
|
||||
};
|
||||
|
||||
fn cancel_editing(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let Some(thread) = self.thread().cloned() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let rewind = thread.update(cx, |thread, cx| {
|
||||
thread.rewind(editing_message.message_id, cx)
|
||||
});
|
||||
if let Some(index) = self.editing_message.take() {
|
||||
if let Some(editor) = self
|
||||
.entry_view_state
|
||||
.read(cx)
|
||||
.entry(index)
|
||||
.and_then(|e| e.message_editor())
|
||||
.cloned()
|
||||
{
|
||||
editor.update(cx, |editor, cx| {
|
||||
if let Some(user_message) = thread
|
||||
.read(cx)
|
||||
.entries()
|
||||
.get(index)
|
||||
.and_then(|e| e.user_message())
|
||||
{
|
||||
editor.set_message(user_message.chunks.clone(), window, cx);
|
||||
}
|
||||
})
|
||||
}
|
||||
};
|
||||
self.focus_handle(cx).focus(window);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn regenerate(
|
||||
&mut self,
|
||||
entry_ix: usize,
|
||||
message_editor: &Entity<MessageEditor>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(thread) = self.thread().cloned() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(rewind) = thread.update(cx, |thread, cx| {
|
||||
let user_message_id = thread.entries().get(entry_ix)?.user_message()?.id.clone()?;
|
||||
Some(thread.rewind(user_message_id, cx))
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let contents =
|
||||
message_editor.update(cx, |message_editor, cx| message_editor.contents(window, cx));
|
||||
|
||||
let contents = editing_message
|
||||
.editor
|
||||
.update(cx, |message_editor, cx| message_editor.contents(window, cx));
|
||||
let task = cx.foreground_executor().spawn(async move {
|
||||
rewind.await?;
|
||||
contents.await
|
||||
|
@ -568,27 +619,20 @@ impl AcpThreadView {
|
|||
AcpThreadEvent::NewEntry => {
|
||||
let len = thread.read(cx).entries().len();
|
||||
let index = len - 1;
|
||||
self.entry_view_state.sync_entry(
|
||||
self.workspace.clone(),
|
||||
thread.clone(),
|
||||
index,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
self.entry_view_state.update(cx, |view_state, cx| {
|
||||
view_state.sync_entry(index, &thread, window, cx)
|
||||
});
|
||||
self.list_state.splice(index..index, 1);
|
||||
}
|
||||
AcpThreadEvent::EntryUpdated(index) => {
|
||||
self.entry_view_state.sync_entry(
|
||||
self.workspace.clone(),
|
||||
thread.clone(),
|
||||
*index,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
self.entry_view_state.update(cx, |view_state, cx| {
|
||||
view_state.sync_entry(*index, &thread, window, cx)
|
||||
});
|
||||
self.list_state.splice(*index..index + 1, 1);
|
||||
}
|
||||
AcpThreadEvent::EntriesRemoved(range) => {
|
||||
self.entry_view_state.remove(range.clone());
|
||||
self.entry_view_state
|
||||
.update(cx, |view_state, _cx| view_state.remove(range.clone()));
|
||||
self.list_state.splice(range.clone(), 0);
|
||||
}
|
||||
AcpThreadEvent::ToolAuthorizationRequired => {
|
||||
|
@ -720,29 +764,15 @@ impl AcpThreadView {
|
|||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.text_xs()
|
||||
.id("message")
|
||||
.on_click(cx.listener({
|
||||
move |this, _, window, cx| {
|
||||
this.start_editing_message(entry_ix, window, cx)
|
||||
}
|
||||
}))
|
||||
.children(
|
||||
if let Some(editing) = self.editing_message.as_ref()
|
||||
&& Some(&editing.message_id) == message.id.as_ref()
|
||||
{
|
||||
Some(
|
||||
self.render_edit_message_editor(editing, cx)
|
||||
.into_any_element(),
|
||||
)
|
||||
} else {
|
||||
message.content.markdown().map(|md| {
|
||||
self.render_markdown(
|
||||
md.clone(),
|
||||
user_message_markdown_style(window, cx),
|
||||
)
|
||||
self.entry_view_state
|
||||
.read(cx)
|
||||
.entry(entry_ix)
|
||||
.and_then(|entry| entry.message_editor())
|
||||
.map(|editor| {
|
||||
self.render_sent_message_editor(entry_ix, editor, cx)
|
||||
.into_any_element()
|
||||
})
|
||||
},
|
||||
}),
|
||||
),
|
||||
)
|
||||
.into_any(),
|
||||
|
@ -817,8 +847,8 @@ impl AcpThreadView {
|
|||
primary
|
||||
};
|
||||
|
||||
if let Some(editing) = self.editing_message.as_ref()
|
||||
&& editing.index < entry_ix
|
||||
if let Some(editing_index) = self.editing_message.as_ref()
|
||||
&& *editing_index < entry_ix
|
||||
{
|
||||
let backdrop = div()
|
||||
.id(("backdrop", entry_ix))
|
||||
|
@ -832,8 +862,8 @@ impl AcpThreadView {
|
|||
|
||||
div()
|
||||
.relative()
|
||||
.child(backdrop)
|
||||
.child(primary)
|
||||
.child(backdrop)
|
||||
.into_any_element()
|
||||
} else {
|
||||
primary
|
||||
|
@ -1254,9 +1284,7 @@ impl AcpThreadView {
|
|||
Empty.into_any_element()
|
||||
}
|
||||
}
|
||||
ToolCallContent::Diff(diff) => {
|
||||
self.render_diff_editor(entry_ix, &diff.read(cx).multibuffer(), cx)
|
||||
}
|
||||
ToolCallContent::Diff(diff) => self.render_diff_editor(entry_ix, &diff, cx),
|
||||
ToolCallContent::Terminal(terminal) => {
|
||||
self.render_terminal_tool_call(entry_ix, terminal, tool_call, window, cx)
|
||||
}
|
||||
|
@ -1403,7 +1431,7 @@ impl AcpThreadView {
|
|||
fn render_diff_editor(
|
||||
&self,
|
||||
entry_ix: usize,
|
||||
multibuffer: &Entity<MultiBuffer>,
|
||||
diff: &Entity<acp_thread::Diff>,
|
||||
cx: &Context<Self>,
|
||||
) -> AnyElement {
|
||||
v_flex()
|
||||
|
@ -1411,8 +1439,8 @@ impl AcpThreadView {
|
|||
.border_t_1()
|
||||
.border_color(self.tool_card_border_color(cx))
|
||||
.child(
|
||||
if let Some(entry) = self.entry_view_state.entry(entry_ix)
|
||||
&& let Some(editor) = entry.editor_for_diff(&multibuffer)
|
||||
if let Some(entry) = self.entry_view_state.read(cx).entry(entry_ix)
|
||||
&& let Some(editor) = entry.editor_for_diff(&diff)
|
||||
{
|
||||
editor.clone().into_any_element()
|
||||
} else {
|
||||
|
@ -1615,6 +1643,7 @@ impl AcpThreadView {
|
|||
|
||||
let terminal_view = self
|
||||
.entry_view_state
|
||||
.read(cx)
|
||||
.entry(entry_ix)
|
||||
.and_then(|entry| entry.terminal(&terminal));
|
||||
let show_output = self.terminal_expanded && terminal_view.is_some();
|
||||
|
@ -2483,63 +2512,16 @@ impl AcpThreadView {
|
|||
)
|
||||
}
|
||||
|
||||
fn start_editing_message(&mut self, index: usize, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let Some(thread) = self.thread() else {
|
||||
return;
|
||||
};
|
||||
let Some(AgentThreadEntry::UserMessage(message)) = thread.read(cx).entries().get(index)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let Some(message_id) = message.id.clone() else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.list_state.scroll_to_reveal_item(index);
|
||||
|
||||
let chunks = message.chunks.clone();
|
||||
let editor = cx.new(|cx| {
|
||||
let mut editor = MessageEditor::new(
|
||||
self.workspace.clone(),
|
||||
self.project.clone(),
|
||||
self.thread_store.clone(),
|
||||
self.text_thread_store.clone(),
|
||||
editor::EditorMode::AutoHeight {
|
||||
min_lines: 1,
|
||||
max_lines: None,
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
editor.set_message(chunks, window, cx);
|
||||
editor
|
||||
});
|
||||
let subscription =
|
||||
cx.subscribe_in(&editor, window, |this, _, event, window, cx| match event {
|
||||
MessageEditorEvent::Send => {
|
||||
this.regenerate(&Default::default(), window, cx);
|
||||
}
|
||||
MessageEditorEvent::Cancel => {
|
||||
this.cancel_editing(&Default::default(), window, cx);
|
||||
}
|
||||
});
|
||||
editor.focus_handle(cx).focus(window);
|
||||
|
||||
self.editing_message.replace(EditingMessage {
|
||||
index: index,
|
||||
message_id: message_id.clone(),
|
||||
editor,
|
||||
_subscription: subscription,
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn render_edit_message_editor(&self, editing: &EditingMessage, cx: &Context<Self>) -> Div {
|
||||
v_flex()
|
||||
.w_full()
|
||||
.gap_2()
|
||||
.child(editing.editor.clone())
|
||||
.child(
|
||||
fn render_sent_message_editor(
|
||||
&self,
|
||||
entry_ix: usize,
|
||||
editor: &Entity<MessageEditor>,
|
||||
cx: &Context<Self>,
|
||||
) -> Div {
|
||||
v_flex().w_full().gap_2().child(editor.clone()).when(
|
||||
self.editing_message == Some(entry_ix),
|
||||
|el| {
|
||||
el.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
|
@ -2552,13 +2534,16 @@ impl AcpThreadView {
|
|||
.color(Color::Muted)
|
||||
.size(LabelSize::XSmall),
|
||||
)
|
||||
.child(self.render_editing_message_editor_buttons(editing, cx)),
|
||||
.child(self.render_sent_message_editor_buttons(entry_ix, editor, cx)),
|
||||
)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn render_editing_message_editor_buttons(
|
||||
fn render_sent_message_editor_buttons(
|
||||
&self,
|
||||
editing: &EditingMessage,
|
||||
entry_ix: usize,
|
||||
editor: &Entity<MessageEditor>,
|
||||
cx: &Context<Self>,
|
||||
) -> Div {
|
||||
h_flex()
|
||||
|
@ -2571,7 +2556,7 @@ impl AcpThreadView {
|
|||
.icon_color(Color::Error)
|
||||
.icon_size(IconSize::Small)
|
||||
.tooltip({
|
||||
let focus_handle = editing.editor.focus_handle(cx);
|
||||
let focus_handle = editor.focus_handle(cx);
|
||||
move |window, cx| {
|
||||
Tooltip::for_action_in(
|
||||
"Cancel Edit",
|
||||
|
@ -2586,12 +2571,12 @@ impl AcpThreadView {
|
|||
)
|
||||
.child(
|
||||
IconButton::new("confirm-edit-message", IconName::Return)
|
||||
.disabled(editing.editor.read(cx).is_empty(cx))
|
||||
.disabled(editor.read(cx).is_empty(cx))
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.icon_color(Color::Muted)
|
||||
.icon_size(IconSize::Small)
|
||||
.tooltip({
|
||||
let focus_handle = editing.editor.focus_handle(cx);
|
||||
let focus_handle = editor.focus_handle(cx);
|
||||
move |window, cx| {
|
||||
Tooltip::for_action_in(
|
||||
"Regenerate",
|
||||
|
@ -2602,7 +2587,12 @@ impl AcpThreadView {
|
|||
)
|
||||
}
|
||||
})
|
||||
.on_click(cx.listener(Self::regenerate)),
|
||||
.on_click(cx.listener({
|
||||
let editor = editor.clone();
|
||||
move |this, _, window, cx| {
|
||||
this.regenerate(entry_ix, &editor, window, cx);
|
||||
}
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -3102,7 +3092,9 @@ impl AcpThreadView {
|
|||
}
|
||||
|
||||
fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.entry_view_state.settings_changed(cx);
|
||||
self.entry_view_state.update(cx, |entry_view_state, cx| {
|
||||
entry_view_state.settings_changed(cx);
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn insert_dragged_files(
|
||||
|
@ -3117,9 +3109,7 @@ impl AcpThreadView {
|
|||
drop(added_worktrees);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl AcpThreadView {
|
||||
fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option<Div> {
|
||||
let content = match self.thread_error.as_ref()? {
|
||||
ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx),
|
||||
|
@ -3411,35 +3401,6 @@ impl Render for AcpThreadView {
|
|||
}
|
||||
}
|
||||
|
||||
fn user_message_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
|
||||
let mut style = default_markdown_style(false, window, cx);
|
||||
let mut text_style = window.text_style();
|
||||
let theme_settings = ThemeSettings::get_global(cx);
|
||||
|
||||
let buffer_font = theme_settings.buffer_font.family.clone();
|
||||
let buffer_font_size = TextSize::Small.rems(cx);
|
||||
|
||||
text_style.refine(&TextStyleRefinement {
|
||||
font_family: Some(buffer_font),
|
||||
font_size: Some(buffer_font_size.into()),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
style.base_text_style = text_style;
|
||||
style.link_callback = Some(Rc::new(move |url, cx| {
|
||||
if MentionUri::parse(url).is_ok() {
|
||||
let colors = cx.theme().colors();
|
||||
Some(TextStyleRefinement {
|
||||
background_color: Some(colors.element_background),
|
||||
..Default::default()
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}));
|
||||
style
|
||||
}
|
||||
|
||||
fn default_markdown_style(buffer_font: bool, window: &Window, cx: &App) -> MarkdownStyle {
|
||||
let theme_settings = ThemeSettings::get_global(cx);
|
||||
let colors = cx.theme().colors();
|
||||
|
@ -3598,12 +3559,13 @@ pub(crate) mod tests {
|
|||
use agent_client_protocol::SessionId;
|
||||
use editor::EditorSettings;
|
||||
use fs::FakeFs;
|
||||
use gpui::{SemanticVersion, TestAppContext, VisualTestContext};
|
||||
use gpui::{EventEmitter, SemanticVersion, TestAppContext, VisualTestContext};
|
||||
use project::Project;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use std::any::Any;
|
||||
use std::path::Path;
|
||||
use workspace::Item;
|
||||
|
||||
use super::*;
|
||||
|
||||
|
@ -3750,6 +3712,50 @@ pub(crate) mod tests {
|
|||
(thread_view, cx)
|
||||
}
|
||||
|
||||
fn add_to_workspace(thread_view: Entity<AcpThreadView>, cx: &mut VisualTestContext) {
|
||||
let workspace = thread_view.read_with(cx, |thread_view, _cx| thread_view.workspace.clone());
|
||||
|
||||
workspace
|
||||
.update_in(cx, |workspace, window, cx| {
|
||||
workspace.add_item_to_active_pane(
|
||||
Box::new(cx.new(|_| ThreadViewItem(thread_view.clone()))),
|
||||
None,
|
||||
true,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
struct ThreadViewItem(Entity<AcpThreadView>);
|
||||
|
||||
impl Item for ThreadViewItem {
|
||||
type Event = ();
|
||||
|
||||
fn include_in_nav_history() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
|
||||
"Test".into()
|
||||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<()> for ThreadViewItem {}
|
||||
|
||||
impl Focusable for ThreadViewItem {
|
||||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||
self.0.read(cx).focus_handle(cx).clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ThreadViewItem {
|
||||
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
|
||||
self.0.clone().into_any_element()
|
||||
}
|
||||
}
|
||||
|
||||
struct StubAgentServer<C> {
|
||||
connection: C,
|
||||
}
|
||||
|
@ -3771,19 +3777,19 @@ pub(crate) mod tests {
|
|||
C: 'static + AgentConnection + Send + Clone,
|
||||
{
|
||||
fn logo(&self) -> ui::IconName {
|
||||
unimplemented!()
|
||||
ui::IconName::Ai
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
unimplemented!()
|
||||
"Test"
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
unimplemented!()
|
||||
"Test"
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
unimplemented!()
|
||||
"Test"
|
||||
}
|
||||
|
||||
fn connect(
|
||||
|
@ -3932,9 +3938,17 @@ pub(crate) mod tests {
|
|||
assert_eq!(thread.entries().len(), 2);
|
||||
});
|
||||
|
||||
thread_view.read_with(cx, |view, _| {
|
||||
assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
|
||||
assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
|
||||
thread_view.read_with(cx, |view, cx| {
|
||||
view.entry_view_state.read_with(cx, |entry_view_state, _| {
|
||||
assert!(
|
||||
entry_view_state
|
||||
.entry(0)
|
||||
.unwrap()
|
||||
.message_editor()
|
||||
.is_some()
|
||||
);
|
||||
assert!(entry_view_state.entry(1).unwrap().has_content());
|
||||
});
|
||||
});
|
||||
|
||||
// Second user message
|
||||
|
@ -3963,18 +3977,31 @@ pub(crate) mod tests {
|
|||
|
||||
let second_user_message_id = thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.entries().len(), 4);
|
||||
let AgentThreadEntry::UserMessage(user_message) = thread.entries().get(2).unwrap()
|
||||
else {
|
||||
let AgentThreadEntry::UserMessage(user_message) = &thread.entries()[2] else {
|
||||
panic!();
|
||||
};
|
||||
user_message.id.clone().unwrap()
|
||||
});
|
||||
|
||||
thread_view.read_with(cx, |view, _| {
|
||||
assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
|
||||
assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
|
||||
assert_eq!(view.entry_view_state.entry(2).unwrap().len(), 0);
|
||||
assert_eq!(view.entry_view_state.entry(3).unwrap().len(), 1);
|
||||
thread_view.read_with(cx, |view, cx| {
|
||||
view.entry_view_state.read_with(cx, |entry_view_state, _| {
|
||||
assert!(
|
||||
entry_view_state
|
||||
.entry(0)
|
||||
.unwrap()
|
||||
.message_editor()
|
||||
.is_some()
|
||||
);
|
||||
assert!(entry_view_state.entry(1).unwrap().has_content());
|
||||
assert!(
|
||||
entry_view_state
|
||||
.entry(2)
|
||||
.unwrap()
|
||||
.message_editor()
|
||||
.is_some()
|
||||
);
|
||||
assert!(entry_view_state.entry(3).unwrap().has_content());
|
||||
});
|
||||
});
|
||||
|
||||
// Rewind to first message
|
||||
|
@ -3989,13 +4016,169 @@ pub(crate) mod tests {
|
|||
assert_eq!(thread.entries().len(), 2);
|
||||
});
|
||||
|
||||
thread_view.read_with(cx, |view, _| {
|
||||
assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
|
||||
assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
|
||||
thread_view.read_with(cx, |view, cx| {
|
||||
view.entry_view_state.read_with(cx, |entry_view_state, _| {
|
||||
assert!(
|
||||
entry_view_state
|
||||
.entry(0)
|
||||
.unwrap()
|
||||
.message_editor()
|
||||
.is_some()
|
||||
);
|
||||
assert!(entry_view_state.entry(1).unwrap().has_content());
|
||||
|
||||
// Old views should be dropped
|
||||
assert!(view.entry_view_state.entry(2).is_none());
|
||||
assert!(view.entry_view_state.entry(3).is_none());
|
||||
assert!(entry_view_state.entry(2).is_none());
|
||||
assert!(entry_view_state.entry(3).is_none());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_message_editing_cancel(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let connection = StubAgentConnection::new();
|
||||
|
||||
connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk {
|
||||
content: acp::ContentBlock::Text(acp::TextContent {
|
||||
text: "Response".into(),
|
||||
annotations: None,
|
||||
}),
|
||||
}]);
|
||||
|
||||
let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await;
|
||||
add_to_workspace(thread_view.clone(), cx);
|
||||
|
||||
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
|
||||
message_editor.update_in(cx, |editor, window, cx| {
|
||||
editor.set_text("Original message to edit", window, cx);
|
||||
});
|
||||
thread_view.update_in(cx, |thread_view, window, cx| {
|
||||
thread_view.send(window, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let user_message_editor = thread_view.read_with(cx, |view, cx| {
|
||||
assert_eq!(view.editing_message, None);
|
||||
|
||||
view.entry_view_state
|
||||
.read(cx)
|
||||
.entry(0)
|
||||
.unwrap()
|
||||
.message_editor()
|
||||
.unwrap()
|
||||
.clone()
|
||||
});
|
||||
|
||||
// Focus
|
||||
cx.focus(&user_message_editor);
|
||||
thread_view.read_with(cx, |view, _cx| {
|
||||
assert_eq!(view.editing_message, Some(0));
|
||||
});
|
||||
|
||||
// Edit
|
||||
user_message_editor.update_in(cx, |editor, window, cx| {
|
||||
editor.set_text("Edited message content", window, cx);
|
||||
});
|
||||
|
||||
// Cancel
|
||||
user_message_editor.update_in(cx, |_editor, window, cx| {
|
||||
window.dispatch_action(Box::new(editor::actions::Cancel), cx);
|
||||
});
|
||||
|
||||
thread_view.read_with(cx, |view, _cx| {
|
||||
assert_eq!(view.editing_message, None);
|
||||
});
|
||||
|
||||
user_message_editor.read_with(cx, |editor, cx| {
|
||||
assert_eq!(editor.text(cx), "Original message to edit");
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_message_editing_regenerate(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let connection = StubAgentConnection::new();
|
||||
|
||||
connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk {
|
||||
content: acp::ContentBlock::Text(acp::TextContent {
|
||||
text: "Response".into(),
|
||||
annotations: None,
|
||||
}),
|
||||
}]);
|
||||
|
||||
let (thread_view, cx) =
|
||||
setup_thread_view(StubAgentServer::new(connection.clone()), cx).await;
|
||||
add_to_workspace(thread_view.clone(), cx);
|
||||
|
||||
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
|
||||
message_editor.update_in(cx, |editor, window, cx| {
|
||||
editor.set_text("Original message to edit", window, cx);
|
||||
});
|
||||
thread_view.update_in(cx, |thread_view, window, cx| {
|
||||
thread_view.send(window, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let user_message_editor = thread_view.read_with(cx, |view, cx| {
|
||||
assert_eq!(view.editing_message, None);
|
||||
assert_eq!(view.thread().unwrap().read(cx).entries().len(), 2);
|
||||
|
||||
view.entry_view_state
|
||||
.read(cx)
|
||||
.entry(0)
|
||||
.unwrap()
|
||||
.message_editor()
|
||||
.unwrap()
|
||||
.clone()
|
||||
});
|
||||
|
||||
// Focus
|
||||
cx.focus(&user_message_editor);
|
||||
|
||||
// Edit
|
||||
user_message_editor.update_in(cx, |editor, window, cx| {
|
||||
editor.set_text("Edited message content", window, cx);
|
||||
});
|
||||
|
||||
// Send
|
||||
connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk {
|
||||
content: acp::ContentBlock::Text(acp::TextContent {
|
||||
text: "New Response".into(),
|
||||
annotations: None,
|
||||
}),
|
||||
}]);
|
||||
|
||||
user_message_editor.update_in(cx, |_editor, window, cx| {
|
||||
window.dispatch_action(Box::new(Chat), cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
thread_view.read_with(cx, |view, cx| {
|
||||
assert_eq!(view.editing_message, None);
|
||||
|
||||
let entries = view.thread().unwrap().read(cx).entries();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert_eq!(
|
||||
entries[0].to_markdown(cx),
|
||||
"## User\n\nEdited message content\n\n"
|
||||
);
|
||||
assert_eq!(
|
||||
entries[1].to_markdown(cx),
|
||||
"## Assistant\n\nNew Response\n\n"
|
||||
);
|
||||
|
||||
let new_editor = view.entry_view_state.read_with(cx, |state, _cx| {
|
||||
assert!(!state.entry(1).unwrap().has_content());
|
||||
state.entry(0).unwrap().message_editor().unwrap().clone()
|
||||
});
|
||||
|
||||
assert_eq!(new_editor.read(cx).text(cx), "Edited message content");
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -462,7 +462,7 @@ impl AgentConfiguration {
|
|||
"modifier-send",
|
||||
"Use modifier to submit a message",
|
||||
Some(
|
||||
"Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(),
|
||||
"Make a modifier (cmd-enter on macOS, ctrl-enter on Linux or Windows) required to send messages.".into(),
|
||||
),
|
||||
use_modifier_to_send,
|
||||
move |state, _window, cx| {
|
||||
|
|
|
@ -818,12 +818,10 @@ impl AgentPanel {
|
|||
ActiveView::Thread { thread, .. } => {
|
||||
thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx));
|
||||
}
|
||||
ActiveView::ExternalAgentThread { thread_view, .. } => {
|
||||
thread_view.update(cx, |thread_element, cx| {
|
||||
thread_element.cancel_generation(cx)
|
||||
});
|
||||
}
|
||||
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {}
|
||||
ActiveView::ExternalAgentThread { .. }
|
||||
| ActiveView::TextThread { .. }
|
||||
| ActiveView::History
|
||||
| ActiveView::Configuration => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ use anyhow::{Result, anyhow};
|
|||
use collections::HashSet;
|
||||
pub use completion_provider::ContextPickerCompletionProvider;
|
||||
use editor::display_map::{Crease, CreaseId, CreaseMetadata, FoldId};
|
||||
use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset};
|
||||
use editor::{Anchor, Editor, ExcerptId, FoldPlaceholder, ToOffset};
|
||||
use fetch_context_picker::FetchContextPicker;
|
||||
use file_context_picker::FileContextPicker;
|
||||
use file_context_picker::render_file_context_entry;
|
||||
|
@ -228,7 +228,7 @@ impl ContextPicker {
|
|||
}
|
||||
|
||||
fn build_menu(&mut self, window: &mut Window, cx: &mut Context<Self>) -> Entity<ContextMenu> {
|
||||
let context_picker = cx.entity().clone();
|
||||
let context_picker = cx.entity();
|
||||
|
||||
let menu = ContextMenu::build(window, cx, move |menu, _window, cx| {
|
||||
let recent = self.recent_entries(cx);
|
||||
|
@ -837,42 +837,9 @@ fn render_fold_icon_button(
|
|||
) -> Arc<dyn Send + Sync + Fn(FoldId, Range<Anchor>, &mut App) -> AnyElement> {
|
||||
Arc::new({
|
||||
move |fold_id, fold_range, cx| {
|
||||
let is_in_text_selection = editor.upgrade().is_some_and(|editor| {
|
||||
editor.update(cx, |editor, cx| {
|
||||
let snapshot = editor
|
||||
.buffer()
|
||||
.update(cx, |multi_buffer, cx| multi_buffer.snapshot(cx));
|
||||
|
||||
let is_in_pending_selection = || {
|
||||
editor
|
||||
.selections
|
||||
.pending
|
||||
.as_ref()
|
||||
.is_some_and(|pending_selection| {
|
||||
pending_selection
|
||||
.selection
|
||||
.range()
|
||||
.includes(&fold_range, &snapshot)
|
||||
})
|
||||
};
|
||||
|
||||
let mut is_in_complete_selection = || {
|
||||
editor
|
||||
.selections
|
||||
.disjoint_in_range::<usize>(fold_range.clone(), cx)
|
||||
.into_iter()
|
||||
.any(|selection| {
|
||||
// This is needed to cover a corner case, if we just check for an existing
|
||||
// selection in the fold range, having a cursor at the start of the fold
|
||||
// marks it as selected. Non-empty selections don't cause this.
|
||||
let length = selection.end - selection.start;
|
||||
length > 0
|
||||
})
|
||||
};
|
||||
|
||||
is_in_pending_selection() || is_in_complete_selection()
|
||||
})
|
||||
});
|
||||
let is_in_text_selection = editor
|
||||
.update(cx, |editor, cx| editor.is_range_selected(&fold_range, cx))
|
||||
.unwrap_or_default();
|
||||
|
||||
ButtonLike::new(fold_id)
|
||||
.style(ButtonStyle::Filled)
|
||||
|
|
|
@ -72,7 +72,7 @@ pub fn init(
|
|||
let Some(window) = window else {
|
||||
return;
|
||||
};
|
||||
let workspace = cx.entity().clone();
|
||||
let workspace = cx.entity();
|
||||
InlineAssistant::update_global(cx, |inline_assistant, cx| {
|
||||
inline_assistant.register_workspace(&workspace, window, cx)
|
||||
});
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use std::{cmp::Reverse, sync::Arc};
|
||||
|
||||
use cloud_llm_client::Plan;
|
||||
use collections::{HashSet, IndexMap};
|
||||
use feature_flags::ZedProFeatureFlag;
|
||||
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
|
||||
|
@ -10,7 +11,6 @@ use language_model::{
|
|||
};
|
||||
use ordered_float::OrderedFloat;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use proto::Plan;
|
||||
use ui::{ListItem, ListItemSpacing, prelude::*};
|
||||
|
||||
const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
|
||||
|
@ -536,7 +536,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
|||
) -> Option<gpui::AnyElement> {
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
|
||||
let plan = proto::Plan::ZedPro;
|
||||
let plan = Plan::ZedPro;
|
||||
|
||||
Some(
|
||||
h_flex()
|
||||
|
@ -557,7 +557,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
|||
window
|
||||
.dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
|
||||
}),
|
||||
Plan::Free | Plan::ZedProTrial => Button::new(
|
||||
Plan::ZedFree | Plan::ZedProTrial => Button::new(
|
||||
"try-pro",
|
||||
if plan == Plan::ZedProTrial {
|
||||
"Upgrade to Pro"
|
||||
|
|
|
@ -163,7 +163,7 @@ impl Render for ProfileSelector {
|
|||
.unwrap_or_else(|| "Unknown".into());
|
||||
|
||||
if self.provider.profiles_supported(cx) {
|
||||
let this = cx.entity().clone();
|
||||
let this = cx.entity();
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
let trigger_button = Button::new("profile-selector-model", selected_profile)
|
||||
.label_size(LabelSize::Small)
|
||||
|
|
|
@ -1,16 +1,12 @@
|
|||
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::Duration;
|
||||
use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
|
||||
use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
|
||||
use futures::{StreamExt, stream::BoxStream};
|
||||
use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
|
||||
use http_client::{AsyncBody, Method, Request, http};
|
||||
use parking_lot::Mutex;
|
||||
use rpc::{
|
||||
ConnectionId, Peer, Receipt, TypedEnvelope,
|
||||
proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
|
||||
};
|
||||
use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct FakeServer {
|
||||
|
@ -187,7 +183,6 @@ impl FakeServer {
|
|||
pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
|
||||
self.executor.start_waiting();
|
||||
|
||||
loop {
|
||||
let message = self
|
||||
.state
|
||||
.lock()
|
||||
|
@ -205,33 +200,11 @@ impl FakeServer {
|
|||
return Ok(*message.downcast().unwrap());
|
||||
}
|
||||
|
||||
let accepted_tos_at = chrono::Utc::now()
|
||||
.checked_sub_signed(Duration::hours(5))
|
||||
.expect("failed to build accepted_tos_at")
|
||||
.timestamp() as u64;
|
||||
|
||||
if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
|
||||
self.respond(
|
||||
message
|
||||
.downcast::<TypedEnvelope<GetPrivateUserInfo>>()
|
||||
.unwrap()
|
||||
.receipt(),
|
||||
GetPrivateUserInfoResponse {
|
||||
metrics_id: "the-metrics-id".into(),
|
||||
staff: false,
|
||||
flags: Default::default(),
|
||||
accepted_tos_at: Some(accepted_tos_at),
|
||||
},
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
panic!(
|
||||
"fake server received unexpected message type: {:?}",
|
||||
type_name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
|
||||
self.peer.respond(receipt, response).unwrap()
|
||||
|
|
|
@ -177,7 +177,6 @@ impl UserStore {
|
|||
let (mut current_user_tx, current_user_rx) = watch::channel();
|
||||
let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded();
|
||||
let rpc_subscriptions = vec![
|
||||
client.add_message_handler(cx.weak_entity(), Self::handle_update_plan),
|
||||
client.add_message_handler(cx.weak_entity(), Self::handle_update_contacts),
|
||||
client.add_message_handler(cx.weak_entity(), Self::handle_update_invite_info),
|
||||
client.add_message_handler(cx.weak_entity(), Self::handle_show_contacts),
|
||||
|
@ -343,26 +342,6 @@ impl UserStore {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_update_plan(
|
||||
this: Entity<Self>,
|
||||
_message: TypedEnvelope<proto::UpdateUserPlan>,
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<()> {
|
||||
let client = this
|
||||
.read_with(&cx, |this, _| this.client.upgrade())?
|
||||
.context("client was dropped")?;
|
||||
|
||||
let response = client
|
||||
.cloud_client()
|
||||
.get_authenticated_user()
|
||||
.await
|
||||
.context("failed to fetch authenticated user")?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.update_authenticated_user(response, cx);
|
||||
})
|
||||
}
|
||||
|
||||
fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
|
||||
match message {
|
||||
UpdateContacts::Wait(barrier) => {
|
||||
|
@ -1019,19 +998,6 @@ impl RequestUsage {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn from_proto(amount: u32, limit: proto::UsageLimit) -> Option<Self> {
|
||||
let limit = match limit.variant? {
|
||||
proto::usage_limit::Variant::Limited(limited) => {
|
||||
UsageLimit::Limited(limited.limit as i32)
|
||||
}
|
||||
proto::usage_limit::Variant::Unlimited(_) => UsageLimit::Unlimited,
|
||||
};
|
||||
Some(RequestUsage {
|
||||
limit,
|
||||
amount: amount as i32,
|
||||
})
|
||||
}
|
||||
|
||||
fn from_headers(
|
||||
limit_name: &str,
|
||||
amount_name: &str,
|
||||
|
|
|
@ -19,7 +19,6 @@ test-support = ["sqlite"]
|
|||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
async-stripe.workspace = true
|
||||
async-trait.workspace = true
|
||||
async-tungstenite.workspace = true
|
||||
aws-config = { version = "1.1.5" }
|
||||
|
@ -33,13 +32,11 @@ clock.workspace = true
|
|||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
dashmap.workspace = true
|
||||
derive_more.workspace = true
|
||||
envy = "0.4.2"
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
hex.workspace = true
|
||||
http_client.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
livekit_api.workspace = true
|
||||
log.workspace = true
|
||||
nanoid.workspace = true
|
||||
|
@ -65,7 +62,6 @@ subtle.workspace = true
|
|||
supermaven_api.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
toml.workspace = true
|
||||
|
@ -136,6 +132,3 @@ util.workspace = true
|
|||
workspace = { workspace = true, features = ["test-support"] }
|
||||
worktree = { workspace = true, features = ["test-support"] }
|
||||
zlog.workspace = true
|
||||
|
||||
[package.metadata.cargo-machete]
|
||||
ignored = ["async-stripe"]
|
||||
|
|
|
@ -219,12 +219,6 @@ spec:
|
|||
secretKeyRef:
|
||||
name: slack
|
||||
key: panics_webhook
|
||||
- name: STRIPE_API_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: stripe
|
||||
key: api_key
|
||||
optional: true
|
||||
- name: COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR
|
||||
value: "1000"
|
||||
- name: SUPERMAVEN_ADMIN_API_KEY
|
||||
|
|
|
@ -1,16 +1,10 @@
|
|||
pub mod billing;
|
||||
pub mod contributors;
|
||||
pub mod events;
|
||||
pub mod extensions;
|
||||
pub mod ips_file;
|
||||
pub mod slack;
|
||||
|
||||
use crate::db::Database;
|
||||
use crate::{
|
||||
AppState, Error, Result, auth,
|
||||
db::{User, UserId},
|
||||
rpc,
|
||||
};
|
||||
use crate::{AppState, Error, Result, auth, db::UserId, rpc};
|
||||
use anyhow::Context as _;
|
||||
use axum::{
|
||||
Extension, Json, Router,
|
||||
|
@ -97,7 +91,6 @@ impl std::fmt::Display for SystemIdHeader {
|
|||
|
||||
pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
|
||||
Router::new()
|
||||
.route("/users/look_up", get(look_up_user))
|
||||
.route("/users/:id/access_tokens", post(create_access_token))
|
||||
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
|
||||
.merge(contributors::router())
|
||||
|
@ -139,99 +132,6 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
|
|||
Ok::<_, Error>(next.run(req).await)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LookUpUserParams {
|
||||
identifier: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct LookUpUserResponse {
|
||||
user: Option<User>,
|
||||
}
|
||||
|
||||
async fn look_up_user(
|
||||
Query(params): Query<LookUpUserParams>,
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
) -> Result<Json<LookUpUserResponse>> {
|
||||
let user = resolve_identifier_to_user(&app.db, ¶ms.identifier).await?;
|
||||
let user = if let Some(user) = user {
|
||||
match user {
|
||||
UserOrId::User(user) => Some(user),
|
||||
UserOrId::Id(id) => app.db.get_user_by_id(id).await?,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Json(LookUpUserResponse { user }))
|
||||
}
|
||||
|
||||
enum UserOrId {
|
||||
User(User),
|
||||
Id(UserId),
|
||||
}
|
||||
|
||||
async fn resolve_identifier_to_user(
|
||||
db: &Arc<Database>,
|
||||
identifier: &str,
|
||||
) -> Result<Option<UserOrId>> {
|
||||
if let Some(identifier) = identifier.parse::<i32>().ok() {
|
||||
let user = db.get_user_by_id(UserId(identifier)).await?;
|
||||
|
||||
return Ok(user.map(UserOrId::User));
|
||||
}
|
||||
|
||||
if identifier.starts_with("cus_") {
|
||||
let billing_customer = db
|
||||
.get_billing_customer_by_stripe_customer_id(&identifier)
|
||||
.await?;
|
||||
|
||||
return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)));
|
||||
}
|
||||
|
||||
if identifier.starts_with("sub_") {
|
||||
let billing_subscription = db
|
||||
.get_billing_subscription_by_stripe_subscription_id(&identifier)
|
||||
.await?;
|
||||
|
||||
if let Some(billing_subscription) = billing_subscription {
|
||||
let billing_customer = db
|
||||
.get_billing_customer_by_id(billing_subscription.billing_customer_id)
|
||||
.await?;
|
||||
|
||||
return Ok(
|
||||
billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))
|
||||
);
|
||||
} else {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
if identifier.contains('@') {
|
||||
let user = db.get_user_by_email(identifier).await?;
|
||||
|
||||
return Ok(user.map(UserOrId::User));
|
||||
}
|
||||
|
||||
if let Some(user) = db.get_user_by_github_login(identifier).await? {
|
||||
return Ok(Some(UserOrId::User(user)));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct CreateUserParams {
|
||||
github_user_id: i32,
|
||||
github_login: String,
|
||||
email_address: String,
|
||||
email_confirmation_code: Option<String>,
|
||||
#[serde(default)]
|
||||
admin: bool,
|
||||
#[serde(default)]
|
||||
invite_count: i32,
|
||||
}
|
||||
|
||||
async fn get_rpc_server_snapshot(
|
||||
Extension(rpc_server): Extension<Arc<rpc::Server>>,
|
||||
) -> Result<ErasedJson> {
|
||||
|
|
|
@ -1,59 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
use stripe::SubscriptionStatus;
|
||||
|
||||
use crate::AppState;
|
||||
use crate::db::billing_subscription::StripeSubscriptionStatus;
|
||||
use crate::db::{CreateBillingCustomerParams, billing_customer};
|
||||
use crate::stripe_client::{StripeClient, StripeCustomerId};
|
||||
|
||||
impl From<SubscriptionStatus> for StripeSubscriptionStatus {
|
||||
fn from(value: SubscriptionStatus) -> Self {
|
||||
match value {
|
||||
SubscriptionStatus::Incomplete => Self::Incomplete,
|
||||
SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
|
||||
SubscriptionStatus::Trialing => Self::Trialing,
|
||||
SubscriptionStatus::Active => Self::Active,
|
||||
SubscriptionStatus::PastDue => Self::PastDue,
|
||||
SubscriptionStatus::Canceled => Self::Canceled,
|
||||
SubscriptionStatus::Unpaid => Self::Unpaid,
|
||||
SubscriptionStatus::Paused => Self::Paused,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Finds or creates a billing customer using the provided customer.
|
||||
pub async fn find_or_create_billing_customer(
|
||||
app: &Arc<AppState>,
|
||||
stripe_client: &dyn StripeClient,
|
||||
customer_id: &StripeCustomerId,
|
||||
) -> anyhow::Result<Option<billing_customer::Model>> {
|
||||
// If we already have a billing customer record associated with the Stripe customer,
|
||||
// there's nothing more we need to do.
|
||||
if let Some(billing_customer) = app
|
||||
.db
|
||||
.get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
|
||||
.await?
|
||||
{
|
||||
return Ok(Some(billing_customer));
|
||||
}
|
||||
|
||||
let customer = stripe_client.get_customer(customer_id).await?;
|
||||
|
||||
let Some(email) = customer.email else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let Some(user) = app.db.get_user_by_email(&email).await? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let billing_customer = app
|
||||
.db
|
||||
.create_billing_customer(&CreateBillingCustomerParams {
|
||||
user_id: user.id,
|
||||
stripe_customer_id: customer.id.to_string(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(Some(billing_customer))
|
||||
}
|
|
@ -564,170 +564,10 @@ fn for_snowflake(
|
|||
country_code: Option<String>,
|
||||
checksum_matched: bool,
|
||||
) -> impl Iterator<Item = SnowflakeRow> {
|
||||
body.events.into_iter().filter_map(move |event| {
|
||||
body.events.into_iter().map(move |event| {
|
||||
let timestamp =
|
||||
first_event_at + Duration::milliseconds(event.milliseconds_since_first_event);
|
||||
// We will need to double check, but I believe all of the events that
|
||||
// are being transformed here are now migrated over to use the
|
||||
// telemetry::event! macro, as of this commit so this code can go away
|
||||
// when we feel enough users have upgraded past this point.
|
||||
let (event_type, mut event_properties) = match &event.event {
|
||||
Event::Editor(e) => (
|
||||
match e.operation.as_str() {
|
||||
"open" => "Editor Opened".to_string(),
|
||||
"save" => "Editor Saved".to_string(),
|
||||
_ => format!("Unknown Editor Event: {}", e.operation),
|
||||
},
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::EditPrediction(e) => (
|
||||
format!(
|
||||
"Edit Prediction {}",
|
||||
if e.suggestion_accepted {
|
||||
"Accepted"
|
||||
} else {
|
||||
"Discarded"
|
||||
}
|
||||
),
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::EditPredictionRating(e) => (
|
||||
"Edit Prediction Rated".to_string(),
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::Call(e) => {
|
||||
let event_type = match e.operation.trim() {
|
||||
"unshare project" => "Project Unshared".to_string(),
|
||||
"open channel notes" => "Channel Notes Opened".to_string(),
|
||||
"share project" => "Project Shared".to_string(),
|
||||
"join channel" => "Channel Joined".to_string(),
|
||||
"hang up" => "Call Ended".to_string(),
|
||||
"accept incoming" => "Incoming Call Accepted".to_string(),
|
||||
"invite" => "Participant Invited".to_string(),
|
||||
"disable microphone" => "Microphone Disabled".to_string(),
|
||||
"enable microphone" => "Microphone Enabled".to_string(),
|
||||
"enable screen share" => "Screen Share Enabled".to_string(),
|
||||
"disable screen share" => "Screen Share Disabled".to_string(),
|
||||
"decline incoming" => "Incoming Call Declined".to_string(),
|
||||
_ => format!("Unknown Call Event: {}", e.operation),
|
||||
};
|
||||
|
||||
(event_type, serde_json::to_value(e).unwrap())
|
||||
}
|
||||
Event::Assistant(e) => (
|
||||
match e.phase {
|
||||
telemetry_events::AssistantPhase::Response => "Assistant Responded".to_string(),
|
||||
telemetry_events::AssistantPhase::Invoked => "Assistant Invoked".to_string(),
|
||||
telemetry_events::AssistantPhase::Accepted => {
|
||||
"Assistant Response Accepted".to_string()
|
||||
}
|
||||
telemetry_events::AssistantPhase::Rejected => {
|
||||
"Assistant Response Rejected".to_string()
|
||||
}
|
||||
},
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::Cpu(_) | Event::Memory(_) => return None,
|
||||
Event::App(e) => {
|
||||
let mut properties = json!({});
|
||||
let event_type = match e.operation.trim() {
|
||||
// App
|
||||
"open" => "App Opened".to_string(),
|
||||
"first open" => "App First Opened".to_string(),
|
||||
"first open for release channel" => {
|
||||
"App First Opened For Release Channel".to_string()
|
||||
}
|
||||
"close" => "App Closed".to_string(),
|
||||
|
||||
// Project
|
||||
"open project" => "Project Opened".to_string(),
|
||||
"open node project" => {
|
||||
properties["project_type"] = json!("node");
|
||||
"Project Opened".to_string()
|
||||
}
|
||||
"open pnpm project" => {
|
||||
properties["project_type"] = json!("pnpm");
|
||||
"Project Opened".to_string()
|
||||
}
|
||||
"open yarn project" => {
|
||||
properties["project_type"] = json!("yarn");
|
||||
"Project Opened".to_string()
|
||||
}
|
||||
|
||||
// SSH
|
||||
"create ssh server" => "SSH Server Created".to_string(),
|
||||
"create ssh project" => "SSH Project Created".to_string(),
|
||||
"open ssh project" => "SSH Project Opened".to_string(),
|
||||
|
||||
// Welcome Page
|
||||
"welcome page: change keymap" => "Welcome Keymap Changed".to_string(),
|
||||
"welcome page: change theme" => "Welcome Theme Changed".to_string(),
|
||||
"welcome page: close" => "Welcome Page Closed".to_string(),
|
||||
"welcome page: edit settings" => "Welcome Settings Edited".to_string(),
|
||||
"welcome page: install cli" => "Welcome CLI Installed".to_string(),
|
||||
"welcome page: open" => "Welcome Page Opened".to_string(),
|
||||
"welcome page: open extensions" => "Welcome Extensions Page Opened".to_string(),
|
||||
"welcome page: sign in to copilot" => "Welcome Copilot Signed In".to_string(),
|
||||
"welcome page: toggle diagnostic telemetry" => {
|
||||
"Welcome Diagnostic Telemetry Toggled".to_string()
|
||||
}
|
||||
"welcome page: toggle metric telemetry" => {
|
||||
"Welcome Metric Telemetry Toggled".to_string()
|
||||
}
|
||||
"welcome page: toggle vim" => "Welcome Vim Mode Toggled".to_string(),
|
||||
"welcome page: view docs" => "Welcome Documentation Viewed".to_string(),
|
||||
|
||||
// Extensions
|
||||
"extensions page: open" => "Extensions Page Opened".to_string(),
|
||||
"extensions: install extension" => "Extension Installed".to_string(),
|
||||
"extensions: uninstall extension" => "Extension Uninstalled".to_string(),
|
||||
|
||||
// Misc
|
||||
"markdown preview: open" => "Markdown Preview Opened".to_string(),
|
||||
"project diagnostics: open" => "Project Diagnostics Opened".to_string(),
|
||||
"project search: open" => "Project Search Opened".to_string(),
|
||||
"repl sessions: open" => "REPL Session Started".to_string(),
|
||||
|
||||
// Feature Upsell
|
||||
"feature upsell: toggle vim" => {
|
||||
properties["source"] = json!("Feature Upsell");
|
||||
"Vim Mode Toggled".to_string()
|
||||
}
|
||||
_ => e
|
||||
.operation
|
||||
.strip_prefix("feature upsell: viewed docs (")
|
||||
.and_then(|s| s.strip_suffix(')'))
|
||||
.map_or_else(
|
||||
|| format!("Unknown App Event: {}", e.operation),
|
||||
|docs_url| {
|
||||
properties["url"] = json!(docs_url);
|
||||
properties["source"] = json!("Feature Upsell");
|
||||
"Documentation Viewed".to_string()
|
||||
},
|
||||
),
|
||||
};
|
||||
(event_type, properties)
|
||||
}
|
||||
Event::Setting(e) => (
|
||||
"Settings Changed".to_string(),
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::Extension(e) => (
|
||||
"Extension Loaded".to_string(),
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::Edit(e) => (
|
||||
"Editor Edited".to_string(),
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::Action(e) => (
|
||||
"Action Invoked".to_string(),
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::Repl(e) => (
|
||||
"Kernel Status Changed".to_string(),
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::Flexible(e) => (
|
||||
e.event_type.clone(),
|
||||
serde_json::to_value(&e.event_properties).unwrap(),
|
||||
|
@ -759,7 +599,7 @@ fn for_snowflake(
|
|||
})
|
||||
});
|
||||
|
||||
Some(SnowflakeRow {
|
||||
SnowflakeRow {
|
||||
time: timestamp,
|
||||
user_id: body.metrics_id.clone(),
|
||||
device_id: body.system_id.clone(),
|
||||
|
@ -767,7 +607,7 @@ fn for_snowflake(
|
|||
event_properties,
|
||||
user_properties,
|
||||
insert_id: Some(Uuid::new_v4().to_string()),
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use crate::db::{BillingCustomerId, BillingSubscriptionId};
|
||||
use crate::stripe_client;
|
||||
use chrono::{Datelike as _, NaiveDate, Utc};
|
||||
use sea_orm::entity::prelude::*;
|
||||
use serde::Serialize;
|
||||
|
@ -160,17 +159,3 @@ pub enum StripeCancellationReason {
|
|||
#[sea_orm(string_value = "payment_failed")]
|
||||
PaymentFailed,
|
||||
}
|
||||
|
||||
impl From<stripe_client::StripeCancellationDetailsReason> for StripeCancellationReason {
|
||||
fn from(value: stripe_client::StripeCancellationDetailsReason) -> Self {
|
||||
match value {
|
||||
stripe_client::StripeCancellationDetailsReason::CancellationRequested => {
|
||||
Self::CancellationRequested
|
||||
}
|
||||
stripe_client::StripeCancellationDetailsReason::PaymentDisputed => {
|
||||
Self::PaymentDisputed
|
||||
}
|
||||
stripe_client::StripeCancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,8 +7,6 @@ pub mod llm;
|
|||
pub mod migrations;
|
||||
pub mod rpc;
|
||||
pub mod seed;
|
||||
pub mod stripe_billing;
|
||||
pub mod stripe_client;
|
||||
pub mod user_backfiller;
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -27,16 +25,12 @@ use serde::Deserialize;
|
|||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::stripe_billing::StripeBilling;
|
||||
use crate::stripe_client::{RealStripeClient, StripeClient};
|
||||
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
||||
pub enum Error {
|
||||
Http(StatusCode, String, HeaderMap),
|
||||
Database(sea_orm::error::DbErr),
|
||||
Internal(anyhow::Error),
|
||||
Stripe(stripe::StripeError),
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for Error {
|
||||
|
@ -51,12 +45,6 @@ impl From<sea_orm::error::DbErr> for Error {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<stripe::StripeError> for Error {
|
||||
fn from(error: stripe::StripeError) -> Self {
|
||||
Self::Stripe(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<axum::Error> for Error {
|
||||
fn from(error: axum::Error) -> Self {
|
||||
Self::Internal(error.into())
|
||||
|
@ -104,14 +92,6 @@ impl IntoResponse for Error {
|
|||
);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
|
||||
}
|
||||
Error::Stripe(error) => {
|
||||
log::error!(
|
||||
"HTTP error {}: {:?}",
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
&error
|
||||
);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -122,7 +102,6 @@ impl std::fmt::Debug for Error {
|
|||
Error::Http(code, message, _headers) => (code, message).fmt(f),
|
||||
Error::Database(error) => error.fmt(f),
|
||||
Error::Internal(error) => error.fmt(f),
|
||||
Error::Stripe(error) => error.fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -133,7 +112,6 @@ impl std::fmt::Display for Error {
|
|||
Error::Http(code, message, _) => write!(f, "{code}: {message}"),
|
||||
Error::Database(error) => error.fmt(f),
|
||||
Error::Internal(error) => error.fmt(f),
|
||||
Error::Stripe(error) => error.fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -179,7 +157,6 @@ pub struct Config {
|
|||
pub zed_client_checksum_seed: Option<String>,
|
||||
pub slack_panics_webhook: Option<String>,
|
||||
pub auto_join_channel_id: Option<ChannelId>,
|
||||
pub stripe_api_key: Option<String>,
|
||||
pub supermaven_admin_api_key: Option<Arc<str>>,
|
||||
pub user_backfiller_github_access_token: Option<Arc<str>>,
|
||||
}
|
||||
|
@ -234,7 +211,6 @@ impl Config {
|
|||
auto_join_channel_id: None,
|
||||
migrations_path: None,
|
||||
seed_path: None,
|
||||
stripe_api_key: None,
|
||||
supermaven_admin_api_key: None,
|
||||
user_backfiller_github_access_token: None,
|
||||
kinesis_region: None,
|
||||
|
@ -269,11 +245,6 @@ pub struct AppState {
|
|||
pub llm_db: Option<Arc<LlmDatabase>>,
|
||||
pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
|
||||
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
||||
/// This is a real instance of the Stripe client; we're working to replace references to this with the
|
||||
/// [`StripeClient`] trait.
|
||||
pub real_stripe_client: Option<Arc<stripe::Client>>,
|
||||
pub stripe_client: Option<Arc<dyn StripeClient>>,
|
||||
pub stripe_billing: Option<Arc<StripeBilling>>,
|
||||
pub executor: Executor,
|
||||
pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
|
||||
pub config: Config,
|
||||
|
@ -316,18 +287,11 @@ impl AppState {
|
|||
};
|
||||
|
||||
let db = Arc::new(db);
|
||||
let stripe_client = build_stripe_client(&config).map(Arc::new).log_err();
|
||||
let this = Self {
|
||||
db: db.clone(),
|
||||
llm_db,
|
||||
livekit_client,
|
||||
blob_store_client: build_blob_store_client(&config).await.log_err(),
|
||||
stripe_billing: stripe_client
|
||||
.clone()
|
||||
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
|
||||
real_stripe_client: stripe_client.clone(),
|
||||
stripe_client: stripe_client
|
||||
.map(|stripe_client| Arc::new(RealStripeClient::new(stripe_client)) as _),
|
||||
executor,
|
||||
kinesis_client: if config.kinesis_access_key.is_some() {
|
||||
build_kinesis_client(&config).await.log_err()
|
||||
|
@ -340,14 +304,6 @@ impl AppState {
|
|||
}
|
||||
}
|
||||
|
||||
fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
|
||||
let api_key = config
|
||||
.stripe_api_key
|
||||
.as_ref()
|
||||
.context("missing stripe_api_key")?;
|
||||
Ok(stripe::Client::new(api_key))
|
||||
}
|
||||
|
||||
async fn build_blob_store_client(config: &Config) -> anyhow::Result<aws_sdk_s3::Client> {
|
||||
let keys = aws_sdk_s3::config::Credentials::new(
|
||||
config
|
||||
|
|
|
@ -1,12 +1 @@
|
|||
pub mod db;
|
||||
mod token;
|
||||
|
||||
pub use token::*;
|
||||
|
||||
pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial";
|
||||
|
||||
/// The name of the feature flag that bypasses the account age check.
|
||||
pub const BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG: &str = "bypass-account-age-check";
|
||||
|
||||
/// The minimum account age an account must have in order to use the LLM service.
|
||||
pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
|
||||
|
|
|
@ -1,146 +0,0 @@
|
|||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use crate::db::{billing_customer, billing_subscription, user};
|
||||
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG};
|
||||
use crate::{Config, db::billing_preference};
|
||||
use anyhow::{Context as _, Result};
|
||||
use chrono::{NaiveDateTime, Utc};
|
||||
use cloud_llm_client::Plan;
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct LlmTokenClaims {
|
||||
pub iat: u64,
|
||||
pub exp: u64,
|
||||
pub jti: String,
|
||||
pub user_id: u64,
|
||||
pub system_id: Option<String>,
|
||||
pub metrics_id: Uuid,
|
||||
pub github_user_login: String,
|
||||
pub account_created_at: NaiveDateTime,
|
||||
pub is_staff: bool,
|
||||
pub has_llm_closed_beta_feature_flag: bool,
|
||||
pub bypass_account_age_check: bool,
|
||||
pub use_llm_request_queue: bool,
|
||||
pub plan: Plan,
|
||||
pub has_extended_trial: bool,
|
||||
pub subscription_period: (NaiveDateTime, NaiveDateTime),
|
||||
pub enable_model_request_overages: bool,
|
||||
pub model_request_overages_spend_limit_in_cents: u32,
|
||||
pub can_use_web_search_tool: bool,
|
||||
#[serde(default)]
|
||||
pub has_overdue_invoices: bool,
|
||||
}
|
||||
|
||||
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
|
||||
|
||||
impl LlmTokenClaims {
|
||||
pub fn create(
|
||||
user: &user::Model,
|
||||
is_staff: bool,
|
||||
billing_customer: billing_customer::Model,
|
||||
billing_preferences: Option<billing_preference::Model>,
|
||||
feature_flags: &Vec<String>,
|
||||
subscription: billing_subscription::Model,
|
||||
system_id: Option<String>,
|
||||
config: &Config,
|
||||
) -> Result<String> {
|
||||
let secret = config
|
||||
.llm_api_secret
|
||||
.as_ref()
|
||||
.context("no LLM API secret")?;
|
||||
|
||||
let plan = if is_staff {
|
||||
Plan::ZedPro
|
||||
} else {
|
||||
subscription.kind.map_or(Plan::ZedFree, |kind| match kind {
|
||||
SubscriptionKind::ZedFree => Plan::ZedFree,
|
||||
SubscriptionKind::ZedPro => Plan::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
|
||||
})
|
||||
};
|
||||
let subscription_period =
|
||||
billing_subscription::Model::current_period(Some(subscription), is_staff)
|
||||
.map(|(start, end)| (start.naive_utc(), end.naive_utc()))
|
||||
.context("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started.")?;
|
||||
|
||||
let now = Utc::now();
|
||||
let claims = Self {
|
||||
iat: now.timestamp() as u64,
|
||||
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
|
||||
jti: uuid::Uuid::new_v4().to_string(),
|
||||
user_id: user.id.to_proto(),
|
||||
system_id,
|
||||
metrics_id: user.metrics_id,
|
||||
github_user_login: user.github_login.clone(),
|
||||
account_created_at: user.account_created_at(),
|
||||
is_staff,
|
||||
has_llm_closed_beta_feature_flag: feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == "llm-closed-beta"),
|
||||
bypass_account_age_check: feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG),
|
||||
can_use_web_search_tool: true,
|
||||
use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"),
|
||||
plan,
|
||||
has_extended_trial: feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG),
|
||||
subscription_period,
|
||||
enable_model_request_overages: billing_preferences
|
||||
.as_ref()
|
||||
.map_or(false, |preferences| {
|
||||
preferences.model_request_overages_enabled
|
||||
}),
|
||||
model_request_overages_spend_limit_in_cents: billing_preferences
|
||||
.as_ref()
|
||||
.map_or(0, |preferences| {
|
||||
preferences.model_request_overages_spend_limit_in_cents as u32
|
||||
}),
|
||||
has_overdue_invoices: billing_customer.has_overdue_invoices,
|
||||
};
|
||||
|
||||
Ok(jsonwebtoken::encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(secret.as_ref()),
|
||||
)?)
|
||||
}
|
||||
|
||||
pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
|
||||
let secret = config
|
||||
.llm_api_secret
|
||||
.as_ref()
|
||||
.context("no LLM API secret")?;
|
||||
|
||||
match jsonwebtoken::decode::<Self>(
|
||||
token,
|
||||
&DecodingKey::from_secret(secret.as_ref()),
|
||||
&Validation::default(),
|
||||
) {
|
||||
Ok(token) => Ok(token.claims),
|
||||
Err(e) => {
|
||||
if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
|
||||
Err(ValidateLlmTokenError::Expired)
|
||||
} else {
|
||||
Err(ValidateLlmTokenError::JwtError(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ValidateLlmTokenError {
|
||||
#[error("access token is expired")]
|
||||
Expired,
|
||||
#[error("access token validation error: {0}")]
|
||||
JwtError(#[from] jsonwebtoken::errors::Error),
|
||||
#[error("{0}")]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
|
@ -102,13 +102,6 @@ async fn main() -> Result<()> {
|
|||
|
||||
let state = AppState::new(config, Executor::Production).await?;
|
||||
|
||||
if let Some(stripe_billing) = state.stripe_billing.clone() {
|
||||
let executor = state.executor.clone();
|
||||
executor.spawn_detached(async move {
|
||||
stripe_billing.initialize().await.trace_err();
|
||||
});
|
||||
}
|
||||
|
||||
if mode.is_collab() {
|
||||
state.db.purge_old_embeddings().await.trace_err();
|
||||
|
||||
|
|
|
@ -1,14 +1,6 @@
|
|||
mod connection_pool;
|
||||
|
||||
use crate::api::billing::find_or_create_billing_customer;
|
||||
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
||||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use crate::llm::db::LlmDatabase;
|
||||
use crate::llm::{
|
||||
AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
|
||||
MIN_ACCOUNT_AGE_FOR_LLM_USE,
|
||||
};
|
||||
use crate::stripe_client::StripeCustomerId;
|
||||
use crate::{
|
||||
AppState, Error, Result, auth,
|
||||
db::{
|
||||
|
@ -37,7 +29,6 @@ use axum::{
|
|||
response::IntoResponse,
|
||||
routing::get,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use collections::{HashMap, HashSet};
|
||||
pub use connection_pool::{ConnectionPool, ZedVersion};
|
||||
use core::fmt::{self, Debug, Formatter};
|
||||
|
@ -148,13 +139,6 @@ pub enum Principal {
|
|||
}
|
||||
|
||||
impl Principal {
|
||||
fn user(&self) -> &User {
|
||||
match self {
|
||||
Principal::User(user) => user,
|
||||
Principal::Impersonated { user, .. } => user,
|
||||
}
|
||||
}
|
||||
|
||||
fn update_span(&self, span: &tracing::Span) {
|
||||
match &self {
|
||||
Principal::User(user) => {
|
||||
|
@ -218,6 +202,7 @@ struct Session {
|
|||
/// The GeoIP country code for the user.
|
||||
#[allow(unused)]
|
||||
geoip_country_code: Option<String>,
|
||||
#[allow(unused)]
|
||||
system_id: Option<String>,
|
||||
_executor: Executor,
|
||||
}
|
||||
|
@ -463,9 +448,6 @@ impl Server {
|
|||
.add_request_handler(follow)
|
||||
.add_message_handler(unfollow)
|
||||
.add_message_handler(update_followers)
|
||||
.add_request_handler(get_private_user_info)
|
||||
.add_request_handler(get_llm_api_token)
|
||||
.add_request_handler(accept_terms_of_service)
|
||||
.add_message_handler(acknowledge_channel_message)
|
||||
.add_message_handler(acknowledge_buffer_version)
|
||||
.add_request_handler(get_supermaven_api_key)
|
||||
|
@ -1000,8 +982,6 @@ impl Server {
|
|||
.await?;
|
||||
}
|
||||
|
||||
update_user_plan(session).await?;
|
||||
|
||||
let contacts = self.app_state.db.get_contacts(user.id).await?;
|
||||
|
||||
{
|
||||
|
@ -2835,214 +2815,6 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
|
|||
version.0.minor() < 139
|
||||
}
|
||||
|
||||
async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
|
||||
if is_staff {
|
||||
return Ok(proto::Plan::ZedPro);
|
||||
}
|
||||
|
||||
let subscription = db.get_active_billing_subscription(user_id).await?;
|
||||
let subscription_kind = subscription.and_then(|subscription| subscription.kind);
|
||||
|
||||
let plan = if let Some(subscription_kind) = subscription_kind {
|
||||
match subscription_kind {
|
||||
SubscriptionKind::ZedPro => proto::Plan::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
|
||||
SubscriptionKind::ZedFree => proto::Plan::Free,
|
||||
}
|
||||
} else {
|
||||
proto::Plan::Free
|
||||
};
|
||||
|
||||
Ok(plan)
|
||||
}
|
||||
|
||||
async fn make_update_user_plan_message(
|
||||
user: &User,
|
||||
is_staff: bool,
|
||||
db: &Arc<Database>,
|
||||
llm_db: Option<Arc<LlmDatabase>>,
|
||||
) -> Result<proto::UpdateUserPlan> {
|
||||
let feature_flags = db.get_user_flags(user.id).await?;
|
||||
let plan = current_plan(db, user.id, is_staff).await?;
|
||||
let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
|
||||
let billing_preferences = db.get_billing_preferences(user.id).await?;
|
||||
|
||||
let (subscription_period, usage) = if let Some(llm_db) = llm_db {
|
||||
let subscription = db.get_active_billing_subscription(user.id).await?;
|
||||
|
||||
let subscription_period =
|
||||
crate::db::billing_subscription::Model::current_period(subscription, is_staff);
|
||||
|
||||
let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
|
||||
llm_db
|
||||
.get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
|
||||
.await?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(subscription_period, usage)
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let bypass_account_age_check = feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
|
||||
let account_too_young = !matches!(plan, proto::Plan::ZedPro)
|
||||
&& !bypass_account_age_check
|
||||
&& user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
|
||||
|
||||
Ok(proto::UpdateUserPlan {
|
||||
plan: plan.into(),
|
||||
trial_started_at: billing_customer
|
||||
.as_ref()
|
||||
.and_then(|billing_customer| billing_customer.trial_started_at)
|
||||
.map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
|
||||
is_usage_based_billing_enabled: if is_staff {
|
||||
Some(true)
|
||||
} else {
|
||||
billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
|
||||
},
|
||||
subscription_period: subscription_period.map(|(started_at, ended_at)| {
|
||||
proto::SubscriptionPeriod {
|
||||
started_at: started_at.timestamp() as u64,
|
||||
ended_at: ended_at.timestamp() as u64,
|
||||
}
|
||||
}),
|
||||
account_too_young: Some(account_too_young),
|
||||
has_overdue_invoices: billing_customer
|
||||
.map(|billing_customer| billing_customer.has_overdue_invoices),
|
||||
usage: Some(
|
||||
usage
|
||||
.map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
|
||||
.unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
fn model_requests_limit(
|
||||
plan: cloud_llm_client::Plan,
|
||||
feature_flags: &Vec<String>,
|
||||
) -> cloud_llm_client::UsageLimit {
|
||||
match plan.model_requests_limit() {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
let limit = if plan == cloud_llm_client::Plan::ZedProTrial
|
||||
&& feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
|
||||
{
|
||||
1_000
|
||||
} else {
|
||||
limit
|
||||
};
|
||||
|
||||
cloud_llm_client::UsageLimit::Limited(limit)
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
|
||||
}
|
||||
}
|
||||
|
||||
fn subscription_usage_to_proto(
|
||||
plan: proto::Plan,
|
||||
usage: crate::llm::db::subscription_usage::Model,
|
||||
feature_flags: &Vec<String>,
|
||||
) -> proto::SubscriptionUsage {
|
||||
let plan = match plan {
|
||||
proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
|
||||
proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
|
||||
proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
|
||||
};
|
||||
|
||||
proto::SubscriptionUsage {
|
||||
model_requests_usage_amount: usage.model_requests as u32,
|
||||
model_requests_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match model_requests_limit(plan, feature_flags) {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
}),
|
||||
edit_predictions_usage_amount: usage.edit_predictions as u32,
|
||||
edit_predictions_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match plan.edit_predictions_limit() {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn make_default_subscription_usage(
|
||||
plan: proto::Plan,
|
||||
feature_flags: &Vec<String>,
|
||||
) -> proto::SubscriptionUsage {
|
||||
let plan = match plan {
|
||||
proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
|
||||
proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
|
||||
proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
|
||||
};
|
||||
|
||||
proto::SubscriptionUsage {
|
||||
model_requests_usage_amount: 0,
|
||||
model_requests_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match model_requests_limit(plan, feature_flags) {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
}),
|
||||
edit_predictions_usage_amount: 0,
|
||||
edit_predictions_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match plan.edit_predictions_limit() {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn update_user_plan(session: &Session) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
|
||||
let update_user_plan = make_update_user_plan_message(
|
||||
session.principal.user(),
|
||||
session.is_staff(),
|
||||
&db.0,
|
||||
session.app_state.llm_db.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
session
|
||||
.peer
|
||||
.send(session.connection_id, update_user_plan)
|
||||
.trace_err();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn subscribe_to_channels(
|
||||
_: proto::SubscribeToChannels,
|
||||
session: MessageContext,
|
||||
|
@ -4211,139 +3983,6 @@ async fn mark_notification_as_read(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the current users information
|
||||
async fn get_private_user_info(
|
||||
_request: proto::GetPrivateUserInfo,
|
||||
response: Response<proto::GetPrivateUserInfo>,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
|
||||
let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
|
||||
let user = db
|
||||
.get_user_by_id(session.user_id())
|
||||
.await?
|
||||
.context("user not found")?;
|
||||
let flags = db.get_user_flags(session.user_id()).await?;
|
||||
|
||||
response.send(proto::GetPrivateUserInfoResponse {
|
||||
metrics_id,
|
||||
staff: user.admin,
|
||||
flags,
|
||||
accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Accept the terms of service (tos) on behalf of the current user
|
||||
async fn accept_terms_of_service(
|
||||
_request: proto::AcceptTermsOfService,
|
||||
response: Response<proto::AcceptTermsOfService>,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
|
||||
let accepted_tos_at = Utc::now();
|
||||
db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
|
||||
.await?;
|
||||
|
||||
response.send(proto::AcceptTermsOfServiceResponse {
|
||||
accepted_tos_at: accepted_tos_at.timestamp() as u64,
|
||||
})?;
|
||||
|
||||
// When the user accepts the terms of service, we want to refresh their LLM
|
||||
// token to grant access.
|
||||
session
|
||||
.peer
|
||||
.send(session.connection_id, proto::RefreshLlmToken {})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_llm_api_token(
|
||||
_request: proto::GetLlmToken,
|
||||
response: Response<proto::GetLlmToken>,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
|
||||
let flags = db.get_user_flags(session.user_id()).await?;
|
||||
|
||||
let user_id = session.user_id();
|
||||
let user = db
|
||||
.get_user_by_id(user_id)
|
||||
.await?
|
||||
.with_context(|| format!("user {user_id} not found"))?;
|
||||
|
||||
if user.accepted_tos_at.is_none() {
|
||||
Err(anyhow!("terms of service not accepted"))?
|
||||
}
|
||||
|
||||
let stripe_client = session
|
||||
.app_state
|
||||
.stripe_client
|
||||
.as_ref()
|
||||
.context("failed to retrieve Stripe client")?;
|
||||
|
||||
let stripe_billing = session
|
||||
.app_state
|
||||
.stripe_billing
|
||||
.as_ref()
|
||||
.context("failed to retrieve Stripe billing object")?;
|
||||
|
||||
let billing_customer = if let Some(billing_customer) =
|
||||
db.get_billing_customer_by_user_id(user.id).await?
|
||||
{
|
||||
billing_customer
|
||||
} else {
|
||||
let customer_id = stripe_billing
|
||||
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||
.await?;
|
||||
|
||||
find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
|
||||
.await?
|
||||
.context("billing customer not found")?
|
||||
};
|
||||
|
||||
let billing_subscription =
|
||||
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
|
||||
billing_subscription
|
||||
} else {
|
||||
let stripe_customer_id =
|
||||
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
|
||||
|
||||
let stripe_subscription = stripe_billing
|
||||
.subscribe_to_zed_free(stripe_customer_id)
|
||||
.await?;
|
||||
|
||||
db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
|
||||
billing_customer_id: billing_customer.id,
|
||||
kind: Some(SubscriptionKind::ZedFree),
|
||||
stripe_subscription_id: stripe_subscription.id.to_string(),
|
||||
stripe_subscription_status: stripe_subscription.status.into(),
|
||||
stripe_cancellation_reason: None,
|
||||
stripe_current_period_start: Some(stripe_subscription.current_period_start),
|
||||
stripe_current_period_end: Some(stripe_subscription.current_period_end),
|
||||
})
|
||||
.await?
|
||||
};
|
||||
|
||||
let billing_preferences = db.get_billing_preferences(user.id).await?;
|
||||
|
||||
let token = LlmTokenClaims::create(
|
||||
&user,
|
||||
session.is_staff(),
|
||||
billing_customer,
|
||||
billing_preferences,
|
||||
&flags,
|
||||
billing_subscription,
|
||||
session.system_id.clone(),
|
||||
&session.app_state.config,
|
||||
)?;
|
||||
response.send(proto::GetLlmTokenResponse { token })?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
|
||||
let message = match message {
|
||||
TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
|
||||
|
|
|
@ -1,156 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use collections::HashMap;
|
||||
use stripe::SubscriptionStatus;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::Result;
|
||||
use crate::stripe_client::{
|
||||
RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateSubscriptionItems,
|
||||
StripeCreateSubscriptionParams, StripeCustomerId, StripePrice, StripePriceId,
|
||||
StripeSubscription,
|
||||
};
|
||||
|
||||
pub struct StripeBilling {
|
||||
state: RwLock<StripeBillingState>,
|
||||
client: Arc<dyn StripeClient>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct StripeBillingState {
|
||||
prices_by_lookup_key: HashMap<String, StripePrice>,
|
||||
}
|
||||
|
||||
impl StripeBilling {
|
||||
pub fn new(client: Arc<stripe::Client>) -> Self {
|
||||
Self {
|
||||
client: Arc::new(RealStripeClient::new(client.clone())),
|
||||
state: RwLock::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
state: RwLock::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn client(&self) -> &Arc<dyn StripeClient> {
|
||||
&self.client
|
||||
}
|
||||
|
||||
pub async fn initialize(&self) -> Result<()> {
|
||||
log::info!("StripeBilling: initializing");
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
|
||||
let prices = self.client.list_prices().await?;
|
||||
|
||||
for price in prices {
|
||||
if let Some(lookup_key) = price.lookup_key.clone() {
|
||||
state.prices_by_lookup_key.insert(lookup_key, price);
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("StripeBilling: initialized");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
|
||||
self.find_price_id_by_lookup_key("zed-pro").await
|
||||
}
|
||||
|
||||
pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
|
||||
self.find_price_id_by_lookup_key("zed-free").await
|
||||
}
|
||||
|
||||
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
|
||||
self.state
|
||||
.read()
|
||||
.await
|
||||
.prices_by_lookup_key
|
||||
.get(lookup_key)
|
||||
.map(|price| price.id.clone())
|
||||
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
|
||||
}
|
||||
|
||||
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
|
||||
self.state
|
||||
.read()
|
||||
.await
|
||||
.prices_by_lookup_key
|
||||
.get(lookup_key)
|
||||
.cloned()
|
||||
.ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
|
||||
}
|
||||
|
||||
/// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
|
||||
/// not already exist.
|
||||
///
|
||||
/// Always returns a new Stripe customer if the email address is `None`.
|
||||
pub async fn find_or_create_customer_by_email(
|
||||
&self,
|
||||
email_address: Option<&str>,
|
||||
) -> Result<StripeCustomerId> {
|
||||
let existing_customer = if let Some(email) = email_address {
|
||||
let customers = self.client.list_customers_by_email(email).await?;
|
||||
|
||||
customers.first().cloned()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let customer_id = if let Some(existing_customer) = existing_customer {
|
||||
existing_customer.id
|
||||
} else {
|
||||
let customer = self
|
||||
.client
|
||||
.create_customer(crate::stripe_client::CreateCustomerParams {
|
||||
email: email_address,
|
||||
})
|
||||
.await?;
|
||||
|
||||
customer.id
|
||||
};
|
||||
|
||||
Ok(customer_id)
|
||||
}
|
||||
|
||||
pub async fn subscribe_to_zed_free(
|
||||
&self,
|
||||
customer_id: StripeCustomerId,
|
||||
) -> Result<StripeSubscription> {
|
||||
let zed_free_price_id = self.zed_free_price_id().await?;
|
||||
|
||||
let existing_subscriptions = self
|
||||
.client
|
||||
.list_subscriptions_for_customer(&customer_id)
|
||||
.await?;
|
||||
|
||||
let existing_active_subscription =
|
||||
existing_subscriptions.into_iter().find(|subscription| {
|
||||
subscription.status == SubscriptionStatus::Active
|
||||
|| subscription.status == SubscriptionStatus::Trialing
|
||||
});
|
||||
if let Some(subscription) = existing_active_subscription {
|
||||
return Ok(subscription);
|
||||
}
|
||||
|
||||
let params = StripeCreateSubscriptionParams {
|
||||
customer: customer_id,
|
||||
items: vec![StripeCreateSubscriptionItems {
|
||||
price: Some(zed_free_price_id),
|
||||
quantity: Some(1),
|
||||
}],
|
||||
automatic_tax: Some(StripeAutomaticTax { enabled: true }),
|
||||
};
|
||||
|
||||
let subscription = self.client.create_subscription(params).await?;
|
||||
|
||||
Ok(subscription)
|
||||
}
|
||||
}
|
|
@ -1,285 +0,0 @@
|
|||
#[cfg(test)]
|
||||
mod fake_stripe_client;
|
||||
mod real_stripe_client;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[cfg(test)]
|
||||
pub use fake_stripe_client::*;
|
||||
pub use real_stripe_client::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Serialize)]
|
||||
pub struct StripeCustomerId(pub Arc<str>);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StripeCustomer {
|
||||
pub id: StripeCustomerId,
|
||||
pub email: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CreateCustomerParams<'a> {
|
||||
pub email: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UpdateCustomerParams<'a> {
|
||||
pub email: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
|
||||
pub struct StripeSubscriptionId(pub Arc<str>);
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct StripeSubscription {
|
||||
pub id: StripeSubscriptionId,
|
||||
pub customer: StripeCustomerId,
|
||||
// TODO: Create our own version of this enum.
|
||||
pub status: stripe::SubscriptionStatus,
|
||||
pub current_period_end: i64,
|
||||
pub current_period_start: i64,
|
||||
pub items: Vec<StripeSubscriptionItem>,
|
||||
pub cancel_at: Option<i64>,
|
||||
pub cancellation_details: Option<StripeCancellationDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
|
||||
pub struct StripeSubscriptionItemId(pub Arc<str>);
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct StripeSubscriptionItem {
|
||||
pub id: StripeSubscriptionItemId,
|
||||
pub price: Option<StripePrice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct StripeCancellationDetails {
|
||||
pub reason: Option<StripeCancellationDetailsReason>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum StripeCancellationDetailsReason {
|
||||
CancellationRequested,
|
||||
PaymentDisputed,
|
||||
PaymentFailed,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StripeCreateSubscriptionParams {
|
||||
pub customer: StripeCustomerId,
|
||||
pub items: Vec<StripeCreateSubscriptionItems>,
|
||||
pub automatic_tax: Option<StripeAutomaticTax>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StripeCreateSubscriptionItems {
|
||||
pub price: Option<StripePriceId>,
|
||||
pub quantity: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UpdateSubscriptionParams {
|
||||
pub items: Option<Vec<UpdateSubscriptionItems>>,
|
||||
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct UpdateSubscriptionItems {
|
||||
pub price: Option<StripePriceId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct StripeSubscriptionTrialSettings {
|
||||
pub end_behavior: StripeSubscriptionTrialSettingsEndBehavior,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct StripeSubscriptionTrialSettingsEndBehavior {
|
||||
pub missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
|
||||
Cancel,
|
||||
CreateInvoice,
|
||||
Pause,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
|
||||
pub struct StripePriceId(pub Arc<str>);
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct StripePrice {
|
||||
pub id: StripePriceId,
|
||||
pub unit_amount: Option<i64>,
|
||||
pub lookup_key: Option<String>,
|
||||
pub recurring: Option<StripePriceRecurring>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct StripePriceRecurring {
|
||||
pub meter: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Deserialize)]
|
||||
pub struct StripeMeterId(pub Arc<str>);
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct StripeMeter {
|
||||
pub id: StripeMeterId,
|
||||
pub event_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct StripeCreateMeterEventParams<'a> {
|
||||
pub identifier: &'a str,
|
||||
pub event_name: &'a str,
|
||||
pub payload: StripeCreateMeterEventPayload<'a>,
|
||||
pub timestamp: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct StripeCreateMeterEventPayload<'a> {
|
||||
pub value: u64,
|
||||
pub stripe_customer_id: &'a StripeCustomerId,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum StripeBillingAddressCollection {
|
||||
Auto,
|
||||
Required,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct StripeCustomerUpdate {
|
||||
pub address: Option<StripeCustomerUpdateAddress>,
|
||||
pub name: Option<StripeCustomerUpdateName>,
|
||||
pub shipping: Option<StripeCustomerUpdateShipping>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum StripeCustomerUpdateAddress {
|
||||
Auto,
|
||||
Never,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum StripeCustomerUpdateName {
|
||||
Auto,
|
||||
Never,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum StripeCustomerUpdateShipping {
|
||||
Auto,
|
||||
Never,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct StripeCreateCheckoutSessionParams<'a> {
|
||||
pub customer: Option<&'a StripeCustomerId>,
|
||||
pub client_reference_id: Option<&'a str>,
|
||||
pub mode: Option<StripeCheckoutSessionMode>,
|
||||
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
|
||||
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
|
||||
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
|
||||
pub success_url: Option<&'a str>,
|
||||
pub billing_address_collection: Option<StripeBillingAddressCollection>,
|
||||
pub customer_update: Option<StripeCustomerUpdate>,
|
||||
pub tax_id_collection: Option<StripeTaxIdCollection>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum StripeCheckoutSessionMode {
|
||||
Payment,
|
||||
Setup,
|
||||
Subscription,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct StripeCreateCheckoutSessionLineItems {
|
||||
pub price: Option<String>,
|
||||
pub quantity: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum StripeCheckoutSessionPaymentMethodCollection {
|
||||
Always,
|
||||
IfRequired,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct StripeCreateCheckoutSessionSubscriptionData {
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
pub trial_period_days: Option<u32>,
|
||||
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct StripeTaxIdCollection {
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StripeAutomaticTax {
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StripeCheckoutSession {
|
||||
pub url: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait StripeClient: Send + Sync {
|
||||
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
|
||||
|
||||
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer>;
|
||||
|
||||
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
|
||||
|
||||
async fn update_customer(
|
||||
&self,
|
||||
customer_id: &StripeCustomerId,
|
||||
params: UpdateCustomerParams<'_>,
|
||||
) -> Result<StripeCustomer>;
|
||||
|
||||
async fn list_subscriptions_for_customer(
|
||||
&self,
|
||||
customer_id: &StripeCustomerId,
|
||||
) -> Result<Vec<StripeSubscription>>;
|
||||
|
||||
async fn get_subscription(
|
||||
&self,
|
||||
subscription_id: &StripeSubscriptionId,
|
||||
) -> Result<StripeSubscription>;
|
||||
|
||||
async fn create_subscription(
|
||||
&self,
|
||||
params: StripeCreateSubscriptionParams,
|
||||
) -> Result<StripeSubscription>;
|
||||
|
||||
async fn update_subscription(
|
||||
&self,
|
||||
subscription_id: &StripeSubscriptionId,
|
||||
params: UpdateSubscriptionParams,
|
||||
) -> Result<()>;
|
||||
|
||||
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()>;
|
||||
|
||||
async fn list_prices(&self) -> Result<Vec<StripePrice>>;
|
||||
|
||||
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
|
||||
|
||||
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>;
|
||||
|
||||
async fn create_checkout_session(
|
||||
&self,
|
||||
params: StripeCreateCheckoutSessionParams<'_>,
|
||||
) -> Result<StripeCheckoutSession>;
|
||||
}
|
|
@ -1,247 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use chrono::{Duration, Utc};
|
||||
use collections::HashMap;
|
||||
use parking_lot::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::stripe_client::{
|
||||
CreateCustomerParams, StripeBillingAddressCollection, StripeCheckoutSession,
|
||||
StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
|
||||
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
|
||||
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
|
||||
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate,
|
||||
StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
|
||||
StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, StripeTaxIdCollection,
|
||||
UpdateCustomerParams, UpdateSubscriptionParams,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StripeCreateMeterEventCall {
|
||||
pub identifier: Arc<str>,
|
||||
pub event_name: Arc<str>,
|
||||
pub value: u64,
|
||||
pub stripe_customer_id: StripeCustomerId,
|
||||
pub timestamp: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StripeCreateCheckoutSessionCall {
|
||||
pub customer: Option<StripeCustomerId>,
|
||||
pub client_reference_id: Option<String>,
|
||||
pub mode: Option<StripeCheckoutSessionMode>,
|
||||
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
|
||||
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
|
||||
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
|
||||
pub success_url: Option<String>,
|
||||
pub billing_address_collection: Option<StripeBillingAddressCollection>,
|
||||
pub customer_update: Option<StripeCustomerUpdate>,
|
||||
pub tax_id_collection: Option<StripeTaxIdCollection>,
|
||||
}
|
||||
|
||||
pub struct FakeStripeClient {
|
||||
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
|
||||
pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
|
||||
pub update_subscription_calls:
|
||||
Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
|
||||
pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
|
||||
pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
|
||||
pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
|
||||
pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
|
||||
}
|
||||
|
||||
impl FakeStripeClient {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
customers: Arc::new(Mutex::new(HashMap::default())),
|
||||
subscriptions: Arc::new(Mutex::new(HashMap::default())),
|
||||
update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
|
||||
prices: Arc::new(Mutex::new(HashMap::default())),
|
||||
meters: Arc::new(Mutex::new(HashMap::default())),
|
||||
create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
|
||||
create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StripeClient for FakeStripeClient {
|
||||
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
|
||||
Ok(self
|
||||
.customers
|
||||
.lock()
|
||||
.values()
|
||||
.filter(|customer| customer.email.as_deref() == Some(email))
|
||||
.cloned()
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
|
||||
self.customers
|
||||
.lock()
|
||||
.get(customer_id)
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("no customer found for {customer_id:?}"))
|
||||
}
|
||||
|
||||
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
|
||||
let customer = StripeCustomer {
|
||||
id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
|
||||
email: params.email.map(|email| email.to_string()),
|
||||
};
|
||||
|
||||
self.customers
|
||||
.lock()
|
||||
.insert(customer.id.clone(), customer.clone());
|
||||
|
||||
Ok(customer)
|
||||
}
|
||||
|
||||
async fn update_customer(
|
||||
&self,
|
||||
customer_id: &StripeCustomerId,
|
||||
params: UpdateCustomerParams<'_>,
|
||||
) -> Result<StripeCustomer> {
|
||||
let mut customers = self.customers.lock();
|
||||
if let Some(customer) = customers.get_mut(customer_id) {
|
||||
if let Some(email) = params.email {
|
||||
customer.email = Some(email.to_string());
|
||||
}
|
||||
Ok(customer.clone())
|
||||
} else {
|
||||
Err(anyhow!("no customer found for {customer_id:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_subscriptions_for_customer(
|
||||
&self,
|
||||
customer_id: &StripeCustomerId,
|
||||
) -> Result<Vec<StripeSubscription>> {
|
||||
let subscriptions = self
|
||||
.subscriptions
|
||||
.lock()
|
||||
.values()
|
||||
.filter(|subscription| subscription.customer == *customer_id)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
Ok(subscriptions)
|
||||
}
|
||||
|
||||
async fn get_subscription(
|
||||
&self,
|
||||
subscription_id: &StripeSubscriptionId,
|
||||
) -> Result<StripeSubscription> {
|
||||
self.subscriptions
|
||||
.lock()
|
||||
.get(subscription_id)
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
|
||||
}
|
||||
|
||||
async fn create_subscription(
|
||||
&self,
|
||||
params: StripeCreateSubscriptionParams,
|
||||
) -> Result<StripeSubscription> {
|
||||
let now = Utc::now();
|
||||
|
||||
let subscription = StripeSubscription {
|
||||
id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
|
||||
customer: params.customer,
|
||||
status: stripe::SubscriptionStatus::Active,
|
||||
current_period_start: now.timestamp(),
|
||||
current_period_end: (now + Duration::days(30)).timestamp(),
|
||||
items: params
|
||||
.items
|
||||
.into_iter()
|
||||
.map(|item| StripeSubscriptionItem {
|
||||
id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
|
||||
price: item
|
||||
.price
|
||||
.and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
|
||||
})
|
||||
.collect(),
|
||||
cancel_at: None,
|
||||
cancellation_details: None,
|
||||
};
|
||||
|
||||
self.subscriptions
|
||||
.lock()
|
||||
.insert(subscription.id.clone(), subscription.clone());
|
||||
|
||||
Ok(subscription)
|
||||
}
|
||||
|
||||
async fn update_subscription(
|
||||
&self,
|
||||
subscription_id: &StripeSubscriptionId,
|
||||
params: UpdateSubscriptionParams,
|
||||
) -> Result<()> {
|
||||
let subscription = self.get_subscription(subscription_id).await?;
|
||||
|
||||
self.update_subscription_calls
|
||||
.lock()
|
||||
.push((subscription.id, params));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
|
||||
// TODO: Implement fake subscription cancellation.
|
||||
let _ = subscription_id;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
|
||||
let prices = self.prices.lock().values().cloned().collect();
|
||||
|
||||
Ok(prices)
|
||||
}
|
||||
|
||||
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
|
||||
let meters = self.meters.lock().values().cloned().collect();
|
||||
|
||||
Ok(meters)
|
||||
}
|
||||
|
||||
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
|
||||
self.create_meter_event_calls
|
||||
.lock()
|
||||
.push(StripeCreateMeterEventCall {
|
||||
identifier: params.identifier.into(),
|
||||
event_name: params.event_name.into(),
|
||||
value: params.payload.value,
|
||||
stripe_customer_id: params.payload.stripe_customer_id.clone(),
|
||||
timestamp: params.timestamp,
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_checkout_session(
|
||||
&self,
|
||||
params: StripeCreateCheckoutSessionParams<'_>,
|
||||
) -> Result<StripeCheckoutSession> {
|
||||
self.create_checkout_session_calls
|
||||
.lock()
|
||||
.push(StripeCreateCheckoutSessionCall {
|
||||
customer: params.customer.cloned(),
|
||||
client_reference_id: params.client_reference_id.map(|id| id.to_string()),
|
||||
mode: params.mode,
|
||||
line_items: params.line_items,
|
||||
payment_method_collection: params.payment_method_collection,
|
||||
subscription_data: params.subscription_data,
|
||||
success_url: params.success_url.map(|url| url.to_string()),
|
||||
billing_address_collection: params.billing_address_collection,
|
||||
customer_update: params.customer_update,
|
||||
tax_id_collection: params.tax_id_collection,
|
||||
});
|
||||
|
||||
Ok(StripeCheckoutSession {
|
||||
url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,612 +0,0 @@
|
|||
use std::str::FromStr as _;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use stripe::{
|
||||
CancellationDetails, CancellationDetailsReason, CheckoutSession, CheckoutSessionMode,
|
||||
CheckoutSessionPaymentMethodCollection, CreateCheckoutSession, CreateCheckoutSessionLineItems,
|
||||
CreateCheckoutSessionSubscriptionData, CreateCheckoutSessionSubscriptionDataTrialSettings,
|
||||
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior,
|
||||
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod,
|
||||
CreateCustomer, CreateSubscriptionAutomaticTax, Customer, CustomerId, ListCustomers, Price,
|
||||
PriceId, Recurring, Subscription, SubscriptionId, SubscriptionItem, SubscriptionItemId,
|
||||
UpdateCustomer, UpdateSubscriptionItems, UpdateSubscriptionTrialSettings,
|
||||
UpdateSubscriptionTrialSettingsEndBehavior,
|
||||
UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
|
||||
};
|
||||
|
||||
use crate::stripe_client::{
|
||||
CreateCustomerParams, StripeAutomaticTax, StripeBillingAddressCollection,
|
||||
StripeCancellationDetails, StripeCancellationDetailsReason, StripeCheckoutSession,
|
||||
StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
|
||||
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
|
||||
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
|
||||
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate,
|
||||
StripeCustomerUpdateAddress, StripeCustomerUpdateName, StripeCustomerUpdateShipping,
|
||||
StripeMeter, StripePrice, StripePriceId, StripePriceRecurring, StripeSubscription,
|
||||
StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
|
||||
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, StripeTaxIdCollection,
|
||||
UpdateCustomerParams, UpdateSubscriptionParams,
|
||||
};
|
||||
|
||||
pub struct RealStripeClient {
|
||||
client: Arc<stripe::Client>,
|
||||
}
|
||||
|
||||
impl RealStripeClient {
|
||||
pub fn new(client: Arc<stripe::Client>) -> Self {
|
||||
Self { client }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StripeClient for RealStripeClient {
|
||||
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
|
||||
let response = Customer::list(
|
||||
&self.client,
|
||||
&ListCustomers {
|
||||
email: Some(email),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(StripeCustomer::from)
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
|
||||
let customer_id = customer_id.try_into()?;
|
||||
|
||||
let customer = Customer::retrieve(&self.client, &customer_id, &[]).await?;
|
||||
|
||||
Ok(StripeCustomer::from(customer))
|
||||
}
|
||||
|
||||
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
|
||||
let customer = Customer::create(
|
||||
&self.client,
|
||||
CreateCustomer {
|
||||
email: params.email,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(StripeCustomer::from(customer))
|
||||
}
|
||||
|
||||
async fn update_customer(
|
||||
&self,
|
||||
customer_id: &StripeCustomerId,
|
||||
params: UpdateCustomerParams<'_>,
|
||||
) -> Result<StripeCustomer> {
|
||||
let customer = Customer::update(
|
||||
&self.client,
|
||||
&customer_id.try_into()?,
|
||||
UpdateCustomer {
|
||||
email: params.email,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(StripeCustomer::from(customer))
|
||||
}
|
||||
|
||||
async fn list_subscriptions_for_customer(
|
||||
&self,
|
||||
customer_id: &StripeCustomerId,
|
||||
) -> Result<Vec<StripeSubscription>> {
|
||||
let customer_id = customer_id.try_into()?;
|
||||
|
||||
let subscriptions = stripe::Subscription::list(
|
||||
&self.client,
|
||||
&stripe::ListSubscriptions {
|
||||
customer: Some(customer_id),
|
||||
status: None,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(subscriptions
|
||||
.data
|
||||
.into_iter()
|
||||
.map(StripeSubscription::from)
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn get_subscription(
|
||||
&self,
|
||||
subscription_id: &StripeSubscriptionId,
|
||||
) -> Result<StripeSubscription> {
|
||||
let subscription_id = subscription_id.try_into()?;
|
||||
|
||||
let subscription = Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
|
||||
|
||||
Ok(StripeSubscription::from(subscription))
|
||||
}
|
||||
|
||||
async fn create_subscription(
|
||||
&self,
|
||||
params: StripeCreateSubscriptionParams,
|
||||
) -> Result<StripeSubscription> {
|
||||
let customer_id = params.customer.try_into()?;
|
||||
|
||||
let mut create_subscription = stripe::CreateSubscription::new(customer_id);
|
||||
create_subscription.items = Some(
|
||||
params
|
||||
.items
|
||||
.into_iter()
|
||||
.map(|item| stripe::CreateSubscriptionItems {
|
||||
price: item.price.map(|price| price.to_string()),
|
||||
quantity: item.quantity,
|
||||
..Default::default()
|
||||
})
|
||||
.collect(),
|
||||
);
|
||||
create_subscription.automatic_tax = params.automatic_tax.map(Into::into);
|
||||
|
||||
let subscription = Subscription::create(&self.client, create_subscription).await?;
|
||||
|
||||
Ok(StripeSubscription::from(subscription))
|
||||
}
|
||||
|
||||
async fn update_subscription(
|
||||
&self,
|
||||
subscription_id: &StripeSubscriptionId,
|
||||
params: UpdateSubscriptionParams,
|
||||
) -> Result<()> {
|
||||
let subscription_id = subscription_id.try_into()?;
|
||||
|
||||
stripe::Subscription::update(
|
||||
&self.client,
|
||||
&subscription_id,
|
||||
stripe::UpdateSubscription {
|
||||
items: params.items.map(|items| {
|
||||
items
|
||||
.into_iter()
|
||||
.map(|item| UpdateSubscriptionItems {
|
||||
price: item.price.map(|price| price.to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.collect()
|
||||
}),
|
||||
trial_settings: params.trial_settings.map(Into::into),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
|
||||
let subscription_id = subscription_id.try_into()?;
|
||||
|
||||
Subscription::cancel(
|
||||
&self.client,
|
||||
&subscription_id,
|
||||
stripe::CancelSubscription {
|
||||
invoice_now: None,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
|
||||
let response = stripe::Price::list(
|
||||
&self.client,
|
||||
&stripe::ListPrices {
|
||||
limit: Some(100),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(response.data.into_iter().map(StripePrice::from).collect())
|
||||
}
|
||||
|
||||
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
|
||||
#[derive(Serialize)]
|
||||
struct Params {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
limit: Option<u64>,
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get_query::<stripe::List<StripeMeter>, _>(
|
||||
"/billing/meters",
|
||||
Params { limit: Some(100) },
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(response.data)
|
||||
}
|
||||
|
||||
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
|
||||
#[derive(Deserialize)]
|
||||
struct StripeMeterEvent {
|
||||
pub identifier: String,
|
||||
}
|
||||
|
||||
let identifier = params.identifier;
|
||||
match self
|
||||
.client
|
||||
.post_form::<StripeMeterEvent, _>("/billing/meter_events", params)
|
||||
.await
|
||||
{
|
||||
Ok(_event) => Ok(()),
|
||||
Err(stripe::StripeError::Stripe(error)) => {
|
||||
if error.http_status == 400
|
||||
&& error
|
||||
.message
|
||||
.as_ref()
|
||||
.map_or(false, |message| message.contains(identifier))
|
||||
{
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!(stripe::StripeError::Stripe(error)))
|
||||
}
|
||||
}
|
||||
Err(error) => Err(anyhow!("failed to create meter event: {error:?}")),
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_checkout_session(
|
||||
&self,
|
||||
params: StripeCreateCheckoutSessionParams<'_>,
|
||||
) -> Result<StripeCheckoutSession> {
|
||||
let params = params.try_into()?;
|
||||
let session = CheckoutSession::create(&self.client, params).await?;
|
||||
|
||||
Ok(session.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CustomerId> for StripeCustomerId {
|
||||
fn from(value: CustomerId) -> Self {
|
||||
Self(value.as_str().into())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<StripeCustomerId> for CustomerId {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: StripeCustomerId) -> Result<Self, Self::Error> {
|
||||
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&StripeCustomerId> for CustomerId {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: &StripeCustomerId) -> Result<Self, Self::Error> {
|
||||
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Customer> for StripeCustomer {
|
||||
fn from(value: Customer) -> Self {
|
||||
StripeCustomer {
|
||||
id: value.id.into(),
|
||||
email: value.email,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SubscriptionId> for StripeSubscriptionId {
|
||||
fn from(value: SubscriptionId) -> Self {
|
||||
Self(value.as_str().into())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&StripeSubscriptionId> for SubscriptionId {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: &StripeSubscriptionId) -> Result<Self, Self::Error> {
|
||||
Self::from_str(value.0.as_ref()).context("failed to parse Stripe subscription ID")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Subscription> for StripeSubscription {
|
||||
fn from(value: Subscription) -> Self {
|
||||
Self {
|
||||
id: value.id.into(),
|
||||
customer: value.customer.id().into(),
|
||||
status: value.status,
|
||||
current_period_start: value.current_period_start,
|
||||
current_period_end: value.current_period_end,
|
||||
items: value.items.data.into_iter().map(Into::into).collect(),
|
||||
cancel_at: value.cancel_at,
|
||||
cancellation_details: value.cancellation_details.map(Into::into),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CancellationDetails> for StripeCancellationDetails {
|
||||
fn from(value: CancellationDetails) -> Self {
|
||||
Self {
|
||||
reason: value.reason.map(Into::into),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CancellationDetailsReason> for StripeCancellationDetailsReason {
|
||||
fn from(value: CancellationDetailsReason) -> Self {
|
||||
match value {
|
||||
CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
|
||||
CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
|
||||
CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SubscriptionItemId> for StripeSubscriptionItemId {
|
||||
fn from(value: SubscriptionItemId) -> Self {
|
||||
Self(value.as_str().into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SubscriptionItem> for StripeSubscriptionItem {
|
||||
fn from(value: SubscriptionItem) -> Self {
|
||||
Self {
|
||||
id: value.id.into(),
|
||||
price: value.price.map(Into::into),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeAutomaticTax> for CreateSubscriptionAutomaticTax {
|
||||
fn from(value: StripeAutomaticTax) -> Self {
|
||||
Self {
|
||||
enabled: value.enabled,
|
||||
liability: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeSubscriptionTrialSettings> for UpdateSubscriptionTrialSettings {
|
||||
fn from(value: StripeSubscriptionTrialSettings) -> Self {
|
||||
Self {
|
||||
end_behavior: value.end_behavior.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeSubscriptionTrialSettingsEndBehavior>
|
||||
for UpdateSubscriptionTrialSettingsEndBehavior
|
||||
{
|
||||
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
|
||||
Self {
|
||||
missing_payment_method: value.missing_payment_method.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
|
||||
for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod
|
||||
{
|
||||
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
|
||||
match value {
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
|
||||
Self::CreateInvoice
|
||||
}
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PriceId> for StripePriceId {
|
||||
fn from(value: PriceId) -> Self {
|
||||
Self(value.as_str().into())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<StripePriceId> for PriceId {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: StripePriceId) -> Result<Self, Self::Error> {
|
||||
Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Price> for StripePrice {
|
||||
fn from(value: Price) -> Self {
|
||||
Self {
|
||||
id: value.id.into(),
|
||||
unit_amount: value.unit_amount,
|
||||
lookup_key: value.lookup_key,
|
||||
recurring: value.recurring.map(StripePriceRecurring::from),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Recurring> for StripePriceRecurring {
|
||||
fn from(value: Recurring) -> Self {
|
||||
Self { meter: value.meter }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<StripeCreateCheckoutSessionParams<'a>> for CreateCheckoutSession<'a> {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result<Self, Self::Error> {
|
||||
Ok(Self {
|
||||
customer: value
|
||||
.customer
|
||||
.map(|customer_id| customer_id.try_into())
|
||||
.transpose()?,
|
||||
client_reference_id: value.client_reference_id,
|
||||
mode: value.mode.map(Into::into),
|
||||
line_items: value
|
||||
.line_items
|
||||
.map(|line_items| line_items.into_iter().map(Into::into).collect()),
|
||||
payment_method_collection: value.payment_method_collection.map(Into::into),
|
||||
subscription_data: value.subscription_data.map(Into::into),
|
||||
success_url: value.success_url,
|
||||
billing_address_collection: value.billing_address_collection.map(Into::into),
|
||||
customer_update: value.customer_update.map(Into::into),
|
||||
tax_id_collection: value.tax_id_collection.map(Into::into),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeCheckoutSessionMode> for CheckoutSessionMode {
|
||||
fn from(value: StripeCheckoutSessionMode) -> Self {
|
||||
match value {
|
||||
StripeCheckoutSessionMode::Payment => Self::Payment,
|
||||
StripeCheckoutSessionMode::Setup => Self::Setup,
|
||||
StripeCheckoutSessionMode::Subscription => Self::Subscription,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeCreateCheckoutSessionLineItems> for CreateCheckoutSessionLineItems {
|
||||
fn from(value: StripeCreateCheckoutSessionLineItems) -> Self {
|
||||
Self {
|
||||
price: value.price,
|
||||
quantity: value.quantity,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeCheckoutSessionPaymentMethodCollection> for CheckoutSessionPaymentMethodCollection {
|
||||
fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self {
|
||||
match value {
|
||||
StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always,
|
||||
StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeCreateCheckoutSessionSubscriptionData> for CreateCheckoutSessionSubscriptionData {
|
||||
fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self {
|
||||
Self {
|
||||
trial_period_days: value.trial_period_days,
|
||||
trial_settings: value.trial_settings.map(Into::into),
|
||||
metadata: value.metadata,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeSubscriptionTrialSettings> for CreateCheckoutSessionSubscriptionDataTrialSettings {
|
||||
fn from(value: StripeSubscriptionTrialSettings) -> Self {
|
||||
Self {
|
||||
end_behavior: value.end_behavior.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeSubscriptionTrialSettingsEndBehavior>
|
||||
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior
|
||||
{
|
||||
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
|
||||
Self {
|
||||
missing_payment_method: value.missing_payment_method.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
|
||||
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod
|
||||
{
|
||||
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
|
||||
match value {
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
|
||||
Self::CreateInvoice
|
||||
}
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CheckoutSession> for StripeCheckoutSession {
|
||||
fn from(value: CheckoutSession) -> Self {
|
||||
Self { url: value.url }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeBillingAddressCollection> for stripe::CheckoutSessionBillingAddressCollection {
|
||||
fn from(value: StripeBillingAddressCollection) -> Self {
|
||||
match value {
|
||||
StripeBillingAddressCollection::Auto => {
|
||||
stripe::CheckoutSessionBillingAddressCollection::Auto
|
||||
}
|
||||
StripeBillingAddressCollection::Required => {
|
||||
stripe::CheckoutSessionBillingAddressCollection::Required
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeCustomerUpdateAddress> for stripe::CreateCheckoutSessionCustomerUpdateAddress {
|
||||
fn from(value: StripeCustomerUpdateAddress) -> Self {
|
||||
match value {
|
||||
StripeCustomerUpdateAddress::Auto => {
|
||||
stripe::CreateCheckoutSessionCustomerUpdateAddress::Auto
|
||||
}
|
||||
StripeCustomerUpdateAddress::Never => {
|
||||
stripe::CreateCheckoutSessionCustomerUpdateAddress::Never
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeCustomerUpdateName> for stripe::CreateCheckoutSessionCustomerUpdateName {
|
||||
fn from(value: StripeCustomerUpdateName) -> Self {
|
||||
match value {
|
||||
StripeCustomerUpdateName::Auto => stripe::CreateCheckoutSessionCustomerUpdateName::Auto,
|
||||
StripeCustomerUpdateName::Never => {
|
||||
stripe::CreateCheckoutSessionCustomerUpdateName::Never
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeCustomerUpdateShipping> for stripe::CreateCheckoutSessionCustomerUpdateShipping {
|
||||
fn from(value: StripeCustomerUpdateShipping) -> Self {
|
||||
match value {
|
||||
StripeCustomerUpdateShipping::Auto => {
|
||||
stripe::CreateCheckoutSessionCustomerUpdateShipping::Auto
|
||||
}
|
||||
StripeCustomerUpdateShipping::Never => {
|
||||
stripe::CreateCheckoutSessionCustomerUpdateShipping::Never
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeCustomerUpdate> for stripe::CreateCheckoutSessionCustomerUpdate {
|
||||
fn from(value: StripeCustomerUpdate) -> Self {
|
||||
stripe::CreateCheckoutSessionCustomerUpdate {
|
||||
address: value.address.map(Into::into),
|
||||
name: value.name.map(Into::into),
|
||||
shipping: value.shipping.map(Into::into),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeTaxIdCollection> for stripe::CreateCheckoutSessionTaxIdCollection {
|
||||
fn from(value: StripeTaxIdCollection) -> Self {
|
||||
stripe::CreateCheckoutSessionTaxIdCollection {
|
||||
enabled: value.enabled,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -8,7 +8,6 @@ mod channel_buffer_tests;
|
|||
mod channel_guest_tests;
|
||||
mod channel_message_tests;
|
||||
mod channel_tests;
|
||||
// mod debug_panel_tests;
|
||||
mod editor_tests;
|
||||
mod following_tests;
|
||||
mod git_tests;
|
||||
|
@ -18,7 +17,6 @@ mod random_channel_buffer_tests;
|
|||
mod random_project_collaboration_tests;
|
||||
mod randomized_test_helpers;
|
||||
mod remote_editing_collaboration_tests;
|
||||
mod stripe_billing_tests;
|
||||
mod test_server;
|
||||
|
||||
use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
|
||||
|
|
|
@ -1,123 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use crate::stripe_billing::StripeBilling;
|
||||
use crate::stripe_client::{FakeStripeClient, StripePrice, StripePriceId, StripePriceRecurring};
|
||||
|
||||
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
|
||||
let stripe_client = Arc::new(FakeStripeClient::new());
|
||||
let stripe_billing = StripeBilling::test(stripe_client.clone());
|
||||
|
||||
(stripe_billing, stripe_client)
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_initialize() {
|
||||
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||
|
||||
// Add test prices
|
||||
let price1 = StripePrice {
|
||||
id: StripePriceId("price_1".into()),
|
||||
unit_amount: Some(1_000),
|
||||
lookup_key: Some("zed-pro".to_string()),
|
||||
recurring: None,
|
||||
};
|
||||
let price2 = StripePrice {
|
||||
id: StripePriceId("price_2".into()),
|
||||
unit_amount: Some(0),
|
||||
lookup_key: Some("zed-free".to_string()),
|
||||
recurring: None,
|
||||
};
|
||||
let price3 = StripePrice {
|
||||
id: StripePriceId("price_3".into()),
|
||||
unit_amount: Some(500),
|
||||
lookup_key: None,
|
||||
recurring: Some(StripePriceRecurring {
|
||||
meter: Some("meter_1".to_string()),
|
||||
}),
|
||||
};
|
||||
stripe_client
|
||||
.prices
|
||||
.lock()
|
||||
.insert(price1.id.clone(), price1);
|
||||
stripe_client
|
||||
.prices
|
||||
.lock()
|
||||
.insert(price2.id.clone(), price2);
|
||||
stripe_client
|
||||
.prices
|
||||
.lock()
|
||||
.insert(price3.id.clone(), price3);
|
||||
|
||||
// Initialize the billing system
|
||||
stripe_billing.initialize().await.unwrap();
|
||||
|
||||
// Verify that prices can be found by lookup key
|
||||
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
|
||||
assert_eq!(zed_pro_price_id.to_string(), "price_1");
|
||||
|
||||
let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
|
||||
assert_eq!(zed_free_price_id.to_string(), "price_2");
|
||||
|
||||
// Verify that a price can be found by lookup key
|
||||
let zed_pro_price = stripe_billing
|
||||
.find_price_by_lookup_key("zed-pro")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(zed_pro_price.id.to_string(), "price_1");
|
||||
assert_eq!(zed_pro_price.unit_amount, Some(1_000));
|
||||
|
||||
// Verify that finding a non-existent lookup key returns an error
|
||||
let result = stripe_billing
|
||||
.find_price_by_lookup_key("non-existent")
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_find_or_create_customer_by_email() {
|
||||
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||
|
||||
// Create a customer with an email that doesn't yet correspond to a customer.
|
||||
{
|
||||
let email = "user@example.com";
|
||||
|
||||
let customer_id = stripe_billing
|
||||
.find_or_create_customer_by_email(Some(email))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let customer = stripe_client
|
||||
.customers
|
||||
.lock()
|
||||
.get(&customer_id)
|
||||
.unwrap()
|
||||
.clone();
|
||||
assert_eq!(customer.email.as_deref(), Some(email));
|
||||
}
|
||||
|
||||
// Create a customer with an email that corresponds to an existing customer.
|
||||
{
|
||||
let email = "user2@example.com";
|
||||
|
||||
let existing_customer_id = stripe_billing
|
||||
.find_or_create_customer_by_email(Some(email))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let customer_id = stripe_billing
|
||||
.find_or_create_customer_by_email(Some(email))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(customer_id, existing_customer_id);
|
||||
|
||||
let customer = stripe_client
|
||||
.customers
|
||||
.lock()
|
||||
.get(&customer_id)
|
||||
.unwrap()
|
||||
.clone();
|
||||
assert_eq!(customer.email.as_deref(), Some(email));
|
||||
}
|
||||
}
|
|
@ -1,4 +1,3 @@
|
|||
use crate::stripe_client::FakeStripeClient;
|
||||
use crate::{
|
||||
AppState, Config,
|
||||
db::{NewUserParams, UserId, tests::TestDb},
|
||||
|
@ -569,9 +568,6 @@ impl TestServer {
|
|||
llm_db: None,
|
||||
livekit_client: Some(Arc::new(livekit_test_server.create_api_client())),
|
||||
blob_store_client: None,
|
||||
real_stripe_client: None,
|
||||
stripe_client: Some(Arc::new(FakeStripeClient::new())),
|
||||
stripe_billing: None,
|
||||
executor,
|
||||
kinesis_client: None,
|
||||
config: Config {
|
||||
|
@ -608,7 +604,6 @@ impl TestServer {
|
|||
auto_join_channel_id: None,
|
||||
migrations_path: None,
|
||||
seed_path: None,
|
||||
stripe_api_key: None,
|
||||
supermaven_admin_api_key: None,
|
||||
user_backfiller_github_access_token: None,
|
||||
kinesis_region: None,
|
||||
|
|
|
@ -674,7 +674,7 @@ impl ChatPanel {
|
|||
})
|
||||
})
|
||||
.when_some(message_id, |el, message_id| {
|
||||
let this = cx.entity().clone();
|
||||
let this = cx.entity();
|
||||
|
||||
el.child(
|
||||
self.render_popover_button(
|
||||
|
|
|
@ -95,7 +95,7 @@ pub fn init(cx: &mut App) {
|
|||
.and_then(|room| room.read(cx).channel_id());
|
||||
|
||||
if let Some(channel_id) = channel_id {
|
||||
let workspace = cx.entity().clone();
|
||||
let workspace = cx.entity();
|
||||
window.defer(cx, move |window, cx| {
|
||||
ChannelView::open(channel_id, None, workspace, window, cx)
|
||||
.detach_and_log_err(cx)
|
||||
|
@ -1142,7 +1142,7 @@ impl CollabPanel {
|
|||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let this = cx.entity().clone();
|
||||
let this = cx.entity();
|
||||
if !(role == proto::ChannelRole::Guest
|
||||
|| role == proto::ChannelRole::Talker
|
||||
|| role == proto::ChannelRole::Member)
|
||||
|
@ -1272,7 +1272,7 @@ impl CollabPanel {
|
|||
.channel_for_id(clipboard.channel_id)
|
||||
.map(|channel| channel.name.clone())
|
||||
});
|
||||
let this = cx.entity().clone();
|
||||
let this = cx.entity();
|
||||
|
||||
let context_menu = ContextMenu::build(window, cx, |mut context_menu, window, cx| {
|
||||
if self.has_subchannels(ix) {
|
||||
|
@ -1439,7 +1439,7 @@ impl CollabPanel {
|
|||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let this = cx.entity().clone();
|
||||
let this = cx.entity();
|
||||
let in_room = ActiveCall::global(cx).read(cx).room().is_some();
|
||||
|
||||
let context_menu = ContextMenu::build(window, cx, |mut context_menu, _, _| {
|
||||
|
|
|
@ -586,7 +586,7 @@ impl ChannelModalDelegate {
|
|||
return;
|
||||
};
|
||||
let user_id = membership.user.id;
|
||||
let picker = cx.entity().clone();
|
||||
let picker = cx.entity();
|
||||
let context_menu = ContextMenu::build(window, cx, |mut menu, _window, _cx| {
|
||||
let role = membership.role;
|
||||
|
||||
|
|
|
@ -321,7 +321,7 @@ impl NotificationPanel {
|
|||
.justify_end()
|
||||
.child(Button::new("decline", "Decline").on_click({
|
||||
let notification = notification.clone();
|
||||
let entity = cx.entity().clone();
|
||||
let entity = cx.entity();
|
||||
move |_, _, cx| {
|
||||
entity.update(cx, |this, cx| {
|
||||
this.respond_to_notification(
|
||||
|
@ -334,7 +334,7 @@ impl NotificationPanel {
|
|||
}))
|
||||
.child(Button::new("accept", "Accept").on_click({
|
||||
let notification = notification.clone();
|
||||
let entity = cx.entity().clone();
|
||||
let entity = cx.entity();
|
||||
move |_, _, cx| {
|
||||
entity.update(cx, |this, cx| {
|
||||
this.respond_to_notification(
|
||||
|
|
|
@ -291,7 +291,7 @@ pub(crate) fn new_debugger_pane(
|
|||
let Some(project) = project.upgrade() else {
|
||||
return ControlFlow::Break(());
|
||||
};
|
||||
let this_pane = cx.entity().clone();
|
||||
let this_pane = cx.entity();
|
||||
let item = if tab.pane == this_pane {
|
||||
pane.item_for_index(tab.ix)
|
||||
} else {
|
||||
|
@ -502,7 +502,7 @@ pub(crate) fn new_debugger_pane(
|
|||
.on_drag(
|
||||
DraggedTab {
|
||||
item: item.boxed_clone(),
|
||||
pane: cx.entity().clone(),
|
||||
pane: cx.entity(),
|
||||
detail: 0,
|
||||
is_active: selected,
|
||||
ix,
|
||||
|
|
|
@ -971,7 +971,7 @@ async fn active_diagnostics_dismiss_after_invalidation(cx: &mut TestAppContext)
|
|||
|
||||
let mut cx = EditorTestContext::new(cx).await;
|
||||
let lsp_store =
|
||||
cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store());
|
||||
cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store());
|
||||
|
||||
cx.set_state(indoc! {"
|
||||
ˇfn func(abc def: i32) -> u32 {
|
||||
|
@ -1065,7 +1065,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) {
|
|||
|
||||
let mut cx = EditorTestContext::new(cx).await;
|
||||
let lsp_store =
|
||||
cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store());
|
||||
cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store());
|
||||
|
||||
cx.set_state(indoc! {"
|
||||
ˇfn func(abc def: i32) -> u32 {
|
||||
|
@ -1239,7 +1239,7 @@ async fn test_diagnostics_with_links(cx: &mut TestAppContext) {
|
|||
}
|
||||
"});
|
||||
let lsp_store =
|
||||
cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store());
|
||||
cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store());
|
||||
|
||||
cx.update(|_, cx| {
|
||||
lsp_store.update(cx, |lsp_store, cx| {
|
||||
|
@ -1293,7 +1293,7 @@ async fn test_hover_diagnostic_and_info_popovers(cx: &mut gpui::TestAppContext)
|
|||
fn «test»() { println!(); }
|
||||
"});
|
||||
let lsp_store =
|
||||
cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store());
|
||||
cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store());
|
||||
cx.update(|_, cx| {
|
||||
lsp_store.update(cx, |lsp_store, cx| {
|
||||
lsp_store.update_diagnostics(
|
||||
|
@ -1450,7 +1450,7 @@ async fn go_to_diagnostic_with_severity(cx: &mut TestAppContext) {
|
|||
|
||||
let mut cx = EditorTestContext::new(cx).await;
|
||||
let lsp_store =
|
||||
cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store());
|
||||
cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store());
|
||||
|
||||
cx.set_state(indoc! {"error warning info hiˇnt"});
|
||||
|
||||
|
|
|
@ -127,7 +127,7 @@ impl Render for EditPredictionButton {
|
|||
}),
|
||||
);
|
||||
}
|
||||
let this = cx.entity().clone();
|
||||
let this = cx.entity();
|
||||
|
||||
div().child(
|
||||
PopoverMenu::new("copilot")
|
||||
|
@ -182,7 +182,7 @@ impl Render for EditPredictionButton {
|
|||
let icon = status.to_icon();
|
||||
let tooltip_text = status.to_tooltip();
|
||||
let has_menu = status.has_menu();
|
||||
let this = cx.entity().clone();
|
||||
let this = cx.entity();
|
||||
let fs = self.fs.clone();
|
||||
|
||||
return div().child(
|
||||
|
@ -331,7 +331,7 @@ impl Render for EditPredictionButton {
|
|||
})
|
||||
});
|
||||
|
||||
let this = cx.entity().clone();
|
||||
let this = cx.entity();
|
||||
|
||||
let mut popover_menu = PopoverMenu::new("zeta")
|
||||
.menu(move |window, cx| {
|
||||
|
|
|
@ -1039,9 +1039,7 @@ pub struct Editor {
|
|||
inline_diagnostics: Vec<(Anchor, InlineDiagnostic)>,
|
||||
soft_wrap_mode_override: Option<language_settings::SoftWrap>,
|
||||
hard_wrap: Option<usize>,
|
||||
|
||||
// TODO: make this a access method
|
||||
pub project: Option<Entity<Project>>,
|
||||
project: Option<Entity<Project>>,
|
||||
semantics_provider: Option<Rc<dyn SemanticsProvider>>,
|
||||
completion_provider: Option<Rc<dyn CompletionProvider>>,
|
||||
collaboration_hub: Option<Box<dyn CollaborationHub>>,
|
||||
|
@ -2326,7 +2324,7 @@ impl Editor {
|
|||
editor.go_to_active_debug_line(window, cx);
|
||||
|
||||
if let Some(buffer) = buffer.read(cx).as_singleton() {
|
||||
if let Some(project) = editor.project.as_ref() {
|
||||
if let Some(project) = editor.project() {
|
||||
let handle = project.update(cx, |project, cx| {
|
||||
project.register_buffer_with_language_servers(&buffer, cx)
|
||||
});
|
||||
|
@ -2371,6 +2369,34 @@ impl Editor {
|
|||
.is_some_and(|menu| menu.context_menu.focus_handle(cx).is_focused(window))
|
||||
}
|
||||
|
||||
pub fn is_range_selected(&mut self, range: &Range<Anchor>, cx: &mut Context<Self>) -> bool {
|
||||
if self
|
||||
.selections
|
||||
.pending
|
||||
.as_ref()
|
||||
.is_some_and(|pending_selection| {
|
||||
let snapshot = self.buffer().read(cx).snapshot(cx);
|
||||
pending_selection
|
||||
.selection
|
||||
.range()
|
||||
.includes(&range, &snapshot)
|
||||
})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
self.selections
|
||||
.disjoint_in_range::<usize>(range.clone(), cx)
|
||||
.into_iter()
|
||||
.any(|selection| {
|
||||
// This is needed to cover a corner case, if we just check for an existing
|
||||
// selection in the fold range, having a cursor at the start of the fold
|
||||
// marks it as selected. Non-empty selections don't cause this.
|
||||
let length = selection.end - selection.start;
|
||||
length > 0
|
||||
})
|
||||
}
|
||||
|
||||
pub fn key_context(&self, window: &Window, cx: &App) -> KeyContext {
|
||||
self.key_context_internal(self.has_active_edit_prediction(), window, cx)
|
||||
}
|
||||
|
@ -2626,6 +2652,10 @@ impl Editor {
|
|||
&self.buffer
|
||||
}
|
||||
|
||||
pub fn project(&self) -> Option<&Entity<Project>> {
|
||||
self.project.as_ref()
|
||||
}
|
||||
|
||||
pub fn workspace(&self) -> Option<Entity<Workspace>> {
|
||||
self.workspace.as_ref()?.0.upgrade()
|
||||
}
|
||||
|
@ -5212,7 +5242,7 @@ impl Editor {
|
|||
restrict_to_languages: Option<&HashSet<Arc<Language>>>,
|
||||
cx: &mut Context<Editor>,
|
||||
) -> HashMap<ExcerptId, (Entity<Buffer>, clock::Global, Range<usize>)> {
|
||||
let Some(project) = self.project.as_ref() else {
|
||||
let Some(project) = self.project() else {
|
||||
return HashMap::default();
|
||||
};
|
||||
let project = project.read(cx);
|
||||
|
@ -5294,7 +5324,7 @@ impl Editor {
|
|||
return None;
|
||||
}
|
||||
|
||||
let project = self.project.as_ref()?;
|
||||
let project = self.project()?;
|
||||
let position = self.selections.newest_anchor().head();
|
||||
let (buffer, buffer_position) = self
|
||||
.buffer
|
||||
|
@ -6141,7 +6171,7 @@ impl Editor {
|
|||
cx: &mut App,
|
||||
) -> Task<Vec<task::DebugScenario>> {
|
||||
maybe!({
|
||||
let project = self.project.as_ref()?;
|
||||
let project = self.project()?;
|
||||
let dap_store = project.read(cx).dap_store();
|
||||
let mut scenarios = vec![];
|
||||
let resolved_tasks = resolved_tasks.as_ref()?;
|
||||
|
@ -7907,7 +7937,7 @@ impl Editor {
|
|||
let snapshot = self.snapshot(window, cx);
|
||||
|
||||
let multi_buffer_snapshot = &snapshot.display_snapshot.buffer_snapshot;
|
||||
let Some(project) = self.project.as_ref() else {
|
||||
let Some(project) = self.project() else {
|
||||
return breakpoint_display_points;
|
||||
};
|
||||
|
||||
|
@ -10501,7 +10531,7 @@ impl Editor {
|
|||
) {
|
||||
if let Some(working_directory) = self.active_excerpt(cx).and_then(|(_, buffer, _)| {
|
||||
let project_path = buffer.read(cx).project_path(cx)?;
|
||||
let project = self.project.as_ref()?.read(cx);
|
||||
let project = self.project()?.read(cx);
|
||||
let entry = project.entry_for_path(&project_path, cx)?;
|
||||
let parent = match &entry.canonical_path {
|
||||
Some(canonical_path) => canonical_path.to_path_buf(),
|
||||
|
@ -14875,7 +14905,7 @@ impl Editor {
|
|||
self.clear_tasks();
|
||||
return Task::ready(());
|
||||
}
|
||||
let project = self.project.as_ref().map(Entity::downgrade);
|
||||
let project = self.project().map(Entity::downgrade);
|
||||
let task_sources = self.lsp_task_sources(cx);
|
||||
let multi_buffer = self.buffer.downgrade();
|
||||
cx.spawn_in(window, async move |editor, cx| {
|
||||
|
@ -17054,7 +17084,7 @@ impl Editor {
|
|||
if !pull_diagnostics_settings.enabled {
|
||||
return None;
|
||||
}
|
||||
let project = self.project.as_ref()?.downgrade();
|
||||
let project = self.project()?.downgrade();
|
||||
let debounce = Duration::from_millis(pull_diagnostics_settings.debounce_ms);
|
||||
let mut buffers = self.buffer.read(cx).all_buffers();
|
||||
if let Some(buffer_id) = buffer_id {
|
||||
|
@ -18018,7 +18048,7 @@ impl Editor {
|
|||
hunks: impl Iterator<Item = MultiBufferDiffHunk>,
|
||||
cx: &mut App,
|
||||
) -> Option<()> {
|
||||
let project = self.project.as_ref()?;
|
||||
let project = self.project()?;
|
||||
let buffer = project.read(cx).buffer_for_id(buffer_id, cx)?;
|
||||
let diff = self.buffer.read(cx).diff_for(buffer_id)?;
|
||||
let buffer_snapshot = buffer.read(cx).snapshot();
|
||||
|
@ -18678,7 +18708,7 @@ impl Editor {
|
|||
self.active_excerpt(cx).and_then(|(_, buffer, _)| {
|
||||
let buffer = buffer.read(cx);
|
||||
if let Some(project_path) = buffer.project_path(cx) {
|
||||
let project = self.project.as_ref()?.read(cx);
|
||||
let project = self.project()?.read(cx);
|
||||
project.absolute_path(&project_path, cx)
|
||||
} else {
|
||||
buffer
|
||||
|
@ -18691,7 +18721,7 @@ impl Editor {
|
|||
fn target_file_path(&self, cx: &mut Context<Self>) -> Option<PathBuf> {
|
||||
self.active_excerpt(cx).and_then(|(_, buffer, _)| {
|
||||
let project_path = buffer.read(cx).project_path(cx)?;
|
||||
let project = self.project.as_ref()?.read(cx);
|
||||
let project = self.project()?.read(cx);
|
||||
let entry = project.entry_for_path(&project_path, cx)?;
|
||||
let path = entry.path.to_path_buf();
|
||||
Some(path)
|
||||
|
@ -18912,7 +18942,7 @@ impl Editor {
|
|||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(project) = self.project.as_ref() {
|
||||
if let Some(project) = self.project() {
|
||||
let Some(buffer) = self.buffer().read(cx).as_singleton() else {
|
||||
return;
|
||||
};
|
||||
|
@ -19028,7 +19058,7 @@ impl Editor {
|
|||
return Task::ready(Err(anyhow!("failed to determine buffer and selection")));
|
||||
};
|
||||
|
||||
let Some(project) = self.project.as_ref() else {
|
||||
let Some(project) = self.project() else {
|
||||
return Task::ready(Err(anyhow!("editor does not have project")));
|
||||
};
|
||||
|
||||
|
@ -21015,7 +21045,7 @@ impl Editor {
|
|||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let workspace = self.workspace();
|
||||
let project = self.project.as_ref();
|
||||
let project = self.project();
|
||||
let save_tasks = self.buffer().update(cx, |multi_buffer, cx| {
|
||||
let mut tasks = Vec::new();
|
||||
for (buffer_id, changes) in revert_changes {
|
||||
|
|
|
@ -74,7 +74,7 @@ fn test_edit_events(cx: &mut TestAppContext) {
|
|||
let editor1 = cx.add_window({
|
||||
let events = events.clone();
|
||||
|window, cx| {
|
||||
let entity = cx.entity().clone();
|
||||
let entity = cx.entity();
|
||||
cx.subscribe_in(
|
||||
&entity,
|
||||
window,
|
||||
|
@ -95,7 +95,7 @@ fn test_edit_events(cx: &mut TestAppContext) {
|
|||
let events = events.clone();
|
||||
|window, cx| {
|
||||
cx.subscribe_in(
|
||||
&cx.entity().clone(),
|
||||
&cx.entity(),
|
||||
window,
|
||||
move |_, _, event: &EditorEvent, _, _| match event {
|
||||
EditorEvent::Edited { .. } => events.borrow_mut().push(("editor2", "edited")),
|
||||
|
@ -15082,7 +15082,7 @@ async fn go_to_prev_overlapping_diagnostic(executor: BackgroundExecutor, cx: &mu
|
|||
|
||||
let mut cx = EditorTestContext::new(cx).await;
|
||||
let lsp_store =
|
||||
cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store());
|
||||
cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store());
|
||||
|
||||
cx.set_state(indoc! {"
|
||||
ˇfn func(abc def: i32) -> u32 {
|
||||
|
@ -19634,13 +19634,8 @@ fn test_crease_insertion_and_rendering(cx: &mut TestAppContext) {
|
|||
|
||||
editor.insert_creases(Some(crease), cx);
|
||||
let snapshot = editor.snapshot(window, cx);
|
||||
let _div = snapshot.render_crease_toggle(
|
||||
MultiBufferRow(1),
|
||||
false,
|
||||
cx.entity().clone(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
let _div =
|
||||
snapshot.render_crease_toggle(MultiBufferRow(1), false, cx.entity(), window, cx);
|
||||
snapshot
|
||||
})
|
||||
.unwrap();
|
||||
|
|
|
@ -7818,7 +7818,7 @@ impl Element for EditorElement {
|
|||
min_lines,
|
||||
max_lines,
|
||||
} => {
|
||||
let editor_handle = cx.entity().clone();
|
||||
let editor_handle = cx.entity();
|
||||
let max_line_number_width =
|
||||
self.max_line_number_width(&editor.snapshot(window, cx), window);
|
||||
window.request_measured_layout(
|
||||
|
|
|
@ -250,7 +250,7 @@ fn show_hover(
|
|||
|
||||
let (excerpt_id, _, _) = editor.buffer().read(cx).excerpt_containing(anchor, cx)?;
|
||||
|
||||
let language_registry = editor.project.as_ref()?.read(cx).languages().clone();
|
||||
let language_registry = editor.project()?.read(cx).languages().clone();
|
||||
let provider = editor.semantics_provider.clone()?;
|
||||
|
||||
if !ignore_timeout {
|
||||
|
|
|
@ -678,7 +678,7 @@ impl Item for Editor {
|
|||
let buffer = buffer.read(cx);
|
||||
let path = buffer.project_path(cx)?;
|
||||
let buffer_id = buffer.remote_id();
|
||||
let project = self.project.as_ref()?.read(cx);
|
||||
let project = self.project()?.read(cx);
|
||||
let entry = project.entry_for_path(&path, cx)?;
|
||||
let (repo, repo_path) = project
|
||||
.git_store()
|
||||
|
|
|
@ -51,7 +51,7 @@ pub(super) fn refresh_linked_ranges(
|
|||
if editor.pending_rename.is_some() {
|
||||
return None;
|
||||
}
|
||||
let project = editor.project.as_ref()?.downgrade();
|
||||
let project = editor.project()?.downgrade();
|
||||
|
||||
editor.linked_editing_range_task = Some(cx.spawn_in(window, async move |editor, cx| {
|
||||
cx.background_executor().timer(UPDATE_DEBOUNCE).await;
|
||||
|
|
|
@ -169,7 +169,7 @@ impl Editor {
|
|||
else {
|
||||
return;
|
||||
};
|
||||
let Some(lsp_store) = self.project.as_ref().map(|p| p.read(cx).lsp_store()) else {
|
||||
let Some(lsp_store) = self.project().map(|p| p.read(cx).lsp_store()) else {
|
||||
return;
|
||||
};
|
||||
let task = lsp_store.update(cx, |lsp_store, cx| {
|
||||
|
|
|
@ -297,9 +297,8 @@ impl EditorTestContext {
|
|||
|
||||
pub fn set_head_text(&mut self, diff_base: &str) {
|
||||
self.cx.run_until_parked();
|
||||
let fs = self.update_editor(|editor, _, cx| {
|
||||
editor.project.as_ref().unwrap().read(cx).fs().as_fake()
|
||||
});
|
||||
let fs =
|
||||
self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake());
|
||||
let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone());
|
||||
fs.set_head_for_repo(
|
||||
&Self::root_path().join(".git"),
|
||||
|
@ -311,18 +310,16 @@ impl EditorTestContext {
|
|||
|
||||
pub fn clear_index_text(&mut self) {
|
||||
self.cx.run_until_parked();
|
||||
let fs = self.update_editor(|editor, _, cx| {
|
||||
editor.project.as_ref().unwrap().read(cx).fs().as_fake()
|
||||
});
|
||||
let fs =
|
||||
self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake());
|
||||
fs.set_index_for_repo(&Self::root_path().join(".git"), &[]);
|
||||
self.cx.run_until_parked();
|
||||
}
|
||||
|
||||
pub fn set_index_text(&mut self, diff_base: &str) {
|
||||
self.cx.run_until_parked();
|
||||
let fs = self.update_editor(|editor, _, cx| {
|
||||
editor.project.as_ref().unwrap().read(cx).fs().as_fake()
|
||||
});
|
||||
let fs =
|
||||
self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake());
|
||||
let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone());
|
||||
fs.set_index_for_repo(
|
||||
&Self::root_path().join(".git"),
|
||||
|
@ -333,9 +330,8 @@ impl EditorTestContext {
|
|||
|
||||
#[track_caller]
|
||||
pub fn assert_index_text(&mut self, expected: Option<&str>) {
|
||||
let fs = self.update_editor(|editor, _, cx| {
|
||||
editor.project.as_ref().unwrap().read(cx).fs().as_fake()
|
||||
});
|
||||
let fs =
|
||||
self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake());
|
||||
let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone());
|
||||
let mut found = None;
|
||||
fs.with_git_state(&Self::root_path().join(".git"), false, |git_state| {
|
||||
|
|
|
@ -701,7 +701,7 @@ impl ExtensionsPage {
|
|||
extension: &ExtensionMetadata,
|
||||
cx: &mut Context<Self>,
|
||||
) -> ExtensionCard {
|
||||
let this = cx.entity().clone();
|
||||
let this = cx.entity();
|
||||
let status = Self::extension_status(&extension.id, cx);
|
||||
let has_dev_extension = Self::dev_extension_exists(&extension.id, cx);
|
||||
|
||||
|
|
|
@ -112,7 +112,7 @@ fn excerpt_for_buffer_updated(
|
|||
}
|
||||
|
||||
fn buffer_added(editor: &mut Editor, buffer: Entity<Buffer>, cx: &mut Context<Editor>) {
|
||||
let Some(project) = &editor.project else {
|
||||
let Some(project) = editor.project() else {
|
||||
return;
|
||||
};
|
||||
let git_store = project.read(cx).git_store().clone();
|
||||
|
@ -469,7 +469,7 @@ pub(crate) fn resolve_conflict(
|
|||
let Some((workspace, project, multibuffer, buffer)) = editor
|
||||
.update(cx, |editor, cx| {
|
||||
let workspace = editor.workspace()?;
|
||||
let project = editor.project.clone()?;
|
||||
let project = editor.project()?.clone();
|
||||
let multibuffer = editor.buffer().clone();
|
||||
let buffer_id = resolved_conflict.ours.end.buffer_id?;
|
||||
let buffer = multibuffer.read(cx).buffer(buffer_id)?;
|
||||
|
|
|
@ -3246,7 +3246,7 @@ impl GitPanel {
|
|||
* MAX_PANEL_EDITOR_LINES
|
||||
+ gap;
|
||||
|
||||
let git_panel = cx.entity().clone();
|
||||
let git_panel = cx.entity();
|
||||
let display_name = SharedString::from(Arc::from(
|
||||
active_repository
|
||||
.read(cx)
|
||||
|
|
|
@ -595,9 +595,7 @@ impl Render for TextInput {
|
|||
.w_full()
|
||||
.p(px(4.))
|
||||
.bg(white())
|
||||
.child(TextElement {
|
||||
input: cx.entity().clone(),
|
||||
}),
|
||||
.child(TextElement { input: cx.entity() }),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1358,7 +1358,7 @@ impl Render for LspLogToolbarItemView {
|
|||
})
|
||||
.collect();
|
||||
|
||||
let log_toolbar_view = cx.entity().clone();
|
||||
let log_toolbar_view = cx.entity();
|
||||
|
||||
let lsp_menu = PopoverMenu::new("LspLogView")
|
||||
.anchor(Corner::TopLeft)
|
||||
|
|
|
@ -1007,7 +1007,7 @@ impl Render for LspTool {
|
|||
(None, "All Servers Operational")
|
||||
};
|
||||
|
||||
let lsp_tool = cx.entity().clone();
|
||||
let lsp_tool = cx.entity();
|
||||
|
||||
div().child(
|
||||
PopoverMenu::new("lsp-tool")
|
||||
|
|
|
@ -456,7 +456,7 @@ impl SyntaxTreeToolbarItemView {
|
|||
let active_layer = buffer_state.active_layer.clone()?;
|
||||
let active_buffer = buffer_state.buffer.read(cx).snapshot();
|
||||
|
||||
let view = cx.entity().clone();
|
||||
let view = cx.entity();
|
||||
Some(
|
||||
PopoverMenu::new("Syntax Tree")
|
||||
.trigger(Self::render_header(&active_layer))
|
||||
|
|
|
@ -4655,20 +4655,15 @@ impl OutlinePanel {
|
|||
.when(show_indent_guides, |list| {
|
||||
list.with_decoration(
|
||||
ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx))
|
||||
.with_compute_indents_fn(
|
||||
cx.entity().clone(),
|
||||
|outline_panel, range, _, _| {
|
||||
.with_compute_indents_fn(cx.entity(), |outline_panel, range, _, _| {
|
||||
let entries = outline_panel.cached_entries.get(range);
|
||||
if let Some(entries) = entries {
|
||||
entries.into_iter().map(|item| item.depth).collect()
|
||||
} else {
|
||||
smallvec::SmallVec::new()
|
||||
}
|
||||
},
|
||||
)
|
||||
.with_render_fn(
|
||||
cx.entity().clone(),
|
||||
move |outline_panel, params, _, _| {
|
||||
})
|
||||
.with_render_fn(cx.entity(), move |outline_panel, params, _, _| {
|
||||
const LEFT_OFFSET: Pixels = px(14.);
|
||||
|
||||
let indent_size = params.indent_size;
|
||||
|
@ -4698,8 +4693,7 @@ impl OutlinePanel {
|
|||
}
|
||||
})
|
||||
.collect()
|
||||
},
|
||||
),
|
||||
}),
|
||||
)
|
||||
})
|
||||
.custom_scrollbars(
|
||||
|
|
|
@ -5187,9 +5187,7 @@ impl Render for ProjectPanel {
|
|||
.when(show_indent_guides, |list| {
|
||||
list.with_decoration(
|
||||
ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx))
|
||||
.with_compute_indents_fn(
|
||||
cx.entity().clone(),
|
||||
|this, range, window, cx| {
|
||||
.with_compute_indents_fn(cx.entity(), |this, range, window, cx| {
|
||||
let mut items =
|
||||
SmallVec::with_capacity(range.end - range.start);
|
||||
this.iter_visible_entries(
|
||||
|
@ -5197,16 +5195,14 @@ impl Render for ProjectPanel {
|
|||
window,
|
||||
cx,
|
||||
|entry, _, entries, _, _| {
|
||||
let (depth, _) =
|
||||
Self::calculate_depth_and_difference(
|
||||
let (depth, _) = Self::calculate_depth_and_difference(
|
||||
entry, entries,
|
||||
);
|
||||
items.push(depth);
|
||||
},
|
||||
);
|
||||
items
|
||||
},
|
||||
)
|
||||
})
|
||||
.on_click(cx.listener(
|
||||
|this, active_indent_guide: &IndentGuideLayout, window, cx| {
|
||||
if window.modifiers().secondary() {
|
||||
|
@ -5230,7 +5226,7 @@ impl Render for ProjectPanel {
|
|||
}
|
||||
},
|
||||
))
|
||||
.with_render_fn(cx.entity().clone(), move |this, params, _, cx| {
|
||||
.with_render_fn(cx.entity(), move |this, params, _, cx| {
|
||||
const LEFT_OFFSET: Pixels = px(14.);
|
||||
const PADDING_Y: Pixels = px(4.);
|
||||
const HITBOX_OVERDRAW: Pixels = px(3.);
|
||||
|
@ -5283,7 +5279,7 @@ impl Render for ProjectPanel {
|
|||
})
|
||||
.when(show_sticky_entries, |list| {
|
||||
let sticky_items = ui::sticky_items(
|
||||
cx.entity().clone(),
|
||||
cx.entity(),
|
||||
|this, range, window, cx| {
|
||||
let mut items = SmallVec::with_capacity(range.end - range.start);
|
||||
this.iter_visible_entries(
|
||||
|
@ -5310,7 +5306,7 @@ impl Render for ProjectPanel {
|
|||
list.with_decoration(if show_indent_guides {
|
||||
sticky_items.with_decoration(
|
||||
ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx))
|
||||
.with_render_fn(cx.entity().clone(), move |_, params, _, _| {
|
||||
.with_render_fn(cx.entity(), move |_, params, _, _| {
|
||||
const LEFT_OFFSET: Pixels = px(14.);
|
||||
|
||||
let indent_size = params.indent_size;
|
||||
|
|
|
@ -158,14 +158,6 @@ message SynchronizeContextsResponse {
|
|||
repeated ContextVersion contexts = 1;
|
||||
}
|
||||
|
||||
message GetLlmToken {}
|
||||
|
||||
message GetLlmTokenResponse {
|
||||
string token = 1;
|
||||
}
|
||||
|
||||
message RefreshLlmToken {}
|
||||
|
||||
enum LanguageModelRole {
|
||||
LanguageModelUser = 0;
|
||||
LanguageModelAssistant = 1;
|
||||
|
|
|
@ -6,62 +6,6 @@ message UpdateInviteInfo {
|
|||
uint32 count = 2;
|
||||
}
|
||||
|
||||
message GetPrivateUserInfo {}
|
||||
|
||||
message GetPrivateUserInfoResponse {
|
||||
string metrics_id = 1;
|
||||
bool staff = 2;
|
||||
repeated string flags = 3;
|
||||
optional uint64 accepted_tos_at = 4;
|
||||
}
|
||||
|
||||
enum Plan {
|
||||
Free = 0;
|
||||
ZedPro = 1;
|
||||
ZedProTrial = 2;
|
||||
}
|
||||
|
||||
message UpdateUserPlan {
|
||||
Plan plan = 1;
|
||||
optional uint64 trial_started_at = 2;
|
||||
optional bool is_usage_based_billing_enabled = 3;
|
||||
optional SubscriptionUsage usage = 4;
|
||||
optional SubscriptionPeriod subscription_period = 5;
|
||||
optional bool account_too_young = 6;
|
||||
optional bool has_overdue_invoices = 7;
|
||||
}
|
||||
|
||||
message SubscriptionPeriod {
|
||||
uint64 started_at = 1;
|
||||
uint64 ended_at = 2;
|
||||
}
|
||||
|
||||
message SubscriptionUsage {
|
||||
uint32 model_requests_usage_amount = 1;
|
||||
UsageLimit model_requests_usage_limit = 2;
|
||||
uint32 edit_predictions_usage_amount = 3;
|
||||
UsageLimit edit_predictions_usage_limit = 4;
|
||||
}
|
||||
|
||||
message UsageLimit {
|
||||
oneof variant {
|
||||
Limited limited = 1;
|
||||
Unlimited unlimited = 2;
|
||||
}
|
||||
|
||||
message Limited {
|
||||
uint32 limit = 1;
|
||||
}
|
||||
|
||||
message Unlimited {}
|
||||
}
|
||||
|
||||
message AcceptTermsOfService {}
|
||||
|
||||
message AcceptTermsOfServiceResponse {
|
||||
uint64 accepted_tos_at = 1;
|
||||
}
|
||||
|
||||
message ShutdownRemoteServer {}
|
||||
|
||||
message Toast {
|
||||
|
|
|
@ -135,12 +135,7 @@ message Envelope {
|
|||
FollowResponse follow_response = 99;
|
||||
UpdateFollowers update_followers = 100;
|
||||
Unfollow unfollow = 101;
|
||||
GetPrivateUserInfo get_private_user_info = 102;
|
||||
GetPrivateUserInfoResponse get_private_user_info_response = 103;
|
||||
UpdateUserPlan update_user_plan = 234;
|
||||
UpdateDiffBases update_diff_bases = 104;
|
||||
AcceptTermsOfService accept_terms_of_service = 239;
|
||||
AcceptTermsOfServiceResponse accept_terms_of_service_response = 240;
|
||||
|
||||
OnTypeFormatting on_type_formatting = 105;
|
||||
OnTypeFormattingResponse on_type_formatting_response = 106;
|
||||
|
@ -250,10 +245,6 @@ message Envelope {
|
|||
AddWorktree add_worktree = 222;
|
||||
AddWorktreeResponse add_worktree_response = 223;
|
||||
|
||||
GetLlmToken get_llm_token = 235;
|
||||
GetLlmTokenResponse get_llm_token_response = 236;
|
||||
RefreshLlmToken refresh_llm_token = 259;
|
||||
|
||||
LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241;
|
||||
LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242;
|
||||
|
||||
|
@ -406,6 +397,7 @@ message Envelope {
|
|||
}
|
||||
|
||||
reserved 87 to 88;
|
||||
reserved 102 to 103;
|
||||
reserved 158 to 161;
|
||||
reserved 164;
|
||||
reserved 166 to 169;
|
||||
|
@ -419,10 +411,13 @@ message Envelope {
|
|||
reserved 221;
|
||||
reserved 224 to 229;
|
||||
reserved 230 to 231;
|
||||
reserved 234 to 236;
|
||||
reserved 239 to 240;
|
||||
reserved 246;
|
||||
reserved 270;
|
||||
reserved 247 to 254;
|
||||
reserved 255 to 256;
|
||||
reserved 259;
|
||||
reserved 270;
|
||||
reserved 280 to 281;
|
||||
reserved 332 to 333;
|
||||
}
|
||||
|
|
|
@ -20,8 +20,6 @@ pub const SSH_PEER_ID: PeerId = PeerId { owner_id: 0, id: 0 };
|
|||
pub const SSH_PROJECT_ID: u64 = 0;
|
||||
|
||||
messages!(
|
||||
(AcceptTermsOfService, Foreground),
|
||||
(AcceptTermsOfServiceResponse, Foreground),
|
||||
(Ack, Foreground),
|
||||
(AckBufferOperation, Background),
|
||||
(AckChannelMessage, Background),
|
||||
|
@ -105,8 +103,6 @@ messages!(
|
|||
(GetPathMetadataResponse, Background),
|
||||
(GetPermalinkToLine, Foreground),
|
||||
(GetPermalinkToLineResponse, Foreground),
|
||||
(GetPrivateUserInfo, Foreground),
|
||||
(GetPrivateUserInfoResponse, Foreground),
|
||||
(GetProjectSymbols, Background),
|
||||
(GetProjectSymbolsResponse, Background),
|
||||
(GetReferences, Background),
|
||||
|
@ -119,8 +115,6 @@ messages!(
|
|||
(GetTypeDefinitionResponse, Background),
|
||||
(GetImplementation, Background),
|
||||
(GetImplementationResponse, Background),
|
||||
(GetLlmToken, Background),
|
||||
(GetLlmTokenResponse, Background),
|
||||
(OpenUnstagedDiff, Foreground),
|
||||
(OpenUnstagedDiffResponse, Foreground),
|
||||
(OpenUncommittedDiff, Foreground),
|
||||
|
@ -196,7 +190,6 @@ messages!(
|
|||
(PrepareRenameResponse, Background),
|
||||
(ProjectEntryResponse, Foreground),
|
||||
(RefreshInlayHints, Foreground),
|
||||
(RefreshLlmToken, Background),
|
||||
(RegisterBufferWithLanguageServers, Background),
|
||||
(RejoinChannelBuffers, Foreground),
|
||||
(RejoinChannelBuffersResponse, Foreground),
|
||||
|
@ -280,7 +273,6 @@ messages!(
|
|||
(UpdateProject, Foreground),
|
||||
(UpdateProjectCollaborator, Foreground),
|
||||
(UpdateUserChannels, Foreground),
|
||||
(UpdateUserPlan, Foreground),
|
||||
(UpdateWorktree, Foreground),
|
||||
(UpdateWorktreeSettings, Foreground),
|
||||
(UpdateRepository, Foreground),
|
||||
|
@ -321,7 +313,6 @@ messages!(
|
|||
);
|
||||
|
||||
request_messages!(
|
||||
(AcceptTermsOfService, AcceptTermsOfServiceResponse),
|
||||
(ApplyCodeAction, ApplyCodeActionResponse),
|
||||
(
|
||||
ApplyCompletionAdditionalEdits,
|
||||
|
@ -354,9 +345,7 @@ request_messages!(
|
|||
(GetDocumentHighlights, GetDocumentHighlightsResponse),
|
||||
(GetDocumentSymbols, GetDocumentSymbolsResponse),
|
||||
(GetHover, GetHoverResponse),
|
||||
(GetLlmToken, GetLlmTokenResponse),
|
||||
(GetNotifications, GetNotificationsResponse),
|
||||
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
|
||||
(GetProjectSymbols, GetProjectSymbolsResponse),
|
||||
(GetReferences, GetReferencesResponse),
|
||||
(GetSignatureHelp, GetSignatureHelpResponse),
|
||||
|
|
|
@ -1291,7 +1291,7 @@ impl RemoteServerProjects {
|
|||
let connection_string = connection_string.clone();
|
||||
move |_, _: &menu::Confirm, window, cx| {
|
||||
remove_ssh_server(
|
||||
cx.entity().clone(),
|
||||
cx.entity(),
|
||||
server_index,
|
||||
connection_string.clone(),
|
||||
window,
|
||||
|
@ -1311,7 +1311,7 @@ impl RemoteServerProjects {
|
|||
.child(Label::new("Remove Server").color(Color::Error))
|
||||
.on_click(cx.listener(move |_, _, window, cx| {
|
||||
remove_ssh_server(
|
||||
cx.entity().clone(),
|
||||
cx.entity(),
|
||||
server_index,
|
||||
connection_string.clone(),
|
||||
window,
|
||||
|
|
|
@ -244,7 +244,7 @@ impl Session {
|
|||
repl_session_id = cx.entity_id().to_string(),
|
||||
);
|
||||
|
||||
let session_view = cx.entity().clone();
|
||||
let session_view = cx.entity();
|
||||
|
||||
let kernel = match self.kernel_specification.clone() {
|
||||
KernelSpecification::Jupyter(kernel_specification)
|
||||
|
|
|
@ -2,9 +2,9 @@ mod registrar;
|
|||
|
||||
use crate::{
|
||||
FocusSearch, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext, SearchOption,
|
||||
SearchOptions, SelectAllMatches, SelectNextMatch, SelectPreviousMatch, ToggleCaseSensitive,
|
||||
ToggleRegex, ToggleReplace, ToggleSelection, ToggleWholeWord,
|
||||
search_bar::{input_base_styles, render_action_button, render_text_input},
|
||||
SearchOptions, SearchSource, SelectAllMatches, SelectNextMatch, SelectPreviousMatch,
|
||||
ToggleCaseSensitive, ToggleRegex, ToggleReplace, ToggleSelection, ToggleWholeWord,
|
||||
search_bar::{ActionButtonState, input_base_styles, render_action_button, render_text_input},
|
||||
};
|
||||
use any_vec::AnyVec;
|
||||
use anyhow::Context as _;
|
||||
|
@ -213,22 +213,25 @@ impl Render for BufferSearchBar {
|
|||
h_flex()
|
||||
.gap_1()
|
||||
.when(case, |div| {
|
||||
div.child(
|
||||
SearchOption::CaseSensitive
|
||||
.as_button(self.search_options, focus_handle.clone()),
|
||||
)
|
||||
div.child(SearchOption::CaseSensitive.as_button(
|
||||
self.search_options,
|
||||
SearchSource::Buffer,
|
||||
focus_handle.clone(),
|
||||
))
|
||||
})
|
||||
.when(word, |div| {
|
||||
div.child(
|
||||
SearchOption::WholeWord
|
||||
.as_button(self.search_options, focus_handle.clone()),
|
||||
)
|
||||
div.child(SearchOption::WholeWord.as_button(
|
||||
self.search_options,
|
||||
SearchSource::Buffer,
|
||||
focus_handle.clone(),
|
||||
))
|
||||
})
|
||||
.when(regex, |div| {
|
||||
div.child(
|
||||
SearchOption::Regex
|
||||
.as_button(self.search_options, focus_handle.clone()),
|
||||
)
|
||||
div.child(SearchOption::Regex.as_button(
|
||||
self.search_options,
|
||||
SearchSource::Buffer,
|
||||
focus_handle.clone(),
|
||||
))
|
||||
}),
|
||||
)
|
||||
});
|
||||
|
@ -240,7 +243,7 @@ impl Render for BufferSearchBar {
|
|||
this.child(render_action_button(
|
||||
"buffer-search-bar-toggle",
|
||||
IconName::Replace,
|
||||
self.replace_enabled,
|
||||
self.replace_enabled.then_some(ActionButtonState::Toggled),
|
||||
"Toggle Replace",
|
||||
&ToggleReplace,
|
||||
focus_handle.clone(),
|
||||
|
@ -285,7 +288,9 @@ impl Render for BufferSearchBar {
|
|||
.child(render_action_button(
|
||||
"buffer-search-nav-button",
|
||||
ui::IconName::ChevronLeft,
|
||||
self.active_match_index.is_some(),
|
||||
self.active_match_index
|
||||
.is_none()
|
||||
.then_some(ActionButtonState::Disabled),
|
||||
"Select Previous Match",
|
||||
&SelectPreviousMatch,
|
||||
query_focus.clone(),
|
||||
|
@ -293,7 +298,9 @@ impl Render for BufferSearchBar {
|
|||
.child(render_action_button(
|
||||
"buffer-search-nav-button",
|
||||
ui::IconName::ChevronRight,
|
||||
self.active_match_index.is_some(),
|
||||
self.active_match_index
|
||||
.is_none()
|
||||
.then_some(ActionButtonState::Disabled),
|
||||
"Select Next Match",
|
||||
&SelectNextMatch,
|
||||
query_focus.clone(),
|
||||
|
@ -313,7 +320,7 @@ impl Render for BufferSearchBar {
|
|||
el.child(render_action_button(
|
||||
"buffer-search-nav-button",
|
||||
IconName::SelectAll,
|
||||
true,
|
||||
Default::default(),
|
||||
"Select All Matches",
|
||||
&SelectAllMatches,
|
||||
query_focus,
|
||||
|
@ -324,7 +331,7 @@ impl Render for BufferSearchBar {
|
|||
el.child(render_action_button(
|
||||
"buffer-search",
|
||||
IconName::Close,
|
||||
true,
|
||||
Default::default(),
|
||||
"Close Search Bar",
|
||||
&Dismiss,
|
||||
focus_handle.clone(),
|
||||
|
@ -352,7 +359,7 @@ impl Render for BufferSearchBar {
|
|||
.child(render_action_button(
|
||||
"buffer-search-replace-button",
|
||||
IconName::ReplaceNext,
|
||||
true,
|
||||
Default::default(),
|
||||
"Replace Next Match",
|
||||
&ReplaceNext,
|
||||
focus_handle.clone(),
|
||||
|
@ -360,7 +367,7 @@ impl Render for BufferSearchBar {
|
|||
.child(render_action_button(
|
||||
"buffer-search-replace-button",
|
||||
IconName::ReplaceAll,
|
||||
true,
|
||||
Default::default(),
|
||||
"Replace All Matches",
|
||||
&ReplaceAll,
|
||||
focus_handle,
|
||||
|
@ -394,7 +401,7 @@ impl Render for BufferSearchBar {
|
|||
div.child(h_flex().absolute().right_0().child(render_action_button(
|
||||
"buffer-search",
|
||||
IconName::Close,
|
||||
true,
|
||||
Default::default(),
|
||||
"Close Search Bar",
|
||||
&Dismiss,
|
||||
focus_handle.clone(),
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use crate::{
|
||||
BufferSearchBar, FocusSearch, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext,
|
||||
SearchOption, SearchOptions, SelectNextMatch, SelectPreviousMatch, ToggleCaseSensitive,
|
||||
ToggleIncludeIgnored, ToggleRegex, ToggleReplace, ToggleWholeWord,
|
||||
SearchOption, SearchOptions, SearchSource, SelectNextMatch, SelectPreviousMatch,
|
||||
ToggleCaseSensitive, ToggleIncludeIgnored, ToggleRegex, ToggleReplace, ToggleWholeWord,
|
||||
buffer_search::Deploy,
|
||||
search_bar::{input_base_styles, render_action_button, render_text_input},
|
||||
search_bar::{ActionButtonState, input_base_styles, render_action_button, render_text_input},
|
||||
};
|
||||
use anyhow::Context as _;
|
||||
use collections::HashMap;
|
||||
|
@ -1665,7 +1665,7 @@ impl ProjectSearchBar {
|
|||
});
|
||||
}
|
||||
|
||||
fn toggle_search_option(
|
||||
pub(crate) fn toggle_search_option(
|
||||
&mut self,
|
||||
option: SearchOptions,
|
||||
window: &mut Window,
|
||||
|
@ -1962,17 +1962,21 @@ impl Render for ProjectSearchBar {
|
|||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
SearchOption::CaseSensitive
|
||||
.as_button(search.search_options, focus_handle.clone()),
|
||||
)
|
||||
.child(
|
||||
SearchOption::WholeWord
|
||||
.as_button(search.search_options, focus_handle.clone()),
|
||||
)
|
||||
.child(
|
||||
SearchOption::Regex.as_button(search.search_options, focus_handle.clone()),
|
||||
),
|
||||
.child(SearchOption::CaseSensitive.as_button(
|
||||
search.search_options,
|
||||
SearchSource::Project(cx),
|
||||
focus_handle.clone(),
|
||||
))
|
||||
.child(SearchOption::WholeWord.as_button(
|
||||
search.search_options,
|
||||
SearchSource::Project(cx),
|
||||
focus_handle.clone(),
|
||||
))
|
||||
.child(SearchOption::Regex.as_button(
|
||||
search.search_options,
|
||||
SearchSource::Project(cx),
|
||||
focus_handle.clone(),
|
||||
)),
|
||||
);
|
||||
|
||||
let query_focus = search.query_editor.focus_handle(cx);
|
||||
|
@ -1985,7 +1989,10 @@ impl Render for ProjectSearchBar {
|
|||
.child(render_action_button(
|
||||
"project-search-nav-button",
|
||||
IconName::ChevronLeft,
|
||||
search.active_match_index.is_some(),
|
||||
search
|
||||
.active_match_index
|
||||
.is_none()
|
||||
.then_some(ActionButtonState::Disabled),
|
||||
"Select Previous Match",
|
||||
&SelectPreviousMatch,
|
||||
query_focus.clone(),
|
||||
|
@ -1993,7 +2000,10 @@ impl Render for ProjectSearchBar {
|
|||
.child(render_action_button(
|
||||
"project-search-nav-button",
|
||||
IconName::ChevronRight,
|
||||
search.active_match_index.is_some(),
|
||||
search
|
||||
.active_match_index
|
||||
.is_none()
|
||||
.then_some(ActionButtonState::Disabled),
|
||||
"Select Next Match",
|
||||
&SelectNextMatch,
|
||||
query_focus,
|
||||
|
@ -2054,7 +2064,7 @@ impl Render for ProjectSearchBar {
|
|||
self.active_project_search
|
||||
.as_ref()
|
||||
.map(|search| search.read(cx).replace_enabled)
|
||||
.unwrap_or_default(),
|
||||
.and_then(|enabled| enabled.then_some(ActionButtonState::Toggled)),
|
||||
"Toggle Replace",
|
||||
&ToggleReplace,
|
||||
focus_handle.clone(),
|
||||
|
@ -2079,7 +2089,7 @@ impl Render for ProjectSearchBar {
|
|||
.child(render_action_button(
|
||||
"project-search-replace-button",
|
||||
IconName::ReplaceNext,
|
||||
true,
|
||||
Default::default(),
|
||||
"Replace Next Match",
|
||||
&ReplaceNext,
|
||||
focus_handle.clone(),
|
||||
|
@ -2087,7 +2097,7 @@ impl Render for ProjectSearchBar {
|
|||
.child(render_action_button(
|
||||
"project-search-replace-button",
|
||||
IconName::ReplaceAll,
|
||||
true,
|
||||
Default::default(),
|
||||
"Replace All Matches",
|
||||
&ReplaceAll,
|
||||
focus_handle,
|
||||
|
@ -2129,10 +2139,11 @@ impl Render for ProjectSearchBar {
|
|||
this.toggle_opened_only(window, cx);
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
SearchOption::IncludeIgnored
|
||||
.as_button(search.search_options, focus_handle.clone()),
|
||||
);
|
||||
.child(SearchOption::IncludeIgnored.as_button(
|
||||
search.search_options,
|
||||
SearchSource::Project(cx),
|
||||
focus_handle.clone(),
|
||||
));
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_2()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use bitflags::bitflags;
|
||||
pub use buffer_search::BufferSearchBar;
|
||||
use editor::SearchSettings;
|
||||
use gpui::{Action, App, FocusHandle, IntoElement, actions};
|
||||
use gpui::{Action, App, ClickEvent, FocusHandle, IntoElement, actions};
|
||||
use project::search::SearchQuery;
|
||||
pub use project_search::ProjectSearchView;
|
||||
use ui::{ButtonStyle, IconButton, IconButtonShape};
|
||||
|
@ -11,6 +11,8 @@ use workspace::{Toast, Workspace};
|
|||
|
||||
pub use search_status_button::SEARCH_ICON;
|
||||
|
||||
use crate::project_search::ProjectSearchBar;
|
||||
|
||||
pub mod buffer_search;
|
||||
pub mod project_search;
|
||||
pub(crate) mod search_bar;
|
||||
|
@ -83,9 +85,14 @@ pub enum SearchOption {
|
|||
Backwards,
|
||||
}
|
||||
|
||||
pub(crate) enum SearchSource<'a, 'b> {
|
||||
Buffer,
|
||||
Project(&'a Context<'b, ProjectSearchBar>),
|
||||
}
|
||||
|
||||
impl SearchOption {
|
||||
pub fn as_options(self) -> SearchOptions {
|
||||
SearchOptions::from_bits(1 << self as u8).unwrap()
|
||||
pub fn as_options(&self) -> SearchOptions {
|
||||
SearchOptions::from_bits(1 << *self as u8).unwrap()
|
||||
}
|
||||
|
||||
pub fn label(&self) -> &'static str {
|
||||
|
@ -119,17 +126,33 @@ impl SearchOption {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn as_button(&self, active: SearchOptions, focus_handle: FocusHandle) -> impl IntoElement {
|
||||
pub(crate) fn as_button(
|
||||
&self,
|
||||
active: SearchOptions,
|
||||
search_source: SearchSource,
|
||||
focus_handle: FocusHandle,
|
||||
) -> impl IntoElement {
|
||||
let action = self.to_toggle_action();
|
||||
let label = self.label();
|
||||
IconButton::new(label, self.icon())
|
||||
.on_click({
|
||||
IconButton::new(
|
||||
(label, matches!(search_source, SearchSource::Buffer) as u32),
|
||||
self.icon(),
|
||||
)
|
||||
.map(|button| match search_source {
|
||||
SearchSource::Buffer => {
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |_, window, cx| {
|
||||
button.on_click(move |_: &ClickEvent, window, cx| {
|
||||
if !focus_handle.is_focused(&window) {
|
||||
window.focus(&focus_handle);
|
||||
}
|
||||
window.dispatch_action(action.boxed_clone(), cx)
|
||||
window.dispatch_action(action.boxed_clone(), cx);
|
||||
})
|
||||
}
|
||||
SearchSource::Project(cx) => {
|
||||
let options = self.as_options();
|
||||
button.on_click(cx.listener(move |this, _: &ClickEvent, window, cx| {
|
||||
this.toggle_search_option(options, window, cx);
|
||||
}))
|
||||
}
|
||||
})
|
||||
.style(ButtonStyle::Subtle)
|
||||
|
|
|
@ -5,10 +5,15 @@ use theme::ThemeSettings;
|
|||
use ui::{IconButton, IconButtonShape};
|
||||
use ui::{Tooltip, prelude::*};
|
||||
|
||||
pub(super) enum ActionButtonState {
|
||||
Disabled,
|
||||
Toggled,
|
||||
}
|
||||
|
||||
pub(super) fn render_action_button(
|
||||
id_prefix: &'static str,
|
||||
icon: ui::IconName,
|
||||
active: bool,
|
||||
button_state: Option<ActionButtonState>,
|
||||
tooltip: &'static str,
|
||||
action: &'static dyn Action,
|
||||
focus_handle: FocusHandle,
|
||||
|
@ -28,7 +33,10 @@ pub(super) fn render_action_button(
|
|||
}
|
||||
})
|
||||
.tooltip(move |window, cx| Tooltip::for_action_in(tooltip, action, &focus_handle, window, cx))
|
||||
.disabled(!active)
|
||||
.when_some(button_state, |this, state| match state {
|
||||
ActionButtonState::Toggled => this.toggle_state(true),
|
||||
ActionButtonState::Disabled => this.disabled(true),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn input_base_styles(border_color: Hsla, map: impl FnOnce(Div) -> Div) -> Div {
|
||||
|
|
|
@ -2177,6 +2177,7 @@ impl KeybindingEditorModal {
|
|||
|
||||
let value = action_arguments
|
||||
.as_ref()
|
||||
.filter(|args| !args.is_empty())
|
||||
.map(|args| {
|
||||
serde_json::from_str(args).context("Failed to parse action arguments as JSON")
|
||||
})
|
||||
|
|
|
@ -65,7 +65,7 @@ impl Render for IndentGuidesStory {
|
|||
},
|
||||
)
|
||||
.with_compute_indents_fn(
|
||||
cx.entity().clone(),
|
||||
cx.entity(),
|
||||
|this, range, _cx, _context| {
|
||||
this.depths
|
||||
.iter()
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
use semantic_version::SemanticVersion;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::HashMap, fmt::Display, sync::Arc, time::Duration};
|
||||
use std::{collections::HashMap, fmt::Display, time::Duration};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct EventRequestBody {
|
||||
|
@ -93,19 +93,6 @@ impl Display for AssistantPhase {
|
|||
#[serde(tag = "type")]
|
||||
pub enum Event {
|
||||
Flexible(FlexibleEvent),
|
||||
Editor(EditorEvent),
|
||||
EditPrediction(EditPredictionEvent),
|
||||
EditPredictionRating(EditPredictionRatingEvent),
|
||||
Call(CallEvent),
|
||||
Assistant(AssistantEventData),
|
||||
Cpu(CpuEvent),
|
||||
Memory(MemoryEvent),
|
||||
App(AppEvent),
|
||||
Setting(SettingEvent),
|
||||
Extension(ExtensionEvent),
|
||||
Edit(EditEvent),
|
||||
Action(ActionEvent),
|
||||
Repl(ReplEvent),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
|
@ -114,54 +101,12 @@ pub struct FlexibleEvent {
|
|||
pub event_properties: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct EditorEvent {
|
||||
/// The editor operation performed (open, save)
|
||||
pub operation: String,
|
||||
/// The extension of the file that was opened or saved
|
||||
pub file_extension: Option<String>,
|
||||
/// Whether the user is in vim mode or not
|
||||
pub vim_mode: bool,
|
||||
/// Whether the user has copilot enabled or not
|
||||
pub copilot_enabled: bool,
|
||||
/// Whether the user has copilot enabled for the language of the file opened or saved
|
||||
pub copilot_enabled_for_language: bool,
|
||||
/// Whether the client is opening/saving a local file or a remote file via SSH
|
||||
#[serde(default)]
|
||||
pub is_via_ssh: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct EditPredictionEvent {
|
||||
/// Provider of the completion suggestion (e.g. copilot, supermaven)
|
||||
pub provider: String,
|
||||
pub suggestion_accepted: bool,
|
||||
pub file_extension: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub enum EditPredictionRating {
|
||||
Positive,
|
||||
Negative,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct EditPredictionRatingEvent {
|
||||
pub rating: EditPredictionRating,
|
||||
pub input_events: Arc<str>,
|
||||
pub input_excerpt: Arc<str>,
|
||||
pub output_excerpt: Arc<str>,
|
||||
pub feedback: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CallEvent {
|
||||
/// Operation performed: invite/join call; begin/end screenshare; share/unshare project; etc
|
||||
pub operation: String,
|
||||
pub room_id: Option<u64>,
|
||||
pub channel_id: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct AssistantEventData {
|
||||
/// Unique random identifier for each assistant tab (None for inline assist)
|
||||
|
@ -180,57 +125,6 @@ pub struct AssistantEventData {
|
|||
pub language_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CpuEvent {
|
||||
pub usage_as_percentage: f32,
|
||||
pub core_count: u32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MemoryEvent {
|
||||
pub memory_in_bytes: u64,
|
||||
pub virtual_memory_in_bytes: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ActionEvent {
|
||||
pub source: String,
|
||||
pub action: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct EditEvent {
|
||||
pub duration: i64,
|
||||
pub environment: String,
|
||||
/// Whether the edits occurred locally or remotely via SSH
|
||||
#[serde(default)]
|
||||
pub is_via_ssh: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SettingEvent {
|
||||
pub setting: String,
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ExtensionEvent {
|
||||
pub extension_id: Arc<str>,
|
||||
pub version: Arc<str>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct AppEvent {
|
||||
pub operation: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ReplEvent {
|
||||
pub kernel_language: String,
|
||||
pub kernel_status: String,
|
||||
pub repl_session_id: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct BacktraceFrame {
|
||||
pub ip: usize,
|
||||
|
|
|
@ -947,7 +947,7 @@ pub fn new_terminal_pane(
|
|||
cx: &mut Context<TerminalPanel>,
|
||||
) -> Entity<Pane> {
|
||||
let is_local = project.read(cx).is_local();
|
||||
let terminal_panel = cx.entity().clone();
|
||||
let terminal_panel = cx.entity();
|
||||
let pane = cx.new(|cx| {
|
||||
let mut pane = Pane::new(
|
||||
workspace.clone(),
|
||||
|
@ -1009,7 +1009,7 @@ pub fn new_terminal_pane(
|
|||
return ControlFlow::Break(());
|
||||
};
|
||||
if let Some(tab) = dropped_item.downcast_ref::<DraggedTab>() {
|
||||
let this_pane = cx.entity().clone();
|
||||
let this_pane = cx.entity();
|
||||
let item = if tab.pane == this_pane {
|
||||
pane.item_for_index(tab.ix)
|
||||
} else {
|
||||
|
|
|
@ -1391,7 +1391,7 @@ impl Render for TerminalView {
|
|||
}
|
||||
|
||||
let terminal_handle = self.terminal.clone();
|
||||
let terminal_view_handle = cx.entity().clone();
|
||||
let terminal_view_handle = cx.entity();
|
||||
|
||||
let focused = self.focus_handle.is_focused(window);
|
||||
|
||||
|
|
|
@ -299,7 +299,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) {
|
|||
|
||||
Vim::action(editor, cx, |vim, action: &VimSave, window, cx| {
|
||||
vim.update_editor(cx, |_, editor, cx| {
|
||||
let Some(project) = editor.project.clone() else {
|
||||
let Some(project) = editor.project().cloned() else {
|
||||
return;
|
||||
};
|
||||
let Some(worktree) = project.read(cx).visible_worktrees(cx).next() else {
|
||||
|
@ -436,7 +436,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) {
|
|||
let Some(workspace) = vim.workspace(window) else {
|
||||
return;
|
||||
};
|
||||
let Some(project) = editor.project.clone() else {
|
||||
let Some(project) = editor.project().cloned() else {
|
||||
return;
|
||||
};
|
||||
let Some(worktree) = project.read(cx).visible_worktrees(cx).next() else {
|
||||
|
|
|
@ -20,7 +20,7 @@ impl ModeIndicator {
|
|||
})
|
||||
.detach();
|
||||
|
||||
let handle = cx.entity().clone();
|
||||
let handle = cx.entity();
|
||||
let window_handle = window.window_handle();
|
||||
cx.observe_new::<Vim>(move |_, window, cx| {
|
||||
let Some(window) = window else {
|
||||
|
@ -29,7 +29,7 @@ impl ModeIndicator {
|
|||
if window.window_handle() != window_handle {
|
||||
return;
|
||||
}
|
||||
let vim = cx.entity().clone();
|
||||
let vim = cx.entity();
|
||||
handle.update(cx, |_, cx| {
|
||||
cx.subscribe(&vim, |mode_indicator, vim, event, cx| match event {
|
||||
VimEvent::Focused => {
|
||||
|
|
|
@ -332,7 +332,7 @@ impl Vim {
|
|||
Vim::take_forced_motion(cx);
|
||||
let prior_selections = self.editor_selections(window, cx);
|
||||
let cursor_word = self.editor_cursor_word(window, cx);
|
||||
let vim = cx.entity().clone();
|
||||
let vim = cx.entity();
|
||||
|
||||
let searched = pane.update(cx, |pane, cx| {
|
||||
self.search.direction = direction;
|
||||
|
|
|
@ -402,7 +402,7 @@ impl Vim {
|
|||
const NAMESPACE: &'static str = "vim";
|
||||
|
||||
pub fn new(window: &mut Window, cx: &mut Context<Editor>) -> Entity<Self> {
|
||||
let editor = cx.entity().clone();
|
||||
let editor = cx.entity();
|
||||
|
||||
let mut initial_mode = VimSettings::get_global(cx).default_mode;
|
||||
if initial_mode == Mode::Normal && HelixModeSetting::get_global(cx).0 {
|
||||
|
|
|
@ -253,7 +253,7 @@ impl Dock {
|
|||
cx: &mut Context<Workspace>,
|
||||
) -> Entity<Self> {
|
||||
let focus_handle = cx.focus_handle();
|
||||
let workspace = cx.entity().clone();
|
||||
let workspace = cx.entity();
|
||||
let dock = cx.new(|cx| {
|
||||
let focus_subscription =
|
||||
cx.on_focus(&focus_handle, window, |dock: &mut Dock, window, cx| {
|
||||
|
|
|
@ -346,7 +346,7 @@ impl Render for LanguageServerPrompt {
|
|||
)
|
||||
.child(Label::new(request.message.to_string()).size(LabelSize::Small))
|
||||
.children(request.actions.iter().enumerate().map(|(ix, action)| {
|
||||
let this_handle = cx.entity().clone();
|
||||
let this_handle = cx.entity();
|
||||
Button::new(ix, action.title.clone())
|
||||
.size(ButtonSize::Large)
|
||||
.on_click(move |_, window, cx| {
|
||||
|
|
|
@ -2198,7 +2198,7 @@ impl Pane {
|
|||
|
||||
fn update_status_bar(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let workspace = self.workspace.clone();
|
||||
let pane = cx.entity().clone();
|
||||
let pane = cx.entity();
|
||||
|
||||
window.defer(cx, move |window, cx| {
|
||||
let Ok(status_bar) =
|
||||
|
@ -2279,7 +2279,7 @@ impl Pane {
|
|||
cx: &mut Context<Self>,
|
||||
) {
|
||||
maybe!({
|
||||
let pane = cx.entity().clone();
|
||||
let pane = cx.entity();
|
||||
|
||||
let destination_index = match operation {
|
||||
PinOperation::Pin => self.pinned_tab_count.min(ix),
|
||||
|
@ -2473,7 +2473,7 @@ impl Pane {
|
|||
.on_drag(
|
||||
DraggedTab {
|
||||
item: item.boxed_clone(),
|
||||
pane: cx.entity().clone(),
|
||||
pane: cx.entity(),
|
||||
detail,
|
||||
is_active,
|
||||
ix,
|
||||
|
@ -2832,7 +2832,7 @@ impl Pane {
|
|||
let navigate_backward = IconButton::new("navigate_backward", IconName::ArrowLeft)
|
||||
.icon_size(IconSize::Small)
|
||||
.on_click({
|
||||
let entity = cx.entity().clone();
|
||||
let entity = cx.entity();
|
||||
move |_, window, cx| {
|
||||
entity.update(cx, |pane, cx| pane.navigate_backward(window, cx))
|
||||
}
|
||||
|
@ -2848,7 +2848,7 @@ impl Pane {
|
|||
let navigate_forward = IconButton::new("navigate_forward", IconName::ArrowRight)
|
||||
.icon_size(IconSize::Small)
|
||||
.on_click({
|
||||
let entity = cx.entity().clone();
|
||||
let entity = cx.entity();
|
||||
move |_, window, cx| entity.update(cx, |pane, cx| pane.navigate_forward(window, cx))
|
||||
})
|
||||
.disabled(!self.can_navigate_forward())
|
||||
|
@ -3054,7 +3054,7 @@ impl Pane {
|
|||
return;
|
||||
}
|
||||
}
|
||||
let mut to_pane = cx.entity().clone();
|
||||
let mut to_pane = cx.entity();
|
||||
let split_direction = self.drag_split_direction;
|
||||
let item_id = dragged_tab.item.item_id();
|
||||
if let Some(preview_item_id) = self.preview_item_id {
|
||||
|
@ -3163,7 +3163,7 @@ impl Pane {
|
|||
return;
|
||||
}
|
||||
}
|
||||
let mut to_pane = cx.entity().clone();
|
||||
let mut to_pane = cx.entity();
|
||||
let split_direction = self.drag_split_direction;
|
||||
let project_entry_id = *project_entry_id;
|
||||
self.workspace
|
||||
|
@ -3239,7 +3239,7 @@ impl Pane {
|
|||
return;
|
||||
}
|
||||
}
|
||||
let mut to_pane = cx.entity().clone();
|
||||
let mut to_pane = cx.entity();
|
||||
let mut split_direction = self.drag_split_direction;
|
||||
let paths = paths.paths().to_vec();
|
||||
let is_remote = self
|
||||
|
|
|
@ -6338,7 +6338,7 @@ impl Render for Workspace {
|
|||
.border_b_1()
|
||||
.border_color(colors.border)
|
||||
.child({
|
||||
let this = cx.entity().clone();
|
||||
let this = cx.entity();
|
||||
canvas(
|
||||
move |bounds, window, cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
|
|
|
@ -319,7 +319,7 @@ pub fn initialize_workspace(
|
|||
return;
|
||||
};
|
||||
|
||||
let workspace_handle = cx.entity().clone();
|
||||
let workspace_handle = cx.entity();
|
||||
let center_pane = workspace.active_pane().clone();
|
||||
initialize_pane(workspace, ¢er_pane, window, cx);
|
||||
|
||||
|
|
|
@ -229,8 +229,7 @@ fn assign_edit_prediction_provider(
|
|||
if let Some(file) = buffer.read(cx).file() {
|
||||
let id = file.worktree_id(cx);
|
||||
if let Some(inner_worktree) = editor
|
||||
.project
|
||||
.as_ref()
|
||||
.project()
|
||||
.and_then(|project| project.read(cx).worktree_for_id(id, cx))
|
||||
{
|
||||
worktree = Some(inner_worktree);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue