acp: Simplify control flow for native agent loop (#36868)
Release Notes: - N/A Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
This commit is contained in:
parent
db949546cf
commit
69127d2bea
1 changed files with 73 additions and 91 deletions
|
@ -1142,37 +1142,7 @@ impl Thread {
|
||||||
_task: cx.spawn(async move |this, cx| {
|
_task: cx.spawn(async move |this, cx| {
|
||||||
log::debug!("Starting agent turn execution");
|
log::debug!("Starting agent turn execution");
|
||||||
|
|
||||||
let turn_result: Result<()> = async {
|
let turn_result = Self::run_turn_internal(&this, model, &event_stream, cx).await;
|
||||||
let mut intent = CompletionIntent::UserPrompt;
|
|
||||||
loop {
|
|
||||||
Self::stream_completion(&this, &model, intent, &event_stream, cx).await?;
|
|
||||||
|
|
||||||
let mut end_turn = true;
|
|
||||||
this.update(cx, |this, cx| {
|
|
||||||
// Generate title if needed.
|
|
||||||
if this.title.is_none() && this.pending_title_generation.is_none() {
|
|
||||||
this.generate_title(cx);
|
|
||||||
}
|
|
||||||
|
|
||||||
// End the turn if the model didn't use tools.
|
|
||||||
let message = this.pending_message.as_ref();
|
|
||||||
end_turn =
|
|
||||||
message.map_or(true, |message| message.tool_results.is_empty());
|
|
||||||
this.flush_pending_message(cx);
|
|
||||||
})?;
|
|
||||||
|
|
||||||
if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
|
|
||||||
log::info!("Tool use limit reached, completing turn");
|
|
||||||
return Err(language_model::ToolUseLimitReachedError.into());
|
|
||||||
} else if end_turn {
|
|
||||||
log::debug!("No tool uses found, completing turn");
|
|
||||||
return Ok(());
|
|
||||||
} else {
|
|
||||||
intent = CompletionIntent::ToolResults;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.await;
|
|
||||||
_ = this.update(cx, |this, cx| this.flush_pending_message(cx));
|
_ = this.update(cx, |this, cx| this.flush_pending_message(cx));
|
||||||
|
|
||||||
match turn_result {
|
match turn_result {
|
||||||
|
@ -1203,20 +1173,17 @@ impl Thread {
|
||||||
Ok(events_rx)
|
Ok(events_rx)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn stream_completion(
|
async fn run_turn_internal(
|
||||||
this: &WeakEntity<Self>,
|
this: &WeakEntity<Self>,
|
||||||
model: &Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
completion_intent: CompletionIntent,
|
|
||||||
event_stream: &ThreadEventStream,
|
event_stream: &ThreadEventStream,
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
log::debug!("Stream completion started successfully");
|
let mut attempt = 0;
|
||||||
|
let mut intent = CompletionIntent::UserPrompt;
|
||||||
let mut attempt = None;
|
|
||||||
loop {
|
loop {
|
||||||
let request = this.update(cx, |this, cx| {
|
let request =
|
||||||
this.build_completion_request(completion_intent, cx)
|
this.update(cx, |this, cx| this.build_completion_request(intent, cx))??;
|
||||||
})??;
|
|
||||||
|
|
||||||
telemetry::event!(
|
telemetry::event!(
|
||||||
"Agent Thread Completion",
|
"Agent Thread Completion",
|
||||||
|
@ -1227,23 +1194,19 @@ impl Thread {
|
||||||
attempt
|
attempt
|
||||||
);
|
);
|
||||||
|
|
||||||
log::debug!(
|
log::debug!("Calling model.stream_completion, attempt {}", attempt);
|
||||||
"Calling model.stream_completion, attempt {}",
|
|
||||||
attempt.unwrap_or(0)
|
|
||||||
);
|
|
||||||
let mut events = model
|
let mut events = model
|
||||||
.stream_completion(request, cx)
|
.stream_completion(request, cx)
|
||||||
.await
|
.await
|
||||||
.map_err(|error| anyhow!(error))?;
|
.map_err(|error| anyhow!(error))?;
|
||||||
let mut tool_results = FuturesUnordered::new();
|
let mut tool_results = FuturesUnordered::new();
|
||||||
let mut error = None;
|
let mut error = None;
|
||||||
|
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
|
log::trace!("Received completion event: {:?}", event);
|
||||||
match event {
|
match event {
|
||||||
Ok(event) => {
|
Ok(event) => {
|
||||||
log::trace!("Received completion event: {:?}", event);
|
|
||||||
tool_results.extend(this.update(cx, |this, cx| {
|
tool_results.extend(this.update(cx, |this, cx| {
|
||||||
this.handle_streamed_completion_event(event, event_stream, cx)
|
this.handle_completion_event(event, event_stream, cx)
|
||||||
})??);
|
})??);
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
@ -1253,6 +1216,7 @@ impl Thread {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let end_turn = tool_results.is_empty();
|
||||||
while let Some(tool_result) = tool_results.next().await {
|
while let Some(tool_result) = tool_results.next().await {
|
||||||
log::debug!("Tool finished {:?}", tool_result);
|
log::debug!("Tool finished {:?}", tool_result);
|
||||||
|
|
||||||
|
@ -1275,65 +1239,83 @@ impl Thread {
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
this.flush_pending_message(cx);
|
||||||
|
if this.title.is_none() && this.pending_title_generation.is_none() {
|
||||||
|
this.generate_title(cx);
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
if let Some(error) = error {
|
if let Some(error) = error {
|
||||||
let completion_mode = this.read_with(cx, |thread, _cx| thread.completion_mode())?;
|
attempt += 1;
|
||||||
if completion_mode == CompletionMode::Normal {
|
let retry =
|
||||||
return Err(anyhow!(error))?;
|
this.update(cx, |this, _| this.handle_completion_error(error, attempt))??;
|
||||||
}
|
let timer = cx.background_executor().timer(retry.duration);
|
||||||
|
event_stream.send_retry(retry);
|
||||||
let Some(strategy) = Self::retry_strategy_for(&error) else {
|
timer.await;
|
||||||
return Err(anyhow!(error))?;
|
this.update(cx, |this, _cx| {
|
||||||
};
|
|
||||||
|
|
||||||
let max_attempts = match &strategy {
|
|
||||||
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
|
|
||||||
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
|
|
||||||
};
|
|
||||||
|
|
||||||
let attempt = attempt.get_or_insert(0u8);
|
|
||||||
|
|
||||||
*attempt += 1;
|
|
||||||
|
|
||||||
let attempt = *attempt;
|
|
||||||
if attempt > max_attempts {
|
|
||||||
return Err(anyhow!(error))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let delay = match &strategy {
|
|
||||||
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
|
|
||||||
let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
|
|
||||||
Duration::from_secs(delay_secs)
|
|
||||||
}
|
|
||||||
RetryStrategy::Fixed { delay, .. } => *delay,
|
|
||||||
};
|
|
||||||
log::debug!("Retry attempt {attempt} with delay {delay:?}");
|
|
||||||
|
|
||||||
event_stream.send_retry(acp_thread::RetryStatus {
|
|
||||||
last_error: error.to_string().into(),
|
|
||||||
attempt: attempt as usize,
|
|
||||||
max_attempts: max_attempts as usize,
|
|
||||||
started_at: Instant::now(),
|
|
||||||
duration: delay,
|
|
||||||
});
|
|
||||||
cx.background_executor().timer(delay).await;
|
|
||||||
this.update(cx, |this, cx| {
|
|
||||||
this.flush_pending_message(cx);
|
|
||||||
if let Some(Message::Agent(message)) = this.messages.last() {
|
if let Some(Message::Agent(message)) = this.messages.last() {
|
||||||
if message.tool_results.is_empty() {
|
if message.tool_results.is_empty() {
|
||||||
|
intent = CompletionIntent::UserPrompt;
|
||||||
this.messages.push(Message::Resume);
|
this.messages.push(Message::Resume);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})?;
|
})?;
|
||||||
} else {
|
} else if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
|
||||||
|
return Err(language_model::ToolUseLimitReachedError.into());
|
||||||
|
} else if end_turn {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
|
} else {
|
||||||
|
intent = CompletionIntent::ToolResults;
|
||||||
|
attempt = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn handle_completion_error(
|
||||||
|
&mut self,
|
||||||
|
error: LanguageModelCompletionError,
|
||||||
|
attempt: u8,
|
||||||
|
) -> Result<acp_thread::RetryStatus> {
|
||||||
|
if self.completion_mode == CompletionMode::Normal {
|
||||||
|
return Err(anyhow!(error));
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(strategy) = Self::retry_strategy_for(&error) else {
|
||||||
|
return Err(anyhow!(error));
|
||||||
|
};
|
||||||
|
|
||||||
|
let max_attempts = match &strategy {
|
||||||
|
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
|
||||||
|
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
|
||||||
|
};
|
||||||
|
|
||||||
|
if attempt > max_attempts {
|
||||||
|
return Err(anyhow!(error));
|
||||||
|
}
|
||||||
|
|
||||||
|
let delay = match &strategy {
|
||||||
|
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
|
||||||
|
let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
|
||||||
|
Duration::from_secs(delay_secs)
|
||||||
|
}
|
||||||
|
RetryStrategy::Fixed { delay, .. } => *delay,
|
||||||
|
};
|
||||||
|
log::debug!("Retry attempt {attempt} with delay {delay:?}");
|
||||||
|
|
||||||
|
Ok(acp_thread::RetryStatus {
|
||||||
|
last_error: error.to_string().into(),
|
||||||
|
attempt: attempt as usize,
|
||||||
|
max_attempts: max_attempts as usize,
|
||||||
|
started_at: Instant::now(),
|
||||||
|
duration: delay,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// A helper method that's called on every streamed completion event.
|
/// A helper method that's called on every streamed completion event.
|
||||||
/// Returns an optional tool result task, which the main agentic loop will
|
/// Returns an optional tool result task, which the main agentic loop will
|
||||||
/// send back to the model when it resolves.
|
/// send back to the model when it resolves.
|
||||||
fn handle_streamed_completion_event(
|
fn handle_completion_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
event: LanguageModelCompletionEvent,
|
event: LanguageModelCompletionEvent,
|
||||||
event_stream: &ThreadEventStream,
|
event_stream: &ThreadEventStream,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue