diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index c05c667450..87ecc1037c 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -5,6 +5,7 @@ use agent_settings::AgentProfileId; use anyhow::Result; use client::{Client, UserStore}; use cloud_llm_client::CompletionIntent; +use collections::IndexMap; use context_server::{ContextServer, ContextServerCommand, ContextServerId}; use fs::{FakeFs, Fs}; use futures::{ @@ -2096,6 +2097,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) { .unwrap(); cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Hey,"); fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded { provider: LanguageModelProviderName::new("Anthropic"), retry_after: Some(Duration::from_secs(3)), @@ -2105,7 +2107,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) { cx.executor().advance_clock(Duration::from_secs(3)); cx.run_until_parked(); - fake_model.send_last_completion_stream_text_chunk("Hey!"); + fake_model.send_last_completion_stream_text_chunk("there!"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -2135,18 +2137,24 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) { ## Assistant - Hey! + Hey, + + [resume] + + ## Assistant + + there! "} ) }); } #[gpui::test] -async fn test_send_retry_cancels_tool_calls_on_error(cx: &mut TestAppContext) { +async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) { let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await; let fake_model = model.as_fake(); - thread + let events = thread .update(cx, |thread, cx| { thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); thread.add_tool(EchoTool); @@ -2162,58 +2170,16 @@ async fn test_send_retry_cancels_tool_calls_on_error(cx: &mut TestAppContext) { input: json!({"text": "test"}), is_input_complete: true, }; - fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use_1)); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + tool_use_1.clone(), + )); fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded { provider: LanguageModelProviderName::new("Anthropic"), retry_after: Some(Duration::from_secs(3)), }); fake_model.end_last_completion_stream(); - cx.run_until_parked(); - thread.read_with(cx, |thread, _cx| { - assert_eq!( - thread.to_markdown(), - indoc! {" - ## User - - Call the echo tool! - "} - ) - }); cx.executor().advance_clock(Duration::from_secs(3)); - cx.run_until_parked(); - let completion = fake_model.pending_completions().pop().unwrap(); - assert_eq!( - completion.messages[1..], - vec![LanguageModelRequestMessage { - role: Role::User, - content: vec!["Call the echo tool!".into()], - cache: true - }] - ); - - let tool_use_2 = LanguageModelToolUse { - id: "tool_2".into(), - name: EchoTool::name().into(), - raw_input: json!({"text": "test"}).to_string(), - input: json!({"text": "test"}), - is_input_complete: true, - }; - let tool_result_2 = LanguageModelToolResult { - tool_use_id: "tool_2".into(), - tool_name: EchoTool::name().into(), - is_error: false, - content: "test".into(), - output: Some("test".into()), - }; - fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( - tool_use_2.clone(), - )); - fake_model - .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); - let completion = fake_model.pending_completions().pop().unwrap(); assert_eq!( completion.messages[1..], @@ -2225,16 +2191,38 @@ async fn test_send_retry_cancels_tool_calls_on_error(cx: &mut TestAppContext) { }, LanguageModelRequestMessage { role: Role::Assistant, - content: vec![MessageContent::ToolUse(tool_use_2)], + content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())], cache: false }, LanguageModelRequestMessage { role: Role::User, - content: vec![MessageContent::ToolResult(tool_result_2)], + content: vec![language_model::MessageContent::ToolResult( + LanguageModelToolResult { + tool_use_id: tool_use_1.id.clone(), + tool_name: tool_use_1.name.clone(), + is_error: false, + content: "test".into(), + output: Some("test".into()) + } + )], cache: true - } + }, ] ); + + fake_model.send_last_completion_stream_text_chunk("Done"); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + events.collect::>().await; + thread.read_with(cx, |thread, _cx| { + assert_eq!( + thread.last_message(), + Some(Message::Agent(AgentMessage { + content: vec![AgentMessageContent::Text("Done".into())], + tool_results: IndexMap::default() + })) + ); + }) } #[gpui::test] diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index f3af555b21..43f391ca64 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -123,7 +123,7 @@ impl Message { match self { Message::User(message) => message.to_markdown(), Message::Agent(message) => message.to_markdown(), - Message::Resume => "[resumed after tool use limit was reached]".into(), + Message::Resume => "[resume]\n".into(), } } @@ -1213,7 +1213,7 @@ impl Thread { log::debug!("Stream completion started successfully"); let mut attempt = None; - 'retry: loop { + loop { let request = this.update(cx, |this, cx| { this.build_completion_request(completion_intent, cx) })??; @@ -1236,6 +1236,7 @@ impl Thread { .await .map_err(|error| anyhow!(error))?; let mut tool_results = FuturesUnordered::new(); + let mut error = None; while let Some(event) = events.next().await { match event { @@ -1245,52 +1246,9 @@ impl Thread { this.handle_streamed_completion_event(event, event_stream, cx) })??); } - Err(error) => { - let completion_mode = - this.read_with(cx, |thread, _cx| thread.completion_mode())?; - if completion_mode == CompletionMode::Normal { - return Err(anyhow!(error))?; - } - - let Some(strategy) = Self::retry_strategy_for(&error) else { - return Err(anyhow!(error))?; - }; - - let max_attempts = match &strategy { - RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts, - RetryStrategy::Fixed { max_attempts, .. } => *max_attempts, - }; - - let attempt = attempt.get_or_insert(0u8); - - *attempt += 1; - - let attempt = *attempt; - if attempt > max_attempts { - return Err(anyhow!(error))?; - } - - let delay = match &strategy { - RetryStrategy::ExponentialBackoff { initial_delay, .. } => { - let delay_secs = - initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32); - Duration::from_secs(delay_secs) - } - RetryStrategy::Fixed { delay, .. } => *delay, - }; - log::debug!("Retry attempt {attempt} with delay {delay:?}"); - - event_stream.send_retry(acp_thread::RetryStatus { - last_error: error.to_string().into(), - attempt: attempt as usize, - max_attempts: max_attempts as usize, - started_at: Instant::now(), - duration: delay, - }); - this.update(cx, |this, _cx| this.pending_message.take())?; - - cx.background_executor().timer(delay).await; - continue 'retry; + Err(err) => { + error = Some(err); + break; } } } @@ -1317,7 +1275,58 @@ impl Thread { })?; } - return Ok(()); + if let Some(error) = error { + let completion_mode = this.read_with(cx, |thread, _cx| thread.completion_mode())?; + if completion_mode == CompletionMode::Normal { + return Err(anyhow!(error))?; + } + + let Some(strategy) = Self::retry_strategy_for(&error) else { + return Err(anyhow!(error))?; + }; + + let max_attempts = match &strategy { + RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts, + RetryStrategy::Fixed { max_attempts, .. } => *max_attempts, + }; + + let attempt = attempt.get_or_insert(0u8); + + *attempt += 1; + + let attempt = *attempt; + if attempt > max_attempts { + return Err(anyhow!(error))?; + } + + let delay = match &strategy { + RetryStrategy::ExponentialBackoff { initial_delay, .. } => { + let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32); + Duration::from_secs(delay_secs) + } + RetryStrategy::Fixed { delay, .. } => *delay, + }; + log::debug!("Retry attempt {attempt} with delay {delay:?}"); + + event_stream.send_retry(acp_thread::RetryStatus { + last_error: error.to_string().into(), + attempt: attempt as usize, + max_attempts: max_attempts as usize, + started_at: Instant::now(), + duration: delay, + }); + cx.background_executor().timer(delay).await; + this.update(cx, |this, cx| { + this.flush_pending_message(cx); + if let Some(Message::Agent(message)) = this.messages.last() { + if message.tool_results.is_empty() { + this.messages.push(Message::Resume); + } + } + })?; + } else { + return Ok(()); + } } }