Finish pending tools before retrying/erroring

This commit is contained in:
Antonio Scandurra 2025-08-25 14:37:00 +02:00
parent 5021993703
commit d7c4c9aa1e
2 changed files with 98 additions and 101 deletions

View file

@ -5,6 +5,7 @@ use agent_settings::AgentProfileId;
use anyhow::Result;
use client::{Client, UserStore};
use cloud_llm_client::CompletionIntent;
use collections::IndexMap;
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use fs::{FakeFs, Fs};
use futures::{
@ -2096,6 +2097,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey,");
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)),
@ -2105,7 +2107,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
cx.executor().advance_clock(Duration::from_secs(3));
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.send_last_completion_stream_text_chunk("there!");
fake_model.end_last_completion_stream();
cx.run_until_parked();
@ -2135,18 +2137,24 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
## Assistant
Hey!
Hey,
[resume]
## Assistant
there!
"}
)
});
}
#[gpui::test]
async fn test_send_retry_cancels_tool_calls_on_error(cx: &mut TestAppContext) {
async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
thread
let events = thread
.update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.add_tool(EchoTool);
@ -2162,58 +2170,16 @@ async fn test_send_retry_cancels_tool_calls_on_error(cx: &mut TestAppContext) {
input: json!({"text": "test"}),
is_input_complete: true,
};
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use_1));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
tool_use_1.clone(),
));
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)),
});
fake_model.end_last_completion_stream();
cx.run_until_parked();
thread.read_with(cx, |thread, _cx| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Call the echo tool!
"}
)
});
cx.executor().advance_clock(Duration::from_secs(3));
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
completion.messages[1..],
vec![LanguageModelRequestMessage {
role: Role::User,
content: vec!["Call the echo tool!".into()],
cache: true
}]
);
let tool_use_2 = LanguageModelToolUse {
id: "tool_2".into(),
name: EchoTool::name().into(),
raw_input: json!({"text": "test"}).to_string(),
input: json!({"text": "test"}),
is_input_complete: true,
};
let tool_result_2 = LanguageModelToolResult {
tool_use_id: "tool_2".into(),
tool_name: EchoTool::name().into(),
is_error: false,
content: "test".into(),
output: Some("test".into()),
};
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
tool_use_2.clone(),
));
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
fake_model.end_last_completion_stream();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
assert_eq!(
completion.messages[1..],
@ -2225,16 +2191,38 @@ async fn test_send_retry_cancels_tool_calls_on_error(cx: &mut TestAppContext) {
},
LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use_2)],
content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
cache: false
},
LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::ToolResult(tool_result_2)],
content: vec![language_model::MessageContent::ToolResult(
LanguageModelToolResult {
tool_use_id: tool_use_1.id.clone(),
tool_name: tool_use_1.name.clone(),
is_error: false,
content: "test".into(),
output: Some("test".into())
}
)],
cache: true
}
},
]
);
fake_model.send_last_completion_stream_text_chunk("Done");
fake_model.end_last_completion_stream();
cx.run_until_parked();
events.collect::<Vec<_>>().await;
thread.read_with(cx, |thread, _cx| {
assert_eq!(
thread.last_message(),
Some(Message::Agent(AgentMessage {
content: vec![AgentMessageContent::Text("Done".into())],
tool_results: IndexMap::default()
}))
);
})
}
#[gpui::test]

View file

@ -123,7 +123,7 @@ impl Message {
match self {
Message::User(message) => message.to_markdown(),
Message::Agent(message) => message.to_markdown(),
Message::Resume => "[resumed after tool use limit was reached]".into(),
Message::Resume => "[resume]\n".into(),
}
}
@ -1213,7 +1213,7 @@ impl Thread {
log::debug!("Stream completion started successfully");
let mut attempt = None;
'retry: loop {
loop {
let request = this.update(cx, |this, cx| {
this.build_completion_request(completion_intent, cx)
})??;
@ -1236,6 +1236,7 @@ impl Thread {
.await
.map_err(|error| anyhow!(error))?;
let mut tool_results = FuturesUnordered::new();
let mut error = None;
while let Some(event) = events.next().await {
match event {
@ -1245,52 +1246,9 @@ impl Thread {
this.handle_streamed_completion_event(event, event_stream, cx)
})??);
}
Err(error) => {
let completion_mode =
this.read_with(cx, |thread, _cx| thread.completion_mode())?;
if completion_mode == CompletionMode::Normal {
return Err(anyhow!(error))?;
}
let Some(strategy) = Self::retry_strategy_for(&error) else {
return Err(anyhow!(error))?;
};
let max_attempts = match &strategy {
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
};
let attempt = attempt.get_or_insert(0u8);
*attempt += 1;
let attempt = *attempt;
if attempt > max_attempts {
return Err(anyhow!(error))?;
}
let delay = match &strategy {
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
let delay_secs =
initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
Duration::from_secs(delay_secs)
}
RetryStrategy::Fixed { delay, .. } => *delay,
};
log::debug!("Retry attempt {attempt} with delay {delay:?}");
event_stream.send_retry(acp_thread::RetryStatus {
last_error: error.to_string().into(),
attempt: attempt as usize,
max_attempts: max_attempts as usize,
started_at: Instant::now(),
duration: delay,
});
this.update(cx, |this, _cx| this.pending_message.take())?;
cx.background_executor().timer(delay).await;
continue 'retry;
Err(err) => {
error = Some(err);
break;
}
}
}
@ -1317,7 +1275,58 @@ impl Thread {
})?;
}
return Ok(());
if let Some(error) = error {
let completion_mode = this.read_with(cx, |thread, _cx| thread.completion_mode())?;
if completion_mode == CompletionMode::Normal {
return Err(anyhow!(error))?;
}
let Some(strategy) = Self::retry_strategy_for(&error) else {
return Err(anyhow!(error))?;
};
let max_attempts = match &strategy {
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
};
let attempt = attempt.get_or_insert(0u8);
*attempt += 1;
let attempt = *attempt;
if attempt > max_attempts {
return Err(anyhow!(error))?;
}
let delay = match &strategy {
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
Duration::from_secs(delay_secs)
}
RetryStrategy::Fixed { delay, .. } => *delay,
};
log::debug!("Retry attempt {attempt} with delay {delay:?}");
event_stream.send_retry(acp_thread::RetryStatus {
last_error: error.to_string().into(),
attempt: attempt as usize,
max_attempts: max_attempts as usize,
started_at: Instant::now(),
duration: delay,
});
cx.background_executor().timer(delay).await;
this.update(cx, |this, cx| {
this.flush_pending_message(cx);
if let Some(Message::Agent(message)) = this.messages.last() {
if message.tool_results.is_empty() {
this.messages.push(Message::Resume);
}
}
})?;
} else {
return Ok(());
}
}
}