diff --git a/Cargo.lock b/Cargo.lock index f50842dddc..26fce3c46b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -78,6 +78,7 @@ dependencies = [ "language", "language_model", "log", + "parking_lot", "paths", "postage", "pretty_assertions", diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index f320e58d00..135363ab65 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -72,6 +72,7 @@ gpui = { workspace = true, "features" = ["test-support"] } indoc.workspace = true language = { workspace = true, "features" = ["test-support"] } language_model = { workspace = true, "features" = ["test-support"] } +parking_lot.workspace = true pretty_assertions.workspace = true project = { workspace = true, features = ["test-support"] } workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 33b9209f0c..68624d7c3b 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -39,12 +39,20 @@ use proto::Plan; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; -use std::{io::Write, ops::Range, sync::Arc, time::Instant}; +use std::{ + io::Write, + ops::Range, + sync::Arc, + time::{Duration, Instant}, +}; use thiserror::Error; use util::{ResultExt as _, post_inc}; use uuid::Uuid; use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; +const MAX_RETRY_ATTEMPTS: u8 = 3; +const BASE_RETRY_DELAY_SECS: u64 = 5; + #[derive( Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, )] @@ -118,6 +126,7 @@ pub struct Message { pub loaded_context: LoadedContext, pub creases: Vec, pub is_hidden: bool, + pub ui_only: bool, } impl Message { @@ -364,6 +373,7 @@ pub struct Thread { exceeded_window_error: Option, tool_use_limit_reached: bool, feedback: Option, + retry_state: Option, message_feedback: HashMap, last_auto_capture_at: Option, last_received_chunk_at: Option, @@ -375,6 +385,13 @@ pub struct Thread { profile: AgentProfile, } +#[derive(Clone, Debug)] +struct RetryState { + attempt: u8, + max_attempts: u8, + intent: CompletionIntent, +} + #[derive(Clone, Debug, PartialEq, Eq)] pub enum ThreadSummary { Pending, @@ -456,6 +473,7 @@ impl Thread { exceeded_window_error: None, tool_use_limit_reached: false, feedback: None, + retry_state: None, message_feedback: HashMap::default(), last_auto_capture_at: None, last_received_chunk_at: None, @@ -522,6 +540,7 @@ impl Thread { detailed_summary_tx, detailed_summary_rx, completion_mode, + retry_state: None, messages: serialized .messages .into_iter() @@ -557,6 +576,7 @@ impl Thread { }) .collect(), is_hidden: message.is_hidden, + ui_only: false, // UI-only messages are not persisted }) .collect(), next_message_id, @@ -1041,6 +1061,7 @@ impl Thread { loaded_context, creases, is_hidden, + ui_only: false, }); self.touch_updated_at(); cx.emit(ThreadEvent::MessageAdded(id)); @@ -1130,6 +1151,7 @@ impl Thread { updated_at: this.updated_at(), messages: this .messages() + .filter(|message| !message.ui_only) .map(|message| SerializedMessage { id: message.id, role: message.role, @@ -1228,7 +1250,7 @@ impl Thread { let request = self.to_completion_request(model.clone(), intent, cx); - self.stream_completion(request, model, window, cx); + self.stream_completion(request, model, intent, window, cx); } pub fn used_tools_since_last_user_message(&self) -> bool { @@ -1303,6 +1325,11 @@ impl Thread { let mut message_ix_to_cache = None; for message in &self.messages { + // ui_only messages are for the UI only, not for the model + if message.ui_only { + continue; + } + let mut request_message = LanguageModelRequestMessage { role: message.role, content: Vec::new(), @@ -1457,6 +1484,7 @@ impl Thread { &mut self, request: LanguageModelRequest, model: Arc, + intent: CompletionIntent, window: Option, cx: &mut Context, ) { @@ -1770,58 +1798,64 @@ impl Thread { }; let result = stream_completion.await; + let mut retry_scheduled = false; thread .update(cx, |thread, cx| { thread.finalize_pending_checkpoint(cx); match result.as_ref() { - Ok(stop_reason) => match stop_reason { - StopReason::ToolUse => { - let tool_uses = thread.use_pending_tools(window, model.clone(), cx); - cx.emit(ThreadEvent::UsePendingTools { tool_uses }); - } - StopReason::EndTurn | StopReason::MaxTokens => { - thread.project.update(cx, |project, cx| { - project.set_agent_location(None, cx); - }); - } - StopReason::Refusal => { - thread.project.update(cx, |project, cx| { - project.set_agent_location(None, cx); - }); + Ok(stop_reason) => { + match stop_reason { + StopReason::ToolUse => { + let tool_uses = thread.use_pending_tools(window, model.clone(), cx); + cx.emit(ThreadEvent::UsePendingTools { tool_uses }); + } + StopReason::EndTurn | StopReason::MaxTokens => { + thread.project.update(cx, |project, cx| { + project.set_agent_location(None, cx); + }); + } + StopReason::Refusal => { + thread.project.update(cx, |project, cx| { + project.set_agent_location(None, cx); + }); - // Remove the turn that was refused. - // - // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal - { - let mut messages_to_remove = Vec::new(); + // Remove the turn that was refused. + // + // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal + { + let mut messages_to_remove = Vec::new(); - for (ix, message) in thread.messages.iter().enumerate().rev() { - messages_to_remove.push(message.id); + for (ix, message) in thread.messages.iter().enumerate().rev() { + messages_to_remove.push(message.id); - if message.role == Role::User { - if ix == 0 { - break; - } - - if let Some(prev_message) = thread.messages.get(ix - 1) { - if prev_message.role == Role::Assistant { + if message.role == Role::User { + if ix == 0 { break; } + + if let Some(prev_message) = thread.messages.get(ix - 1) { + if prev_message.role == Role::Assistant { + break; + } + } } } + + for message_id in messages_to_remove { + thread.delete_message(message_id, cx); + } } - for message_id in messages_to_remove { - thread.delete_message(message_id, cx); - } + cx.emit(ThreadEvent::ShowError(ThreadError::Message { + header: "Language model refusal".into(), + message: "Model refused to generate content for safety reasons.".into(), + })); } - - cx.emit(ThreadEvent::ShowError(ThreadError::Message { - header: "Language model refusal".into(), - message: "Model refused to generate content for safety reasons.".into(), - })); } + + // We successfully completed, so cancel any remaining retries. + thread.retry_state = None; }, Err(error) => { thread.project.update(cx, |project, cx| { @@ -1859,17 +1893,58 @@ impl Thread { }); cx.notify(); } - LanguageModelKnownError::RateLimitExceeded { .. } => { - // In the future we will report the error to the user, wait retry_after, and then retry. - emit_generic_error(error, cx); + LanguageModelKnownError::RateLimitExceeded { retry_after } => { + let provider_name = model.provider_name(); + let error_message = format!( + "{}'s API rate limit exceeded", + provider_name.0.as_ref() + ); + + thread.handle_rate_limit_error( + &error_message, + *retry_after, + model.clone(), + intent, + window, + cx, + ); + retry_scheduled = true; } LanguageModelKnownError::Overloaded => { - // In the future we will wait and then retry, up to N times. - emit_generic_error(error, cx); + let provider_name = model.provider_name(); + let error_message = format!( + "{}'s API servers are overloaded right now", + provider_name.0.as_ref() + ); + + retry_scheduled = thread.handle_retryable_error( + &error_message, + model.clone(), + intent, + window, + cx, + ); + if !retry_scheduled { + emit_generic_error(error, cx); + } } LanguageModelKnownError::ApiInternalServerError => { - // In the future we will retry the request, but only once. - emit_generic_error(error, cx); + let provider_name = model.provider_name(); + let error_message = format!( + "{}'s API server reported an internal server error", + provider_name.0.as_ref() + ); + + retry_scheduled = thread.handle_retryable_error( + &error_message, + model.clone(), + intent, + window, + cx, + ); + if !retry_scheduled { + emit_generic_error(error, cx); + } } LanguageModelKnownError::ReadResponseError(_) | LanguageModelKnownError::DeserializeResponse(_) | @@ -1882,11 +1957,15 @@ impl Thread { emit_generic_error(error, cx); } - thread.cancel_last_completion(window, cx); + if !retry_scheduled { + thread.cancel_last_completion(window, cx); + } } } - cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new))); + if !retry_scheduled { + cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new))); + } if let Some((request_callback, (request, response_events))) = thread .request_callback @@ -2002,6 +2081,146 @@ impl Thread { }); } + fn handle_rate_limit_error( + &mut self, + error_message: &str, + retry_after: Duration, + model: Arc, + intent: CompletionIntent, + window: Option, + cx: &mut Context, + ) { + // For rate limit errors, we only retry once with the specified duration + let retry_message = format!( + "{error_message}. Retrying in {} seconds…", + retry_after.as_secs() + ); + + // Add a UI-only message instead of a regular message + let id = self.next_message_id.post_inc(); + self.messages.push(Message { + id, + role: Role::System, + segments: vec![MessageSegment::Text(retry_message)], + loaded_context: LoadedContext::default(), + creases: Vec::new(), + is_hidden: false, + ui_only: true, + }); + cx.emit(ThreadEvent::MessageAdded(id)); + // Schedule the retry + let thread_handle = cx.entity().downgrade(); + + cx.spawn(async move |_thread, cx| { + cx.background_executor().timer(retry_after).await; + + thread_handle + .update(cx, |thread, cx| { + // Retry the completion + thread.send_to_model(model, intent, window, cx); + }) + .log_err(); + }) + .detach(); + } + + fn handle_retryable_error( + &mut self, + error_message: &str, + model: Arc, + intent: CompletionIntent, + window: Option, + cx: &mut Context, + ) -> bool { + self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx) + } + + fn handle_retryable_error_with_delay( + &mut self, + error_message: &str, + custom_delay: Option, + model: Arc, + intent: CompletionIntent, + window: Option, + cx: &mut Context, + ) -> bool { + let retry_state = self.retry_state.get_or_insert(RetryState { + attempt: 0, + max_attempts: MAX_RETRY_ATTEMPTS, + intent, + }); + + retry_state.attempt += 1; + let attempt = retry_state.attempt; + let max_attempts = retry_state.max_attempts; + let intent = retry_state.intent; + + if attempt <= max_attempts { + // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff + let delay = if let Some(custom_delay) = custom_delay { + custom_delay + } else { + let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32); + Duration::from_secs(delay_secs) + }; + + // Add a transient message to inform the user + let delay_secs = delay.as_secs(); + let retry_message = format!( + "{}. Retrying (attempt {} of {}) in {} seconds...", + error_message, attempt, max_attempts, delay_secs + ); + + // Add a UI-only message instead of a regular message + let id = self.next_message_id.post_inc(); + self.messages.push(Message { + id, + role: Role::System, + segments: vec![MessageSegment::Text(retry_message)], + loaded_context: LoadedContext::default(), + creases: Vec::new(), + is_hidden: false, + ui_only: true, + }); + cx.emit(ThreadEvent::MessageAdded(id)); + + // Schedule the retry + let thread_handle = cx.entity().downgrade(); + + cx.spawn(async move |_thread, cx| { + cx.background_executor().timer(delay).await; + + thread_handle + .update(cx, |thread, cx| { + // Retry the completion + thread.send_to_model(model, intent, window, cx); + }) + .log_err(); + }) + .detach(); + + true + } else { + // Max retries exceeded + self.retry_state = None; + + let notification_text = if max_attempts == 1 { + "Failed after retrying.".into() + } else { + format!("Failed after retrying {} times.", max_attempts).into() + }; + + // Stop generating since we're giving up on retrying. + self.pending_completions.clear(); + + cx.emit(ThreadEvent::RetriesFailed { + message: notification_text, + }); + + false + } + } + pub fn start_generating_detailed_summary_if_needed( &mut self, thread_store: WeakEntity, @@ -2354,7 +2573,9 @@ impl Thread { window: Option, cx: &mut Context, ) -> bool { - let mut canceled = self.pending_completions.pop().is_some(); + let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some(); + + self.retry_state = None; for pending_tool_use in self.tool_use.cancel_pending() { canceled = true; @@ -2943,6 +3164,9 @@ pub enum ThreadEvent { CancelEditing, CompletionCanceled, ProfileChanged, + RetriesFailed { + message: SharedString, + }, } impl EventEmitter for Thread {} @@ -3038,16 +3262,28 @@ mod tests { use crate::{ context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore, }; + + // Test-specific constants + const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30; use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters}; use assistant_tool::ToolRegistry; + use futures::StreamExt; + use futures::future::BoxFuture; + use futures::stream::BoxStream; use gpui::TestAppContext; use icons::IconName; use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}; + use language_model::{ + LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelToolChoice, + }; + use parking_lot::Mutex; use project::{FakeFs, Project}; use prompt_store::PromptBuilder; use serde_json::json; use settings::{Settings, SettingsStore}; use std::sync::Arc; + use std::time::Duration; use theme::ThemeSettings; use util::path; use workspace::Workspace; @@ -3822,6 +4058,1183 @@ fn main() {{ } } + // Helper to create a model that returns errors + enum TestError { + Overloaded, + InternalServerError, + } + + struct ErrorInjector { + inner: Arc, + error_type: TestError, + } + + impl ErrorInjector { + fn new(error_type: TestError) -> Self { + Self { + inner: Arc::new(FakeLanguageModel::default()), + error_type, + } + } + } + + impl LanguageModel for ErrorInjector { + fn id(&self) -> LanguageModelId { + self.inner.id() + } + + fn name(&self) -> LanguageModelName { + self.inner.name() + } + + fn provider_id(&self) -> LanguageModelProviderId { + self.inner.provider_id() + } + + fn provider_name(&self) -> LanguageModelProviderName { + self.inner.provider_name() + } + + fn supports_tools(&self) -> bool { + self.inner.supports_tools() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + self.inner.supports_tool_choice(choice) + } + + fn supports_images(&self) -> bool { + self.inner.supports_images() + } + + fn telemetry_id(&self) -> String { + self.inner.telemetry_id() + } + + fn max_token_count(&self) -> u64 { + self.inner.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + self.inner.count_tokens(request, cx) + } + + fn stream_completion( + &self, + _request: LanguageModelRequest, + _cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { + let error = match self.error_type { + TestError::Overloaded => LanguageModelCompletionError::Overloaded, + TestError::InternalServerError => { + LanguageModelCompletionError::ApiInternalServerError + } + }; + async move { + let stream = futures::stream::once(async move { Err(error) }); + Ok(stream.boxed()) + } + .boxed() + } + + fn as_fake(&self) -> &FakeLanguageModel { + &self.inner + } + } + + #[gpui::test] + async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Create model that returns overloaded error + let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Start completion + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + + cx.run_until_parked(); + + thread.read_with(cx, |thread, _| { + assert!(thread.retry_state.is_some(), "Should have retry state"); + let retry_state = thread.retry_state.as_ref().unwrap(); + assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); + assert_eq!( + retry_state.max_attempts, MAX_RETRY_ATTEMPTS, + "Should have default max attempts" + ); + }); + + // Check that a retry message was added + thread.read_with(cx, |thread, _| { + let mut messages = thread.messages(); + assert!( + messages.any(|msg| { + msg.role == Role::System + && msg.ui_only + && msg.segments.iter().any(|seg| { + if let MessageSegment::Text(text) = seg { + text.contains("overloaded") + && text + .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)) + } else { + false + } + }) + }), + "Should have added a system retry message" + ); + }); + + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + .count() + }); + + assert_eq!(retry_count, 1, "Should have one retry message"); + } + + #[gpui::test] + async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Create model that returns internal server error + let model = Arc::new(ErrorInjector::new(TestError::InternalServerError)); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Start completion + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + + cx.run_until_parked(); + + // Check retry state on thread + thread.read_with(cx, |thread, _| { + assert!(thread.retry_state.is_some(), "Should have retry state"); + let retry_state = thread.retry_state.as_ref().unwrap(); + assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); + assert_eq!( + retry_state.max_attempts, MAX_RETRY_ATTEMPTS, + "Should have correct max attempts" + ); + }); + + // Check that a retry message was added with provider name + thread.read_with(cx, |thread, _| { + let mut messages = thread.messages(); + assert!( + messages.any(|msg| { + msg.role == Role::System + && msg.ui_only + && msg.segments.iter().any(|seg| { + if let MessageSegment::Text(text) = seg { + text.contains("internal") + && text.contains("Fake") + && text + .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)) + } else { + false + } + }) + }), + "Should have added a system retry message with provider name" + ); + }); + + // Count retry messages + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + .count() + }); + + assert_eq!(retry_count, 1, "Should have one retry message"); + } + + #[gpui::test] + async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Create model that returns overloaded error + let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Track retry events and completion count + // Track completion events + let completion_count = Arc::new(Mutex::new(0)); + let completion_count_clone = completion_count.clone(); + + let _subscription = thread.update(cx, |_, cx| { + cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { + if let ThreadEvent::NewRequest = event { + *completion_count_clone.lock() += 1; + } + }) + }); + + // First attempt + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + cx.run_until_parked(); + + // Should have scheduled first retry - count retry messages + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + .count() + }); + assert_eq!(retry_count, 1, "Should have scheduled first retry"); + + // Check retry state + thread.read_with(cx, |thread, _| { + assert!(thread.retry_state.is_some(), "Should have retry state"); + let retry_state = thread.retry_state.as_ref().unwrap(); + assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); + }); + + // Advance clock for first retry + cx.executor() + .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); + cx.run_until_parked(); + + // Should have scheduled second retry - count retry messages + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + .count() + }); + assert_eq!(retry_count, 2, "Should have scheduled second retry"); + + // Check retry state updated + thread.read_with(cx, |thread, _| { + assert!(thread.retry_state.is_some(), "Should have retry state"); + let retry_state = thread.retry_state.as_ref().unwrap(); + assert_eq!(retry_state.attempt, 2, "Should be second retry attempt"); + assert_eq!( + retry_state.max_attempts, MAX_RETRY_ATTEMPTS, + "Should have correct max attempts" + ); + }); + + // Advance clock for second retry (exponential backoff) + cx.executor() + .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2)); + cx.run_until_parked(); + + // Should have scheduled third retry + // Count all retry messages now + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + .count() + }); + assert_eq!( + retry_count, MAX_RETRY_ATTEMPTS as usize, + "Should have scheduled third retry" + ); + + // Check retry state updated + thread.read_with(cx, |thread, _| { + assert!(thread.retry_state.is_some(), "Should have retry state"); + let retry_state = thread.retry_state.as_ref().unwrap(); + assert_eq!( + retry_state.attempt, MAX_RETRY_ATTEMPTS, + "Should be at max retry attempt" + ); + assert_eq!( + retry_state.max_attempts, MAX_RETRY_ATTEMPTS, + "Should have correct max attempts" + ); + }); + + // Advance clock for third retry (exponential backoff) + cx.executor() + .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4)); + cx.run_until_parked(); + + // No more retries should be scheduled after clock was advanced. + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + .count() + }); + assert_eq!( + retry_count, MAX_RETRY_ATTEMPTS as usize, + "Should not exceed max retries" + ); + + // Final completion count should be initial + max retries + assert_eq!( + *completion_count.lock(), + (MAX_RETRY_ATTEMPTS + 1) as usize, + "Should have made initial + max retry attempts" + ); + } + + #[gpui::test] + async fn test_max_retries_exceeded(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Create model that returns overloaded error + let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Track events + let retries_failed = Arc::new(Mutex::new(false)); + let retries_failed_clone = retries_failed.clone(); + + let _subscription = thread.update(cx, |_, cx| { + cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { + if let ThreadEvent::RetriesFailed { .. } = event { + *retries_failed_clone.lock() = true; + } + }) + }); + + // Start initial completion + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + cx.run_until_parked(); + + // Advance through all retries + for i in 0..MAX_RETRY_ATTEMPTS { + let delay = if i == 0 { + BASE_RETRY_DELAY_SECS + } else { + BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1) + }; + cx.executor().advance_clock(Duration::from_secs(delay)); + cx.run_until_parked(); + } + + // After the 3rd retry is scheduled, we need to wait for it to execute and fail + // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds) + let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32); + cx.executor() + .advance_clock(Duration::from_secs(final_delay)); + cx.run_until_parked(); + + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + .count() + }); + + // After max retries, should emit RetriesFailed event + assert_eq!( + retry_count, MAX_RETRY_ATTEMPTS as usize, + "Should have attempted max retries" + ); + assert!( + *retries_failed.lock(), + "Should emit RetriesFailed event after max retries exceeded" + ); + + // Retry state should be cleared + thread.read_with(cx, |thread, _| { + assert!( + thread.retry_state.is_none(), + "Retry state should be cleared after max retries" + ); + + // Verify we have the expected number of retry messages + let retry_messages = thread + .messages + .iter() + .filter(|msg| msg.ui_only && msg.role == Role::System) + .count(); + assert_eq!( + retry_messages, MAX_RETRY_ATTEMPTS as usize, + "Should have one retry message per attempt" + ); + }); + } + + #[gpui::test] + async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // We'll use a wrapper to switch behavior after first failure + struct RetryTestModel { + inner: Arc, + failed_once: Arc>, + } + + impl LanguageModel for RetryTestModel { + fn id(&self) -> LanguageModelId { + self.inner.id() + } + + fn name(&self) -> LanguageModelName { + self.inner.name() + } + + fn provider_id(&self) -> LanguageModelProviderId { + self.inner.provider_id() + } + + fn provider_name(&self) -> LanguageModelProviderName { + self.inner.provider_name() + } + + fn supports_tools(&self) -> bool { + self.inner.supports_tools() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + self.inner.supports_tool_choice(choice) + } + + fn supports_images(&self) -> bool { + self.inner.supports_images() + } + + fn telemetry_id(&self) -> String { + self.inner.telemetry_id() + } + + fn max_token_count(&self) -> u64 { + self.inner.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + self.inner.count_tokens(request, cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { + if !*self.failed_once.lock() { + *self.failed_once.lock() = true; + // Return error on first attempt + let stream = futures::stream::once(async move { + Err(LanguageModelCompletionError::Overloaded) + }); + async move { Ok(stream.boxed()) }.boxed() + } else { + // Succeed on retry + self.inner.stream_completion(request, cx) + } + } + + fn as_fake(&self) -> &FakeLanguageModel { + &self.inner + } + } + + let model = Arc::new(RetryTestModel { + inner: Arc::new(FakeLanguageModel::default()), + failed_once: Arc::new(Mutex::new(false)), + }); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Track message deletions + // Track when retry completes successfully + let retry_completed = Arc::new(Mutex::new(false)); + let retry_completed_clone = retry_completed.clone(); + + let _subscription = thread.update(cx, |_, cx| { + cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { + if let ThreadEvent::StreamedCompletion = event { + *retry_completed_clone.lock() = true; + } + }) + }); + + // Start completion + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + cx.run_until_parked(); + + // Get the retry message ID + let retry_message_id = thread.read_with(cx, |thread, _| { + thread + .messages() + .find(|msg| msg.role == Role::System && msg.ui_only) + .map(|msg| msg.id) + .expect("Should have a retry message") + }); + + // Wait for retry + cx.executor() + .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); + cx.run_until_parked(); + + // Stream some successful content + let fake_model = model.as_fake(); + // After the retry, there should be a new pending completion + let pending = fake_model.pending_completions(); + assert!( + !pending.is_empty(), + "Should have a pending completion after retry" + ); + fake_model.stream_completion_response(&pending[0], "Success!"); + fake_model.end_completion_stream(&pending[0]); + cx.run_until_parked(); + + // Check that the retry completed successfully + assert!( + *retry_completed.lock(), + "Retry should have completed successfully" + ); + + // Retry message should still exist but be marked as ui_only + thread.read_with(cx, |thread, _| { + let retry_msg = thread + .message(retry_message_id) + .expect("Retry message should still exist"); + assert!(retry_msg.ui_only, "Retry message should be ui_only"); + assert_eq!( + retry_msg.role, + Role::System, + "Retry message should have System role" + ); + }); + } + + #[gpui::test] + async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Create a model that fails once then succeeds + struct FailOnceModel { + inner: Arc, + failed_once: Arc>, + } + + impl LanguageModel for FailOnceModel { + fn id(&self) -> LanguageModelId { + self.inner.id() + } + + fn name(&self) -> LanguageModelName { + self.inner.name() + } + + fn provider_id(&self) -> LanguageModelProviderId { + self.inner.provider_id() + } + + fn provider_name(&self) -> LanguageModelProviderName { + self.inner.provider_name() + } + + fn supports_tools(&self) -> bool { + self.inner.supports_tools() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + self.inner.supports_tool_choice(choice) + } + + fn supports_images(&self) -> bool { + self.inner.supports_images() + } + + fn telemetry_id(&self) -> String { + self.inner.telemetry_id() + } + + fn max_token_count(&self) -> u64 { + self.inner.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + self.inner.count_tokens(request, cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { + if !*self.failed_once.lock() { + *self.failed_once.lock() = true; + // Return error on first attempt + let stream = futures::stream::once(async move { + Err(LanguageModelCompletionError::Overloaded) + }); + async move { Ok(stream.boxed()) }.boxed() + } else { + // Succeed on retry + self.inner.stream_completion(request, cx) + } + } + } + + let fail_once_model = Arc::new(FailOnceModel { + inner: Arc::new(FakeLanguageModel::default()), + failed_once: Arc::new(Mutex::new(false)), + }); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message( + "Test message", + ContextLoadResult::default(), + None, + vec![], + cx, + ); + }); + + // Start completion with fail-once model + thread.update(cx, |thread, cx| { + thread.send_to_model( + fail_once_model.clone(), + CompletionIntent::UserPrompt, + None, + cx, + ); + }); + + cx.run_until_parked(); + + // Verify retry state exists after first failure + thread.read_with(cx, |thread, _| { + assert!( + thread.retry_state.is_some(), + "Should have retry state after failure" + ); + }); + + // Wait for retry delay + cx.executor() + .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); + cx.run_until_parked(); + + // The retry should now use our FailOnceModel which should succeed + // We need to help the FakeLanguageModel complete the stream + let inner_fake = fail_once_model.inner.clone(); + + // Wait a bit for the retry to start + cx.run_until_parked(); + + // Check for pending completions and complete them + if let Some(pending) = inner_fake.pending_completions().first() { + inner_fake.stream_completion_response(pending, "Success!"); + inner_fake.end_completion_stream(pending); + } + cx.run_until_parked(); + + thread.read_with(cx, |thread, _| { + assert!( + thread.retry_state.is_none(), + "Retry state should be cleared after successful completion" + ); + + let has_assistant_message = thread + .messages + .iter() + .any(|msg| msg.role == Role::Assistant && !msg.ui_only); + assert!( + has_assistant_message, + "Should have an assistant message after successful retry" + ); + }); + } + + #[gpui::test] + async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Create a model that returns rate limit error with retry_after + struct RateLimitModel { + inner: Arc, + } + + impl LanguageModel for RateLimitModel { + fn id(&self) -> LanguageModelId { + self.inner.id() + } + + fn name(&self) -> LanguageModelName { + self.inner.name() + } + + fn provider_id(&self) -> LanguageModelProviderId { + self.inner.provider_id() + } + + fn provider_name(&self) -> LanguageModelProviderName { + self.inner.provider_name() + } + + fn supports_tools(&self) -> bool { + self.inner.supports_tools() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + self.inner.supports_tool_choice(choice) + } + + fn supports_images(&self) -> bool { + self.inner.supports_images() + } + + fn telemetry_id(&self) -> String { + self.inner.telemetry_id() + } + + fn max_token_count(&self) -> u64 { + self.inner.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + self.inner.count_tokens(request, cx) + } + + fn stream_completion( + &self, + _request: LanguageModelRequest, + _cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { + async move { + let stream = futures::stream::once(async move { + Err(LanguageModelCompletionError::RateLimitExceeded { + retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS), + }) + }); + Ok(stream.boxed()) + } + .boxed() + } + + fn as_fake(&self) -> &FakeLanguageModel { + &self.inner + } + } + + let model = Arc::new(RateLimitModel { + inner: Arc::new(FakeLanguageModel::default()), + }); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Start completion + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + + cx.run_until_parked(); + + let retry_count = thread.update(cx, |thread, _| { + thread + .messages + .iter() + .filter(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("rate limit exceeded") + } else { + false + } + }) + }) + .count() + }); + assert_eq!(retry_count, 1, "Should have scheduled one retry"); + + thread.read_with(cx, |thread, _| { + assert!( + thread.retry_state.is_none(), + "Rate limit errors should not set retry_state" + ); + }); + + // Verify we have one retry message + thread.read_with(cx, |thread, _| { + let retry_messages = thread + .messages + .iter() + .filter(|msg| { + msg.ui_only + && msg.segments.iter().any(|seg| { + if let MessageSegment::Text(text) = seg { + text.contains("rate limit exceeded") + } else { + false + } + }) + }) + .count(); + assert_eq!( + retry_messages, 1, + "Should have one rate limit retry message" + ); + }); + + // Check that retry message doesn't include attempt count + thread.read_with(cx, |thread, _| { + let retry_message = thread + .messages + .iter() + .find(|msg| msg.role == Role::System && msg.ui_only) + .expect("Should have a retry message"); + + // Check that the message doesn't contain attempt count + if let Some(MessageSegment::Text(text)) = retry_message.segments.first() { + assert!( + !text.contains("attempt"), + "Rate limit retry message should not contain attempt count" + ); + assert!( + text.contains(&format!( + "Retrying in {} seconds", + TEST_RATE_LIMIT_RETRY_SECS + )), + "Rate limit retry message should contain retry delay" + ); + } + }); + } + + #[gpui::test] + async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await; + + // Insert a regular user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Insert a UI-only message (like our retry notifications) + thread.update(cx, |thread, cx| { + let id = thread.next_message_id.post_inc(); + thread.messages.push(Message { + id, + role: Role::System, + segments: vec![MessageSegment::Text( + "This is a UI-only message that should not be sent to the model".to_string(), + )], + loaded_context: LoadedContext::default(), + creases: Vec::new(), + is_hidden: true, + ui_only: true, + }); + cx.emit(ThreadEvent::MessageAdded(id)); + }); + + // Insert another regular message + thread.update(cx, |thread, cx| { + thread.insert_user_message( + "How are you?", + ContextLoadResult::default(), + None, + vec![], + cx, + ); + }); + + // Generate the completion request + let request = thread.update(cx, |thread, cx| { + thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) + }); + + // Verify that the request only contains non-UI-only messages + // Should have system prompt + 2 user messages, but not the UI-only message + let user_messages: Vec<_> = request + .messages + .iter() + .filter(|msg| msg.role == Role::User) + .collect(); + assert_eq!( + user_messages.len(), + 2, + "Should have exactly 2 user messages" + ); + + // Verify the UI-only content is not present anywhere in the request + let request_text = request + .messages + .iter() + .flat_map(|msg| &msg.content) + .filter_map(|content| match content { + MessageContent::Text(text) => Some(text.as_str()), + _ => None, + }) + .collect::(); + + assert!( + !request_text.contains("UI-only message"), + "UI-only message content should not be in the request" + ); + + // Verify the thread still has all 3 messages (including UI-only) + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.messages().count(), + 3, + "Thread should have 3 messages" + ); + assert_eq!( + thread.messages().filter(|m| m.ui_only).count(), + 1, + "Thread should have 1 UI-only message" + ); + }); + + // Verify that UI-only messages are not serialized + let serialized = thread + .update(cx, |thread, cx| thread.serialize(cx)) + .await + .unwrap(); + assert_eq!( + serialized.messages.len(), + 2, + "Serialized thread should only have 2 messages (no UI-only)" + ); + } + + #[gpui::test] + async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + + // Create model that returns overloaded error + let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + + // Insert a user message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); + }); + + // Start completion + thread.update(cx, |thread, cx| { + thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); + }); + + cx.run_until_parked(); + + // Verify retry was scheduled by checking for retry message + let has_retry_message = thread.read_with(cx, |thread, _| { + thread.messages.iter().any(|m| { + m.ui_only + && m.segments.iter().any(|s| { + if let MessageSegment::Text(text) = s { + text.contains("Retrying") && text.contains("seconds") + } else { + false + } + }) + }) + }); + assert!(has_retry_message, "Should have scheduled a retry"); + + // Cancel the completion before the retry happens + thread.update(cx, |thread, cx| { + thread.cancel_last_completion(None, cx); + }); + + cx.run_until_parked(); + + // The retry should not have happened - no pending completions + let fake_model = model.as_fake(); + assert_eq!( + fake_model.pending_completions().len(), + 0, + "Should have no pending completions after cancellation" + ); + + // Verify the retry was cancelled by checking retry state + thread.read_with(cx, |thread, _| { + if let Some(retry_state) = &thread.retry_state { + panic!( + "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}", + retry_state.attempt, retry_state.max_attempts, retry_state.intent + ); + } + }); + } + fn test_summarize_error( model: &Arc, thread: &Entity, diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index 0e7ca9aa89..4da959d36e 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -1140,6 +1140,9 @@ impl ActiveThread { self.save_thread(cx); cx.notify(); } + ThreadEvent::RetriesFailed { message } => { + self.show_notification(message, ui::IconName::Warning, window, cx); + } } } @@ -1835,9 +1838,10 @@ impl ActiveThread { .filter(|(id, _)| *id == message_id) .map(|(_, state)| state); - let colors = cx.theme().colors(); - let editor_bg_color = colors.editor_background; - let panel_bg = colors.panel_background; + let (editor_bg_color, panel_bg) = { + let colors = cx.theme().colors(); + (colors.editor_background, colors.panel_background) + }; let open_as_markdown = IconButton::new(("open-as-markdown", ix), IconName::DocumentText) .icon_size(IconSize::XSmall) @@ -2025,152 +2029,162 @@ impl ActiveThread { } }); - let styled_message = match message.role { - Role::User => v_flex() - .id(("message-container", ix)) - .pt_2() - .pl_2() - .pr_2p5() - .pb_4() - .child( + let styled_message = if message.ui_only { + self.render_ui_notification(message_content, ix, cx) + } else { + match message.role { + Role::User => { + let colors = cx.theme().colors(); v_flex() - .id(("user-message", ix)) - .bg(editor_bg_color) - .rounded_lg() - .shadow_md() - .border_1() - .border_color(colors.border) - .hover(|hover| hover.border_color(colors.text_accent.opacity(0.5))) + .id(("message-container", ix)) + .pt_2() + .pl_2() + .pr_2p5() + .pb_4() .child( v_flex() - .p_2p5() - .gap_1() - .children(message_content) - .when_some(editing_message_state, |this, state| { - let focus_handle = state.editor.focus_handle(cx).clone(); + .id(("user-message", ix)) + .bg(editor_bg_color) + .rounded_lg() + .shadow_md() + .border_1() + .border_color(colors.border) + .hover(|hover| hover.border_color(colors.text_accent.opacity(0.5))) + .child( + v_flex() + .p_2p5() + .gap_1() + .children(message_content) + .when_some(editing_message_state, |this, state| { + let focus_handle = state.editor.focus_handle(cx).clone(); - this.child( - h_flex() - .w_full() - .gap_1() - .justify_between() - .flex_wrap() - .child( + this.child( h_flex() - .gap_1p5() + .w_full() + .gap_1() + .justify_between() + .flex_wrap() .child( - div() - .opacity(0.8) + h_flex() + .gap_1p5() .child( - Icon::new(IconName::Warning) - .size(IconSize::Indicator) - .color(Color::Warning) + div() + .opacity(0.8) + .child( + Icon::new(IconName::Warning) + .size(IconSize::Indicator) + .color(Color::Warning) + ), + ) + .child( + Label::new("Editing will restart the thread from this point.") + .color(Color::Muted) + .size(LabelSize::XSmall), ), ) .child( - Label::new("Editing will restart the thread from this point.") - .color(Color::Muted) - .size(LabelSize::XSmall), - ), - ) - .child( - h_flex() - .gap_0p5() - .child( - IconButton::new( - "cancel-edit-message", - IconName::Close, - ) - .shape(ui::IconButtonShape::Square) - .icon_color(Color::Error) - .icon_size(IconSize::Small) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Cancel Edit", - &menu::Cancel, - &focus_handle, - window, - cx, + h_flex() + .gap_0p5() + .child( + IconButton::new( + "cancel-edit-message", + IconName::Close, ) - } - }) - .on_click(cx.listener(Self::handle_cancel_click)), + .shape(ui::IconButtonShape::Square) + .icon_color(Color::Error) + .icon_size(IconSize::Small) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + "Cancel Edit", + &menu::Cancel, + &focus_handle, + window, + cx, + ) + } + }) + .on_click(cx.listener(Self::handle_cancel_click)), + ) + .child( + IconButton::new( + "confirm-edit-message", + IconName::Return, + ) + .disabled(state.editor.read(cx).is_empty(cx)) + .shape(ui::IconButtonShape::Square) + .icon_color(Color::Muted) + .icon_size(IconSize::Small) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + "Regenerate", + &menu::Confirm, + &focus_handle, + window, + cx, + ) + } + }) + .on_click( + cx.listener(Self::handle_regenerate_click), + ), + ), ) - .child( - IconButton::new( - "confirm-edit-message", - IconName::Return, - ) - .disabled(state.editor.read(cx).is_empty(cx)) - .shape(ui::IconButtonShape::Square) - .icon_color(Color::Muted) - .icon_size(IconSize::Small) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in( - "Regenerate", - &menu::Confirm, - &focus_handle, - window, - cx, - ) - } - }) - .on_click( - cx.listener(Self::handle_regenerate_click), - ), - ), ) - ) - }), + }), + ) + .on_click(cx.listener({ + let message_creases = message.creases.clone(); + move |this, _, window, cx| { + if let Some(message_text) = + this.thread.read(cx).message(message_id).and_then(|message| { + message.segments.first().and_then(|segment| { + match segment { + MessageSegment::Text(message_text) => { + Some(Into::>::into(message_text.as_str())) + } + _ => { + None + } + } + }) + }) + { + this.start_editing_message( + message_id, + message_text, + &message_creases, + window, + cx, + ); + } + } + })), ) - .on_click(cx.listener({ - let message_creases = message.creases.clone(); - move |this, _, window, cx| { - if let Some(message_text) = - this.thread.read(cx).message(message_id).and_then(|message| { - message.segments.first().and_then(|segment| { - match segment { - MessageSegment::Text(message_text) => { - Some(Into::>::into(message_text.as_str())) - } - _ => { - None - } - } - }) - }) - { - this.start_editing_message( - message_id, - message_text, - &message_creases, - window, - cx, - ); - } - } - })), - ), - Role::Assistant => v_flex() - .id(("message-container", ix)) - .px(RESPONSE_PADDING_X) - .gap_2() - .children(message_content) - .when(has_tool_uses, |parent| { - parent.children(tool_uses.into_iter().map(|tool_use| { - self.render_tool_use(tool_use, window, workspace.clone(), cx) - })) - }), - Role::System => div().id(("message-container", ix)).py_1().px_2().child( - v_flex() - .bg(colors.editor_background) - .rounded_sm() - .child(div().p_4().children(message_content)), - ), + } + Role::Assistant => v_flex() + .id(("message-container", ix)) + .px(RESPONSE_PADDING_X) + .gap_2() + .children(message_content) + .when(has_tool_uses, |parent| { + parent.children(tool_uses.into_iter().map(|tool_use| { + self.render_tool_use(tool_use, window, workspace.clone(), cx) + })) + }), + Role::System => { + let colors = cx.theme().colors(); + div().id(("message-container", ix)).py_1().px_2().child( + v_flex() + .bg(colors.editor_background) + .rounded_sm() + .child(div().p_4().children(message_content)), + ) + } + } }; let after_editing_message = self @@ -2509,6 +2523,42 @@ impl ActiveThread { .blend(cx.theme().colors().editor_foreground.opacity(0.025)) } + fn render_ui_notification( + &self, + message_content: impl IntoIterator, + ix: usize, + cx: &mut Context, + ) -> Stateful
{ + let colors = cx.theme().colors(); + div().id(("message-container", ix)).py_1().px_2().child( + v_flex() + .w_full() + .bg(colors.editor_background) + .rounded_sm() + .child( + h_flex() + .w_full() + .p_2() + .gap_2() + .child( + div().flex_none().child( + Icon::new(IconName::Warning) + .size(IconSize::Small) + .color(Color::Warning), + ), + ) + .child( + v_flex() + .flex_1() + .min_w_0() + .text_size(TextSize::Small.rems(cx)) + .text_color(cx.theme().colors().text_muted) + .children(message_content), + ), + ), + ) + } + fn render_message_thinking_segment( &self, message_id: MessageId, @@ -3763,9 +3813,9 @@ mod tests { // Stream response to user message thread.update(cx, |thread, cx| { - let request = - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx); - thread.stream_completion(request, model, cx.active_window(), cx) + let intent = CompletionIntent::UserPrompt; + let request = thread.to_completion_request(model.clone(), intent, cx); + thread.stream_completion(request, model, intent, cx.active_window(), cx) }); // Follow the agent cx.update(|window, cx| { diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index aa5e49551b..b8e67512e2 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1380,6 +1380,7 @@ impl AgentDiff { | ThreadEvent::ToolConfirmationNeeded | ThreadEvent::ToolUseLimitReached | ThreadEvent::CancelEditing + | ThreadEvent::RetriesFailed { .. } | ThreadEvent::ProfileChanged => {} } } diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 09770364cb..904eca83e6 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -221,6 +221,9 @@ impl ExampleContext { ThreadEvent::ShowError(thread_error) => { tx.try_send(Err(anyhow!(thread_error.clone()))).ok(); } + ThreadEvent::RetriesFailed { .. } => { + // Ignore retries failed events + } ThreadEvent::Stopped(reason) => match reason { Ok(StopReason::EndTurn) => { tx.close_channel();