From f0515d1c34c55ecee5b041bd694786072cd46dc7 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 2 May 2025 22:09:54 -0400 Subject: [PATCH] agent: Show a notice when reaching consecutive tool use limits (#29833) This PR adds a notice when reaching consecutive tool use limits when using normal mode. Here's an example with the limit artificially lowered to 2 consecutive tool uses: https://github.com/user-attachments/assets/32da8d38-67de-4d6b-8f24-754d2518e5d4 Release Notes: - agent: Added a notice when reaching consecutive tool use limits when using a model in normal mode. --- Cargo.lock | 4 +- Cargo.toml | 2 +- crates/agent/src/assistant_panel.rs | 36 +++++++++ crates/agent/src/thread.rs | 29 +++++-- crates/language_model/src/language_model.rs | 5 +- crates/language_models/src/provider/cloud.rs | 83 ++++++++++++++++---- 6 files changed, 134 insertions(+), 25 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6a7e05d563..6a76f96cd9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18826,9 +18826,9 @@ dependencies = [ [[package]] name = "zed_llm_client" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc9ec491b7112cb8c2fba3c17d9a349d8ab695fb1a4ef6c5c4b9fd8d7aa975c1" +checksum = "226e0b479b3aed072d83db276866d54bce631e3a8600fcdf4f309d73389af9c7" dependencies = [ "anyhow", "serde", diff --git a/Cargo.toml b/Cargo.toml index 1fc04f838e..df699ef5a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -611,7 +611,7 @@ wasmtime-wasi = "29" which = "6.0.0" wit-component = "0.221" workspace-hack = "0.1.0" -zed_llm_client = "0.7.1" +zed_llm_client = "0.7.2" zstd = "0.11" [workspace.dependencies.async-stripe] diff --git a/crates/agent/src/assistant_panel.rs b/crates/agent/src/assistant_panel.rs index 8a7b5676c0..b641f76792 100644 --- a/crates/agent/src/assistant_panel.rs +++ b/crates/agent/src/assistant_panel.rs @@ -1957,6 +1957,41 @@ impl AssistantPanel { Some(UsageBanner::new(plan, usage).into_any_element()) } + fn render_tool_use_limit_reached(&self, cx: &mut Context) -> Option { + let tool_use_limit_reached = self + .thread + .read(cx) + .thread() + .read(cx) + .tool_use_limit_reached(); + if !tool_use_limit_reached { + return None; + } + + let model = self + .thread + .read(cx) + .thread() + .read(cx) + .configured_model()? + .model; + + let max_mode_upsell = if model.supports_max_mode() { + " Enable max mode for unlimited tool use." + } else { + "" + }; + + Some( + Banner::new() + .severity(ui::Severity::Info) + .children(h_flex().child(Label::new(format!( + "Consecutive tool use limit reached.{max_mode_upsell}" + )))) + .into_any_element(), + ) + } + fn render_last_error(&self, cx: &mut Context) -> Option { let last_error = self.thread.read(cx).last_error()?; @@ -2238,6 +2273,7 @@ impl Render for AssistantPanel { .map(|parent| match &self.active_view { ActiveView::Thread { .. } => parent .child(self.render_active_thread_or_empty_state(window, cx)) + .children(self.render_tool_use_limit_reached(cx)) .children(self.render_usage_banner(cx)) .child(h_flex().child(self.message_editor.clone())) .children(self.render_last_error(cx)), diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index d1611efc7c..4a90916888 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -355,6 +355,7 @@ pub struct Thread { request_token_usage: Vec, cumulative_token_usage: TokenUsage, exceeded_window_error: Option, + tool_use_limit_reached: bool, feedback: Option, message_feedback: HashMap, last_auto_capture_at: Option, @@ -417,6 +418,7 @@ impl Thread { request_token_usage: Vec::new(), cumulative_token_usage: TokenUsage::default(), exceeded_window_error: None, + tool_use_limit_reached: false, feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, @@ -524,6 +526,7 @@ impl Thread { request_token_usage: serialized.request_token_usage, cumulative_token_usage: serialized.cumulative_token_usage, exceeded_window_error: None, + tool_use_limit_reached: false, feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, @@ -814,6 +817,10 @@ impl Thread { .unwrap_or(false) } + pub fn tool_use_limit_reached(&self) -> bool { + self.tool_use_limit_reached + } + /// Returns whether all of the tool uses have finished running. pub fn all_tools_finished(&self) -> bool { // If the only pending tool uses left are the ones with errors, then @@ -1331,6 +1338,8 @@ impl Thread { window: Option, cx: &mut Context, ) { + self.tool_use_limit_reached = false; + let pending_completion_id = post_inc(&mut self.completion_count); let mut request_callback_parameters = if self.request_callback.is_some() { Some((request.clone(), Vec::new())) @@ -1506,17 +1515,27 @@ impl Thread { }); } } - LanguageModelCompletionEvent::QueueUpdate(queue_event) => { + LanguageModelCompletionEvent::QueueUpdate(status) => { if let Some(completion) = thread .pending_completions .iter_mut() .find(|completion| completion.id == pending_completion_id) { - completion.queue_state = match queue_event { - language_model::QueueState::Queued { position } => { - QueueState::Queued { position } + let queue_state = match status { + language_model::CompletionRequestStatus::Queued { + position, + } => Some(QueueState::Queued { position }), + language_model::CompletionRequestStatus::Started => { + Some(QueueState::Started) } - language_model::QueueState::Started => QueueState::Started, + language_model::CompletionRequestStatus::ToolUseLimitReached => { + thread.tool_use_limit_reached = true; + None + } + }; + + if let Some(queue_state) = queue_state { + completion.queue_state = queue_state; } } } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 1146bbc137..7ff3fc86ec 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -66,15 +66,16 @@ pub struct LanguageModelCacheConfiguration { #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] #[serde(tag = "status", rename_all = "snake_case")] -pub enum QueueState { +pub enum CompletionRequestStatus { Queued { position: usize }, Started, + ToolUseLimitReached, } /// A completion event from a language model. #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub enum LanguageModelCompletionEvent { - QueueUpdate(QueueState), + QueueUpdate(CompletionRequestStatus), Stop(StopReason), Text(String), Thinking { diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 556be2c75d..68c494234e 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -9,11 +9,12 @@ use futures::{ use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task}; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use language_model::{ - AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, - LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, - LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat, - ModelRequestLimitReachedError, QueueState, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID, + AuthenticateError, CloudModel, CompletionRequestStatus, LanguageModel, + LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId, + LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, + LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage, + ZED_CLOUD_PROVIDER_ID, }; use language_model::{ LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, @@ -38,6 +39,7 @@ use zed_llm_client::{ CURRENT_PLAN_HEADER_NAME, CompletionBody, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, + TOOL_USE_LIMIT_REACHED_HEADER_NAME, }; use crate::AllLanguageModelSettings; @@ -511,6 +513,13 @@ pub struct CloudLanguageModel { request_limiter: RateLimiter, } +struct PerformLlmCompletionResponse { + response: Response, + usage: Option, + tool_use_limit_reached: bool, + includes_queue_events: bool, +} + impl CloudLanguageModel { const MAX_RETRIES: usize = 3; @@ -518,7 +527,7 @@ impl CloudLanguageModel { client: Arc, llm_api_token: LlmApiToken, body: CompletionBody, - ) -> Result<(Response, Option, bool)> { + ) -> Result { let http_client = &client.http_client(); let mut token = llm_api_token.acquire(&client).await?; @@ -545,9 +554,18 @@ impl CloudLanguageModel { .headers() .get("x-zed-server-supports-queueing") .is_some(); + let tool_use_limit_reached = response + .headers() + .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME) + .is_some(); let usage = RequestUsage::from_headers(response.headers()).ok(); - return Ok((response, usage, includes_queue_events)); + return Ok(PerformLlmCompletionResponse { + response, + usage, + includes_queue_events, + tool_use_limit_reached, + }); } else if response .headers() .get(EXPIRED_LLM_TOKEN_HEADER_NAME) @@ -787,7 +805,12 @@ impl LanguageModel for CloudLanguageModel { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream_with_usage(async move { - let (response, usage, includes_queue_events) = Self::perform_llm_completion( + let PerformLlmCompletionResponse { + response, + usage, + includes_queue_events, + tool_use_limit_reached, + } = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -819,7 +842,10 @@ impl LanguageModel for CloudLanguageModel { let mut mapper = AnthropicEventMapper::new(); Ok(( map_cloud_completion_events( - Box::pin(response_lines(response, includes_queue_events)), + Box::pin( + response_lines(response, includes_queue_events) + .chain(tool_use_limit_reached_event(tool_use_limit_reached)), + ), move |event| mapper.map_event(event), ), usage, @@ -836,7 +862,12 @@ impl LanguageModel for CloudLanguageModel { let request = into_open_ai(request, model, model.max_output_tokens()); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream_with_usage(async move { - let (response, usage, includes_queue_events) = Self::perform_llm_completion( + let PerformLlmCompletionResponse { + response, + usage, + includes_queue_events, + tool_use_limit_reached, + } = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -853,7 +884,10 @@ impl LanguageModel for CloudLanguageModel { let mut mapper = OpenAiEventMapper::new(); Ok(( map_cloud_completion_events( - Box::pin(response_lines(response, includes_queue_events)), + Box::pin( + response_lines(response, includes_queue_events) + .chain(tool_use_limit_reached_event(tool_use_limit_reached)), + ), move |event| mapper.map_event(event), ), usage, @@ -870,7 +904,12 @@ impl LanguageModel for CloudLanguageModel { let request = into_google(request, model.id().into()); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream_with_usage(async move { - let (response, usage, includes_queue_events) = Self::perform_llm_completion( + let PerformLlmCompletionResponse { + response, + usage, + includes_queue_events, + tool_use_limit_reached, + } = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -883,10 +922,14 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; + let mut mapper = GoogleEventMapper::new(); Ok(( map_cloud_completion_events( - Box::pin(response_lines(response, includes_queue_events)), + Box::pin( + response_lines(response, includes_queue_events) + .chain(tool_use_limit_reached_event(tool_use_limit_reached)), + ), move |event| mapper.map_event(event), ), usage, @@ -905,7 +948,7 @@ impl LanguageModel for CloudLanguageModel { #[derive(Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum CloudCompletionEvent { - Queue(QueueState), + System(CompletionRequestStatus), Event(T), } @@ -925,7 +968,7 @@ where Err(error) => { vec![Err(LanguageModelCompletionError::Other(error))] } - Ok(CloudCompletionEvent::Queue(event)) => { + Ok(CloudCompletionEvent::System(event)) => { vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))] } Ok(CloudCompletionEvent::Event(event)) => map_callback(event), @@ -934,6 +977,16 @@ where .boxed() } +fn tool_use_limit_reached_event( + tool_use_limit_reached: bool, +) -> impl Stream>> { + futures::stream::iter(tool_use_limit_reached.then(|| { + Ok(CloudCompletionEvent::System( + CompletionRequestStatus::ToolUseLimitReached, + )) + })) +} + fn response_lines( response: Response, includes_queue_events: bool,