From 75b832029a7ab35442e030fff05df55dbbd2d6de Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 15 Aug 2025 13:26:21 -0400 Subject: [PATCH 01/17] Remove RPC messages pertaining to the LLM token (#36252) This PR removes the RPC messages pertaining to the LLM token. We now retrieve the LLM token from Cloud. Release Notes: - N/A --- Cargo.lock | 2 - crates/collab/Cargo.toml | 2 - crates/collab/src/llm.rs | 3 - crates/collab/src/llm/token.rs | 146 --------------------------------- crates/collab/src/rpc.rs | 96 +--------------------- crates/proto/proto/ai.proto | 8 -- crates/proto/proto/zed.proto | 6 +- crates/proto/src/proto.rs | 4 - 8 files changed, 4 insertions(+), 263 deletions(-) delete mode 100644 crates/collab/src/llm/token.rs diff --git a/Cargo.lock b/Cargo.lock index 2353733dc0..bfc797d6cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3324,7 +3324,6 @@ dependencies = [ "http_client", "hyper 0.14.32", "indoc", - "jsonwebtoken", "language", "language_model", "livekit_api", @@ -3370,7 +3369,6 @@ dependencies = [ "telemetry_events", "text", "theme", - "thiserror 2.0.12", "time", "tokio", "toml 0.8.20", diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 9af95317e6..9a867f9e05 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -39,7 +39,6 @@ futures.workspace = true gpui.workspace = true hex.workspace = true http_client.workspace = true -jsonwebtoken.workspace = true livekit_api.workspace = true log.workspace = true nanoid.workspace = true @@ -65,7 +64,6 @@ subtle.workspace = true supermaven_api.workspace = true telemetry_events.workspace = true text.workspace = true -thiserror.workspace = true time.workspace = true tokio = { workspace = true, features = ["full"] } toml.workspace = true diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index de74858168..ca8e89bc6d 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -1,7 +1,4 @@ pub mod db; -mod token; - -pub use token::*; pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial"; diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs deleted file mode 100644 index da01c7f3be..0000000000 --- a/crates/collab/src/llm/token.rs +++ /dev/null @@ -1,146 +0,0 @@ -use crate::db::billing_subscription::SubscriptionKind; -use crate::db::{billing_customer, billing_subscription, user}; -use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG}; -use crate::{Config, db::billing_preference}; -use anyhow::{Context as _, Result}; -use chrono::{NaiveDateTime, Utc}; -use cloud_llm_client::Plan; -use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; -use serde::{Deserialize, Serialize}; -use std::time::Duration; -use thiserror::Error; -use uuid::Uuid; - -#[derive(Clone, Debug, Default, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct LlmTokenClaims { - pub iat: u64, - pub exp: u64, - pub jti: String, - pub user_id: u64, - pub system_id: Option, - pub metrics_id: Uuid, - pub github_user_login: String, - pub account_created_at: NaiveDateTime, - pub is_staff: bool, - pub has_llm_closed_beta_feature_flag: bool, - pub bypass_account_age_check: bool, - pub use_llm_request_queue: bool, - pub plan: Plan, - pub has_extended_trial: bool, - pub subscription_period: (NaiveDateTime, NaiveDateTime), - pub enable_model_request_overages: bool, - pub model_request_overages_spend_limit_in_cents: u32, - pub can_use_web_search_tool: bool, - #[serde(default)] - pub has_overdue_invoices: bool, -} - -const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60); - -impl LlmTokenClaims { - pub fn create( - user: &user::Model, - is_staff: bool, - billing_customer: billing_customer::Model, - billing_preferences: Option, - feature_flags: &Vec, - subscription: billing_subscription::Model, - system_id: Option, - config: &Config, - ) -> Result { - let secret = config - .llm_api_secret - .as_ref() - .context("no LLM API secret")?; - - let plan = if is_staff { - Plan::ZedPro - } else { - subscription.kind.map_or(Plan::ZedFree, |kind| match kind { - SubscriptionKind::ZedFree => Plan::ZedFree, - SubscriptionKind::ZedPro => Plan::ZedPro, - SubscriptionKind::ZedProTrial => Plan::ZedProTrial, - }) - }; - let subscription_period = - billing_subscription::Model::current_period(Some(subscription), is_staff) - .map(|(start, end)| (start.naive_utc(), end.naive_utc())) - .context("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started.")?; - - let now = Utc::now(); - let claims = Self { - iat: now.timestamp() as u64, - exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64, - jti: uuid::Uuid::new_v4().to_string(), - user_id: user.id.to_proto(), - system_id, - metrics_id: user.metrics_id, - github_user_login: user.github_login.clone(), - account_created_at: user.account_created_at(), - is_staff, - has_llm_closed_beta_feature_flag: feature_flags - .iter() - .any(|flag| flag == "llm-closed-beta"), - bypass_account_age_check: feature_flags - .iter() - .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG), - can_use_web_search_tool: true, - use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"), - plan, - has_extended_trial: feature_flags - .iter() - .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG), - subscription_period, - enable_model_request_overages: billing_preferences - .as_ref() - .map_or(false, |preferences| { - preferences.model_request_overages_enabled - }), - model_request_overages_spend_limit_in_cents: billing_preferences - .as_ref() - .map_or(0, |preferences| { - preferences.model_request_overages_spend_limit_in_cents as u32 - }), - has_overdue_invoices: billing_customer.has_overdue_invoices, - }; - - Ok(jsonwebtoken::encode( - &Header::default(), - &claims, - &EncodingKey::from_secret(secret.as_ref()), - )?) - } - - pub fn validate(token: &str, config: &Config) -> Result { - let secret = config - .llm_api_secret - .as_ref() - .context("no LLM API secret")?; - - match jsonwebtoken::decode::( - token, - &DecodingKey::from_secret(secret.as_ref()), - &Validation::default(), - ) { - Ok(token) => Ok(token.claims), - Err(e) => { - if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature { - Err(ValidateLlmTokenError::Expired) - } else { - Err(ValidateLlmTokenError::JwtError(e)) - } - } - } - } -} - -#[derive(Error, Debug)] -pub enum ValidateLlmTokenError { - #[error("access token is expired")] - Expired, - #[error("access token validation error: {0}")] - JwtError(#[from] jsonwebtoken::errors::Error), - #[error("{0}")] - Other(#[from] anyhow::Error), -} diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 584970a4c6..715ff4e67d 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,14 +1,12 @@ mod connection_pool; -use crate::api::billing::find_or_create_billing_customer; use crate::api::{CloudflareIpCountryHeader, SystemIdHeader}; use crate::db::billing_subscription::SubscriptionKind; use crate::llm::db::LlmDatabase; use crate::llm::{ - AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims, + AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, MIN_ACCOUNT_AGE_FOR_LLM_USE, }; -use crate::stripe_client::StripeCustomerId; use crate::{ AppState, Error, Result, auth, db::{ @@ -218,6 +216,7 @@ struct Session { /// The GeoIP country code for the user. #[allow(unused)] geoip_country_code: Option, + #[allow(unused)] system_id: Option, _executor: Executor, } @@ -464,7 +463,6 @@ impl Server { .add_message_handler(unfollow) .add_message_handler(update_followers) .add_request_handler(get_private_user_info) - .add_request_handler(get_llm_api_token) .add_request_handler(accept_terms_of_service) .add_message_handler(acknowledge_channel_message) .add_message_handler(acknowledge_buffer_version) @@ -4251,96 +4249,6 @@ async fn accept_terms_of_service( accepted_tos_at: accepted_tos_at.timestamp() as u64, })?; - // When the user accepts the terms of service, we want to refresh their LLM - // token to grant access. - session - .peer - .send(session.connection_id, proto::RefreshLlmToken {})?; - - Ok(()) -} - -async fn get_llm_api_token( - _request: proto::GetLlmToken, - response: Response, - session: MessageContext, -) -> Result<()> { - let db = session.db().await; - - let flags = db.get_user_flags(session.user_id()).await?; - - let user_id = session.user_id(); - let user = db - .get_user_by_id(user_id) - .await? - .with_context(|| format!("user {user_id} not found"))?; - - if user.accepted_tos_at.is_none() { - Err(anyhow!("terms of service not accepted"))? - } - - let stripe_client = session - .app_state - .stripe_client - .as_ref() - .context("failed to retrieve Stripe client")?; - - let stripe_billing = session - .app_state - .stripe_billing - .as_ref() - .context("failed to retrieve Stripe billing object")?; - - let billing_customer = if let Some(billing_customer) = - db.get_billing_customer_by_user_id(user.id).await? - { - billing_customer - } else { - let customer_id = stripe_billing - .find_or_create_customer_by_email(user.email_address.as_deref()) - .await?; - - find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id) - .await? - .context("billing customer not found")? - }; - - let billing_subscription = - if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? { - billing_subscription - } else { - let stripe_customer_id = - StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); - - let stripe_subscription = stripe_billing - .subscribe_to_zed_free(stripe_customer_id) - .await?; - - db.create_billing_subscription(&db::CreateBillingSubscriptionParams { - billing_customer_id: billing_customer.id, - kind: Some(SubscriptionKind::ZedFree), - stripe_subscription_id: stripe_subscription.id.to_string(), - stripe_subscription_status: stripe_subscription.status.into(), - stripe_cancellation_reason: None, - stripe_current_period_start: Some(stripe_subscription.current_period_start), - stripe_current_period_end: Some(stripe_subscription.current_period_end), - }) - .await? - }; - - let billing_preferences = db.get_billing_preferences(user.id).await?; - - let token = LlmTokenClaims::create( - &user, - session.is_staff(), - billing_customer, - billing_preferences, - &flags, - billing_subscription, - session.system_id.clone(), - &session.app_state.config, - )?; - response.send(proto::GetLlmTokenResponse { token })?; Ok(()) } diff --git a/crates/proto/proto/ai.proto b/crates/proto/proto/ai.proto index 67c2224387..1064ed2f8d 100644 --- a/crates/proto/proto/ai.proto +++ b/crates/proto/proto/ai.proto @@ -158,14 +158,6 @@ message SynchronizeContextsResponse { repeated ContextVersion contexts = 1; } -message GetLlmToken {} - -message GetLlmTokenResponse { - string token = 1; -} - -message RefreshLlmToken {} - enum LanguageModelRole { LanguageModelUser = 0; LanguageModelAssistant = 1; diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 856a793c2f..b6c7fc3cac 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -250,10 +250,6 @@ message Envelope { AddWorktree add_worktree = 222; AddWorktreeResponse add_worktree_response = 223; - GetLlmToken get_llm_token = 235; - GetLlmTokenResponse get_llm_token_response = 236; - RefreshLlmToken refresh_llm_token = 259; - LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241; LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242; @@ -419,7 +415,9 @@ message Envelope { reserved 221; reserved 224 to 229; reserved 230 to 231; + reserved 235 to 236; reserved 246; + reserved 259; reserved 270; reserved 247 to 254; reserved 255 to 256; diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index a5dd97661f..8be9fed172 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -119,8 +119,6 @@ messages!( (GetTypeDefinitionResponse, Background), (GetImplementation, Background), (GetImplementationResponse, Background), - (GetLlmToken, Background), - (GetLlmTokenResponse, Background), (OpenUnstagedDiff, Foreground), (OpenUnstagedDiffResponse, Foreground), (OpenUncommittedDiff, Foreground), @@ -196,7 +194,6 @@ messages!( (PrepareRenameResponse, Background), (ProjectEntryResponse, Foreground), (RefreshInlayHints, Foreground), - (RefreshLlmToken, Background), (RegisterBufferWithLanguageServers, Background), (RejoinChannelBuffers, Foreground), (RejoinChannelBuffersResponse, Foreground), @@ -354,7 +351,6 @@ request_messages!( (GetDocumentHighlights, GetDocumentHighlightsResponse), (GetDocumentSymbols, GetDocumentSymbolsResponse), (GetHover, GetHoverResponse), - (GetLlmToken, GetLlmTokenResponse), (GetNotifications, GetNotificationsResponse), (GetPrivateUserInfo, GetPrivateUserInfoResponse), (GetProjectSymbols, GetProjectSymbolsResponse), From e452aba9da0cd66ec227371a2466f7a97847d5a9 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 15 Aug 2025 13:59:08 -0400 Subject: [PATCH 02/17] proto: Order `reserved` fields (#36261) This PR orders the `reserved` fields in the RPC `Envelope`, as they had gotten unsorted. Release Notes: - N/A --- crates/proto/proto/zed.proto | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index b6c7fc3cac..7e7bd6b42b 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -417,10 +417,10 @@ message Envelope { reserved 230 to 231; reserved 235 to 236; reserved 246; - reserved 259; - reserved 270; reserved 247 to 254; reserved 255 to 256; + reserved 259; + reserved 270; reserved 280 to 281; reserved 332 to 333; } From bd1fda6782933678be7ed8e39494aba32af871d1 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 15 Aug 2025 14:27:31 -0400 Subject: [PATCH 03/17] proto: Remove `GetPrivateUserInfo` message (#36265) This PR removes the `GetPrivateUserInfo` RPC message. We're no longer using the message after https://github.com/zed-industries/zed/pull/36255. Release Notes: - N/A --- crates/client/src/test.rs | 67 +++++++++++------------------------- crates/collab/src/rpc.rs | 25 -------------- crates/proto/proto/app.proto | 9 ----- crates/proto/proto/zed.proto | 3 +- crates/proto/src/proto.rs | 3 -- 5 files changed, 21 insertions(+), 86 deletions(-) diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 439fb100d2..3c451fcb01 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -1,16 +1,12 @@ use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore}; use anyhow::{Context as _, Result, anyhow}; -use chrono::Duration; use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo}; use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit}; use futures::{StreamExt, stream::BoxStream}; use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext}; use http_client::{AsyncBody, Method, Request, http}; use parking_lot::Mutex; -use rpc::{ - ConnectionId, Peer, Receipt, TypedEnvelope, - proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse}, -}; +use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto}; use std::sync::Arc; pub struct FakeServer { @@ -187,50 +183,27 @@ impl FakeServer { pub async fn receive(&self) -> Result> { self.executor.start_waiting(); - loop { - let message = self - .state - .lock() - .incoming - .as_mut() - .expect("not connected") - .next() - .await - .context("other half hung up")?; - self.executor.finish_waiting(); - let type_name = message.payload_type_name(); - let message = message.into_any(); + let message = self + .state + .lock() + .incoming + .as_mut() + .expect("not connected") + .next() + .await + .context("other half hung up")?; + self.executor.finish_waiting(); + let type_name = message.payload_type_name(); + let message = message.into_any(); - if message.is::>() { - return Ok(*message.downcast().unwrap()); - } - - let accepted_tos_at = chrono::Utc::now() - .checked_sub_signed(Duration::hours(5)) - .expect("failed to build accepted_tos_at") - .timestamp() as u64; - - if message.is::>() { - self.respond( - message - .downcast::>() - .unwrap() - .receipt(), - GetPrivateUserInfoResponse { - metrics_id: "the-metrics-id".into(), - staff: false, - flags: Default::default(), - accepted_tos_at: Some(accepted_tos_at), - }, - ); - continue; - } - - panic!( - "fake server received unexpected message type: {:?}", - type_name - ); + if message.is::>() { + return Ok(*message.downcast().unwrap()); } + + panic!( + "fake server received unexpected message type: {:?}", + type_name + ); } pub fn respond(&self, receipt: Receipt, response: T::Response) { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 715ff4e67d..8366b2cf13 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -462,7 +462,6 @@ impl Server { .add_request_handler(follow) .add_message_handler(unfollow) .add_message_handler(update_followers) - .add_request_handler(get_private_user_info) .add_request_handler(accept_terms_of_service) .add_message_handler(acknowledge_channel_message) .add_message_handler(acknowledge_buffer_version) @@ -4209,30 +4208,6 @@ async fn mark_notification_as_read( Ok(()) } -/// Get the current users information -async fn get_private_user_info( - _request: proto::GetPrivateUserInfo, - response: Response, - session: MessageContext, -) -> Result<()> { - let db = session.db().await; - - let metrics_id = db.get_user_metrics_id(session.user_id()).await?; - let user = db - .get_user_by_id(session.user_id()) - .await? - .context("user not found")?; - let flags = db.get_user_flags(session.user_id()).await?; - - response.send(proto::GetPrivateUserInfoResponse { - metrics_id, - staff: user.admin, - flags, - accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64), - })?; - Ok(()) -} - /// Accept the terms of service (tos) on behalf of the current user async fn accept_terms_of_service( _request: proto::AcceptTermsOfService, diff --git a/crates/proto/proto/app.proto b/crates/proto/proto/app.proto index 353f19adb2..66baf968e3 100644 --- a/crates/proto/proto/app.proto +++ b/crates/proto/proto/app.proto @@ -6,15 +6,6 @@ message UpdateInviteInfo { uint32 count = 2; } -message GetPrivateUserInfo {} - -message GetPrivateUserInfoResponse { - string metrics_id = 1; - bool staff = 2; - repeated string flags = 3; - optional uint64 accepted_tos_at = 4; -} - enum Plan { Free = 0; ZedPro = 1; diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 7e7bd6b42b..8984df2944 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -135,8 +135,6 @@ message Envelope { FollowResponse follow_response = 99; UpdateFollowers update_followers = 100; Unfollow unfollow = 101; - GetPrivateUserInfo get_private_user_info = 102; - GetPrivateUserInfoResponse get_private_user_info_response = 103; UpdateUserPlan update_user_plan = 234; UpdateDiffBases update_diff_bases = 104; AcceptTermsOfService accept_terms_of_service = 239; @@ -402,6 +400,7 @@ message Envelope { } reserved 87 to 88; + reserved 102 to 103; reserved 158 to 161; reserved 164; reserved 166 to 169; diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 8be9fed172..82bd1af6db 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -105,8 +105,6 @@ messages!( (GetPathMetadataResponse, Background), (GetPermalinkToLine, Foreground), (GetPermalinkToLineResponse, Foreground), - (GetPrivateUserInfo, Foreground), - (GetPrivateUserInfoResponse, Foreground), (GetProjectSymbols, Background), (GetProjectSymbolsResponse, Background), (GetReferences, Background), @@ -352,7 +350,6 @@ request_messages!( (GetDocumentSymbols, GetDocumentSymbolsResponse), (GetHover, GetHoverResponse), (GetNotifications, GetNotificationsResponse), - (GetPrivateUserInfo, GetPrivateUserInfoResponse), (GetProjectSymbols, GetProjectSymbolsResponse), (GetReferences, GetReferencesResponse), (GetSignatureHelp, GetSignatureHelpResponse), From 3c5d5a1d57f8569fa2818a0538d0ba950036c710 Mon Sep 17 00:00:00 2001 From: Finn Evers Date: Fri, 15 Aug 2025 20:34:22 +0200 Subject: [PATCH 04/17] editor: Add access method for `project` (#36266) This resolves a `TODO` that I've stumbled upon too many times whilst looking at the editor code. Release Notes: - N/A --- crates/diagnostics/src/diagnostics_tests.rs | 10 +++--- crates/editor/src/editor.rs | 36 ++++++++++--------- crates/editor/src/editor_tests.rs | 2 +- crates/editor/src/hover_popover.rs | 2 +- crates/editor/src/items.rs | 2 +- crates/editor/src/linked_editing_ranges.rs | 2 +- crates/editor/src/signature_help.rs | 2 +- crates/editor/src/test/editor_test_context.rs | 20 +++++------ crates/git_ui/src/conflict_view.rs | 4 +-- crates/vim/src/command.rs | 4 +-- .../zed/src/zed/edit_prediction_registry.rs | 3 +- 11 files changed, 42 insertions(+), 45 deletions(-) diff --git a/crates/diagnostics/src/diagnostics_tests.rs b/crates/diagnostics/src/diagnostics_tests.rs index 8fb223b2cb..5df1b13897 100644 --- a/crates/diagnostics/src/diagnostics_tests.rs +++ b/crates/diagnostics/src/diagnostics_tests.rs @@ -971,7 +971,7 @@ async fn active_diagnostics_dismiss_after_invalidation(cx: &mut TestAppContext) let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {" ˇfn func(abc def: i32) -> u32 { @@ -1065,7 +1065,7 @@ async fn cycle_through_same_place_diagnostics(cx: &mut TestAppContext) { let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {" ˇfn func(abc def: i32) -> u32 { @@ -1239,7 +1239,7 @@ async fn test_diagnostics_with_links(cx: &mut TestAppContext) { } "}); let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.update(|_, cx| { lsp_store.update(cx, |lsp_store, cx| { @@ -1293,7 +1293,7 @@ async fn test_hover_diagnostic_and_info_popovers(cx: &mut gpui::TestAppContext) fn «test»() { println!(); } "}); let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.update(|_, cx| { lsp_store.update(cx, |lsp_store, cx| { lsp_store.update_diagnostics( @@ -1450,7 +1450,7 @@ async fn go_to_diagnostic_with_severity(cx: &mut TestAppContext) { let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {"error warning info hiˇnt"}); diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index a9780ed6c2..f77e9ae08c 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -1039,9 +1039,7 @@ pub struct Editor { inline_diagnostics: Vec<(Anchor, InlineDiagnostic)>, soft_wrap_mode_override: Option, hard_wrap: Option, - - // TODO: make this a access method - pub project: Option>, + project: Option>, semantics_provider: Option>, completion_provider: Option>, collaboration_hub: Option>, @@ -2326,7 +2324,7 @@ impl Editor { editor.go_to_active_debug_line(window, cx); if let Some(buffer) = buffer.read(cx).as_singleton() { - if let Some(project) = editor.project.as_ref() { + if let Some(project) = editor.project() { let handle = project.update(cx, |project, cx| { project.register_buffer_with_language_servers(&buffer, cx) }); @@ -2626,6 +2624,10 @@ impl Editor { &self.buffer } + pub fn project(&self) -> Option<&Entity> { + self.project.as_ref() + } + pub fn workspace(&self) -> Option> { self.workspace.as_ref()?.0.upgrade() } @@ -5212,7 +5214,7 @@ impl Editor { restrict_to_languages: Option<&HashSet>>, cx: &mut Context, ) -> HashMap, clock::Global, Range)> { - let Some(project) = self.project.as_ref() else { + let Some(project) = self.project() else { return HashMap::default(); }; let project = project.read(cx); @@ -5294,7 +5296,7 @@ impl Editor { return None; } - let project = self.project.as_ref()?; + let project = self.project()?; let position = self.selections.newest_anchor().head(); let (buffer, buffer_position) = self .buffer @@ -6141,7 +6143,7 @@ impl Editor { cx: &mut App, ) -> Task> { maybe!({ - let project = self.project.as_ref()?; + let project = self.project()?; let dap_store = project.read(cx).dap_store(); let mut scenarios = vec![]; let resolved_tasks = resolved_tasks.as_ref()?; @@ -7907,7 +7909,7 @@ impl Editor { let snapshot = self.snapshot(window, cx); let multi_buffer_snapshot = &snapshot.display_snapshot.buffer_snapshot; - let Some(project) = self.project.as_ref() else { + let Some(project) = self.project() else { return breakpoint_display_points; }; @@ -10501,7 +10503,7 @@ impl Editor { ) { if let Some(working_directory) = self.active_excerpt(cx).and_then(|(_, buffer, _)| { let project_path = buffer.read(cx).project_path(cx)?; - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); let entry = project.entry_for_path(&project_path, cx)?; let parent = match &entry.canonical_path { Some(canonical_path) => canonical_path.to_path_buf(), @@ -14875,7 +14877,7 @@ impl Editor { self.clear_tasks(); return Task::ready(()); } - let project = self.project.as_ref().map(Entity::downgrade); + let project = self.project().map(Entity::downgrade); let task_sources = self.lsp_task_sources(cx); let multi_buffer = self.buffer.downgrade(); cx.spawn_in(window, async move |editor, cx| { @@ -17054,7 +17056,7 @@ impl Editor { if !pull_diagnostics_settings.enabled { return None; } - let project = self.project.as_ref()?.downgrade(); + let project = self.project()?.downgrade(); let debounce = Duration::from_millis(pull_diagnostics_settings.debounce_ms); let mut buffers = self.buffer.read(cx).all_buffers(); if let Some(buffer_id) = buffer_id { @@ -18018,7 +18020,7 @@ impl Editor { hunks: impl Iterator, cx: &mut App, ) -> Option<()> { - let project = self.project.as_ref()?; + let project = self.project()?; let buffer = project.read(cx).buffer_for_id(buffer_id, cx)?; let diff = self.buffer.read(cx).diff_for(buffer_id)?; let buffer_snapshot = buffer.read(cx).snapshot(); @@ -18678,7 +18680,7 @@ impl Editor { self.active_excerpt(cx).and_then(|(_, buffer, _)| { let buffer = buffer.read(cx); if let Some(project_path) = buffer.project_path(cx) { - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); project.absolute_path(&project_path, cx) } else { buffer @@ -18691,7 +18693,7 @@ impl Editor { fn target_file_path(&self, cx: &mut Context) -> Option { self.active_excerpt(cx).and_then(|(_, buffer, _)| { let project_path = buffer.read(cx).project_path(cx)?; - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); let entry = project.entry_for_path(&project_path, cx)?; let path = entry.path.to_path_buf(); Some(path) @@ -18912,7 +18914,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) { - if let Some(project) = self.project.as_ref() { + if let Some(project) = self.project() { let Some(buffer) = self.buffer().read(cx).as_singleton() else { return; }; @@ -19028,7 +19030,7 @@ impl Editor { return Task::ready(Err(anyhow!("failed to determine buffer and selection"))); }; - let Some(project) = self.project.as_ref() else { + let Some(project) = self.project() else { return Task::ready(Err(anyhow!("editor does not have project"))); }; @@ -21015,7 +21017,7 @@ impl Editor { cx: &mut Context, ) { let workspace = self.workspace(); - let project = self.project.as_ref(); + let project = self.project(); let save_tasks = self.buffer().update(cx, |multi_buffer, cx| { let mut tasks = Vec::new(); for (buffer_id, changes) in revert_changes { diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index a5966b3301..cf9954bc12 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -15082,7 +15082,7 @@ async fn go_to_prev_overlapping_diagnostic(executor: BackgroundExecutor, cx: &mu let mut cx = EditorTestContext::new(cx).await; let lsp_store = - cx.update_editor(|editor, _, cx| editor.project.as_ref().unwrap().read(cx).lsp_store()); + cx.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).lsp_store()); cx.set_state(indoc! {" ˇfn func(abc def: i32) -> u32 { diff --git a/crates/editor/src/hover_popover.rs b/crates/editor/src/hover_popover.rs index bda229e346..3fc673bad9 100644 --- a/crates/editor/src/hover_popover.rs +++ b/crates/editor/src/hover_popover.rs @@ -251,7 +251,7 @@ fn show_hover( let (excerpt_id, _, _) = editor.buffer().read(cx).excerpt_containing(anchor, cx)?; - let language_registry = editor.project.as_ref()?.read(cx).languages().clone(); + let language_registry = editor.project()?.read(cx).languages().clone(); let provider = editor.semantics_provider.clone()?; if !ignore_timeout { diff --git a/crates/editor/src/items.rs b/crates/editor/src/items.rs index 45a4f7365c..34533002ff 100644 --- a/crates/editor/src/items.rs +++ b/crates/editor/src/items.rs @@ -678,7 +678,7 @@ impl Item for Editor { let buffer = buffer.read(cx); let path = buffer.project_path(cx)?; let buffer_id = buffer.remote_id(); - let project = self.project.as_ref()?.read(cx); + let project = self.project()?.read(cx); let entry = project.entry_for_path(&path, cx)?; let (repo, repo_path) = project .git_store() diff --git a/crates/editor/src/linked_editing_ranges.rs b/crates/editor/src/linked_editing_ranges.rs index a185de33ca..aaf9032b04 100644 --- a/crates/editor/src/linked_editing_ranges.rs +++ b/crates/editor/src/linked_editing_ranges.rs @@ -51,7 +51,7 @@ pub(super) fn refresh_linked_ranges( if editor.pending_rename.is_some() { return None; } - let project = editor.project.as_ref()?.downgrade(); + let project = editor.project()?.downgrade(); editor.linked_editing_range_task = Some(cx.spawn_in(window, async move |editor, cx| { cx.background_executor().timer(UPDATE_DEBOUNCE).await; diff --git a/crates/editor/src/signature_help.rs b/crates/editor/src/signature_help.rs index e9f8d2dbd3..e0736a6e9f 100644 --- a/crates/editor/src/signature_help.rs +++ b/crates/editor/src/signature_help.rs @@ -169,7 +169,7 @@ impl Editor { else { return; }; - let Some(lsp_store) = self.project.as_ref().map(|p| p.read(cx).lsp_store()) else { + let Some(lsp_store) = self.project().map(|p| p.read(cx).lsp_store()) else { return; }; let task = lsp_store.update(cx, |lsp_store, cx| { diff --git a/crates/editor/src/test/editor_test_context.rs b/crates/editor/src/test/editor_test_context.rs index bdf73da5fb..dbb519c40e 100644 --- a/crates/editor/src/test/editor_test_context.rs +++ b/crates/editor/src/test/editor_test_context.rs @@ -297,9 +297,8 @@ impl EditorTestContext { pub fn set_head_text(&mut self, diff_base: &str) { self.cx.run_until_parked(); - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone()); fs.set_head_for_repo( &Self::root_path().join(".git"), @@ -311,18 +310,16 @@ impl EditorTestContext { pub fn clear_index_text(&mut self) { self.cx.run_until_parked(); - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); fs.set_index_for_repo(&Self::root_path().join(".git"), &[]); self.cx.run_until_parked(); } pub fn set_index_text(&mut self, diff_base: &str) { self.cx.run_until_parked(); - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone()); fs.set_index_for_repo( &Self::root_path().join(".git"), @@ -333,9 +330,8 @@ impl EditorTestContext { #[track_caller] pub fn assert_index_text(&mut self, expected: Option<&str>) { - let fs = self.update_editor(|editor, _, cx| { - editor.project.as_ref().unwrap().read(cx).fs().as_fake() - }); + let fs = + self.update_editor(|editor, _, cx| editor.project().unwrap().read(cx).fs().as_fake()); let path = self.update_buffer(|buffer, _| buffer.file().unwrap().path().clone()); let mut found = None; fs.with_git_state(&Self::root_path().join(".git"), false, |git_state| { diff --git a/crates/git_ui/src/conflict_view.rs b/crates/git_ui/src/conflict_view.rs index 0bbb9411be..6482ebb9f8 100644 --- a/crates/git_ui/src/conflict_view.rs +++ b/crates/git_ui/src/conflict_view.rs @@ -112,7 +112,7 @@ fn excerpt_for_buffer_updated( } fn buffer_added(editor: &mut Editor, buffer: Entity, cx: &mut Context) { - let Some(project) = &editor.project else { + let Some(project) = editor.project() else { return; }; let git_store = project.read(cx).git_store().clone(); @@ -469,7 +469,7 @@ pub(crate) fn resolve_conflict( let Some((workspace, project, multibuffer, buffer)) = editor .update(cx, |editor, cx| { let workspace = editor.workspace()?; - let project = editor.project.clone()?; + let project = editor.project()?.clone(); let multibuffer = editor.buffer().clone(); let buffer_id = resolved_conflict.ours.end.buffer_id?; let buffer = multibuffer.read(cx).buffer(buffer_id)?; diff --git a/crates/vim/src/command.rs b/crates/vim/src/command.rs index 264fa4bf2f..ce5e5a0300 100644 --- a/crates/vim/src/command.rs +++ b/crates/vim/src/command.rs @@ -299,7 +299,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, action: &VimSave, window, cx| { vim.update_editor(cx, |_, editor, cx| { - let Some(project) = editor.project.clone() else { + let Some(project) = editor.project().cloned() else { return; }; let Some(worktree) = project.read(cx).visible_worktrees(cx).next() else { @@ -436,7 +436,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { let Some(workspace) = vim.workspace(window) else { return; }; - let Some(project) = editor.project.clone() else { + let Some(project) = editor.project().cloned() else { return; }; let Some(worktree) = project.read(cx).visible_worktrees(cx).next() else { diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index da4b6e78c6..5b0826413b 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -229,8 +229,7 @@ fn assign_edit_prediction_provider( if let Some(file) = buffer.read(cx).file() { let id = file.worktree_id(cx); if let Some(inner_worktree) = editor - .project - .as_ref() + .project() .and_then(|project| project.read(cx).worktree_for_id(id, cx)) { worktree = Some(inner_worktree); From 19318897597071a64282d3bf4e1c4846485e7333 Mon Sep 17 00:00:00 2001 From: Cole Miller Date: Fri, 15 Aug 2025 14:55:34 -0400 Subject: [PATCH 05/17] thread_view: Move handlers for confirmed completions to the MessageEditor (#36214) Release Notes: - N/A --------- Co-authored-by: Conrad Irwin --- .../agent_ui/src/acp/completion_provider.rs | 435 +++++------------- crates/agent_ui/src/acp/message_editor.rs | 360 ++++++++++++--- crates/agent_ui/src/context_picker.rs | 41 +- crates/editor/src/editor.rs | 28 ++ 4 files changed, 455 insertions(+), 409 deletions(-) diff --git a/crates/agent_ui/src/acp/completion_provider.rs b/crates/agent_ui/src/acp/completion_provider.rs index adcfab85b1..4ee1eb6948 100644 --- a/crates/agent_ui/src/acp/completion_provider.rs +++ b/crates/agent_ui/src/acp/completion_provider.rs @@ -1,38 +1,34 @@ use std::ffi::OsStr; use std::ops::Range; -use std::path::{Path, PathBuf}; +use std::path::Path; use std::sync::Arc; use std::sync::atomic::AtomicBool; -use acp_thread::{MentionUri, selection_name}; +use acp_thread::MentionUri; use anyhow::{Context as _, Result, anyhow}; -use collections::{HashMap, HashSet}; +use collections::HashMap; use editor::display_map::CreaseId; -use editor::{CompletionProvider, Editor, ExcerptId, ToOffset as _}; +use editor::{CompletionProvider, Editor, ExcerptId}; use futures::future::{Shared, try_join_all}; -use futures::{FutureExt, TryFutureExt}; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{App, Entity, ImageFormat, Img, Task, WeakEntity}; use http_client::HttpClientWithUrl; -use itertools::Itertools as _; use language::{Buffer, CodeLabel, HighlightId}; use language_model::LanguageModelImage; use lsp::CompletionContext; -use parking_lot::Mutex; use project::{ Completion, CompletionIntent, CompletionResponse, Project, ProjectPath, Symbol, WorktreeId, }; use prompt_store::PromptStore; use rope::Point; -use text::{Anchor, OffsetRangeExt as _, ToPoint as _}; +use text::{Anchor, ToPoint as _}; use ui::prelude::*; use url::Url; use workspace::Workspace; -use workspace::notifications::NotifyResultExt; use agent::thread_store::{TextThreadStore, ThreadStore}; -use crate::context_picker::fetch_context_picker::fetch_url_content; +use crate::acp::message_editor::MessageEditor; use crate::context_picker::file_context_picker::{FileMatch, search_files}; use crate::context_picker::rules_context_picker::{RulesContextEntry, search_rules}; use crate::context_picker::symbol_context_picker::SymbolMatch; @@ -54,7 +50,7 @@ pub struct MentionImage { #[derive(Default)] pub struct MentionSet { - uri_by_crease_id: HashMap, + pub(crate) uri_by_crease_id: HashMap, fetch_results: HashMap>>>, images: HashMap>>>, } @@ -488,36 +484,31 @@ fn search( } pub struct ContextPickerCompletionProvider { - mention_set: Arc>, workspace: WeakEntity, thread_store: WeakEntity, text_thread_store: WeakEntity, - editor: WeakEntity, + message_editor: WeakEntity, } impl ContextPickerCompletionProvider { pub fn new( - mention_set: Arc>, workspace: WeakEntity, thread_store: WeakEntity, text_thread_store: WeakEntity, - editor: WeakEntity, + message_editor: WeakEntity, ) -> Self { Self { - mention_set, workspace, thread_store, text_thread_store, - editor, + message_editor, } } fn completion_for_entry( entry: ContextPickerEntry, - excerpt_id: ExcerptId, source_range: Range, - editor: Entity, - mention_set: Arc>, + message_editor: WeakEntity, workspace: &Entity, cx: &mut App, ) -> Option { @@ -538,88 +529,39 @@ impl ContextPickerCompletionProvider { ContextPickerEntry::Action(action) => { let (new_text, on_action) = match action { ContextPickerAction::AddSelections => { - let selections = selection_ranges(workspace, cx); - const PLACEHOLDER: &str = "selection "; + let selections = selection_ranges(workspace, cx) + .into_iter() + .enumerate() + .map(|(ix, (buffer, range))| { + ( + buffer, + range, + (PLACEHOLDER.len() * ix)..(PLACEHOLDER.len() * (ix + 1) - 1), + ) + }) + .collect::>(); - let new_text = std::iter::repeat(PLACEHOLDER) - .take(selections.len()) - .chain(std::iter::once("")) - .join(" "); + let new_text: String = PLACEHOLDER.repeat(selections.len()); let callback = Arc::new({ - let mention_set = mention_set.clone(); - let selections = selections.clone(); + let source_range = source_range.clone(); move |_, window: &mut Window, cx: &mut App| { - let editor = editor.clone(); - let mention_set = mention_set.clone(); let selections = selections.clone(); + let message_editor = message_editor.clone(); + let source_range = source_range.clone(); window.defer(cx, move |window, cx| { - let mut current_offset = 0; - - for (buffer, selection_range) in selections { - let snapshot = - editor.read(cx).buffer().read(cx).snapshot(cx); - let Some(start) = snapshot - .anchor_in_excerpt(excerpt_id, source_range.start) - else { - return; - }; - - let offset = start.to_offset(&snapshot) + current_offset; - let text_len = PLACEHOLDER.len() - 1; - - let range = snapshot.anchor_after(offset) - ..snapshot.anchor_after(offset + text_len); - - let path = buffer - .read(cx) - .file() - .map_or(PathBuf::from("untitled"), |file| { - file.path().to_path_buf() - }); - - let point_range = snapshot - .as_singleton() - .map(|(_, _, snapshot)| { - selection_range.to_point(&snapshot) - }) - .unwrap_or_default(); - let line_range = point_range.start.row..point_range.end.row; - - let uri = MentionUri::Selection { - path: path.clone(), - line_range: line_range.clone(), - }; - let crease = crate::context_picker::crease_for_mention( - selection_name(&path, &line_range).into(), - uri.icon_path(cx), - range, - editor.downgrade(), - ); - - let [crease_id]: [_; 1] = - editor.update(cx, |editor, cx| { - let crease_ids = - editor.insert_creases(vec![crease.clone()], cx); - editor.fold_creases( - vec![crease], - false, - window, - cx, - ); - crease_ids.try_into().unwrap() - }); - - mention_set.lock().insert_uri( - crease_id, - MentionUri::Selection { path, line_range }, - ); - - current_offset += text_len + 1; - } + message_editor + .update(cx, |message_editor, cx| { + message_editor.confirm_mention_for_selection( + source_range, + selections, + window, + cx, + ) + }) + .ok(); }); - false } }); @@ -647,11 +589,9 @@ impl ContextPickerCompletionProvider { fn completion_for_thread( thread_entry: ThreadContextEntry, - excerpt_id: ExcerptId, source_range: Range, recent: bool, - editor: Entity, - mention_set: Arc>, + editor: WeakEntity, cx: &mut App, ) -> Completion { let uri = match &thread_entry { @@ -683,13 +623,10 @@ impl ContextPickerCompletionProvider { source: project::CompletionSource::Custom, icon_path: Some(icon_for_completion.clone()), confirm: Some(confirm_completion_callback( - uri.icon_path(cx), thread_entry.title().clone(), - excerpt_id, source_range.start, new_text_len - 1, - editor.clone(), - mention_set, + editor, uri, )), } @@ -697,10 +634,8 @@ impl ContextPickerCompletionProvider { fn completion_for_rules( rule: RulesContextEntry, - excerpt_id: ExcerptId, source_range: Range, - editor: Entity, - mention_set: Arc>, + editor: WeakEntity, cx: &mut App, ) -> Completion { let uri = MentionUri::Rule { @@ -719,13 +654,10 @@ impl ContextPickerCompletionProvider { source: project::CompletionSource::Custom, icon_path: Some(icon_path.clone()), confirm: Some(confirm_completion_callback( - icon_path, rule.title.clone(), - excerpt_id, source_range.start, new_text_len - 1, - editor.clone(), - mention_set, + editor, uri, )), } @@ -736,10 +668,8 @@ impl ContextPickerCompletionProvider { path_prefix: &str, is_recent: bool, is_directory: bool, - excerpt_id: ExcerptId, source_range: Range, - editor: Entity, - mention_set: Arc>, + message_editor: WeakEntity, project: Entity, cx: &mut App, ) -> Option { @@ -777,13 +707,10 @@ impl ContextPickerCompletionProvider { icon_path: Some(completion_icon_path), insert_text_mode: None, confirm: Some(confirm_completion_callback( - crease_icon_path, file_name, - excerpt_id, source_range.start, new_text_len - 1, - editor, - mention_set.clone(), + message_editor, file_uri, )), }) @@ -791,10 +718,8 @@ impl ContextPickerCompletionProvider { fn completion_for_symbol( symbol: Symbol, - excerpt_id: ExcerptId, source_range: Range, - editor: Entity, - mention_set: Arc>, + message_editor: WeakEntity, workspace: Entity, cx: &mut App, ) -> Option { @@ -820,13 +745,10 @@ impl ContextPickerCompletionProvider { icon_path: Some(icon_path.clone()), insert_text_mode: None, confirm: Some(confirm_completion_callback( - icon_path, symbol.name.clone().into(), - excerpt_id, source_range.start, new_text_len - 1, - editor.clone(), - mention_set.clone(), + message_editor, uri, )), }) @@ -835,112 +757,46 @@ impl ContextPickerCompletionProvider { fn completion_for_fetch( source_range: Range, url_to_fetch: SharedString, - excerpt_id: ExcerptId, - editor: Entity, - mention_set: Arc>, + message_editor: WeakEntity, http_client: Arc, cx: &mut App, ) -> Option { let new_text = format!("@fetch {} ", url_to_fetch.clone()); - let new_text_len = new_text.len(); + let url_to_fetch = url::Url::parse(url_to_fetch.as_ref()) + .or_else(|_| url::Url::parse(&format!("https://{url_to_fetch}"))) + .ok()?; let mention_uri = MentionUri::Fetch { - url: url::Url::parse(url_to_fetch.as_ref()) - .or_else(|_| url::Url::parse(&format!("https://{url_to_fetch}"))) - .ok()?, + url: url_to_fetch.clone(), }; let icon_path = mention_uri.icon_path(cx); Some(Completion { replace_range: source_range.clone(), - new_text, + new_text: new_text.clone(), label: CodeLabel::plain(url_to_fetch.to_string(), None), documentation: None, source: project::CompletionSource::Custom, icon_path: Some(icon_path.clone()), insert_text_mode: None, confirm: Some({ - let start = source_range.start; - let content_len = new_text_len - 1; - let editor = editor.clone(); - let url_to_fetch = url_to_fetch.clone(); - let source_range = source_range.clone(); - let icon_path = icon_path.clone(); - let mention_uri = mention_uri.clone(); Arc::new(move |_, window, cx| { - let Some(url) = url::Url::parse(url_to_fetch.as_ref()) - .or_else(|_| url::Url::parse(&format!("https://{url_to_fetch}"))) - .notify_app_err(cx) - else { - return false; - }; - - let editor = editor.clone(); - let mention_set = mention_set.clone(); - let http_client = http_client.clone(); + let url_to_fetch = url_to_fetch.clone(); let source_range = source_range.clone(); - let icon_path = icon_path.clone(); - let mention_uri = mention_uri.clone(); + let message_editor = message_editor.clone(); + let new_text = new_text.clone(); + let http_client = http_client.clone(); window.defer(cx, move |window, cx| { - let url = url.clone(); - - let Some(crease_id) = crate::context_picker::insert_crease_for_mention( - excerpt_id, - start, - content_len, - url.to_string().into(), - icon_path, - editor.clone(), - window, - cx, - ) else { - return; - }; - - let editor = editor.clone(); - let mention_set = mention_set.clone(); - let http_client = http_client.clone(); - let source_range = source_range.clone(); - - let url_string = url.to_string(); - let fetch = cx - .background_executor() - .spawn(async move { - fetch_url_content(http_client, url_string) - .map_err(|e| e.to_string()) - .await + message_editor + .update(cx, |message_editor, cx| { + message_editor.confirm_mention_for_fetch( + new_text, + source_range, + url_to_fetch, + http_client, + window, + cx, + ) }) - .shared(); - mention_set.lock().add_fetch_result(url, fetch.clone()); - - window - .spawn(cx, async move |cx| { - if fetch.await.notify_async_err(cx).is_some() { - mention_set - .lock() - .insert_uri(crease_id, mention_uri.clone()); - } else { - // Remove crease if we failed to fetch - editor - .update(cx, |editor, cx| { - let snapshot = editor.buffer().read(cx).snapshot(cx); - let Some(anchor) = snapshot - .anchor_in_excerpt(excerpt_id, source_range.start) - else { - return; - }; - editor.display_map.update(cx, |display_map, cx| { - display_map.unfold_intersecting( - vec![anchor..anchor], - true, - cx, - ); - }); - editor.remove_creases([crease_id], cx); - }) - .ok(); - } - Some(()) - }) - .detach(); + .ok(); }); false }) @@ -968,7 +824,7 @@ fn build_code_label_for_full_path(file_name: &str, directory: Option<&str>, cx: impl CompletionProvider for ContextPickerCompletionProvider { fn completions( &self, - excerpt_id: ExcerptId, + _excerpt_id: ExcerptId, buffer: &Entity, buffer_position: Anchor, _trigger: CompletionContext, @@ -999,32 +855,18 @@ impl CompletionProvider for ContextPickerCompletionProvider { let thread_store = self.thread_store.clone(); let text_thread_store = self.text_thread_store.clone(); - let editor = self.editor.clone(); + let editor = self.message_editor.clone(); + let Ok((exclude_paths, exclude_threads)) = + self.message_editor.update(cx, |message_editor, cx| { + message_editor.mentioned_path_and_threads(cx) + }) + else { + return Task::ready(Ok(Vec::new())); + }; let MentionCompletion { mode, argument, .. } = state; let query = argument.unwrap_or_else(|| "".to_string()); - let (exclude_paths, exclude_threads) = { - let mention_set = self.mention_set.lock(); - - let mut excluded_paths = HashSet::default(); - let mut excluded_threads = HashSet::default(); - - for uri in mention_set.uri_by_crease_id.values() { - match uri { - MentionUri::File { abs_path, .. } => { - excluded_paths.insert(abs_path.clone()); - } - MentionUri::Thread { id, .. } => { - excluded_threads.insert(id.clone()); - } - _ => {} - } - } - - (excluded_paths, excluded_threads) - }; - let recent_entries = recent_context_picker_entries( Some(thread_store.clone()), Some(text_thread_store.clone()), @@ -1051,13 +893,8 @@ impl CompletionProvider for ContextPickerCompletionProvider { cx, ); - let mention_set = self.mention_set.clone(); - cx.spawn(async move |_, cx| { let matches = search_task.await; - let Some(editor) = editor.upgrade() else { - return Ok(Vec::new()); - }; let completions = cx.update(|cx| { matches @@ -1074,10 +911,8 @@ impl CompletionProvider for ContextPickerCompletionProvider { &mat.path_prefix, is_recent, mat.is_dir, - excerpt_id, source_range.clone(), editor.clone(), - mention_set.clone(), project.clone(), cx, ) @@ -1085,10 +920,8 @@ impl CompletionProvider for ContextPickerCompletionProvider { Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol( symbol, - excerpt_id, source_range.clone(), editor.clone(), - mention_set.clone(), workspace.clone(), cx, ), @@ -1097,39 +930,31 @@ impl CompletionProvider for ContextPickerCompletionProvider { thread, is_recent, .. }) => Some(Self::completion_for_thread( thread, - excerpt_id, source_range.clone(), is_recent, editor.clone(), - mention_set.clone(), cx, )), Match::Rules(user_rules) => Some(Self::completion_for_rules( user_rules, - excerpt_id, source_range.clone(), editor.clone(), - mention_set.clone(), cx, )), Match::Fetch(url) => Self::completion_for_fetch( source_range.clone(), url, - excerpt_id, editor.clone(), - mention_set.clone(), http_client.clone(), cx, ), Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry( entry, - excerpt_id, source_range.clone(), editor.clone(), - mention_set.clone(), &workspace, cx, ), @@ -1182,36 +1007,30 @@ impl CompletionProvider for ContextPickerCompletionProvider { } fn confirm_completion_callback( - crease_icon_path: SharedString, crease_text: SharedString, - excerpt_id: ExcerptId, start: Anchor, content_len: usize, - editor: Entity, - mention_set: Arc>, + message_editor: WeakEntity, mention_uri: MentionUri, ) -> Arc bool + Send + Sync> { Arc::new(move |_, window, cx| { + let message_editor = message_editor.clone(); let crease_text = crease_text.clone(); - let crease_icon_path = crease_icon_path.clone(); - let editor = editor.clone(); - let mention_set = mention_set.clone(); let mention_uri = mention_uri.clone(); window.defer(cx, move |window, cx| { - if let Some(crease_id) = crate::context_picker::insert_crease_for_mention( - excerpt_id, - start, - content_len, - crease_text.clone(), - crease_icon_path, - editor.clone(), - window, - cx, - ) { - mention_set - .lock() - .insert_uri(crease_id, mention_uri.clone()); - } + message_editor + .clone() + .update(cx, |message_editor, cx| { + message_editor.confirm_completion( + crease_text, + start, + content_len, + mention_uri, + window, + cx, + ) + }) + .ok(); }); false }) @@ -1279,13 +1098,13 @@ impl MentionCompletion { #[cfg(test)] mod tests { use super::*; - use editor::AnchorRangeExt; + use editor::{AnchorRangeExt, EditorMode}; use gpui::{EventEmitter, FocusHandle, Focusable, TestAppContext, VisualTestContext}; use project::{Project, ProjectPath}; use serde_json::json; use settings::SettingsStore; use smol::stream::StreamExt as _; - use std::{ops::Deref, path::Path, rc::Rc}; + use std::{ops::Deref, path::Path}; use util::path; use workspace::{AppState, Item}; @@ -1359,9 +1178,9 @@ mod tests { assert_eq!(MentionCompletion::try_parse("test@", 0), None); } - struct AtMentionEditor(Entity); + struct MessageEditorItem(Entity); - impl Item for AtMentionEditor { + impl Item for MessageEditorItem { type Event = (); fn include_in_nav_history() -> bool { @@ -1373,15 +1192,15 @@ mod tests { } } - impl EventEmitter<()> for AtMentionEditor {} + impl EventEmitter<()> for MessageEditorItem {} - impl Focusable for AtMentionEditor { + impl Focusable for MessageEditorItem { fn focus_handle(&self, cx: &App) -> FocusHandle { self.0.read(cx).focus_handle(cx).clone() } } - impl Render for AtMentionEditor { + impl Render for MessageEditorItem { fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { self.0.clone().into_any_element() } @@ -1467,19 +1286,28 @@ mod tests { opened_editors.push(buffer); } - let editor = workspace.update_in(&mut cx, |workspace, window, cx| { - let editor = cx.new(|cx| { - Editor::new( - editor::EditorMode::full(), - multi_buffer::MultiBuffer::build_simple("", cx), - None, + let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx)); + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + + let (message_editor, editor) = workspace.update_in(&mut cx, |workspace, window, cx| { + let workspace_handle = cx.weak_entity(); + let message_editor = cx.new(|cx| { + MessageEditor::new( + workspace_handle, + project.clone(), + thread_store.clone(), + text_thread_store.clone(), + EditorMode::AutoHeight { + max_lines: None, + min_lines: 1, + }, window, cx, ) }); workspace.active_pane().update(cx, |pane, cx| { pane.add_item( - Box::new(cx.new(|_| AtMentionEditor(editor.clone()))), + Box::new(cx.new(|_| MessageEditorItem(message_editor.clone()))), true, true, None, @@ -1487,24 +1315,9 @@ mod tests { cx, ); }); - editor - }); - - let mention_set = Arc::new(Mutex::new(MentionSet::default())); - - let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx)); - let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); - - let editor_entity = editor.downgrade(); - editor.update_in(&mut cx, |editor, window, cx| { - window.focus(&editor.focus_handle(cx)); - editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new( - mention_set.clone(), - workspace.downgrade(), - thread_store.downgrade(), - text_thread_store.downgrade(), - editor_entity, - )))); + message_editor.read(cx).focus_handle(cx).focus(window); + let editor = message_editor.read(cx).editor().clone(); + (message_editor, editor) }); cx.simulate_input("Lorem "); @@ -1573,9 +1386,9 @@ mod tests { ); }); - let contents = cx - .update(|window, cx| { - mention_set.lock().contents( + let contents = message_editor + .update_in(&mut cx, |message_editor, window, cx| { + message_editor.mention_set().contents( project.clone(), thread_store.clone(), text_thread_store.clone(), @@ -1641,9 +1454,9 @@ mod tests { cx.run_until_parked(); - let contents = cx - .update(|window, cx| { - mention_set.lock().contents( + let contents = message_editor + .update_in(&mut cx, |message_editor, window, cx| { + message_editor.mention_set().contents( project.clone(), thread_store.clone(), text_thread_store.clone(), @@ -1765,9 +1578,9 @@ mod tests { editor.confirm_completion(&editor::actions::ConfirmCompletion::default(), window, cx); }); - let contents = cx - .update(|window, cx| { - mention_set.lock().contents( + let contents = message_editor + .update_in(&mut cx, |message_editor, window, cx| { + message_editor.mention_set().contents( project.clone(), thread_store, text_thread_store, diff --git a/crates/agent_ui/src/acp/message_editor.rs b/crates/agent_ui/src/acp/message_editor.rs index 8d512948dd..32c37da519 100644 --- a/crates/agent_ui/src/acp/message_editor.rs +++ b/crates/agent_ui/src/acp/message_editor.rs @@ -1,56 +1,55 @@ -use crate::acp::completion_provider::ContextPickerCompletionProvider; -use crate::acp::completion_provider::MentionImage; -use crate::acp::completion_provider::MentionSet; -use acp_thread::MentionUri; -use agent::TextThreadStore; -use agent::ThreadStore; +use crate::{ + acp::completion_provider::{ContextPickerCompletionProvider, MentionImage, MentionSet}, + context_picker::fetch_context_picker::fetch_url_content, +}; +use acp_thread::{MentionUri, selection_name}; +use agent::{TextThreadStore, ThreadId, ThreadStore}; use agent_client_protocol as acp; use anyhow::Result; use collections::HashSet; -use editor::ExcerptId; -use editor::actions::Paste; -use editor::display_map::CreaseId; use editor::{ - AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode, - EditorStyle, MultiBuffer, + Anchor, AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, + EditorMode, EditorStyle, ExcerptId, FoldPlaceholder, MultiBuffer, ToOffset, + actions::Paste, + display_map::{Crease, CreaseId, FoldId}, }; -use futures::FutureExt as _; -use gpui::ClipboardEntry; -use gpui::Image; -use gpui::ImageFormat; +use futures::{FutureExt as _, TryFutureExt as _}; use gpui::{ - AppContext, Context, Entity, EventEmitter, FocusHandle, Focusable, Task, TextStyle, WeakEntity, + AppContext, ClipboardEntry, Context, Entity, EventEmitter, FocusHandle, Focusable, Image, + ImageFormat, Task, TextStyle, WeakEntity, }; -use language::Buffer; -use language::Language; +use http_client::HttpClientWithUrl; +use language::{Buffer, Language}; use language_model::LanguageModelImage; -use parking_lot::Mutex; use project::{CompletionIntent, Project}; use settings::Settings; -use std::fmt::Write; -use std::path::Path; -use std::rc::Rc; -use std::sync::Arc; +use std::{ + fmt::Write, + ops::Range, + path::{Path, PathBuf}, + rc::Rc, + sync::Arc, +}; +use text::OffsetRangeExt; use theme::ThemeSettings; -use ui::IconName; -use ui::SharedString; use ui::{ - ActiveTheme, App, InteractiveElement, IntoElement, ParentElement, Render, Styled, TextSize, - Window, div, + ActiveTheme, AnyElement, App, ButtonCommon, ButtonLike, ButtonStyle, Color, Icon, IconName, + IconSize, InteractiveElement, IntoElement, Label, LabelCommon, LabelSize, ParentElement, + Render, SelectableButton, SharedString, Styled, TextSize, TintColor, Toggleable, Window, div, + h_flex, }; use util::ResultExt; -use workspace::Workspace; -use workspace::notifications::NotifyResultExt as _; +use workspace::{Workspace, notifications::NotifyResultExt as _}; use zed_actions::agent::Chat; use super::completion_provider::Mention; pub struct MessageEditor { + mention_set: MentionSet, editor: Entity, project: Entity, thread_store: Entity, text_thread_store: Entity, - mention_set: Arc>, } pub enum MessageEditorEvent { @@ -77,8 +76,13 @@ impl MessageEditor { }, None, ); - - let mention_set = Arc::new(Mutex::new(MentionSet::default())); + let completion_provider = ContextPickerCompletionProvider::new( + workspace, + thread_store.downgrade(), + text_thread_store.downgrade(), + cx.weak_entity(), + ); + let mention_set = MentionSet::default(); let editor = cx.new(|cx| { let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx)); let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); @@ -88,13 +92,7 @@ impl MessageEditor { editor.set_show_indent_guides(false, cx); editor.set_soft_wrap(); editor.set_use_modal_editing(true); - editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new( - mention_set.clone(), - workspace, - thread_store.downgrade(), - text_thread_store.downgrade(), - cx.weak_entity(), - )))); + editor.set_completion_provider(Some(Rc::new(completion_provider))); editor.set_context_menu_options(ContextMenuOptions { min_entries_visible: 12, max_entries_visible: 12, @@ -112,16 +110,202 @@ impl MessageEditor { } } + #[cfg(test)] + pub(crate) fn editor(&self) -> &Entity { + &self.editor + } + + #[cfg(test)] + pub(crate) fn mention_set(&mut self) -> &mut MentionSet { + &mut self.mention_set + } + pub fn is_empty(&self, cx: &App) -> bool { self.editor.read(cx).is_empty(cx) } + pub fn mentioned_path_and_threads(&self, _: &App) -> (HashSet, HashSet) { + let mut excluded_paths = HashSet::default(); + let mut excluded_threads = HashSet::default(); + + for uri in self.mention_set.uri_by_crease_id.values() { + match uri { + MentionUri::File { abs_path, .. } => { + excluded_paths.insert(abs_path.clone()); + } + MentionUri::Thread { id, .. } => { + excluded_threads.insert(id.clone()); + } + _ => {} + } + } + + (excluded_paths, excluded_threads) + } + + pub fn confirm_completion( + &mut self, + crease_text: SharedString, + start: text::Anchor, + content_len: usize, + mention_uri: MentionUri, + window: &mut Window, + cx: &mut Context, + ) { + let snapshot = self + .editor + .update(cx, |editor, cx| editor.snapshot(window, cx)); + let Some((excerpt_id, _, _)) = snapshot.buffer_snapshot.as_singleton() else { + return; + }; + + if let Some(crease_id) = crate::context_picker::insert_crease_for_mention( + *excerpt_id, + start, + content_len, + crease_text.clone(), + mention_uri.icon_path(cx), + self.editor.clone(), + window, + cx, + ) { + self.mention_set.insert_uri(crease_id, mention_uri.clone()); + } + } + + pub fn confirm_mention_for_fetch( + &mut self, + new_text: String, + source_range: Range, + url: url::Url, + http_client: Arc, + window: &mut Window, + cx: &mut Context, + ) { + let mention_uri = MentionUri::Fetch { url: url.clone() }; + let icon_path = mention_uri.icon_path(cx); + + let start = source_range.start; + let content_len = new_text.len() - 1; + + let snapshot = self + .editor + .update(cx, |editor, cx| editor.snapshot(window, cx)); + let Some((&excerpt_id, _, _)) = snapshot.buffer_snapshot.as_singleton() else { + return; + }; + + let Some(crease_id) = crate::context_picker::insert_crease_for_mention( + excerpt_id, + start, + content_len, + url.to_string().into(), + icon_path, + self.editor.clone(), + window, + cx, + ) else { + return; + }; + + let http_client = http_client.clone(); + let source_range = source_range.clone(); + + let url_string = url.to_string(); + let fetch = cx + .background_executor() + .spawn(async move { + fetch_url_content(http_client, url_string) + .map_err(|e| e.to_string()) + .await + }) + .shared(); + self.mention_set.add_fetch_result(url, fetch.clone()); + + cx.spawn_in(window, async move |this, cx| { + let fetch = fetch.await.notify_async_err(cx); + this.update(cx, |this, cx| { + if fetch.is_some() { + this.mention_set.insert_uri(crease_id, mention_uri.clone()); + } else { + // Remove crease if we failed to fetch + this.editor.update(cx, |editor, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + let Some(anchor) = + snapshot.anchor_in_excerpt(excerpt_id, source_range.start) + else { + return; + }; + editor.display_map.update(cx, |display_map, cx| { + display_map.unfold_intersecting(vec![anchor..anchor], true, cx); + }); + editor.remove_creases([crease_id], cx); + }); + } + }) + .ok(); + }) + .detach(); + } + + pub fn confirm_mention_for_selection( + &mut self, + source_range: Range, + selections: Vec<(Entity, Range, Range)>, + window: &mut Window, + cx: &mut Context, + ) { + let snapshot = self.editor.read(cx).buffer().read(cx).snapshot(cx); + let Some((&excerpt_id, _, _)) = snapshot.as_singleton() else { + return; + }; + let Some(start) = snapshot.anchor_in_excerpt(excerpt_id, source_range.start) else { + return; + }; + + let offset = start.to_offset(&snapshot); + + for (buffer, selection_range, range_to_fold) in selections { + let range = snapshot.anchor_after(offset + range_to_fold.start) + ..snapshot.anchor_after(offset + range_to_fold.end); + + let path = buffer + .read(cx) + .file() + .map_or(PathBuf::from("untitled"), |file| file.path().to_path_buf()); + let snapshot = buffer.read(cx).snapshot(); + + let point_range = selection_range.to_point(&snapshot); + let line_range = point_range.start.row..point_range.end.row; + + let uri = MentionUri::Selection { + path: path.clone(), + line_range: line_range.clone(), + }; + let crease = crate::context_picker::crease_for_mention( + selection_name(&path, &line_range).into(), + uri.icon_path(cx), + range, + self.editor.downgrade(), + ); + + let crease_id = self.editor.update(cx, |editor, cx| { + let crease_ids = editor.insert_creases(vec![crease.clone()], cx); + editor.fold_creases(vec![crease], false, window, cx); + crease_ids.first().copied().unwrap() + }); + + self.mention_set + .insert_uri(crease_id, MentionUri::Selection { path, line_range }); + } + } + pub fn contents( &self, window: &mut Window, cx: &mut Context, ) -> Task>> { - let contents = self.mention_set.lock().contents( + let contents = self.mention_set.contents( self.project.clone(), self.thread_store.clone(), self.text_thread_store.clone(), @@ -198,7 +382,7 @@ impl MessageEditor { pub fn clear(&mut self, window: &mut Window, cx: &mut Context) { self.editor.update(cx, |editor, cx| { editor.clear(window, cx); - editor.remove_creases(self.mention_set.lock().drain(), cx) + editor.remove_creases(self.mention_set.drain(), cx) }); } @@ -267,9 +451,6 @@ impl MessageEditor { cx: &mut Context, ) { let buffer = self.editor.read(cx).buffer().clone(); - let Some((&excerpt_id, _, _)) = buffer.read(cx).snapshot(cx).as_singleton() else { - return; - }; let Some(buffer) = buffer.read(cx).as_singleton() else { return; }; @@ -292,10 +473,8 @@ impl MessageEditor { &path_prefix, false, entry.is_dir(), - excerpt_id, anchor..anchor, - self.editor.clone(), - self.mention_set.clone(), + cx.weak_entity(), self.project.clone(), cx, ) else { @@ -331,6 +510,7 @@ impl MessageEditor { excerpt_id, crease_start, content_len, + abs_path.clone(), self.editor.clone(), window, cx, @@ -375,7 +555,7 @@ impl MessageEditor { }) .detach(); - self.mention_set.lock().insert_image(crease_id, task); + self.mention_set.insert_image(crease_id, task); }); } @@ -429,7 +609,7 @@ impl MessageEditor { editor.buffer().read(cx).snapshot(cx) }); - self.mention_set.lock().clear(); + self.mention_set.clear(); for (range, mention_uri) in mentions { let anchor = snapshot.anchor_before(range.start); let crease_id = crate::context_picker::insert_crease_for_mention( @@ -444,7 +624,7 @@ impl MessageEditor { ); if let Some(crease_id) = crease_id { - self.mention_set.lock().insert_uri(crease_id, mention_uri); + self.mention_set.insert_uri(crease_id, mention_uri); } } for (range, content) in images { @@ -479,7 +659,7 @@ impl MessageEditor { let data: SharedString = content.data.to_string().into(); if let Some(crease_id) = crease_id { - self.mention_set.lock().insert_image( + self.mention_set.insert_image( crease_id, Task::ready(Ok(MentionImage { abs_path, @@ -550,20 +730,78 @@ pub(crate) fn insert_crease_for_image( excerpt_id: ExcerptId, anchor: text::Anchor, content_len: usize, + abs_path: Option>, editor: Entity, window: &mut Window, cx: &mut App, ) -> Option { - crate::context_picker::insert_crease_for_mention( - excerpt_id, - anchor, - content_len, - "Image".into(), - IconName::Image.path().into(), - editor, - window, - cx, - ) + let crease_label = abs_path + .as_ref() + .and_then(|path| path.file_name()) + .map(|name| name.to_string_lossy().to_string().into()) + .unwrap_or(SharedString::from("Image")); + + editor.update(cx, |editor, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + + let start = snapshot.anchor_in_excerpt(excerpt_id, anchor)?; + + let start = start.bias_right(&snapshot); + let end = snapshot.anchor_before(start.to_offset(&snapshot) + content_len); + + let placeholder = FoldPlaceholder { + render: render_image_fold_icon_button(crease_label, cx.weak_entity()), + merge_adjacent: false, + ..Default::default() + }; + + let crease = Crease::Inline { + range: start..end, + placeholder, + render_toggle: None, + render_trailer: None, + metadata: None, + }; + + let ids = editor.insert_creases(vec![crease.clone()], cx); + editor.fold_creases(vec![crease], false, window, cx); + + Some(ids[0]) + }) +} + +fn render_image_fold_icon_button( + label: SharedString, + editor: WeakEntity, +) -> Arc, &mut App) -> AnyElement> { + Arc::new({ + move |fold_id, fold_range, cx| { + let is_in_text_selection = editor + .update(cx, |editor, cx| editor.is_range_selected(&fold_range, cx)) + .unwrap_or_default(); + + ButtonLike::new(fold_id) + .style(ButtonStyle::Filled) + .selected_style(ButtonStyle::Tinted(TintColor::Accent)) + .toggle_state(is_in_text_selection) + .child( + h_flex() + .gap_1() + .child( + Icon::new(IconName::Image) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child( + Label::new(label.clone()) + .size(LabelSize::Small) + .buffer_font(cx) + .single_line(), + ), + ) + .into_any_element() + } + }) } #[cfg(test)] diff --git a/crates/agent_ui/src/context_picker.rs b/crates/agent_ui/src/context_picker.rs index 7dc00bfae2..6c5546c6bb 100644 --- a/crates/agent_ui/src/context_picker.rs +++ b/crates/agent_ui/src/context_picker.rs @@ -13,7 +13,7 @@ use anyhow::{Result, anyhow}; use collections::HashSet; pub use completion_provider::ContextPickerCompletionProvider; use editor::display_map::{Crease, CreaseId, CreaseMetadata, FoldId}; -use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset}; +use editor::{Anchor, Editor, ExcerptId, FoldPlaceholder, ToOffset}; use fetch_context_picker::FetchContextPicker; use file_context_picker::FileContextPicker; use file_context_picker::render_file_context_entry; @@ -837,42 +837,9 @@ fn render_fold_icon_button( ) -> Arc, &mut App) -> AnyElement> { Arc::new({ move |fold_id, fold_range, cx| { - let is_in_text_selection = editor.upgrade().is_some_and(|editor| { - editor.update(cx, |editor, cx| { - let snapshot = editor - .buffer() - .update(cx, |multi_buffer, cx| multi_buffer.snapshot(cx)); - - let is_in_pending_selection = || { - editor - .selections - .pending - .as_ref() - .is_some_and(|pending_selection| { - pending_selection - .selection - .range() - .includes(&fold_range, &snapshot) - }) - }; - - let mut is_in_complete_selection = || { - editor - .selections - .disjoint_in_range::(fold_range.clone(), cx) - .into_iter() - .any(|selection| { - // This is needed to cover a corner case, if we just check for an existing - // selection in the fold range, having a cursor at the start of the fold - // marks it as selected. Non-empty selections don't cause this. - let length = selection.end - selection.start; - length > 0 - }) - }; - - is_in_pending_selection() || is_in_complete_selection() - }) - }); + let is_in_text_selection = editor + .update(cx, |editor, cx| editor.is_range_selected(&fold_range, cx)) + .unwrap_or_default(); ButtonLike::new(fold_id) .style(ButtonStyle::Filled) diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index f77e9ae08c..85f2e01ed4 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -2369,6 +2369,34 @@ impl Editor { .is_some_and(|menu| menu.context_menu.focus_handle(cx).is_focused(window)) } + pub fn is_range_selected(&mut self, range: &Range, cx: &mut Context) -> bool { + if self + .selections + .pending + .as_ref() + .is_some_and(|pending_selection| { + let snapshot = self.buffer().read(cx).snapshot(cx); + pending_selection + .selection + .range() + .includes(&range, &snapshot) + }) + { + return true; + } + + self.selections + .disjoint_in_range::(range.clone(), cx) + .into_iter() + .any(|selection| { + // This is needed to cover a corner case, if we just check for an existing + // selection in the fold range, having a cursor at the start of the fold + // marks it as selected. Non-empty selections don't cause this. + let length = selection.end - selection.start; + length > 0 + }) + } + pub fn key_context(&self, window: &Window, cx: &App) -> KeyContext { self.key_context_internal(self.has_active_edit_prediction(), window, cx) } From b3cad8b527c773c3a541e1a9e3ff23a8fbbae548 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 15 Aug 2025 15:21:04 -0400 Subject: [PATCH 06/17] proto: Remove `UpdateUserPlan` message (#36268) This PR removes the `UpdateUserPlan` RPC message. We're no longer using the message after https://github.com/zed-industries/zed/pull/36255. Release Notes: - N/A --- crates/client/src/user.rs | 21 ---- crates/collab/src/llm.rs | 8 -- crates/collab/src/rpc.rs | 223 ----------------------------------- crates/proto/proto/app.proto | 10 -- crates/proto/proto/zed.proto | 3 +- crates/proto/src/proto.rs | 1 - 6 files changed, 1 insertion(+), 265 deletions(-) diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index faf46945d8..33a240eca1 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -177,7 +177,6 @@ impl UserStore { let (mut current_user_tx, current_user_rx) = watch::channel(); let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded(); let rpc_subscriptions = vec![ - client.add_message_handler(cx.weak_entity(), Self::handle_update_plan), client.add_message_handler(cx.weak_entity(), Self::handle_update_contacts), client.add_message_handler(cx.weak_entity(), Self::handle_update_invite_info), client.add_message_handler(cx.weak_entity(), Self::handle_show_contacts), @@ -343,26 +342,6 @@ impl UserStore { Ok(()) } - async fn handle_update_plan( - this: Entity, - _message: TypedEnvelope, - mut cx: AsyncApp, - ) -> Result<()> { - let client = this - .read_with(&cx, |this, _| this.client.upgrade())? - .context("client was dropped")?; - - let response = client - .cloud_client() - .get_authenticated_user() - .await - .context("failed to fetch authenticated user")?; - - this.update(&mut cx, |this, cx| { - this.update_authenticated_user(response, cx); - }) - } - fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { match message { UpdateContacts::Wait(barrier) => { diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index ca8e89bc6d..dec10232bd 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -1,9 +1 @@ pub mod db; - -pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial"; - -/// The name of the feature flag that bypasses the account age check. -pub const BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG: &str = "bypass-account-age-check"; - -/// The minimum account age an account must have in order to use the LLM service. -pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30); diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 8366b2cf13..957cc30fe6 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,12 +1,6 @@ mod connection_pool; use crate::api::{CloudflareIpCountryHeader, SystemIdHeader}; -use crate::db::billing_subscription::SubscriptionKind; -use crate::llm::db::LlmDatabase; -use crate::llm::{ - AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, - MIN_ACCOUNT_AGE_FOR_LLM_USE, -}; use crate::{ AppState, Error, Result, auth, db::{ @@ -146,13 +140,6 @@ pub enum Principal { } impl Principal { - fn user(&self) -> &User { - match self { - Principal::User(user) => user, - Principal::Impersonated { user, .. } => user, - } - } - fn update_span(&self, span: &tracing::Span) { match &self { Principal::User(user) => { @@ -997,8 +984,6 @@ impl Server { .await?; } - update_user_plan(session).await?; - let contacts = self.app_state.db.get_contacts(user.id).await?; { @@ -2832,214 +2817,6 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool { version.0.minor() < 139 } -async fn current_plan(db: &Arc, user_id: UserId, is_staff: bool) -> Result { - if is_staff { - return Ok(proto::Plan::ZedPro); - } - - let subscription = db.get_active_billing_subscription(user_id).await?; - let subscription_kind = subscription.and_then(|subscription| subscription.kind); - - let plan = if let Some(subscription_kind) = subscription_kind { - match subscription_kind { - SubscriptionKind::ZedPro => proto::Plan::ZedPro, - SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial, - SubscriptionKind::ZedFree => proto::Plan::Free, - } - } else { - proto::Plan::Free - }; - - Ok(plan) -} - -async fn make_update_user_plan_message( - user: &User, - is_staff: bool, - db: &Arc, - llm_db: Option>, -) -> Result { - let feature_flags = db.get_user_flags(user.id).await?; - let plan = current_plan(db, user.id, is_staff).await?; - let billing_customer = db.get_billing_customer_by_user_id(user.id).await?; - let billing_preferences = db.get_billing_preferences(user.id).await?; - - let (subscription_period, usage) = if let Some(llm_db) = llm_db { - let subscription = db.get_active_billing_subscription(user.id).await?; - - let subscription_period = - crate::db::billing_subscription::Model::current_period(subscription, is_staff); - - let usage = if let Some((period_start_at, period_end_at)) = subscription_period { - llm_db - .get_subscription_usage_for_period(user.id, period_start_at, period_end_at) - .await? - } else { - None - }; - - (subscription_period, usage) - } else { - (None, None) - }; - - let bypass_account_age_check = feature_flags - .iter() - .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG); - let account_too_young = !matches!(plan, proto::Plan::ZedPro) - && !bypass_account_age_check - && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE; - - Ok(proto::UpdateUserPlan { - plan: plan.into(), - trial_started_at: billing_customer - .as_ref() - .and_then(|billing_customer| billing_customer.trial_started_at) - .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64), - is_usage_based_billing_enabled: if is_staff { - Some(true) - } else { - billing_preferences.map(|preferences| preferences.model_request_overages_enabled) - }, - subscription_period: subscription_period.map(|(started_at, ended_at)| { - proto::SubscriptionPeriod { - started_at: started_at.timestamp() as u64, - ended_at: ended_at.timestamp() as u64, - } - }), - account_too_young: Some(account_too_young), - has_overdue_invoices: billing_customer - .map(|billing_customer| billing_customer.has_overdue_invoices), - usage: Some( - usage - .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags)) - .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)), - ), - }) -} - -fn model_requests_limit( - plan: cloud_llm_client::Plan, - feature_flags: &Vec, -) -> cloud_llm_client::UsageLimit { - match plan.model_requests_limit() { - cloud_llm_client::UsageLimit::Limited(limit) => { - let limit = if plan == cloud_llm_client::Plan::ZedProTrial - && feature_flags - .iter() - .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG) - { - 1_000 - } else { - limit - }; - - cloud_llm_client::UsageLimit::Limited(limit) - } - cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited, - } -} - -fn subscription_usage_to_proto( - plan: proto::Plan, - usage: crate::llm::db::subscription_usage::Model, - feature_flags: &Vec, -) -> proto::SubscriptionUsage { - let plan = match plan { - proto::Plan::Free => cloud_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, - }; - - proto::SubscriptionUsage { - model_requests_usage_amount: usage.model_requests as u32, - model_requests_usage_limit: Some(proto::UsageLimit { - variant: Some(match model_requests_limit(plan, feature_flags) { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - edit_predictions_usage_amount: usage.edit_predictions as u32, - edit_predictions_usage_limit: Some(proto::UsageLimit { - variant: Some(match plan.edit_predictions_limit() { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - } -} - -fn make_default_subscription_usage( - plan: proto::Plan, - feature_flags: &Vec, -) -> proto::SubscriptionUsage { - let plan = match plan { - proto::Plan::Free => cloud_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, - }; - - proto::SubscriptionUsage { - model_requests_usage_amount: 0, - model_requests_usage_limit: Some(proto::UsageLimit { - variant: Some(match model_requests_limit(plan, feature_flags) { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - edit_predictions_usage_amount: 0, - edit_predictions_usage_limit: Some(proto::UsageLimit { - variant: Some(match plan.edit_predictions_limit() { - cloud_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { - limit: limit as u32, - }) - } - cloud_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) - } - }), - }), - } -} - -async fn update_user_plan(session: &Session) -> Result<()> { - let db = session.db().await; - - let update_user_plan = make_update_user_plan_message( - session.principal.user(), - session.is_staff(), - &db.0, - session.app_state.llm_db.clone(), - ) - .await?; - - session - .peer - .send(session.connection_id, update_user_plan) - .trace_err(); - - Ok(()) -} - async fn subscribe_to_channels( _: proto::SubscribeToChannels, session: MessageContext, diff --git a/crates/proto/proto/app.proto b/crates/proto/proto/app.proto index 66baf968e3..fe6f7be1b0 100644 --- a/crates/proto/proto/app.proto +++ b/crates/proto/proto/app.proto @@ -12,16 +12,6 @@ enum Plan { ZedProTrial = 2; } -message UpdateUserPlan { - Plan plan = 1; - optional uint64 trial_started_at = 2; - optional bool is_usage_based_billing_enabled = 3; - optional SubscriptionUsage usage = 4; - optional SubscriptionPeriod subscription_period = 5; - optional bool account_too_young = 6; - optional bool has_overdue_invoices = 7; -} - message SubscriptionPeriod { uint64 started_at = 1; uint64 ended_at = 2; diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 8984df2944..4b023a46bc 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -135,7 +135,6 @@ message Envelope { FollowResponse follow_response = 99; UpdateFollowers update_followers = 100; Unfollow unfollow = 101; - UpdateUserPlan update_user_plan = 234; UpdateDiffBases update_diff_bases = 104; AcceptTermsOfService accept_terms_of_service = 239; AcceptTermsOfServiceResponse accept_terms_of_service_response = 240; @@ -414,7 +413,7 @@ message Envelope { reserved 221; reserved 224 to 229; reserved 230 to 231; - reserved 235 to 236; + reserved 234 to 236; reserved 246; reserved 247 to 254; reserved 255 to 256; diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 82bd1af6db..18abf31c64 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -275,7 +275,6 @@ messages!( (UpdateProject, Foreground), (UpdateProjectCollaborator, Foreground), (UpdateUserChannels, Foreground), - (UpdateUserPlan, Foreground), (UpdateWorktree, Foreground), (UpdateWorktreeSettings, Foreground), (UpdateRepository, Foreground), From 75f85b3aaa202f07185a39d855143851f609ddf7 Mon Sep 17 00:00:00 2001 From: "Joseph T. Lyons" Date: Fri, 15 Aug 2025 15:37:52 -0400 Subject: [PATCH 07/17] Remove old telemetry events and transformation layer (#36263) Successor to: https://github.com/zed-industries/zed/pull/25179 Release Notes: - N/A --- crates/collab/src/api/events.rs | 166 +----------------- .../telemetry_events/src/telemetry_events.rs | 108 +----------- 2 files changed, 4 insertions(+), 270 deletions(-) diff --git a/crates/collab/src/api/events.rs b/crates/collab/src/api/events.rs index 2f34a843a8..cd1dc42e64 100644 --- a/crates/collab/src/api/events.rs +++ b/crates/collab/src/api/events.rs @@ -564,170 +564,10 @@ fn for_snowflake( country_code: Option, checksum_matched: bool, ) -> impl Iterator { - body.events.into_iter().filter_map(move |event| { + body.events.into_iter().map(move |event| { let timestamp = first_event_at + Duration::milliseconds(event.milliseconds_since_first_event); - // We will need to double check, but I believe all of the events that - // are being transformed here are now migrated over to use the - // telemetry::event! macro, as of this commit so this code can go away - // when we feel enough users have upgraded past this point. let (event_type, mut event_properties) = match &event.event { - Event::Editor(e) => ( - match e.operation.as_str() { - "open" => "Editor Opened".to_string(), - "save" => "Editor Saved".to_string(), - _ => format!("Unknown Editor Event: {}", e.operation), - }, - serde_json::to_value(e).unwrap(), - ), - Event::EditPrediction(e) => ( - format!( - "Edit Prediction {}", - if e.suggestion_accepted { - "Accepted" - } else { - "Discarded" - } - ), - serde_json::to_value(e).unwrap(), - ), - Event::EditPredictionRating(e) => ( - "Edit Prediction Rated".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Call(e) => { - let event_type = match e.operation.trim() { - "unshare project" => "Project Unshared".to_string(), - "open channel notes" => "Channel Notes Opened".to_string(), - "share project" => "Project Shared".to_string(), - "join channel" => "Channel Joined".to_string(), - "hang up" => "Call Ended".to_string(), - "accept incoming" => "Incoming Call Accepted".to_string(), - "invite" => "Participant Invited".to_string(), - "disable microphone" => "Microphone Disabled".to_string(), - "enable microphone" => "Microphone Enabled".to_string(), - "enable screen share" => "Screen Share Enabled".to_string(), - "disable screen share" => "Screen Share Disabled".to_string(), - "decline incoming" => "Incoming Call Declined".to_string(), - _ => format!("Unknown Call Event: {}", e.operation), - }; - - (event_type, serde_json::to_value(e).unwrap()) - } - Event::Assistant(e) => ( - match e.phase { - telemetry_events::AssistantPhase::Response => "Assistant Responded".to_string(), - telemetry_events::AssistantPhase::Invoked => "Assistant Invoked".to_string(), - telemetry_events::AssistantPhase::Accepted => { - "Assistant Response Accepted".to_string() - } - telemetry_events::AssistantPhase::Rejected => { - "Assistant Response Rejected".to_string() - } - }, - serde_json::to_value(e).unwrap(), - ), - Event::Cpu(_) | Event::Memory(_) => return None, - Event::App(e) => { - let mut properties = json!({}); - let event_type = match e.operation.trim() { - // App - "open" => "App Opened".to_string(), - "first open" => "App First Opened".to_string(), - "first open for release channel" => { - "App First Opened For Release Channel".to_string() - } - "close" => "App Closed".to_string(), - - // Project - "open project" => "Project Opened".to_string(), - "open node project" => { - properties["project_type"] = json!("node"); - "Project Opened".to_string() - } - "open pnpm project" => { - properties["project_type"] = json!("pnpm"); - "Project Opened".to_string() - } - "open yarn project" => { - properties["project_type"] = json!("yarn"); - "Project Opened".to_string() - } - - // SSH - "create ssh server" => "SSH Server Created".to_string(), - "create ssh project" => "SSH Project Created".to_string(), - "open ssh project" => "SSH Project Opened".to_string(), - - // Welcome Page - "welcome page: change keymap" => "Welcome Keymap Changed".to_string(), - "welcome page: change theme" => "Welcome Theme Changed".to_string(), - "welcome page: close" => "Welcome Page Closed".to_string(), - "welcome page: edit settings" => "Welcome Settings Edited".to_string(), - "welcome page: install cli" => "Welcome CLI Installed".to_string(), - "welcome page: open" => "Welcome Page Opened".to_string(), - "welcome page: open extensions" => "Welcome Extensions Page Opened".to_string(), - "welcome page: sign in to copilot" => "Welcome Copilot Signed In".to_string(), - "welcome page: toggle diagnostic telemetry" => { - "Welcome Diagnostic Telemetry Toggled".to_string() - } - "welcome page: toggle metric telemetry" => { - "Welcome Metric Telemetry Toggled".to_string() - } - "welcome page: toggle vim" => "Welcome Vim Mode Toggled".to_string(), - "welcome page: view docs" => "Welcome Documentation Viewed".to_string(), - - // Extensions - "extensions page: open" => "Extensions Page Opened".to_string(), - "extensions: install extension" => "Extension Installed".to_string(), - "extensions: uninstall extension" => "Extension Uninstalled".to_string(), - - // Misc - "markdown preview: open" => "Markdown Preview Opened".to_string(), - "project diagnostics: open" => "Project Diagnostics Opened".to_string(), - "project search: open" => "Project Search Opened".to_string(), - "repl sessions: open" => "REPL Session Started".to_string(), - - // Feature Upsell - "feature upsell: toggle vim" => { - properties["source"] = json!("Feature Upsell"); - "Vim Mode Toggled".to_string() - } - _ => e - .operation - .strip_prefix("feature upsell: viewed docs (") - .and_then(|s| s.strip_suffix(')')) - .map_or_else( - || format!("Unknown App Event: {}", e.operation), - |docs_url| { - properties["url"] = json!(docs_url); - properties["source"] = json!("Feature Upsell"); - "Documentation Viewed".to_string() - }, - ), - }; - (event_type, properties) - } - Event::Setting(e) => ( - "Settings Changed".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Extension(e) => ( - "Extension Loaded".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Edit(e) => ( - "Editor Edited".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Action(e) => ( - "Action Invoked".to_string(), - serde_json::to_value(e).unwrap(), - ), - Event::Repl(e) => ( - "Kernel Status Changed".to_string(), - serde_json::to_value(e).unwrap(), - ), Event::Flexible(e) => ( e.event_type.clone(), serde_json::to_value(&e.event_properties).unwrap(), @@ -759,7 +599,7 @@ fn for_snowflake( }) }); - Some(SnowflakeRow { + SnowflakeRow { time: timestamp, user_id: body.metrics_id.clone(), device_id: body.system_id.clone(), @@ -767,7 +607,7 @@ fn for_snowflake( event_properties, user_properties, insert_id: Some(Uuid::new_v4().to_string()), - }) + } }) } diff --git a/crates/telemetry_events/src/telemetry_events.rs b/crates/telemetry_events/src/telemetry_events.rs index 735a1310ae..12d8d4c04b 100644 --- a/crates/telemetry_events/src/telemetry_events.rs +++ b/crates/telemetry_events/src/telemetry_events.rs @@ -2,7 +2,7 @@ use semantic_version::SemanticVersion; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, fmt::Display, sync::Arc, time::Duration}; +use std::{collections::HashMap, fmt::Display, time::Duration}; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct EventRequestBody { @@ -93,19 +93,6 @@ impl Display for AssistantPhase { #[serde(tag = "type")] pub enum Event { Flexible(FlexibleEvent), - Editor(EditorEvent), - EditPrediction(EditPredictionEvent), - EditPredictionRating(EditPredictionRatingEvent), - Call(CallEvent), - Assistant(AssistantEventData), - Cpu(CpuEvent), - Memory(MemoryEvent), - App(AppEvent), - Setting(SettingEvent), - Extension(ExtensionEvent), - Edit(EditEvent), - Action(ActionEvent), - Repl(ReplEvent), } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -114,54 +101,12 @@ pub struct FlexibleEvent { pub event_properties: HashMap, } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct EditorEvent { - /// The editor operation performed (open, save) - pub operation: String, - /// The extension of the file that was opened or saved - pub file_extension: Option, - /// Whether the user is in vim mode or not - pub vim_mode: bool, - /// Whether the user has copilot enabled or not - pub copilot_enabled: bool, - /// Whether the user has copilot enabled for the language of the file opened or saved - pub copilot_enabled_for_language: bool, - /// Whether the client is opening/saving a local file or a remote file via SSH - #[serde(default)] - pub is_via_ssh: bool, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct EditPredictionEvent { - /// Provider of the completion suggestion (e.g. copilot, supermaven) - pub provider: String, - pub suggestion_accepted: bool, - pub file_extension: Option, -} - #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum EditPredictionRating { Positive, Negative, } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct EditPredictionRatingEvent { - pub rating: EditPredictionRating, - pub input_events: Arc, - pub input_excerpt: Arc, - pub output_excerpt: Arc, - pub feedback: String, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct CallEvent { - /// Operation performed: invite/join call; begin/end screenshare; share/unshare project; etc - pub operation: String, - pub room_id: Option, - pub channel_id: Option, -} - #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct AssistantEventData { /// Unique random identifier for each assistant tab (None for inline assist) @@ -180,57 +125,6 @@ pub struct AssistantEventData { pub language_name: Option, } -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct CpuEvent { - pub usage_as_percentage: f32, - pub core_count: u32, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct MemoryEvent { - pub memory_in_bytes: u64, - pub virtual_memory_in_bytes: u64, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct ActionEvent { - pub source: String, - pub action: String, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct EditEvent { - pub duration: i64, - pub environment: String, - /// Whether the edits occurred locally or remotely via SSH - #[serde(default)] - pub is_via_ssh: bool, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct SettingEvent { - pub setting: String, - pub value: String, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct ExtensionEvent { - pub extension_id: Arc, - pub version: Arc, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct AppEvent { - pub operation: String, -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct ReplEvent { - pub kernel_language: String, - pub kernel_status: String, - pub repl_session_id: String, -} - #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct BacktraceFrame { pub ip: usize, From 2a9d4599cdeb61d5f6cf90f01d7475b14bf5b510 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 15 Aug 2025 15:46:23 -0400 Subject: [PATCH 08/17] proto: Remove unused types (#36269) This PR removes some unused types from the RPC protocol. Release Notes: - N/A --- .../agent_ui/src/language_model_selector.rs | 6 ++-- crates/client/src/user.rs | 13 -------- crates/proto/proto/app.proto | 31 ------------------- 3 files changed, 3 insertions(+), 47 deletions(-) diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 7121624c87..bb8514a224 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -1,5 +1,6 @@ use std::{cmp::Reverse, sync::Arc}; +use cloud_llm_client::Plan; use collections::{HashSet, IndexMap}; use feature_flags::ZedProFeatureFlag; use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; @@ -10,7 +11,6 @@ use language_model::{ }; use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; -use proto::Plan; use ui::{ListItem, ListItemSpacing, prelude::*}; const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro"; @@ -536,7 +536,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { ) -> Option { use feature_flags::FeatureFlagAppExt; - let plan = proto::Plan::ZedPro; + let plan = Plan::ZedPro; Some( h_flex() @@ -557,7 +557,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { window .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx) }), - Plan::Free | Plan::ZedProTrial => Button::new( + Plan::ZedFree | Plan::ZedProTrial => Button::new( "try-pro", if plan == Plan::ZedProTrial { "Upgrade to Pro" diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 33a240eca1..da7f50076b 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -998,19 +998,6 @@ impl RequestUsage { } } - pub fn from_proto(amount: u32, limit: proto::UsageLimit) -> Option { - let limit = match limit.variant? { - proto::usage_limit::Variant::Limited(limited) => { - UsageLimit::Limited(limited.limit as i32) - } - proto::usage_limit::Variant::Unlimited(_) => UsageLimit::Unlimited, - }; - Some(RequestUsage { - limit, - amount: amount as i32, - }) - } - fn from_headers( limit_name: &str, amount_name: &str, diff --git a/crates/proto/proto/app.proto b/crates/proto/proto/app.proto index fe6f7be1b0..9611b607d0 100644 --- a/crates/proto/proto/app.proto +++ b/crates/proto/proto/app.proto @@ -6,37 +6,6 @@ message UpdateInviteInfo { uint32 count = 2; } -enum Plan { - Free = 0; - ZedPro = 1; - ZedProTrial = 2; -} - -message SubscriptionPeriod { - uint64 started_at = 1; - uint64 ended_at = 2; -} - -message SubscriptionUsage { - uint32 model_requests_usage_amount = 1; - UsageLimit model_requests_usage_limit = 2; - uint32 edit_predictions_usage_amount = 3; - UsageLimit edit_predictions_usage_limit = 4; -} - -message UsageLimit { - oneof variant { - Limited limited = 1; - Unlimited unlimited = 2; - } - - message Limited { - uint32 limit = 1; - } - - message Unlimited {} -} - message AcceptTermsOfService {} message AcceptTermsOfServiceResponse { From 65f64aa5138a4cfcede025648cda973eeae21021 Mon Sep 17 00:00:00 2001 From: Finn Evers Date: Fri, 15 Aug 2025 22:21:21 +0200 Subject: [PATCH 09/17] search: Fix recently introduced issues with the search bars (#36271) Follow-up to https://github.com/zed-industries/zed/pull/36233 The above PR simplified the handling but introduced some bugs: The replace buttons were no longer clickable, some buttons also lost their toggle states, some buttons shared their element id and, lastly, some buttons were clickable but would not trigger the right action. This PR fixes all that. Release Notes: - N/A --- crates/search/src/buffer_search.rs | 53 +++++++++++++++----------- crates/search/src/project_search.rs | 59 +++++++++++++++++------------ crates/search/src/search.rs | 55 +++++++++++++++++++-------- crates/search/src/search_bar.rs | 12 +++++- 4 files changed, 114 insertions(+), 65 deletions(-) diff --git a/crates/search/src/buffer_search.rs b/crates/search/src/buffer_search.rs index da2d35d74c..189f48e6b6 100644 --- a/crates/search/src/buffer_search.rs +++ b/crates/search/src/buffer_search.rs @@ -2,9 +2,9 @@ mod registrar; use crate::{ FocusSearch, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext, SearchOption, - SearchOptions, SelectAllMatches, SelectNextMatch, SelectPreviousMatch, ToggleCaseSensitive, - ToggleRegex, ToggleReplace, ToggleSelection, ToggleWholeWord, - search_bar::{input_base_styles, render_action_button, render_text_input}, + SearchOptions, SearchSource, SelectAllMatches, SelectNextMatch, SelectPreviousMatch, + ToggleCaseSensitive, ToggleRegex, ToggleReplace, ToggleSelection, ToggleWholeWord, + search_bar::{ActionButtonState, input_base_styles, render_action_button, render_text_input}, }; use any_vec::AnyVec; use anyhow::Context as _; @@ -213,22 +213,25 @@ impl Render for BufferSearchBar { h_flex() .gap_1() .when(case, |div| { - div.child( - SearchOption::CaseSensitive - .as_button(self.search_options, focus_handle.clone()), - ) + div.child(SearchOption::CaseSensitive.as_button( + self.search_options, + SearchSource::Buffer, + focus_handle.clone(), + )) }) .when(word, |div| { - div.child( - SearchOption::WholeWord - .as_button(self.search_options, focus_handle.clone()), - ) + div.child(SearchOption::WholeWord.as_button( + self.search_options, + SearchSource::Buffer, + focus_handle.clone(), + )) }) .when(regex, |div| { - div.child( - SearchOption::Regex - .as_button(self.search_options, focus_handle.clone()), - ) + div.child(SearchOption::Regex.as_button( + self.search_options, + SearchSource::Buffer, + focus_handle.clone(), + )) }), ) }); @@ -240,7 +243,7 @@ impl Render for BufferSearchBar { this.child(render_action_button( "buffer-search-bar-toggle", IconName::Replace, - self.replace_enabled, + self.replace_enabled.then_some(ActionButtonState::Toggled), "Toggle Replace", &ToggleReplace, focus_handle.clone(), @@ -285,7 +288,9 @@ impl Render for BufferSearchBar { .child(render_action_button( "buffer-search-nav-button", ui::IconName::ChevronLeft, - self.active_match_index.is_some(), + self.active_match_index + .is_none() + .then_some(ActionButtonState::Disabled), "Select Previous Match", &SelectPreviousMatch, query_focus.clone(), @@ -293,7 +298,9 @@ impl Render for BufferSearchBar { .child(render_action_button( "buffer-search-nav-button", ui::IconName::ChevronRight, - self.active_match_index.is_some(), + self.active_match_index + .is_none() + .then_some(ActionButtonState::Disabled), "Select Next Match", &SelectNextMatch, query_focus.clone(), @@ -313,7 +320,7 @@ impl Render for BufferSearchBar { el.child(render_action_button( "buffer-search-nav-button", IconName::SelectAll, - true, + Default::default(), "Select All Matches", &SelectAllMatches, query_focus, @@ -324,7 +331,7 @@ impl Render for BufferSearchBar { el.child(render_action_button( "buffer-search", IconName::Close, - true, + Default::default(), "Close Search Bar", &Dismiss, focus_handle.clone(), @@ -352,7 +359,7 @@ impl Render for BufferSearchBar { .child(render_action_button( "buffer-search-replace-button", IconName::ReplaceNext, - true, + Default::default(), "Replace Next Match", &ReplaceNext, focus_handle.clone(), @@ -360,7 +367,7 @@ impl Render for BufferSearchBar { .child(render_action_button( "buffer-search-replace-button", IconName::ReplaceAll, - true, + Default::default(), "Replace All Matches", &ReplaceAll, focus_handle, @@ -394,7 +401,7 @@ impl Render for BufferSearchBar { div.child(h_flex().absolute().right_0().child(render_action_button( "buffer-search", IconName::Close, - true, + Default::default(), "Close Search Bar", &Dismiss, focus_handle.clone(), diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index b791f748ad..056c3556ba 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -1,9 +1,9 @@ use crate::{ BufferSearchBar, FocusSearch, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext, - SearchOption, SearchOptions, SelectNextMatch, SelectPreviousMatch, ToggleCaseSensitive, - ToggleIncludeIgnored, ToggleRegex, ToggleReplace, ToggleWholeWord, + SearchOption, SearchOptions, SearchSource, SelectNextMatch, SelectPreviousMatch, + ToggleCaseSensitive, ToggleIncludeIgnored, ToggleRegex, ToggleReplace, ToggleWholeWord, buffer_search::Deploy, - search_bar::{input_base_styles, render_action_button, render_text_input}, + search_bar::{ActionButtonState, input_base_styles, render_action_button, render_text_input}, }; use anyhow::Context as _; use collections::HashMap; @@ -1665,7 +1665,7 @@ impl ProjectSearchBar { }); } - fn toggle_search_option( + pub(crate) fn toggle_search_option( &mut self, option: SearchOptions, window: &mut Window, @@ -1962,17 +1962,21 @@ impl Render for ProjectSearchBar { .child( h_flex() .gap_1() - .child( - SearchOption::CaseSensitive - .as_button(search.search_options, focus_handle.clone()), - ) - .child( - SearchOption::WholeWord - .as_button(search.search_options, focus_handle.clone()), - ) - .child( - SearchOption::Regex.as_button(search.search_options, focus_handle.clone()), - ), + .child(SearchOption::CaseSensitive.as_button( + search.search_options, + SearchSource::Project(cx), + focus_handle.clone(), + )) + .child(SearchOption::WholeWord.as_button( + search.search_options, + SearchSource::Project(cx), + focus_handle.clone(), + )) + .child(SearchOption::Regex.as_button( + search.search_options, + SearchSource::Project(cx), + focus_handle.clone(), + )), ); let query_focus = search.query_editor.focus_handle(cx); @@ -1985,7 +1989,10 @@ impl Render for ProjectSearchBar { .child(render_action_button( "project-search-nav-button", IconName::ChevronLeft, - search.active_match_index.is_some(), + search + .active_match_index + .is_none() + .then_some(ActionButtonState::Disabled), "Select Previous Match", &SelectPreviousMatch, query_focus.clone(), @@ -1993,7 +2000,10 @@ impl Render for ProjectSearchBar { .child(render_action_button( "project-search-nav-button", IconName::ChevronRight, - search.active_match_index.is_some(), + search + .active_match_index + .is_none() + .then_some(ActionButtonState::Disabled), "Select Next Match", &SelectNextMatch, query_focus, @@ -2054,7 +2064,7 @@ impl Render for ProjectSearchBar { self.active_project_search .as_ref() .map(|search| search.read(cx).replace_enabled) - .unwrap_or_default(), + .and_then(|enabled| enabled.then_some(ActionButtonState::Toggled)), "Toggle Replace", &ToggleReplace, focus_handle.clone(), @@ -2079,7 +2089,7 @@ impl Render for ProjectSearchBar { .child(render_action_button( "project-search-replace-button", IconName::ReplaceNext, - true, + Default::default(), "Replace Next Match", &ReplaceNext, focus_handle.clone(), @@ -2087,7 +2097,7 @@ impl Render for ProjectSearchBar { .child(render_action_button( "project-search-replace-button", IconName::ReplaceAll, - true, + Default::default(), "Replace All Matches", &ReplaceAll, focus_handle, @@ -2129,10 +2139,11 @@ impl Render for ProjectSearchBar { this.toggle_opened_only(window, cx); })), ) - .child( - SearchOption::IncludeIgnored - .as_button(search.search_options, focus_handle.clone()), - ); + .child(SearchOption::IncludeIgnored.as_button( + search.search_options, + SearchSource::Project(cx), + focus_handle.clone(), + )); h_flex() .w_full() .gap_2() diff --git a/crates/search/src/search.rs b/crates/search/src/search.rs index 89064e0a27..904c74d03c 100644 --- a/crates/search/src/search.rs +++ b/crates/search/src/search.rs @@ -1,7 +1,7 @@ use bitflags::bitflags; pub use buffer_search::BufferSearchBar; use editor::SearchSettings; -use gpui::{Action, App, FocusHandle, IntoElement, actions}; +use gpui::{Action, App, ClickEvent, FocusHandle, IntoElement, actions}; use project::search::SearchQuery; pub use project_search::ProjectSearchView; use ui::{ButtonStyle, IconButton, IconButtonShape}; @@ -11,6 +11,8 @@ use workspace::{Toast, Workspace}; pub use search_status_button::SEARCH_ICON; +use crate::project_search::ProjectSearchBar; + pub mod buffer_search; pub mod project_search; pub(crate) mod search_bar; @@ -83,9 +85,14 @@ pub enum SearchOption { Backwards, } +pub(crate) enum SearchSource<'a, 'b> { + Buffer, + Project(&'a Context<'b, ProjectSearchBar>), +} + impl SearchOption { - pub fn as_options(self) -> SearchOptions { - SearchOptions::from_bits(1 << self as u8).unwrap() + pub fn as_options(&self) -> SearchOptions { + SearchOptions::from_bits(1 << *self as u8).unwrap() } pub fn label(&self) -> &'static str { @@ -119,25 +126,41 @@ impl SearchOption { } } - pub fn as_button(&self, active: SearchOptions, focus_handle: FocusHandle) -> impl IntoElement { + pub(crate) fn as_button( + &self, + active: SearchOptions, + search_source: SearchSource, + focus_handle: FocusHandle, + ) -> impl IntoElement { let action = self.to_toggle_action(); let label = self.label(); - IconButton::new(label, self.icon()) - .on_click({ + IconButton::new( + (label, matches!(search_source, SearchSource::Buffer) as u32), + self.icon(), + ) + .map(|button| match search_source { + SearchSource::Buffer => { let focus_handle = focus_handle.clone(); - move |_, window, cx| { + button.on_click(move |_: &ClickEvent, window, cx| { if !focus_handle.is_focused(&window) { window.focus(&focus_handle); } - window.dispatch_action(action.boxed_clone(), cx) - } - }) - .style(ButtonStyle::Subtle) - .shape(IconButtonShape::Square) - .toggle_state(active.contains(self.as_options())) - .tooltip({ - move |window, cx| Tooltip::for_action_in(label, action, &focus_handle, window, cx) - }) + window.dispatch_action(action.boxed_clone(), cx); + }) + } + SearchSource::Project(cx) => { + let options = self.as_options(); + button.on_click(cx.listener(move |this, _: &ClickEvent, window, cx| { + this.toggle_search_option(options, window, cx); + })) + } + }) + .style(ButtonStyle::Subtle) + .shape(IconButtonShape::Square) + .toggle_state(active.contains(self.as_options())) + .tooltip({ + move |window, cx| Tooltip::for_action_in(label, action, &focus_handle, window, cx) + }) } } diff --git a/crates/search/src/search_bar.rs b/crates/search/src/search_bar.rs index 094ce3638e..8cc838a8a6 100644 --- a/crates/search/src/search_bar.rs +++ b/crates/search/src/search_bar.rs @@ -5,10 +5,15 @@ use theme::ThemeSettings; use ui::{IconButton, IconButtonShape}; use ui::{Tooltip, prelude::*}; +pub(super) enum ActionButtonState { + Disabled, + Toggled, +} + pub(super) fn render_action_button( id_prefix: &'static str, icon: ui::IconName, - active: bool, + button_state: Option, tooltip: &'static str, action: &'static dyn Action, focus_handle: FocusHandle, @@ -28,7 +33,10 @@ pub(super) fn render_action_button( } }) .tooltip(move |window, cx| Tooltip::for_action_in(tooltip, action, &focus_handle, window, cx)) - .disabled(!active) + .when_some(button_state, |this, state| match state { + ActionButtonState::Toggled => this.toggle_state(true), + ActionButtonState::Disabled => this.disabled(true), + }) } pub(crate) fn input_base_styles(border_color: Hsla, map: impl FnOnce(Div) -> Div) -> Div { From 7199c733b252f62f84135e0b9102fab22d5480e5 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 15 Aug 2025 16:21:45 -0400 Subject: [PATCH 10/17] proto: Remove `AcceptTermsOfService` message (#36272) This PR removes the `AcceptTermsOfService` RPC message. We're no longer using the message after https://github.com/zed-industries/zed/pull/36255. Release Notes: - N/A --- crates/collab/src/rpc.rs | 21 --------------------- crates/proto/proto/app.proto | 6 ------ crates/proto/proto/zed.proto | 3 +-- crates/proto/src/proto.rs | 3 --- 4 files changed, 1 insertion(+), 32 deletions(-) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 957cc30fe6..ef749ac9b7 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -29,7 +29,6 @@ use axum::{ response::IntoResponse, routing::get, }; -use chrono::Utc; use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; @@ -449,7 +448,6 @@ impl Server { .add_request_handler(follow) .add_message_handler(unfollow) .add_message_handler(update_followers) - .add_request_handler(accept_terms_of_service) .add_message_handler(acknowledge_channel_message) .add_message_handler(acknowledge_buffer_version) .add_request_handler(get_supermaven_api_key) @@ -3985,25 +3983,6 @@ async fn mark_notification_as_read( Ok(()) } -/// Accept the terms of service (tos) on behalf of the current user -async fn accept_terms_of_service( - _request: proto::AcceptTermsOfService, - response: Response, - session: MessageContext, -) -> Result<()> { - let db = session.db().await; - - let accepted_tos_at = Utc::now(); - db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc())) - .await?; - - response.send(proto::AcceptTermsOfServiceResponse { - accepted_tos_at: accepted_tos_at.timestamp() as u64, - })?; - - Ok(()) -} - fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result { let message = match message { TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()), diff --git a/crates/proto/proto/app.proto b/crates/proto/proto/app.proto index 9611b607d0..1f2ab1f539 100644 --- a/crates/proto/proto/app.proto +++ b/crates/proto/proto/app.proto @@ -6,12 +6,6 @@ message UpdateInviteInfo { uint32 count = 2; } -message AcceptTermsOfService {} - -message AcceptTermsOfServiceResponse { - uint64 accepted_tos_at = 1; -} - message ShutdownRemoteServer {} message Toast { diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 4b023a46bc..310fcf584e 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -136,8 +136,6 @@ message Envelope { UpdateFollowers update_followers = 100; Unfollow unfollow = 101; UpdateDiffBases update_diff_bases = 104; - AcceptTermsOfService accept_terms_of_service = 239; - AcceptTermsOfServiceResponse accept_terms_of_service_response = 240; OnTypeFormatting on_type_formatting = 105; OnTypeFormattingResponse on_type_formatting_response = 106; @@ -414,6 +412,7 @@ message Envelope { reserved 224 to 229; reserved 230 to 231; reserved 234 to 236; + reserved 239 to 240; reserved 246; reserved 247 to 254; reserved 255 to 256; diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 18abf31c64..802db09590 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -20,8 +20,6 @@ pub const SSH_PEER_ID: PeerId = PeerId { owner_id: 0, id: 0 }; pub const SSH_PROJECT_ID: u64 = 0; messages!( - (AcceptTermsOfService, Foreground), - (AcceptTermsOfServiceResponse, Foreground), (Ack, Foreground), (AckBufferOperation, Background), (AckChannelMessage, Background), @@ -315,7 +313,6 @@ messages!( ); request_messages!( - (AcceptTermsOfService, AcceptTermsOfServiceResponse), (ApplyCodeAction, ApplyCodeActionResponse), ( ApplyCompletionAdditionalEdits, From 3e0a755486201a2fe6e77213af68494a784a4895 Mon Sep 17 00:00:00 2001 From: Finn Evers Date: Fri, 15 Aug 2025 22:27:44 +0200 Subject: [PATCH 11/17] Remove some redundant entity clones (#36274) `cx.entity()` already returns an owned entity, so there is no need for these clones. Release Notes: - N/A --- crates/agent_ui/src/context_picker.rs | 2 +- crates/agent_ui/src/inline_assistant.rs | 2 +- crates/agent_ui/src/profile_selector.rs | 2 +- crates/collab_ui/src/chat_panel.rs | 2 +- crates/collab_ui/src/collab_panel.rs | 8 +- .../src/collab_panel/channel_modal.rs | 2 +- crates/collab_ui/src/notification_panel.rs | 4 +- crates/debugger_ui/src/session/running.rs | 4 +- .../src/edit_prediction_button.rs | 6 +- crates/editor/src/editor_tests.rs | 13 +-- crates/editor/src/element.rs | 2 +- crates/extensions_ui/src/extensions_ui.rs | 2 +- crates/git_ui/src/git_panel.rs | 2 +- crates/gpui/examples/input.rs | 4 +- crates/language_tools/src/lsp_log.rs | 2 +- crates/language_tools/src/lsp_tool.rs | 2 +- crates/language_tools/src/syntax_tree_view.rs | 2 +- crates/outline_panel/src/outline_panel.rs | 80 +++++++++---------- crates/project_panel/src/project_panel.rs | 42 +++++----- crates/recent_projects/src/remote_servers.rs | 4 +- crates/repl/src/session.rs | 2 +- crates/storybook/src/stories/indent_guides.rs | 2 +- crates/terminal_view/src/terminal_panel.rs | 4 +- crates/terminal_view/src/terminal_view.rs | 2 +- crates/vim/src/mode_indicator.rs | 4 +- crates/vim/src/normal/search.rs | 2 +- crates/vim/src/vim.rs | 2 +- crates/workspace/src/dock.rs | 2 +- crates/workspace/src/notifications.rs | 2 +- crates/workspace/src/pane.rs | 16 ++-- crates/workspace/src/workspace.rs | 2 +- crates/zed/src/zed.rs | 2 +- 32 files changed, 106 insertions(+), 123 deletions(-) diff --git a/crates/agent_ui/src/context_picker.rs b/crates/agent_ui/src/context_picker.rs index 6c5546c6bb..131023d249 100644 --- a/crates/agent_ui/src/context_picker.rs +++ b/crates/agent_ui/src/context_picker.rs @@ -228,7 +228,7 @@ impl ContextPicker { } fn build_menu(&mut self, window: &mut Window, cx: &mut Context) -> Entity { - let context_picker = cx.entity().clone(); + let context_picker = cx.entity(); let menu = ContextMenu::build(window, cx, move |menu, _window, cx| { let recent = self.recent_entries(cx); diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 4a4a747899..bbd3595805 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -72,7 +72,7 @@ pub fn init( let Some(window) = window else { return; }; - let workspace = cx.entity().clone(); + let workspace = cx.entity(); InlineAssistant::update_global(cx, |inline_assistant, cx| { inline_assistant.register_workspace(&workspace, window, cx) }); diff --git a/crates/agent_ui/src/profile_selector.rs b/crates/agent_ui/src/profile_selector.rs index 27ca69590f..ce25f531e2 100644 --- a/crates/agent_ui/src/profile_selector.rs +++ b/crates/agent_ui/src/profile_selector.rs @@ -163,7 +163,7 @@ impl Render for ProfileSelector { .unwrap_or_else(|| "Unknown".into()); if self.provider.profiles_supported(cx) { - let this = cx.entity().clone(); + let this = cx.entity(); let focus_handle = self.focus_handle.clone(); let trigger_button = Button::new("profile-selector-model", selected_profile) .label_size(LabelSize::Small) diff --git a/crates/collab_ui/src/chat_panel.rs b/crates/collab_ui/src/chat_panel.rs index 51d9f003f8..2bbaa8446c 100644 --- a/crates/collab_ui/src/chat_panel.rs +++ b/crates/collab_ui/src/chat_panel.rs @@ -674,7 +674,7 @@ impl ChatPanel { }) }) .when_some(message_id, |el, message_id| { - let this = cx.entity().clone(); + let this = cx.entity(); el.child( self.render_popover_button( diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index 430b447580..c2cc6a7ad5 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -95,7 +95,7 @@ pub fn init(cx: &mut App) { .and_then(|room| room.read(cx).channel_id()); if let Some(channel_id) = channel_id { - let workspace = cx.entity().clone(); + let workspace = cx.entity(); window.defer(cx, move |window, cx| { ChannelView::open(channel_id, None, workspace, window, cx) .detach_and_log_err(cx) @@ -1142,7 +1142,7 @@ impl CollabPanel { window: &mut Window, cx: &mut Context, ) { - let this = cx.entity().clone(); + let this = cx.entity(); if !(role == proto::ChannelRole::Guest || role == proto::ChannelRole::Talker || role == proto::ChannelRole::Member) @@ -1272,7 +1272,7 @@ impl CollabPanel { .channel_for_id(clipboard.channel_id) .map(|channel| channel.name.clone()) }); - let this = cx.entity().clone(); + let this = cx.entity(); let context_menu = ContextMenu::build(window, cx, |mut context_menu, window, cx| { if self.has_subchannels(ix) { @@ -1439,7 +1439,7 @@ impl CollabPanel { window: &mut Window, cx: &mut Context, ) { - let this = cx.entity().clone(); + let this = cx.entity(); let in_room = ActiveCall::global(cx).read(cx).room().is_some(); let context_menu = ContextMenu::build(window, cx, |mut context_menu, _, _| { diff --git a/crates/collab_ui/src/collab_panel/channel_modal.rs b/crates/collab_ui/src/collab_panel/channel_modal.rs index c0d3130ee9..e558835dba 100644 --- a/crates/collab_ui/src/collab_panel/channel_modal.rs +++ b/crates/collab_ui/src/collab_panel/channel_modal.rs @@ -586,7 +586,7 @@ impl ChannelModalDelegate { return; }; let user_id = membership.user.id; - let picker = cx.entity().clone(); + let picker = cx.entity(); let context_menu = ContextMenu::build(window, cx, |mut menu, _window, _cx| { let role = membership.role; diff --git a/crates/collab_ui/src/notification_panel.rs b/crates/collab_ui/src/notification_panel.rs index 3a280ff667..a3420d603b 100644 --- a/crates/collab_ui/src/notification_panel.rs +++ b/crates/collab_ui/src/notification_panel.rs @@ -321,7 +321,7 @@ impl NotificationPanel { .justify_end() .child(Button::new("decline", "Decline").on_click({ let notification = notification.clone(); - let entity = cx.entity().clone(); + let entity = cx.entity(); move |_, _, cx| { entity.update(cx, |this, cx| { this.respond_to_notification( @@ -334,7 +334,7 @@ impl NotificationPanel { })) .child(Button::new("accept", "Accept").on_click({ let notification = notification.clone(); - let entity = cx.entity().clone(); + let entity = cx.entity(); move |_, _, cx| { entity.update(cx, |this, cx| { this.respond_to_notification( diff --git a/crates/debugger_ui/src/session/running.rs b/crates/debugger_ui/src/session/running.rs index c8bee42039..f3117aee07 100644 --- a/crates/debugger_ui/src/session/running.rs +++ b/crates/debugger_ui/src/session/running.rs @@ -291,7 +291,7 @@ pub(crate) fn new_debugger_pane( let Some(project) = project.upgrade() else { return ControlFlow::Break(()); }; - let this_pane = cx.entity().clone(); + let this_pane = cx.entity(); let item = if tab.pane == this_pane { pane.item_for_index(tab.ix) } else { @@ -502,7 +502,7 @@ pub(crate) fn new_debugger_pane( .on_drag( DraggedTab { item: item.boxed_clone(), - pane: cx.entity().clone(), + pane: cx.entity(), detail: 0, is_active: selected, ix, diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs index 3d3b43d71b..4632a03daf 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -127,7 +127,7 @@ impl Render for EditPredictionButton { }), ); } - let this = cx.entity().clone(); + let this = cx.entity(); div().child( PopoverMenu::new("copilot") @@ -182,7 +182,7 @@ impl Render for EditPredictionButton { let icon = status.to_icon(); let tooltip_text = status.to_tooltip(); let has_menu = status.has_menu(); - let this = cx.entity().clone(); + let this = cx.entity(); let fs = self.fs.clone(); return div().child( @@ -331,7 +331,7 @@ impl Render for EditPredictionButton { }) }); - let this = cx.entity().clone(); + let this = cx.entity(); let mut popover_menu = PopoverMenu::new("zeta") .menu(move |window, cx| { diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index cf9954bc12..ef2bdc5da3 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -74,7 +74,7 @@ fn test_edit_events(cx: &mut TestAppContext) { let editor1 = cx.add_window({ let events = events.clone(); |window, cx| { - let entity = cx.entity().clone(); + let entity = cx.entity(); cx.subscribe_in( &entity, window, @@ -95,7 +95,7 @@ fn test_edit_events(cx: &mut TestAppContext) { let events = events.clone(); |window, cx| { cx.subscribe_in( - &cx.entity().clone(), + &cx.entity(), window, move |_, _, event: &EditorEvent, _, _| match event { EditorEvent::Edited { .. } => events.borrow_mut().push(("editor2", "edited")), @@ -19634,13 +19634,8 @@ fn test_crease_insertion_and_rendering(cx: &mut TestAppContext) { editor.insert_creases(Some(crease), cx); let snapshot = editor.snapshot(window, cx); - let _div = snapshot.render_crease_toggle( - MultiBufferRow(1), - false, - cx.entity().clone(), - window, - cx, - ); + let _div = + snapshot.render_crease_toggle(MultiBufferRow(1), false, cx.entity(), window, cx); snapshot }) .unwrap(); diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 8a5c65f994..5edfd7df30 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -7815,7 +7815,7 @@ impl Element for EditorElement { min_lines, max_lines, } => { - let editor_handle = cx.entity().clone(); + let editor_handle = cx.entity(); let max_line_number_width = self.max_line_number_width(&editor.snapshot(window, cx), window); window.request_measured_layout( diff --git a/crates/extensions_ui/src/extensions_ui.rs b/crates/extensions_ui/src/extensions_ui.rs index fe3e94f5c2..4915933920 100644 --- a/crates/extensions_ui/src/extensions_ui.rs +++ b/crates/extensions_ui/src/extensions_ui.rs @@ -703,7 +703,7 @@ impl ExtensionsPage { extension: &ExtensionMetadata, cx: &mut Context, ) -> ExtensionCard { - let this = cx.entity().clone(); + let this = cx.entity(); let status = Self::extension_status(&extension.id, cx); let has_dev_extension = Self::dev_extension_exists(&extension.id, cx); diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index de308b9dde..70987dd212 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -3410,7 +3410,7 @@ impl GitPanel { * MAX_PANEL_EDITOR_LINES + gap; - let git_panel = cx.entity().clone(); + let git_panel = cx.entity(); let display_name = SharedString::from(Arc::from( active_repository .read(cx) diff --git a/crates/gpui/examples/input.rs b/crates/gpui/examples/input.rs index 52a5b08b96..b0f560e38d 100644 --- a/crates/gpui/examples/input.rs +++ b/crates/gpui/examples/input.rs @@ -595,9 +595,7 @@ impl Render for TextInput { .w_full() .p(px(4.)) .bg(white()) - .child(TextElement { - input: cx.entity().clone(), - }), + .child(TextElement { input: cx.entity() }), ) } } diff --git a/crates/language_tools/src/lsp_log.rs b/crates/language_tools/src/lsp_log.rs index 606f3a3f0e..823d59ce12 100644 --- a/crates/language_tools/src/lsp_log.rs +++ b/crates/language_tools/src/lsp_log.rs @@ -1358,7 +1358,7 @@ impl Render for LspLogToolbarItemView { }) .collect(); - let log_toolbar_view = cx.entity().clone(); + let log_toolbar_view = cx.entity(); let lsp_menu = PopoverMenu::new("LspLogView") .anchor(Corner::TopLeft) diff --git a/crates/language_tools/src/lsp_tool.rs b/crates/language_tools/src/lsp_tool.rs index 50547253a9..3244350a34 100644 --- a/crates/language_tools/src/lsp_tool.rs +++ b/crates/language_tools/src/lsp_tool.rs @@ -1007,7 +1007,7 @@ impl Render for LspTool { (None, "All Servers Operational") }; - let lsp_tool = cx.entity().clone(); + let lsp_tool = cx.entity(); div().child( PopoverMenu::new("lsp-tool") diff --git a/crates/language_tools/src/syntax_tree_view.rs b/crates/language_tools/src/syntax_tree_view.rs index eadba2c1d2..9946442ec8 100644 --- a/crates/language_tools/src/syntax_tree_view.rs +++ b/crates/language_tools/src/syntax_tree_view.rs @@ -456,7 +456,7 @@ impl SyntaxTreeToolbarItemView { let active_layer = buffer_state.active_layer.clone()?; let active_buffer = buffer_state.buffer.read(cx).snapshot(); - let view = cx.entity().clone(); + let view = cx.entity(); Some( PopoverMenu::new("Syntax Tree") .trigger(Self::render_header(&active_layer)) diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index 1cda3897ec..004a27b0cf 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -4815,51 +4815,45 @@ impl OutlinePanel { .when(show_indent_guides, |list| { list.with_decoration( ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx)) - .with_compute_indents_fn( - cx.entity().clone(), - |outline_panel, range, _, _| { - let entries = outline_panel.cached_entries.get(range); - if let Some(entries) = entries { - entries.into_iter().map(|item| item.depth).collect() - } else { - smallvec::SmallVec::new() - } - }, - ) - .with_render_fn( - cx.entity().clone(), - move |outline_panel, params, _, _| { - const LEFT_OFFSET: Pixels = px(14.); + .with_compute_indents_fn(cx.entity(), |outline_panel, range, _, _| { + let entries = outline_panel.cached_entries.get(range); + if let Some(entries) = entries { + entries.into_iter().map(|item| item.depth).collect() + } else { + smallvec::SmallVec::new() + } + }) + .with_render_fn(cx.entity(), move |outline_panel, params, _, _| { + const LEFT_OFFSET: Pixels = px(14.); - let indent_size = params.indent_size; - let item_height = params.item_height; - let active_indent_guide_ix = find_active_indent_guide_ix( - outline_panel, - ¶ms.indent_guides, - ); + let indent_size = params.indent_size; + let item_height = params.item_height; + let active_indent_guide_ix = find_active_indent_guide_ix( + outline_panel, + ¶ms.indent_guides, + ); - params - .indent_guides - .into_iter() - .enumerate() - .map(|(ix, layout)| { - let bounds = Bounds::new( - point( - layout.offset.x * indent_size + LEFT_OFFSET, - layout.offset.y * item_height, - ), - size(px(1.), layout.length * item_height), - ); - ui::RenderedIndentGuide { - bounds, - layout, - is_active: active_indent_guide_ix == Some(ix), - hitbox: None, - } - }) - .collect() - }, - ), + params + .indent_guides + .into_iter() + .enumerate() + .map(|(ix, layout)| { + let bounds = Bounds::new( + point( + layout.offset.x * indent_size + LEFT_OFFSET, + layout.offset.y * item_height, + ), + size(px(1.), layout.length * item_height), + ); + ui::RenderedIndentGuide { + bounds, + layout, + is_active: active_indent_guide_ix == Some(ix), + hitbox: None, + } + }) + .collect() + }), ) }) }; diff --git a/crates/project_panel/src/project_panel.rs b/crates/project_panel/src/project_panel.rs index 967df41e23..4d7f2faf62 100644 --- a/crates/project_panel/src/project_panel.rs +++ b/crates/project_panel/src/project_panel.rs @@ -5351,26 +5351,22 @@ impl Render for ProjectPanel { .when(show_indent_guides, |list| { list.with_decoration( ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx)) - .with_compute_indents_fn( - cx.entity().clone(), - |this, range, window, cx| { - let mut items = - SmallVec::with_capacity(range.end - range.start); - this.iter_visible_entries( - range, - window, - cx, - |entry, _, entries, _, _| { - let (depth, _) = - Self::calculate_depth_and_difference( - entry, entries, - ); - items.push(depth); - }, - ); - items - }, - ) + .with_compute_indents_fn(cx.entity(), |this, range, window, cx| { + let mut items = + SmallVec::with_capacity(range.end - range.start); + this.iter_visible_entries( + range, + window, + cx, + |entry, _, entries, _, _| { + let (depth, _) = Self::calculate_depth_and_difference( + entry, entries, + ); + items.push(depth); + }, + ); + items + }) .on_click(cx.listener( |this, active_indent_guide: &IndentGuideLayout, window, cx| { if window.modifiers().secondary() { @@ -5394,7 +5390,7 @@ impl Render for ProjectPanel { } }, )) - .with_render_fn(cx.entity().clone(), move |this, params, _, cx| { + .with_render_fn(cx.entity(), move |this, params, _, cx| { const LEFT_OFFSET: Pixels = px(14.); const PADDING_Y: Pixels = px(4.); const HITBOX_OVERDRAW: Pixels = px(3.); @@ -5447,7 +5443,7 @@ impl Render for ProjectPanel { }) .when(show_sticky_entries, |list| { let sticky_items = ui::sticky_items( - cx.entity().clone(), + cx.entity(), |this, range, window, cx| { let mut items = SmallVec::with_capacity(range.end - range.start); this.iter_visible_entries( @@ -5474,7 +5470,7 @@ impl Render for ProjectPanel { list.with_decoration(if show_indent_guides { sticky_items.with_decoration( ui::indent_guides(px(indent_size), IndentGuideColors::panel(cx)) - .with_render_fn(cx.entity().clone(), move |_, params, _, _| { + .with_render_fn(cx.entity(), move |_, params, _, _| { const LEFT_OFFSET: Pixels = px(14.); let indent_size = params.indent_size; diff --git a/crates/recent_projects/src/remote_servers.rs b/crates/recent_projects/src/remote_servers.rs index 354434a7fc..e5e166cb4c 100644 --- a/crates/recent_projects/src/remote_servers.rs +++ b/crates/recent_projects/src/remote_servers.rs @@ -1292,7 +1292,7 @@ impl RemoteServerProjects { let connection_string = connection_string.clone(); move |_, _: &menu::Confirm, window, cx| { remove_ssh_server( - cx.entity().clone(), + cx.entity(), server_index, connection_string.clone(), window, @@ -1312,7 +1312,7 @@ impl RemoteServerProjects { .child(Label::new("Remove Server").color(Color::Error)) .on_click(cx.listener(move |_, _, window, cx| { remove_ssh_server( - cx.entity().clone(), + cx.entity(), server_index, connection_string.clone(), window, diff --git a/crates/repl/src/session.rs b/crates/repl/src/session.rs index 729a616135..f945e5ed9f 100644 --- a/crates/repl/src/session.rs +++ b/crates/repl/src/session.rs @@ -244,7 +244,7 @@ impl Session { repl_session_id = cx.entity_id().to_string(), ); - let session_view = cx.entity().clone(); + let session_view = cx.entity(); let kernel = match self.kernel_specification.clone() { KernelSpecification::Jupyter(kernel_specification) diff --git a/crates/storybook/src/stories/indent_guides.rs b/crates/storybook/src/stories/indent_guides.rs index e4f9669b1f..db23ea79bd 100644 --- a/crates/storybook/src/stories/indent_guides.rs +++ b/crates/storybook/src/stories/indent_guides.rs @@ -65,7 +65,7 @@ impl Render for IndentGuidesStory { }, ) .with_compute_indents_fn( - cx.entity().clone(), + cx.entity(), |this, range, _cx, _context| { this.depths .iter() diff --git a/crates/terminal_view/src/terminal_panel.rs b/crates/terminal_view/src/terminal_panel.rs index c9528c39b9..568dc1db2e 100644 --- a/crates/terminal_view/src/terminal_panel.rs +++ b/crates/terminal_view/src/terminal_panel.rs @@ -947,7 +947,7 @@ pub fn new_terminal_pane( cx: &mut Context, ) -> Entity { let is_local = project.read(cx).is_local(); - let terminal_panel = cx.entity().clone(); + let terminal_panel = cx.entity(); let pane = cx.new(|cx| { let mut pane = Pane::new( workspace.clone(), @@ -1009,7 +1009,7 @@ pub fn new_terminal_pane( return ControlFlow::Break(()); }; if let Some(tab) = dropped_item.downcast_ref::() { - let this_pane = cx.entity().clone(); + let this_pane = cx.entity(); let item = if tab.pane == this_pane { pane.item_for_index(tab.ix) } else { diff --git a/crates/terminal_view/src/terminal_view.rs b/crates/terminal_view/src/terminal_view.rs index 219238496c..534c0a8051 100644 --- a/crates/terminal_view/src/terminal_view.rs +++ b/crates/terminal_view/src/terminal_view.rs @@ -1491,7 +1491,7 @@ impl TerminalView { impl Render for TerminalView { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let terminal_handle = self.terminal.clone(); - let terminal_view_handle = cx.entity().clone(); + let terminal_view_handle = cx.entity(); let focused = self.focus_handle.is_focused(window); diff --git a/crates/vim/src/mode_indicator.rs b/crates/vim/src/mode_indicator.rs index d54b270074..714b74f239 100644 --- a/crates/vim/src/mode_indicator.rs +++ b/crates/vim/src/mode_indicator.rs @@ -20,7 +20,7 @@ impl ModeIndicator { }) .detach(); - let handle = cx.entity().clone(); + let handle = cx.entity(); let window_handle = window.window_handle(); cx.observe_new::(move |_, window, cx| { let Some(window) = window else { @@ -29,7 +29,7 @@ impl ModeIndicator { if window.window_handle() != window_handle { return; } - let vim = cx.entity().clone(); + let vim = cx.entity(); handle.update(cx, |_, cx| { cx.subscribe(&vim, |mode_indicator, vim, event, cx| match event { VimEvent::Focused => { diff --git a/crates/vim/src/normal/search.rs b/crates/vim/src/normal/search.rs index e4e95ca48e..4054c552ae 100644 --- a/crates/vim/src/normal/search.rs +++ b/crates/vim/src/normal/search.rs @@ -332,7 +332,7 @@ impl Vim { Vim::take_forced_motion(cx); let prior_selections = self.editor_selections(window, cx); let cursor_word = self.editor_cursor_word(window, cx); - let vim = cx.entity().clone(); + let vim = cx.entity(); let searched = pane.update(cx, |pane, cx| { self.search.direction = direction; diff --git a/crates/vim/src/vim.rs b/crates/vim/src/vim.rs index 51bf2dd131..44d9b8f456 100644 --- a/crates/vim/src/vim.rs +++ b/crates/vim/src/vim.rs @@ -402,7 +402,7 @@ impl Vim { const NAMESPACE: &'static str = "vim"; pub fn new(window: &mut Window, cx: &mut Context) -> Entity { - let editor = cx.entity().clone(); + let editor = cx.entity(); let mut initial_mode = VimSettings::get_global(cx).default_mode; if initial_mode == Mode::Normal && HelixModeSetting::get_global(cx).0 { diff --git a/crates/workspace/src/dock.rs b/crates/workspace/src/dock.rs index ca63d3e553..ae72df3971 100644 --- a/crates/workspace/src/dock.rs +++ b/crates/workspace/src/dock.rs @@ -253,7 +253,7 @@ impl Dock { cx: &mut Context, ) -> Entity { let focus_handle = cx.focus_handle(); - let workspace = cx.entity().clone(); + let workspace = cx.entity(); let dock = cx.new(|cx| { let focus_subscription = cx.on_focus(&focus_handle, window, |dock: &mut Dock, window, cx| { diff --git a/crates/workspace/src/notifications.rs b/crates/workspace/src/notifications.rs index 7d8a28b0f1..1356322a5c 100644 --- a/crates/workspace/src/notifications.rs +++ b/crates/workspace/src/notifications.rs @@ -346,7 +346,7 @@ impl Render for LanguageServerPrompt { ) .child(Label::new(request.message.to_string()).size(LabelSize::Small)) .children(request.actions.iter().enumerate().map(|(ix, action)| { - let this_handle = cx.entity().clone(); + let this_handle = cx.entity(); Button::new(ix, action.title.clone()) .size(ButtonSize::Large) .on_click(move |_, window, cx| { diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index 759e91f758..860a57c21f 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -2198,7 +2198,7 @@ impl Pane { fn update_status_bar(&mut self, window: &mut Window, cx: &mut Context) { let workspace = self.workspace.clone(); - let pane = cx.entity().clone(); + let pane = cx.entity(); window.defer(cx, move |window, cx| { let Ok(status_bar) = @@ -2279,7 +2279,7 @@ impl Pane { cx: &mut Context, ) { maybe!({ - let pane = cx.entity().clone(); + let pane = cx.entity(); let destination_index = match operation { PinOperation::Pin => self.pinned_tab_count.min(ix), @@ -2473,7 +2473,7 @@ impl Pane { .on_drag( DraggedTab { item: item.boxed_clone(), - pane: cx.entity().clone(), + pane: cx.entity(), detail, is_active, ix, @@ -2832,7 +2832,7 @@ impl Pane { let navigate_backward = IconButton::new("navigate_backward", IconName::ArrowLeft) .icon_size(IconSize::Small) .on_click({ - let entity = cx.entity().clone(); + let entity = cx.entity(); move |_, window, cx| { entity.update(cx, |pane, cx| pane.navigate_backward(window, cx)) } @@ -2848,7 +2848,7 @@ impl Pane { let navigate_forward = IconButton::new("navigate_forward", IconName::ArrowRight) .icon_size(IconSize::Small) .on_click({ - let entity = cx.entity().clone(); + let entity = cx.entity(); move |_, window, cx| entity.update(cx, |pane, cx| pane.navigate_forward(window, cx)) }) .disabled(!self.can_navigate_forward()) @@ -3054,7 +3054,7 @@ impl Pane { return; } } - let mut to_pane = cx.entity().clone(); + let mut to_pane = cx.entity(); let split_direction = self.drag_split_direction; let item_id = dragged_tab.item.item_id(); if let Some(preview_item_id) = self.preview_item_id { @@ -3163,7 +3163,7 @@ impl Pane { return; } } - let mut to_pane = cx.entity().clone(); + let mut to_pane = cx.entity(); let split_direction = self.drag_split_direction; let project_entry_id = *project_entry_id; self.workspace @@ -3239,7 +3239,7 @@ impl Pane { return; } } - let mut to_pane = cx.entity().clone(); + let mut to_pane = cx.entity(); let mut split_direction = self.drag_split_direction; let paths = paths.paths().to_vec(); let is_remote = self diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index ade6838fad..1eaa125ba5 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -6338,7 +6338,7 @@ impl Render for Workspace { .border_b_1() .border_color(colors.border) .child({ - let this = cx.entity().clone(); + let this = cx.entity(); canvas( move |bounds, window, cx| { this.update(cx, |this, cx| { diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 84145a1be4..b06652b2ce 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -319,7 +319,7 @@ pub fn initialize_workspace( return; }; - let workspace_handle = cx.entity().clone(); + let workspace_handle = cx.entity(); let center_pane = workspace.active_pane().clone(); initialize_pane(workspace, ¢er_pane, window, cx); From 239e479aedebb45cbc2efd7d0417808a3001710c Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 15 Aug 2025 16:49:56 -0400 Subject: [PATCH 12/17] collab: Remove Stripe code (#36275) This PR removes the code for integrating with Stripe from Collab. All of these concerns are now handled by Cloud. Release Notes: - N/A --- Cargo.lock | 159 +---- Cargo.toml | 14 - crates/collab/Cargo.toml | 5 - crates/collab/k8s/collab.template.yml | 6 - crates/collab/src/api.rs | 1 - crates/collab/src/api/billing.rs | 59 -- .../src/db/tables/billing_subscription.rs | 15 - crates/collab/src/lib.rs | 44 -- crates/collab/src/main.rs | 7 - crates/collab/src/stripe_billing.rs | 156 ----- crates/collab/src/stripe_client.rs | 285 -------- .../src/stripe_client/fake_stripe_client.rs | 247 ------- .../src/stripe_client/real_stripe_client.rs | 612 ------------------ crates/collab/src/tests.rs | 2 - .../collab/src/tests/stripe_billing_tests.rs | 123 ---- crates/collab/src/tests/test_server.rs | 5 - 16 files changed, 2 insertions(+), 1738 deletions(-) delete mode 100644 crates/collab/src/api/billing.rs delete mode 100644 crates/collab/src/stripe_billing.rs delete mode 100644 crates/collab/src/stripe_client.rs delete mode 100644 crates/collab/src/stripe_client/fake_stripe_client.rs delete mode 100644 crates/collab/src/stripe_client/real_stripe_client.rs delete mode 100644 crates/collab/src/tests/stripe_billing_tests.rs diff --git a/Cargo.lock b/Cargo.lock index bfc797d6cd..2be16cc22f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1262,26 +1262,6 @@ dependencies = [ "syn 2.0.101", ] -[[package]] -name = "async-stripe" -version = "0.40.0" -source = "git+https://github.com/zed-industries/async-stripe?rev=3672dd4efb7181aa597bf580bf5a2f5d23db6735#3672dd4efb7181aa597bf580bf5a2f5d23db6735" -dependencies = [ - "chrono", - "futures-util", - "http-types", - "hyper 0.14.32", - "hyper-rustls 0.24.2", - "serde", - "serde_json", - "serde_path_to_error", - "serde_qs 0.10.1", - "smart-default 0.6.0", - "smol_str 0.1.24", - "thiserror 1.0.69", - "tokio", -] - [[package]] name = "async-tar" version = "0.5.0" @@ -2083,12 +2063,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce" -[[package]] -name = "base64" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" - [[package]] name = "base64" version = "0.21.7" @@ -3281,7 +3255,6 @@ dependencies = [ "anyhow", "assistant_context", "assistant_slash_command", - "async-stripe", "async-trait", "async-tungstenite", "audio", @@ -3308,7 +3281,6 @@ dependencies = [ "dap_adapters", "dashmap 6.1.0", "debugger_ui", - "derive_more 0.99.19", "editor", "envy", "extension", @@ -3870,7 +3842,7 @@ dependencies = [ "rustc-hash 1.1.0", "rustybuzz 0.14.1", "self_cell", - "smol_str 0.2.2", + "smol_str", "swash", "sys-locale", "ttf-parser 0.21.1", @@ -6374,17 +6346,6 @@ dependencies = [ "windows-targets 0.48.5", ] -[[package]] -name = "getrandom" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" -dependencies = [ - "cfg-if", - "libc", - "wasi 0.9.0+wasi-snapshot-preview1", -] - [[package]] name = "getrandom" version = "0.2.15" @@ -7988,27 +7949,6 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f" -[[package]] -name = "http-types" -version = "2.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e9b187a72d63adbfba487f48095306ac823049cb504ee195541e91c7775f5ad" -dependencies = [ - "anyhow", - "async-channel 1.9.0", - "base64 0.13.1", - "futures-lite 1.13.0", - "http 0.2.12", - "infer", - "pin-project-lite", - "rand 0.7.3", - "serde", - "serde_json", - "serde_qs 0.8.5", - "serde_urlencoded", - "url", -] - [[package]] name = "http_client" version = "0.1.0" @@ -8487,12 +8427,6 @@ version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" -[[package]] -name = "infer" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac" - [[package]] name = "inherent" version = "1.0.12" @@ -10269,7 +10203,7 @@ dependencies = [ "num-traits", "range-map", "scroll", - "smart-default 0.7.1", + "smart-default", ] [[package]] @@ -13143,19 +13077,6 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" -[[package]] -name = "rand" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" -dependencies = [ - "getrandom 0.1.16", - "libc", - "rand_chacha 0.2.2", - "rand_core 0.5.1", - "rand_hc", -] - [[package]] name = "rand" version = "0.8.5" @@ -13177,16 +13098,6 @@ dependencies = [ "rand_core 0.9.3", ] -[[package]] -name = "rand_chacha" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" -dependencies = [ - "ppv-lite86", - "rand_core 0.5.1", -] - [[package]] name = "rand_chacha" version = "0.3.1" @@ -13207,15 +13118,6 @@ dependencies = [ "rand_core 0.9.3", ] -[[package]] -name = "rand_core" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" -dependencies = [ - "getrandom 0.1.16", -] - [[package]] name = "rand_core" version = "0.6.4" @@ -13234,15 +13136,6 @@ dependencies = [ "getrandom 0.3.2", ] -[[package]] -name = "rand_hc" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" -dependencies = [ - "rand_core 0.5.1", -] - [[package]] name = "range-map" version = "0.2.0" @@ -14897,28 +14790,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_qs" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7715380eec75f029a4ef7de39a9200e0a63823176b759d055b613f5a87df6a6" -dependencies = [ - "percent-encoding", - "serde", - "thiserror 1.0.69", -] - -[[package]] -name = "serde_qs" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cac3f1e2ca2fe333923a1ae72caca910b98ed0630bb35ef6f8c8517d6e81afa" -dependencies = [ - "percent-encoding", - "serde", - "thiserror 1.0.69", -] - [[package]] name = "serde_repr" version = "0.1.20" @@ -15295,17 +15166,6 @@ dependencies = [ "serde", ] -[[package]] -name = "smart-default" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "133659a15339456eeeb07572eb02a91c91e9815e9cbc89566944d2c8d3efdbf6" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "smart-default" version = "0.7.1" @@ -15334,15 +15194,6 @@ dependencies = [ "futures-lite 2.6.0", ] -[[package]] -name = "smol_str" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fad6c857cbab2627dcf01ec85a623ca4e7dcb5691cbaa3d7fb7653671f0d09c9" -dependencies = [ - "serde", -] - [[package]] name = "smol_str" version = "0.2.2" @@ -18191,12 +18042,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "wasi" -version = "0.9.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index baa4ee7f4e..644b6c0f40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -667,20 +667,6 @@ workspace-hack = "0.1.0" yawc = { git = "https://github.com/deviant-forks/yawc", rev = "1899688f3e69ace4545aceb97b2a13881cf26142" } zstd = "0.11" -[workspace.dependencies.async-stripe] -git = "https://github.com/zed-industries/async-stripe" -rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735" -default-features = false -features = [ - "runtime-tokio-hyper-rustls", - "billing", - "checkout", - "events", - # The features below are only enabled to get the `events` feature to build. - "chrono", - "connect", -] - [workspace.dependencies.windows] version = "0.61" features = [ diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 9a867f9e05..6fc591be13 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -19,7 +19,6 @@ test-support = ["sqlite"] [dependencies] anyhow.workspace = true -async-stripe.workspace = true async-trait.workspace = true async-tungstenite.workspace = true aws-config = { version = "1.1.5" } @@ -33,7 +32,6 @@ clock.workspace = true cloud_llm_client.workspace = true collections.workspace = true dashmap.workspace = true -derive_more.workspace = true envy = "0.4.2" futures.workspace = true gpui.workspace = true @@ -134,6 +132,3 @@ util.workspace = true workspace = { workspace = true, features = ["test-support"] } worktree = { workspace = true, features = ["test-support"] } zlog.workspace = true - -[package.metadata.cargo-machete] -ignored = ["async-stripe"] diff --git a/crates/collab/k8s/collab.template.yml b/crates/collab/k8s/collab.template.yml index 45fc018a4a..214b550ac2 100644 --- a/crates/collab/k8s/collab.template.yml +++ b/crates/collab/k8s/collab.template.yml @@ -219,12 +219,6 @@ spec: secretKeyRef: name: slack key: panics_webhook - - name: STRIPE_API_KEY - valueFrom: - secretKeyRef: - name: stripe - key: api_key - optional: true - name: COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR value: "1000" - name: SUPERMAVEN_ADMIN_API_KEY diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 078a4469ae..143e764eb3 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -1,4 +1,3 @@ -pub mod billing; pub mod contributors; pub mod events; pub mod extensions; diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs deleted file mode 100644 index a0325d14c4..0000000000 --- a/crates/collab/src/api/billing.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::sync::Arc; -use stripe::SubscriptionStatus; - -use crate::AppState; -use crate::db::billing_subscription::StripeSubscriptionStatus; -use crate::db::{CreateBillingCustomerParams, billing_customer}; -use crate::stripe_client::{StripeClient, StripeCustomerId}; - -impl From for StripeSubscriptionStatus { - fn from(value: SubscriptionStatus) -> Self { - match value { - SubscriptionStatus::Incomplete => Self::Incomplete, - SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired, - SubscriptionStatus::Trialing => Self::Trialing, - SubscriptionStatus::Active => Self::Active, - SubscriptionStatus::PastDue => Self::PastDue, - SubscriptionStatus::Canceled => Self::Canceled, - SubscriptionStatus::Unpaid => Self::Unpaid, - SubscriptionStatus::Paused => Self::Paused, - } - } -} - -/// Finds or creates a billing customer using the provided customer. -pub async fn find_or_create_billing_customer( - app: &Arc, - stripe_client: &dyn StripeClient, - customer_id: &StripeCustomerId, -) -> anyhow::Result> { - // If we already have a billing customer record associated with the Stripe customer, - // there's nothing more we need to do. - if let Some(billing_customer) = app - .db - .get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref()) - .await? - { - return Ok(Some(billing_customer)); - } - - let customer = stripe_client.get_customer(customer_id).await?; - - let Some(email) = customer.email else { - return Ok(None); - }; - - let Some(user) = app.db.get_user_by_email(&email).await? else { - return Ok(None); - }; - - let billing_customer = app - .db - .create_billing_customer(&CreateBillingCustomerParams { - user_id: user.id, - stripe_customer_id: customer.id.to_string(), - }) - .await?; - - Ok(Some(billing_customer)) -} diff --git a/crates/collab/src/db/tables/billing_subscription.rs b/crates/collab/src/db/tables/billing_subscription.rs index 522973dbc9..f5684aeec3 100644 --- a/crates/collab/src/db/tables/billing_subscription.rs +++ b/crates/collab/src/db/tables/billing_subscription.rs @@ -1,5 +1,4 @@ use crate::db::{BillingCustomerId, BillingSubscriptionId}; -use crate::stripe_client; use chrono::{Datelike as _, NaiveDate, Utc}; use sea_orm::entity::prelude::*; use serde::Serialize; @@ -160,17 +159,3 @@ pub enum StripeCancellationReason { #[sea_orm(string_value = "payment_failed")] PaymentFailed, } - -impl From for StripeCancellationReason { - fn from(value: stripe_client::StripeCancellationDetailsReason) -> Self { - match value { - stripe_client::StripeCancellationDetailsReason::CancellationRequested => { - Self::CancellationRequested - } - stripe_client::StripeCancellationDetailsReason::PaymentDisputed => { - Self::PaymentDisputed - } - stripe_client::StripeCancellationDetailsReason::PaymentFailed => Self::PaymentFailed, - } - } -} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 905859ca69..a68286a5a3 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -7,8 +7,6 @@ pub mod llm; pub mod migrations; pub mod rpc; pub mod seed; -pub mod stripe_billing; -pub mod stripe_client; pub mod user_backfiller; #[cfg(test)] @@ -27,16 +25,12 @@ use serde::Deserialize; use std::{path::PathBuf, sync::Arc}; use util::ResultExt; -use crate::stripe_billing::StripeBilling; -use crate::stripe_client::{RealStripeClient, StripeClient}; - pub type Result = std::result::Result; pub enum Error { Http(StatusCode, String, HeaderMap), Database(sea_orm::error::DbErr), Internal(anyhow::Error), - Stripe(stripe::StripeError), } impl From for Error { @@ -51,12 +45,6 @@ impl From for Error { } } -impl From for Error { - fn from(error: stripe::StripeError) -> Self { - Self::Stripe(error) - } -} - impl From for Error { fn from(error: axum::Error) -> Self { Self::Internal(error.into()) @@ -104,14 +92,6 @@ impl IntoResponse for Error { ); (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() } - Error::Stripe(error) => { - log::error!( - "HTTP error {}: {:?}", - StatusCode::INTERNAL_SERVER_ERROR, - &error - ); - (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() - } } } } @@ -122,7 +102,6 @@ impl std::fmt::Debug for Error { Error::Http(code, message, _headers) => (code, message).fmt(f), Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), - Error::Stripe(error) => error.fmt(f), } } } @@ -133,7 +112,6 @@ impl std::fmt::Display for Error { Error::Http(code, message, _) => write!(f, "{code}: {message}"), Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), - Error::Stripe(error) => error.fmt(f), } } } @@ -179,7 +157,6 @@ pub struct Config { pub zed_client_checksum_seed: Option, pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, - pub stripe_api_key: Option, pub supermaven_admin_api_key: Option>, pub user_backfiller_github_access_token: Option>, } @@ -234,7 +211,6 @@ impl Config { auto_join_channel_id: None, migrations_path: None, seed_path: None, - stripe_api_key: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, kinesis_region: None, @@ -269,11 +245,6 @@ pub struct AppState { pub llm_db: Option>, pub livekit_client: Option>, pub blob_store_client: Option, - /// This is a real instance of the Stripe client; we're working to replace references to this with the - /// [`StripeClient`] trait. - pub real_stripe_client: Option>, - pub stripe_client: Option>, - pub stripe_billing: Option>, pub executor: Executor, pub kinesis_client: Option<::aws_sdk_kinesis::Client>, pub config: Config, @@ -316,18 +287,11 @@ impl AppState { }; let db = Arc::new(db); - let stripe_client = build_stripe_client(&config).map(Arc::new).log_err(); let this = Self { db: db.clone(), llm_db, livekit_client, blob_store_client: build_blob_store_client(&config).await.log_err(), - stripe_billing: stripe_client - .clone() - .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))), - real_stripe_client: stripe_client.clone(), - stripe_client: stripe_client - .map(|stripe_client| Arc::new(RealStripeClient::new(stripe_client)) as _), executor, kinesis_client: if config.kinesis_access_key.is_some() { build_kinesis_client(&config).await.log_err() @@ -340,14 +304,6 @@ impl AppState { } } -fn build_stripe_client(config: &Config) -> anyhow::Result { - let api_key = config - .stripe_api_key - .as_ref() - .context("missing stripe_api_key")?; - Ok(stripe::Client::new(api_key)) -} - async fn build_blob_store_client(config: &Config) -> anyhow::Result { let keys = aws_sdk_s3::config::Credentials::new( config diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 20641cb232..177c97f076 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -102,13 +102,6 @@ async fn main() -> Result<()> { let state = AppState::new(config, Executor::Production).await?; - if let Some(stripe_billing) = state.stripe_billing.clone() { - let executor = state.executor.clone(); - executor.spawn_detached(async move { - stripe_billing.initialize().await.trace_err(); - }); - } - if mode.is_collab() { state.db.purge_old_embeddings().await.trace_err(); diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs deleted file mode 100644 index ef5bef3e7e..0000000000 --- a/crates/collab/src/stripe_billing.rs +++ /dev/null @@ -1,156 +0,0 @@ -use std::sync::Arc; - -use anyhow::anyhow; -use collections::HashMap; -use stripe::SubscriptionStatus; -use tokio::sync::RwLock; - -use crate::Result; -use crate::stripe_client::{ - RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateSubscriptionItems, - StripeCreateSubscriptionParams, StripeCustomerId, StripePrice, StripePriceId, - StripeSubscription, -}; - -pub struct StripeBilling { - state: RwLock, - client: Arc, -} - -#[derive(Default)] -struct StripeBillingState { - prices_by_lookup_key: HashMap, -} - -impl StripeBilling { - pub fn new(client: Arc) -> Self { - Self { - client: Arc::new(RealStripeClient::new(client.clone())), - state: RwLock::default(), - } - } - - #[cfg(test)] - pub fn test(client: Arc) -> Self { - Self { - client, - state: RwLock::default(), - } - } - - pub fn client(&self) -> &Arc { - &self.client - } - - pub async fn initialize(&self) -> Result<()> { - log::info!("StripeBilling: initializing"); - - let mut state = self.state.write().await; - - let prices = self.client.list_prices().await?; - - for price in prices { - if let Some(lookup_key) = price.lookup_key.clone() { - state.prices_by_lookup_key.insert(lookup_key, price); - } - } - - log::info!("StripeBilling: initialized"); - - Ok(()) - } - - pub async fn zed_pro_price_id(&self) -> Result { - self.find_price_id_by_lookup_key("zed-pro").await - } - - pub async fn zed_free_price_id(&self) -> Result { - self.find_price_id_by_lookup_key("zed-free").await - } - - pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result { - self.state - .read() - .await - .prices_by_lookup_key - .get(lookup_key) - .map(|price| price.id.clone()) - .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}"))) - } - - pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result { - self.state - .read() - .await - .prices_by_lookup_key - .get(lookup_key) - .cloned() - .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}"))) - } - - /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does - /// not already exist. - /// - /// Always returns a new Stripe customer if the email address is `None`. - pub async fn find_or_create_customer_by_email( - &self, - email_address: Option<&str>, - ) -> Result { - let existing_customer = if let Some(email) = email_address { - let customers = self.client.list_customers_by_email(email).await?; - - customers.first().cloned() - } else { - None - }; - - let customer_id = if let Some(existing_customer) = existing_customer { - existing_customer.id - } else { - let customer = self - .client - .create_customer(crate::stripe_client::CreateCustomerParams { - email: email_address, - }) - .await?; - - customer.id - }; - - Ok(customer_id) - } - - pub async fn subscribe_to_zed_free( - &self, - customer_id: StripeCustomerId, - ) -> Result { - let zed_free_price_id = self.zed_free_price_id().await?; - - let existing_subscriptions = self - .client - .list_subscriptions_for_customer(&customer_id) - .await?; - - let existing_active_subscription = - existing_subscriptions.into_iter().find(|subscription| { - subscription.status == SubscriptionStatus::Active - || subscription.status == SubscriptionStatus::Trialing - }); - if let Some(subscription) = existing_active_subscription { - return Ok(subscription); - } - - let params = StripeCreateSubscriptionParams { - customer: customer_id, - items: vec![StripeCreateSubscriptionItems { - price: Some(zed_free_price_id), - quantity: Some(1), - }], - automatic_tax: Some(StripeAutomaticTax { enabled: true }), - }; - - let subscription = self.client.create_subscription(params).await?; - - Ok(subscription) - } -} diff --git a/crates/collab/src/stripe_client.rs b/crates/collab/src/stripe_client.rs deleted file mode 100644 index 6e75a4d874..0000000000 --- a/crates/collab/src/stripe_client.rs +++ /dev/null @@ -1,285 +0,0 @@ -#[cfg(test)] -mod fake_stripe_client; -mod real_stripe_client; - -use std::collections::HashMap; -use std::sync::Arc; - -use anyhow::Result; -use async_trait::async_trait; - -#[cfg(test)] -pub use fake_stripe_client::*; -pub use real_stripe_client::*; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Serialize)] -pub struct StripeCustomerId(pub Arc); - -#[derive(Debug, Clone)] -pub struct StripeCustomer { - pub id: StripeCustomerId, - pub email: Option, -} - -#[derive(Debug)] -pub struct CreateCustomerParams<'a> { - pub email: Option<&'a str>, -} - -#[derive(Debug)] -pub struct UpdateCustomerParams<'a> { - pub email: Option<&'a str>, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] -pub struct StripeSubscriptionId(pub Arc); - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscription { - pub id: StripeSubscriptionId, - pub customer: StripeCustomerId, - // TODO: Create our own version of this enum. - pub status: stripe::SubscriptionStatus, - pub current_period_end: i64, - pub current_period_start: i64, - pub items: Vec, - pub cancel_at: Option, - pub cancellation_details: Option, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] -pub struct StripeSubscriptionItemId(pub Arc); - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscriptionItem { - pub id: StripeSubscriptionItemId, - pub price: Option, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct StripeCancellationDetails { - pub reason: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCancellationDetailsReason { - CancellationRequested, - PaymentDisputed, - PaymentFailed, -} - -#[derive(Debug)] -pub struct StripeCreateSubscriptionParams { - pub customer: StripeCustomerId, - pub items: Vec, - pub automatic_tax: Option, -} - -#[derive(Debug)] -pub struct StripeCreateSubscriptionItems { - pub price: Option, - pub quantity: Option, -} - -#[derive(Debug, Clone)] -pub struct UpdateSubscriptionParams { - pub items: Option>, - pub trial_settings: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct UpdateSubscriptionItems { - pub price: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscriptionTrialSettings { - pub end_behavior: StripeSubscriptionTrialSettingsEndBehavior, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeSubscriptionTrialSettingsEndBehavior { - pub missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod { - Cancel, - CreateInvoice, - Pause, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] -pub struct StripePriceId(pub Arc); - -#[derive(Debug, PartialEq, Clone)] -pub struct StripePrice { - pub id: StripePriceId, - pub unit_amount: Option, - pub lookup_key: Option, - pub recurring: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripePriceRecurring { - pub meter: Option, -} - -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Deserialize)] -pub struct StripeMeterId(pub Arc); - -#[derive(Debug, Clone, Deserialize)] -pub struct StripeMeter { - pub id: StripeMeterId, - pub event_name: String, -} - -#[derive(Debug, Serialize)] -pub struct StripeCreateMeterEventParams<'a> { - pub identifier: &'a str, - pub event_name: &'a str, - pub payload: StripeCreateMeterEventPayload<'a>, - pub timestamp: Option, -} - -#[derive(Debug, Serialize)] -pub struct StripeCreateMeterEventPayload<'a> { - pub value: u64, - pub stripe_customer_id: &'a StripeCustomerId, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeBillingAddressCollection { - Auto, - Required, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeCustomerUpdate { - pub address: Option, - pub name: Option, - pub shipping: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCustomerUpdateAddress { - Auto, - Never, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCustomerUpdateName { - Auto, - Never, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCustomerUpdateShipping { - Auto, - Never, -} - -#[derive(Debug, Default)] -pub struct StripeCreateCheckoutSessionParams<'a> { - pub customer: Option<&'a StripeCustomerId>, - pub client_reference_id: Option<&'a str>, - pub mode: Option, - pub line_items: Option>, - pub payment_method_collection: Option, - pub subscription_data: Option, - pub success_url: Option<&'a str>, - pub billing_address_collection: Option, - pub customer_update: Option, - pub tax_id_collection: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCheckoutSessionMode { - Payment, - Setup, - Subscription, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeCreateCheckoutSessionLineItems { - pub price: Option, - pub quantity: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum StripeCheckoutSessionPaymentMethodCollection { - Always, - IfRequired, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeCreateCheckoutSessionSubscriptionData { - pub metadata: Option>, - pub trial_period_days: Option, - pub trial_settings: Option, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct StripeTaxIdCollection { - pub enabled: bool, -} - -#[derive(Debug, Clone)] -pub struct StripeAutomaticTax { - pub enabled: bool, -} - -#[derive(Debug)] -pub struct StripeCheckoutSession { - pub url: Option, -} - -#[async_trait] -pub trait StripeClient: Send + Sync { - async fn list_customers_by_email(&self, email: &str) -> Result>; - - async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result; - - async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result; - - async fn update_customer( - &self, - customer_id: &StripeCustomerId, - params: UpdateCustomerParams<'_>, - ) -> Result; - - async fn list_subscriptions_for_customer( - &self, - customer_id: &StripeCustomerId, - ) -> Result>; - - async fn get_subscription( - &self, - subscription_id: &StripeSubscriptionId, - ) -> Result; - - async fn create_subscription( - &self, - params: StripeCreateSubscriptionParams, - ) -> Result; - - async fn update_subscription( - &self, - subscription_id: &StripeSubscriptionId, - params: UpdateSubscriptionParams, - ) -> Result<()>; - - async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()>; - - async fn list_prices(&self) -> Result>; - - async fn list_meters(&self) -> Result>; - - async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>; - - async fn create_checkout_session( - &self, - params: StripeCreateCheckoutSessionParams<'_>, - ) -> Result; -} diff --git a/crates/collab/src/stripe_client/fake_stripe_client.rs b/crates/collab/src/stripe_client/fake_stripe_client.rs deleted file mode 100644 index 9bb08443ec..0000000000 --- a/crates/collab/src/stripe_client/fake_stripe_client.rs +++ /dev/null @@ -1,247 +0,0 @@ -use std::sync::Arc; - -use anyhow::{Result, anyhow}; -use async_trait::async_trait; -use chrono::{Duration, Utc}; -use collections::HashMap; -use parking_lot::Mutex; -use uuid::Uuid; - -use crate::stripe_client::{ - CreateCustomerParams, StripeBillingAddressCollection, StripeCheckoutSession, - StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient, - StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, - StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, - StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate, - StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription, - StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, StripeTaxIdCollection, - UpdateCustomerParams, UpdateSubscriptionParams, -}; - -#[derive(Debug, Clone)] -pub struct StripeCreateMeterEventCall { - pub identifier: Arc, - pub event_name: Arc, - pub value: u64, - pub stripe_customer_id: StripeCustomerId, - pub timestamp: Option, -} - -#[derive(Debug, Clone)] -pub struct StripeCreateCheckoutSessionCall { - pub customer: Option, - pub client_reference_id: Option, - pub mode: Option, - pub line_items: Option>, - pub payment_method_collection: Option, - pub subscription_data: Option, - pub success_url: Option, - pub billing_address_collection: Option, - pub customer_update: Option, - pub tax_id_collection: Option, -} - -pub struct FakeStripeClient { - pub customers: Arc>>, - pub subscriptions: Arc>>, - pub update_subscription_calls: - Arc>>, - pub prices: Arc>>, - pub meters: Arc>>, - pub create_meter_event_calls: Arc>>, - pub create_checkout_session_calls: Arc>>, -} - -impl FakeStripeClient { - pub fn new() -> Self { - Self { - customers: Arc::new(Mutex::new(HashMap::default())), - subscriptions: Arc::new(Mutex::new(HashMap::default())), - update_subscription_calls: Arc::new(Mutex::new(Vec::new())), - prices: Arc::new(Mutex::new(HashMap::default())), - meters: Arc::new(Mutex::new(HashMap::default())), - create_meter_event_calls: Arc::new(Mutex::new(Vec::new())), - create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())), - } - } -} - -#[async_trait] -impl StripeClient for FakeStripeClient { - async fn list_customers_by_email(&self, email: &str) -> Result> { - Ok(self - .customers - .lock() - .values() - .filter(|customer| customer.email.as_deref() == Some(email)) - .cloned() - .collect()) - } - - async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result { - self.customers - .lock() - .get(customer_id) - .cloned() - .ok_or_else(|| anyhow!("no customer found for {customer_id:?}")) - } - - async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result { - let customer = StripeCustomer { - id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()), - email: params.email.map(|email| email.to_string()), - }; - - self.customers - .lock() - .insert(customer.id.clone(), customer.clone()); - - Ok(customer) - } - - async fn update_customer( - &self, - customer_id: &StripeCustomerId, - params: UpdateCustomerParams<'_>, - ) -> Result { - let mut customers = self.customers.lock(); - if let Some(customer) = customers.get_mut(customer_id) { - if let Some(email) = params.email { - customer.email = Some(email.to_string()); - } - Ok(customer.clone()) - } else { - Err(anyhow!("no customer found for {customer_id:?}")) - } - } - - async fn list_subscriptions_for_customer( - &self, - customer_id: &StripeCustomerId, - ) -> Result> { - let subscriptions = self - .subscriptions - .lock() - .values() - .filter(|subscription| subscription.customer == *customer_id) - .cloned() - .collect(); - - Ok(subscriptions) - } - - async fn get_subscription( - &self, - subscription_id: &StripeSubscriptionId, - ) -> Result { - self.subscriptions - .lock() - .get(subscription_id) - .cloned() - .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}")) - } - - async fn create_subscription( - &self, - params: StripeCreateSubscriptionParams, - ) -> Result { - let now = Utc::now(); - - let subscription = StripeSubscription { - id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()), - customer: params.customer, - status: stripe::SubscriptionStatus::Active, - current_period_start: now.timestamp(), - current_period_end: (now + Duration::days(30)).timestamp(), - items: params - .items - .into_iter() - .map(|item| StripeSubscriptionItem { - id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()), - price: item - .price - .and_then(|price_id| self.prices.lock().get(&price_id).cloned()), - }) - .collect(), - cancel_at: None, - cancellation_details: None, - }; - - self.subscriptions - .lock() - .insert(subscription.id.clone(), subscription.clone()); - - Ok(subscription) - } - - async fn update_subscription( - &self, - subscription_id: &StripeSubscriptionId, - params: UpdateSubscriptionParams, - ) -> Result<()> { - let subscription = self.get_subscription(subscription_id).await?; - - self.update_subscription_calls - .lock() - .push((subscription.id, params)); - - Ok(()) - } - - async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> { - // TODO: Implement fake subscription cancellation. - let _ = subscription_id; - - Ok(()) - } - - async fn list_prices(&self) -> Result> { - let prices = self.prices.lock().values().cloned().collect(); - - Ok(prices) - } - - async fn list_meters(&self) -> Result> { - let meters = self.meters.lock().values().cloned().collect(); - - Ok(meters) - } - - async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> { - self.create_meter_event_calls - .lock() - .push(StripeCreateMeterEventCall { - identifier: params.identifier.into(), - event_name: params.event_name.into(), - value: params.payload.value, - stripe_customer_id: params.payload.stripe_customer_id.clone(), - timestamp: params.timestamp, - }); - - Ok(()) - } - - async fn create_checkout_session( - &self, - params: StripeCreateCheckoutSessionParams<'_>, - ) -> Result { - self.create_checkout_session_calls - .lock() - .push(StripeCreateCheckoutSessionCall { - customer: params.customer.cloned(), - client_reference_id: params.client_reference_id.map(|id| id.to_string()), - mode: params.mode, - line_items: params.line_items, - payment_method_collection: params.payment_method_collection, - subscription_data: params.subscription_data, - success_url: params.success_url.map(|url| url.to_string()), - billing_address_collection: params.billing_address_collection, - customer_update: params.customer_update, - tax_id_collection: params.tax_id_collection, - }); - - Ok(StripeCheckoutSession { - url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()), - }) - } -} diff --git a/crates/collab/src/stripe_client/real_stripe_client.rs b/crates/collab/src/stripe_client/real_stripe_client.rs deleted file mode 100644 index 07c191ff30..0000000000 --- a/crates/collab/src/stripe_client/real_stripe_client.rs +++ /dev/null @@ -1,612 +0,0 @@ -use std::str::FromStr as _; -use std::sync::Arc; - -use anyhow::{Context as _, Result, anyhow}; -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; -use stripe::{ - CancellationDetails, CancellationDetailsReason, CheckoutSession, CheckoutSessionMode, - CheckoutSessionPaymentMethodCollection, CreateCheckoutSession, CreateCheckoutSessionLineItems, - CreateCheckoutSessionSubscriptionData, CreateCheckoutSessionSubscriptionDataTrialSettings, - CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior, - CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod, - CreateCustomer, CreateSubscriptionAutomaticTax, Customer, CustomerId, ListCustomers, Price, - PriceId, Recurring, Subscription, SubscriptionId, SubscriptionItem, SubscriptionItemId, - UpdateCustomer, UpdateSubscriptionItems, UpdateSubscriptionTrialSettings, - UpdateSubscriptionTrialSettingsEndBehavior, - UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, -}; - -use crate::stripe_client::{ - CreateCustomerParams, StripeAutomaticTax, StripeBillingAddressCollection, - StripeCancellationDetails, StripeCancellationDetailsReason, StripeCheckoutSession, - StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient, - StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, - StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, - StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeCustomerUpdate, - StripeCustomerUpdateAddress, StripeCustomerUpdateName, StripeCustomerUpdateShipping, - StripeMeter, StripePrice, StripePriceId, StripePriceRecurring, StripeSubscription, - StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, - StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, StripeTaxIdCollection, - UpdateCustomerParams, UpdateSubscriptionParams, -}; - -pub struct RealStripeClient { - client: Arc, -} - -impl RealStripeClient { - pub fn new(client: Arc) -> Self { - Self { client } - } -} - -#[async_trait] -impl StripeClient for RealStripeClient { - async fn list_customers_by_email(&self, email: &str) -> Result> { - let response = Customer::list( - &self.client, - &ListCustomers { - email: Some(email), - ..Default::default() - }, - ) - .await?; - - Ok(response - .data - .into_iter() - .map(StripeCustomer::from) - .collect()) - } - - async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result { - let customer_id = customer_id.try_into()?; - - let customer = Customer::retrieve(&self.client, &customer_id, &[]).await?; - - Ok(StripeCustomer::from(customer)) - } - - async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result { - let customer = Customer::create( - &self.client, - CreateCustomer { - email: params.email, - ..Default::default() - }, - ) - .await?; - - Ok(StripeCustomer::from(customer)) - } - - async fn update_customer( - &self, - customer_id: &StripeCustomerId, - params: UpdateCustomerParams<'_>, - ) -> Result { - let customer = Customer::update( - &self.client, - &customer_id.try_into()?, - UpdateCustomer { - email: params.email, - ..Default::default() - }, - ) - .await?; - - Ok(StripeCustomer::from(customer)) - } - - async fn list_subscriptions_for_customer( - &self, - customer_id: &StripeCustomerId, - ) -> Result> { - let customer_id = customer_id.try_into()?; - - let subscriptions = stripe::Subscription::list( - &self.client, - &stripe::ListSubscriptions { - customer: Some(customer_id), - status: None, - ..Default::default() - }, - ) - .await?; - - Ok(subscriptions - .data - .into_iter() - .map(StripeSubscription::from) - .collect()) - } - - async fn get_subscription( - &self, - subscription_id: &StripeSubscriptionId, - ) -> Result { - let subscription_id = subscription_id.try_into()?; - - let subscription = Subscription::retrieve(&self.client, &subscription_id, &[]).await?; - - Ok(StripeSubscription::from(subscription)) - } - - async fn create_subscription( - &self, - params: StripeCreateSubscriptionParams, - ) -> Result { - let customer_id = params.customer.try_into()?; - - let mut create_subscription = stripe::CreateSubscription::new(customer_id); - create_subscription.items = Some( - params - .items - .into_iter() - .map(|item| stripe::CreateSubscriptionItems { - price: item.price.map(|price| price.to_string()), - quantity: item.quantity, - ..Default::default() - }) - .collect(), - ); - create_subscription.automatic_tax = params.automatic_tax.map(Into::into); - - let subscription = Subscription::create(&self.client, create_subscription).await?; - - Ok(StripeSubscription::from(subscription)) - } - - async fn update_subscription( - &self, - subscription_id: &StripeSubscriptionId, - params: UpdateSubscriptionParams, - ) -> Result<()> { - let subscription_id = subscription_id.try_into()?; - - stripe::Subscription::update( - &self.client, - &subscription_id, - stripe::UpdateSubscription { - items: params.items.map(|items| { - items - .into_iter() - .map(|item| UpdateSubscriptionItems { - price: item.price.map(|price| price.to_string()), - ..Default::default() - }) - .collect() - }), - trial_settings: params.trial_settings.map(Into::into), - ..Default::default() - }, - ) - .await?; - - Ok(()) - } - - async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> { - let subscription_id = subscription_id.try_into()?; - - Subscription::cancel( - &self.client, - &subscription_id, - stripe::CancelSubscription { - invoice_now: None, - ..Default::default() - }, - ) - .await?; - - Ok(()) - } - - async fn list_prices(&self) -> Result> { - let response = stripe::Price::list( - &self.client, - &stripe::ListPrices { - limit: Some(100), - ..Default::default() - }, - ) - .await?; - - Ok(response.data.into_iter().map(StripePrice::from).collect()) - } - - async fn list_meters(&self) -> Result> { - #[derive(Serialize)] - struct Params { - #[serde(skip_serializing_if = "Option::is_none")] - limit: Option, - } - - let response = self - .client - .get_query::, _>( - "/billing/meters", - Params { limit: Some(100) }, - ) - .await?; - - Ok(response.data) - } - - async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> { - #[derive(Deserialize)] - struct StripeMeterEvent { - pub identifier: String, - } - - let identifier = params.identifier; - match self - .client - .post_form::("/billing/meter_events", params) - .await - { - Ok(_event) => Ok(()), - Err(stripe::StripeError::Stripe(error)) => { - if error.http_status == 400 - && error - .message - .as_ref() - .map_or(false, |message| message.contains(identifier)) - { - Ok(()) - } else { - Err(anyhow!(stripe::StripeError::Stripe(error))) - } - } - Err(error) => Err(anyhow!("failed to create meter event: {error:?}")), - } - } - - async fn create_checkout_session( - &self, - params: StripeCreateCheckoutSessionParams<'_>, - ) -> Result { - let params = params.try_into()?; - let session = CheckoutSession::create(&self.client, params).await?; - - Ok(session.into()) - } -} - -impl From for StripeCustomerId { - fn from(value: CustomerId) -> Self { - Self(value.as_str().into()) - } -} - -impl TryFrom for CustomerId { - type Error = anyhow::Error; - - fn try_from(value: StripeCustomerId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID") - } -} - -impl TryFrom<&StripeCustomerId> for CustomerId { - type Error = anyhow::Error; - - fn try_from(value: &StripeCustomerId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID") - } -} - -impl From for StripeCustomer { - fn from(value: Customer) -> Self { - StripeCustomer { - id: value.id.into(), - email: value.email, - } - } -} - -impl From for StripeSubscriptionId { - fn from(value: SubscriptionId) -> Self { - Self(value.as_str().into()) - } -} - -impl TryFrom<&StripeSubscriptionId> for SubscriptionId { - type Error = anyhow::Error; - - fn try_from(value: &StripeSubscriptionId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe subscription ID") - } -} - -impl From for StripeSubscription { - fn from(value: Subscription) -> Self { - Self { - id: value.id.into(), - customer: value.customer.id().into(), - status: value.status, - current_period_start: value.current_period_start, - current_period_end: value.current_period_end, - items: value.items.data.into_iter().map(Into::into).collect(), - cancel_at: value.cancel_at, - cancellation_details: value.cancellation_details.map(Into::into), - } - } -} - -impl From for StripeCancellationDetails { - fn from(value: CancellationDetails) -> Self { - Self { - reason: value.reason.map(Into::into), - } - } -} - -impl From for StripeCancellationDetailsReason { - fn from(value: CancellationDetailsReason) -> Self { - match value { - CancellationDetailsReason::CancellationRequested => Self::CancellationRequested, - CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed, - CancellationDetailsReason::PaymentFailed => Self::PaymentFailed, - } - } -} - -impl From for StripeSubscriptionItemId { - fn from(value: SubscriptionItemId) -> Self { - Self(value.as_str().into()) - } -} - -impl From for StripeSubscriptionItem { - fn from(value: SubscriptionItem) -> Self { - Self { - id: value.id.into(), - price: value.price.map(Into::into), - } - } -} - -impl From for CreateSubscriptionAutomaticTax { - fn from(value: StripeAutomaticTax) -> Self { - Self { - enabled: value.enabled, - liability: None, - } - } -} - -impl From for UpdateSubscriptionTrialSettings { - fn from(value: StripeSubscriptionTrialSettings) -> Self { - Self { - end_behavior: value.end_behavior.into(), - } - } -} - -impl From - for UpdateSubscriptionTrialSettingsEndBehavior -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self { - Self { - missing_payment_method: value.missing_payment_method.into(), - } - } -} - -impl From - for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self { - match value { - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => { - Self::CreateInvoice - } - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause, - } - } -} - -impl From for StripePriceId { - fn from(value: PriceId) -> Self { - Self(value.as_str().into()) - } -} - -impl TryFrom for PriceId { - type Error = anyhow::Error; - - fn try_from(value: StripePriceId) -> Result { - Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID") - } -} - -impl From for StripePrice { - fn from(value: Price) -> Self { - Self { - id: value.id.into(), - unit_amount: value.unit_amount, - lookup_key: value.lookup_key, - recurring: value.recurring.map(StripePriceRecurring::from), - } - } -} - -impl From for StripePriceRecurring { - fn from(value: Recurring) -> Self { - Self { meter: value.meter } - } -} - -impl<'a> TryFrom> for CreateCheckoutSession<'a> { - type Error = anyhow::Error; - - fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result { - Ok(Self { - customer: value - .customer - .map(|customer_id| customer_id.try_into()) - .transpose()?, - client_reference_id: value.client_reference_id, - mode: value.mode.map(Into::into), - line_items: value - .line_items - .map(|line_items| line_items.into_iter().map(Into::into).collect()), - payment_method_collection: value.payment_method_collection.map(Into::into), - subscription_data: value.subscription_data.map(Into::into), - success_url: value.success_url, - billing_address_collection: value.billing_address_collection.map(Into::into), - customer_update: value.customer_update.map(Into::into), - tax_id_collection: value.tax_id_collection.map(Into::into), - ..Default::default() - }) - } -} - -impl From for CheckoutSessionMode { - fn from(value: StripeCheckoutSessionMode) -> Self { - match value { - StripeCheckoutSessionMode::Payment => Self::Payment, - StripeCheckoutSessionMode::Setup => Self::Setup, - StripeCheckoutSessionMode::Subscription => Self::Subscription, - } - } -} - -impl From for CreateCheckoutSessionLineItems { - fn from(value: StripeCreateCheckoutSessionLineItems) -> Self { - Self { - price: value.price, - quantity: value.quantity, - ..Default::default() - } - } -} - -impl From for CheckoutSessionPaymentMethodCollection { - fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self { - match value { - StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always, - StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired, - } - } -} - -impl From for CreateCheckoutSessionSubscriptionData { - fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self { - Self { - trial_period_days: value.trial_period_days, - trial_settings: value.trial_settings.map(Into::into), - metadata: value.metadata, - ..Default::default() - } - } -} - -impl From for CreateCheckoutSessionSubscriptionDataTrialSettings { - fn from(value: StripeSubscriptionTrialSettings) -> Self { - Self { - end_behavior: value.end_behavior.into(), - } - } -} - -impl From - for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self { - Self { - missing_payment_method: value.missing_payment_method.into(), - } - } -} - -impl From - for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod -{ - fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self { - match value { - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel, - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => { - Self::CreateInvoice - } - StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause, - } - } -} - -impl From for StripeCheckoutSession { - fn from(value: CheckoutSession) -> Self { - Self { url: value.url } - } -} - -impl From for stripe::CheckoutSessionBillingAddressCollection { - fn from(value: StripeBillingAddressCollection) -> Self { - match value { - StripeBillingAddressCollection::Auto => { - stripe::CheckoutSessionBillingAddressCollection::Auto - } - StripeBillingAddressCollection::Required => { - stripe::CheckoutSessionBillingAddressCollection::Required - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdateAddress { - fn from(value: StripeCustomerUpdateAddress) -> Self { - match value { - StripeCustomerUpdateAddress::Auto => { - stripe::CreateCheckoutSessionCustomerUpdateAddress::Auto - } - StripeCustomerUpdateAddress::Never => { - stripe::CreateCheckoutSessionCustomerUpdateAddress::Never - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdateName { - fn from(value: StripeCustomerUpdateName) -> Self { - match value { - StripeCustomerUpdateName::Auto => stripe::CreateCheckoutSessionCustomerUpdateName::Auto, - StripeCustomerUpdateName::Never => { - stripe::CreateCheckoutSessionCustomerUpdateName::Never - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdateShipping { - fn from(value: StripeCustomerUpdateShipping) -> Self { - match value { - StripeCustomerUpdateShipping::Auto => { - stripe::CreateCheckoutSessionCustomerUpdateShipping::Auto - } - StripeCustomerUpdateShipping::Never => { - stripe::CreateCheckoutSessionCustomerUpdateShipping::Never - } - } - } -} - -impl From for stripe::CreateCheckoutSessionCustomerUpdate { - fn from(value: StripeCustomerUpdate) -> Self { - stripe::CreateCheckoutSessionCustomerUpdate { - address: value.address.map(Into::into), - name: value.name.map(Into::into), - shipping: value.shipping.map(Into::into), - } - } -} - -impl From for stripe::CreateCheckoutSessionTaxIdCollection { - fn from(value: StripeTaxIdCollection) -> Self { - stripe::CreateCheckoutSessionTaxIdCollection { - enabled: value.enabled, - } - } -} diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 8d5d076780..ddf245b06f 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -8,7 +8,6 @@ mod channel_buffer_tests; mod channel_guest_tests; mod channel_message_tests; mod channel_tests; -// mod debug_panel_tests; mod editor_tests; mod following_tests; mod git_tests; @@ -18,7 +17,6 @@ mod random_channel_buffer_tests; mod random_project_collaboration_tests; mod randomized_test_helpers; mod remote_editing_collaboration_tests; -mod stripe_billing_tests; mod test_server; use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; diff --git a/crates/collab/src/tests/stripe_billing_tests.rs b/crates/collab/src/tests/stripe_billing_tests.rs deleted file mode 100644 index bb84bedfcf..0000000000 --- a/crates/collab/src/tests/stripe_billing_tests.rs +++ /dev/null @@ -1,123 +0,0 @@ -use std::sync::Arc; - -use pretty_assertions::assert_eq; - -use crate::stripe_billing::StripeBilling; -use crate::stripe_client::{FakeStripeClient, StripePrice, StripePriceId, StripePriceRecurring}; - -fn make_stripe_billing() -> (StripeBilling, Arc) { - let stripe_client = Arc::new(FakeStripeClient::new()); - let stripe_billing = StripeBilling::test(stripe_client.clone()); - - (stripe_billing, stripe_client) -} - -#[gpui::test] -async fn test_initialize() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - // Add test prices - let price1 = StripePrice { - id: StripePriceId("price_1".into()), - unit_amount: Some(1_000), - lookup_key: Some("zed-pro".to_string()), - recurring: None, - }; - let price2 = StripePrice { - id: StripePriceId("price_2".into()), - unit_amount: Some(0), - lookup_key: Some("zed-free".to_string()), - recurring: None, - }; - let price3 = StripePrice { - id: StripePriceId("price_3".into()), - unit_amount: Some(500), - lookup_key: None, - recurring: Some(StripePriceRecurring { - meter: Some("meter_1".to_string()), - }), - }; - stripe_client - .prices - .lock() - .insert(price1.id.clone(), price1); - stripe_client - .prices - .lock() - .insert(price2.id.clone(), price2); - stripe_client - .prices - .lock() - .insert(price3.id.clone(), price3); - - // Initialize the billing system - stripe_billing.initialize().await.unwrap(); - - // Verify that prices can be found by lookup key - let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap(); - assert_eq!(zed_pro_price_id.to_string(), "price_1"); - - let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap(); - assert_eq!(zed_free_price_id.to_string(), "price_2"); - - // Verify that a price can be found by lookup key - let zed_pro_price = stripe_billing - .find_price_by_lookup_key("zed-pro") - .await - .unwrap(); - assert_eq!(zed_pro_price.id.to_string(), "price_1"); - assert_eq!(zed_pro_price.unit_amount, Some(1_000)); - - // Verify that finding a non-existent lookup key returns an error - let result = stripe_billing - .find_price_by_lookup_key("non-existent") - .await; - assert!(result.is_err()); -} - -#[gpui::test] -async fn test_find_or_create_customer_by_email() { - let (stripe_billing, stripe_client) = make_stripe_billing(); - - // Create a customer with an email that doesn't yet correspond to a customer. - { - let email = "user@example.com"; - - let customer_id = stripe_billing - .find_or_create_customer_by_email(Some(email)) - .await - .unwrap(); - - let customer = stripe_client - .customers - .lock() - .get(&customer_id) - .unwrap() - .clone(); - assert_eq!(customer.email.as_deref(), Some(email)); - } - - // Create a customer with an email that corresponds to an existing customer. - { - let email = "user2@example.com"; - - let existing_customer_id = stripe_billing - .find_or_create_customer_by_email(Some(email)) - .await - .unwrap(); - - let customer_id = stripe_billing - .find_or_create_customer_by_email(Some(email)) - .await - .unwrap(); - assert_eq!(customer_id, existing_customer_id); - - let customer = stripe_client - .customers - .lock() - .get(&customer_id) - .unwrap() - .clone(); - assert_eq!(customer.email.as_deref(), Some(email)); - } -} diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index f5a0e8ea81..8c545b0670 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -1,4 +1,3 @@ -use crate::stripe_client::FakeStripeClient; use crate::{ AppState, Config, db::{NewUserParams, UserId, tests::TestDb}, @@ -569,9 +568,6 @@ impl TestServer { llm_db: None, livekit_client: Some(Arc::new(livekit_test_server.create_api_client())), blob_store_client: None, - real_stripe_client: None, - stripe_client: Some(Arc::new(FakeStripeClient::new())), - stripe_billing: None, executor, kinesis_client: None, config: Config { @@ -608,7 +604,6 @@ impl TestServer { auto_join_channel_id: None, migrations_path: None, seed_path: None, - stripe_api_key: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, kinesis_region: None, From 9eb1ff272693a811c8f3f1b251a67c3a97f856e4 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Fri, 15 Aug 2025 18:03:36 -0300 Subject: [PATCH 13/17] acp thread view: Always use editors for user messages (#36256) This means the cursor will be at the position you clicked: https://github.com/user-attachments/assets/0693950d-7513-4d90-88e2-55817df7213a Release Notes: - N/A --- crates/acp_thread/src/acp_thread.rs | 10 +- .../agent_ui/src/acp/completion_provider.rs | 5 - crates/agent_ui/src/acp/entry_view_state.rs | 387 ++++++----- crates/agent_ui/src/acp/message_editor.rs | 28 +- crates/agent_ui/src/acp/thread_view.rs | 605 ++++++++++++------ crates/agent_ui/src/agent_panel.rs | 10 +- 6 files changed, 671 insertions(+), 374 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 4995ddb9df..2ef94a3cbe 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -109,7 +109,7 @@ pub enum AgentThreadEntry { } impl AgentThreadEntry { - fn to_markdown(&self, cx: &App) -> String { + pub fn to_markdown(&self, cx: &App) -> String { match self { Self::UserMessage(message) => message.to_markdown(cx), Self::AssistantMessage(message) => message.to_markdown(cx), @@ -117,6 +117,14 @@ impl AgentThreadEntry { } } + pub fn user_message(&self) -> Option<&UserMessage> { + if let AgentThreadEntry::UserMessage(message) = self { + Some(message) + } else { + None + } + } + pub fn diffs(&self) -> impl Iterator> { if let AgentThreadEntry::ToolCall(call) = self { itertools::Either::Left(call.diffs()) diff --git a/crates/agent_ui/src/acp/completion_provider.rs b/crates/agent_ui/src/acp/completion_provider.rs index 4ee1eb6948..d7d2cd5d0e 100644 --- a/crates/agent_ui/src/acp/completion_provider.rs +++ b/crates/agent_ui/src/acp/completion_provider.rs @@ -80,11 +80,6 @@ impl MentionSet { .chain(self.images.drain().map(|(id, _)| id)) } - pub fn clear(&mut self) { - self.fetch_results.clear(); - self.uri_by_crease_id.clear(); - } - pub fn contents( &self, project: Entity, diff --git a/crates/agent_ui/src/acp/entry_view_state.rs b/crates/agent_ui/src/acp/entry_view_state.rs index 2f5f855e90..e99d1f6323 100644 --- a/crates/agent_ui/src/acp/entry_view_state.rs +++ b/crates/agent_ui/src/acp/entry_view_state.rs @@ -1,45 +1,141 @@ -use std::{collections::HashMap, ops::Range}; +use std::ops::Range; -use acp_thread::AcpThread; -use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer}; +use acp_thread::{AcpThread, AgentThreadEntry}; +use agent::{TextThreadStore, ThreadStore}; +use collections::HashMap; +use editor::{Editor, EditorMode, MinimapVisibility}; use gpui::{ - AnyEntity, App, AppContext as _, Entity, EntityId, TextStyleRefinement, WeakEntity, Window, + AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, TextStyleRefinement, + WeakEntity, Window, }; use language::language_settings::SoftWrap; +use project::Project; use settings::Settings as _; use terminal_view::TerminalView; use theme::ThemeSettings; -use ui::TextSize; +use ui::{Context, TextSize}; use workspace::Workspace; -#[derive(Default)] +use crate::acp::message_editor::{MessageEditor, MessageEditorEvent}; + pub struct EntryViewState { + workspace: WeakEntity, + project: Entity, + thread_store: Entity, + text_thread_store: Entity, entries: Vec, } impl EntryViewState { + pub fn new( + workspace: WeakEntity, + project: Entity, + thread_store: Entity, + text_thread_store: Entity, + ) -> Self { + Self { + workspace, + project, + thread_store, + text_thread_store, + entries: Vec::new(), + } + } + pub fn entry(&self, index: usize) -> Option<&Entry> { self.entries.get(index) } pub fn sync_entry( &mut self, - workspace: WeakEntity, - thread: Entity, index: usize, + thread: &Entity, window: &mut Window, - cx: &mut App, + cx: &mut Context, ) { - debug_assert!(index <= self.entries.len()); - let entry = if let Some(entry) = self.entries.get_mut(index) { - entry - } else { - self.entries.push(Entry::default()); - self.entries.last_mut().unwrap() + let Some(thread_entry) = thread.read(cx).entries().get(index) else { + return; }; - entry.sync_diff_multibuffers(&thread, index, window, cx); - entry.sync_terminals(&workspace, &thread, index, window, cx); + match thread_entry { + AgentThreadEntry::UserMessage(message) => { + let has_id = message.id.is_some(); + let chunks = message.chunks.clone(); + let message_editor = cx.new(|cx| { + let mut editor = MessageEditor::new( + self.workspace.clone(), + self.project.clone(), + self.thread_store.clone(), + self.text_thread_store.clone(), + editor::EditorMode::AutoHeight { + min_lines: 1, + max_lines: None, + }, + window, + cx, + ); + if !has_id { + editor.set_read_only(true, cx); + } + editor.set_message(chunks, window, cx); + editor + }); + cx.subscribe(&message_editor, move |_, editor, event, cx| { + cx.emit(EntryViewEvent { + entry_index: index, + view_event: ViewEvent::MessageEditorEvent(editor, *event), + }) + }) + .detach(); + self.set_entry(index, Entry::UserMessage(message_editor)); + } + AgentThreadEntry::ToolCall(tool_call) => { + let terminals = tool_call.terminals().cloned().collect::>(); + let diffs = tool_call.diffs().cloned().collect::>(); + + let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) { + views + } else { + self.set_entry(index, Entry::empty()); + let Some(Entry::Content(views)) = self.entries.get_mut(index) else { + unreachable!() + }; + views + }; + + for terminal in terminals { + views.entry(terminal.entity_id()).or_insert_with(|| { + create_terminal( + self.workspace.clone(), + self.project.clone(), + terminal.clone(), + window, + cx, + ) + .into_any() + }); + } + + for diff in diffs { + views + .entry(diff.entity_id()) + .or_insert_with(|| create_editor_diff(diff.clone(), window, cx).into_any()); + } + } + AgentThreadEntry::AssistantMessage(_) => { + if index == self.entries.len() { + self.entries.push(Entry::empty()) + } + } + }; + } + + fn set_entry(&mut self, index: usize, entry: Entry) { + if index == self.entries.len() { + self.entries.push(entry); + } else { + self.entries[index] = entry; + } } pub fn remove(&mut self, range: Range) { @@ -48,26 +144,51 @@ impl EntryViewState { pub fn settings_changed(&mut self, cx: &mut App) { for entry in self.entries.iter() { - for view in entry.views.values() { - if let Ok(diff_editor) = view.clone().downcast::() { - diff_editor.update(cx, |diff_editor, cx| { - diff_editor - .set_text_style_refinement(diff_editor_text_style_refinement(cx)); - cx.notify(); - }) + match entry { + Entry::UserMessage { .. } => {} + Entry::Content(response_views) => { + for view in response_views.values() { + if let Ok(diff_editor) = view.clone().downcast::() { + diff_editor.update(cx, |diff_editor, cx| { + diff_editor.set_text_style_refinement( + diff_editor_text_style_refinement(cx), + ); + cx.notify(); + }) + } + } } } } } } -pub struct Entry { - views: HashMap, +impl EventEmitter for EntryViewState {} + +pub struct EntryViewEvent { + pub entry_index: usize, + pub view_event: ViewEvent, +} + +pub enum ViewEvent { + MessageEditorEvent(Entity, MessageEditorEvent), +} + +pub enum Entry { + UserMessage(Entity), + Content(HashMap), } impl Entry { - pub fn editor_for_diff(&self, diff: &Entity) -> Option> { - self.views + pub fn message_editor(&self) -> Option<&Entity> { + match self { + Self::UserMessage(editor) => Some(editor), + Entry::Content(_) => None, + } + } + + pub fn editor_for_diff(&self, diff: &Entity) -> Option> { + self.content_map()? .get(&diff.entity_id()) .cloned() .map(|entity| entity.downcast::().unwrap()) @@ -77,118 +198,88 @@ impl Entry { &self, terminal: &Entity, ) -> Option> { - self.views + self.content_map()? .get(&terminal.entity_id()) .cloned() .map(|entity| entity.downcast::().unwrap()) } - fn sync_diff_multibuffers( - &mut self, - thread: &Entity, - index: usize, - window: &mut Window, - cx: &mut App, - ) { - let Some(entry) = thread.read(cx).entries().get(index) else { - return; - }; - - let multibuffers = entry - .diffs() - .map(|diff| diff.read(cx).multibuffer().clone()); - - let multibuffers = multibuffers.collect::>(); - - for multibuffer in multibuffers { - if self.views.contains_key(&multibuffer.entity_id()) { - return; - } - - let editor = cx.new(|cx| { - let mut editor = Editor::new( - EditorMode::Full { - scale_ui_elements_with_buffer_font_size: false, - show_active_line_background: false, - sized_by_content: true, - }, - multibuffer.clone(), - None, - window, - cx, - ); - editor.set_show_gutter(false, cx); - editor.disable_inline_diagnostics(); - editor.disable_expand_excerpt_buttons(cx); - editor.set_show_vertical_scrollbar(false, cx); - editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); - editor.set_soft_wrap_mode(SoftWrap::None, cx); - editor.scroll_manager.set_forbid_vertical_scroll(true); - editor.set_show_indent_guides(false, cx); - editor.set_read_only(true); - editor.set_show_breakpoints(false, cx); - editor.set_show_code_actions(false, cx); - editor.set_show_git_diff_gutter(false, cx); - editor.set_expand_all_diff_hunks(cx); - editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); - editor - }); - - let entity_id = multibuffer.entity_id(); - self.views.insert(entity_id, editor.into_any()); + fn content_map(&self) -> Option<&HashMap> { + match self { + Self::Content(map) => Some(map), + _ => None, } } - fn sync_terminals( - &mut self, - workspace: &WeakEntity, - thread: &Entity, - index: usize, - window: &mut Window, - cx: &mut App, - ) { - let Some(entry) = thread.read(cx).entries().get(index) else { - return; - }; - - let terminals = entry - .terminals() - .map(|terminal| terminal.clone()) - .collect::>(); - - for terminal in terminals { - if self.views.contains_key(&terminal.entity_id()) { - return; - } - - let Some(strong_workspace) = workspace.upgrade() else { - return; - }; - - let terminal_view = cx.new(|cx| { - let mut view = TerminalView::new( - terminal.read(cx).inner().clone(), - workspace.clone(), - None, - strong_workspace.read(cx).project().downgrade(), - window, - cx, - ); - view.set_embedded_mode(Some(1000), cx); - view - }); - - let entity_id = terminal.entity_id(); - self.views.insert(entity_id, terminal_view.into_any()); - } + fn empty() -> Self { + Self::Content(HashMap::default()) } #[cfg(test)] - pub fn len(&self) -> usize { - self.views.len() + pub fn has_content(&self) -> bool { + match self { + Self::Content(map) => !map.is_empty(), + Self::UserMessage(_) => false, + } } } +fn create_terminal( + workspace: WeakEntity, + project: Entity, + terminal: Entity, + window: &mut Window, + cx: &mut App, +) -> Entity { + cx.new(|cx| { + let mut view = TerminalView::new( + terminal.read(cx).inner().clone(), + workspace.clone(), + None, + project.downgrade(), + window, + cx, + ); + view.set_embedded_mode(Some(1000), cx); + view + }) +} + +fn create_editor_diff( + diff: Entity, + window: &mut Window, + cx: &mut App, +) -> Entity { + cx.new(|cx| { + let mut editor = Editor::new( + EditorMode::Full { + scale_ui_elements_with_buffer_font_size: false, + show_active_line_background: false, + sized_by_content: true, + }, + diff.read(cx).multibuffer().clone(), + None, + window, + cx, + ); + editor.set_show_gutter(false, cx); + editor.disable_inline_diagnostics(); + editor.disable_expand_excerpt_buttons(cx); + editor.set_show_vertical_scrollbar(false, cx); + editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); + editor.set_soft_wrap_mode(SoftWrap::None, cx); + editor.scroll_manager.set_forbid_vertical_scroll(true); + editor.set_show_indent_guides(false, cx); + editor.set_read_only(true); + editor.set_show_breakpoints(false, cx); + editor.set_show_code_actions(false, cx); + editor.set_show_git_diff_gutter(false, cx); + editor.set_expand_all_diff_hunks(cx); + editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); + editor + }) +} + fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement { TextStyleRefinement { font_size: Some( @@ -201,26 +292,20 @@ fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement { } } -impl Default for Entry { - fn default() -> Self { - Self { - // Avoid allocating in the heap by default - views: HashMap::with_capacity(0), - } - } -} - #[cfg(test)] mod tests { use std::{path::Path, rc::Rc}; use acp_thread::{AgentConnection, StubAgentConnection}; + use agent::{TextThreadStore, ThreadStore}; use agent_client_protocol as acp; use agent_settings::AgentSettings; use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind}; use editor::{EditorSettings, RowInfo}; use fs::FakeFs; - use gpui::{SemanticVersion, TestAppContext}; + use gpui::{AppContext as _, SemanticVersion, TestAppContext}; + + use crate::acp::entry_view_state::EntryViewState; use multi_buffer::MultiBufferRow; use pretty_assertions::assert_matches; use project::Project; @@ -230,8 +315,6 @@ mod tests { use util::path; use workspace::Workspace; - use crate::acp::entry_view_state::EntryViewState; - #[gpui::test] async fn test_diff_sync(cx: &mut TestAppContext) { init_test(cx); @@ -269,7 +352,7 @@ mod tests { .update(|_, cx| { connection .clone() - .new_thread(project, Path::new(path!("/project")), cx) + .new_thread(project.clone(), Path::new(path!("/project")), cx) }) .await .unwrap(); @@ -279,12 +362,23 @@ mod tests { connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx) }); - let mut view_state = EntryViewState::default(); - cx.update(|window, cx| { - view_state.sync_entry(workspace.downgrade(), thread.clone(), 0, window, cx); + let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx)); + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + + let view_state = cx.new(|_cx| { + EntryViewState::new( + workspace.downgrade(), + project.clone(), + thread_store, + text_thread_store, + ) }); - let multibuffer = thread.read_with(cx, |thread, cx| { + view_state.update_in(cx, |view_state, window, cx| { + view_state.sync_entry(0, &thread, window, cx) + }); + + let diff = thread.read_with(cx, |thread, _cx| { thread .entries() .get(0) @@ -292,15 +386,14 @@ mod tests { .diffs() .next() .unwrap() - .read(cx) - .multibuffer() .clone() }); cx.run_until_parked(); - let entry = view_state.entry(0).unwrap(); - let diff_editor = entry.editor_for_diff(&multibuffer).unwrap(); + let diff_editor = view_state.read_with(cx, |view_state, _cx| { + view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap() + }); assert_eq!( diff_editor.read_with(cx, |editor, cx| editor.text(cx)), "hi world\nhello world" diff --git a/crates/agent_ui/src/acp/message_editor.rs b/crates/agent_ui/src/acp/message_editor.rs index 32c37da519..90827e5514 100644 --- a/crates/agent_ui/src/acp/message_editor.rs +++ b/crates/agent_ui/src/acp/message_editor.rs @@ -52,9 +52,11 @@ pub struct MessageEditor { text_thread_store: Entity, } +#[derive(Clone, Copy)] pub enum MessageEditorEvent { Send, Cancel, + Focus, } impl EventEmitter for MessageEditor {} @@ -101,6 +103,11 @@ impl MessageEditor { editor }); + cx.on_focus(&editor.focus_handle(cx), window, |_, _, cx| { + cx.emit(MessageEditorEvent::Focus) + }) + .detach(); + Self { editor, project, @@ -386,11 +393,11 @@ impl MessageEditor { }); } - fn chat(&mut self, _: &Chat, _: &mut Window, cx: &mut Context) { + fn send(&mut self, _: &Chat, _: &mut Window, cx: &mut Context) { cx.emit(MessageEditorEvent::Send) } - fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { + fn cancel(&mut self, _: &editor::actions::Cancel, _: &mut Window, cx: &mut Context) { cx.emit(MessageEditorEvent::Cancel) } @@ -496,6 +503,13 @@ impl MessageEditor { } } + pub fn set_read_only(&mut self, read_only: bool, cx: &mut Context) { + self.editor.update(cx, |message_editor, cx| { + message_editor.set_read_only(read_only); + cx.notify() + }) + } + fn insert_image( &mut self, excerpt_id: ExcerptId, @@ -572,6 +586,8 @@ impl MessageEditor { window: &mut Window, cx: &mut Context, ) { + self.clear(window, cx); + let mut text = String::new(); let mut mentions = Vec::new(); let mut images = Vec::new(); @@ -609,7 +625,6 @@ impl MessageEditor { editor.buffer().read(cx).snapshot(cx) }); - self.mention_set.clear(); for (range, mention_uri) in mentions { let anchor = snapshot.anchor_before(range.start); let crease_id = crate::context_picker::insert_crease_for_mention( @@ -679,6 +694,11 @@ impl MessageEditor { editor.set_text(text, window, cx); }); } + + #[cfg(test)] + pub fn text(&self, cx: &App) -> String { + self.editor.read(cx).text(cx) + } } impl Focusable for MessageEditor { @@ -691,7 +711,7 @@ impl Render for MessageEditor { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { div() .key_context("MessageEditor") - .on_action(cx.listener(Self::chat)) + .on_action(cx.listener(Self::send)) .on_action(cx.listener(Self::cancel)) .capture_action(cx.listener(Self::paste)) .flex_1() diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index cb1a62fd11..17341e4c8a 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -45,6 +45,7 @@ use zed_actions::assistant::OpenRulesLibrary; use super::entry_view_state::EntryViewState; use crate::acp::AcpModelSelectorPopover; +use crate::acp::entry_view_state::{EntryViewEvent, ViewEvent}; use crate::acp::message_editor::{MessageEditor, MessageEditorEvent}; use crate::agent_diff::AgentDiff; use crate::profile_selector::{ProfileProvider, ProfileSelector}; @@ -101,10 +102,8 @@ pub struct AcpThreadView { agent: Rc, workspace: WeakEntity, project: Entity, - thread_store: Entity, - text_thread_store: Entity, thread_state: ThreadState, - entry_view_state: EntryViewState, + entry_view_state: Entity, message_editor: Entity, model_selector: Option>, profile_selector: Option>, @@ -120,16 +119,9 @@ pub struct AcpThreadView { plan_expanded: bool, editor_expanded: bool, terminal_expanded: bool, - editing_message: Option, + editing_message: Option, _cancel_task: Option>, - _subscriptions: [Subscription; 2], -} - -struct EditingMessage { - index: usize, - message_id: UserMessageId, - editor: Entity, - _subscription: Subscription, + _subscriptions: [Subscription; 3], } enum ThreadState { @@ -176,24 +168,32 @@ impl AcpThreadView { let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0)); + let entry_view_state = cx.new(|_| { + EntryViewState::new( + workspace.clone(), + project.clone(), + thread_store.clone(), + text_thread_store.clone(), + ) + }); + let subscriptions = [ cx.observe_global_in::(window, Self::settings_changed), - cx.subscribe_in(&message_editor, window, Self::on_message_editor_event), + cx.subscribe_in(&message_editor, window, Self::handle_message_editor_event), + cx.subscribe_in(&entry_view_state, window, Self::handle_entry_view_event), ]; Self { agent: agent.clone(), workspace: workspace.clone(), project: project.clone(), - thread_store, - text_thread_store, + entry_view_state, thread_state: Self::initial_state(agent, workspace, project, window, cx), message_editor, model_selector: None, profile_selector: None, notifications: Vec::new(), notification_subscriptions: HashMap::default(), - entry_view_state: EntryViewState::default(), list_state: list_state.clone(), scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()), thread_error: None, @@ -414,7 +414,7 @@ impl AcpThreadView { cx.notify(); } - pub fn on_message_editor_event( + pub fn handle_message_editor_event( &mut self, _: &Entity, event: &MessageEditorEvent, @@ -424,6 +424,28 @@ impl AcpThreadView { match event { MessageEditorEvent::Send => self.send(window, cx), MessageEditorEvent::Cancel => self.cancel_generation(cx), + MessageEditorEvent::Focus => {} + } + } + + pub fn handle_entry_view_event( + &mut self, + _: &Entity, + event: &EntryViewEvent, + window: &mut Window, + cx: &mut Context, + ) { + match &event.view_event { + ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Focus) => { + self.editing_message = Some(event.entry_index); + cx.notify(); + } + ViewEvent::MessageEditorEvent(editor, MessageEditorEvent::Send) => { + self.regenerate(event.entry_index, editor, window, cx); + } + ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Cancel) => { + self.cancel_editing(&Default::default(), window, cx); + } } } @@ -494,27 +516,56 @@ impl AcpThreadView { .detach(); } - fn cancel_editing(&mut self, _: &ClickEvent, _window: &mut Window, cx: &mut Context) { - self.editing_message.take(); - cx.notify(); - } - - fn regenerate(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context) { - let Some(editing_message) = self.editing_message.take() else { - return; - }; - + fn cancel_editing(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context) { let Some(thread) = self.thread().cloned() else { return; }; - let rewind = thread.update(cx, |thread, cx| { - thread.rewind(editing_message.message_id, cx) - }); + if let Some(index) = self.editing_message.take() { + if let Some(editor) = self + .entry_view_state + .read(cx) + .entry(index) + .and_then(|e| e.message_editor()) + .cloned() + { + editor.update(cx, |editor, cx| { + if let Some(user_message) = thread + .read(cx) + .entries() + .get(index) + .and_then(|e| e.user_message()) + { + editor.set_message(user_message.chunks.clone(), window, cx); + } + }) + } + }; + self.focus_handle(cx).focus(window); + cx.notify(); + } + + fn regenerate( + &mut self, + entry_ix: usize, + message_editor: &Entity, + window: &mut Window, + cx: &mut Context, + ) { + let Some(thread) = self.thread().cloned() else { + return; + }; + + let Some(rewind) = thread.update(cx, |thread, cx| { + let user_message_id = thread.entries().get(entry_ix)?.user_message()?.id.clone()?; + Some(thread.rewind(user_message_id, cx)) + }) else { + return; + }; + + let contents = + message_editor.update(cx, |message_editor, cx| message_editor.contents(window, cx)); - let contents = editing_message - .editor - .update(cx, |message_editor, cx| message_editor.contents(window, cx)); let task = cx.foreground_executor().spawn(async move { rewind.await?; contents.await @@ -570,27 +621,20 @@ impl AcpThreadView { AcpThreadEvent::NewEntry => { let len = thread.read(cx).entries().len(); let index = len - 1; - self.entry_view_state.sync_entry( - self.workspace.clone(), - thread.clone(), - index, - window, - cx, - ); + self.entry_view_state.update(cx, |view_state, cx| { + view_state.sync_entry(index, &thread, window, cx) + }); self.list_state.splice(index..index, 1); } AcpThreadEvent::EntryUpdated(index) => { - self.entry_view_state.sync_entry( - self.workspace.clone(), - thread.clone(), - *index, - window, - cx, - ); + self.entry_view_state.update(cx, |view_state, cx| { + view_state.sync_entry(*index, &thread, window, cx) + }); self.list_state.splice(*index..index + 1, 1); } AcpThreadEvent::EntriesRemoved(range) => { - self.entry_view_state.remove(range.clone()); + self.entry_view_state + .update(cx, |view_state, _cx| view_state.remove(range.clone())); self.list_state.splice(range.clone(), 0); } AcpThreadEvent::ToolAuthorizationRequired => { @@ -722,29 +766,15 @@ impl AcpThreadView { .border_1() .border_color(cx.theme().colors().border) .text_xs() - .id("message") - .on_click(cx.listener({ - move |this, _, window, cx| { - this.start_editing_message(entry_ix, window, cx) - } - })) .children( - if let Some(editing) = self.editing_message.as_ref() - && Some(&editing.message_id) == message.id.as_ref() - { - Some( - self.render_edit_message_editor(editing, cx) - .into_any_element(), - ) - } else { - message.content.markdown().map(|md| { - self.render_markdown( - md.clone(), - user_message_markdown_style(window, cx), - ) - .into_any_element() - }) - }, + self.entry_view_state + .read(cx) + .entry(entry_ix) + .and_then(|entry| entry.message_editor()) + .map(|editor| { + self.render_sent_message_editor(entry_ix, editor, cx) + .into_any_element() + }), ), ) .into_any(), @@ -819,8 +849,8 @@ impl AcpThreadView { primary }; - if let Some(editing) = self.editing_message.as_ref() - && editing.index < entry_ix + if let Some(editing_index) = self.editing_message.as_ref() + && *editing_index < entry_ix { let backdrop = div() .id(("backdrop", entry_ix)) @@ -834,8 +864,8 @@ impl AcpThreadView { div() .relative() - .child(backdrop) .child(primary) + .child(backdrop) .into_any_element() } else { primary @@ -1256,9 +1286,7 @@ impl AcpThreadView { Empty.into_any_element() } } - ToolCallContent::Diff(diff) => { - self.render_diff_editor(entry_ix, &diff.read(cx).multibuffer(), cx) - } + ToolCallContent::Diff(diff) => self.render_diff_editor(entry_ix, &diff, cx), ToolCallContent::Terminal(terminal) => { self.render_terminal_tool_call(entry_ix, terminal, tool_call, window, cx) } @@ -1405,7 +1433,7 @@ impl AcpThreadView { fn render_diff_editor( &self, entry_ix: usize, - multibuffer: &Entity, + diff: &Entity, cx: &Context, ) -> AnyElement { v_flex() @@ -1413,8 +1441,8 @@ impl AcpThreadView { .border_t_1() .border_color(self.tool_card_border_color(cx)) .child( - if let Some(entry) = self.entry_view_state.entry(entry_ix) - && let Some(editor) = entry.editor_for_diff(&multibuffer) + if let Some(entry) = self.entry_view_state.read(cx).entry(entry_ix) + && let Some(editor) = entry.editor_for_diff(&diff) { editor.clone().into_any_element() } else { @@ -1617,6 +1645,7 @@ impl AcpThreadView { let terminal_view = self .entry_view_state + .read(cx) .entry(entry_ix) .and_then(|entry| entry.terminal(&terminal)); let show_output = self.terminal_expanded && terminal_view.is_some(); @@ -2485,82 +2514,38 @@ impl AcpThreadView { ) } - fn start_editing_message(&mut self, index: usize, window: &mut Window, cx: &mut Context) { - let Some(thread) = self.thread() else { - return; - }; - let Some(AgentThreadEntry::UserMessage(message)) = thread.read(cx).entries().get(index) - else { - return; - }; - let Some(message_id) = message.id.clone() else { - return; - }; - - self.list_state.scroll_to_reveal_item(index); - - let chunks = message.chunks.clone(); - let editor = cx.new(|cx| { - let mut editor = MessageEditor::new( - self.workspace.clone(), - self.project.clone(), - self.thread_store.clone(), - self.text_thread_store.clone(), - editor::EditorMode::AutoHeight { - min_lines: 1, - max_lines: None, - }, - window, - cx, - ); - editor.set_message(chunks, window, cx); - editor - }); - let subscription = - cx.subscribe_in(&editor, window, |this, _, event, window, cx| match event { - MessageEditorEvent::Send => { - this.regenerate(&Default::default(), window, cx); - } - MessageEditorEvent::Cancel => { - this.cancel_editing(&Default::default(), window, cx); - } - }); - editor.focus_handle(cx).focus(window); - - self.editing_message.replace(EditingMessage { - index: index, - message_id: message_id.clone(), - editor, - _subscription: subscription, - }); - cx.notify(); - } - - fn render_edit_message_editor(&self, editing: &EditingMessage, cx: &Context) -> Div { - v_flex() - .w_full() - .gap_2() - .child(editing.editor.clone()) - .child( - h_flex() - .gap_1() - .child( - Icon::new(IconName::Warning) - .color(Color::Warning) - .size(IconSize::XSmall), - ) - .child( - Label::new("Editing will restart the thread from this point.") - .color(Color::Muted) - .size(LabelSize::XSmall), - ) - .child(self.render_editing_message_editor_buttons(editing, cx)), - ) - } - - fn render_editing_message_editor_buttons( + fn render_sent_message_editor( &self, - editing: &EditingMessage, + entry_ix: usize, + editor: &Entity, + cx: &Context, + ) -> Div { + v_flex().w_full().gap_2().child(editor.clone()).when( + self.editing_message == Some(entry_ix), + |el| { + el.child( + h_flex() + .gap_1() + .child( + Icon::new(IconName::Warning) + .color(Color::Warning) + .size(IconSize::XSmall), + ) + .child( + Label::new("Editing will restart the thread from this point.") + .color(Color::Muted) + .size(LabelSize::XSmall), + ) + .child(self.render_sent_message_editor_buttons(entry_ix, editor, cx)), + ) + }, + ) + } + + fn render_sent_message_editor_buttons( + &self, + entry_ix: usize, + editor: &Entity, cx: &Context, ) -> Div { h_flex() @@ -2573,7 +2558,7 @@ impl AcpThreadView { .icon_color(Color::Error) .icon_size(IconSize::Small) .tooltip({ - let focus_handle = editing.editor.focus_handle(cx); + let focus_handle = editor.focus_handle(cx); move |window, cx| { Tooltip::for_action_in( "Cancel Edit", @@ -2588,12 +2573,12 @@ impl AcpThreadView { ) .child( IconButton::new("confirm-edit-message", IconName::Return) - .disabled(editing.editor.read(cx).is_empty(cx)) + .disabled(editor.read(cx).is_empty(cx)) .shape(ui::IconButtonShape::Square) .icon_color(Color::Muted) .icon_size(IconSize::Small) .tooltip({ - let focus_handle = editing.editor.focus_handle(cx); + let focus_handle = editor.focus_handle(cx); move |window, cx| { Tooltip::for_action_in( "Regenerate", @@ -2604,7 +2589,12 @@ impl AcpThreadView { ) } }) - .on_click(cx.listener(Self::regenerate)), + .on_click(cx.listener({ + let editor = editor.clone(); + move |this, _, window, cx| { + this.regenerate(entry_ix, &editor, window, cx); + } + })), ) } @@ -3137,7 +3127,9 @@ impl AcpThreadView { } fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context) { - self.entry_view_state.settings_changed(cx); + self.entry_view_state.update(cx, |entry_view_state, cx| { + entry_view_state.settings_changed(cx); + }); } pub(crate) fn insert_dragged_files( @@ -3152,9 +3144,7 @@ impl AcpThreadView { drop(added_worktrees); }) } -} -impl AcpThreadView { fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option
{ let content = match self.thread_error.as_ref()? { ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx), @@ -3439,35 +3429,6 @@ impl Render for AcpThreadView { } } -fn user_message_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { - let mut style = default_markdown_style(false, window, cx); - let mut text_style = window.text_style(); - let theme_settings = ThemeSettings::get_global(cx); - - let buffer_font = theme_settings.buffer_font.family.clone(); - let buffer_font_size = TextSize::Small.rems(cx); - - text_style.refine(&TextStyleRefinement { - font_family: Some(buffer_font), - font_size: Some(buffer_font_size.into()), - ..Default::default() - }); - - style.base_text_style = text_style; - style.link_callback = Some(Rc::new(move |url, cx| { - if MentionUri::parse(url).is_ok() { - let colors = cx.theme().colors(); - Some(TextStyleRefinement { - background_color: Some(colors.element_background), - ..Default::default() - }) - } else { - None - } - })); - style -} - fn default_markdown_style(buffer_font: bool, window: &Window, cx: &App) -> MarkdownStyle { let theme_settings = ThemeSettings::get_global(cx); let colors = cx.theme().colors(); @@ -3626,12 +3587,13 @@ pub(crate) mod tests { use agent_client_protocol::SessionId; use editor::EditorSettings; use fs::FakeFs; - use gpui::{SemanticVersion, TestAppContext, VisualTestContext}; + use gpui::{EventEmitter, SemanticVersion, TestAppContext, VisualTestContext}; use project::Project; use serde_json::json; use settings::SettingsStore; use std::any::Any; use std::path::Path; + use workspace::Item; use super::*; @@ -3778,6 +3740,50 @@ pub(crate) mod tests { (thread_view, cx) } + fn add_to_workspace(thread_view: Entity, cx: &mut VisualTestContext) { + let workspace = thread_view.read_with(cx, |thread_view, _cx| thread_view.workspace.clone()); + + workspace + .update_in(cx, |workspace, window, cx| { + workspace.add_item_to_active_pane( + Box::new(cx.new(|_| ThreadViewItem(thread_view.clone()))), + None, + true, + window, + cx, + ); + }) + .unwrap(); + } + + struct ThreadViewItem(Entity); + + impl Item for ThreadViewItem { + type Event = (); + + fn include_in_nav_history() -> bool { + false + } + + fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { + "Test".into() + } + } + + impl EventEmitter<()> for ThreadViewItem {} + + impl Focusable for ThreadViewItem { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.0.read(cx).focus_handle(cx).clone() + } + } + + impl Render for ThreadViewItem { + fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { + self.0.clone().into_any_element() + } + } + struct StubAgentServer { connection: C, } @@ -3799,19 +3805,19 @@ pub(crate) mod tests { C: 'static + AgentConnection + Send + Clone, { fn logo(&self) -> ui::IconName { - unimplemented!() + ui::IconName::Ai } fn name(&self) -> &'static str { - unimplemented!() + "Test" } fn empty_state_headline(&self) -> &'static str { - unimplemented!() + "Test" } fn empty_state_message(&self) -> &'static str { - unimplemented!() + "Test" } fn connect( @@ -3960,9 +3966,17 @@ pub(crate) mod tests { assert_eq!(thread.entries().len(), 2); }); - thread_view.read_with(cx, |view, _| { - assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0); - assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1); + thread_view.read_with(cx, |view, cx| { + view.entry_view_state.read_with(cx, |entry_view_state, _| { + assert!( + entry_view_state + .entry(0) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(1).unwrap().has_content()); + }); }); // Second user message @@ -3991,18 +4005,31 @@ pub(crate) mod tests { let second_user_message_id = thread.read_with(cx, |thread, _| { assert_eq!(thread.entries().len(), 4); - let AgentThreadEntry::UserMessage(user_message) = thread.entries().get(2).unwrap() - else { + let AgentThreadEntry::UserMessage(user_message) = &thread.entries()[2] else { panic!(); }; user_message.id.clone().unwrap() }); - thread_view.read_with(cx, |view, _| { - assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0); - assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1); - assert_eq!(view.entry_view_state.entry(2).unwrap().len(), 0); - assert_eq!(view.entry_view_state.entry(3).unwrap().len(), 1); + thread_view.read_with(cx, |view, cx| { + view.entry_view_state.read_with(cx, |entry_view_state, _| { + assert!( + entry_view_state + .entry(0) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(1).unwrap().has_content()); + assert!( + entry_view_state + .entry(2) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(3).unwrap().has_content()); + }); }); // Rewind to first message @@ -4017,13 +4044,169 @@ pub(crate) mod tests { assert_eq!(thread.entries().len(), 2); }); - thread_view.read_with(cx, |view, _| { - assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0); - assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1); + thread_view.read_with(cx, |view, cx| { + view.entry_view_state.read_with(cx, |entry_view_state, _| { + assert!( + entry_view_state + .entry(0) + .unwrap() + .message_editor() + .is_some() + ); + assert!(entry_view_state.entry(1).unwrap().has_content()); - // Old views should be dropped - assert!(view.entry_view_state.entry(2).is_none()); - assert!(view.entry_view_state.entry(3).is_none()); + // Old views should be dropped + assert!(entry_view_state.entry(2).is_none()); + assert!(entry_view_state.entry(3).is_none()); + }); }); } + + #[gpui::test] + async fn test_message_editing_cancel(cx: &mut TestAppContext) { + init_test(cx); + + let connection = StubAgentConnection::new(); + + connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text: "Response".into(), + annotations: None, + }), + }]); + + let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await; + add_to_workspace(thread_view.clone(), cx); + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Original message to edit", window, cx); + }); + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); + + cx.run_until_parked(); + + let user_message_editor = thread_view.read_with(cx, |view, cx| { + assert_eq!(view.editing_message, None); + + view.entry_view_state + .read(cx) + .entry(0) + .unwrap() + .message_editor() + .unwrap() + .clone() + }); + + // Focus + cx.focus(&user_message_editor); + thread_view.read_with(cx, |view, _cx| { + assert_eq!(view.editing_message, Some(0)); + }); + + // Edit + user_message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Edited message content", window, cx); + }); + + // Cancel + user_message_editor.update_in(cx, |_editor, window, cx| { + window.dispatch_action(Box::new(editor::actions::Cancel), cx); + }); + + thread_view.read_with(cx, |view, _cx| { + assert_eq!(view.editing_message, None); + }); + + user_message_editor.read_with(cx, |editor, cx| { + assert_eq!(editor.text(cx), "Original message to edit"); + }); + } + + #[gpui::test] + async fn test_message_editing_regenerate(cx: &mut TestAppContext) { + init_test(cx); + + let connection = StubAgentConnection::new(); + + connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text: "Response".into(), + annotations: None, + }), + }]); + + let (thread_view, cx) = + setup_thread_view(StubAgentServer::new(connection.clone()), cx).await; + add_to_workspace(thread_view.clone(), cx); + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Original message to edit", window, cx); + }); + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); + + cx.run_until_parked(); + + let user_message_editor = thread_view.read_with(cx, |view, cx| { + assert_eq!(view.editing_message, None); + assert_eq!(view.thread().unwrap().read(cx).entries().len(), 2); + + view.entry_view_state + .read(cx) + .entry(0) + .unwrap() + .message_editor() + .unwrap() + .clone() + }); + + // Focus + cx.focus(&user_message_editor); + + // Edit + user_message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Edited message content", window, cx); + }); + + // Send + connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text(acp::TextContent { + text: "New Response".into(), + annotations: None, + }), + }]); + + user_message_editor.update_in(cx, |_editor, window, cx| { + window.dispatch_action(Box::new(Chat), cx); + }); + + cx.run_until_parked(); + + thread_view.read_with(cx, |view, cx| { + assert_eq!(view.editing_message, None); + + let entries = view.thread().unwrap().read(cx).entries(); + assert_eq!(entries.len(), 2); + assert_eq!( + entries[0].to_markdown(cx), + "## User\n\nEdited message content\n\n" + ); + assert_eq!( + entries[1].to_markdown(cx), + "## Assistant\n\nNew Response\n\n" + ); + + let new_editor = view.entry_view_state.read_with(cx, |state, _cx| { + assert!(!state.entry(1).unwrap().has_content()); + state.entry(0).unwrap().message_editor().unwrap().clone() + }); + + assert_eq!(new_editor.read(cx).text(cx), "Edited message content"); + }) + } } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 73915195f5..519f7980ff 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -818,12 +818,10 @@ impl AgentPanel { ActiveView::Thread { thread, .. } => { thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx)); } - ActiveView::ExternalAgentThread { thread_view, .. } => { - thread_view.update(cx, |thread_element, cx| { - thread_element.cancel_generation(cx) - }); - } - ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} + ActiveView::ExternalAgentThread { .. } + | ActiveView::TextThread { .. } + | ActiveView::History + | ActiveView::Configuration => {} } } From f3654036189ba5ca414f9827aee52c0a9f7e95d9 Mon Sep 17 00:00:00 2001 From: Yang Gang Date: Sat, 16 Aug 2025 05:03:50 +0800 Subject: [PATCH 14/17] agent: Update use_modifier_to_send behavior description for Windows (#36230) Release Notes: - N/A Signed-off-by: Yang Gang --- crates/agent_settings/src/agent_settings.rs | 2 +- crates/agent_ui/src/agent_configuration.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index d9557c5d00..fd38ba1f7f 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -309,7 +309,7 @@ pub struct AgentSettingsContent { /// /// Default: true expand_terminal_card: Option, - /// Whether to always use cmd-enter (or ctrl-enter on Linux) to send messages in the agent panel. + /// Whether to always use cmd-enter (or ctrl-enter on Linux or Windows) to send messages in the agent panel. /// /// Default: false use_modifier_to_send: Option, diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 5f72fa58c8..96558f1bea 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -465,7 +465,7 @@ impl AgentConfiguration { "modifier-send", "Use modifier to submit a message", Some( - "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(), + "Make a modifier (cmd-enter on macOS, ctrl-enter on Linux or Windows) required to send messages.".into(), ), use_modifier_to_send, move |state, _window, cx| { From 3d77ad7e1a8a7afe068aac600d2ab56225fe1fed Mon Sep 17 00:00:00 2001 From: Cole Miller Date: Fri, 15 Aug 2025 17:39:33 -0400 Subject: [PATCH 15/17] thread_view: Start loading images as soon as they're added (#36276) Release Notes: - N/A --- .../agent_ui/src/acp/completion_provider.rs | 129 +++------- crates/agent_ui/src/acp/message_editor.rs | 229 +++++++++++------- 2 files changed, 176 insertions(+), 182 deletions(-) diff --git a/crates/agent_ui/src/acp/completion_provider.rs b/crates/agent_ui/src/acp/completion_provider.rs index d7d2cd5d0e..1a9861d13a 100644 --- a/crates/agent_ui/src/acp/completion_provider.rs +++ b/crates/agent_ui/src/acp/completion_provider.rs @@ -1,20 +1,17 @@ -use std::ffi::OsStr; use std::ops::Range; -use std::path::Path; +use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::AtomicBool; use acp_thread::MentionUri; use anyhow::{Context as _, Result, anyhow}; -use collections::HashMap; +use collections::{HashMap, HashSet}; use editor::display_map::CreaseId; use editor::{CompletionProvider, Editor, ExcerptId}; use futures::future::{Shared, try_join_all}; use fuzzy::{StringMatch, StringMatchCandidate}; -use gpui::{App, Entity, ImageFormat, Img, Task, WeakEntity}; -use http_client::HttpClientWithUrl; +use gpui::{App, Entity, ImageFormat, Task, WeakEntity}; use language::{Buffer, CodeLabel, HighlightId}; -use language_model::LanguageModelImage; use lsp::CompletionContext; use project::{ Completion, CompletionIntent, CompletionResponse, Project, ProjectPath, Symbol, WorktreeId, @@ -43,7 +40,7 @@ use crate::context_picker::{ #[derive(Clone, Debug, Eq, PartialEq)] pub struct MentionImage { - pub abs_path: Option>, + pub abs_path: Option, pub data: SharedString, pub format: ImageFormat, } @@ -88,6 +85,8 @@ impl MentionSet { window: &mut Window, cx: &mut App, ) -> Task>> { + let mut processed_image_creases = HashSet::default(); + let mut contents = self .uri_by_crease_id .iter() @@ -97,59 +96,27 @@ impl MentionSet { // TODO directories let uri = uri.clone(); let abs_path = abs_path.to_path_buf(); - let extension = abs_path.extension().and_then(OsStr::to_str).unwrap_or(""); - if Img::extensions().contains(&extension) && !extension.contains("svg") { - let open_image_task = project.update(cx, |project, cx| { - let path = project - .find_project_path(&abs_path, cx) - .context("Failed to find project path")?; - anyhow::Ok(project.open_image(path, cx)) + if let Some(task) = self.images.get(&crease_id).cloned() { + processed_image_creases.insert(crease_id); + return cx.spawn(async move |_| { + let image = task.await.map_err(|e| anyhow!("{e}"))?; + anyhow::Ok((crease_id, Mention::Image(image))) }); - - cx.spawn(async move |cx| { - let image_item = open_image_task?.await?; - let (data, format) = image_item.update(cx, |image_item, cx| { - let format = image_item.image.format; - ( - LanguageModelImage::from_image( - image_item.image.clone(), - cx, - ), - format, - ) - })?; - let data = cx.spawn(async move |_| { - if let Some(data) = data.await { - Ok(data.source) - } else { - anyhow::bail!("Failed to convert image") - } - }); - - anyhow::Ok(( - crease_id, - Mention::Image(MentionImage { - abs_path: Some(abs_path.as_path().into()), - data: data.await?, - format, - }), - )) - }) - } else { - let buffer_task = project.update(cx, |project, cx| { - let path = project - .find_project_path(abs_path, cx) - .context("Failed to find project path")?; - anyhow::Ok(project.open_buffer(path, cx)) - }); - cx.spawn(async move |cx| { - let buffer = buffer_task?.await?; - let content = buffer.read_with(cx, |buffer, _cx| buffer.text())?; - - anyhow::Ok((crease_id, Mention::Text { uri, content })) - }) } + + let buffer_task = project.update(cx, |project, cx| { + let path = project + .find_project_path(abs_path, cx) + .context("Failed to find project path")?; + anyhow::Ok(project.open_buffer(path, cx)) + }); + cx.spawn(async move |cx| { + let buffer = buffer_task?.await?; + let content = buffer.read_with(cx, |buffer, _cx| buffer.text())?; + + anyhow::Ok((crease_id, Mention::Text { uri, content })) + }) } MentionUri::Symbol { path, line_range, .. @@ -243,15 +210,19 @@ impl MentionSet { }) .collect::>(); - contents.extend(self.images.iter().map(|(crease_id, image)| { + // Handle images that didn't have a mention URI (because they were added by the paste handler). + contents.extend(self.images.iter().filter_map(|(crease_id, image)| { + if processed_image_creases.contains(crease_id) { + return None; + } let crease_id = *crease_id; let image = image.clone(); - cx.spawn(async move |_| { + Some(cx.spawn(async move |_| { Ok(( crease_id, Mention::Image(image.await.map_err(|e| anyhow::anyhow!("{e}"))?), )) - }) + })) })); cx.spawn(async move |_cx| { @@ -753,7 +724,6 @@ impl ContextPickerCompletionProvider { source_range: Range, url_to_fetch: SharedString, message_editor: WeakEntity, - http_client: Arc, cx: &mut App, ) -> Option { let new_text = format!("@fetch {} ", url_to_fetch.clone()); @@ -772,30 +742,13 @@ impl ContextPickerCompletionProvider { source: project::CompletionSource::Custom, icon_path: Some(icon_path.clone()), insert_text_mode: None, - confirm: Some({ - Arc::new(move |_, window, cx| { - let url_to_fetch = url_to_fetch.clone(); - let source_range = source_range.clone(); - let message_editor = message_editor.clone(); - let new_text = new_text.clone(); - let http_client = http_client.clone(); - window.defer(cx, move |window, cx| { - message_editor - .update(cx, |message_editor, cx| { - message_editor.confirm_mention_for_fetch( - new_text, - source_range, - url_to_fetch, - http_client, - window, - cx, - ) - }) - .ok(); - }); - false - }) - }), + confirm: Some(confirm_completion_callback( + url_to_fetch.to_string().into(), + source_range.start, + new_text.len() - 1, + message_editor, + mention_uri, + )), }) } } @@ -843,7 +796,6 @@ impl CompletionProvider for ContextPickerCompletionProvider { }; let project = workspace.read(cx).project().clone(); - let http_client = workspace.read(cx).client().http_client(); let snapshot = buffer.read(cx).snapshot(); let source_range = snapshot.anchor_before(state.source_range.start) ..snapshot.anchor_after(state.source_range.end); @@ -852,8 +804,8 @@ impl CompletionProvider for ContextPickerCompletionProvider { let text_thread_store = self.text_thread_store.clone(); let editor = self.message_editor.clone(); let Ok((exclude_paths, exclude_threads)) = - self.message_editor.update(cx, |message_editor, cx| { - message_editor.mentioned_path_and_threads(cx) + self.message_editor.update(cx, |message_editor, _cx| { + message_editor.mentioned_path_and_threads() }) else { return Task::ready(Ok(Vec::new())); @@ -942,7 +894,6 @@ impl CompletionProvider for ContextPickerCompletionProvider { source_range.clone(), url, editor.clone(), - http_client.clone(), cx, ), diff --git a/crates/agent_ui/src/acp/message_editor.rs b/crates/agent_ui/src/acp/message_editor.rs index 90827e5514..a4d74db266 100644 --- a/crates/agent_ui/src/acp/message_editor.rs +++ b/crates/agent_ui/src/acp/message_editor.rs @@ -16,14 +16,14 @@ use editor::{ use futures::{FutureExt as _, TryFutureExt as _}; use gpui::{ AppContext, ClipboardEntry, Context, Entity, EventEmitter, FocusHandle, Focusable, Image, - ImageFormat, Task, TextStyle, WeakEntity, + ImageFormat, Img, Task, TextStyle, WeakEntity, }; -use http_client::HttpClientWithUrl; use language::{Buffer, Language}; use language_model::LanguageModelImage; use project::{CompletionIntent, Project}; use settings::Settings; use std::{ + ffi::OsStr, fmt::Write, ops::Range, path::{Path, PathBuf}, @@ -48,6 +48,7 @@ pub struct MessageEditor { mention_set: MentionSet, editor: Entity, project: Entity, + workspace: WeakEntity, thread_store: Entity, text_thread_store: Entity, } @@ -79,7 +80,7 @@ impl MessageEditor { None, ); let completion_provider = ContextPickerCompletionProvider::new( - workspace, + workspace.clone(), thread_store.downgrade(), text_thread_store.downgrade(), cx.weak_entity(), @@ -114,6 +115,7 @@ impl MessageEditor { mention_set, thread_store, text_thread_store, + workspace, } } @@ -131,7 +133,7 @@ impl MessageEditor { self.editor.read(cx).is_empty(cx) } - pub fn mentioned_path_and_threads(&self, _: &App) -> (HashSet, HashSet) { + pub fn mentioned_path_and_threads(&self) -> (HashSet, HashSet) { let mut excluded_paths = HashSet::default(); let mut excluded_threads = HashSet::default(); @@ -165,8 +167,14 @@ impl MessageEditor { let Some((excerpt_id, _, _)) = snapshot.buffer_snapshot.as_singleton() else { return; }; + let Some(anchor) = snapshot + .buffer_snapshot + .anchor_in_excerpt(*excerpt_id, start) + else { + return; + }; - if let Some(crease_id) = crate::context_picker::insert_crease_for_mention( + let Some(crease_id) = crate::context_picker::insert_crease_for_mention( *excerpt_id, start, content_len, @@ -175,48 +183,83 @@ impl MessageEditor { self.editor.clone(), window, cx, - ) { - self.mention_set.insert_uri(crease_id, mention_uri.clone()); - } - } - - pub fn confirm_mention_for_fetch( - &mut self, - new_text: String, - source_range: Range, - url: url::Url, - http_client: Arc, - window: &mut Window, - cx: &mut Context, - ) { - let mention_uri = MentionUri::Fetch { url: url.clone() }; - let icon_path = mention_uri.icon_path(cx); - - let start = source_range.start; - let content_len = new_text.len() - 1; - - let snapshot = self - .editor - .update(cx, |editor, cx| editor.snapshot(window, cx)); - let Some((&excerpt_id, _, _)) = snapshot.buffer_snapshot.as_singleton() else { - return; - }; - - let Some(crease_id) = crate::context_picker::insert_crease_for_mention( - excerpt_id, - start, - content_len, - url.to_string().into(), - icon_path, - self.editor.clone(), - window, - cx, ) else { return; }; + self.mention_set.insert_uri(crease_id, mention_uri.clone()); - let http_client = http_client.clone(); - let source_range = source_range.clone(); + match mention_uri { + MentionUri::Fetch { url } => { + self.confirm_mention_for_fetch(crease_id, anchor, url, window, cx); + } + MentionUri::File { + abs_path, + is_directory, + } => { + self.confirm_mention_for_file( + crease_id, + anchor, + abs_path, + is_directory, + window, + cx, + ); + } + MentionUri::Symbol { .. } + | MentionUri::Thread { .. } + | MentionUri::TextThread { .. } + | MentionUri::Rule { .. } + | MentionUri::Selection { .. } => {} + } + } + + fn confirm_mention_for_file( + &mut self, + crease_id: CreaseId, + anchor: Anchor, + abs_path: PathBuf, + _is_directory: bool, + window: &mut Window, + cx: &mut Context, + ) { + let extension = abs_path + .extension() + .and_then(OsStr::to_str) + .unwrap_or_default(); + let project = self.project.clone(); + let Some(project_path) = project + .read(cx) + .project_path_for_absolute_path(&abs_path, cx) + else { + return; + }; + + if Img::extensions().contains(&extension) && !extension.contains("svg") { + let image = cx.spawn(async move |_, cx| { + let image = project + .update(cx, |project, cx| project.open_image(project_path, cx))? + .await?; + image.read_with(cx, |image, _cx| image.image.clone()) + }); + self.confirm_mention_for_image(crease_id, anchor, Some(abs_path), image, window, cx); + } + } + + fn confirm_mention_for_fetch( + &mut self, + crease_id: CreaseId, + anchor: Anchor, + url: url::Url, + window: &mut Window, + cx: &mut Context, + ) { + let Some(http_client) = self + .workspace + .update(cx, |workspace, _cx| workspace.client().http_client()) + .ok() + else { + return; + }; let url_string = url.to_string(); let fetch = cx @@ -227,22 +270,18 @@ impl MessageEditor { .await }) .shared(); - self.mention_set.add_fetch_result(url, fetch.clone()); + self.mention_set + .add_fetch_result(url.clone(), fetch.clone()); cx.spawn_in(window, async move |this, cx| { let fetch = fetch.await.notify_async_err(cx); this.update(cx, |this, cx| { + let mention_uri = MentionUri::Fetch { url }; if fetch.is_some() { this.mention_set.insert_uri(crease_id, mention_uri.clone()); } else { // Remove crease if we failed to fetch this.editor.update(cx, |editor, cx| { - let snapshot = editor.buffer().read(cx).snapshot(cx); - let Some(anchor) = - snapshot.anchor_in_excerpt(excerpt_id, source_range.start) - else { - return; - }; editor.display_map.update(cx, |display_map, cx| { display_map.unfold_intersecting(vec![anchor..anchor], true, cx); }); @@ -424,27 +463,46 @@ impl MessageEditor { let replacement_text = "image"; for image in images { - let (excerpt_id, anchor) = self.editor.update(cx, |message_editor, cx| { - let snapshot = message_editor.snapshot(window, cx); - let (excerpt_id, _, snapshot) = snapshot.buffer_snapshot.as_singleton().unwrap(); + let (excerpt_id, text_anchor, multibuffer_anchor) = + self.editor.update(cx, |message_editor, cx| { + let snapshot = message_editor.snapshot(window, cx); + let (excerpt_id, _, buffer_snapshot) = + snapshot.buffer_snapshot.as_singleton().unwrap(); - let anchor = snapshot.anchor_before(snapshot.len()); - message_editor.edit( - [( - multi_buffer::Anchor::max()..multi_buffer::Anchor::max(), - format!("{replacement_text} "), - )], - cx, - ); - (*excerpt_id, anchor) - }); + let text_anchor = buffer_snapshot.anchor_before(buffer_snapshot.len()); + let multibuffer_anchor = snapshot + .buffer_snapshot + .anchor_in_excerpt(*excerpt_id, text_anchor); + message_editor.edit( + [( + multi_buffer::Anchor::max()..multi_buffer::Anchor::max(), + format!("{replacement_text} "), + )], + cx, + ); + (*excerpt_id, text_anchor, multibuffer_anchor) + }); - self.insert_image( + let content_len = replacement_text.len(); + let Some(anchor) = multibuffer_anchor else { + return; + }; + let Some(crease_id) = insert_crease_for_image( excerpt_id, + text_anchor, + content_len, + None.clone(), + self.editor.clone(), + window, + cx, + ) else { + return; + }; + self.confirm_mention_for_image( + crease_id, anchor, - replacement_text.len(), - Arc::new(image), None, + Task::ready(Ok(Arc::new(image))), window, cx, ); @@ -510,34 +568,25 @@ impl MessageEditor { }) } - fn insert_image( + fn confirm_mention_for_image( &mut self, - excerpt_id: ExcerptId, - crease_start: text::Anchor, - content_len: usize, - image: Arc, - abs_path: Option>, + crease_id: CreaseId, + anchor: Anchor, + abs_path: Option, + image: Task>>, window: &mut Window, cx: &mut Context, ) { - let Some(crease_id) = insert_crease_for_image( - excerpt_id, - crease_start, - content_len, - abs_path.clone(), - self.editor.clone(), - window, - cx, - ) else { - return; - }; self.editor.update(cx, |_editor, cx| { - let format = image.format; - let convert = LanguageModelImage::from_image(image, cx); - let task = cx .spawn_in(window, async move |editor, cx| { - if let Some(image) = convert.await { + let image = image.await.map_err(|e| e.to_string())?; + let format = image.format; + let image = cx + .update(|_, cx| LanguageModelImage::from_image(image, cx)) + .map_err(|e| e.to_string())? + .await; + if let Some(image) = image { Ok(MentionImage { abs_path, data: image.source, @@ -546,12 +595,6 @@ impl MessageEditor { } else { editor .update(cx, |editor, cx| { - let snapshot = editor.buffer().read(cx).snapshot(cx); - let Some(anchor) = - snapshot.anchor_in_excerpt(excerpt_id, crease_start) - else { - return; - }; editor.display_map.update(cx, |display_map, cx| { display_map.unfold_intersecting(vec![anchor..anchor], true, cx); }); From f642f7615f876f56b1cb5bad90c9ee2bbf574bf0 Mon Sep 17 00:00:00 2001 From: Ben Kunkle Date: Fri, 15 Aug 2025 16:59:57 -0500 Subject: [PATCH 16/17] keymap_ui: Don't try to parse empty action arguments as JSON (#36278) Closes #ISSUE Release Notes: - Keymap Editor: Fixed an issue where leaving the arguments field empty would result in an error even if arguments were optional --- crates/settings_ui/src/keybindings.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/settings_ui/src/keybindings.rs b/crates/settings_ui/src/keybindings.rs index 1aaab211aa..b4e871c617 100644 --- a/crates/settings_ui/src/keybindings.rs +++ b/crates/settings_ui/src/keybindings.rs @@ -2181,6 +2181,7 @@ impl KeybindingEditorModal { let value = action_arguments .as_ref() + .filter(|args| !args.is_empty()) .map(|args| { serde_json::from_str(args).context("Failed to parse action arguments as JSON") }) From b9c110e63e02eea44cde2c1e24d6d332e2a6f0ee Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 15 Aug 2025 18:01:41 -0400 Subject: [PATCH 17/17] collab: Remove `GET /users/look_up` endpoint (#36279) This PR removes the `GET /users/look_up` endpoint from Collab, as it has been moved to Cloud. Release Notes: - N/A --- crates/collab/src/api.rs | 101 +-------------------------------------- 1 file changed, 1 insertion(+), 100 deletions(-) diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 143e764eb3..0cc7e2b2e9 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -4,12 +4,7 @@ pub mod extensions; pub mod ips_file; pub mod slack; -use crate::db::Database; -use crate::{ - AppState, Error, Result, auth, - db::{User, UserId}, - rpc, -}; +use crate::{AppState, Error, Result, auth, db::UserId, rpc}; use anyhow::Context as _; use axum::{ Extension, Json, Router, @@ -96,7 +91,6 @@ impl std::fmt::Display for SystemIdHeader { pub fn routes(rpc_server: Arc) -> Router<(), Body> { Router::new() - .route("/users/look_up", get(look_up_user)) .route("/users/:id/access_tokens", post(create_access_token)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) .merge(contributors::router()) @@ -138,99 +132,6 @@ pub async fn validate_api_token(req: Request, next: Next) -> impl IntoR Ok::<_, Error>(next.run(req).await) } -#[derive(Debug, Deserialize)] -struct LookUpUserParams { - identifier: String, -} - -#[derive(Debug, Serialize)] -struct LookUpUserResponse { - user: Option, -} - -async fn look_up_user( - Query(params): Query, - Extension(app): Extension>, -) -> Result> { - let user = resolve_identifier_to_user(&app.db, ¶ms.identifier).await?; - let user = if let Some(user) = user { - match user { - UserOrId::User(user) => Some(user), - UserOrId::Id(id) => app.db.get_user_by_id(id).await?, - } - } else { - None - }; - - Ok(Json(LookUpUserResponse { user })) -} - -enum UserOrId { - User(User), - Id(UserId), -} - -async fn resolve_identifier_to_user( - db: &Arc, - identifier: &str, -) -> Result> { - if let Some(identifier) = identifier.parse::().ok() { - let user = db.get_user_by_id(UserId(identifier)).await?; - - return Ok(user.map(UserOrId::User)); - } - - if identifier.starts_with("cus_") { - let billing_customer = db - .get_billing_customer_by_stripe_customer_id(&identifier) - .await?; - - return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))); - } - - if identifier.starts_with("sub_") { - let billing_subscription = db - .get_billing_subscription_by_stripe_subscription_id(&identifier) - .await?; - - if let Some(billing_subscription) = billing_subscription { - let billing_customer = db - .get_billing_customer_by_id(billing_subscription.billing_customer_id) - .await?; - - return Ok( - billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)) - ); - } else { - return Ok(None); - } - } - - if identifier.contains('@') { - let user = db.get_user_by_email(identifier).await?; - - return Ok(user.map(UserOrId::User)); - } - - if let Some(user) = db.get_user_by_github_login(identifier).await? { - return Ok(Some(UserOrId::User(user))); - } - - Ok(None) -} - -#[derive(Deserialize, Debug)] -struct CreateUserParams { - github_user_id: i32, - github_login: String, - email_address: String, - email_confirmation_code: Option, - #[serde(default)] - admin: bool, - #[serde(default)] - invite_count: i32, -} - async fn get_rpc_server_snapshot( Extension(rpc_server): Extension>, ) -> Result {