acp: Refactor agent2 send
to have a clearer control flow (#36689)
Release Notes: - N/A
This commit is contained in:
parent
132daef9f6
commit
190217a43b
3 changed files with 134 additions and 163 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -244,6 +244,7 @@ dependencies = [
|
|||
"terminal",
|
||||
"text",
|
||||
"theme",
|
||||
"thiserror 2.0.12",
|
||||
"tree-sitter-rust",
|
||||
"ui",
|
||||
"unindent",
|
||||
|
|
|
@ -61,6 +61,7 @@ sqlez.workspace = true
|
|||
task.workspace = true
|
||||
telemetry.workspace = true
|
||||
terminal.workspace = true
|
||||
thiserror.workspace = true
|
||||
text.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
|
|
|
@ -499,6 +499,16 @@ pub struct ToolCallAuthorization {
|
|||
pub response: oneshot::Sender<acp::PermissionOptionId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum CompletionError {
|
||||
#[error("max tokens")]
|
||||
MaxTokens,
|
||||
#[error("refusal")]
|
||||
Refusal,
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
pub struct Thread {
|
||||
id: acp::SessionId,
|
||||
prompt_id: PromptId,
|
||||
|
@ -1077,101 +1087,62 @@ impl Thread {
|
|||
_task: cx.spawn(async move |this, cx| {
|
||||
log::info!("Starting agent turn execution");
|
||||
let mut update_title = None;
|
||||
let turn_result: Result<StopReason> = async {
|
||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
||||
let turn_result: Result<()> = async {
|
||||
let mut intent = CompletionIntent::UserPrompt;
|
||||
loop {
|
||||
log::debug!(
|
||||
"Building completion request with intent: {:?}",
|
||||
completion_intent
|
||||
);
|
||||
let request = this.update(cx, |this, cx| {
|
||||
this.build_completion_request(completion_intent, cx)
|
||||
})??;
|
||||
|
||||
log::info!("Calling model.stream_completion");
|
||||
|
||||
let mut tool_use_limit_reached = false;
|
||||
let mut refused = false;
|
||||
let mut reached_max_tokens = false;
|
||||
let mut tool_uses = Self::stream_completion_with_retries(
|
||||
this.clone(),
|
||||
model.clone(),
|
||||
request,
|
||||
&event_stream,
|
||||
&mut tool_use_limit_reached,
|
||||
&mut refused,
|
||||
&mut reached_max_tokens,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if refused {
|
||||
return Ok(StopReason::Refusal);
|
||||
} else if reached_max_tokens {
|
||||
return Ok(StopReason::MaxTokens);
|
||||
}
|
||||
|
||||
let end_turn = tool_uses.is_empty();
|
||||
while let Some(tool_result) = tool_uses.next().await {
|
||||
log::info!("Tool finished {:?}", tool_result);
|
||||
|
||||
event_stream.update_tool_call_fields(
|
||||
&tool_result.tool_use_id,
|
||||
acp::ToolCallUpdateFields {
|
||||
status: Some(if tool_result.is_error {
|
||||
acp::ToolCallStatus::Failed
|
||||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
}),
|
||||
raw_output: tool_result.output.clone(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
this.update(cx, |this, _cx| {
|
||||
this.pending_message()
|
||||
.tool_results
|
||||
.insert(tool_result.tool_use_id.clone(), tool_result);
|
||||
})?;
|
||||
}
|
||||
Self::stream_completion(&this, &model, intent, &event_stream, cx).await?;
|
||||
|
||||
let mut end_turn = true;
|
||||
this.update(cx, |this, cx| {
|
||||
// Generate title if needed.
|
||||
if this.title.is_none() && update_title.is_none() {
|
||||
update_title = Some(this.update_title(&event_stream, cx));
|
||||
}
|
||||
|
||||
// End the turn if the model didn't use tools.
|
||||
let message = this.pending_message.as_ref();
|
||||
end_turn =
|
||||
message.map_or(true, |message| message.tool_results.is_empty());
|
||||
this.flush_pending_message(cx);
|
||||
})?;
|
||||
|
||||
if tool_use_limit_reached {
|
||||
if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
|
||||
log::info!("Tool use limit reached, completing turn");
|
||||
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
|
||||
return Err(language_model::ToolUseLimitReachedError.into());
|
||||
} else if end_turn {
|
||||
log::info!("No tool uses found, completing turn");
|
||||
return Ok(StopReason::EndTurn);
|
||||
return Ok(());
|
||||
} else {
|
||||
this.update(cx, |this, cx| this.flush_pending_message(cx))?;
|
||||
completion_intent = CompletionIntent::ToolResults;
|
||||
intent = CompletionIntent::ToolResults;
|
||||
}
|
||||
}
|
||||
}
|
||||
.await;
|
||||
_ = this.update(cx, |this, cx| this.flush_pending_message(cx));
|
||||
|
||||
if let Some(update_title) = update_title {
|
||||
update_title.await.context("update title failed").log_err();
|
||||
}
|
||||
|
||||
match turn_result {
|
||||
Ok(reason) => {
|
||||
log::info!("Turn execution completed: {:?}", reason);
|
||||
|
||||
if let Some(update_title) = update_title {
|
||||
update_title.await.context("update title failed").log_err();
|
||||
}
|
||||
|
||||
event_stream.send_stop(reason);
|
||||
if reason == StopReason::Refusal {
|
||||
_ = this.update(cx, |this, _| this.messages.truncate(message_ix));
|
||||
}
|
||||
Ok(()) => {
|
||||
log::info!("Turn execution completed");
|
||||
event_stream.send_stop(acp::StopReason::EndTurn);
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!("Turn execution failed: {:?}", error);
|
||||
event_stream.send_error(error);
|
||||
match error.downcast::<CompletionError>() {
|
||||
Ok(CompletionError::Refusal) => {
|
||||
event_stream.send_stop(acp::StopReason::Refusal);
|
||||
_ = this.update(cx, |this, _| this.messages.truncate(message_ix));
|
||||
}
|
||||
Ok(CompletionError::MaxTokens) => {
|
||||
event_stream.send_stop(acp::StopReason::MaxTokens);
|
||||
}
|
||||
Ok(CompletionError::Other(error)) | Err(error) => {
|
||||
event_stream.send_error(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1181,17 +1152,17 @@ impl Thread {
|
|||
Ok(events_rx)
|
||||
}
|
||||
|
||||
async fn stream_completion_with_retries(
|
||||
this: WeakEntity<Self>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
request: LanguageModelRequest,
|
||||
async fn stream_completion(
|
||||
this: &WeakEntity<Self>,
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
completion_intent: CompletionIntent,
|
||||
event_stream: &ThreadEventStream,
|
||||
tool_use_limit_reached: &mut bool,
|
||||
refusal: &mut bool,
|
||||
max_tokens_reached: &mut bool,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
|
||||
) -> Result<()> {
|
||||
log::debug!("Stream completion started successfully");
|
||||
let request = this.update(cx, |this, cx| {
|
||||
this.build_completion_request(completion_intent, cx)
|
||||
})??;
|
||||
|
||||
let mut attempt = None;
|
||||
'retry: loop {
|
||||
|
@ -1204,68 +1175,33 @@ impl Thread {
|
|||
attempt
|
||||
);
|
||||
|
||||
let mut events = model.stream_completion(request.clone(), cx).await?;
|
||||
let mut tool_uses = FuturesUnordered::new();
|
||||
log::info!(
|
||||
"Calling model.stream_completion, attempt {}",
|
||||
attempt.unwrap_or(0)
|
||||
);
|
||||
let mut events = model
|
||||
.stream_completion(request.clone(), cx)
|
||||
.await
|
||||
.map_err(|error| anyhow!(error))?;
|
||||
let mut tool_results = FuturesUnordered::new();
|
||||
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::ToolUseLimitReached,
|
||||
)) => {
|
||||
*tool_use_limit_reached = true;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::UsageUpdated { amount, limit },
|
||||
)) => {
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_model_request_usage(amount, limit, cx)
|
||||
})?;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(usage)) => {
|
||||
telemetry::event!(
|
||||
"Agent Thread Completion Usage Updated",
|
||||
thread_id = this.read_with(cx, |this, _| this.id.to_string())?,
|
||||
prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?,
|
||||
model = model.telemetry_id(),
|
||||
model_provider = model.provider_id().to_string(),
|
||||
attempt,
|
||||
input_tokens = usage.input_tokens,
|
||||
output_tokens = usage.output_tokens,
|
||||
cache_creation_input_tokens = usage.cache_creation_input_tokens,
|
||||
cache_read_input_tokens = usage.cache_read_input_tokens,
|
||||
);
|
||||
|
||||
this.update(cx, |this, cx| this.update_token_usage(usage, cx))?;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
|
||||
*refusal = true;
|
||||
return Ok(FuturesUnordered::default());
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => {
|
||||
*max_tokens_reached = true;
|
||||
return Ok(FuturesUnordered::default());
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Stop(
|
||||
StopReason::ToolUse | StopReason::EndTurn,
|
||||
)) => break,
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
this.update(cx, |this, cx| {
|
||||
tool_uses.extend(this.handle_streamed_completion_event(
|
||||
event,
|
||||
event_stream,
|
||||
cx,
|
||||
));
|
||||
})?;
|
||||
tool_results.extend(this.update(cx, |this, cx| {
|
||||
this.handle_streamed_completion_event(event, event_stream, cx)
|
||||
})??);
|
||||
}
|
||||
Err(error) => {
|
||||
let completion_mode =
|
||||
this.read_with(cx, |thread, _cx| thread.completion_mode())?;
|
||||
if completion_mode == CompletionMode::Normal {
|
||||
return Err(error.into());
|
||||
return Err(anyhow!(error))?;
|
||||
}
|
||||
|
||||
let Some(strategy) = Self::retry_strategy_for(&error) else {
|
||||
return Err(error.into());
|
||||
return Err(anyhow!(error))?;
|
||||
};
|
||||
|
||||
let max_attempts = match &strategy {
|
||||
|
@ -1279,7 +1215,7 @@ impl Thread {
|
|||
|
||||
let attempt = *attempt;
|
||||
if attempt > max_attempts {
|
||||
return Err(error.into());
|
||||
return Err(anyhow!(error))?;
|
||||
}
|
||||
|
||||
let delay = match &strategy {
|
||||
|
@ -1306,7 +1242,29 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
return Ok(tool_uses);
|
||||
while let Some(tool_result) = tool_results.next().await {
|
||||
log::info!("Tool finished {:?}", tool_result);
|
||||
|
||||
event_stream.update_tool_call_fields(
|
||||
&tool_result.tool_use_id,
|
||||
acp::ToolCallUpdateFields {
|
||||
status: Some(if tool_result.is_error {
|
||||
acp::ToolCallStatus::Failed
|
||||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
}),
|
||||
raw_output: tool_result.output.clone(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
this.update(cx, |this, _cx| {
|
||||
this.pending_message()
|
||||
.tool_results
|
||||
.insert(tool_result.tool_use_id.clone(), tool_result);
|
||||
})?;
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1328,14 +1286,14 @@ impl Thread {
|
|||
}
|
||||
|
||||
/// A helper method that's called on every streamed completion event.
|
||||
/// Returns an optional tool result task, which the main agentic loop in
|
||||
/// send will send back to the model when it resolves.
|
||||
/// Returns an optional tool result task, which the main agentic loop will
|
||||
/// send back to the model when it resolves.
|
||||
fn handle_streamed_completion_event(
|
||||
&mut self,
|
||||
event: LanguageModelCompletionEvent,
|
||||
event_stream: &ThreadEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
) -> Result<Option<Task<LanguageModelToolResult>>> {
|
||||
log::trace!("Handling streamed completion event: {:?}", event);
|
||||
use LanguageModelCompletionEvent::*;
|
||||
|
||||
|
@ -1350,7 +1308,7 @@ impl Thread {
|
|||
}
|
||||
RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
|
||||
ToolUse(tool_use) => {
|
||||
return self.handle_tool_use_event(tool_use, event_stream, cx);
|
||||
return Ok(self.handle_tool_use_event(tool_use, event_stream, cx));
|
||||
}
|
||||
ToolUseJsonParseError {
|
||||
id,
|
||||
|
@ -1358,18 +1316,46 @@ impl Thread {
|
|||
raw_input,
|
||||
json_parse_error,
|
||||
} => {
|
||||
return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
|
||||
id,
|
||||
tool_name,
|
||||
raw_input,
|
||||
json_parse_error,
|
||||
return Ok(Some(Task::ready(
|
||||
self.handle_tool_use_json_parse_error_event(
|
||||
id,
|
||||
tool_name,
|
||||
raw_input,
|
||||
json_parse_error,
|
||||
),
|
||||
)));
|
||||
}
|
||||
StatusUpdate(_) => {}
|
||||
UsageUpdate(_) | Stop(_) => unreachable!(),
|
||||
UsageUpdate(usage) => {
|
||||
telemetry::event!(
|
||||
"Agent Thread Completion Usage Updated",
|
||||
thread_id = self.id.to_string(),
|
||||
prompt_id = self.prompt_id.to_string(),
|
||||
model = self.model.as_ref().map(|m| m.telemetry_id()),
|
||||
model_provider = self.model.as_ref().map(|m| m.provider_id().to_string()),
|
||||
input_tokens = usage.input_tokens,
|
||||
output_tokens = usage.output_tokens,
|
||||
cache_creation_input_tokens = usage.cache_creation_input_tokens,
|
||||
cache_read_input_tokens = usage.cache_read_input_tokens,
|
||||
);
|
||||
self.update_token_usage(usage, cx);
|
||||
}
|
||||
StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
|
||||
self.update_model_request_usage(amount, limit, cx);
|
||||
}
|
||||
StatusUpdate(
|
||||
CompletionRequestStatus::Started
|
||||
| CompletionRequestStatus::Queued { .. }
|
||||
| CompletionRequestStatus::Failed { .. },
|
||||
) => {}
|
||||
StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
|
||||
self.tool_use_limit_reached = true;
|
||||
}
|
||||
Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
|
||||
Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
|
||||
Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
|
||||
}
|
||||
|
||||
None
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn handle_text_event(
|
||||
|
@ -2225,25 +2211,8 @@ impl ThreadEventStream {
|
|||
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
|
||||
}
|
||||
|
||||
fn send_stop(&self, reason: StopReason) {
|
||||
match reason {
|
||||
StopReason::EndTurn => {
|
||||
self.0
|
||||
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
|
||||
.ok();
|
||||
}
|
||||
StopReason::MaxTokens => {
|
||||
self.0
|
||||
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
|
||||
.ok();
|
||||
}
|
||||
StopReason::Refusal => {
|
||||
self.0
|
||||
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
|
||||
.ok();
|
||||
}
|
||||
StopReason::ToolUse => {}
|
||||
}
|
||||
fn send_stop(&self, reason: acp::StopReason) {
|
||||
self.0.unbounded_send(Ok(ThreadEvent::Stop(reason))).ok();
|
||||
}
|
||||
|
||||
fn send_canceled(&self) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue