agent: Show a notice when reaching consecutive tool use limits (#29833)

This PR adds a notice when reaching consecutive tool use limits when
using normal mode.

Here's an example with the limit artificially lowered to 2 consecutive
tool uses:


https://github.com/user-attachments/assets/32da8d38-67de-4d6b-8f24-754d2518e5d4

Release Notes:

- agent: Added a notice when reaching consecutive tool use limits when
using a model in normal mode.
This commit is contained in:
Marshall Bowers 2025-05-02 22:09:54 -04:00 committed by GitHub
parent 10a7f2a972
commit f0515d1c34
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 134 additions and 25 deletions

4
Cargo.lock generated
View file

@ -18826,9 +18826,9 @@ dependencies = [
[[package]] [[package]]
name = "zed_llm_client" name = "zed_llm_client"
version = "0.7.1" version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc9ec491b7112cb8c2fba3c17d9a349d8ab695fb1a4ef6c5c4b9fd8d7aa975c1" checksum = "226e0b479b3aed072d83db276866d54bce631e3a8600fcdf4f309d73389af9c7"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"serde", "serde",

View file

@ -611,7 +611,7 @@ wasmtime-wasi = "29"
which = "6.0.0" which = "6.0.0"
wit-component = "0.221" wit-component = "0.221"
workspace-hack = "0.1.0" workspace-hack = "0.1.0"
zed_llm_client = "0.7.1" zed_llm_client = "0.7.2"
zstd = "0.11" zstd = "0.11"
[workspace.dependencies.async-stripe] [workspace.dependencies.async-stripe]

View file

@ -1957,6 +1957,41 @@ impl AssistantPanel {
Some(UsageBanner::new(plan, usage).into_any_element()) Some(UsageBanner::new(plan, usage).into_any_element())
} }
fn render_tool_use_limit_reached(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let tool_use_limit_reached = self
.thread
.read(cx)
.thread()
.read(cx)
.tool_use_limit_reached();
if !tool_use_limit_reached {
return None;
}
let model = self
.thread
.read(cx)
.thread()
.read(cx)
.configured_model()?
.model;
let max_mode_upsell = if model.supports_max_mode() {
" Enable max mode for unlimited tool use."
} else {
""
};
Some(
Banner::new()
.severity(ui::Severity::Info)
.children(h_flex().child(Label::new(format!(
"Consecutive tool use limit reached.{max_mode_upsell}"
))))
.into_any_element(),
)
}
fn render_last_error(&self, cx: &mut Context<Self>) -> Option<AnyElement> { fn render_last_error(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let last_error = self.thread.read(cx).last_error()?; let last_error = self.thread.read(cx).last_error()?;
@ -2238,6 +2273,7 @@ impl Render for AssistantPanel {
.map(|parent| match &self.active_view { .map(|parent| match &self.active_view {
ActiveView::Thread { .. } => parent ActiveView::Thread { .. } => parent
.child(self.render_active_thread_or_empty_state(window, cx)) .child(self.render_active_thread_or_empty_state(window, cx))
.children(self.render_tool_use_limit_reached(cx))
.children(self.render_usage_banner(cx)) .children(self.render_usage_banner(cx))
.child(h_flex().child(self.message_editor.clone())) .child(h_flex().child(self.message_editor.clone()))
.children(self.render_last_error(cx)), .children(self.render_last_error(cx)),

View file

@ -355,6 +355,7 @@ pub struct Thread {
request_token_usage: Vec<TokenUsage>, request_token_usage: Vec<TokenUsage>,
cumulative_token_usage: TokenUsage, cumulative_token_usage: TokenUsage,
exceeded_window_error: Option<ExceededWindowError>, exceeded_window_error: Option<ExceededWindowError>,
tool_use_limit_reached: bool,
feedback: Option<ThreadFeedback>, feedback: Option<ThreadFeedback>,
message_feedback: HashMap<MessageId, ThreadFeedback>, message_feedback: HashMap<MessageId, ThreadFeedback>,
last_auto_capture_at: Option<Instant>, last_auto_capture_at: Option<Instant>,
@ -417,6 +418,7 @@ impl Thread {
request_token_usage: Vec::new(), request_token_usage: Vec::new(),
cumulative_token_usage: TokenUsage::default(), cumulative_token_usage: TokenUsage::default(),
exceeded_window_error: None, exceeded_window_error: None,
tool_use_limit_reached: false,
feedback: None, feedback: None,
message_feedback: HashMap::default(), message_feedback: HashMap::default(),
last_auto_capture_at: None, last_auto_capture_at: None,
@ -524,6 +526,7 @@ impl Thread {
request_token_usage: serialized.request_token_usage, request_token_usage: serialized.request_token_usage,
cumulative_token_usage: serialized.cumulative_token_usage, cumulative_token_usage: serialized.cumulative_token_usage,
exceeded_window_error: None, exceeded_window_error: None,
tool_use_limit_reached: false,
feedback: None, feedback: None,
message_feedback: HashMap::default(), message_feedback: HashMap::default(),
last_auto_capture_at: None, last_auto_capture_at: None,
@ -814,6 +817,10 @@ impl Thread {
.unwrap_or(false) .unwrap_or(false)
} }
pub fn tool_use_limit_reached(&self) -> bool {
self.tool_use_limit_reached
}
/// Returns whether all of the tool uses have finished running. /// Returns whether all of the tool uses have finished running.
pub fn all_tools_finished(&self) -> bool { pub fn all_tools_finished(&self) -> bool {
// If the only pending tool uses left are the ones with errors, then // If the only pending tool uses left are the ones with errors, then
@ -1331,6 +1338,8 @@ impl Thread {
window: Option<AnyWindowHandle>, window: Option<AnyWindowHandle>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
self.tool_use_limit_reached = false;
let pending_completion_id = post_inc(&mut self.completion_count); let pending_completion_id = post_inc(&mut self.completion_count);
let mut request_callback_parameters = if self.request_callback.is_some() { let mut request_callback_parameters = if self.request_callback.is_some() {
Some((request.clone(), Vec::new())) Some((request.clone(), Vec::new()))
@ -1506,17 +1515,27 @@ impl Thread {
}); });
} }
} }
LanguageModelCompletionEvent::QueueUpdate(queue_event) => { LanguageModelCompletionEvent::QueueUpdate(status) => {
if let Some(completion) = thread if let Some(completion) = thread
.pending_completions .pending_completions
.iter_mut() .iter_mut()
.find(|completion| completion.id == pending_completion_id) .find(|completion| completion.id == pending_completion_id)
{ {
completion.queue_state = match queue_event { let queue_state = match status {
language_model::QueueState::Queued { position } => { language_model::CompletionRequestStatus::Queued {
QueueState::Queued { position } position,
} => Some(QueueState::Queued { position }),
language_model::CompletionRequestStatus::Started => {
Some(QueueState::Started)
} }
language_model::QueueState::Started => QueueState::Started, language_model::CompletionRequestStatus::ToolUseLimitReached => {
thread.tool_use_limit_reached = true;
None
}
};
if let Some(queue_state) = queue_state {
completion.queue_state = queue_state;
} }
} }
} }

View file

@ -66,15 +66,16 @@ pub struct LanguageModelCacheConfiguration {
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
#[serde(tag = "status", rename_all = "snake_case")] #[serde(tag = "status", rename_all = "snake_case")]
pub enum QueueState { pub enum CompletionRequestStatus {
Queued { position: usize }, Queued { position: usize },
Started, Started,
ToolUseLimitReached,
} }
/// A completion event from a language model. /// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent { pub enum LanguageModelCompletionEvent {
QueueUpdate(QueueState), QueueUpdate(CompletionRequestStatus),
Stop(StopReason), Stop(StopReason),
Text(String), Text(String),
Thinking { Thinking {

View file

@ -9,11 +9,12 @@ use futures::{
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task}; use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{ use language_model::{
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, AuthenticateError, CloudModel, CompletionRequestStatus, LanguageModel,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat, LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
ModelRequestLimitReachedError, QueueState, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID, LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
ZED_CLOUD_PROVIDER_ID,
}; };
use language_model::{ use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
@ -38,6 +39,7 @@ use zed_llm_client::{
CURRENT_PLAN_HEADER_NAME, CompletionBody, CountTokensBody, CountTokensResponse, CURRENT_PLAN_HEADER_NAME, CompletionBody, CountTokensBody, CountTokensResponse,
EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
TOOL_USE_LIMIT_REACHED_HEADER_NAME,
}; };
use crate::AllLanguageModelSettings; use crate::AllLanguageModelSettings;
@ -511,6 +513,13 @@ pub struct CloudLanguageModel {
request_limiter: RateLimiter, request_limiter: RateLimiter,
} }
struct PerformLlmCompletionResponse {
response: Response<AsyncBody>,
usage: Option<RequestUsage>,
tool_use_limit_reached: bool,
includes_queue_events: bool,
}
impl CloudLanguageModel { impl CloudLanguageModel {
const MAX_RETRIES: usize = 3; const MAX_RETRIES: usize = 3;
@ -518,7 +527,7 @@ impl CloudLanguageModel {
client: Arc<Client>, client: Arc<Client>,
llm_api_token: LlmApiToken, llm_api_token: LlmApiToken,
body: CompletionBody, body: CompletionBody,
) -> Result<(Response<AsyncBody>, Option<RequestUsage>, bool)> { ) -> Result<PerformLlmCompletionResponse> {
let http_client = &client.http_client(); let http_client = &client.http_client();
let mut token = llm_api_token.acquire(&client).await?; let mut token = llm_api_token.acquire(&client).await?;
@ -545,9 +554,18 @@ impl CloudLanguageModel {
.headers() .headers()
.get("x-zed-server-supports-queueing") .get("x-zed-server-supports-queueing")
.is_some(); .is_some();
let tool_use_limit_reached = response
.headers()
.get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
.is_some();
let usage = RequestUsage::from_headers(response.headers()).ok(); let usage = RequestUsage::from_headers(response.headers()).ok();
return Ok((response, usage, includes_queue_events)); return Ok(PerformLlmCompletionResponse {
response,
usage,
includes_queue_events,
tool_use_limit_reached,
});
} else if response } else if response
.headers() .headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME) .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
@ -787,7 +805,12 @@ impl LanguageModel for CloudLanguageModel {
let client = self.client.clone(); let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move { let future = self.request_limiter.stream_with_usage(async move {
let (response, usage, includes_queue_events) = Self::perform_llm_completion( let PerformLlmCompletionResponse {
response,
usage,
includes_queue_events,
tool_use_limit_reached,
} = Self::perform_llm_completion(
client.clone(), client.clone(),
llm_api_token, llm_api_token,
CompletionBody { CompletionBody {
@ -819,7 +842,10 @@ impl LanguageModel for CloudLanguageModel {
let mut mapper = AnthropicEventMapper::new(); let mut mapper = AnthropicEventMapper::new();
Ok(( Ok((
map_cloud_completion_events( map_cloud_completion_events(
Box::pin(response_lines(response, includes_queue_events)), Box::pin(
response_lines(response, includes_queue_events)
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
move |event| mapper.map_event(event), move |event| mapper.map_event(event),
), ),
usage, usage,
@ -836,7 +862,12 @@ impl LanguageModel for CloudLanguageModel {
let request = into_open_ai(request, model, model.max_output_tokens()); let request = into_open_ai(request, model, model.max_output_tokens());
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move { let future = self.request_limiter.stream_with_usage(async move {
let (response, usage, includes_queue_events) = Self::perform_llm_completion( let PerformLlmCompletionResponse {
response,
usage,
includes_queue_events,
tool_use_limit_reached,
} = Self::perform_llm_completion(
client.clone(), client.clone(),
llm_api_token, llm_api_token,
CompletionBody { CompletionBody {
@ -853,7 +884,10 @@ impl LanguageModel for CloudLanguageModel {
let mut mapper = OpenAiEventMapper::new(); let mut mapper = OpenAiEventMapper::new();
Ok(( Ok((
map_cloud_completion_events( map_cloud_completion_events(
Box::pin(response_lines(response, includes_queue_events)), Box::pin(
response_lines(response, includes_queue_events)
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
move |event| mapper.map_event(event), move |event| mapper.map_event(event),
), ),
usage, usage,
@ -870,7 +904,12 @@ impl LanguageModel for CloudLanguageModel {
let request = into_google(request, model.id().into()); let request = into_google(request, model.id().into());
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move { let future = self.request_limiter.stream_with_usage(async move {
let (response, usage, includes_queue_events) = Self::perform_llm_completion( let PerformLlmCompletionResponse {
response,
usage,
includes_queue_events,
tool_use_limit_reached,
} = Self::perform_llm_completion(
client.clone(), client.clone(),
llm_api_token, llm_api_token,
CompletionBody { CompletionBody {
@ -883,10 +922,14 @@ impl LanguageModel for CloudLanguageModel {
}, },
) )
.await?; .await?;
let mut mapper = GoogleEventMapper::new(); let mut mapper = GoogleEventMapper::new();
Ok(( Ok((
map_cloud_completion_events( map_cloud_completion_events(
Box::pin(response_lines(response, includes_queue_events)), Box::pin(
response_lines(response, includes_queue_events)
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
move |event| mapper.map_event(event), move |event| mapper.map_event(event),
), ),
usage, usage,
@ -905,7 +948,7 @@ impl LanguageModel for CloudLanguageModel {
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum CloudCompletionEvent<T> { pub enum CloudCompletionEvent<T> {
Queue(QueueState), System(CompletionRequestStatus),
Event(T), Event(T),
} }
@ -925,7 +968,7 @@ where
Err(error) => { Err(error) => {
vec![Err(LanguageModelCompletionError::Other(error))] vec![Err(LanguageModelCompletionError::Other(error))]
} }
Ok(CloudCompletionEvent::Queue(event)) => { Ok(CloudCompletionEvent::System(event)) => {
vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))] vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))]
} }
Ok(CloudCompletionEvent::Event(event)) => map_callback(event), Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
@ -934,6 +977,16 @@ where
.boxed() .boxed()
} }
fn tool_use_limit_reached_event<T>(
tool_use_limit_reached: bool,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
futures::stream::iter(tool_use_limit_reached.then(|| {
Ok(CloudCompletionEvent::System(
CompletionRequestStatus::ToolUseLimitReached,
))
}))
}
fn response_lines<T: DeserializeOwned>( fn response_lines<T: DeserializeOwned>(
response: Response<AsyncBody>, response: Response<AsyncBody>,
includes_queue_events: bool, includes_queue_events: bool,