acp: Refactor agent2 send
to have a clearer control flow (#36689)
Release Notes: - N/A
This commit is contained in:
parent
132daef9f6
commit
190217a43b
3 changed files with 134 additions and 163 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -244,6 +244,7 @@ dependencies = [
|
||||||
"terminal",
|
"terminal",
|
||||||
"text",
|
"text",
|
||||||
"theme",
|
"theme",
|
||||||
|
"thiserror 2.0.12",
|
||||||
"tree-sitter-rust",
|
"tree-sitter-rust",
|
||||||
"ui",
|
"ui",
|
||||||
"unindent",
|
"unindent",
|
||||||
|
|
|
@ -61,6 +61,7 @@ sqlez.workspace = true
|
||||||
task.workspace = true
|
task.workspace = true
|
||||||
telemetry.workspace = true
|
telemetry.workspace = true
|
||||||
terminal.workspace = true
|
terminal.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
text.workspace = true
|
text.workspace = true
|
||||||
ui.workspace = true
|
ui.workspace = true
|
||||||
util.workspace = true
|
util.workspace = true
|
||||||
|
|
|
@ -499,6 +499,16 @@ pub struct ToolCallAuthorization {
|
||||||
pub response: oneshot::Sender<acp::PermissionOptionId>,
|
pub response: oneshot::Sender<acp::PermissionOptionId>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
enum CompletionError {
|
||||||
|
#[error("max tokens")]
|
||||||
|
MaxTokens,
|
||||||
|
#[error("refusal")]
|
||||||
|
Refusal,
|
||||||
|
#[error(transparent)]
|
||||||
|
Other(#[from] anyhow::Error),
|
||||||
|
}
|
||||||
|
|
||||||
pub struct Thread {
|
pub struct Thread {
|
||||||
id: acp::SessionId,
|
id: acp::SessionId,
|
||||||
prompt_id: PromptId,
|
prompt_id: PromptId,
|
||||||
|
@ -1077,101 +1087,62 @@ impl Thread {
|
||||||
_task: cx.spawn(async move |this, cx| {
|
_task: cx.spawn(async move |this, cx| {
|
||||||
log::info!("Starting agent turn execution");
|
log::info!("Starting agent turn execution");
|
||||||
let mut update_title = None;
|
let mut update_title = None;
|
||||||
let turn_result: Result<StopReason> = async {
|
let turn_result: Result<()> = async {
|
||||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
let mut intent = CompletionIntent::UserPrompt;
|
||||||
loop {
|
loop {
|
||||||
log::debug!(
|
Self::stream_completion(&this, &model, intent, &event_stream, cx).await?;
|
||||||
"Building completion request with intent: {:?}",
|
|
||||||
completion_intent
|
|
||||||
);
|
|
||||||
let request = this.update(cx, |this, cx| {
|
|
||||||
this.build_completion_request(completion_intent, cx)
|
|
||||||
})??;
|
|
||||||
|
|
||||||
log::info!("Calling model.stream_completion");
|
|
||||||
|
|
||||||
let mut tool_use_limit_reached = false;
|
|
||||||
let mut refused = false;
|
|
||||||
let mut reached_max_tokens = false;
|
|
||||||
let mut tool_uses = Self::stream_completion_with_retries(
|
|
||||||
this.clone(),
|
|
||||||
model.clone(),
|
|
||||||
request,
|
|
||||||
&event_stream,
|
|
||||||
&mut tool_use_limit_reached,
|
|
||||||
&mut refused,
|
|
||||||
&mut reached_max_tokens,
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if refused {
|
|
||||||
return Ok(StopReason::Refusal);
|
|
||||||
} else if reached_max_tokens {
|
|
||||||
return Ok(StopReason::MaxTokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
let end_turn = tool_uses.is_empty();
|
|
||||||
while let Some(tool_result) = tool_uses.next().await {
|
|
||||||
log::info!("Tool finished {:?}", tool_result);
|
|
||||||
|
|
||||||
event_stream.update_tool_call_fields(
|
|
||||||
&tool_result.tool_use_id,
|
|
||||||
acp::ToolCallUpdateFields {
|
|
||||||
status: Some(if tool_result.is_error {
|
|
||||||
acp::ToolCallStatus::Failed
|
|
||||||
} else {
|
|
||||||
acp::ToolCallStatus::Completed
|
|
||||||
}),
|
|
||||||
raw_output: tool_result.output.clone(),
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
);
|
|
||||||
this.update(cx, |this, _cx| {
|
|
||||||
this.pending_message()
|
|
||||||
.tool_results
|
|
||||||
.insert(tool_result.tool_use_id.clone(), tool_result);
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
let mut end_turn = true;
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
|
// Generate title if needed.
|
||||||
if this.title.is_none() && update_title.is_none() {
|
if this.title.is_none() && update_title.is_none() {
|
||||||
update_title = Some(this.update_title(&event_stream, cx));
|
update_title = Some(this.update_title(&event_stream, cx));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// End the turn if the model didn't use tools.
|
||||||
|
let message = this.pending_message.as_ref();
|
||||||
|
end_turn =
|
||||||
|
message.map_or(true, |message| message.tool_results.is_empty());
|
||||||
|
this.flush_pending_message(cx);
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
if tool_use_limit_reached {
|
if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
|
||||||
log::info!("Tool use limit reached, completing turn");
|
log::info!("Tool use limit reached, completing turn");
|
||||||
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
|
|
||||||
return Err(language_model::ToolUseLimitReachedError.into());
|
return Err(language_model::ToolUseLimitReachedError.into());
|
||||||
} else if end_turn {
|
} else if end_turn {
|
||||||
log::info!("No tool uses found, completing turn");
|
log::info!("No tool uses found, completing turn");
|
||||||
return Ok(StopReason::EndTurn);
|
return Ok(());
|
||||||
} else {
|
} else {
|
||||||
this.update(cx, |this, cx| this.flush_pending_message(cx))?;
|
intent = CompletionIntent::ToolResults;
|
||||||
completion_intent = CompletionIntent::ToolResults;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
.await;
|
.await;
|
||||||
_ = this.update(cx, |this, cx| this.flush_pending_message(cx));
|
_ = this.update(cx, |this, cx| this.flush_pending_message(cx));
|
||||||
|
|
||||||
|
if let Some(update_title) = update_title {
|
||||||
|
update_title.await.context("update title failed").log_err();
|
||||||
|
}
|
||||||
|
|
||||||
match turn_result {
|
match turn_result {
|
||||||
Ok(reason) => {
|
Ok(()) => {
|
||||||
log::info!("Turn execution completed: {:?}", reason);
|
log::info!("Turn execution completed");
|
||||||
|
event_stream.send_stop(acp::StopReason::EndTurn);
|
||||||
if let Some(update_title) = update_title {
|
|
||||||
update_title.await.context("update title failed").log_err();
|
|
||||||
}
|
|
||||||
|
|
||||||
event_stream.send_stop(reason);
|
|
||||||
if reason == StopReason::Refusal {
|
|
||||||
_ = this.update(cx, |this, _| this.messages.truncate(message_ix));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
log::error!("Turn execution failed: {:?}", error);
|
log::error!("Turn execution failed: {:?}", error);
|
||||||
event_stream.send_error(error);
|
match error.downcast::<CompletionError>() {
|
||||||
|
Ok(CompletionError::Refusal) => {
|
||||||
|
event_stream.send_stop(acp::StopReason::Refusal);
|
||||||
|
_ = this.update(cx, |this, _| this.messages.truncate(message_ix));
|
||||||
|
}
|
||||||
|
Ok(CompletionError::MaxTokens) => {
|
||||||
|
event_stream.send_stop(acp::StopReason::MaxTokens);
|
||||||
|
}
|
||||||
|
Ok(CompletionError::Other(error)) | Err(error) => {
|
||||||
|
event_stream.send_error(error);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1181,17 +1152,17 @@ impl Thread {
|
||||||
Ok(events_rx)
|
Ok(events_rx)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn stream_completion_with_retries(
|
async fn stream_completion(
|
||||||
this: WeakEntity<Self>,
|
this: &WeakEntity<Self>,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: &Arc<dyn LanguageModel>,
|
||||||
request: LanguageModelRequest,
|
completion_intent: CompletionIntent,
|
||||||
event_stream: &ThreadEventStream,
|
event_stream: &ThreadEventStream,
|
||||||
tool_use_limit_reached: &mut bool,
|
|
||||||
refusal: &mut bool,
|
|
||||||
max_tokens_reached: &mut bool,
|
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
|
) -> Result<()> {
|
||||||
log::debug!("Stream completion started successfully");
|
log::debug!("Stream completion started successfully");
|
||||||
|
let request = this.update(cx, |this, cx| {
|
||||||
|
this.build_completion_request(completion_intent, cx)
|
||||||
|
})??;
|
||||||
|
|
||||||
let mut attempt = None;
|
let mut attempt = None;
|
||||||
'retry: loop {
|
'retry: loop {
|
||||||
|
@ -1204,68 +1175,33 @@ impl Thread {
|
||||||
attempt
|
attempt
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut events = model.stream_completion(request.clone(), cx).await?;
|
log::info!(
|
||||||
let mut tool_uses = FuturesUnordered::new();
|
"Calling model.stream_completion, attempt {}",
|
||||||
|
attempt.unwrap_or(0)
|
||||||
|
);
|
||||||
|
let mut events = model
|
||||||
|
.stream_completion(request.clone(), cx)
|
||||||
|
.await
|
||||||
|
.map_err(|error| anyhow!(error))?;
|
||||||
|
let mut tool_results = FuturesUnordered::new();
|
||||||
|
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
match event {
|
match event {
|
||||||
Ok(LanguageModelCompletionEvent::StatusUpdate(
|
|
||||||
CompletionRequestStatus::ToolUseLimitReached,
|
|
||||||
)) => {
|
|
||||||
*tool_use_limit_reached = true;
|
|
||||||
}
|
|
||||||
Ok(LanguageModelCompletionEvent::StatusUpdate(
|
|
||||||
CompletionRequestStatus::UsageUpdated { amount, limit },
|
|
||||||
)) => {
|
|
||||||
this.update(cx, |this, cx| {
|
|
||||||
this.update_model_request_usage(amount, limit, cx)
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
Ok(LanguageModelCompletionEvent::UsageUpdate(usage)) => {
|
|
||||||
telemetry::event!(
|
|
||||||
"Agent Thread Completion Usage Updated",
|
|
||||||
thread_id = this.read_with(cx, |this, _| this.id.to_string())?,
|
|
||||||
prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?,
|
|
||||||
model = model.telemetry_id(),
|
|
||||||
model_provider = model.provider_id().to_string(),
|
|
||||||
attempt,
|
|
||||||
input_tokens = usage.input_tokens,
|
|
||||||
output_tokens = usage.output_tokens,
|
|
||||||
cache_creation_input_tokens = usage.cache_creation_input_tokens,
|
|
||||||
cache_read_input_tokens = usage.cache_read_input_tokens,
|
|
||||||
);
|
|
||||||
|
|
||||||
this.update(cx, |this, cx| this.update_token_usage(usage, cx))?;
|
|
||||||
}
|
|
||||||
Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
|
|
||||||
*refusal = true;
|
|
||||||
return Ok(FuturesUnordered::default());
|
|
||||||
}
|
|
||||||
Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => {
|
|
||||||
*max_tokens_reached = true;
|
|
||||||
return Ok(FuturesUnordered::default());
|
|
||||||
}
|
|
||||||
Ok(LanguageModelCompletionEvent::Stop(
|
|
||||||
StopReason::ToolUse | StopReason::EndTurn,
|
|
||||||
)) => break,
|
|
||||||
Ok(event) => {
|
Ok(event) => {
|
||||||
log::trace!("Received completion event: {:?}", event);
|
log::trace!("Received completion event: {:?}", event);
|
||||||
this.update(cx, |this, cx| {
|
tool_results.extend(this.update(cx, |this, cx| {
|
||||||
tool_uses.extend(this.handle_streamed_completion_event(
|
this.handle_streamed_completion_event(event, event_stream, cx)
|
||||||
event,
|
})??);
|
||||||
event_stream,
|
|
||||||
cx,
|
|
||||||
));
|
|
||||||
})?;
|
|
||||||
}
|
}
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
let completion_mode =
|
let completion_mode =
|
||||||
this.read_with(cx, |thread, _cx| thread.completion_mode())?;
|
this.read_with(cx, |thread, _cx| thread.completion_mode())?;
|
||||||
if completion_mode == CompletionMode::Normal {
|
if completion_mode == CompletionMode::Normal {
|
||||||
return Err(error.into());
|
return Err(anyhow!(error))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let Some(strategy) = Self::retry_strategy_for(&error) else {
|
let Some(strategy) = Self::retry_strategy_for(&error) else {
|
||||||
return Err(error.into());
|
return Err(anyhow!(error))?;
|
||||||
};
|
};
|
||||||
|
|
||||||
let max_attempts = match &strategy {
|
let max_attempts = match &strategy {
|
||||||
|
@ -1279,7 +1215,7 @@ impl Thread {
|
||||||
|
|
||||||
let attempt = *attempt;
|
let attempt = *attempt;
|
||||||
if attempt > max_attempts {
|
if attempt > max_attempts {
|
||||||
return Err(error.into());
|
return Err(anyhow!(error))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let delay = match &strategy {
|
let delay = match &strategy {
|
||||||
|
@ -1306,7 +1242,29 @@ impl Thread {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Ok(tool_uses);
|
while let Some(tool_result) = tool_results.next().await {
|
||||||
|
log::info!("Tool finished {:?}", tool_result);
|
||||||
|
|
||||||
|
event_stream.update_tool_call_fields(
|
||||||
|
&tool_result.tool_use_id,
|
||||||
|
acp::ToolCallUpdateFields {
|
||||||
|
status: Some(if tool_result.is_error {
|
||||||
|
acp::ToolCallStatus::Failed
|
||||||
|
} else {
|
||||||
|
acp::ToolCallStatus::Completed
|
||||||
|
}),
|
||||||
|
raw_output: tool_result.output.clone(),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
);
|
||||||
|
this.update(cx, |this, _cx| {
|
||||||
|
this.pending_message()
|
||||||
|
.tool_results
|
||||||
|
.insert(tool_result.tool_use_id.clone(), tool_result);
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1328,14 +1286,14 @@ impl Thread {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A helper method that's called on every streamed completion event.
|
/// A helper method that's called on every streamed completion event.
|
||||||
/// Returns an optional tool result task, which the main agentic loop in
|
/// Returns an optional tool result task, which the main agentic loop will
|
||||||
/// send will send back to the model when it resolves.
|
/// send back to the model when it resolves.
|
||||||
fn handle_streamed_completion_event(
|
fn handle_streamed_completion_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
event: LanguageModelCompletionEvent,
|
event: LanguageModelCompletionEvent,
|
||||||
event_stream: &ThreadEventStream,
|
event_stream: &ThreadEventStream,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Option<Task<LanguageModelToolResult>> {
|
) -> Result<Option<Task<LanguageModelToolResult>>> {
|
||||||
log::trace!("Handling streamed completion event: {:?}", event);
|
log::trace!("Handling streamed completion event: {:?}", event);
|
||||||
use LanguageModelCompletionEvent::*;
|
use LanguageModelCompletionEvent::*;
|
||||||
|
|
||||||
|
@ -1350,7 +1308,7 @@ impl Thread {
|
||||||
}
|
}
|
||||||
RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
|
RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
|
||||||
ToolUse(tool_use) => {
|
ToolUse(tool_use) => {
|
||||||
return self.handle_tool_use_event(tool_use, event_stream, cx);
|
return Ok(self.handle_tool_use_event(tool_use, event_stream, cx));
|
||||||
}
|
}
|
||||||
ToolUseJsonParseError {
|
ToolUseJsonParseError {
|
||||||
id,
|
id,
|
||||||
|
@ -1358,18 +1316,46 @@ impl Thread {
|
||||||
raw_input,
|
raw_input,
|
||||||
json_parse_error,
|
json_parse_error,
|
||||||
} => {
|
} => {
|
||||||
return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
|
return Ok(Some(Task::ready(
|
||||||
id,
|
self.handle_tool_use_json_parse_error_event(
|
||||||
tool_name,
|
id,
|
||||||
raw_input,
|
tool_name,
|
||||||
json_parse_error,
|
raw_input,
|
||||||
|
json_parse_error,
|
||||||
|
),
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
StatusUpdate(_) => {}
|
UsageUpdate(usage) => {
|
||||||
UsageUpdate(_) | Stop(_) => unreachable!(),
|
telemetry::event!(
|
||||||
|
"Agent Thread Completion Usage Updated",
|
||||||
|
thread_id = self.id.to_string(),
|
||||||
|
prompt_id = self.prompt_id.to_string(),
|
||||||
|
model = self.model.as_ref().map(|m| m.telemetry_id()),
|
||||||
|
model_provider = self.model.as_ref().map(|m| m.provider_id().to_string()),
|
||||||
|
input_tokens = usage.input_tokens,
|
||||||
|
output_tokens = usage.output_tokens,
|
||||||
|
cache_creation_input_tokens = usage.cache_creation_input_tokens,
|
||||||
|
cache_read_input_tokens = usage.cache_read_input_tokens,
|
||||||
|
);
|
||||||
|
self.update_token_usage(usage, cx);
|
||||||
|
}
|
||||||
|
StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
|
||||||
|
self.update_model_request_usage(amount, limit, cx);
|
||||||
|
}
|
||||||
|
StatusUpdate(
|
||||||
|
CompletionRequestStatus::Started
|
||||||
|
| CompletionRequestStatus::Queued { .. }
|
||||||
|
| CompletionRequestStatus::Failed { .. },
|
||||||
|
) => {}
|
||||||
|
StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
|
||||||
|
self.tool_use_limit_reached = true;
|
||||||
|
}
|
||||||
|
Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
|
||||||
|
Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
|
||||||
|
Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
None
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_text_event(
|
fn handle_text_event(
|
||||||
|
@ -2225,25 +2211,8 @@ impl ThreadEventStream {
|
||||||
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
|
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_stop(&self, reason: StopReason) {
|
fn send_stop(&self, reason: acp::StopReason) {
|
||||||
match reason {
|
self.0.unbounded_send(Ok(ThreadEvent::Stop(reason))).ok();
|
||||||
StopReason::EndTurn => {
|
|
||||||
self.0
|
|
||||||
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
|
|
||||||
.ok();
|
|
||||||
}
|
|
||||||
StopReason::MaxTokens => {
|
|
||||||
self.0
|
|
||||||
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
|
|
||||||
.ok();
|
|
||||||
}
|
|
||||||
StopReason::Refusal => {
|
|
||||||
self.0
|
|
||||||
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
|
|
||||||
.ok();
|
|
||||||
}
|
|
||||||
StopReason::ToolUse => {}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_canceled(&self) {
|
fn send_canceled(&self) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue