From 50482a6bc2c5274634f9f1f2446b05e425a142ad Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 7 Aug 2025 19:00:45 -0400 Subject: [PATCH 1/6] language_model: Refresh the LLM token upon receiving a `UserUpdated` message from Cloud (#35839) This PR makes it so we refresh the LLM token upon receiving a `UserUpdated` message from Cloud over the WebSocket connection. Release Notes: - N/A --- Cargo.lock | 1 + crates/client/src/client.rs | 4 +-- crates/language_model/Cargo.toml | 1 + .../language_model/src/model/cloud_model.rs | 34 +++++++++---------- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8c1f1d00ba..39ee75f6dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9127,6 +9127,7 @@ dependencies = [ "anyhow", "base64 0.22.1", "client", + "cloud_api_types", "cloud_llm_client", "collections", "futures 0.3.31", diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 12ea4bcd3e..f09c012a85 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -193,7 +193,7 @@ pub fn init(client: &Arc, cx: &mut App) { }); } -pub type MessageToClientHandler = Box; +pub type MessageToClientHandler = Box; struct GlobalClient(Arc); @@ -1684,7 +1684,7 @@ impl Client { pub fn add_message_to_client_handler( self: &Arc, - handler: impl Fn(&MessageToClient, &App) + Send + Sync + 'static, + handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static, ) { self.message_to_client_handlers .lock() diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 841be60b0e..f9920623b5 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -20,6 +20,7 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true base64.workspace = true client.workspace = true +cloud_api_types.workspace = true cloud_llm_client.workspace = true collections.workspace = true futures.workspace = true diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 8ae5893410..3b4c1fa269 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -3,11 +3,9 @@ use std::sync::Arc; use anyhow::Result; use client::Client; +use cloud_api_types::websocket_protocol::MessageToClient; use cloud_llm_client::Plan; -use gpui::{ - App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _, -}; -use proto::TypedEnvelope; +use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _}; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; @@ -82,9 +80,7 @@ impl Global for GlobalRefreshLlmTokenListener {} pub struct RefreshLlmTokenEvent; -pub struct RefreshLlmTokenListener { - _llm_token_subscription: client::Subscription, -} +pub struct RefreshLlmTokenListener; impl EventEmitter for RefreshLlmTokenListener {} @@ -99,17 +95,21 @@ impl RefreshLlmTokenListener { } fn new(client: Arc, cx: &mut Context) -> Self { - Self { - _llm_token_subscription: client - .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token), - } + client.add_message_to_client_handler({ + let this = cx.entity(); + move |message, cx| { + Self::handle_refresh_llm_token(this.clone(), message, cx); + } + }); + + Self } - async fn handle_refresh_llm_token( - this: Entity, - _: TypedEnvelope, - mut cx: AsyncApp, - ) -> Result<()> { - this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent)) + fn handle_refresh_llm_token(this: Entity, message: &MessageToClient, cx: &mut App) { + match message { + MessageToClient::UserUpdated => { + this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent)); + } + } } } From 913e9adf90aef94071059a5ed2c7f635c0a37e92 Mon Sep 17 00:00:00 2001 From: Mikayla Maki Date: Thu, 7 Aug 2025 16:07:33 -0700 Subject: [PATCH 2/6] Move timing fields into span (#35833) Release Notes: - N/A --- crates/collab/src/rpc.rs | 256 +++++++++++++++++++++++---------------- crates/rpc/src/peer.rs | 15 --- 2 files changed, 149 insertions(+), 122 deletions(-) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index b3603d2619..ec1105b138 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -41,9 +41,11 @@ use chrono::Utc; use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; +use futures::TryFutureExt as _; use reqwest_client::ReqwestClient; use rpc::proto::{MultiLspQuery, split_repository_update}; use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; +use tracing::Span; use futures::{ FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture, @@ -94,8 +96,13 @@ const MAX_CONCURRENT_CONNECTIONS: usize = 512; static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0); +const TOTAL_DURATION_MS: &str = "total_duration_ms"; +const PROCESSING_DURATION_MS: &str = "processing_duration_ms"; +const QUEUE_DURATION_MS: &str = "queue_duration_ms"; +const HOST_WAITING_MS: &str = "host_waiting_ms"; + type MessageHandler = - Box, Session) -> BoxFuture<'static, ()>>; + Box, Session, Span) -> BoxFuture<'static, ()>>; pub struct ConnectionGuard; @@ -163,6 +170,42 @@ impl Principal { } } +#[derive(Clone)] +struct MessageContext { + session: Session, + span: tracing::Span, +} + +impl Deref for MessageContext { + type Target = Session; + + fn deref(&self) -> &Self::Target { + &self.session + } +} + +impl MessageContext { + pub fn forward_request( + &self, + receiver_id: ConnectionId, + request: T, + ) -> impl Future> { + let request_start_time = Instant::now(); + let span = self.span.clone(); + tracing::info!("start forwarding request"); + self.peer + .forward_request(self.connection_id, receiver_id, request) + .inspect(move |_| { + span.record( + HOST_WAITING_MS, + request_start_time.elapsed().as_micros() as f64 / 1000.0, + ); + }) + .inspect_err(|_| tracing::error!("error forwarding request")) + .inspect_ok(|_| tracing::info!("finished forwarding request")) + } +} + #[derive(Clone)] struct Session { principal: Principal, @@ -646,40 +689,37 @@ impl Server { fn add_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(TypedEnvelope, Session) -> Fut, + F: 'static + Send + Sync + Fn(TypedEnvelope, MessageContext) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { let prev_handler = self.handlers.insert( TypeId::of::(), - Box::new(move |envelope, session| { + Box::new(move |envelope, session, span| { let envelope = envelope.into_any().downcast::>().unwrap(); let received_at = envelope.received_at; tracing::info!("message received"); let start_time = Instant::now(); - let future = (handler)(*envelope, session); + let future = (handler)( + *envelope, + MessageContext { + session, + span: span.clone(), + }, + ); async move { let result = future.await; let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0; let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0; let queue_duration_ms = total_duration_ms - processing_duration_ms; - + span.record(TOTAL_DURATION_MS, total_duration_ms); + span.record(PROCESSING_DURATION_MS, processing_duration_ms); + span.record(QUEUE_DURATION_MS, queue_duration_ms); match result { Err(error) => { - tracing::error!( - ?error, - total_duration_ms, - processing_duration_ms, - queue_duration_ms, - "error handling message" - ) + tracing::error!(?error, "error handling message") } - Ok(()) => tracing::info!( - total_duration_ms, - processing_duration_ms, - queue_duration_ms, - "finished handling message" - ), + Ok(()) => tracing::info!("finished handling message"), } } .boxed() @@ -693,7 +733,7 @@ impl Server { fn add_message_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(M, Session) -> Fut, + F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { @@ -703,7 +743,7 @@ impl Server { fn add_request_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(M, Response, Session) -> Fut, + F: 'static + Send + Sync + Fn(M, Response, MessageContext) -> Fut, Fut: Send + Future>, M: RequestMessage, { @@ -889,12 +929,16 @@ impl Server { login=field::Empty, impersonator=field::Empty, multi_lsp_query_request=field::Empty, + { TOTAL_DURATION_MS }=field::Empty, + { PROCESSING_DURATION_MS }=field::Empty, + { QUEUE_DURATION_MS }=field::Empty, + { HOST_WAITING_MS }=field::Empty ); principal.update_span(&span); let span_enter = span.enter(); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { let is_background = message.is_background(); - let handle_message = (handler)(message, session.clone()); + let handle_message = (handler)(message, session.clone(), span.clone()); drop(span_enter); let handle_message = async move { @@ -1386,7 +1430,11 @@ async fn connection_lost( } /// Acknowledges a ping from a client, used to keep the connection alive. -async fn ping(_: proto::Ping, response: Response, _session: Session) -> Result<()> { +async fn ping( + _: proto::Ping, + response: Response, + _session: MessageContext, +) -> Result<()> { response.send(proto::Ack {})?; Ok(()) } @@ -1395,7 +1443,7 @@ async fn ping(_: proto::Ping, response: Response, _session: Session async fn create_room( _request: proto::CreateRoom, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let livekit_room = nanoid::nanoid!(30); @@ -1435,7 +1483,7 @@ async fn create_room( async fn join_room( request: proto::JoinRoom, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let room_id = RoomId::from_proto(request.id); @@ -1502,7 +1550,7 @@ async fn join_room( async fn rejoin_room( request: proto::RejoinRoom, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let room; let channel; @@ -1679,7 +1727,7 @@ fn notify_rejoined_projects( async fn leave_room( _: proto::LeaveRoom, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { leave_room_for_session(&session, session.connection_id).await?; response.send(proto::Ack {})?; @@ -1690,7 +1738,7 @@ async fn leave_room( async fn set_room_participant_role( request: proto::SetRoomParticipantRole, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let user_id = UserId::from_proto(request.user_id); let role = ChannelRole::from(request.role()); @@ -1738,7 +1786,7 @@ async fn set_room_participant_role( async fn call( request: proto::Call, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let calling_user_id = session.user_id(); @@ -1807,7 +1855,7 @@ async fn call( async fn cancel_call( request: proto::CancelCall, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let called_user_id = UserId::from_proto(request.called_user_id); let room_id = RoomId::from_proto(request.room_id); @@ -1842,7 +1890,7 @@ async fn cancel_call( } /// Decline an incoming call. -async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> { +async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> { let room_id = RoomId::from_proto(message.room_id); { let room = session @@ -1877,7 +1925,7 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<( async fn update_participant_location( request: proto::UpdateParticipantLocation, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let location = request.location.context("invalid location")?; @@ -1896,7 +1944,7 @@ async fn update_participant_location( async fn share_project( request: proto::ShareProject, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let (project_id, room) = &*session .db() @@ -1917,7 +1965,7 @@ async fn share_project( } /// Unshare a project from the room. -async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> { +async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> { let project_id = ProjectId::from_proto(message.project_id); unshare_project_internal(project_id, session.connection_id, &session).await } @@ -1964,7 +2012,7 @@ async fn unshare_project_internal( async fn join_project( request: proto::JoinProject, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); @@ -2111,7 +2159,7 @@ async fn join_project( } /// Leave someone elses shared project. -async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> { +async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> { let sender_id = session.connection_id; let project_id = ProjectId::from_proto(request.project_id); let db = session.db().await; @@ -2134,7 +2182,7 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result async fn update_project( request: proto::UpdateProject, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); let (room, guest_connection_ids) = &*session @@ -2163,7 +2211,7 @@ async fn update_project( async fn update_worktree( request: proto::UpdateWorktree, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2187,7 +2235,7 @@ async fn update_worktree( async fn update_repository( request: proto::UpdateRepository, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2211,7 +2259,7 @@ async fn update_repository( async fn remove_repository( request: proto::RemoveRepository, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2235,7 +2283,7 @@ async fn remove_repository( /// Updates other participants with changes to the diagnostics async fn update_diagnostic_summary( message: proto::UpdateDiagnosticSummary, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2259,7 +2307,7 @@ async fn update_diagnostic_summary( /// Updates other participants with changes to the worktree settings async fn update_worktree_settings( message: proto::UpdateWorktreeSettings, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2283,7 +2331,7 @@ async fn update_worktree_settings( /// Notify other participants that a language server has started. async fn start_language_server( request: proto::StartLanguageServer, - session: Session, + session: MessageContext, ) -> Result<()> { let guest_connection_ids = session .db() @@ -2306,7 +2354,7 @@ async fn start_language_server( /// Notify other participants that a language server has changed. async fn update_language_server( request: proto::UpdateLanguageServer, - session: Session, + session: MessageContext, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); let db = session.db().await; @@ -2339,7 +2387,7 @@ async fn update_language_server( async fn forward_read_only_project_request( request: T, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> where T: EntityMessage + RequestMessage, @@ -2350,10 +2398,7 @@ where .await .host_for_read_only_project_request(project_id, session.connection_id) .await?; - let payload = session - .peer - .forward_request(session.connection_id, host_connection_id, request) - .await?; + let payload = session.forward_request(host_connection_id, request).await?; response.send(payload)?; Ok(()) } @@ -2363,7 +2408,7 @@ where async fn forward_mutating_project_request( request: T, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> where T: EntityMessage + RequestMessage, @@ -2375,10 +2420,7 @@ where .await .host_for_mutating_project_request(project_id, session.connection_id) .await?; - let payload = session - .peer - .forward_request(session.connection_id, host_connection_id, request) - .await?; + let payload = session.forward_request(host_connection_id, request).await?; response.send(payload)?; Ok(()) } @@ -2386,7 +2428,7 @@ where async fn multi_lsp_query( request: MultiLspQuery, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { tracing::Span::current().record("multi_lsp_query_request", request.request_str()); tracing::info!("multi_lsp_query message received"); @@ -2396,7 +2438,7 @@ async fn multi_lsp_query( /// Notify other participants that a new buffer has been created async fn create_buffer_for_peer( request: proto::CreateBufferForPeer, - session: Session, + session: MessageContext, ) -> Result<()> { session .db() @@ -2418,7 +2460,7 @@ async fn create_buffer_for_peer( async fn update_buffer( request: proto::UpdateBuffer, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); let mut capability = Capability::ReadOnly; @@ -2453,17 +2495,14 @@ async fn update_buffer( }; if host != session.connection_id { - session - .peer - .forward_request(session.connection_id, host, request.clone()) - .await?; + session.forward_request(host, request.clone()).await?; } response.send(proto::Ack {})?; Ok(()) } -async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> { +async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> { let project_id = ProjectId::from_proto(message.project_id); let operation = message.operation.as_ref().context("invalid operation")?; @@ -2508,7 +2547,7 @@ async fn update_context(message: proto::UpdateContext, session: Session) -> Resu /// Notify other participants that a project has been updated. async fn broadcast_project_message_from_host>( request: T, - session: Session, + session: MessageContext, ) -> Result<()> { let project_id = ProjectId::from_proto(request.remote_entity_id()); let project_connection_ids = session @@ -2533,7 +2572,7 @@ async fn broadcast_project_message_from_host, - session: Session, + session: MessageContext, ) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let project_id = request.project_id.map(ProjectId::from_proto); @@ -2546,10 +2585,7 @@ async fn follow( .check_room_participants(room_id, leader_id, session.connection_id) .await?; - let response_payload = session - .peer - .forward_request(session.connection_id, leader_id, request) - .await?; + let response_payload = session.forward_request(leader_id, request).await?; response.send(response_payload)?; if let Some(project_id) = project_id { @@ -2565,7 +2601,7 @@ async fn follow( } /// Stop following another user in a call. -async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> { +async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let project_id = request.project_id.map(ProjectId::from_proto); let leader_id = request.leader_id.context("invalid leader id")?.into(); @@ -2594,7 +2630,7 @@ async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> { } /// Notify everyone following you of your current location. -async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> { +async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let database = session.db.lock().await; @@ -2629,7 +2665,7 @@ async fn update_followers(request: proto::UpdateFollowers, session: Session) -> async fn get_users( request: proto::GetUsers, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let user_ids = request .user_ids @@ -2657,7 +2693,7 @@ async fn get_users( async fn fuzzy_search_users( request: proto::FuzzySearchUsers, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let query = request.query; let users = match query.len() { @@ -2689,7 +2725,7 @@ async fn fuzzy_search_users( async fn request_contact( request: proto::RequestContact, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let requester_id = session.user_id(); let responder_id = UserId::from_proto(request.responder_id); @@ -2736,7 +2772,7 @@ async fn request_contact( async fn respond_to_contact_request( request: proto::RespondToContactRequest, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let responder_id = session.user_id(); let requester_id = UserId::from_proto(request.requester_id); @@ -2794,7 +2830,7 @@ async fn respond_to_contact_request( async fn remove_contact( request: proto::RemoveContact, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let requester_id = session.user_id(); let responder_id = UserId::from_proto(request.user_id); @@ -3053,7 +3089,10 @@ async fn update_user_plan(session: &Session) -> Result<()> { Ok(()) } -async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> { +async fn subscribe_to_channels( + _: proto::SubscribeToChannels, + session: MessageContext, +) -> Result<()> { subscribe_user_to_channels(session.user_id(), &session).await?; Ok(()) } @@ -3079,7 +3118,7 @@ async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Resul async fn create_channel( request: proto::CreateChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; @@ -3134,7 +3173,7 @@ async fn create_channel( async fn delete_channel( request: proto::DeleteChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; @@ -3162,7 +3201,7 @@ async fn delete_channel( async fn invite_channel_member( request: proto::InviteChannelMember, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3199,7 +3238,7 @@ async fn invite_channel_member( async fn remove_channel_member( request: proto::RemoveChannelMember, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3243,7 +3282,7 @@ async fn remove_channel_member( async fn set_channel_visibility( request: proto::SetChannelVisibility, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3288,7 +3327,7 @@ async fn set_channel_visibility( async fn set_channel_member_role( request: proto::SetChannelMemberRole, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3336,7 +3375,7 @@ async fn set_channel_member_role( async fn rename_channel( request: proto::RenameChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3368,7 +3407,7 @@ async fn rename_channel( async fn move_channel( request: proto::MoveChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let to = ChannelId::from_proto(request.to); @@ -3410,7 +3449,7 @@ async fn move_channel( async fn reorder_channel( request: proto::ReorderChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let direction = request.direction(); @@ -3456,7 +3495,7 @@ async fn reorder_channel( async fn get_channel_members( request: proto::GetChannelMembers, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3476,7 +3515,7 @@ async fn get_channel_members( async fn respond_to_channel_invite( request: proto::RespondToChannelInvite, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3517,7 +3556,7 @@ async fn respond_to_channel_invite( async fn join_channel( request: proto::JoinChannel, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); join_channel_internal(channel_id, Box::new(response), session).await @@ -3540,7 +3579,7 @@ impl JoinChannelInternalResponse for Response { async fn join_channel_internal( channel_id: ChannelId, response: Box, - session: Session, + session: MessageContext, ) -> Result<()> { let joined_room = { let mut db = session.db().await; @@ -3635,7 +3674,7 @@ async fn join_channel_internal( async fn join_channel_buffer( request: proto::JoinChannelBuffer, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3666,7 +3705,7 @@ async fn join_channel_buffer( /// Edit the channel notes async fn update_channel_buffer( request: proto::UpdateChannelBuffer, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3718,7 +3757,7 @@ async fn update_channel_buffer( async fn rejoin_channel_buffers( request: proto::RejoinChannelBuffers, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let buffers = db @@ -3753,7 +3792,7 @@ async fn rejoin_channel_buffers( async fn leave_channel_buffer( request: proto::LeaveChannelBuffer, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3815,7 +3854,7 @@ fn send_notifications( async fn send_channel_message( request: proto::SendChannelMessage, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { // Validate the message body. let body = request.body.trim().to_string(); @@ -3908,7 +3947,7 @@ async fn send_channel_message( async fn remove_channel_message( request: proto::RemoveChannelMessage, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); @@ -3943,7 +3982,7 @@ async fn remove_channel_message( async fn update_channel_message( request: proto::UpdateChannelMessage, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); @@ -4027,7 +4066,7 @@ async fn update_channel_message( /// Mark a channel message as read async fn acknowledge_channel_message( request: proto::AckChannelMessage, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); @@ -4047,7 +4086,7 @@ async fn acknowledge_channel_message( /// Mark a buffer version as synced async fn acknowledge_buffer_version( request: proto::AckBufferOperation, - session: Session, + session: MessageContext, ) -> Result<()> { let buffer_id = BufferId::from_proto(request.buffer_id); session @@ -4067,7 +4106,7 @@ async fn acknowledge_buffer_version( async fn get_supermaven_api_key( _request: proto::GetSupermavenApiKey, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let user_id: String = session.user_id().to_string(); if !session.is_staff() { @@ -4096,7 +4135,7 @@ async fn get_supermaven_api_key( async fn join_channel_chat( request: proto::JoinChannelChat, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); @@ -4114,7 +4153,10 @@ async fn join_channel_chat( } /// Stop receiving chat updates for a channel -async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> { +async fn leave_channel_chat( + request: proto::LeaveChannelChat, + session: MessageContext, +) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); session .db() @@ -4128,7 +4170,7 @@ async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) async fn get_channel_messages( request: proto::GetChannelMessages, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let messages = session @@ -4152,7 +4194,7 @@ async fn get_channel_messages( async fn get_channel_messages_by_id( request: proto::GetChannelMessagesById, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let message_ids = request .message_ids @@ -4175,7 +4217,7 @@ async fn get_channel_messages_by_id( async fn get_notifications( request: proto::GetNotifications, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let notifications = session .db() @@ -4197,7 +4239,7 @@ async fn get_notifications( async fn mark_notification_as_read( request: proto::MarkNotificationRead, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let database = &session.db().await; let notifications = database @@ -4219,7 +4261,7 @@ async fn mark_notification_as_read( async fn get_private_user_info( _request: proto::GetPrivateUserInfo, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; @@ -4243,7 +4285,7 @@ async fn get_private_user_info( async fn accept_terms_of_service( _request: proto::AcceptTermsOfService, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; @@ -4267,7 +4309,7 @@ async fn accept_terms_of_service( async fn get_llm_api_token( _request: proto::GetLlmToken, response: Response, - session: Session, + session: MessageContext, ) -> Result<()> { let db = session.db().await; diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 8ddebfb269..80a104641f 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -422,23 +422,8 @@ impl Peer { receiver_id: ConnectionId, request: T, ) -> impl Future> { - let request_start_time = Instant::now(); - let elapsed_time = move || request_start_time.elapsed().as_millis(); - tracing::info!("start forwarding request"); self.request_internal(Some(sender_id), receiver_id, request) .map_ok(|envelope| envelope.payload) - .inspect_err(move |_| { - tracing::error!( - waiting_for_host_ms = elapsed_time(), - "error forwarding request" - ) - }) - .inspect_ok(move |_| { - tracing::info!( - waiting_for_host_ms = elapsed_time(), - "finished forwarding request" - ) - }) } fn request_internal( From 952e3713d7f2675c45cd99aa65bf0e61cb556168 Mon Sep 17 00:00:00 2001 From: Peter Tripp Date: Thu, 7 Aug 2025 19:16:25 -0400 Subject: [PATCH 3/6] ci: Switch to Namespace (#35835) Follow-up to: - https://github.com/zed-industries/zed/pull/35826 Release Notes: - N/A --- .github/actionlint.yml | 27 ++++++++++-------------- .github/workflows/bump_patch_version.yml | 2 +- .github/workflows/ci.yml | 15 ++++++------- .github/workflows/deploy_cloudflare.yml | 2 +- .github/workflows/deploy_collab.yml | 4 ++-- .github/workflows/eval.yml | 2 +- .github/workflows/nix.yml | 2 +- .github/workflows/randomized_tests.yml | 2 +- .github/workflows/release_nightly.yml | 5 ++--- .github/workflows/unit_evals.yml | 2 +- 10 files changed, 28 insertions(+), 35 deletions(-) diff --git a/.github/actionlint.yml b/.github/actionlint.yml index 06b48b9b54..ad09545902 100644 --- a/.github/actionlint.yml +++ b/.github/actionlint.yml @@ -13,22 +13,17 @@ self-hosted-runner: - windows-2025-16 - windows-2025-32 - windows-2025-64 - # Buildjet Ubuntu 20.04 - AMD x86_64 - - buildjet-2vcpu-ubuntu-2004 - - buildjet-4vcpu-ubuntu-2004 - - buildjet-8vcpu-ubuntu-2004 - - buildjet-16vcpu-ubuntu-2004 - - buildjet-32vcpu-ubuntu-2004 - # Buildjet Ubuntu 22.04 - AMD x86_64 - - buildjet-2vcpu-ubuntu-2204 - - buildjet-4vcpu-ubuntu-2204 - - buildjet-8vcpu-ubuntu-2204 - - buildjet-16vcpu-ubuntu-2204 - - buildjet-32vcpu-ubuntu-2204 - # Buildjet Ubuntu 22.04 - Graviton aarch64 - - buildjet-8vcpu-ubuntu-2204-arm - - buildjet-16vcpu-ubuntu-2204-arm - - buildjet-32vcpu-ubuntu-2204-arm + # Namespace Ubuntu 20.04 (Release builds) + - namespace-profile-16x32-ubuntu-2004 + - namespace-profile-32x64-ubuntu-2004 + - namespace-profile-16x32-ubuntu-2004-arm + - namespace-profile-32x64-ubuntu-2004-arm + # Namespace Ubuntu 22.04 (Everything else) + - namespace-profile-2x4-ubuntu-2204 + - namespace-profile-4x8-ubuntu-2204 + - namespace-profile-8x16-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 + - namespace-profile-32x64-ubuntu-2204 # Self Hosted Runners - self-mini-macos - self-32vcpu-windows-2022 diff --git a/.github/workflows/bump_patch_version.yml b/.github/workflows/bump_patch_version.yml index bc44066ea6..bfaf7a271b 100644 --- a/.github/workflows/bump_patch_version.yml +++ b/.github/workflows/bump_patch_version.yml @@ -16,7 +16,7 @@ jobs: bump_patch_version: if: github.repository_owner == 'zed-industries' runs-on: - - github-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Checkout code uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8f4f4d2b11..84907351fe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -137,7 +137,7 @@ jobs: github.repository_owner == 'zed-industries' && needs.job_spec.outputs.run_tests == 'true' runs-on: - - github-8vcpu-ubuntu-2204 + - namespace-profile-8x16-ubuntu-2204 steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -168,7 +168,7 @@ jobs: needs: [job_spec] if: github.repository_owner == 'zed-industries' runs-on: - - github-8vcpu-ubuntu-2204 + - namespace-profile-4x8-ubuntu-2204 steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -221,7 +221,7 @@ jobs: github.repository_owner == 'zed-industries' && (needs.job_spec.outputs.run_tests == 'true' || needs.job_spec.outputs.run_docs == 'true') runs-on: - - github-8vcpu-ubuntu-2204 + - namespace-profile-8x16-ubuntu-2204 steps: - name: Checkout repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -328,7 +328,7 @@ jobs: github.repository_owner == 'zed-industries' && needs.job_spec.outputs.run_tests == 'true' runs-on: - - github-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Add Rust to the PATH run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" @@ -380,7 +380,7 @@ jobs: github.repository_owner == 'zed-industries' && needs.job_spec.outputs.run_tests == 'true' runs-on: - - github-8vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Add Rust to the PATH run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" @@ -597,8 +597,7 @@ jobs: timeout-minutes: 60 name: Linux x86_x64 release bundle runs-on: - - github-16vcpu-ubuntu-2204 - # - buildjet-16vcpu-ubuntu-2004 # ubuntu 20.04 for minimal glibc + - namespace-profile-16x32-ubuntu-2004 # ubuntu 20.04 for minimal glibc if: | startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') @@ -651,7 +650,7 @@ jobs: timeout-minutes: 60 name: Linux arm64 release bundle runs-on: - - github-16vcpu-ubuntu-2204-arm + - namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc if: | startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') diff --git a/.github/workflows/deploy_cloudflare.yml b/.github/workflows/deploy_cloudflare.yml index 3a294fdc17..df35d44ca9 100644 --- a/.github/workflows/deploy_cloudflare.yml +++ b/.github/workflows/deploy_cloudflare.yml @@ -9,7 +9,7 @@ jobs: deploy-docs: name: Deploy Docs if: github.repository_owner == 'zed-industries' - runs-on: github-16vcpu-ubuntu-2204 + runs-on: namespace-profile-16x32-ubuntu-2204 steps: - name: Checkout repo diff --git a/.github/workflows/deploy_collab.yml b/.github/workflows/deploy_collab.yml index d1a68a6280..ff2a3589e4 100644 --- a/.github/workflows/deploy_collab.yml +++ b/.github/workflows/deploy_collab.yml @@ -61,7 +61,7 @@ jobs: - style - tests runs-on: - - github-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Install doctl uses: digitalocean/action-doctl@v2 @@ -94,7 +94,7 @@ jobs: needs: - publish runs-on: - - github-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Checkout repo diff --git a/.github/workflows/eval.yml b/.github/workflows/eval.yml index 196e00519b..b5da9e7b7c 100644 --- a/.github/workflows/eval.yml +++ b/.github/workflows/eval.yml @@ -32,7 +32,7 @@ jobs: github.repository_owner == 'zed-industries' && (github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-eval')) runs-on: - - github-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Add Rust to the PATH run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" diff --git a/.github/workflows/nix.yml b/.github/workflows/nix.yml index 913d6cfe9f..e682ce5890 100644 --- a/.github/workflows/nix.yml +++ b/.github/workflows/nix.yml @@ -20,7 +20,7 @@ jobs: matrix: system: - os: x86 Linux - runner: github-16vcpu-ubuntu-2204 + runner: namespace-profile-16x32-ubuntu-2204 install_nix: true - os: arm Mac runner: [macOS, ARM64, test] diff --git a/.github/workflows/randomized_tests.yml b/.github/workflows/randomized_tests.yml index 3a7b476ba0..de96c3df78 100644 --- a/.github/workflows/randomized_tests.yml +++ b/.github/workflows/randomized_tests.yml @@ -20,7 +20,7 @@ jobs: name: Run randomized tests if: github.repository_owner == 'zed-industries' runs-on: - - github-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Install Node uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 diff --git a/.github/workflows/release_nightly.yml b/.github/workflows/release_nightly.yml index c5be72fca2..b3500a085b 100644 --- a/.github/workflows/release_nightly.yml +++ b/.github/workflows/release_nightly.yml @@ -128,8 +128,7 @@ jobs: name: Create a Linux *.tar.gz bundle for x86 if: github.repository_owner == 'zed-industries' runs-on: - - github-16vcpu-ubuntu-2204 - # - buildjet-16vcpu-ubuntu-2004 + - namespace-profile-16x32-ubuntu-2004 # ubuntu 20.04 for minimal glibc needs: tests steps: - name: Checkout repo @@ -169,7 +168,7 @@ jobs: name: Create a Linux *.tar.gz bundle for ARM if: github.repository_owner == 'zed-industries' runs-on: - - github-16vcpu-ubuntu-2204-arm + - namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc needs: tests steps: - name: Checkout repo diff --git a/.github/workflows/unit_evals.yml b/.github/workflows/unit_evals.yml index 225fca558f..2e03fb028f 100644 --- a/.github/workflows/unit_evals.yml +++ b/.github/workflows/unit_evals.yml @@ -23,7 +23,7 @@ jobs: timeout-minutes: 60 name: Run unit evals runs-on: - - github-16vcpu-ubuntu-2204 + - namespace-profile-16x32-ubuntu-2204 steps: - name: Add Rust to the PATH run: echo "$HOME/.cargo/bin" >> "$GITHUB_PATH" From 6912dc8399148dd0caf951ce0bba711de7279f01 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Thu, 7 Aug 2025 20:26:19 -0300 Subject: [PATCH 4/6] Fix CC tool state on cancel (#35763) When we stop the generation, CC tells us the tool completed, but it was actually cancelled. Release Notes: - N/A --- crates/agent_servers/src/claude.rs | 148 +++++++++++++++----------- crates/agent_servers/src/e2e_tests.rs | 7 +- 2 files changed, 91 insertions(+), 64 deletions(-) diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 09d08fdcf8..c65508f152 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -6,7 +6,7 @@ use context_server::listener::McpServerTool; use project::Project; use settings::SettingsStore; use smol::process::Child; -use std::cell::{Cell, RefCell}; +use std::cell::RefCell; use std::fmt::Display; use std::path::Path; use std::rc::Rc; @@ -153,20 +153,17 @@ impl AgentConnection for ClaudeAgentConnection { }) .detach(); - let pending_cancellation = Rc::new(Cell::new(PendingCancellation::None)); + let turn_state = Rc::new(RefCell::new(TurnState::None)); - let end_turn_tx = Rc::new(RefCell::new(None)); let handler_task = cx.spawn({ - let end_turn_tx = end_turn_tx.clone(); + let turn_state = turn_state.clone(); let mut thread_rx = thread_rx.clone(); - let cancellation_state = pending_cancellation.clone(); async move |cx| { while let Some(message) = incoming_message_rx.next().await { ClaudeAgentSession::handle_message( thread_rx.clone(), message, - end_turn_tx.clone(), - cancellation_state.clone(), + turn_state.clone(), cx, ) .await @@ -192,8 +189,7 @@ impl AgentConnection for ClaudeAgentConnection { let session = ClaudeAgentSession { outgoing_tx, - end_turn_tx, - pending_cancellation, + turn_state, _handler_task: handler_task, _mcp_server: Some(permission_mcp_server), }; @@ -225,8 +221,8 @@ impl AgentConnection for ClaudeAgentConnection { ))); }; - let (tx, rx) = oneshot::channel(); - session.end_turn_tx.borrow_mut().replace(tx); + let (end_tx, end_rx) = oneshot::channel(); + session.turn_state.replace(TurnState::InProgress { end_tx }); let mut content = String::new(); for chunk in params.prompt { @@ -260,12 +256,7 @@ impl AgentConnection for ClaudeAgentConnection { return Task::ready(Err(anyhow!(err))); } - let cancellation_state = session.pending_cancellation.clone(); - cx.foreground_executor().spawn(async move { - let result = rx.await??; - cancellation_state.set(PendingCancellation::None); - Ok(result) - }) + cx.foreground_executor().spawn(async move { end_rx.await? }) } fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { @@ -277,7 +268,15 @@ impl AgentConnection for ClaudeAgentConnection { let request_id = new_request_id(); - session.pending_cancellation.set(PendingCancellation::Sent { + let turn_state = session.turn_state.take(); + let TurnState::InProgress { end_tx } = turn_state else { + // Already cancelled or idle, put it back + session.turn_state.replace(turn_state); + return; + }; + + session.turn_state.replace(TurnState::CancelRequested { + end_tx, request_id: request_id.clone(), }); @@ -349,28 +348,56 @@ fn spawn_claude( struct ClaudeAgentSession { outgoing_tx: UnboundedSender, - end_turn_tx: Rc>>>>, - pending_cancellation: Rc>, + turn_state: Rc>, _mcp_server: Option, _handler_task: Task<()>, } -#[derive(Debug, Default, PartialEq)] -enum PendingCancellation { +#[derive(Debug, Default)] +enum TurnState { #[default] None, - Sent { + InProgress { + end_tx: oneshot::Sender>, + }, + CancelRequested { + end_tx: oneshot::Sender>, request_id: String, }, - Confirmed, + CancelConfirmed { + end_tx: oneshot::Sender>, + }, +} + +impl TurnState { + fn is_cancelled(&self) -> bool { + matches!(self, TurnState::CancelConfirmed { .. }) + } + + fn end_tx(self) -> Option>> { + match self { + TurnState::None => None, + TurnState::InProgress { end_tx, .. } => Some(end_tx), + TurnState::CancelRequested { end_tx, .. } => Some(end_tx), + TurnState::CancelConfirmed { end_tx } => Some(end_tx), + } + } + + fn confirm_cancellation(self, id: &str) -> Self { + match self { + TurnState::CancelRequested { request_id, end_tx } if request_id == id => { + TurnState::CancelConfirmed { end_tx } + } + _ => self, + } + } } impl ClaudeAgentSession { async fn handle_message( mut thread_rx: watch::Receiver>, message: SdkMessage, - end_turn_tx: Rc>>>>, - pending_cancellation: Rc>, + turn_state: Rc>, cx: &mut AsyncApp, ) { match message { @@ -393,15 +420,13 @@ impl ClaudeAgentSession { for chunk in message.content.chunks() { match chunk { ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { - let state = pending_cancellation.take(); - if state != PendingCancellation::Confirmed { + if !turn_state.borrow().is_cancelled() { thread .update(cx, |thread, cx| { thread.push_user_content_block(text.into(), cx) }) .log_err(); } - pending_cancellation.set(state); } ContentChunk::ToolResult { content, @@ -414,7 +439,12 @@ impl ClaudeAgentSession { acp::ToolCallUpdate { id: acp::ToolCallId(tool_use_id.into()), fields: acp::ToolCallUpdateFields { - status: Some(acp::ToolCallStatus::Completed), + status: if turn_state.borrow().is_cancelled() { + // Do not set to completed if turn was cancelled + None + } else { + Some(acp::ToolCallStatus::Completed) + }, content: (!content.is_empty()) .then(|| vec![content.into()]), ..Default::default() @@ -541,40 +571,38 @@ impl ClaudeAgentSession { result, .. } => { - if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() { - if is_error - || (subtype == ResultErrorType::ErrorDuringExecution - && pending_cancellation.take() != PendingCancellation::Confirmed) - { - end_turn_tx - .send(Err(anyhow!( - "Error: {}", - result.unwrap_or_else(|| subtype.to_string()) - ))) - .ok(); - } else { - let stop_reason = match subtype { - ResultErrorType::Success => acp::StopReason::EndTurn, - ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests, - ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled, - }; - end_turn_tx - .send(Ok(acp::PromptResponse { stop_reason })) - .ok(); - } + let turn_state = turn_state.take(); + let was_cancelled = turn_state.is_cancelled(); + let Some(end_turn_tx) = turn_state.end_tx() else { + debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn"); + return; + }; + + if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution) + { + end_turn_tx + .send(Err(anyhow!( + "Error: {}", + result.unwrap_or_else(|| subtype.to_string()) + ))) + .ok(); + } else { + let stop_reason = match subtype { + ResultErrorType::Success => acp::StopReason::EndTurn, + ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests, + ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled, + }; + end_turn_tx + .send(Ok(acp::PromptResponse { stop_reason })) + .ok(); } } SdkMessage::ControlResponse { response } => { if matches!(response.subtype, ResultErrorType::Success) { - let pending_cancellation_value = pending_cancellation.take(); - - if let PendingCancellation::Sent { request_id } = &pending_cancellation_value - && request_id == &response.request_id - { - pending_cancellation.set(PendingCancellation::Confirmed); - } else { - pending_cancellation.set(pending_cancellation_value); - } + let new_state = turn_state.take().confirm_cancellation(&response.request_id); + turn_state.replace(new_state); + } else { + log::error!("Control response error: {:?}", response); } } SdkMessage::System { .. } => {} diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index 05f874bd30..ec6ca29b9d 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -246,7 +246,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; - let full_turn = thread.update(cx, |thread, cx| { + let _ = thread.update(cx, |thread, cx| { thread.send_raw( r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, cx, @@ -285,9 +285,8 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon id.clone() }); - let _ = thread.update(cx, |thread, cx| thread.cancel(cx)); - full_turn.await.unwrap(); - thread.read_with(cx, |thread, _| { + thread.update(cx, |thread, cx| thread.cancel(cx)).await; + thread.read_with(cx, |thread, _cx| { let AgentThreadEntry::ToolCall(ToolCall { status: ToolCallStatus::Canceled, .. From 7d4d8b8398b03dedc6eabd71f3474300b947ae68 Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Thu, 7 Aug 2025 19:35:41 -0400 Subject: [PATCH 5/6] Add GPT-5 support through OpenAI API (#35822) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (This PR does not add GPT-5 to Zed Pro, but rather adds access if you're using your own OpenAI API key.) Screenshot 2025-08-07 at 2 23 18 PM --- **NOTE:** If your API key is not through a verified organization, you may see this error: Screenshot 2025-08-07 at 2 04 54 PM Even if your org is verified, you still may not have access to GPT-5, in which case you could see this error: Screenshot 2025-08-07 at 2 09 18 PM One way to test if you're in this situation is to visit https://platform.openai.com/chat/edit?models=gpt-5 and see if you get the same "you don't have access to GPT-5" error on OpenAI's official playground. It looks like this: Screenshot 2025-08-07 at 2 15 25 PM Release Notes: - Added GPT-5, as well as its mini and nano variants. To use this, you need to have an OpenAI API key configured via the `OPENAI_API_KEY` environment variable. --- .../language_models/src/provider/open_ai.rs | 4 +++ crates/open_ai/src/open_ai.rs | 26 ++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 5185e979b7..ee74562687 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -674,6 +674,10 @@ pub fn count_open_ai_tokens( | Model::O3 | Model::O3Mini | Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), + // GPT-5 models don't have tiktoken support yet; fall back on gpt-4o tokenizer + Model::Five | Model::FiveMini | Model::FiveNano => { + tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages) + } } .map(|tokens| tokens as u64) }) diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 12a5cf52d2..4697d71ed3 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -74,6 +74,12 @@ pub enum Model { O3, #[serde(rename = "o4-mini")] O4Mini, + #[serde(rename = "gpt-5")] + Five, + #[serde(rename = "gpt-5-mini")] + FiveMini, + #[serde(rename = "gpt-5-nano")] + FiveNano, #[serde(rename = "custom")] Custom { @@ -105,6 +111,9 @@ impl Model { "o3-mini" => Ok(Self::O3Mini), "o3" => Ok(Self::O3), "o4-mini" => Ok(Self::O4Mini), + "gpt-5" => Ok(Self::Five), + "gpt-5-mini" => Ok(Self::FiveMini), + "gpt-5-nano" => Ok(Self::FiveNano), invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"), } } @@ -123,6 +132,9 @@ impl Model { Self::O3Mini => "o3-mini", Self::O3 => "o3", Self::O4Mini => "o4-mini", + Self::Five => "gpt-5", + Self::FiveMini => "gpt-5-mini", + Self::FiveNano => "gpt-5-nano", Self::Custom { name, .. } => name, } } @@ -141,6 +153,9 @@ impl Model { Self::O3Mini => "o3-mini", Self::O3 => "o3", Self::O4Mini => "o4-mini", + Self::Five => "gpt-5", + Self::FiveMini => "gpt-5-mini", + Self::FiveNano => "gpt-5-nano", Self::Custom { name, display_name, .. } => display_name.as_ref().unwrap_or(name), @@ -161,6 +176,9 @@ impl Model { Self::O3Mini => 200_000, Self::O3 => 200_000, Self::O4Mini => 200_000, + Self::Five => 272_000, + Self::FiveMini => 272_000, + Self::FiveNano => 272_000, Self::Custom { max_tokens, .. } => *max_tokens, } } @@ -182,6 +200,9 @@ impl Model { Self::O3Mini => Some(100_000), Self::O3 => Some(100_000), Self::O4Mini => Some(100_000), + Self::Five => Some(128_000), + Self::FiveMini => Some(128_000), + Self::FiveNano => Some(128_000), } } @@ -197,7 +218,10 @@ impl Model { | Self::FourOmniMini | Self::FourPointOne | Self::FourPointOneMini - | Self::FourPointOneNano => true, + | Self::FourPointOneNano + | Self::Five + | Self::FiveMini + | Self::FiveNano => true, Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false, } } From 3d662ee2828ab022058a60dde74d4f99d3f1a15d Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Thu, 7 Aug 2025 20:46:47 -0300 Subject: [PATCH 6/6] agent2: Port read_file tool (#35840) Ports the read_file tool from `assistant_tools` to `agent2`. Note: Image support not implemented. Release Notes: - N/A --- Cargo.lock | 2 + crates/agent2/Cargo.toml | 3 + crates/agent2/src/agent.rs | 5 +- crates/agent2/src/thread.rs | 33 +- crates/agent2/src/tools.rs | 2 + crates/agent2/src/tools/read_file_tool.rs | 970 ++++++++++++++++++++++ 6 files changed, 1011 insertions(+), 4 deletions(-) create mode 100644 crates/agent2/src/tools/read_file_tool.rs diff --git a/Cargo.lock b/Cargo.lock index 39ee75f6dd..e63c5e2acf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -172,10 +172,12 @@ dependencies = [ "gpui_tokio", "handlebars 4.5.0", "indoc", + "itertools 0.14.0", "language", "language_model", "language_models", "log", + "pretty_assertions", "project", "prompt_store", "reqwest_client", diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 884378fbcc..a75011a671 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -24,6 +24,8 @@ futures.workspace = true gpui.workspace = true handlebars = { workspace = true, features = ["rust-embed"] } indoc.workspace = true +itertools.workspace = true +language.workspace = true language_model.workspace = true language_models.workspace = true log.workspace = true @@ -55,3 +57,4 @@ project = { workspace = true, "features" = ["test-support"] } reqwest_client.workspace = true settings = { workspace = true, "features" = ["test-support"] } worktree = { workspace = true, "features" = ["test-support"] } +pretty_assertions.workspace = true diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index cb568f04c2..2014d86fb7 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,5 +1,5 @@ use crate::{templates::Templates, AgentResponseEvent, Thread}; -use crate::{FindPathTool, ThinkingTool, ToolCallAuthorization}; +use crate::{FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization}; use acp_thread::ModelSelector; use agent_client_protocol as acp; use anyhow::{anyhow, Context as _, Result}; @@ -413,9 +413,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection { })?; let thread = cx.new(|_| { - let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log, agent.templates.clone(), default_model); + let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model); thread.add_tool(ThinkingTool); thread.add_tool(FindPathTool::new(project.clone())); + thread.add_tool(ReadFileTool::new(project.clone(), action_log)); thread }); diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 805ffff1c0..4b8a65655f 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -125,7 +125,7 @@ pub struct Thread { project_context: Rc>, templates: Arc, pub selected_model: Arc, - _action_log: Entity, + action_log: Entity, } impl Thread { @@ -145,7 +145,7 @@ impl Thread { project_context, templates, selected_model: default_model, - _action_log: action_log, + action_log, } } @@ -315,6 +315,10 @@ impl Thread { events_rx } + pub fn action_log(&self) -> &Entity { + &self.action_log + } + pub fn build_system_message(&self) -> AgentMessage { log::debug!("Building system message"); let prompt = SystemPromptTemplate { @@ -924,3 +928,28 @@ impl ToolCallEventStream { .authorize_tool_call(&self.tool_use_id, title, kind, input) } } + +#[cfg(test)] +pub struct TestToolCallEventStream { + stream: ToolCallEventStream, + _events_rx: mpsc::UnboundedReceiver>, +} + +#[cfg(test)] +impl TestToolCallEventStream { + pub fn new() -> Self { + let (events_tx, events_rx) = + mpsc::unbounded::>(); + + let stream = ToolCallEventStream::new("test".into(), AgentResponseEventStream(events_tx)); + + Self { + stream, + _events_rx: events_rx, + } + } + + pub fn stream(&self) -> ToolCallEventStream { + self.stream.clone() + } +} diff --git a/crates/agent2/src/tools.rs b/crates/agent2/src/tools.rs index 848fe552ed..240614c263 100644 --- a/crates/agent2/src/tools.rs +++ b/crates/agent2/src/tools.rs @@ -1,5 +1,7 @@ mod find_path_tool; +mod read_file_tool; mod thinking_tool; pub use find_path_tool::*; +pub use read_file_tool::*; pub use thinking_tool::*; diff --git a/crates/agent2/src/tools/read_file_tool.rs b/crates/agent2/src/tools/read_file_tool.rs new file mode 100644 index 0000000000..30794ccdad --- /dev/null +++ b/crates/agent2/src/tools/read_file_tool.rs @@ -0,0 +1,970 @@ +use agent_client_protocol::{self as acp}; +use anyhow::{anyhow, Result}; +use assistant_tool::{outline, ActionLog}; +use gpui::{Entity, Task}; +use indoc::formatdoc; +use language::{Anchor, Point}; +use project::{AgentLocation, Project, WorktreeSettings}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::Settings; +use std::sync::Arc; +use ui::{App, SharedString}; + +use crate::{AgentTool, ToolCallEventStream}; + +/// Reads the content of the given file in the project. +/// +/// - Never attempt to read a path that hasn't been previously mentioned. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct ReadFileToolInput { + /// The relative path of the file to read. + /// + /// This path should never be absolute, and the first component + /// of the path should always be a root directory in a project. + /// + /// + /// If the project has the following root directories: + /// + /// - /a/b/directory1 + /// - /c/d/directory2 + /// + /// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`. + /// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`. + /// + pub path: String, + + /// Optional line number to start reading on (1-based index) + #[serde(default)] + pub start_line: Option, + + /// Optional line number to end reading on (1-based index, inclusive) + #[serde(default)] + pub end_line: Option, +} + +pub struct ReadFileTool { + project: Entity, + action_log: Entity, +} + +impl ReadFileTool { + pub fn new(project: Entity, action_log: Entity) -> Self { + Self { + project, + action_log, + } + } +} + +impl AgentTool for ReadFileTool { + type Input = ReadFileToolInput; + + fn name(&self) -> SharedString { + "read_file".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Read + } + + fn initial_title(&self, input: Self::Input) -> SharedString { + let path = &input.path; + match (input.start_line, input.end_line) { + (Some(start), Some(end)) => { + format!( + "[Read file `{}` (lines {}-{})](@selection:{}:({}-{}))", + path, start, end, path, start, end + ) + } + (Some(start), None) => { + format!( + "[Read file `{}` (from line {})](@selection:{}:({}-{}))", + path, start, path, start, start + ) + } + _ => format!("[Read file `{}`](@file:{})", path, path), + } + .into() + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else { + return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))); + }; + + // Error out if this path is either excluded or private in global settings + let global_settings = WorktreeSettings::get_global(cx); + if global_settings.is_path_excluded(&project_path.path) { + return Task::ready(Err(anyhow!( + "Cannot read file because its path matches the global `file_scan_exclusions` setting: {}", + &input.path + ))); + } + + if global_settings.is_path_private(&project_path.path) { + return Task::ready(Err(anyhow!( + "Cannot read file because its path matches the global `private_files` setting: {}", + &input.path + ))); + } + + // Error out if this path is either excluded or private in worktree settings + let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx); + if worktree_settings.is_path_excluded(&project_path.path) { + return Task::ready(Err(anyhow!( + "Cannot read file because its path matches the worktree `file_scan_exclusions` setting: {}", + &input.path + ))); + } + + if worktree_settings.is_path_private(&project_path.path) { + return Task::ready(Err(anyhow!( + "Cannot read file because its path matches the worktree `private_files` setting: {}", + &input.path + ))); + } + + let file_path = input.path.clone(); + + event_stream.send_update(acp::ToolCallUpdateFields { + locations: Some(vec![acp::ToolCallLocation { + path: project_path.path.to_path_buf(), + line: input.start_line, + // TODO (tracked): use full range + }]), + ..Default::default() + }); + + // TODO (tracked): images + // if image_store::is_image_file(&self.project, &project_path, cx) { + // let model = &self.thread.read(cx).selected_model; + + // if !model.supports_images() { + // return Task::ready(Err(anyhow!( + // "Attempted to read an image, but Zed doesn't currently support sending images to {}.", + // model.name().0 + // ))) + // .into(); + // } + + // return cx.spawn(async move |cx| -> Result { + // let image_entity: Entity = cx + // .update(|cx| { + // self.project.update(cx, |project, cx| { + // project.open_image(project_path.clone(), cx) + // }) + // })? + // .await?; + + // let image = + // image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?; + + // let language_model_image = cx + // .update(|cx| LanguageModelImage::from_image(image, cx))? + // .await + // .context("processing image")?; + + // Ok(ToolResultOutput { + // content: ToolResultContent::Image(language_model_image), + // output: None, + // }) + // }); + // } + // + + let project = self.project.clone(); + let action_log = self.action_log.clone(); + + cx.spawn(async move |cx| { + let buffer = cx + .update(|cx| { + project.update(cx, |project, cx| project.open_buffer(project_path, cx)) + })? + .await?; + if buffer.read_with(cx, |buffer, _| { + buffer + .file() + .as_ref() + .map_or(true, |file| !file.disk_state().exists()) + })? { + anyhow::bail!("{file_path} not found"); + } + + project.update(cx, |project, cx| { + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position: Anchor::MIN, + }), + cx, + ); + })?; + + // Check if specific line ranges are provided + if input.start_line.is_some() || input.end_line.is_some() { + let mut anchor = None; + let result = buffer.read_with(cx, |buffer, _cx| { + let text = buffer.text(); + // .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0. + let start = input.start_line.unwrap_or(1).max(1); + let start_row = start - 1; + if start_row <= buffer.max_point().row { + let column = buffer.line_indent_for_row(start_row).raw_len(); + anchor = Some(buffer.anchor_before(Point::new(start_row, column))); + } + + let lines = text.split('\n').skip(start_row as usize); + if let Some(end) = input.end_line { + let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line + itertools::intersperse(lines.take(count as usize), "\n").collect::() + } else { + itertools::intersperse(lines, "\n").collect::() + } + })?; + + action_log.update(cx, |log, cx| { + log.buffer_read(buffer.clone(), cx); + })?; + + if let Some(anchor) = anchor { + project.update(cx, |project, cx| { + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position: anchor, + }), + cx, + ); + })?; + } + + Ok(result) + } else { + // No line ranges specified, so check file size to see if it's too big. + let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?; + + if file_size <= outline::AUTO_OUTLINE_SIZE { + // File is small enough, so return its contents. + let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?; + + action_log.update(cx, |log, cx| { + log.buffer_read(buffer, cx); + })?; + + Ok(result) + } else { + // File is too big, so return the outline + // and a suggestion to read again with line numbers. + let outline = + outline::file_outline(project, file_path, action_log, None, cx).await?; + Ok(formatdoc! {" + This file was too big to read all at once. + + Here is an outline of its symbols: + + {outline} + + Using the line numbers in this outline, you can call this tool again + while specifying the start_line and end_line fields to see the + implementations of symbols in the outline. + + Alternatively, you can fall back to the `grep` tool (if available) + to search the file for specific content." + }) + } + } + }) + } +} + +#[cfg(test)] +mod test { + use crate::TestToolCallEventStream; + + use super::*; + use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; + use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher}; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + #[gpui::test] + async fn test_read_nonexistent_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/root"), json!({})).await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let tool = Arc::new(ReadFileTool::new(project, action_log)); + let event_stream = TestToolCallEventStream::new(); + + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/nonexistent_file.txt".to_string(), + start_line: None, + end_line: None, + }; + tool.run(input, event_stream.stream(), cx) + }) + .await; + assert_eq!( + result.unwrap_err().to_string(), + "root/nonexistent_file.txt not found" + ); + } + #[gpui::test] + async fn test_read_small_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "small_file.txt": "This is a small file content" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let tool = Arc::new(ReadFileTool::new(project, action_log)); + let event_stream = TestToolCallEventStream::new(); + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/small_file.txt".into(), + start_line: None, + end_line: None, + }; + tool.run(input, event_stream.stream(), cx) + }) + .await; + assert_eq!(result.unwrap(), "This is a small file content"); + } + + #[gpui::test] + async fn test_read_large_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n a: u32,\n b: usize,\n}}", i)).collect::>().join("\n") + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + language_registry.add(Arc::new(rust_lang())); + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let tool = Arc::new(ReadFileTool::new(project, action_log)); + let event_stream = TestToolCallEventStream::new(); + let content = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/large_file.rs".into(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await + .unwrap(); + + assert_eq!( + content.lines().skip(4).take(6).collect::>(), + vec![ + "struct Test0 [L1-4]", + " a [L2]", + " b [L3]", + "struct Test1 [L5-8]", + " a [L6]", + " b [L7]", + ] + ); + + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/large_file.rs".into(), + start_line: None, + end_line: None, + }; + tool.run(input, event_stream.stream(), cx) + }) + .await; + let content = result.unwrap(); + let expected_content = (0..1000) + .flat_map(|i| { + vec![ + format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4), + format!(" a [L{}]", i * 4 + 2), + format!(" b [L{}]", i * 4 + 3), + ] + }) + .collect::>(); + pretty_assertions::assert_eq!( + content + .lines() + .skip(4) + .take(expected_content.len()) + .collect::>(), + expected_content + ); + } + + #[gpui::test] + async fn test_read_file_with_line_range(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let tool = Arc::new(ReadFileTool::new(project, action_log)); + let event_stream = TestToolCallEventStream::new(); + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/multiline.txt".to_string(), + start_line: Some(2), + end_line: Some(4), + }; + tool.run(input, event_stream.stream(), cx) + }) + .await; + assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4"); + } + + #[gpui::test] + async fn test_read_file_line_range_edge_cases(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let tool = Arc::new(ReadFileTool::new(project, action_log)); + let event_stream = TestToolCallEventStream::new(); + + // start_line of 0 should be treated as 1 + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/multiline.txt".to_string(), + start_line: Some(0), + end_line: Some(2), + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + assert_eq!(result.unwrap(), "Line 1\nLine 2"); + + // end_line of 0 should result in at least 1 line + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/multiline.txt".to_string(), + start_line: Some(1), + end_line: Some(0), + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + assert_eq!(result.unwrap(), "Line 1"); + + // when start_line > end_line, should still return at least 1 line + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "root/multiline.txt".to_string(), + start_line: Some(3), + end_line: Some(2), + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + assert_eq!(result.unwrap(), "Line 3"); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_outline_query( + r#" + (line_comment) @annotation + + (struct_item + "struct" @context + name: (_) @name) @item + (enum_item + "enum" @context + name: (_) @name) @item + (enum_variant + name: (_) @name) @item + (field_declaration + name: (_) @name) @item + (impl_item + "impl" @context + trait: (_)? @name + "for"? @context + type: (_) @name + body: (_ "{" (_)* "}")) @item + (function_item + "fn" @context + name: (_) @name) @item + (mod_item + "mod" @context + name: (_) @name) @item + "#, + ) + .unwrap() + } + + #[gpui::test] + async fn test_read_file_security(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + + fs.insert_tree( + path!("/"), + json!({ + "project_root": { + "allowed_file.txt": "This file is in the project", + ".mysecrets": "SECRET_KEY=abc123", + ".secretdir": { + "config": "special configuration" + }, + ".mymetadata": "custom metadata", + "subdir": { + "normal_file.txt": "Normal file content", + "special.privatekey": "private key content", + "data.mysensitive": "sensitive data" + } + }, + "outside_project": { + "sensitive_file.txt": "This file is outside the project" + } + }), + ) + .await; + + cx.update(|cx| { + use gpui::UpdateGlobal; + use project::WorktreeSettings; + use settings::SettingsStore; + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::(cx, |settings| { + settings.file_scan_exclusions = Some(vec![ + "**/.secretdir".to_string(), + "**/.mymetadata".to_string(), + ]); + settings.private_files = Some(vec![ + "**/.mysecrets".to_string(), + "**/*.privatekey".to_string(), + "**/*.mysensitive".to_string(), + ]); + }); + }); + }); + + let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let tool = Arc::new(ReadFileTool::new(project, action_log)); + let event_stream = TestToolCallEventStream::new(); + + // Reading a file outside the project worktree should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "/outside_project/sensitive_file.txt".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + assert!( + result.is_err(), + "read_file_tool should error when attempting to read an absolute path outside a worktree" + ); + + // Reading a file within the project should succeed + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/allowed_file.txt".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + assert!( + result.is_ok(), + "read_file_tool should be able to read files inside worktrees" + ); + + // Reading files that match file_scan_exclusions should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/.secretdir/config".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + assert!( + result.is_err(), + "read_file_tool should error when attempting to read files in .secretdir (file_scan_exclusions)" + ); + + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/.mymetadata".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + assert!( + result.is_err(), + "read_file_tool should error when attempting to read .mymetadata files (file_scan_exclusions)" + ); + + // Reading private files should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/.mysecrets".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + assert!( + result.is_err(), + "read_file_tool should error when attempting to read .mysecrets (private_files)" + ); + + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/subdir/special.privatekey".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + assert!( + result.is_err(), + "read_file_tool should error when attempting to read .privatekey files (private_files)" + ); + + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/subdir/data.mysensitive".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + assert!( + result.is_err(), + "read_file_tool should error when attempting to read .mysensitive files (private_files)" + ); + + // Reading a normal file should still work, even with private_files configured + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/subdir/normal_file.txt".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + assert!(result.is_ok(), "Should be able to read normal files"); + assert_eq!(result.unwrap(), "Normal file content"); + + // Path traversal attempts with .. should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "project_root/../outside_project/sensitive_file.txt".to_string(), + start_line: None, + end_line: None, + }; + tool.run(input, event_stream.stream(), cx) + }) + .await; + assert!( + result.is_err(), + "read_file_tool should error when attempting to read a relative path that resolves to outside a worktree" + ); + } + + #[gpui::test] + async fn test_read_file_with_multiple_worktree_settings(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + + // Create first worktree with its own private_files setting + fs.insert_tree( + path!("/worktree1"), + json!({ + "src": { + "main.rs": "fn main() { println!(\"Hello from worktree1\"); }", + "secret.rs": "const API_KEY: &str = \"secret_key_1\";", + "config.toml": "[database]\nurl = \"postgres://localhost/db1\"" + }, + "tests": { + "test.rs": "mod tests { fn test_it() {} }", + "fixture.sql": "CREATE TABLE users (id INT, name VARCHAR(255));" + }, + ".zed": { + "settings.json": r#"{ + "file_scan_exclusions": ["**/fixture.*"], + "private_files": ["**/secret.rs", "**/config.toml"] + }"# + } + }), + ) + .await; + + // Create second worktree with different private_files setting + fs.insert_tree( + path!("/worktree2"), + json!({ + "lib": { + "public.js": "export function greet() { return 'Hello from worktree2'; }", + "private.js": "const SECRET_TOKEN = \"private_token_2\";", + "data.json": "{\"api_key\": \"json_secret_key\"}" + }, + "docs": { + "README.md": "# Public Documentation", + "internal.md": "# Internal Secrets and Configuration" + }, + ".zed": { + "settings.json": r#"{ + "file_scan_exclusions": ["**/internal.*"], + "private_files": ["**/private.js", "**/data.json"] + }"# + } + }), + ) + .await; + + // Set global settings + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::(cx, |settings| { + settings.file_scan_exclusions = + Some(vec!["**/.git".to_string(), "**/node_modules".to_string()]); + settings.private_files = Some(vec!["**/.env".to_string()]); + }); + }); + }); + + let project = Project::test( + fs.clone(), + [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], + cx, + ) + .await; + + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone())); + let event_stream = TestToolCallEventStream::new(); + + // Test reading allowed files in worktree1 + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "worktree1/src/main.rs".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await + .unwrap(); + + assert_eq!(result, "fn main() { println!(\"Hello from worktree1\"); }"); + + // Test reading private file in worktree1 should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "worktree1/src/secret.rs".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("worktree `private_files` setting"), + "Error should mention worktree private_files setting" + ); + + // Test reading excluded file in worktree1 should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "worktree1/tests/fixture.sql".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("worktree `file_scan_exclusions` setting"), + "Error should mention worktree file_scan_exclusions setting" + ); + + // Test reading allowed files in worktree2 + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "worktree2/lib/public.js".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await + .unwrap(); + + assert_eq!( + result, + "export function greet() { return 'Hello from worktree2'; }" + ); + + // Test reading private file in worktree2 should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "worktree2/lib/private.js".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("worktree `private_files` setting"), + "Error should mention worktree private_files setting" + ); + + // Test reading excluded file in worktree2 should fail + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "worktree2/docs/internal.md".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("worktree `file_scan_exclusions` setting"), + "Error should mention worktree file_scan_exclusions setting" + ); + + // Test that files allowed in one worktree but not in another are handled correctly + // (e.g., config.toml is private in worktree1 but doesn't exist in worktree2) + let result = cx + .update(|cx| { + let input = ReadFileToolInput { + path: "worktree1/src/config.toml".to_string(), + start_line: None, + end_line: None, + }; + tool.clone().run(input, event_stream.stream(), cx) + }) + .await; + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("worktree `private_files` setting"), + "Config.toml should be blocked by worktree1's private_files setting" + ); + } +}