parent
47e1d4511c
commit
0ea0d466d2
7 changed files with 514 additions and 52 deletions
|
@ -24,6 +24,7 @@ use std::fmt::{Formatter, Write};
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
use std::process::ExitStatus;
|
use std::process::ExitStatus;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
|
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
|
||||||
use ui::App;
|
use ui::App;
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
@ -658,6 +659,15 @@ impl PlanEntry {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RetryStatus {
|
||||||
|
pub last_error: SharedString,
|
||||||
|
pub attempt: usize,
|
||||||
|
pub max_attempts: usize,
|
||||||
|
pub started_at: Instant,
|
||||||
|
pub duration: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct AcpThread {
|
pub struct AcpThread {
|
||||||
title: SharedString,
|
title: SharedString,
|
||||||
entries: Vec<AgentThreadEntry>,
|
entries: Vec<AgentThreadEntry>,
|
||||||
|
@ -676,6 +686,7 @@ pub enum AcpThreadEvent {
|
||||||
EntryUpdated(usize),
|
EntryUpdated(usize),
|
||||||
EntriesRemoved(Range<usize>),
|
EntriesRemoved(Range<usize>),
|
||||||
ToolAuthorizationRequired,
|
ToolAuthorizationRequired,
|
||||||
|
Retry(RetryStatus),
|
||||||
Stopped,
|
Stopped,
|
||||||
Error,
|
Error,
|
||||||
ServerExited(ExitStatus),
|
ServerExited(ExitStatus),
|
||||||
|
@ -916,6 +927,10 @@ impl AcpThread {
|
||||||
cx.emit(AcpThreadEvent::NewEntry);
|
cx.emit(AcpThreadEvent::NewEntry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
|
||||||
|
cx.emit(AcpThreadEvent::Retry(status));
|
||||||
|
}
|
||||||
|
|
||||||
pub fn update_tool_call(
|
pub fn update_tool_call(
|
||||||
&mut self,
|
&mut self,
|
||||||
update: impl Into<ToolCallUpdate>,
|
update: impl Into<ToolCallUpdate>,
|
||||||
|
|
|
@ -546,6 +546,11 @@ impl NativeAgentConnection {
|
||||||
thread.update_tool_call(update, cx)
|
thread.update_tool_call(update, cx)
|
||||||
})??;
|
})??;
|
||||||
}
|
}
|
||||||
|
AgentResponseEvent::Retry(status) => {
|
||||||
|
acp_thread.update(cx, |thread, cx| {
|
||||||
|
thread.update_retry_status(status, cx)
|
||||||
|
})?;
|
||||||
|
}
|
||||||
AgentResponseEvent::Stop(stop_reason) => {
|
AgentResponseEvent::Stop(stop_reason) => {
|
||||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||||
return Ok(acp::PromptResponse { stop_reason });
|
return Ok(acp::PromptResponse { stop_reason });
|
||||||
|
|
|
@ -6,15 +6,16 @@ use agent_settings::AgentProfileId;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use client::{Client, UserStore};
|
use client::{Client, UserStore};
|
||||||
use fs::{FakeFs, Fs};
|
use fs::{FakeFs, Fs};
|
||||||
use futures::channel::mpsc::UnboundedReceiver;
|
use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
|
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
|
||||||
};
|
};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
|
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
||||||
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
|
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
|
||||||
Role, StopReason, fake_provider::FakeLanguageModel,
|
LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
|
||||||
|
fake_provider::FakeLanguageModel,
|
||||||
};
|
};
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
@ -24,7 +25,6 @@ use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use smol::stream::StreamExt;
|
|
||||||
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||||
use util::path;
|
use util::path;
|
||||||
|
|
||||||
|
@ -1435,6 +1435,162 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
let mut events = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
|
||||||
|
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
|
||||||
|
let mut retry_events = Vec::new();
|
||||||
|
while let Some(Ok(event)) = events.next().await {
|
||||||
|
match event {
|
||||||
|
AgentResponseEvent::Retry(retry_status) => {
|
||||||
|
retry_events.push(retry_status);
|
||||||
|
}
|
||||||
|
AgentResponseEvent::Stop(..) => break,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(retry_events.len(), 0);
|
||||||
|
thread.read_with(cx, |thread, _cx| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.to_markdown(),
|
||||||
|
indoc! {"
|
||||||
|
## User
|
||||||
|
|
||||||
|
Hello!
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
Hey!
|
||||||
|
"}
|
||||||
|
)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_send_retry_on_error(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
let mut events = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
|
||||||
|
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
||||||
|
provider: LanguageModelProviderName::new("Anthropic"),
|
||||||
|
retry_after: Some(Duration::from_secs(3)),
|
||||||
|
});
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
|
||||||
|
cx.executor().advance_clock(Duration::from_secs(3));
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
|
||||||
|
let mut retry_events = Vec::new();
|
||||||
|
while let Some(Ok(event)) = events.next().await {
|
||||||
|
match event {
|
||||||
|
AgentResponseEvent::Retry(retry_status) => {
|
||||||
|
retry_events.push(retry_status);
|
||||||
|
}
|
||||||
|
AgentResponseEvent::Stop(..) => break,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(retry_events.len(), 1);
|
||||||
|
assert!(matches!(
|
||||||
|
retry_events[0],
|
||||||
|
acp_thread::RetryStatus { attempt: 1, .. }
|
||||||
|
));
|
||||||
|
thread.read_with(cx, |thread, _cx| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.to_markdown(),
|
||||||
|
indoc! {"
|
||||||
|
## User
|
||||||
|
|
||||||
|
Hello!
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
Hey!
|
||||||
|
"}
|
||||||
|
)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
let mut events = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
|
||||||
|
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
|
||||||
|
fake_model.send_last_completion_stream_error(
|
||||||
|
LanguageModelCompletionError::ServerOverloaded {
|
||||||
|
provider: LanguageModelProviderName::new("Anthropic"),
|
||||||
|
retry_after: Some(Duration::from_secs(3)),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
cx.executor().advance_clock(Duration::from_secs(3));
|
||||||
|
cx.run_until_parked();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut errors = Vec::new();
|
||||||
|
let mut retry_events = Vec::new();
|
||||||
|
while let Some(event) = events.next().await {
|
||||||
|
match event {
|
||||||
|
Ok(AgentResponseEvent::Retry(retry_status)) => {
|
||||||
|
retry_events.push(retry_status);
|
||||||
|
}
|
||||||
|
Ok(AgentResponseEvent::Stop(..)) => break,
|
||||||
|
Err(error) => errors.push(error),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
retry_events.len(),
|
||||||
|
crate::thread::MAX_RETRY_ATTEMPTS as usize
|
||||||
|
);
|
||||||
|
for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
|
||||||
|
assert_eq!(retry_events[i].attempt, i + 1);
|
||||||
|
}
|
||||||
|
assert_eq!(errors.len(), 1);
|
||||||
|
let error = errors[0]
|
||||||
|
.downcast_ref::<LanguageModelCompletionError>()
|
||||||
|
.unwrap();
|
||||||
|
assert!(matches!(
|
||||||
|
error,
|
||||||
|
LanguageModelCompletionError::ServerOverloaded { .. }
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
/// Filters out the stop events for asserting against in tests
|
/// Filters out the stop events for asserting against in tests
|
||||||
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
|
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
|
||||||
result_events
|
result_events
|
||||||
|
|
|
@ -12,12 +12,12 @@ use futures::{
|
||||||
channel::{mpsc, oneshot},
|
channel::{mpsc, oneshot},
|
||||||
stream::FuturesUnordered,
|
stream::FuturesUnordered,
|
||||||
};
|
};
|
||||||
use gpui::{App, Context, Entity, SharedString, Task};
|
use gpui::{App, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
|
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
||||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
|
||||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||||
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||||
};
|
};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use prompt_store::ProjectContext;
|
use prompt_store::ProjectContext;
|
||||||
|
@ -25,7 +25,12 @@ use schemars::{JsonSchema, Schema};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::{Settings, update_settings_file};
|
use settings::{Settings, update_settings_file};
|
||||||
use smol::stream::StreamExt;
|
use smol::stream::StreamExt;
|
||||||
use std::{collections::BTreeMap, path::Path, sync::Arc};
|
use std::{
|
||||||
|
collections::BTreeMap,
|
||||||
|
path::Path,
|
||||||
|
sync::Arc,
|
||||||
|
time::{Duration, Instant},
|
||||||
|
};
|
||||||
use std::{fmt::Write, ops::Range};
|
use std::{fmt::Write, ops::Range};
|
||||||
use util::{ResultExt, markdown::MarkdownCodeBlock};
|
use util::{ResultExt, markdown::MarkdownCodeBlock};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
@ -71,6 +76,21 @@ impl std::fmt::Display for PromptId {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
|
||||||
|
pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum RetryStrategy {
|
||||||
|
ExponentialBackoff {
|
||||||
|
initial_delay: Duration,
|
||||||
|
max_attempts: u8,
|
||||||
|
},
|
||||||
|
Fixed {
|
||||||
|
delay: Duration,
|
||||||
|
max_attempts: u8,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub enum Message {
|
pub enum Message {
|
||||||
User(UserMessage),
|
User(UserMessage),
|
||||||
|
@ -455,6 +475,7 @@ pub enum AgentResponseEvent {
|
||||||
ToolCall(acp::ToolCall),
|
ToolCall(acp::ToolCall),
|
||||||
ToolCallUpdate(acp_thread::ToolCallUpdate),
|
ToolCallUpdate(acp_thread::ToolCallUpdate),
|
||||||
ToolCallAuthorization(ToolCallAuthorization),
|
ToolCallAuthorization(ToolCallAuthorization),
|
||||||
|
Retry(acp_thread::RetryStatus),
|
||||||
Stop(acp::StopReason),
|
Stop(acp::StopReason),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -662,41 +683,18 @@ impl Thread {
|
||||||
})??;
|
})??;
|
||||||
|
|
||||||
log::info!("Calling model.stream_completion");
|
log::info!("Calling model.stream_completion");
|
||||||
let mut events = model.stream_completion(request, cx).await?;
|
|
||||||
log::debug!("Stream completion started successfully");
|
|
||||||
|
|
||||||
let mut tool_use_limit_reached = false;
|
let mut tool_use_limit_reached = false;
|
||||||
let mut tool_uses = FuturesUnordered::new();
|
let mut tool_uses = Self::stream_completion_with_retries(
|
||||||
while let Some(event) = events.next().await {
|
this.clone(),
|
||||||
match event? {
|
model.clone(),
|
||||||
LanguageModelCompletionEvent::StatusUpdate(
|
request,
|
||||||
CompletionRequestStatus::ToolUseLimitReached,
|
message_ix,
|
||||||
) => {
|
&event_stream,
|
||||||
tool_use_limit_reached = true;
|
&mut tool_use_limit_reached,
|
||||||
}
|
cx,
|
||||||
LanguageModelCompletionEvent::Stop(reason) => {
|
)
|
||||||
event_stream.send_stop(reason);
|
.await?;
|
||||||
if reason == StopReason::Refusal {
|
|
||||||
this.update(cx, |this, _cx| {
|
|
||||||
this.flush_pending_message();
|
|
||||||
this.messages.truncate(message_ix);
|
|
||||||
})?;
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
event => {
|
|
||||||
log::trace!("Received completion event: {:?}", event);
|
|
||||||
this.update(cx, |this, cx| {
|
|
||||||
tool_uses.extend(this.handle_streamed_completion_event(
|
|
||||||
event,
|
|
||||||
&event_stream,
|
|
||||||
cx,
|
|
||||||
));
|
|
||||||
})
|
|
||||||
.ok();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let used_tools = tool_uses.is_empty();
|
let used_tools = tool_uses.is_empty();
|
||||||
while let Some(tool_result) = tool_uses.next().await {
|
while let Some(tool_result) = tool_uses.next().await {
|
||||||
|
@ -754,10 +752,105 @@ impl Thread {
|
||||||
Ok(events_rx)
|
Ok(events_rx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn stream_completion_with_retries(
|
||||||
|
this: WeakEntity<Self>,
|
||||||
|
model: Arc<dyn LanguageModel>,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
message_ix: usize,
|
||||||
|
event_stream: &AgentResponseEventStream,
|
||||||
|
tool_use_limit_reached: &mut bool,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
|
||||||
|
log::debug!("Stream completion started successfully");
|
||||||
|
|
||||||
|
let mut attempt = None;
|
||||||
|
'retry: loop {
|
||||||
|
let mut events = model.stream_completion(request.clone(), cx).await?;
|
||||||
|
let mut tool_uses = FuturesUnordered::new();
|
||||||
|
while let Some(event) = events.next().await {
|
||||||
|
match event {
|
||||||
|
Ok(LanguageModelCompletionEvent::StatusUpdate(
|
||||||
|
CompletionRequestStatus::ToolUseLimitReached,
|
||||||
|
)) => {
|
||||||
|
*tool_use_limit_reached = true;
|
||||||
|
}
|
||||||
|
Ok(LanguageModelCompletionEvent::Stop(reason)) => {
|
||||||
|
event_stream.send_stop(reason);
|
||||||
|
if reason == StopReason::Refusal {
|
||||||
|
this.update(cx, |this, _cx| {
|
||||||
|
this.flush_pending_message();
|
||||||
|
this.messages.truncate(message_ix);
|
||||||
|
})?;
|
||||||
|
return Ok(tool_uses);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(event) => {
|
||||||
|
log::trace!("Received completion event: {:?}", event);
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
tool_uses.extend(this.handle_streamed_completion_event(
|
||||||
|
event,
|
||||||
|
event_stream,
|
||||||
|
cx,
|
||||||
|
));
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
let completion_mode =
|
||||||
|
this.read_with(cx, |thread, _cx| thread.completion_mode())?;
|
||||||
|
if completion_mode == CompletionMode::Normal {
|
||||||
|
return Err(error.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(strategy) = Self::retry_strategy_for(&error) else {
|
||||||
|
return Err(error.into());
|
||||||
|
};
|
||||||
|
|
||||||
|
let max_attempts = match &strategy {
|
||||||
|
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
|
||||||
|
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
|
||||||
|
};
|
||||||
|
|
||||||
|
let attempt = attempt.get_or_insert(0u8);
|
||||||
|
|
||||||
|
*attempt += 1;
|
||||||
|
|
||||||
|
let attempt = *attempt;
|
||||||
|
if attempt > max_attempts {
|
||||||
|
return Err(error.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let delay = match &strategy {
|
||||||
|
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
|
||||||
|
let delay_secs =
|
||||||
|
initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
|
||||||
|
Duration::from_secs(delay_secs)
|
||||||
|
}
|
||||||
|
RetryStrategy::Fixed { delay, .. } => *delay,
|
||||||
|
};
|
||||||
|
log::debug!("Retry attempt {attempt} with delay {delay:?}");
|
||||||
|
|
||||||
|
event_stream.send_retry(acp_thread::RetryStatus {
|
||||||
|
last_error: error.to_string().into(),
|
||||||
|
attempt: attempt as usize,
|
||||||
|
max_attempts: max_attempts as usize,
|
||||||
|
started_at: Instant::now(),
|
||||||
|
duration: delay,
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.background_executor().timer(delay).await;
|
||||||
|
continue 'retry;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Ok(tool_uses);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
|
pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
|
||||||
log::debug!("Building system message");
|
log::debug!("Building system message");
|
||||||
let prompt = SystemPromptTemplate {
|
let prompt = SystemPromptTemplate {
|
||||||
project: &self.project_context.read(cx),
|
project: self.project_context.read(cx),
|
||||||
available_tools: self.tools.keys().cloned().collect(),
|
available_tools: self.tools.keys().cloned().collect(),
|
||||||
}
|
}
|
||||||
.render(&self.templates)
|
.render(&self.templates)
|
||||||
|
@ -1158,6 +1251,113 @@ impl Thread {
|
||||||
fn advance_prompt_id(&mut self) {
|
fn advance_prompt_id(&mut self) {
|
||||||
self.prompt_id = PromptId::new();
|
self.prompt_id = PromptId::new();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
|
||||||
|
use LanguageModelCompletionError::*;
|
||||||
|
use http_client::StatusCode;
|
||||||
|
|
||||||
|
// General strategy here:
|
||||||
|
// - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
|
||||||
|
// - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
|
||||||
|
// - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
|
||||||
|
match error {
|
||||||
|
HttpResponseError {
|
||||||
|
status_code: StatusCode::TOO_MANY_REQUESTS,
|
||||||
|
..
|
||||||
|
} => Some(RetryStrategy::ExponentialBackoff {
|
||||||
|
initial_delay: BASE_RETRY_DELAY,
|
||||||
|
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||||
|
}),
|
||||||
|
ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
|
||||||
|
Some(RetryStrategy::Fixed {
|
||||||
|
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||||
|
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
UpstreamProviderError {
|
||||||
|
status,
|
||||||
|
retry_after,
|
||||||
|
..
|
||||||
|
} => match *status {
|
||||||
|
StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
|
||||||
|
Some(RetryStrategy::Fixed {
|
||||||
|
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||||
|
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
|
||||||
|
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||||
|
// Internal Server Error could be anything, retry up to 3 times.
|
||||||
|
max_attempts: 3,
|
||||||
|
}),
|
||||||
|
status => {
|
||||||
|
// There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
|
||||||
|
// but we frequently get them in practice. See https://http.dev/529
|
||||||
|
if status.as_u16() == 529 {
|
||||||
|
Some(RetryStrategy::Fixed {
|
||||||
|
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||||
|
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Some(RetryStrategy::Fixed {
|
||||||
|
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||||
|
max_attempts: 2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
|
||||||
|
delay: BASE_RETRY_DELAY,
|
||||||
|
max_attempts: 3,
|
||||||
|
}),
|
||||||
|
ApiReadResponseError { .. }
|
||||||
|
| HttpSend { .. }
|
||||||
|
| DeserializeResponse { .. }
|
||||||
|
| BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
|
||||||
|
delay: BASE_RETRY_DELAY,
|
||||||
|
max_attempts: 3,
|
||||||
|
}),
|
||||||
|
// Retrying these errors definitely shouldn't help.
|
||||||
|
HttpResponseError {
|
||||||
|
status_code:
|
||||||
|
StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| AuthenticationError { .. }
|
||||||
|
| PermissionError { .. }
|
||||||
|
| NoApiKey { .. }
|
||||||
|
| ApiEndpointNotFound { .. }
|
||||||
|
| PromptTooLarge { .. } => None,
|
||||||
|
// These errors might be transient, so retry them
|
||||||
|
SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
|
||||||
|
delay: BASE_RETRY_DELAY,
|
||||||
|
max_attempts: 1,
|
||||||
|
}),
|
||||||
|
// Retry all other 4xx and 5xx errors once.
|
||||||
|
HttpResponseError { status_code, .. }
|
||||||
|
if status_code.is_client_error() || status_code.is_server_error() =>
|
||||||
|
{
|
||||||
|
Some(RetryStrategy::Fixed {
|
||||||
|
delay: BASE_RETRY_DELAY,
|
||||||
|
max_attempts: 3,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Other(err)
|
||||||
|
if err.is::<language_model::PaymentRequiredError>()
|
||||||
|
|| err.is::<language_model::ModelRequestLimitReachedError>() =>
|
||||||
|
{
|
||||||
|
// Retrying won't help for Payment Required or Model Request Limit errors (where
|
||||||
|
// the user must upgrade to usage-based billing to get more requests, or else wait
|
||||||
|
// for a significant amount of time for the request limit to reset).
|
||||||
|
None
|
||||||
|
}
|
||||||
|
// Conservatively assume that any other errors are non-retryable
|
||||||
|
HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
|
||||||
|
delay: BASE_RETRY_DELAY,
|
||||||
|
max_attempts: 2,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct RunningTurn {
|
struct RunningTurn {
|
||||||
|
@ -1367,6 +1567,12 @@ impl AgentResponseEventStream {
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn send_retry(&self, status: acp_thread::RetryStatus) {
|
||||||
|
self.0
|
||||||
|
.unbounded_send(Ok(AgentResponseEvent::Retry(status)))
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
fn send_stop(&self, reason: StopReason) {
|
fn send_stop(&self, reason: StopReason) {
|
||||||
match reason {
|
match reason {
|
||||||
StopReason::EndTurn => {
|
StopReason::EndTurn => {
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use acp_thread::{
|
use acp_thread::{
|
||||||
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
|
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
|
||||||
AuthRequired, LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
|
AuthRequired, LoadError, MentionUri, RetryStatus, ThreadStatus, ToolCall, ToolCallContent,
|
||||||
UserMessageId,
|
ToolCallStatus, UserMessageId,
|
||||||
};
|
};
|
||||||
use acp_thread::{AgentConnection, Plan};
|
use acp_thread::{AgentConnection, Plan};
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
|
@ -35,6 +35,7 @@ use prompt_store::PromptId;
|
||||||
use rope::Point;
|
use rope::Point;
|
||||||
use settings::{Settings as _, SettingsStore};
|
use settings::{Settings as _, SettingsStore};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration};
|
use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration};
|
||||||
use text::Anchor;
|
use text::Anchor;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
|
@ -115,6 +116,7 @@ pub struct AcpThreadView {
|
||||||
profile_selector: Option<Entity<ProfileSelector>>,
|
profile_selector: Option<Entity<ProfileSelector>>,
|
||||||
notifications: Vec<WindowHandle<AgentNotification>>,
|
notifications: Vec<WindowHandle<AgentNotification>>,
|
||||||
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
|
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
|
||||||
|
thread_retry_status: Option<RetryStatus>,
|
||||||
thread_error: Option<ThreadError>,
|
thread_error: Option<ThreadError>,
|
||||||
list_state: ListState,
|
list_state: ListState,
|
||||||
scrollbar_state: ScrollbarState,
|
scrollbar_state: ScrollbarState,
|
||||||
|
@ -209,6 +211,7 @@ impl AcpThreadView {
|
||||||
notification_subscriptions: HashMap::default(),
|
notification_subscriptions: HashMap::default(),
|
||||||
list_state: list_state.clone(),
|
list_state: list_state.clone(),
|
||||||
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
|
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
|
||||||
|
thread_retry_status: None,
|
||||||
thread_error: None,
|
thread_error: None,
|
||||||
auth_task: None,
|
auth_task: None,
|
||||||
expanded_tool_calls: HashSet::default(),
|
expanded_tool_calls: HashSet::default(),
|
||||||
|
@ -445,6 +448,7 @@ impl AcpThreadView {
|
||||||
|
|
||||||
pub fn cancel_generation(&mut self, cx: &mut Context<Self>) {
|
pub fn cancel_generation(&mut self, cx: &mut Context<Self>) {
|
||||||
self.thread_error.take();
|
self.thread_error.take();
|
||||||
|
self.thread_retry_status.take();
|
||||||
|
|
||||||
if let Some(thread) = self.thread() {
|
if let Some(thread) = self.thread() {
|
||||||
self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx)));
|
self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx)));
|
||||||
|
@ -775,7 +779,11 @@ impl AcpThreadView {
|
||||||
AcpThreadEvent::ToolAuthorizationRequired => {
|
AcpThreadEvent::ToolAuthorizationRequired => {
|
||||||
self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx);
|
self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx);
|
||||||
}
|
}
|
||||||
|
AcpThreadEvent::Retry(retry) => {
|
||||||
|
self.thread_retry_status = Some(retry.clone());
|
||||||
|
}
|
||||||
AcpThreadEvent::Stopped => {
|
AcpThreadEvent::Stopped => {
|
||||||
|
self.thread_retry_status.take();
|
||||||
let used_tools = thread.read(cx).used_tools_since_last_user_message();
|
let used_tools = thread.read(cx).used_tools_since_last_user_message();
|
||||||
self.notify_with_sound(
|
self.notify_with_sound(
|
||||||
if used_tools {
|
if used_tools {
|
||||||
|
@ -789,6 +797,7 @@ impl AcpThreadView {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
AcpThreadEvent::Error => {
|
AcpThreadEvent::Error => {
|
||||||
|
self.thread_retry_status.take();
|
||||||
self.notify_with_sound(
|
self.notify_with_sound(
|
||||||
"Agent stopped due to an error",
|
"Agent stopped due to an error",
|
||||||
IconName::Warning,
|
IconName::Warning,
|
||||||
|
@ -797,6 +806,7 @@ impl AcpThreadView {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
AcpThreadEvent::ServerExited(status) => {
|
AcpThreadEvent::ServerExited(status) => {
|
||||||
|
self.thread_retry_status.take();
|
||||||
self.thread_state = ThreadState::ServerExited { status: *status };
|
self.thread_state = ThreadState::ServerExited { status: *status };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3413,7 +3423,51 @@ impl AcpThreadView {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option<Div> {
|
fn render_thread_retry_status_callout(
|
||||||
|
&self,
|
||||||
|
_window: &mut Window,
|
||||||
|
_cx: &mut Context<Self>,
|
||||||
|
) -> Option<Callout> {
|
||||||
|
let state = self.thread_retry_status.as_ref()?;
|
||||||
|
|
||||||
|
let next_attempt_in = state
|
||||||
|
.duration
|
||||||
|
.saturating_sub(Instant::now().saturating_duration_since(state.started_at));
|
||||||
|
if next_attempt_in.is_zero() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next_attempt_in_secs = next_attempt_in.as_secs() + 1;
|
||||||
|
|
||||||
|
let retry_message = if state.max_attempts == 1 {
|
||||||
|
if next_attempt_in_secs == 1 {
|
||||||
|
"Retrying. Next attempt in 1 second.".to_string()
|
||||||
|
} else {
|
||||||
|
format!("Retrying. Next attempt in {next_attempt_in_secs} seconds.")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if next_attempt_in_secs == 1 {
|
||||||
|
format!(
|
||||||
|
"Retrying. Next attempt in 1 second (Attempt {} of {}).",
|
||||||
|
state.attempt, state.max_attempts,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
format!(
|
||||||
|
"Retrying. Next attempt in {next_attempt_in_secs} seconds (Attempt {} of {}).",
|
||||||
|
state.attempt, state.max_attempts,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Some(
|
||||||
|
Callout::new()
|
||||||
|
.severity(Severity::Warning)
|
||||||
|
.title(state.last_error.clone())
|
||||||
|
.description(retry_message),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_thread_error(&self, window: &mut Window, cx: &mut Context<Self>) -> Option<Div> {
|
||||||
let content = match self.thread_error.as_ref()? {
|
let content = match self.thread_error.as_ref()? {
|
||||||
ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx),
|
ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx),
|
||||||
ThreadError::PaymentRequired => self.render_payment_required_error(cx),
|
ThreadError::PaymentRequired => self.render_payment_required_error(cx),
|
||||||
|
@ -3678,6 +3732,7 @@ impl Render for AcpThreadView {
|
||||||
}
|
}
|
||||||
_ => this,
|
_ => this,
|
||||||
})
|
})
|
||||||
|
.children(self.render_thread_retry_status_callout(window, cx))
|
||||||
.children(self.render_thread_error(window, cx))
|
.children(self.render_thread_error(window, cx))
|
||||||
.child(self.render_message_editor(window, cx))
|
.child(self.render_message_editor(window, cx))
|
||||||
}
|
}
|
||||||
|
|
|
@ -1523,6 +1523,7 @@ impl AgentDiff {
|
||||||
AcpThreadEvent::EntriesRemoved(_)
|
AcpThreadEvent::EntriesRemoved(_)
|
||||||
| AcpThreadEvent::Stopped
|
| AcpThreadEvent::Stopped
|
||||||
| AcpThreadEvent::ToolAuthorizationRequired
|
| AcpThreadEvent::ToolAuthorizationRequired
|
||||||
|
| AcpThreadEvent::Retry(_)
|
||||||
| AcpThreadEvent::Error
|
| AcpThreadEvent::Error
|
||||||
| AcpThreadEvent::ServerExited(_) => {}
|
| AcpThreadEvent::ServerExited(_) => {}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,10 +4,11 @@ use crate::{
|
||||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||||
LanguageModelRequest, LanguageModelToolChoice,
|
LanguageModelRequest, LanguageModelToolChoice,
|
||||||
};
|
};
|
||||||
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
|
use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
|
||||||
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
|
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
|
||||||
use http_client::Result;
|
use http_client::Result;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
use smol::stream::StreamExt;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
@ -100,7 +101,9 @@ pub struct FakeLanguageModel {
|
||||||
current_completion_txs: Mutex<
|
current_completion_txs: Mutex<
|
||||||
Vec<(
|
Vec<(
|
||||||
LanguageModelRequest,
|
LanguageModelRequest,
|
||||||
mpsc::UnboundedSender<LanguageModelCompletionEvent>,
|
mpsc::UnboundedSender<
|
||||||
|
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||||
|
>,
|
||||||
)>,
|
)>,
|
||||||
>,
|
>,
|
||||||
}
|
}
|
||||||
|
@ -150,7 +153,21 @@ impl FakeLanguageModel {
|
||||||
.find(|(req, _)| req == request)
|
.find(|(req, _)| req == request)
|
||||||
.map(|(_, tx)| tx)
|
.map(|(_, tx)| tx)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
tx.unbounded_send(event.into()).unwrap();
|
tx.unbounded_send(Ok(event.into())).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_completion_stream_error(
|
||||||
|
&self,
|
||||||
|
request: &LanguageModelRequest,
|
||||||
|
error: impl Into<LanguageModelCompletionError>,
|
||||||
|
) {
|
||||||
|
let current_completion_txs = self.current_completion_txs.lock();
|
||||||
|
let tx = current_completion_txs
|
||||||
|
.iter()
|
||||||
|
.find(|(req, _)| req == request)
|
||||||
|
.map(|(_, tx)| tx)
|
||||||
|
.unwrap();
|
||||||
|
tx.unbounded_send(Err(error.into())).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
|
pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
|
||||||
|
@ -170,6 +187,13 @@ impl FakeLanguageModel {
|
||||||
self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
|
self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn send_last_completion_stream_error(
|
||||||
|
&self,
|
||||||
|
error: impl Into<LanguageModelCompletionError>,
|
||||||
|
) {
|
||||||
|
self.send_completion_stream_error(self.pending_completions().last().unwrap(), error);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn end_last_completion_stream(&self) {
|
pub fn end_last_completion_stream(&self) {
|
||||||
self.end_completion_stream(self.pending_completions().last().unwrap());
|
self.end_completion_stream(self.pending_completions().last().unwrap());
|
||||||
}
|
}
|
||||||
|
@ -229,7 +253,7 @@ impl LanguageModel for FakeLanguageModel {
|
||||||
> {
|
> {
|
||||||
let (tx, rx) = mpsc::unbounded();
|
let (tx, rx) = mpsc::unbounded();
|
||||||
self.current_completion_txs.lock().push((request, tx));
|
self.current_completion_txs.lock().push((request, tx));
|
||||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
async move { Ok(rx.boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn as_fake(&self) -> &Self {
|
fn as_fake(&self) -> &Self {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue