agent: Show a notice when reaching consecutive tool use limits (#29833)

This PR adds a notice when reaching consecutive tool use limits when
using normal mode.

Here's an example with the limit artificially lowered to 2 consecutive
tool uses:


https://github.com/user-attachments/assets/32da8d38-67de-4d6b-8f24-754d2518e5d4

Release Notes:

- agent: Added a notice when reaching consecutive tool use limits when
using a model in normal mode.
This commit is contained in:
Marshall Bowers 2025-05-02 22:09:54 -04:00 committed by GitHub
parent 10a7f2a972
commit f0515d1c34
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 134 additions and 25 deletions

View file

@ -9,11 +9,12 @@ use futures::{
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
ModelRequestLimitReachedError, QueueState, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
AuthenticateError, CloudModel, CompletionRequestStatus, LanguageModel,
LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
@ -38,6 +39,7 @@ use zed_llm_client::{
CURRENT_PLAN_HEADER_NAME, CompletionBody, CountTokensBody, CountTokensResponse,
EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
TOOL_USE_LIMIT_REACHED_HEADER_NAME,
};
use crate::AllLanguageModelSettings;
@ -511,6 +513,13 @@ pub struct CloudLanguageModel {
request_limiter: RateLimiter,
}
struct PerformLlmCompletionResponse {
response: Response<AsyncBody>,
usage: Option<RequestUsage>,
tool_use_limit_reached: bool,
includes_queue_events: bool,
}
impl CloudLanguageModel {
const MAX_RETRIES: usize = 3;
@ -518,7 +527,7 @@ impl CloudLanguageModel {
client: Arc<Client>,
llm_api_token: LlmApiToken,
body: CompletionBody,
) -> Result<(Response<AsyncBody>, Option<RequestUsage>, bool)> {
) -> Result<PerformLlmCompletionResponse> {
let http_client = &client.http_client();
let mut token = llm_api_token.acquire(&client).await?;
@ -545,9 +554,18 @@ impl CloudLanguageModel {
.headers()
.get("x-zed-server-supports-queueing")
.is_some();
let tool_use_limit_reached = response
.headers()
.get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
.is_some();
let usage = RequestUsage::from_headers(response.headers()).ok();
return Ok((response, usage, includes_queue_events));
return Ok(PerformLlmCompletionResponse {
response,
usage,
includes_queue_events,
tool_use_limit_reached,
});
} else if response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
@ -787,7 +805,12 @@ impl LanguageModel for CloudLanguageModel {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move {
let (response, usage, includes_queue_events) = Self::perform_llm_completion(
let PerformLlmCompletionResponse {
response,
usage,
includes_queue_events,
tool_use_limit_reached,
} = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
@ -819,7 +842,10 @@ impl LanguageModel for CloudLanguageModel {
let mut mapper = AnthropicEventMapper::new();
Ok((
map_cloud_completion_events(
Box::pin(response_lines(response, includes_queue_events)),
Box::pin(
response_lines(response, includes_queue_events)
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
move |event| mapper.map_event(event),
),
usage,
@ -836,7 +862,12 @@ impl LanguageModel for CloudLanguageModel {
let request = into_open_ai(request, model, model.max_output_tokens());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move {
let (response, usage, includes_queue_events) = Self::perform_llm_completion(
let PerformLlmCompletionResponse {
response,
usage,
includes_queue_events,
tool_use_limit_reached,
} = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
@ -853,7 +884,10 @@ impl LanguageModel for CloudLanguageModel {
let mut mapper = OpenAiEventMapper::new();
Ok((
map_cloud_completion_events(
Box::pin(response_lines(response, includes_queue_events)),
Box::pin(
response_lines(response, includes_queue_events)
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
move |event| mapper.map_event(event),
),
usage,
@ -870,7 +904,12 @@ impl LanguageModel for CloudLanguageModel {
let request = into_google(request, model.id().into());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move {
let (response, usage, includes_queue_events) = Self::perform_llm_completion(
let PerformLlmCompletionResponse {
response,
usage,
includes_queue_events,
tool_use_limit_reached,
} = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
@ -883,10 +922,14 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
let mut mapper = GoogleEventMapper::new();
Ok((
map_cloud_completion_events(
Box::pin(response_lines(response, includes_queue_events)),
Box::pin(
response_lines(response, includes_queue_events)
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
move |event| mapper.map_event(event),
),
usage,
@ -905,7 +948,7 @@ impl LanguageModel for CloudLanguageModel {
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CloudCompletionEvent<T> {
Queue(QueueState),
System(CompletionRequestStatus),
Event(T),
}
@ -925,7 +968,7 @@ where
Err(error) => {
vec![Err(LanguageModelCompletionError::Other(error))]
}
Ok(CloudCompletionEvent::Queue(event)) => {
Ok(CloudCompletionEvent::System(event)) => {
vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))]
}
Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
@ -934,6 +977,16 @@ where
.boxed()
}
fn tool_use_limit_reached_event<T>(
tool_use_limit_reached: bool,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
futures::stream::iter(tool_use_limit_reached.then(|| {
Ok(CloudCompletionEvent::System(
CompletionRequestStatus::ToolUseLimitReached,
))
}))
}
fn response_lines<T: DeserializeOwned>(
response: Response<AsyncBody>,
includes_queue_events: bool,