From 0ea0d466d289ff2c57bdac3ab4b20f61c9ab7494 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Tue, 19 Aug 2025 11:41:55 +0200 Subject: [PATCH] agent2: Port retry logic (#36421) Release Notes: - N/A --- crates/acp_thread/src/acp_thread.rs | 15 ++ crates/agent2/src/agent.rs | 5 + crates/agent2/src/tests/mod.rs | 166 +++++++++++- crates/agent2/src/thread.rs | 286 ++++++++++++++++++--- crates/agent_ui/src/acp/thread_view.rs | 61 ++++- crates/agent_ui/src/agent_diff.rs | 1 + crates/language_model/src/fake_provider.rs | 32 ++- 7 files changed, 514 insertions(+), 52 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 8bc0635475..916f48cbe0 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -24,6 +24,7 @@ use std::fmt::{Formatter, Write}; use std::ops::Range; use std::process::ExitStatus; use std::rc::Rc; +use std::time::{Duration, Instant}; use std::{fmt::Display, mem, path::PathBuf, sync::Arc}; use ui::App; use util::ResultExt; @@ -658,6 +659,15 @@ impl PlanEntry { } } +#[derive(Debug, Clone)] +pub struct RetryStatus { + pub last_error: SharedString, + pub attempt: usize, + pub max_attempts: usize, + pub started_at: Instant, + pub duration: Duration, +} + pub struct AcpThread { title: SharedString, entries: Vec, @@ -676,6 +686,7 @@ pub enum AcpThreadEvent { EntryUpdated(usize), EntriesRemoved(Range), ToolAuthorizationRequired, + Retry(RetryStatus), Stopped, Error, ServerExited(ExitStatus), @@ -916,6 +927,10 @@ impl AcpThread { cx.emit(AcpThreadEvent::NewEntry); } + pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context) { + cx.emit(AcpThreadEvent::Retry(status)); + } + pub fn update_tool_call( &mut self, update: impl Into, diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 6347f5f9a4..480b2baa95 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -546,6 +546,11 @@ impl NativeAgentConnection { thread.update_tool_call(update, cx) })??; } + AgentResponseEvent::Retry(status) => { + acp_thread.update(cx, |thread, cx| { + thread.update_retry_status(status, cx) + })?; + } AgentResponseEvent::Stop(stop_reason) => { log::debug!("Assistant message complete: {:?}", stop_reason); return Ok(acp::PromptResponse { stop_reason }); diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 13b37fbaa2..c83479f2cf 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -6,15 +6,16 @@ use agent_settings::AgentProfileId; use anyhow::Result; use client::{Client, UserStore}; use fs::{FakeFs, Fs}; -use futures::channel::mpsc::UnboundedReceiver; +use futures::{StreamExt, channel::mpsc::UnboundedReceiver}; use gpui::{ App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient, }; use indoc::indoc; use language_model::{ - LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, - LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent, - Role, StopReason, fake_provider::FakeLanguageModel, + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, + LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage, + LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason, + fake_provider::FakeLanguageModel, }; use pretty_assertions::assert_eq; use project::Project; @@ -24,7 +25,6 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::json; use settings::SettingsStore; -use smol::stream::StreamExt; use std::{path::Path, rc::Rc, sync::Arc, time::Duration}; use util::path; @@ -1435,6 +1435,162 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { ); } +#[gpui::test] +async fn test_send_no_retry_on_success(cx: &mut TestAppContext) { + let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let mut events = thread + .update(cx, |thread, cx| { + thread.set_completion_mode(agent_settings::CompletionMode::Burn); + thread.send(UserMessageId::new(), ["Hello!"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_text_chunk("Hey!"); + fake_model.end_last_completion_stream(); + + let mut retry_events = Vec::new(); + while let Some(Ok(event)) = events.next().await { + match event { + AgentResponseEvent::Retry(retry_status) => { + retry_events.push(retry_status); + } + AgentResponseEvent::Stop(..) => break, + _ => {} + } + } + + assert_eq!(retry_events.len(), 0); + thread.read_with(cx, |thread, _cx| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hello! + + ## Assistant + + Hey! + "} + ) + }); +} + +#[gpui::test] +async fn test_send_retry_on_error(cx: &mut TestAppContext) { + let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let mut events = thread + .update(cx, |thread, cx| { + thread.set_completion_mode(agent_settings::CompletionMode::Burn); + thread.send(UserMessageId::new(), ["Hello!"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded { + provider: LanguageModelProviderName::new("Anthropic"), + retry_after: Some(Duration::from_secs(3)), + }); + fake_model.end_last_completion_stream(); + + cx.executor().advance_clock(Duration::from_secs(3)); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_text_chunk("Hey!"); + fake_model.end_last_completion_stream(); + + let mut retry_events = Vec::new(); + while let Some(Ok(event)) = events.next().await { + match event { + AgentResponseEvent::Retry(retry_status) => { + retry_events.push(retry_status); + } + AgentResponseEvent::Stop(..) => break, + _ => {} + } + } + + assert_eq!(retry_events.len(), 1); + assert!(matches!( + retry_events[0], + acp_thread::RetryStatus { attempt: 1, .. } + )); + thread.read_with(cx, |thread, _cx| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hello! + + ## Assistant + + Hey! + "} + ) + }); +} + +#[gpui::test] +async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) { + let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let mut events = thread + .update(cx, |thread, cx| { + thread.set_completion_mode(agent_settings::CompletionMode::Burn); + thread.send(UserMessageId::new(), ["Hello!"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 { + fake_model.send_last_completion_stream_error( + LanguageModelCompletionError::ServerOverloaded { + provider: LanguageModelProviderName::new("Anthropic"), + retry_after: Some(Duration::from_secs(3)), + }, + ); + fake_model.end_last_completion_stream(); + cx.executor().advance_clock(Duration::from_secs(3)); + cx.run_until_parked(); + } + + let mut errors = Vec::new(); + let mut retry_events = Vec::new(); + while let Some(event) = events.next().await { + match event { + Ok(AgentResponseEvent::Retry(retry_status)) => { + retry_events.push(retry_status); + } + Ok(AgentResponseEvent::Stop(..)) => break, + Err(error) => errors.push(error), + _ => {} + } + } + + assert_eq!( + retry_events.len(), + crate::thread::MAX_RETRY_ATTEMPTS as usize + ); + for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize { + assert_eq!(retry_events[i].attempt, i + 1); + } + assert_eq!(errors.len(), 1); + let error = errors[0] + .downcast_ref::() + .unwrap(); + assert!(matches!( + error, + LanguageModelCompletionError::ServerOverloaded { .. } + )); +} + /// Filters out the stop events for asserting against in tests fn stop_events(result_events: Vec>) -> Vec { result_events diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 7f0465f5ce..beb780850c 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -12,12 +12,12 @@ use futures::{ channel::{mpsc, oneshot}, stream::FuturesUnordered, }; -use gpui::{App, Context, Entity, SharedString, Task}; +use gpui::{App, AsyncApp, Context, Entity, SharedString, Task, WeakEntity}; use language_model::{ - LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId, - LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, - LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, + LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage, + LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, + LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, }; use project::Project; use prompt_store::ProjectContext; @@ -25,7 +25,12 @@ use schemars::{JsonSchema, Schema}; use serde::{Deserialize, Serialize}; use settings::{Settings, update_settings_file}; use smol::stream::StreamExt; -use std::{collections::BTreeMap, path::Path, sync::Arc}; +use std::{ + collections::BTreeMap, + path::Path, + sync::Arc, + time::{Duration, Instant}, +}; use std::{fmt::Write, ops::Range}; use util::{ResultExt, markdown::MarkdownCodeBlock}; use uuid::Uuid; @@ -71,6 +76,21 @@ impl std::fmt::Display for PromptId { } } +pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4; +pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5); + +#[derive(Debug, Clone)] +enum RetryStrategy { + ExponentialBackoff { + initial_delay: Duration, + max_attempts: u8, + }, + Fixed { + delay: Duration, + max_attempts: u8, + }, +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum Message { User(UserMessage), @@ -455,6 +475,7 @@ pub enum AgentResponseEvent { ToolCall(acp::ToolCall), ToolCallUpdate(acp_thread::ToolCallUpdate), ToolCallAuthorization(ToolCallAuthorization), + Retry(acp_thread::RetryStatus), Stop(acp::StopReason), } @@ -662,41 +683,18 @@ impl Thread { })??; log::info!("Calling model.stream_completion"); - let mut events = model.stream_completion(request, cx).await?; - log::debug!("Stream completion started successfully"); let mut tool_use_limit_reached = false; - let mut tool_uses = FuturesUnordered::new(); - while let Some(event) = events.next().await { - match event? { - LanguageModelCompletionEvent::StatusUpdate( - CompletionRequestStatus::ToolUseLimitReached, - ) => { - tool_use_limit_reached = true; - } - LanguageModelCompletionEvent::Stop(reason) => { - event_stream.send_stop(reason); - if reason == StopReason::Refusal { - this.update(cx, |this, _cx| { - this.flush_pending_message(); - this.messages.truncate(message_ix); - })?; - return Ok(()); - } - } - event => { - log::trace!("Received completion event: {:?}", event); - this.update(cx, |this, cx| { - tool_uses.extend(this.handle_streamed_completion_event( - event, - &event_stream, - cx, - )); - }) - .ok(); - } - } - } + let mut tool_uses = Self::stream_completion_with_retries( + this.clone(), + model.clone(), + request, + message_ix, + &event_stream, + &mut tool_use_limit_reached, + cx, + ) + .await?; let used_tools = tool_uses.is_empty(); while let Some(tool_result) = tool_uses.next().await { @@ -754,10 +752,105 @@ impl Thread { Ok(events_rx) } + async fn stream_completion_with_retries( + this: WeakEntity, + model: Arc, + request: LanguageModelRequest, + message_ix: usize, + event_stream: &AgentResponseEventStream, + tool_use_limit_reached: &mut bool, + cx: &mut AsyncApp, + ) -> Result>> { + log::debug!("Stream completion started successfully"); + + let mut attempt = None; + 'retry: loop { + let mut events = model.stream_completion(request.clone(), cx).await?; + let mut tool_uses = FuturesUnordered::new(); + while let Some(event) = events.next().await { + match event { + Ok(LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::ToolUseLimitReached, + )) => { + *tool_use_limit_reached = true; + } + Ok(LanguageModelCompletionEvent::Stop(reason)) => { + event_stream.send_stop(reason); + if reason == StopReason::Refusal { + this.update(cx, |this, _cx| { + this.flush_pending_message(); + this.messages.truncate(message_ix); + })?; + return Ok(tool_uses); + } + } + Ok(event) => { + log::trace!("Received completion event: {:?}", event); + this.update(cx, |this, cx| { + tool_uses.extend(this.handle_streamed_completion_event( + event, + event_stream, + cx, + )); + }) + .ok(); + } + Err(error) => { + let completion_mode = + this.read_with(cx, |thread, _cx| thread.completion_mode())?; + if completion_mode == CompletionMode::Normal { + return Err(error.into()); + } + + let Some(strategy) = Self::retry_strategy_for(&error) else { + return Err(error.into()); + }; + + let max_attempts = match &strategy { + RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts, + RetryStrategy::Fixed { max_attempts, .. } => *max_attempts, + }; + + let attempt = attempt.get_or_insert(0u8); + + *attempt += 1; + + let attempt = *attempt; + if attempt > max_attempts { + return Err(error.into()); + } + + let delay = match &strategy { + RetryStrategy::ExponentialBackoff { initial_delay, .. } => { + let delay_secs = + initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32); + Duration::from_secs(delay_secs) + } + RetryStrategy::Fixed { delay, .. } => *delay, + }; + log::debug!("Retry attempt {attempt} with delay {delay:?}"); + + event_stream.send_retry(acp_thread::RetryStatus { + last_error: error.to_string().into(), + attempt: attempt as usize, + max_attempts: max_attempts as usize, + started_at: Instant::now(), + duration: delay, + }); + + cx.background_executor().timer(delay).await; + continue 'retry; + } + } + } + return Ok(tool_uses); + } + } + pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage { log::debug!("Building system message"); let prompt = SystemPromptTemplate { - project: &self.project_context.read(cx), + project: self.project_context.read(cx), available_tools: self.tools.keys().cloned().collect(), } .render(&self.templates) @@ -1158,6 +1251,113 @@ impl Thread { fn advance_prompt_id(&mut self) { self.prompt_id = PromptId::new(); } + + fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option { + use LanguageModelCompletionError::*; + use http_client::StatusCode; + + // General strategy here: + // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all. + // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff. + // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times. + match error { + HttpResponseError { + status_code: StatusCode::TOO_MANY_REQUESTS, + .. + } => Some(RetryStrategy::ExponentialBackoff { + initial_delay: BASE_RETRY_DELAY, + max_attempts: MAX_RETRY_ATTEMPTS, + }), + ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: MAX_RETRY_ATTEMPTS, + }) + } + UpstreamProviderError { + status, + retry_after, + .. + } => match *status { + StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: MAX_RETRY_ATTEMPTS, + }) + } + StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + // Internal Server Error could be anything, retry up to 3 times. + max_attempts: 3, + }), + status => { + // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"), + // but we frequently get them in practice. See https://http.dev/529 + if status.as_u16() == 529 { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: MAX_RETRY_ATTEMPTS, + }) + } else { + Some(RetryStrategy::Fixed { + delay: retry_after.unwrap_or(BASE_RETRY_DELAY), + max_attempts: 2, + }) + } + } + }, + ApiInternalServerError { .. } => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 3, + }), + ApiReadResponseError { .. } + | HttpSend { .. } + | DeserializeResponse { .. } + | BadRequestFormat { .. } => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 3, + }), + // Retrying these errors definitely shouldn't help. + HttpResponseError { + status_code: + StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED, + .. + } + | AuthenticationError { .. } + | PermissionError { .. } + | NoApiKey { .. } + | ApiEndpointNotFound { .. } + | PromptTooLarge { .. } => None, + // These errors might be transient, so retry them + SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 1, + }), + // Retry all other 4xx and 5xx errors once. + HttpResponseError { status_code, .. } + if status_code.is_client_error() || status_code.is_server_error() => + { + Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 3, + }) + } + Other(err) + if err.is::() + || err.is::() => + { + // Retrying won't help for Payment Required or Model Request Limit errors (where + // the user must upgrade to usage-based billing to get more requests, or else wait + // for a significant amount of time for the request limit to reset). + None + } + // Conservatively assume that any other errors are non-retryable + HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed { + delay: BASE_RETRY_DELAY, + max_attempts: 2, + }), + } + } } struct RunningTurn { @@ -1367,6 +1567,12 @@ impl AgentResponseEventStream { .ok(); } + fn send_retry(&self, status: acp_thread::RetryStatus) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Retry(status))) + .ok(); + } + fn send_stop(&self, reason: StopReason) { match reason { StopReason::EndTurn => { diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 2fffe1b179..370dae53e4 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,7 +1,7 @@ use acp_thread::{ AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, - AuthRequired, LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, - UserMessageId, + AuthRequired, LoadError, MentionUri, RetryStatus, ThreadStatus, ToolCall, ToolCallContent, + ToolCallStatus, UserMessageId, }; use acp_thread::{AgentConnection, Plan}; use action_log::ActionLog; @@ -35,6 +35,7 @@ use prompt_store::PromptId; use rope::Point; use settings::{Settings as _, SettingsStore}; use std::sync::Arc; +use std::time::Instant; use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration}; use text::Anchor; use theme::ThemeSettings; @@ -115,6 +116,7 @@ pub struct AcpThreadView { profile_selector: Option>, notifications: Vec>, notification_subscriptions: HashMap, Vec>, + thread_retry_status: Option, thread_error: Option, list_state: ListState, scrollbar_state: ScrollbarState, @@ -209,6 +211,7 @@ impl AcpThreadView { notification_subscriptions: HashMap::default(), list_state: list_state.clone(), scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()), + thread_retry_status: None, thread_error: None, auth_task: None, expanded_tool_calls: HashSet::default(), @@ -445,6 +448,7 @@ impl AcpThreadView { pub fn cancel_generation(&mut self, cx: &mut Context) { self.thread_error.take(); + self.thread_retry_status.take(); if let Some(thread) = self.thread() { self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx))); @@ -775,7 +779,11 @@ impl AcpThreadView { AcpThreadEvent::ToolAuthorizationRequired => { self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx); } + AcpThreadEvent::Retry(retry) => { + self.thread_retry_status = Some(retry.clone()); + } AcpThreadEvent::Stopped => { + self.thread_retry_status.take(); let used_tools = thread.read(cx).used_tools_since_last_user_message(); self.notify_with_sound( if used_tools { @@ -789,6 +797,7 @@ impl AcpThreadView { ); } AcpThreadEvent::Error => { + self.thread_retry_status.take(); self.notify_with_sound( "Agent stopped due to an error", IconName::Warning, @@ -797,6 +806,7 @@ impl AcpThreadView { ); } AcpThreadEvent::ServerExited(status) => { + self.thread_retry_status.take(); self.thread_state = ThreadState::ServerExited { status: *status }; } } @@ -3413,7 +3423,51 @@ impl AcpThreadView { }) } - fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option
{ + fn render_thread_retry_status_callout( + &self, + _window: &mut Window, + _cx: &mut Context, + ) -> Option { + let state = self.thread_retry_status.as_ref()?; + + let next_attempt_in = state + .duration + .saturating_sub(Instant::now().saturating_duration_since(state.started_at)); + if next_attempt_in.is_zero() { + return None; + } + + let next_attempt_in_secs = next_attempt_in.as_secs() + 1; + + let retry_message = if state.max_attempts == 1 { + if next_attempt_in_secs == 1 { + "Retrying. Next attempt in 1 second.".to_string() + } else { + format!("Retrying. Next attempt in {next_attempt_in_secs} seconds.") + } + } else { + if next_attempt_in_secs == 1 { + format!( + "Retrying. Next attempt in 1 second (Attempt {} of {}).", + state.attempt, state.max_attempts, + ) + } else { + format!( + "Retrying. Next attempt in {next_attempt_in_secs} seconds (Attempt {} of {}).", + state.attempt, state.max_attempts, + ) + } + }; + + Some( + Callout::new() + .severity(Severity::Warning) + .title(state.last_error.clone()) + .description(retry_message), + ) + } + + fn render_thread_error(&self, window: &mut Window, cx: &mut Context) -> Option
{ let content = match self.thread_error.as_ref()? { ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx), ThreadError::PaymentRequired => self.render_payment_required_error(cx), @@ -3678,6 +3732,7 @@ impl Render for AcpThreadView { } _ => this, }) + .children(self.render_thread_retry_status_callout(window, cx)) .children(self.render_thread_error(window, cx)) .child(self.render_message_editor(window, cx)) } diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index 3522a0c9ab..b0b06583a4 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1523,6 +1523,7 @@ impl AgentDiff { AcpThreadEvent::EntriesRemoved(_) | AcpThreadEvent::Stopped | AcpThreadEvent::ToolAuthorizationRequired + | AcpThreadEvent::Retry(_) | AcpThreadEvent::Error | AcpThreadEvent::ServerExited(_) => {} } diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index 67fba44887..ebfd37d16c 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -4,10 +4,11 @@ use crate::{ LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, }; -use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; +use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, Entity, Task, Window}; use http_client::Result; use parking_lot::Mutex; +use smol::stream::StreamExt; use std::sync::Arc; #[derive(Clone)] @@ -100,7 +101,9 @@ pub struct FakeLanguageModel { current_completion_txs: Mutex< Vec<( LanguageModelRequest, - mpsc::UnboundedSender, + mpsc::UnboundedSender< + Result, + >, )>, >, } @@ -150,7 +153,21 @@ impl FakeLanguageModel { .find(|(req, _)| req == request) .map(|(_, tx)| tx) .unwrap(); - tx.unbounded_send(event.into()).unwrap(); + tx.unbounded_send(Ok(event.into())).unwrap(); + } + + pub fn send_completion_stream_error( + &self, + request: &LanguageModelRequest, + error: impl Into, + ) { + let current_completion_txs = self.current_completion_txs.lock(); + let tx = current_completion_txs + .iter() + .find(|(req, _)| req == request) + .map(|(_, tx)| tx) + .unwrap(); + tx.unbounded_send(Err(error.into())).unwrap(); } pub fn end_completion_stream(&self, request: &LanguageModelRequest) { @@ -170,6 +187,13 @@ impl FakeLanguageModel { self.send_completion_stream_event(self.pending_completions().last().unwrap(), event); } + pub fn send_last_completion_stream_error( + &self, + error: impl Into, + ) { + self.send_completion_stream_error(self.pending_completions().last().unwrap(), error); + } + pub fn end_last_completion_stream(&self) { self.end_completion_stream(self.pending_completions().last().unwrap()); } @@ -229,7 +253,7 @@ impl LanguageModel for FakeLanguageModel { > { let (tx, rx) = mpsc::unbounded(); self.current_completion_txs.lock().push((request, tx)); - async move { Ok(rx.map(Ok).boxed()) }.boxed() + async move { Ok(rx.boxed()) }.boxed() } fn as_fake(&self) -> &Self {