From 6073d2c93c4bc9ded0c6fc3a22bce814f4ac9716 Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Thu, 26 Jun 2025 10:53:33 -0400 Subject: [PATCH] Automatically retry when API is Overloaded or 500s (#33275) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Screenshot 2025-06-25 at 2 26 16 PM Screenshot 2025-06-25 at 2 26 08 PM Now we: * Automatically retry up to 3 times on upstream Overloaded or 500 errors (currently for Anthropic only; will add others in future PRs) * Also automatically retry on rate limit errors (using the provided duration to wait, if we were given one) * Give you a notification if you don't have Zed open and we stopped the thread because of an error Still todo in future PRs: * Update collab to report Overloaded and 500 errors differently if collab itself is passing through an upstream error vs not (currently we report these as "Zed's API is overloaded" when actually it's the upstream one!) * Updating providers other than Anthropic to categorize their errors so that they benefit from this * Expanding graceful error handling/retry to other things besides Overloaded and 500 errors (e.g. connection reset) Release Notes: - Automatically retry in Agent Panel instead of erroring out when an upstream AI API is overloaded or 500s - Show a notification when an Agent thread errors out and Zed is not the active window --- Cargo.lock | 1 + crates/agent/Cargo.toml | 1 + crates/agent/src/thread.rs | 1509 +++++++++++++++++++++++++- crates/agent_ui/src/active_thread.rs | 326 +++--- crates/agent_ui/src/agent_diff.rs | 1 + crates/eval/src/example.rs | 3 + 6 files changed, 1655 insertions(+), 186 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f50842dddc..26fce3c46b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -78,6 +78,7 @@ dependencies = [ "language", "language_model", "log", + "parking_lot", "paths", "postage", "pretty_assertions", diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index f320e58d00..135363ab65 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -72,6 +72,7 @@ gpui = { workspace = true, "features" = ["test-support"] } indoc.workspace = true language = { workspace = true, "features" = ["test-support"] } language_model = { workspace = true, "features" = ["test-support"] } +parking_lot.workspace = true pretty_assertions.workspace = true project = { workspace = true, features = ["test-support"] } workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 33b9209f0c..68624d7c3b 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -39,12 +39,20 @@ use proto::Plan; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; -use std::{io::Write, ops::Range, sync::Arc, time::Instant}; +use std::{ + io::Write, + ops::Range, + sync::Arc, + time::{Duration, Instant}, +}; use thiserror::Error; use util::{ResultExt as _, post_inc}; use uuid::Uuid; use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; +const MAX_RETRY_ATTEMPTS: u8 = 3; +const BASE_RETRY_DELAY_SECS: u64 = 5; + #[derive( Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, )] @@ -118,6 +126,7 @@ pub struct Message { pub loaded_context: LoadedContext, pub creases: Vec, pub is_hidden: bool, + pub ui_only: bool, } impl Message { @@ -364,6 +373,7 @@ pub struct Thread { exceeded_window_error: Option, tool_use_limit_reached: bool, feedback: Option, + retry_state: Option, message_feedback: HashMap, last_auto_capture_at: Option, last_received_chunk_at: Option, @@ -375,6 +385,13 @@ pub struct Thread { profile: AgentProfile, } +#[derive(Clone, Debug)] +struct RetryState { + attempt: u8, + max_attempts: u8, + intent: CompletionIntent, +} + #[derive(Clone, Debug, PartialEq, Eq)] pub enum ThreadSummary { Pending, @@ -456,6 +473,7 @@ impl Thread { exceeded_window_error: None, tool_use_limit_reached: false, feedback: None, + retry_state: None, message_feedback: HashMap::default(), last_auto_capture_at: None, last_received_chunk_at: None, @@ -522,6 +540,7 @@ impl Thread { detailed_summary_tx, detailed_summary_rx, completion_mode, + retry_state: None, messages: serialized .messages .into_iter() @@ -557,6 +576,7 @@ impl Thread { }) .collect(), is_hidden: message.is_hidden, + ui_only: false, // UI-only messages are not persisted }) .collect(), next_message_id, @@ -1041,6 +1061,7 @@ impl Thread { loaded_context, creases, is_hidden, + ui_only: false, }); self.touch_updated_at(); cx.emit(ThreadEvent::MessageAdded(id)); @@ -1130,6 +1151,7 @@ impl Thread { updated_at: this.updated_at(), messages: this .messages() + .filter(|message| !message.ui_only) .map(|message| SerializedMessage { id: message.id, role: message.role, @@ -1228,7 +1250,7 @@ impl Thread { let request = self.to_completion_request(model.clone(), intent, cx); - self.stream_completion(request, model, window, cx); + self.stream_completion(request, model, intent, window, cx); } pub fn used_tools_since_last_user_message(&self) -> bool { @@ -1303,6 +1325,11 @@ impl Thread { let mut message_ix_to_cache = None; for message in &self.messages { + // ui_only messages are for the UI only, not for the model + if message.ui_only { + continue; + } + let mut request_message = LanguageModelRequestMessage { role: message.role, content: Vec::new(), @@ -1457,6 +1484,7 @@ impl Thread { &mut self, request: LanguageModelRequest, model: Arc, + intent: CompletionIntent, window: Option, cx: &mut Context, ) { @@ -1770,58 +1798,64 @@ impl Thread { }; let result = stream_completion.await; + let mut retry_scheduled = false; thread .update(cx, |thread, cx| { thread.finalize_pending_checkpoint(cx); match result.as_ref() { - Ok(stop_reason) => match stop_reason { - StopReason::ToolUse => { - let tool_uses = thread.use_pending_tools(window, model.clone(), cx); - cx.emit(ThreadEvent::UsePendingTools { tool_uses }); - } - StopReason::EndTurn | StopReason::MaxTokens => { - thread.project.update(cx, |project, cx| { - project.set_agent_location(None, cx); - }); - } - StopReason::Refusal => { - thread.project.update(cx, |project, cx| { - project.set_agent_location(None, cx); - }); + Ok(stop_reason) => { + match stop_reason { + StopReason::ToolUse => { + let tool_uses = thread.use_pending_tools(window, model.clone(), cx); + cx.emit(ThreadEvent::UsePendingTools { tool_uses }); + } + StopReason::EndTurn | StopReason::MaxTokens => { + thread.project.update(cx, |project, cx| { + project.set_agent_location(None, cx); + }); + } + StopReason::Refusal => { + thread.project.update(cx, |project, cx| { + project.set_agent_location(None, cx); + }); - // Remove the turn that was refused. - // - // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal - { - let mut messages_to_remove = Vec::new(); + // Remove the turn that was refused. + // + // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal + { + let mut messages_to_remove = Vec::new(); - for (ix, message) in thread.messages.iter().enumerate().rev() { - messages_to_remove.push(message.id); + for (ix, message) in thread.messages.iter().enumerate().rev() { + messages_to_remove.push(message.id); - if message.role == Role::User { - if ix == 0 { - break; - } - - if let Some(prev_message) = thread.messages.get(ix - 1) { - if prev_message.role == Role::Assistant { + if message.role == Role::User { + if ix == 0 { break; } + + if let Some(prev_message) = thread.messages.get(ix - 1) { + if prev_message.role == Role::Assistant { + break; + } + } } } + + for message_id in messages_to_remove { + thread.delete_message(message_id, cx); + } } - for message_id in messages_to_remove { - thread.delete_message(message_id, cx); - } + cx.emit(ThreadEvent::ShowError(ThreadError::Message { + header: "Language model refusal".into(), + message: "Model refused to generate content for safety reasons.".into(), + })); } - - cx.emit(ThreadEvent::ShowError(ThreadError::Message { - header: "Language model refusal".into(), - message: "Model refused to generate content for safety reasons.".into(), - })); } + + // We successfully completed, so cancel any remaining retries. + thread.retry_state = None; }, Err(error) => { thread.project.update(cx, |project, cx| { @@ -1859,17 +1893,58 @@ impl Thread { }); cx.notify(); } - LanguageModelKnownError::RateLimitExceeded { .. } => { - // In the future we will report the error to the user, wait retry_after, and then retry. - emit_generic_error(error, cx); + LanguageModelKnownError::RateLimitExceeded { retry_after } => { + let provider_name = model.provider_name(); + let error_message = format!( + "{}'s API rate limit exceeded", + provider_name.0.as_ref() + ); + + thread.handle_rate_limit_error( + &error_message, + *retry_after, + model.clone(), + intent, + window, + cx, + ); + retry_scheduled = true; } LanguageModelKnownError::Overloaded => { - // In the future we will wait and then retry, up to N times. - emit_generic_error(error, cx); + let provider_name = model.provider_name(); + let error_message = format!( + "{}'s API servers are overloaded right now", + provider_name.0.as_ref() + ); + + retry_scheduled = thread.handle_retryable_error( + &error_message, + model.clone(), + intent, + window, + cx, + ); + if !retry_scheduled { + emit_generic_error(error, cx); + } } LanguageModelKnownError::ApiInternalServerError => { - // In the future we will retry the request, but only once. - emit_generic_error(error, cx); + let provider_name = model.provider_name(); + let error_message = format!( + "{}'s API server reported an internal server error", + provider_name.0.as_ref() + ); + + retry_scheduled = thread.handle_retryable_error( + &error_message, + model.clone(), + intent, + window, + cx, + ); + if !retry_scheduled { + emit_generic_error(error, cx); + } } LanguageModelKnownError::ReadResponseError(_) | LanguageModelKnownError::DeserializeResponse(_) | @@ -1882,11 +1957,15 @@ impl Thread { emit_generic_error(error, cx); } - thread.cancel_last_completion(window, cx); + if !retry_scheduled { + thread.cancel_last_completion(window, cx); + } } } - cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new))); + if !retry_scheduled { + cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new))); + } if let Some((request_callback, (request, response_events))) = thread .request_callback @@ -2002,6 +2081,146 @@ impl Thread { }); } + fn handle_rate_limit_error( + &mut self, + error_message: &str, + retry_after: Duration, + model: Arc, + intent: CompletionIntent, + window: Option, + cx: &mut Context, + ) { + // For rate limit errors, we only retry once with the specified duration + let retry_message = format!( + "{error_message}. Retrying in {} seconds…", + retry_after.as_secs() + ); + + // Add a UI-only message instead of a regular message + let id = self.next_message_id.post_inc(); + self.messages.push(Message { + id, + role: Role::System, + segments: vec![MessageSegment::Text(retry_message)], + loaded_context: LoadedContext::default(), + creases: Vec::new(), + is_hidden: false, + ui_only: true, + }); + cx.emit(ThreadEvent::MessageAdded(id)); + // Schedule the retry + let thread_handle = cx.entity().downgrade(); + + cx.spawn(async move |_thread, cx| { + cx.background_executor().timer(retry_after).await; + + thread_handle + .update(cx, |thread, cx| { + // Retry the completion + thread.send_to_model(model, intent, window, cx); + }) + .log_err(); + }) + .detach(); + } + + fn handle_retryable_error( + &mut self, + error_message: &str, + model: Arc, + intent: CompletionIntent, + window: Option, + cx: &mut Context, + ) -> bool { + self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx) + } + + fn handle_retryable_error_with_delay( + &mut self, + error_message: &str, + custom_delay: Option, + model: Arc, + intent: CompletionIntent, + window: Option, + cx: &mut Context, + ) -> bool { + let retry_state = self.retry_state.get_or_insert(RetryState { + attempt: 0, + max_attempts: MAX_RETRY_ATTEMPTS, + intent, + }); + + retry_state.attempt += 1; + let attempt = retry_state.attempt; + let max_attempts = retry_state.max_attempts; + let intent = retry_state.intent; + + if attempt <= max_attempts { + // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff + let delay = if let Some(custom_delay) = custom_delay { + custom_delay + } else { + let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32); + Duration::from_secs(delay_secs) + }; + + // Add a transient message to inform the user + let delay_secs = delay.as_secs(); + let retry_message = format!( + "{}. Retrying (attempt {} of {}) in {} seconds...", + error_message, attempt, max_attempts, delay_secs + ); + + // Add a UI-only message instead of a regular message + let id = self.next_message_id.post_inc(); + self.messages.push(Message { + id, + role: Role::System, + segments: vec![MessageSegment::Text(retry_message)], + loaded_context: LoadedContext::default(), + creases: Vec::new(), + is_hidden: false, + ui_only: true, + }); + cx.emit(ThreadEvent::MessageAdded(id)); + + // Schedule the retry + let thread_handle = cx.entity().downgrade(); + + cx.spawn(async move |_thread, cx| { + cx.background_executor().timer(delay).await; + + thread_handle + .update(cx, |thread, cx| { + // Retry the completion + thread.send_to_model(model, intent, window, cx); + }) + .log_err(); + }) + .detach(); + + true + } else { + // Max retries exceeded + self.retry_state = None; + + let notification_text = if max_attempts == 1 { + "Failed after retrying.".into() + } else { + format!("Failed after retrying {} times.", max_attempts).into() + }; + + // Stop generating since we're giving up on retrying. + self.pending_completions.clear(); + + cx.emit(ThreadEvent::RetriesFailed { + message: notification_text, + }); + + false + } + } + pub fn start_generating_detailed_summary_if_needed( &mut self, thread_store: WeakEntity, @@ -2354,7 +2573,9 @@ impl Thread { window: Option, cx: &mut Context, ) -> bool { - let mut canceled = self.pending_completions.pop().is_some(); + let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some(); + + self.retry_state = None; for pending_tool_use in self.tool_use.cancel_pending() { canceled = true; @@ -2943,6 +3164,9 @@ pub enum ThreadEvent { CancelEditing, CompletionCanceled, ProfileChanged, + RetriesFailed { + message: SharedString, + }, } impl EventEmitter for Thread {} @@ -3038,16 +3262,28 @@ mod tests { use crate::{ context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore, }; + + // Test-specific constants + const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30; use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters}; use assistant_tool::ToolRegistry; + use futures::StreamExt; + use futures::future::BoxFuture; + use futures::stream::BoxStream; use gpui::TestAppContext; use icons::IconName; use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}; + use language_model::{ + LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelToolChoice, + }; + use parking_lot::Mutex; use project::{FakeFs, Project}; use prompt_store::PromptBuilder; use serde_json::json; use settings::{Settings, SettingsStore}; use std::sync::Arc; + use std::time::Duration; use theme::ThemeSettings; use util::path; use workspace::Workspace; @@ -3822,6 +4058,1183 @@ fn main() {{ } } + // Helper to create a model that returns errors + enum TestError { + Overloaded, + InternalServerError, + } + + struct ErrorInjector { + inner: Arc, + error_type: TestError, + } + + impl ErrorInjector { + fn new(error_type: TestError) -> Self { + Self { + inner: Arc::new(FakeLanguageModel::default()), + error_type, + } + } + } + + impl LanguageModel for ErrorInjector { + fn id(&self) -> LanguageModelId { + self.inner.id() + } + + fn name(&self) -> LanguageModelName { + self.inner.name() + } + + fn provider_id(&self) -> LanguageModelProviderId { + self.inner.provider_id() + } + + fn provider_name(&self) -> LanguageModelProviderName { + self.inner.provider_name() + } + + fn supports_tools(&self) -> bool { + self.inner.supports_tools() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + self.inner.supports_tool_choice(choice) + } + + fn supports_images(&self) -> bool { + self.inner.supports_images() + } + + fn telemetry_id(&self) -> String { + self.inner.telemetry_id() + } + + fn max_token_count(&self) -> u64 { + self.inner.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + self.inner.count_tokens(request, cx) + } + + fn stream_completion( + &self, + _request: LanguageModelRequest, + _cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { + let error = match self.error_type { + TestError::Overloaded => LanguageModelCompletionError::Overloaded, + TestError::InternalServerError => { + LanguageModelCompletionError::ApiInternalServerError + } + }; + async move { + let stream = futures::stream::once(async move { Err(error) }); + Ok(stream.boxed()) + } + .boxed() + } + + fn as_fake(&self) -> &FakeLanguageModel { + &self.inner + } + } + + #[gpui::test] + async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Create model that returns overloaded error + let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Start completion + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + + cx.run_until_parked(); + + thread.read_with(cx, |thread, _| { + assert!(thread.retry_state.is_some(), "Should have retry state"); + let retry_state = thread.retry_state.as_ref().unwrap(); + assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); + assert_eq!( + retry_state.max_attempts, MAX_RETRY_ATTEMPTS, + "Should have default max attempts" + ); + }); + + // Check that a retry message was added + thread.read_with(cx, |thread, _| { + let mut messages = thread.messages(); + assert!( + messages.any(|msg| { + msg.role == Role::System + && msg.ui_only + && msg.segments.iter().any(|seg| { + if let MessageSegment::Text(text) = seg { + text.contains("overloaded") + && text + .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)) + } else { + false + } + }) + }), + "Should have added a system retry message" + ); + }); + + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + .count() + }); + + assert_eq!(retry_count, 1, "Should have one retry message"); + } + + #[gpui::test] + async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Create model that returns internal server error + let model = Arc::new(ErrorInjector::new(TestError::InternalServerError)); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Start completion + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + + cx.run_until_parked(); + + // Check retry state on thread + thread.read_with(cx, |thread, _| { + assert!(thread.retry_state.is_some(), "Should have retry state"); + let retry_state = thread.retry_state.as_ref().unwrap(); + assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); + assert_eq!( + retry_state.max_attempts, MAX_RETRY_ATTEMPTS, + "Should have correct max attempts" + ); + }); + + // Check that a retry message was added with provider name + thread.read_with(cx, |thread, _| { + let mut messages = thread.messages(); + assert!( + messages.any(|msg| { + msg.role == Role::System + && msg.ui_only + && msg.segments.iter().any(|seg| { + if let MessageSegment::Text(text) = seg { + text.contains("internal") + && text.contains("Fake") + && text + .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)) + } else { + false + } + }) + }), + "Should have added a system retry message with provider name" + ); + }); + + // Count retry messages + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + .count() + }); + + assert_eq!(retry_count, 1, "Should have one retry message"); + } + + #[gpui::test] + async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Create model that returns overloaded error + let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Track retry events and completion count + // Track completion events + let completion_count = Arc::new(Mutex::new(0)); + let completion_count_clone = completion_count.clone(); + + let _subscription = thread.update(cx, |_, cx| { + cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { + if let ThreadEvent::NewRequest = event { + *completion_count_clone.lock() += 1; + } + }) + }); + + // First attempt + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + cx.run_until_parked(); + + // Should have scheduled first retry - count retry messages + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + .count() + }); + assert_eq!(retry_count, 1, "Should have scheduled first retry"); + + // Check retry state + thread.read_with(cx, |thread, _| { + assert!(thread.retry_state.is_some(), "Should have retry state"); + let retry_state = thread.retry_state.as_ref().unwrap(); + assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); + }); + + // Advance clock for first retry + cx.executor() + .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); + cx.run_until_parked(); + + // Should have scheduled second retry - count retry messages + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + .count() + }); + assert_eq!(retry_count, 2, "Should have scheduled second retry"); + + // Check retry state updated + thread.read_with(cx, |thread, _| { + assert!(thread.retry_state.is_some(), "Should have retry state"); + let retry_state = thread.retry_state.as_ref().unwrap(); + assert_eq!(retry_state.attempt, 2, "Should be second retry attempt"); + assert_eq!( + retry_state.max_attempts, MAX_RETRY_ATTEMPTS, + "Should have correct max attempts" + ); + }); + + // Advance clock for second retry (exponential backoff) + cx.executor() + .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2)); + cx.run_until_parked(); + + // Should have scheduled third retry + // Count all retry messages now + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + .count() + }); + assert_eq!( + retry_count, MAX_RETRY_ATTEMPTS as usize, + "Should have scheduled third retry" + ); + + // Check retry state updated + thread.read_with(cx, |thread, _| { + assert!(thread.retry_state.is_some(), "Should have retry state"); + let retry_state = thread.retry_state.as_ref().unwrap(); + assert_eq!( + retry_state.attempt, MAX_RETRY_ATTEMPTS, + "Should be at max retry attempt" + ); + assert_eq!( + retry_state.max_attempts, MAX_RETRY_ATTEMPTS, + "Should have correct max attempts" + ); + }); + + // Advance clock for third retry (exponential backoff) + cx.executor() + .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4)); + cx.run_until_parked(); + + // No more retries should be scheduled after clock was advanced. + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + .count() + }); + assert_eq!( + retry_count, MAX_RETRY_ATTEMPTS as usize, + "Should not exceed max retries" + ); + + // Final completion count should be initial + max retries + assert_eq!( + *completion_count.lock(), + (MAX_RETRY_ATTEMPTS + 1) as usize, + "Should have made initial + max retry attempts" + ); + } + + #[gpui::test] + async fn test_max_retries_exceeded(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Create model that returns overloaded error + let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Track events + let retries_failed = Arc::new(Mutex::new(false)); + let retries_failed_clone = retries_failed.clone(); + + let _subscription = thread.update(cx, |_, cx| { + cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { + if let ThreadEvent::RetriesFailed { .. } = event { + *retries_failed_clone.lock() = true; + } + }) + }); + + // Start initial completion + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + cx.run_until_parked(); + + // Advance through all retries + for i in 0..MAX_RETRY_ATTEMPTS { + let delay = if i == 0 { + BASE_RETRY_DELAY_SECS + } else { + BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1) + }; + cx.executor().advance_clock(Duration::from_secs(delay)); + cx.run_until_parked(); + } + + // After the 3rd retry is scheduled, we need to wait for it to execute and fail + // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds) + let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32); + cx.executor() + .advance_clock(Duration::from_secs(final_delay)); + cx.run_until_parked(); + + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + .count() + }); + + // After max retries, should emit RetriesFailed event + assert_eq!( + retry_count, MAX_RETRY_ATTEMPTS as usize, + "Should have attempted max retries" + ); + assert!( + *retries_failed.lock(), + "Should emit RetriesFailed event after max retries exceeded" + ); + + // Retry state should be cleared + thread.read_with(cx, |thread, _| { + assert!( + thread.retry_state.is_none(), + "Retry state should be cleared after max retries" + ); + + // Verify we have the expected number of retry messages + let retry_messages = thread + .messages + .iter() + .filter(|msg| msg.ui_only && msg.role == Role::System) + .count(); + assert_eq!( + retry_messages, MAX_RETRY_ATTEMPTS as usize, + "Should have one retry message per attempt" + ); + }); + } + + #[gpui::test] + async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // We'll use a wrapper to switch behavior after first failure + struct RetryTestModel { + inner: Arc, + failed_once: Arc>, + } + + impl LanguageModel for RetryTestModel { + fn id(&self) -> LanguageModelId { + self.inner.id() + } + + fn name(&self) -> LanguageModelName { + self.inner.name() + } + + fn provider_id(&self) -> LanguageModelProviderId { + self.inner.provider_id() + } + + fn provider_name(&self) -> LanguageModelProviderName { + self.inner.provider_name() + } + + fn supports_tools(&self) -> bool { + self.inner.supports_tools() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + self.inner.supports_tool_choice(choice) + } + + fn supports_images(&self) -> bool { + self.inner.supports_images() + } + + fn telemetry_id(&self) -> String { + self.inner.telemetry_id() + } + + fn max_token_count(&self) -> u64 { + self.inner.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + self.inner.count_tokens(request, cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { + if !*self.failed_once.lock() { + *self.failed_once.lock() = true; + // Return error on first attempt + let stream = futures::stream::once(async move { + Err(LanguageModelCompletionError::Overloaded) + }); + async move { Ok(stream.boxed()) }.boxed() + } else { + // Succeed on retry + self.inner.stream_completion(request, cx) + } + } + + fn as_fake(&self) -> &FakeLanguageModel { + &self.inner + } + } + + let model = Arc::new(RetryTestModel { + inner: Arc::new(FakeLanguageModel::default()), + failed_once: Arc::new(Mutex::new(false)), + }); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Track message deletions + // Track when retry completes successfully + let retry_completed = Arc::new(Mutex::new(false)); + let retry_completed_clone = retry_completed.clone(); + + let _subscription = thread.update(cx, |_, cx| { + cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { + if let ThreadEvent::StreamedCompletion = event { + *retry_completed_clone.lock() = true; + } + }) + }); + + // Start completion + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + cx.run_until_parked(); + + // Get the retry message ID + let retry_message_id = thread.read_with(cx, |thread, _| { + thread + .messages() + .find(|msg| msg.role == Role::System && msg.ui_only) + .map(|msg| msg.id) + .expect("Should have a retry message") + }); + + // Wait for retry + cx.executor() + .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); + cx.run_until_parked(); + + // Stream some successful content + let fake_model = model.as_fake(); + // After the retry, there should be a new pending completion + let pending = fake_model.pending_completions(); + assert!( + !pending.is_empty(), + "Should have a pending completion after retry" + ); + fake_model.stream_completion_response(&pending[0], "Success!"); + fake_model.end_completion_stream(&pending[0]); + cx.run_until_parked(); + + // Check that the retry completed successfully + assert!( + *retry_completed.lock(), + "Retry should have completed successfully" + ); + + // Retry message should still exist but be marked as ui_only + thread.read_with(cx, |thread, _| { + let retry_msg = thread + .message(retry_message_id) + .expect("Retry message should still exist"); + assert!(retry_msg.ui_only, "Retry message should be ui_only"); + assert_eq!( + retry_msg.role, + Role::System, + "Retry message should have System role" + ); + }); + } + + #[gpui::test] + async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Create a model that fails once then succeeds + struct FailOnceModel { + inner: Arc, + failed_once: Arc>, + } + + impl LanguageModel for FailOnceModel { + fn id(&self) -> LanguageModelId { + self.inner.id() + } + + fn name(&self) -> LanguageModelName { + self.inner.name() + } + + fn provider_id(&self) -> LanguageModelProviderId { + self.inner.provider_id() + } + + fn provider_name(&self) -> LanguageModelProviderName { + self.inner.provider_name() + } + + fn supports_tools(&self) -> bool { + self.inner.supports_tools() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + self.inner.supports_tool_choice(choice) + } + + fn supports_images(&self) -> bool { + self.inner.supports_images() + } + + fn telemetry_id(&self) -> String { + self.inner.telemetry_id() + } + + fn max_token_count(&self) -> u64 { + self.inner.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + self.inner.count_tokens(request, cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { + if !*self.failed_once.lock() { + *self.failed_once.lock() = true; + // Return error on first attempt + let stream = futures::stream::once(async move { + Err(LanguageModelCompletionError::Overloaded) + }); + async move { Ok(stream.boxed()) }.boxed() + } else { + // Succeed on retry + self.inner.stream_completion(request, cx) + } + } + } + + let fail_once_model = Arc::new(FailOnceModel { + inner: Arc::new(FakeLanguageModel::default()), + failed_once: Arc::new(Mutex::new(false)), + }); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message( + "Test message", + ContextLoadResult::default(), + None, + vec![], + cx, + ); + }); + + // Start completion with fail-once model + thread.update(cx, |thread, cx| { + thread.send_to_model( + fail_once_model.clone(), + CompletionIntent::UserPrompt, + None, + cx, + ); + }); + + cx.run_until_parked(); + + // Verify retry state exists after first failure + thread.read_with(cx, |thread, _| { + assert!( + thread.retry_state.is_some(), + "Should have retry state after failure" + ); + }); + + // Wait for retry delay + cx.executor() + .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); + cx.run_until_parked(); + + // The retry should now use our FailOnceModel which should succeed + // We need to help the FakeLanguageModel complete the stream + let inner_fake = fail_once_model.inner.clone(); + + // Wait a bit for the retry to start + cx.run_until_parked(); + + // Check for pending completions and complete them + if let Some(pending) = inner_fake.pending_completions().first() { + inner_fake.stream_completion_response(pending, "Success!"); + inner_fake.end_completion_stream(pending); + } + cx.run_until_parked(); + + thread.read_with(cx, |thread, _| { + assert!( + thread.retry_state.is_none(), + "Retry state should be cleared after successful completion" + ); + + let has_assistant_message = thread + .messages + .iter() + .any(|msg| msg.role == Role::Assistant && !msg.ui_only); + assert!( + has_assistant_message, + "Should have an assistant message after successful retry" + ); + }); + } + + #[gpui::test] + async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Create a model that returns rate limit error with retry_after + struct RateLimitModel { + inner: Arc, + } + + impl LanguageModel for RateLimitModel { + fn id(&self) -> LanguageModelId { + self.inner.id() + } + + fn name(&self) -> LanguageModelName { + self.inner.name() + } + + fn provider_id(&self) -> LanguageModelProviderId { + self.inner.provider_id() + } + + fn provider_name(&self) -> LanguageModelProviderName { + self.inner.provider_name() + } + + fn supports_tools(&self) -> bool { + self.inner.supports_tools() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + self.inner.supports_tool_choice(choice) + } + + fn supports_images(&self) -> bool { + self.inner.supports_images() + } + + fn telemetry_id(&self) -> String { + self.inner.telemetry_id() + } + + fn max_token_count(&self) -> u64 { + self.inner.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + self.inner.count_tokens(request, cx) + } + + fn stream_completion( + &self, + _request: LanguageModelRequest, + _cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { + async move { + let stream = futures::stream::once(async move { + Err(LanguageModelCompletionError::RateLimitExceeded { + retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS), + }) + }); + Ok(stream.boxed()) + } + .boxed() + } + + fn as_fake(&self) -> &FakeLanguageModel { + &self.inner + } + } + + let model = Arc::new(RateLimitModel { + inner: Arc::new(FakeLanguageModel::default()), + }); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Start completion + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + + cx.run_until_parked(); + + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("rate limit exceeded") + } else { + false + } + }) + }) + .count() + }); + assert_eq!(retry_count, 1, "Should have scheduled one retry"); + + thread.read_with(cx, |thread, _| { + assert!( + thread.retry_state.is_none(), + "Rate limit errors should not set retry_state" + ); + }); + + // Verify we have one retry message + thread.read_with(cx, |thread, _| { + let retry_messages = thread + .messages + .iter() + .filter(|msg| { + msg.ui_only + && msg.segments.iter().any(|seg| { + if let MessageSegment::Text(text) = seg { + text.contains("rate limit exceeded") + } else { + false + } + }) + }) + .count(); + assert_eq!( + retry_messages, 1, + "Should have one rate limit retry message" + ); + }); + + // Check that retry message doesn't include attempt count + thread.read_with(cx, |thread, _| { + let retry_message = thread + .messages + .iter() + .find(|msg| msg.role == Role::System && msg.ui_only) + .expect("Should have a retry message"); + + // Check that the message doesn't contain attempt count + if let Some(MessageSegment::Text(text)) = retry_message.segments.first() { + assert!( + !text.contains("attempt"), + "Rate limit retry message should not contain attempt count" + ); + assert!( + text.contains(&format!( + "Retrying in {} seconds", + TEST_RATE_LIMIT_RETRY_SECS + )), + "Rate limit retry message should contain retry delay" + ); + } + }); + } + + #[gpui::test] + async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await; + + // Insert a regular user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Insert a UI-only message (like our retry notifications) + thread.update(cx, |thread, cx| { + let id = thread.next_message_id.post_inc(); + thread.messages.push(Message { + id, + role: Role::System, + segments: vec![MessageSegment::Text( + "This is a UI-only message that should not be sent to the model".to_string(), + )], + loaded_context: LoadedContext::default(), + creases: Vec::new(), + is_hidden: true, + ui_only: true, + }); + cx.emit(ThreadEvent::MessageAdded(id)); + }); + + // Insert another regular message + thread.update(cx, |thread, cx| { + thread.insert_user_message( + "How are you?", + ContextLoadResult::default(), + None, + vec![], + cx, + ); + }); + + // Generate the completion request + let request = thread.update(cx, |thread, cx| { + thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) + }); + + // Verify that the request only contains non-UI-only messages + // Should have system prompt + 2 user messages, but not the UI-only message + let user_messages: Vec<_> = request + .messages + .iter() + .filter(|msg| msg.role == Role::User) + .collect(); + assert_eq!( + user_messages.len(), + 2, + "Should have exactly 2 user messages" + ); + + // Verify the UI-only content is not present anywhere in the request + let request_text = request + .messages + .iter() + .flat_map(|msg| &msg.content) + .filter_map(|content| match content { + MessageContent::Text(text) => Some(text.as_str()), + _ => None, + }) + .collect::(); + + assert!( + !request_text.contains("UI-only message"), + "UI-only message content should not be in the request" + ); + + // Verify the thread still has all 3 messages (including UI-only) + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.messages().count(), + 3, + "Thread should have 3 messages" + ); + assert_eq!( + thread.messages().filter(|m| m.ui_only).count(), + 1, + "Thread should have 1 UI-only message" + ); + }); + + // Verify that UI-only messages are not serialized + let serialized = thread + .update(cx, |thread, cx| thread.serialize(cx)) + .await + .unwrap(); + assert_eq!( + serialized.messages.len(), + 2, + "Serialized thread should only have 2 messages (no UI-only)" + ); + } + + #[gpui::test] + async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Create model that returns overloaded error + let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Start completion + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + + cx.run_until_parked(); + + // Verify retry was scheduled by checking for retry message + let has_retry_message = thread.read_with(cx, |thread, _| { + thread.messages.iter().any(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + }); + assert!(has_retry_message, "Should have scheduled a retry"); + + // Cancel the completion before the retry happens + thread.update(cx, |thread, cx| { + thread.cancel_last_completion(None, cx); + }); + + cx.run_until_parked(); + + // The retry should not have happened - no pending completions + let fake_model = model.as_fake(); + assert_eq!( + fake_model.pending_completions().len(), + 0, + "Should have no pending completions after cancellation" + ); + + // Verify the retry was cancelled by checking retry state + thread.read_with(cx, |thread, _| { + if let Some(retry_state) = &thread.retry_state { + panic!( + "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}", + retry_state.attempt, retry_state.max_attempts, retry_state.intent + ); + } + }); + } + fn test_summarize_error( model: &Arc, thread: &Entity, diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index 0e7ca9aa89..4da959d36e 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -1140,6 +1140,9 @@ impl ActiveThread { self.save_thread(cx); cx.notify(); } + ThreadEvent::RetriesFailed { message } => { + self.show_notification(message, ui::IconName::Warning, window, cx); + } } } @@ -1835,9 +1838,10 @@ impl ActiveThread { .filter(|(id, _)| *id == message_id) .map(|(_, state)| state); - let colors = cx.theme().colors(); - let editor_bg_color = colors.editor_background; - let panel_bg = colors.panel_background; + let (editor_bg_color, panel_bg) = { + let colors = cx.theme().colors(); + (colors.editor_background, colors.panel_background) + }; let open_as_markdown = IconButton::new(("open-as-markdown", ix), IconName::DocumentText) .icon_size(IconSize::XSmall) @@ -2025,152 +2029,162 @@ impl ActiveThread { } }); - let styled_message = match message.role { - Role::User => v_flex() - .id(("message-container", ix)) - .pt_2() - .pl_2() - .pr_2p5() - .pb_4() - .child( + let styled_message = if message.ui_only { + self.render_ui_notification(message_content, ix, cx) + } else { + match message.role { + Role::User => { + let colors = cx.theme().colors(); v_flex() - .id(("user-message", ix)) - .bg(editor_bg_color) - .rounded_lg() - .shadow_md() - .border_1() - .border_color(colors.border) - .hover(|hover| hover.border_color(colors.text_accent.opacity(0.5))) + .id(("message-container", ix)) + .pt_2() + .pl_2() + .pr_2p5() + .pb_4() .child( v_flex() - .p_2p5() - .gap_1() - .children(message_content) - .when_some(editing_message_state, |this, state| { - let focus_handle = state.editor.focus_handle(cx).clone(); + .id(("user-message", ix)) + .bg(editor_bg_color) + .rounded_lg() + .shadow_md() + .border_1() + .border_color(colors.border) + .hover(|hover| hover.border_color(colors.text_accent.opacity(0.5))) + .child( + v_flex() + .p_2p5() + .gap_1() + .children(message_content) + .when_some(editing_message_state, |this, state| { + let focus_handle = state.editor.focus_handle(cx).clone(); - this.child( - h_flex() - .w_full() - .gap_1() - .justify_between() - .flex_wrap() - .child( + this.child( h_flex() - .gap_1p5() + .w_full() + .gap_1() + .justify_between() + .flex_wrap() .child( - div() - .opacity(0.8) + h_flex() + .gap_1p5() .child( - Icon::new(IconName::Warning) - .size(IconSize::Indicator) - .color(Color::Warning) + div() + .opacity(0.8) + .child( + Icon::new(IconName::Warning) + .size(IconSize::Indicator) + .color(Color::Warning) + ), + ) + .child( + Label::new("Editing will restart the thread from this point.") + .color(Color::Muted) + .size(LabelSize::XSmall), ), ) .child( - Label::new("Editing will restart the thread from this point.") - .color(Color::Muted) - .size(LabelSize::XSmall), - ), - ) - .child( - h_flex() - .gap_0p5() - .child( - IconButton::new( - "cancel-edit-message", - IconName::Close, - ) - .shape(ui::IconButtonShape::Square) - .icon_color(Color::Error) - .icon_size(IconSize::Small) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Cancel Edit", - &menu::Cancel, - &focus_handle, - window, - cx, + h_flex() + .gap_0p5() + .child( + IconButton::new( + "cancel-edit-message", + IconName::Close, ) - } - }) - .on_click(cx.listener(Self::handle_cancel_click)), + .shape(ui::IconButtonShape::Square) + .icon_color(Color::Error) + .icon_size(IconSize::Small) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + "Cancel Edit", + &menu::Cancel, + &focus_handle, + window, + cx, + ) + } + }) + .on_click(cx.listener(Self::handle_cancel_click)), + ) + .child( + IconButton::new( + "confirm-edit-message", + IconName::Return, + ) + .disabled(state.editor.read(cx).is_empty(cx)) + .shape(ui::IconButtonShape::Square) + .icon_color(Color::Muted) + .icon_size(IconSize::Small) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + "Regenerate", + &menu::Confirm, + &focus_handle, + window, + cx, + ) + } + }) + .on_click( + cx.listener(Self::handle_regenerate_click), + ), + ), ) - .child( - IconButton::new( - "confirm-edit-message", - IconName::Return, - ) - .disabled(state.editor.read(cx).is_empty(cx)) - .shape(ui::IconButtonShape::Square) - .icon_color(Color::Muted) - .icon_size(IconSize::Small) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Regenerate", - &menu::Confirm, - &focus_handle, - window, - cx, - ) - } - }) - .on_click( - cx.listener(Self::handle_regenerate_click), - ), - ), ) - ) - }), + }), + ) + .on_click(cx.listener({ + let message_creases = message.creases.clone(); + move |this, _, window, cx| { + if let Some(message_text) = + this.thread.read(cx).message(message_id).and_then(|message| { + message.segments.first().and_then(|segment| { + match segment { + MessageSegment::Text(message_text) => { + Some(Into::>::into(message_text.as_str())) + } + _ => { + None + } + } + }) + }) + { + this.start_editing_message( + message_id, + message_text, + &message_creases, + window, + cx, + ); + } + } + })), ) - .on_click(cx.listener({ - let message_creases = message.creases.clone(); - move |this, _, window, cx| { - if let Some(message_text) = - this.thread.read(cx).message(message_id).and_then(|message| { - message.segments.first().and_then(|segment| { - match segment { - MessageSegment::Text(message_text) => { - Some(Into::>::into(message_text.as_str())) - } - _ => { - None - } - } - }) - }) - { - this.start_editing_message( - message_id, - message_text, - &message_creases, - window, - cx, - ); - } - } - })), - ), - Role::Assistant => v_flex() - .id(("message-container", ix)) - .px(RESPONSE_PADDING_X) - .gap_2() - .children(message_content) - .when(has_tool_uses, |parent| { - parent.children(tool_uses.into_iter().map(|tool_use| { - self.render_tool_use(tool_use, window, workspace.clone(), cx) - })) - }), - Role::System => div().id(("message-container", ix)).py_1().px_2().child( - v_flex() - .bg(colors.editor_background) - .rounded_sm() - .child(div().p_4().children(message_content)), - ), + } + Role::Assistant => v_flex() + .id(("message-container", ix)) + .px(RESPONSE_PADDING_X) + .gap_2() + .children(message_content) + .when(has_tool_uses, |parent| { + parent.children(tool_uses.into_iter().map(|tool_use| { + self.render_tool_use(tool_use, window, workspace.clone(), cx) + })) + }), + Role::System => { + let colors = cx.theme().colors(); + div().id(("message-container", ix)).py_1().px_2().child( + v_flex() + .bg(colors.editor_background) + .rounded_sm() + .child(div().p_4().children(message_content)), + ) + } + } }; let after_editing_message = self @@ -2509,6 +2523,42 @@ impl ActiveThread { .blend(cx.theme().colors().editor_foreground.opacity(0.025)) } + fn render_ui_notification( + &self, + message_content: impl IntoIterator, + ix: usize, + cx: &mut Context, + ) -> Stateful
{ + let colors = cx.theme().colors(); + div().id(("message-container", ix)).py_1().px_2().child( + v_flex() + .w_full() + .bg(colors.editor_background) + .rounded_sm() + .child( + h_flex() + .w_full() + .p_2() + .gap_2() + .child( + div().flex_none().child( + Icon::new(IconName::Warning) + .size(IconSize::Small) + .color(Color::Warning), + ), + ) + .child( + v_flex() + .flex_1() + .min_w_0() + .text_size(TextSize::Small.rems(cx)) + .text_color(cx.theme().colors().text_muted) + .children(message_content), + ), + ), + ) + } + fn render_message_thinking_segment( &self, message_id: MessageId, @@ -3763,9 +3813,9 @@ mod tests { // Stream response to user message thread.update(cx, |thread, cx| { - let request = - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx); - thread.stream_completion(request, model, cx.active_window(), cx) + let intent = CompletionIntent::UserPrompt; + let request = thread.to_completion_request(model.clone(), intent, cx); + thread.stream_completion(request, model, intent, cx.active_window(), cx) }); // Follow the agent cx.update(|window, cx| { diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index aa5e49551b..b8e67512e2 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1380,6 +1380,7 @@ impl AgentDiff { | ThreadEvent::ToolConfirmationNeeded | ThreadEvent::ToolUseLimitReached | ThreadEvent::CancelEditing + | ThreadEvent::RetriesFailed { .. } | ThreadEvent::ProfileChanged => {} } } diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 09770364cb..904eca83e6 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -221,6 +221,9 @@ impl ExampleContext { ThreadEvent::ShowError(thread_error) => { tx.try_send(Err(anyhow!(thread_error.clone()))).ok(); } + ThreadEvent::RetriesFailed { .. } => { + // Ignore retries failed events + } ThreadEvent::Stopped(reason) => match reason { Ok(StopReason::EndTurn) => { tx.close_channel();