diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 029d175054..d9a7a2582a 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1373,6 +1373,10 @@ impl AcpThread { }) } + pub fn can_resume(&self, cx: &App) -> bool { + self.connection.resume(&self.session_id, cx).is_some() + } + pub fn resume(&mut self, cx: &mut Context) -> BoxFuture<'static, Result<()>> { self.run_turn(cx, async move |this, cx| { this.update(cx, |this, cx| { @@ -2659,7 +2663,7 @@ mod tests { fn truncate( &self, session_id: &acp::SessionId, - _cx: &mut App, + _cx: &App, ) -> Option> { Some(Rc::new(FakeAgentSessionEditor { _session_id: session_id.clone(), diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 91e46dbac1..5f5032e588 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -43,7 +43,7 @@ pub trait AgentConnection { fn resume( &self, _session_id: &acp::SessionId, - _cx: &mut App, + _cx: &App, ) -> Option> { None } @@ -53,7 +53,7 @@ pub trait AgentConnection { fn truncate( &self, _session_id: &acp::SessionId, - _cx: &mut App, + _cx: &App, ) -> Option> { None } @@ -61,7 +61,7 @@ pub trait AgentConnection { fn set_title( &self, _session_id: &acp::SessionId, - _cx: &mut App, + _cx: &App, ) -> Option> { None } @@ -439,7 +439,7 @@ mod test_support { fn truncate( &self, _session_id: &agent_client_protocol::SessionId, - _cx: &mut App, + _cx: &App, ) -> Option> { Some(Rc::new(StubAgentSessionEditor)) } diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 4eaf87e218..415933b7d1 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -936,7 +936,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { fn resume( &self, session_id: &acp::SessionId, - _cx: &mut App, + _cx: &App, ) -> Option> { Some(Rc::new(NativeAgentSessionResume { connection: self.clone(), @@ -956,9 +956,9 @@ impl acp_thread::AgentConnection for NativeAgentConnection { fn truncate( &self, session_id: &agent_client_protocol::SessionId, - cx: &mut App, + cx: &App, ) -> Option> { - self.0.update(cx, |agent, _cx| { + self.0.read_with(cx, |agent, _cx| { agent.sessions.get(session_id).map(|session| { Rc::new(NativeAgentSessionEditor { thread: session.thread.clone(), @@ -971,7 +971,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { fn set_title( &self, session_id: &acp::SessionId, - _cx: &mut App, + _cx: &App, ) -> Option> { Some(Rc::new(NativeAgentSessionSetTitle { connection: self.clone(), diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 5b935dae4c..87ecc1037c 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -5,6 +5,7 @@ use agent_settings::AgentProfileId; use anyhow::Result; use client::{Client, UserStore}; use cloud_llm_client::CompletionIntent; +use collections::IndexMap; use context_server::{ContextServer, ContextServerCommand, ContextServerId}; use fs::{FakeFs, Fs}; use futures::{ @@ -673,15 +674,6 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) { "} ) }); - - // Ensure we error if calling resume when tool use limit was *not* reached. - let error = thread - .update(cx, |thread, cx| thread.resume(cx)) - .unwrap_err(); - assert_eq!( - error.to_string(), - "can only resume after tool use limit is reached" - ) } #[gpui::test] @@ -2105,6 +2097,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) { .unwrap(); cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Hey,"); fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded { provider: LanguageModelProviderName::new("Anthropic"), retry_after: Some(Duration::from_secs(3)), @@ -2114,8 +2107,9 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) { cx.executor().advance_clock(Duration::from_secs(3)); cx.run_until_parked(); - fake_model.send_last_completion_stream_text_chunk("Hey!"); + fake_model.send_last_completion_stream_text_chunk("there!"); fake_model.end_last_completion_stream(); + cx.run_until_parked(); let mut retry_events = Vec::new(); while let Some(Ok(event)) = events.next().await { @@ -2143,12 +2137,94 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) { ## Assistant - Hey! + Hey, + + [resume] + + ## Assistant + + there! "} ) }); } +#[gpui::test] +async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) { + let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let events = thread + .update(cx, |thread, cx| { + thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); + thread.add_tool(EchoTool); + thread.send(UserMessageId::new(), ["Call the echo tool!"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + let tool_use_1 = LanguageModelToolUse { + id: "tool_1".into(), + name: EchoTool::name().into(), + raw_input: json!({"text": "test"}).to_string(), + input: json!({"text": "test"}), + is_input_complete: true, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + tool_use_1.clone(), + )); + fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded { + provider: LanguageModelProviderName::new("Anthropic"), + retry_after: Some(Duration::from_secs(3)), + }); + fake_model.end_last_completion_stream(); + + cx.executor().advance_clock(Duration::from_secs(3)); + let completion = fake_model.pending_completions().pop().unwrap(); + assert_eq!( + completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Call the echo tool!".into()], + cache: false + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())], + cache: false + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![language_model::MessageContent::ToolResult( + LanguageModelToolResult { + tool_use_id: tool_use_1.id.clone(), + tool_name: tool_use_1.name.clone(), + is_error: false, + content: "test".into(), + output: Some("test".into()) + } + )], + cache: true + }, + ] + ); + + fake_model.send_last_completion_stream_text_chunk("Done"); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + events.collect::>().await; + thread.read_with(cx, |thread, _cx| { + assert_eq!( + thread.last_message(), + Some(Message::Agent(AgentMessage { + content: vec![AgentMessageContent::Text("Done".into())], + tool_results: IndexMap::default() + })) + ); + }) +} + #[gpui::test] async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) { let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await; diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index c000027368..43f391ca64 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -123,7 +123,7 @@ impl Message { match self { Message::User(message) => message.to_markdown(), Message::Agent(message) => message.to_markdown(), - Message::Resume => "[resumed after tool use limit was reached]".into(), + Message::Resume => "[resume]\n".into(), } } @@ -1085,11 +1085,6 @@ impl Thread { &mut self, cx: &mut Context, ) -> Result>> { - anyhow::ensure!( - self.tool_use_limit_reached, - "can only resume after tool use limit is reached" - ); - self.messages.push(Message::Resume); cx.notify(); @@ -1216,12 +1211,13 @@ impl Thread { cx: &mut AsyncApp, ) -> Result<()> { log::debug!("Stream completion started successfully"); - let request = this.update(cx, |this, cx| { - this.build_completion_request(completion_intent, cx) - })??; let mut attempt = None; - 'retry: loop { + loop { + let request = this.update(cx, |this, cx| { + this.build_completion_request(completion_intent, cx) + })??; + telemetry::event!( "Agent Thread Completion", thread_id = this.read_with(cx, |this, _| this.id.to_string())?, @@ -1236,10 +1232,11 @@ impl Thread { attempt.unwrap_or(0) ); let mut events = model - .stream_completion(request.clone(), cx) + .stream_completion(request, cx) .await .map_err(|error| anyhow!(error))?; let mut tool_results = FuturesUnordered::new(); + let mut error = None; while let Some(event) = events.next().await { match event { @@ -1249,51 +1246,9 @@ impl Thread { this.handle_streamed_completion_event(event, event_stream, cx) })??); } - Err(error) => { - let completion_mode = - this.read_with(cx, |thread, _cx| thread.completion_mode())?; - if completion_mode == CompletionMode::Normal { - return Err(anyhow!(error))?; - } - - let Some(strategy) = Self::retry_strategy_for(&error) else { - return Err(anyhow!(error))?; - }; - - let max_attempts = match &strategy { - RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts, - RetryStrategy::Fixed { max_attempts, .. } => *max_attempts, - }; - - let attempt = attempt.get_or_insert(0u8); - - *attempt += 1; - - let attempt = *attempt; - if attempt > max_attempts { - return Err(anyhow!(error))?; - } - - let delay = match &strategy { - RetryStrategy::ExponentialBackoff { initial_delay, .. } => { - let delay_secs = - initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32); - Duration::from_secs(delay_secs) - } - RetryStrategy::Fixed { delay, .. } => *delay, - }; - log::debug!("Retry attempt {attempt} with delay {delay:?}"); - - event_stream.send_retry(acp_thread::RetryStatus { - last_error: error.to_string().into(), - attempt: attempt as usize, - max_attempts: max_attempts as usize, - started_at: Instant::now(), - duration: delay, - }); - - cx.background_executor().timer(delay).await; - continue 'retry; + Err(err) => { + error = Some(err); + break; } } } @@ -1320,7 +1275,58 @@ impl Thread { })?; } - return Ok(()); + if let Some(error) = error { + let completion_mode = this.read_with(cx, |thread, _cx| thread.completion_mode())?; + if completion_mode == CompletionMode::Normal { + return Err(anyhow!(error))?; + } + + let Some(strategy) = Self::retry_strategy_for(&error) else { + return Err(anyhow!(error))?; + }; + + let max_attempts = match &strategy { + RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts, + RetryStrategy::Fixed { max_attempts, .. } => *max_attempts, + }; + + let attempt = attempt.get_or_insert(0u8); + + *attempt += 1; + + let attempt = *attempt; + if attempt > max_attempts { + return Err(anyhow!(error))?; + } + + let delay = match &strategy { + RetryStrategy::ExponentialBackoff { initial_delay, .. } => { + let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32); + Duration::from_secs(delay_secs) + } + RetryStrategy::Fixed { delay, .. } => *delay, + }; + log::debug!("Retry attempt {attempt} with delay {delay:?}"); + + event_stream.send_retry(acp_thread::RetryStatus { + last_error: error.to_string().into(), + attempt: attempt as usize, + max_attempts: max_attempts as usize, + started_at: Instant::now(), + duration: delay, + }); + cx.background_executor().timer(delay).await; + this.update(cx, |this, cx| { + this.flush_pending_message(cx); + if let Some(Message::Agent(message)) = this.messages.last() { + if message.tool_results.is_empty() { + this.messages.push(Message::Resume); + } + } + })?; + } else { + return Ok(()); + } } } @@ -1737,6 +1743,10 @@ impl Thread { return; }; + if message.content.is_empty() { + return; + } + for content in &message.content { let AgentMessageContent::ToolUse(tool_use) = content else { continue; diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 0b987e25b6..5674b15c98 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -820,6 +820,9 @@ impl AcpThreadView { let Some(thread) = self.thread() else { return; }; + if !thread.read(cx).can_resume(cx) { + return; + } let task = thread.update(cx, |thread, cx| thread.resume(cx)); cx.spawn(async move |this, cx| { @@ -4459,12 +4462,53 @@ impl AcpThreadView { } fn render_any_thread_error(&self, error: SharedString, cx: &mut Context<'_, Self>) -> Callout { + let can_resume = self + .thread() + .map_or(false, |thread| thread.read(cx).can_resume(cx)); + + let can_enable_burn_mode = self.as_native_thread(cx).map_or(false, |thread| { + let thread = thread.read(cx); + let supports_burn_mode = thread + .model() + .map_or(false, |model| model.supports_burn_mode()); + supports_burn_mode && thread.completion_mode() == CompletionMode::Normal + }); + Callout::new() .severity(Severity::Error) .title("Error") .icon(IconName::XCircle) .description(error.clone()) - .actions_slot(self.create_copy_button(error.to_string())) + .actions_slot( + h_flex() + .gap_0p5() + .when(can_resume && can_enable_burn_mode, |this| { + this.child( + Button::new("enable-burn-mode-and-retry", "Enable Burn Mode and Retry") + .icon(IconName::ZedBurnMode) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .label_size(LabelSize::Small) + .on_click(cx.listener(|this, _, window, cx| { + this.toggle_burn_mode(&ToggleBurnMode, window, cx); + this.resume_chat(cx); + })), + ) + }) + .when(can_resume, |this| { + this.child( + Button::new("retry", "Retry") + .icon(IconName::RotateCw) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .label_size(LabelSize::Small) + .on_click(cx.listener(|this, _, _window, cx| { + this.resume_chat(cx); + })), + ) + }) + .child(self.create_copy_button(error.to_string())), + ) .dismiss_action(self.dismiss_error_button(cx)) }