Finish pending tools before retrying/erroring
This commit is contained in:
parent
5021993703
commit
d7c4c9aa1e
2 changed files with 98 additions and 101 deletions
|
@ -5,6 +5,7 @@ use agent_settings::AgentProfileId;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use client::{Client, UserStore};
|
use client::{Client, UserStore};
|
||||||
use cloud_llm_client::CompletionIntent;
|
use cloud_llm_client::CompletionIntent;
|
||||||
|
use collections::IndexMap;
|
||||||
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
||||||
use fs::{FakeFs, Fs};
|
use fs::{FakeFs, Fs};
|
||||||
use futures::{
|
use futures::{
|
||||||
|
@ -2096,6 +2097,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
fake_model.send_last_completion_stream_text_chunk("Hey,");
|
||||||
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
||||||
provider: LanguageModelProviderName::new("Anthropic"),
|
provider: LanguageModelProviderName::new("Anthropic"),
|
||||||
retry_after: Some(Duration::from_secs(3)),
|
retry_after: Some(Duration::from_secs(3)),
|
||||||
|
@ -2105,7 +2107,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
|
||||||
cx.executor().advance_clock(Duration::from_secs(3));
|
cx.executor().advance_clock(Duration::from_secs(3));
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
fake_model.send_last_completion_stream_text_chunk("there!");
|
||||||
fake_model.end_last_completion_stream();
|
fake_model.end_last_completion_stream();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
@ -2135,18 +2137,24 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
## Assistant
|
## Assistant
|
||||||
|
|
||||||
Hey!
|
Hey,
|
||||||
|
|
||||||
|
[resume]
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
there!
|
||||||
"}
|
"}
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_send_retry_cancels_tool_calls_on_error(cx: &mut TestAppContext) {
|
async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
thread
|
let events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
||||||
thread.add_tool(EchoTool);
|
thread.add_tool(EchoTool);
|
||||||
|
@ -2162,58 +2170,16 @@ async fn test_send_retry_cancels_tool_calls_on_error(cx: &mut TestAppContext) {
|
||||||
input: json!({"text": "test"}),
|
input: json!({"text": "test"}),
|
||||||
is_input_complete: true,
|
is_input_complete: true,
|
||||||
};
|
};
|
||||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use_1));
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||||
|
tool_use_1.clone(),
|
||||||
|
));
|
||||||
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
||||||
provider: LanguageModelProviderName::new("Anthropic"),
|
provider: LanguageModelProviderName::new("Anthropic"),
|
||||||
retry_after: Some(Duration::from_secs(3)),
|
retry_after: Some(Duration::from_secs(3)),
|
||||||
});
|
});
|
||||||
fake_model.end_last_completion_stream();
|
fake_model.end_last_completion_stream();
|
||||||
cx.run_until_parked();
|
|
||||||
thread.read_with(cx, |thread, _cx| {
|
|
||||||
assert_eq!(
|
|
||||||
thread.to_markdown(),
|
|
||||||
indoc! {"
|
|
||||||
## User
|
|
||||||
|
|
||||||
Call the echo tool!
|
|
||||||
"}
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
cx.executor().advance_clock(Duration::from_secs(3));
|
cx.executor().advance_clock(Duration::from_secs(3));
|
||||||
cx.run_until_parked();
|
|
||||||
let completion = fake_model.pending_completions().pop().unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
completion.messages[1..],
|
|
||||||
vec![LanguageModelRequestMessage {
|
|
||||||
role: Role::User,
|
|
||||||
content: vec!["Call the echo tool!".into()],
|
|
||||||
cache: true
|
|
||||||
}]
|
|
||||||
);
|
|
||||||
|
|
||||||
let tool_use_2 = LanguageModelToolUse {
|
|
||||||
id: "tool_2".into(),
|
|
||||||
name: EchoTool::name().into(),
|
|
||||||
raw_input: json!({"text": "test"}).to_string(),
|
|
||||||
input: json!({"text": "test"}),
|
|
||||||
is_input_complete: true,
|
|
||||||
};
|
|
||||||
let tool_result_2 = LanguageModelToolResult {
|
|
||||||
tool_use_id: "tool_2".into(),
|
|
||||||
tool_name: EchoTool::name().into(),
|
|
||||||
is_error: false,
|
|
||||||
content: "test".into(),
|
|
||||||
output: Some("test".into()),
|
|
||||||
};
|
|
||||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
|
||||||
tool_use_2.clone(),
|
|
||||||
));
|
|
||||||
fake_model
|
|
||||||
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
|
|
||||||
fake_model.end_last_completion_stream();
|
|
||||||
cx.run_until_parked();
|
|
||||||
|
|
||||||
let completion = fake_model.pending_completions().pop().unwrap();
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
completion.messages[1..],
|
completion.messages[1..],
|
||||||
|
@ -2225,16 +2191,38 @@ async fn test_send_retry_cancels_tool_calls_on_error(cx: &mut TestAppContext) {
|
||||||
},
|
},
|
||||||
LanguageModelRequestMessage {
|
LanguageModelRequestMessage {
|
||||||
role: Role::Assistant,
|
role: Role::Assistant,
|
||||||
content: vec![MessageContent::ToolUse(tool_use_2)],
|
content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
|
||||||
cache: false
|
cache: false
|
||||||
},
|
},
|
||||||
LanguageModelRequestMessage {
|
LanguageModelRequestMessage {
|
||||||
role: Role::User,
|
role: Role::User,
|
||||||
content: vec![MessageContent::ToolResult(tool_result_2)],
|
content: vec![language_model::MessageContent::ToolResult(
|
||||||
|
LanguageModelToolResult {
|
||||||
|
tool_use_id: tool_use_1.id.clone(),
|
||||||
|
tool_name: tool_use_1.name.clone(),
|
||||||
|
is_error: false,
|
||||||
|
content: "test".into(),
|
||||||
|
output: Some("test".into())
|
||||||
|
}
|
||||||
|
)],
|
||||||
cache: true
|
cache: true
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
fake_model.send_last_completion_stream_text_chunk("Done");
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
cx.run_until_parked();
|
||||||
|
events.collect::<Vec<_>>().await;
|
||||||
|
thread.read_with(cx, |thread, _cx| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.last_message(),
|
||||||
|
Some(Message::Agent(AgentMessage {
|
||||||
|
content: vec![AgentMessageContent::Text("Done".into())],
|
||||||
|
tool_results: IndexMap::default()
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
|
|
|
@ -123,7 +123,7 @@ impl Message {
|
||||||
match self {
|
match self {
|
||||||
Message::User(message) => message.to_markdown(),
|
Message::User(message) => message.to_markdown(),
|
||||||
Message::Agent(message) => message.to_markdown(),
|
Message::Agent(message) => message.to_markdown(),
|
||||||
Message::Resume => "[resumed after tool use limit was reached]".into(),
|
Message::Resume => "[resume]\n".into(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1213,7 +1213,7 @@ impl Thread {
|
||||||
log::debug!("Stream completion started successfully");
|
log::debug!("Stream completion started successfully");
|
||||||
|
|
||||||
let mut attempt = None;
|
let mut attempt = None;
|
||||||
'retry: loop {
|
loop {
|
||||||
let request = this.update(cx, |this, cx| {
|
let request = this.update(cx, |this, cx| {
|
||||||
this.build_completion_request(completion_intent, cx)
|
this.build_completion_request(completion_intent, cx)
|
||||||
})??;
|
})??;
|
||||||
|
@ -1236,6 +1236,7 @@ impl Thread {
|
||||||
.await
|
.await
|
||||||
.map_err(|error| anyhow!(error))?;
|
.map_err(|error| anyhow!(error))?;
|
||||||
let mut tool_results = FuturesUnordered::new();
|
let mut tool_results = FuturesUnordered::new();
|
||||||
|
let mut error = None;
|
||||||
|
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
match event {
|
match event {
|
||||||
|
@ -1245,52 +1246,9 @@ impl Thread {
|
||||||
this.handle_streamed_completion_event(event, event_stream, cx)
|
this.handle_streamed_completion_event(event, event_stream, cx)
|
||||||
})??);
|
})??);
|
||||||
}
|
}
|
||||||
Err(error) => {
|
Err(err) => {
|
||||||
let completion_mode =
|
error = Some(err);
|
||||||
this.read_with(cx, |thread, _cx| thread.completion_mode())?;
|
break;
|
||||||
if completion_mode == CompletionMode::Normal {
|
|
||||||
return Err(anyhow!(error))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let Some(strategy) = Self::retry_strategy_for(&error) else {
|
|
||||||
return Err(anyhow!(error))?;
|
|
||||||
};
|
|
||||||
|
|
||||||
let max_attempts = match &strategy {
|
|
||||||
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
|
|
||||||
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
|
|
||||||
};
|
|
||||||
|
|
||||||
let attempt = attempt.get_or_insert(0u8);
|
|
||||||
|
|
||||||
*attempt += 1;
|
|
||||||
|
|
||||||
let attempt = *attempt;
|
|
||||||
if attempt > max_attempts {
|
|
||||||
return Err(anyhow!(error))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let delay = match &strategy {
|
|
||||||
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
|
|
||||||
let delay_secs =
|
|
||||||
initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
|
|
||||||
Duration::from_secs(delay_secs)
|
|
||||||
}
|
|
||||||
RetryStrategy::Fixed { delay, .. } => *delay,
|
|
||||||
};
|
|
||||||
log::debug!("Retry attempt {attempt} with delay {delay:?}");
|
|
||||||
|
|
||||||
event_stream.send_retry(acp_thread::RetryStatus {
|
|
||||||
last_error: error.to_string().into(),
|
|
||||||
attempt: attempt as usize,
|
|
||||||
max_attempts: max_attempts as usize,
|
|
||||||
started_at: Instant::now(),
|
|
||||||
duration: delay,
|
|
||||||
});
|
|
||||||
this.update(cx, |this, _cx| this.pending_message.take())?;
|
|
||||||
|
|
||||||
cx.background_executor().timer(delay).await;
|
|
||||||
continue 'retry;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1317,7 +1275,58 @@ impl Thread {
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
|
|
||||||
return Ok(());
|
if let Some(error) = error {
|
||||||
|
let completion_mode = this.read_with(cx, |thread, _cx| thread.completion_mode())?;
|
||||||
|
if completion_mode == CompletionMode::Normal {
|
||||||
|
return Err(anyhow!(error))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(strategy) = Self::retry_strategy_for(&error) else {
|
||||||
|
return Err(anyhow!(error))?;
|
||||||
|
};
|
||||||
|
|
||||||
|
let max_attempts = match &strategy {
|
||||||
|
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
|
||||||
|
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
|
||||||
|
};
|
||||||
|
|
||||||
|
let attempt = attempt.get_or_insert(0u8);
|
||||||
|
|
||||||
|
*attempt += 1;
|
||||||
|
|
||||||
|
let attempt = *attempt;
|
||||||
|
if attempt > max_attempts {
|
||||||
|
return Err(anyhow!(error))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let delay = match &strategy {
|
||||||
|
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
|
||||||
|
let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
|
||||||
|
Duration::from_secs(delay_secs)
|
||||||
|
}
|
||||||
|
RetryStrategy::Fixed { delay, .. } => *delay,
|
||||||
|
};
|
||||||
|
log::debug!("Retry attempt {attempt} with delay {delay:?}");
|
||||||
|
|
||||||
|
event_stream.send_retry(acp_thread::RetryStatus {
|
||||||
|
last_error: error.to_string().into(),
|
||||||
|
attempt: attempt as usize,
|
||||||
|
max_attempts: max_attempts as usize,
|
||||||
|
started_at: Instant::now(),
|
||||||
|
duration: delay,
|
||||||
|
});
|
||||||
|
cx.background_executor().timer(delay).await;
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
this.flush_pending_message(cx);
|
||||||
|
if let Some(Message::Agent(message)) = this.messages.last() {
|
||||||
|
if message.tool_results.is_empty() {
|
||||||
|
this.messages.push(Message::Resume);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
} else {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue