agent: Show a notice when reaching consecutive tool use limits (#29833)
This PR adds a notice when reaching consecutive tool use limits when using normal mode. Here's an example with the limit artificially lowered to 2 consecutive tool uses: https://github.com/user-attachments/assets/32da8d38-67de-4d6b-8f24-754d2518e5d4 Release Notes: - agent: Added a notice when reaching consecutive tool use limits when using a model in normal mode.
This commit is contained in:
parent
10a7f2a972
commit
f0515d1c34
6 changed files with 134 additions and 25 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
@ -18826,9 +18826,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "zed_llm_client"
|
||||
version = "0.7.1"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cc9ec491b7112cb8c2fba3c17d9a349d8ab695fb1a4ef6c5c4b9fd8d7aa975c1"
|
||||
checksum = "226e0b479b3aed072d83db276866d54bce631e3a8600fcdf4f309d73389af9c7"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"serde",
|
||||
|
|
|
@ -611,7 +611,7 @@ wasmtime-wasi = "29"
|
|||
which = "6.0.0"
|
||||
wit-component = "0.221"
|
||||
workspace-hack = "0.1.0"
|
||||
zed_llm_client = "0.7.1"
|
||||
zed_llm_client = "0.7.2"
|
||||
zstd = "0.11"
|
||||
|
||||
[workspace.dependencies.async-stripe]
|
||||
|
|
|
@ -1957,6 +1957,41 @@ impl AssistantPanel {
|
|||
Some(UsageBanner::new(plan, usage).into_any_element())
|
||||
}
|
||||
|
||||
fn render_tool_use_limit_reached(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
|
||||
let tool_use_limit_reached = self
|
||||
.thread
|
||||
.read(cx)
|
||||
.thread()
|
||||
.read(cx)
|
||||
.tool_use_limit_reached();
|
||||
if !tool_use_limit_reached {
|
||||
return None;
|
||||
}
|
||||
|
||||
let model = self
|
||||
.thread
|
||||
.read(cx)
|
||||
.thread()
|
||||
.read(cx)
|
||||
.configured_model()?
|
||||
.model;
|
||||
|
||||
let max_mode_upsell = if model.supports_max_mode() {
|
||||
" Enable max mode for unlimited tool use."
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
Some(
|
||||
Banner::new()
|
||||
.severity(ui::Severity::Info)
|
||||
.children(h_flex().child(Label::new(format!(
|
||||
"Consecutive tool use limit reached.{max_mode_upsell}"
|
||||
))))
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_last_error(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
|
||||
let last_error = self.thread.read(cx).last_error()?;
|
||||
|
||||
|
@ -2238,6 +2273,7 @@ impl Render for AssistantPanel {
|
|||
.map(|parent| match &self.active_view {
|
||||
ActiveView::Thread { .. } => parent
|
||||
.child(self.render_active_thread_or_empty_state(window, cx))
|
||||
.children(self.render_tool_use_limit_reached(cx))
|
||||
.children(self.render_usage_banner(cx))
|
||||
.child(h_flex().child(self.message_editor.clone()))
|
||||
.children(self.render_last_error(cx)),
|
||||
|
|
|
@ -355,6 +355,7 @@ pub struct Thread {
|
|||
request_token_usage: Vec<TokenUsage>,
|
||||
cumulative_token_usage: TokenUsage,
|
||||
exceeded_window_error: Option<ExceededWindowError>,
|
||||
tool_use_limit_reached: bool,
|
||||
feedback: Option<ThreadFeedback>,
|
||||
message_feedback: HashMap<MessageId, ThreadFeedback>,
|
||||
last_auto_capture_at: Option<Instant>,
|
||||
|
@ -417,6 +418,7 @@ impl Thread {
|
|||
request_token_usage: Vec::new(),
|
||||
cumulative_token_usage: TokenUsage::default(),
|
||||
exceeded_window_error: None,
|
||||
tool_use_limit_reached: false,
|
||||
feedback: None,
|
||||
message_feedback: HashMap::default(),
|
||||
last_auto_capture_at: None,
|
||||
|
@ -524,6 +526,7 @@ impl Thread {
|
|||
request_token_usage: serialized.request_token_usage,
|
||||
cumulative_token_usage: serialized.cumulative_token_usage,
|
||||
exceeded_window_error: None,
|
||||
tool_use_limit_reached: false,
|
||||
feedback: None,
|
||||
message_feedback: HashMap::default(),
|
||||
last_auto_capture_at: None,
|
||||
|
@ -814,6 +817,10 @@ impl Thread {
|
|||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn tool_use_limit_reached(&self) -> bool {
|
||||
self.tool_use_limit_reached
|
||||
}
|
||||
|
||||
/// Returns whether all of the tool uses have finished running.
|
||||
pub fn all_tools_finished(&self) -> bool {
|
||||
// If the only pending tool uses left are the ones with errors, then
|
||||
|
@ -1331,6 +1338,8 @@ impl Thread {
|
|||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.tool_use_limit_reached = false;
|
||||
|
||||
let pending_completion_id = post_inc(&mut self.completion_count);
|
||||
let mut request_callback_parameters = if self.request_callback.is_some() {
|
||||
Some((request.clone(), Vec::new()))
|
||||
|
@ -1506,17 +1515,27 @@ impl Thread {
|
|||
});
|
||||
}
|
||||
}
|
||||
LanguageModelCompletionEvent::QueueUpdate(queue_event) => {
|
||||
LanguageModelCompletionEvent::QueueUpdate(status) => {
|
||||
if let Some(completion) = thread
|
||||
.pending_completions
|
||||
.iter_mut()
|
||||
.find(|completion| completion.id == pending_completion_id)
|
||||
{
|
||||
completion.queue_state = match queue_event {
|
||||
language_model::QueueState::Queued { position } => {
|
||||
QueueState::Queued { position }
|
||||
let queue_state = match status {
|
||||
language_model::CompletionRequestStatus::Queued {
|
||||
position,
|
||||
} => Some(QueueState::Queued { position }),
|
||||
language_model::CompletionRequestStatus::Started => {
|
||||
Some(QueueState::Started)
|
||||
}
|
||||
language_model::QueueState::Started => QueueState::Started,
|
||||
language_model::CompletionRequestStatus::ToolUseLimitReached => {
|
||||
thread.tool_use_limit_reached = true;
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(queue_state) = queue_state {
|
||||
completion.queue_state = queue_state;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -66,15 +66,16 @@ pub struct LanguageModelCacheConfiguration {
|
|||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(tag = "status", rename_all = "snake_case")]
|
||||
pub enum QueueState {
|
||||
pub enum CompletionRequestStatus {
|
||||
Queued { position: usize },
|
||||
Started,
|
||||
ToolUseLimitReached,
|
||||
}
|
||||
|
||||
/// A completion event from a language model.
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub enum LanguageModelCompletionEvent {
|
||||
QueueUpdate(QueueState),
|
||||
QueueUpdate(CompletionRequestStatus),
|
||||
Stop(StopReason),
|
||||
Text(String),
|
||||
Thinking {
|
||||
|
|
|
@ -9,11 +9,12 @@ use futures::{
|
|||
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
|
||||
use language_model::{
|
||||
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
|
||||
ModelRequestLimitReachedError, QueueState, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
|
||||
AuthenticateError, CloudModel, CompletionRequestStatus, LanguageModel,
|
||||
LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
|
||||
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
|
||||
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
|
||||
ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
|
||||
|
@ -38,6 +39,7 @@ use zed_llm_client::{
|
|||
CURRENT_PLAN_HEADER_NAME, CompletionBody, CountTokensBody, CountTokensResponse,
|
||||
EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
|
||||
MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
||||
TOOL_USE_LIMIT_REACHED_HEADER_NAME,
|
||||
};
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
|
@ -511,6 +513,13 @@ pub struct CloudLanguageModel {
|
|||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
struct PerformLlmCompletionResponse {
|
||||
response: Response<AsyncBody>,
|
||||
usage: Option<RequestUsage>,
|
||||
tool_use_limit_reached: bool,
|
||||
includes_queue_events: bool,
|
||||
}
|
||||
|
||||
impl CloudLanguageModel {
|
||||
const MAX_RETRIES: usize = 3;
|
||||
|
||||
|
@ -518,7 +527,7 @@ impl CloudLanguageModel {
|
|||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
body: CompletionBody,
|
||||
) -> Result<(Response<AsyncBody>, Option<RequestUsage>, bool)> {
|
||||
) -> Result<PerformLlmCompletionResponse> {
|
||||
let http_client = &client.http_client();
|
||||
|
||||
let mut token = llm_api_token.acquire(&client).await?;
|
||||
|
@ -545,9 +554,18 @@ impl CloudLanguageModel {
|
|||
.headers()
|
||||
.get("x-zed-server-supports-queueing")
|
||||
.is_some();
|
||||
let tool_use_limit_reached = response
|
||||
.headers()
|
||||
.get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
|
||||
.is_some();
|
||||
let usage = RequestUsage::from_headers(response.headers()).ok();
|
||||
|
||||
return Ok((response, usage, includes_queue_events));
|
||||
return Ok(PerformLlmCompletionResponse {
|
||||
response,
|
||||
usage,
|
||||
includes_queue_events,
|
||||
tool_use_limit_reached,
|
||||
});
|
||||
} else if response
|
||||
.headers()
|
||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||
|
@ -787,7 +805,12 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream_with_usage(async move {
|
||||
let (response, usage, includes_queue_events) = Self::perform_llm_completion(
|
||||
let PerformLlmCompletionResponse {
|
||||
response,
|
||||
usage,
|
||||
includes_queue_events,
|
||||
tool_use_limit_reached,
|
||||
} = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
CompletionBody {
|
||||
|
@ -819,7 +842,10 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let mut mapper = AnthropicEventMapper::new();
|
||||
Ok((
|
||||
map_cloud_completion_events(
|
||||
Box::pin(response_lines(response, includes_queue_events)),
|
||||
Box::pin(
|
||||
response_lines(response, includes_queue_events)
|
||||
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
|
||||
),
|
||||
move |event| mapper.map_event(event),
|
||||
),
|
||||
usage,
|
||||
|
@ -836,7 +862,12 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let request = into_open_ai(request, model, model.max_output_tokens());
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream_with_usage(async move {
|
||||
let (response, usage, includes_queue_events) = Self::perform_llm_completion(
|
||||
let PerformLlmCompletionResponse {
|
||||
response,
|
||||
usage,
|
||||
includes_queue_events,
|
||||
tool_use_limit_reached,
|
||||
} = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
CompletionBody {
|
||||
|
@ -853,7 +884,10 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let mut mapper = OpenAiEventMapper::new();
|
||||
Ok((
|
||||
map_cloud_completion_events(
|
||||
Box::pin(response_lines(response, includes_queue_events)),
|
||||
Box::pin(
|
||||
response_lines(response, includes_queue_events)
|
||||
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
|
||||
),
|
||||
move |event| mapper.map_event(event),
|
||||
),
|
||||
usage,
|
||||
|
@ -870,7 +904,12 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let request = into_google(request, model.id().into());
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream_with_usage(async move {
|
||||
let (response, usage, includes_queue_events) = Self::perform_llm_completion(
|
||||
let PerformLlmCompletionResponse {
|
||||
response,
|
||||
usage,
|
||||
includes_queue_events,
|
||||
tool_use_limit_reached,
|
||||
} = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
CompletionBody {
|
||||
|
@ -883,10 +922,14 @@ impl LanguageModel for CloudLanguageModel {
|
|||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
Ok((
|
||||
map_cloud_completion_events(
|
||||
Box::pin(response_lines(response, includes_queue_events)),
|
||||
Box::pin(
|
||||
response_lines(response, includes_queue_events)
|
||||
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
|
||||
),
|
||||
move |event| mapper.map_event(event),
|
||||
),
|
||||
usage,
|
||||
|
@ -905,7 +948,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CloudCompletionEvent<T> {
|
||||
Queue(QueueState),
|
||||
System(CompletionRequestStatus),
|
||||
Event(T),
|
||||
}
|
||||
|
||||
|
@ -925,7 +968,7 @@ where
|
|||
Err(error) => {
|
||||
vec![Err(LanguageModelCompletionError::Other(error))]
|
||||
}
|
||||
Ok(CloudCompletionEvent::Queue(event)) => {
|
||||
Ok(CloudCompletionEvent::System(event)) => {
|
||||
vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))]
|
||||
}
|
||||
Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
|
||||
|
@ -934,6 +977,16 @@ where
|
|||
.boxed()
|
||||
}
|
||||
|
||||
fn tool_use_limit_reached_event<T>(
|
||||
tool_use_limit_reached: bool,
|
||||
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
||||
futures::stream::iter(tool_use_limit_reached.then(|| {
|
||||
Ok(CloudCompletionEvent::System(
|
||||
CompletionRequestStatus::ToolUseLimitReached,
|
||||
))
|
||||
}))
|
||||
}
|
||||
|
||||
fn response_lines<T: DeserializeOwned>(
|
||||
response: Response<AsyncBody>,
|
||||
includes_queue_events: bool,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue