parent
47e1d4511c
commit
0ea0d466d2
7 changed files with 514 additions and 52 deletions
|
@ -24,6 +24,7 @@ use std::fmt::{Formatter, Write};
|
|||
use std::ops::Range;
|
||||
use std::process::ExitStatus;
|
||||
use std::rc::Rc;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
|
||||
use ui::App;
|
||||
use util::ResultExt;
|
||||
|
@ -658,6 +659,15 @@ impl PlanEntry {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryStatus {
|
||||
pub last_error: SharedString,
|
||||
pub attempt: usize,
|
||||
pub max_attempts: usize,
|
||||
pub started_at: Instant,
|
||||
pub duration: Duration,
|
||||
}
|
||||
|
||||
pub struct AcpThread {
|
||||
title: SharedString,
|
||||
entries: Vec<AgentThreadEntry>,
|
||||
|
@ -676,6 +686,7 @@ pub enum AcpThreadEvent {
|
|||
EntryUpdated(usize),
|
||||
EntriesRemoved(Range<usize>),
|
||||
ToolAuthorizationRequired,
|
||||
Retry(RetryStatus),
|
||||
Stopped,
|
||||
Error,
|
||||
ServerExited(ExitStatus),
|
||||
|
@ -916,6 +927,10 @@ impl AcpThread {
|
|||
cx.emit(AcpThreadEvent::NewEntry);
|
||||
}
|
||||
|
||||
pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
|
||||
cx.emit(AcpThreadEvent::Retry(status));
|
||||
}
|
||||
|
||||
pub fn update_tool_call(
|
||||
&mut self,
|
||||
update: impl Into<ToolCallUpdate>,
|
||||
|
|
|
@ -546,6 +546,11 @@ impl NativeAgentConnection {
|
|||
thread.update_tool_call(update, cx)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::Retry(status) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.update_retry_status(status, cx)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::Stop(stop_reason) => {
|
||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||
return Ok(acp::PromptResponse { stop_reason });
|
||||
|
|
|
@ -6,15 +6,16 @@ use agent_settings::AgentProfileId;
|
|||
use anyhow::Result;
|
||||
use client::{Client, UserStore};
|
||||
use fs::{FakeFs, Fs};
|
||||
use futures::channel::mpsc::UnboundedReceiver;
|
||||
use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
|
||||
use gpui::{
|
||||
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
|
||||
};
|
||||
use indoc::indoc;
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
|
||||
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
|
||||
Role, StopReason, fake_provider::FakeLanguageModel,
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
|
||||
LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
|
||||
fake_provider::FakeLanguageModel,
|
||||
};
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::Project;
|
||||
|
@ -24,7 +25,6 @@ use schemars::JsonSchema;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt;
|
||||
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||
use util::path;
|
||||
|
||||
|
@ -1435,6 +1435,162 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
|||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
|
||||
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
|
||||
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||
fake_model.end_last_completion_stream();
|
||||
|
||||
let mut retry_events = Vec::new();
|
||||
while let Some(Ok(event)) = events.next().await {
|
||||
match event {
|
||||
AgentResponseEvent::Retry(retry_status) => {
|
||||
retry_events.push(retry_status);
|
||||
}
|
||||
AgentResponseEvent::Stop(..) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(retry_events.len(), 0);
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hello!
|
||||
|
||||
## Assistant
|
||||
|
||||
Hey!
|
||||
"}
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_send_retry_on_error(cx: &mut TestAppContext) {
|
||||
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
|
||||
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
||||
provider: LanguageModelProviderName::new("Anthropic"),
|
||||
retry_after: Some(Duration::from_secs(3)),
|
||||
});
|
||||
fake_model.end_last_completion_stream();
|
||||
|
||||
cx.executor().advance_clock(Duration::from_secs(3));
|
||||
cx.run_until_parked();
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||
fake_model.end_last_completion_stream();
|
||||
|
||||
let mut retry_events = Vec::new();
|
||||
while let Some(Ok(event)) = events.next().await {
|
||||
match event {
|
||||
AgentResponseEvent::Retry(retry_status) => {
|
||||
retry_events.push(retry_status);
|
||||
}
|
||||
AgentResponseEvent::Stop(..) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(retry_events.len(), 1);
|
||||
assert!(matches!(
|
||||
retry_events[0],
|
||||
acp_thread::RetryStatus { attempt: 1, .. }
|
||||
));
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hello!
|
||||
|
||||
## Assistant
|
||||
|
||||
Hey!
|
||||
"}
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
|
||||
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
|
||||
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
|
||||
fake_model.send_last_completion_stream_error(
|
||||
LanguageModelCompletionError::ServerOverloaded {
|
||||
provider: LanguageModelProviderName::new("Anthropic"),
|
||||
retry_after: Some(Duration::from_secs(3)),
|
||||
},
|
||||
);
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.executor().advance_clock(Duration::from_secs(3));
|
||||
cx.run_until_parked();
|
||||
}
|
||||
|
||||
let mut errors = Vec::new();
|
||||
let mut retry_events = Vec::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(AgentResponseEvent::Retry(retry_status)) => {
|
||||
retry_events.push(retry_status);
|
||||
}
|
||||
Ok(AgentResponseEvent::Stop(..)) => break,
|
||||
Err(error) => errors.push(error),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
retry_events.len(),
|
||||
crate::thread::MAX_RETRY_ATTEMPTS as usize
|
||||
);
|
||||
for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
|
||||
assert_eq!(retry_events[i].attempt, i + 1);
|
||||
}
|
||||
assert_eq!(errors.len(), 1);
|
||||
let error = errors[0]
|
||||
.downcast_ref::<LanguageModelCompletionError>()
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
error,
|
||||
LanguageModelCompletionError::ServerOverloaded { .. }
|
||||
));
|
||||
}
|
||||
|
||||
/// Filters out the stop events for asserting against in tests
|
||||
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
|
||||
result_events
|
||||
|
|
|
@ -12,12 +12,12 @@ use futures::{
|
|||
channel::{mpsc, oneshot},
|
||||
stream::FuturesUnordered,
|
||||
};
|
||||
use gpui::{App, Context, Entity, SharedString, Task};
|
||||
use gpui::{App, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
||||
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
|
@ -25,7 +25,12 @@ use schemars::{JsonSchema, Schema};
|
|||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, update_settings_file};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{collections::BTreeMap, path::Path, sync::Arc};
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use std::{fmt::Write, ops::Range};
|
||||
use util::{ResultExt, markdown::MarkdownCodeBlock};
|
||||
use uuid::Uuid;
|
||||
|
@ -71,6 +76,21 @@ impl std::fmt::Display for PromptId {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
|
||||
pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum RetryStrategy {
|
||||
ExponentialBackoff {
|
||||
initial_delay: Duration,
|
||||
max_attempts: u8,
|
||||
},
|
||||
Fixed {
|
||||
delay: Duration,
|
||||
max_attempts: u8,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum Message {
|
||||
User(UserMessage),
|
||||
|
@ -455,6 +475,7 @@ pub enum AgentResponseEvent {
|
|||
ToolCall(acp::ToolCall),
|
||||
ToolCallUpdate(acp_thread::ToolCallUpdate),
|
||||
ToolCallAuthorization(ToolCallAuthorization),
|
||||
Retry(acp_thread::RetryStatus),
|
||||
Stop(acp::StopReason),
|
||||
}
|
||||
|
||||
|
@ -662,41 +683,18 @@ impl Thread {
|
|||
})??;
|
||||
|
||||
log::info!("Calling model.stream_completion");
|
||||
let mut events = model.stream_completion(request, cx).await?;
|
||||
log::debug!("Stream completion started successfully");
|
||||
|
||||
let mut tool_use_limit_reached = false;
|
||||
let mut tool_uses = FuturesUnordered::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event? {
|
||||
LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::ToolUseLimitReached,
|
||||
) => {
|
||||
tool_use_limit_reached = true;
|
||||
}
|
||||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
event_stream.send_stop(reason);
|
||||
if reason == StopReason::Refusal {
|
||||
this.update(cx, |this, _cx| {
|
||||
this.flush_pending_message();
|
||||
this.messages.truncate(message_ix);
|
||||
})?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
event => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
this.update(cx, |this, cx| {
|
||||
tool_uses.extend(this.handle_streamed_completion_event(
|
||||
event,
|
||||
&event_stream,
|
||||
cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut tool_uses = Self::stream_completion_with_retries(
|
||||
this.clone(),
|
||||
model.clone(),
|
||||
request,
|
||||
message_ix,
|
||||
&event_stream,
|
||||
&mut tool_use_limit_reached,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let used_tools = tool_uses.is_empty();
|
||||
while let Some(tool_result) = tool_uses.next().await {
|
||||
|
@ -754,10 +752,105 @@ impl Thread {
|
|||
Ok(events_rx)
|
||||
}
|
||||
|
||||
async fn stream_completion_with_retries(
|
||||
this: WeakEntity<Self>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
request: LanguageModelRequest,
|
||||
message_ix: usize,
|
||||
event_stream: &AgentResponseEventStream,
|
||||
tool_use_limit_reached: &mut bool,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
|
||||
log::debug!("Stream completion started successfully");
|
||||
|
||||
let mut attempt = None;
|
||||
'retry: loop {
|
||||
let mut events = model.stream_completion(request.clone(), cx).await?;
|
||||
let mut tool_uses = FuturesUnordered::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::ToolUseLimitReached,
|
||||
)) => {
|
||||
*tool_use_limit_reached = true;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Stop(reason)) => {
|
||||
event_stream.send_stop(reason);
|
||||
if reason == StopReason::Refusal {
|
||||
this.update(cx, |this, _cx| {
|
||||
this.flush_pending_message();
|
||||
this.messages.truncate(message_ix);
|
||||
})?;
|
||||
return Ok(tool_uses);
|
||||
}
|
||||
}
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
this.update(cx, |this, cx| {
|
||||
tool_uses.extend(this.handle_streamed_completion_event(
|
||||
event,
|
||||
event_stream,
|
||||
cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
Err(error) => {
|
||||
let completion_mode =
|
||||
this.read_with(cx, |thread, _cx| thread.completion_mode())?;
|
||||
if completion_mode == CompletionMode::Normal {
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
let Some(strategy) = Self::retry_strategy_for(&error) else {
|
||||
return Err(error.into());
|
||||
};
|
||||
|
||||
let max_attempts = match &strategy {
|
||||
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
|
||||
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
|
||||
};
|
||||
|
||||
let attempt = attempt.get_or_insert(0u8);
|
||||
|
||||
*attempt += 1;
|
||||
|
||||
let attempt = *attempt;
|
||||
if attempt > max_attempts {
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
let delay = match &strategy {
|
||||
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
|
||||
let delay_secs =
|
||||
initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
|
||||
Duration::from_secs(delay_secs)
|
||||
}
|
||||
RetryStrategy::Fixed { delay, .. } => *delay,
|
||||
};
|
||||
log::debug!("Retry attempt {attempt} with delay {delay:?}");
|
||||
|
||||
event_stream.send_retry(acp_thread::RetryStatus {
|
||||
last_error: error.to_string().into(),
|
||||
attempt: attempt as usize,
|
||||
max_attempts: max_attempts as usize,
|
||||
started_at: Instant::now(),
|
||||
duration: delay,
|
||||
});
|
||||
|
||||
cx.background_executor().timer(delay).await;
|
||||
continue 'retry;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ok(tool_uses);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
|
||||
log::debug!("Building system message");
|
||||
let prompt = SystemPromptTemplate {
|
||||
project: &self.project_context.read(cx),
|
||||
project: self.project_context.read(cx),
|
||||
available_tools: self.tools.keys().cloned().collect(),
|
||||
}
|
||||
.render(&self.templates)
|
||||
|
@ -1158,6 +1251,113 @@ impl Thread {
|
|||
fn advance_prompt_id(&mut self) {
|
||||
self.prompt_id = PromptId::new();
|
||||
}
|
||||
|
||||
fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
|
||||
use LanguageModelCompletionError::*;
|
||||
use http_client::StatusCode;
|
||||
|
||||
// General strategy here:
|
||||
// - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
|
||||
// - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
|
||||
// - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
|
||||
match error {
|
||||
HttpResponseError {
|
||||
status_code: StatusCode::TOO_MANY_REQUESTS,
|
||||
..
|
||||
} => Some(RetryStrategy::ExponentialBackoff {
|
||||
initial_delay: BASE_RETRY_DELAY,
|
||||
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||
}),
|
||||
ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
|
||||
Some(RetryStrategy::Fixed {
|
||||
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||
})
|
||||
}
|
||||
UpstreamProviderError {
|
||||
status,
|
||||
retry_after,
|
||||
..
|
||||
} => match *status {
|
||||
StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
|
||||
Some(RetryStrategy::Fixed {
|
||||
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||
})
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
|
||||
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||
// Internal Server Error could be anything, retry up to 3 times.
|
||||
max_attempts: 3,
|
||||
}),
|
||||
status => {
|
||||
// There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
|
||||
// but we frequently get them in practice. See https://http.dev/529
|
||||
if status.as_u16() == 529 {
|
||||
Some(RetryStrategy::Fixed {
|
||||
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||
})
|
||||
} else {
|
||||
Some(RetryStrategy::Fixed {
|
||||
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||
max_attempts: 2,
|
||||
})
|
||||
}
|
||||
}
|
||||
},
|
||||
ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
max_attempts: 3,
|
||||
}),
|
||||
ApiReadResponseError { .. }
|
||||
| HttpSend { .. }
|
||||
| DeserializeResponse { .. }
|
||||
| BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
max_attempts: 3,
|
||||
}),
|
||||
// Retrying these errors definitely shouldn't help.
|
||||
HttpResponseError {
|
||||
status_code:
|
||||
StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
|
||||
..
|
||||
}
|
||||
| AuthenticationError { .. }
|
||||
| PermissionError { .. }
|
||||
| NoApiKey { .. }
|
||||
| ApiEndpointNotFound { .. }
|
||||
| PromptTooLarge { .. } => None,
|
||||
// These errors might be transient, so retry them
|
||||
SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
max_attempts: 1,
|
||||
}),
|
||||
// Retry all other 4xx and 5xx errors once.
|
||||
HttpResponseError { status_code, .. }
|
||||
if status_code.is_client_error() || status_code.is_server_error() =>
|
||||
{
|
||||
Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
max_attempts: 3,
|
||||
})
|
||||
}
|
||||
Other(err)
|
||||
if err.is::<language_model::PaymentRequiredError>()
|
||||
|| err.is::<language_model::ModelRequestLimitReachedError>() =>
|
||||
{
|
||||
// Retrying won't help for Payment Required or Model Request Limit errors (where
|
||||
// the user must upgrade to usage-based billing to get more requests, or else wait
|
||||
// for a significant amount of time for the request limit to reset).
|
||||
None
|
||||
}
|
||||
// Conservatively assume that any other errors are non-retryable
|
||||
HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
max_attempts: 2,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RunningTurn {
|
||||
|
@ -1367,6 +1567,12 @@ impl AgentResponseEventStream {
|
|||
.ok();
|
||||
}
|
||||
|
||||
fn send_retry(&self, status: acp_thread::RetryStatus) {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::Retry(status)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn send_stop(&self, reason: StopReason) {
|
||||
match reason {
|
||||
StopReason::EndTurn => {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use acp_thread::{
|
||||
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
|
||||
AuthRequired, LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
|
||||
UserMessageId,
|
||||
AuthRequired, LoadError, MentionUri, RetryStatus, ThreadStatus, ToolCall, ToolCallContent,
|
||||
ToolCallStatus, UserMessageId,
|
||||
};
|
||||
use acp_thread::{AgentConnection, Plan};
|
||||
use action_log::ActionLog;
|
||||
|
@ -35,6 +35,7 @@ use prompt_store::PromptId;
|
|||
use rope::Point;
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration};
|
||||
use text::Anchor;
|
||||
use theme::ThemeSettings;
|
||||
|
@ -115,6 +116,7 @@ pub struct AcpThreadView {
|
|||
profile_selector: Option<Entity<ProfileSelector>>,
|
||||
notifications: Vec<WindowHandle<AgentNotification>>,
|
||||
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
|
||||
thread_retry_status: Option<RetryStatus>,
|
||||
thread_error: Option<ThreadError>,
|
||||
list_state: ListState,
|
||||
scrollbar_state: ScrollbarState,
|
||||
|
@ -209,6 +211,7 @@ impl AcpThreadView {
|
|||
notification_subscriptions: HashMap::default(),
|
||||
list_state: list_state.clone(),
|
||||
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
|
||||
thread_retry_status: None,
|
||||
thread_error: None,
|
||||
auth_task: None,
|
||||
expanded_tool_calls: HashSet::default(),
|
||||
|
@ -445,6 +448,7 @@ impl AcpThreadView {
|
|||
|
||||
pub fn cancel_generation(&mut self, cx: &mut Context<Self>) {
|
||||
self.thread_error.take();
|
||||
self.thread_retry_status.take();
|
||||
|
||||
if let Some(thread) = self.thread() {
|
||||
self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx)));
|
||||
|
@ -775,7 +779,11 @@ impl AcpThreadView {
|
|||
AcpThreadEvent::ToolAuthorizationRequired => {
|
||||
self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx);
|
||||
}
|
||||
AcpThreadEvent::Retry(retry) => {
|
||||
self.thread_retry_status = Some(retry.clone());
|
||||
}
|
||||
AcpThreadEvent::Stopped => {
|
||||
self.thread_retry_status.take();
|
||||
let used_tools = thread.read(cx).used_tools_since_last_user_message();
|
||||
self.notify_with_sound(
|
||||
if used_tools {
|
||||
|
@ -789,6 +797,7 @@ impl AcpThreadView {
|
|||
);
|
||||
}
|
||||
AcpThreadEvent::Error => {
|
||||
self.thread_retry_status.take();
|
||||
self.notify_with_sound(
|
||||
"Agent stopped due to an error",
|
||||
IconName::Warning,
|
||||
|
@ -797,6 +806,7 @@ impl AcpThreadView {
|
|||
);
|
||||
}
|
||||
AcpThreadEvent::ServerExited(status) => {
|
||||
self.thread_retry_status.take();
|
||||
self.thread_state = ThreadState::ServerExited { status: *status };
|
||||
}
|
||||
}
|
||||
|
@ -3413,7 +3423,51 @@ impl AcpThreadView {
|
|||
})
|
||||
}
|
||||
|
||||
fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option<Div> {
|
||||
fn render_thread_retry_status_callout(
|
||||
&self,
|
||||
_window: &mut Window,
|
||||
_cx: &mut Context<Self>,
|
||||
) -> Option<Callout> {
|
||||
let state = self.thread_retry_status.as_ref()?;
|
||||
|
||||
let next_attempt_in = state
|
||||
.duration
|
||||
.saturating_sub(Instant::now().saturating_duration_since(state.started_at));
|
||||
if next_attempt_in.is_zero() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let next_attempt_in_secs = next_attempt_in.as_secs() + 1;
|
||||
|
||||
let retry_message = if state.max_attempts == 1 {
|
||||
if next_attempt_in_secs == 1 {
|
||||
"Retrying. Next attempt in 1 second.".to_string()
|
||||
} else {
|
||||
format!("Retrying. Next attempt in {next_attempt_in_secs} seconds.")
|
||||
}
|
||||
} else {
|
||||
if next_attempt_in_secs == 1 {
|
||||
format!(
|
||||
"Retrying. Next attempt in 1 second (Attempt {} of {}).",
|
||||
state.attempt, state.max_attempts,
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"Retrying. Next attempt in {next_attempt_in_secs} seconds (Attempt {} of {}).",
|
||||
state.attempt, state.max_attempts,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
Some(
|
||||
Callout::new()
|
||||
.severity(Severity::Warning)
|
||||
.title(state.last_error.clone())
|
||||
.description(retry_message),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_thread_error(&self, window: &mut Window, cx: &mut Context<Self>) -> Option<Div> {
|
||||
let content = match self.thread_error.as_ref()? {
|
||||
ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx),
|
||||
ThreadError::PaymentRequired => self.render_payment_required_error(cx),
|
||||
|
@ -3678,6 +3732,7 @@ impl Render for AcpThreadView {
|
|||
}
|
||||
_ => this,
|
||||
})
|
||||
.children(self.render_thread_retry_status_callout(window, cx))
|
||||
.children(self.render_thread_error(window, cx))
|
||||
.child(self.render_message_editor(window, cx))
|
||||
}
|
||||
|
|
|
@ -1523,6 +1523,7 @@ impl AgentDiff {
|
|||
AcpThreadEvent::EntriesRemoved(_)
|
||||
| AcpThreadEvent::Stopped
|
||||
| AcpThreadEvent::ToolAuthorizationRequired
|
||||
| AcpThreadEvent::Retry(_)
|
||||
| AcpThreadEvent::Error
|
||||
| AcpThreadEvent::ServerExited(_) => {}
|
||||
}
|
||||
|
|
|
@ -4,10 +4,11 @@ use crate::{
|
|||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, LanguageModelToolChoice,
|
||||
};
|
||||
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
|
||||
use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
|
||||
use http_client::Result;
|
||||
use parking_lot::Mutex;
|
||||
use smol::stream::StreamExt;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -100,7 +101,9 @@ pub struct FakeLanguageModel {
|
|||
current_completion_txs: Mutex<
|
||||
Vec<(
|
||||
LanguageModelRequest,
|
||||
mpsc::UnboundedSender<LanguageModelCompletionEvent>,
|
||||
mpsc::UnboundedSender<
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
)>,
|
||||
>,
|
||||
}
|
||||
|
@ -150,7 +153,21 @@ impl FakeLanguageModel {
|
|||
.find(|(req, _)| req == request)
|
||||
.map(|(_, tx)| tx)
|
||||
.unwrap();
|
||||
tx.unbounded_send(event.into()).unwrap();
|
||||
tx.unbounded_send(Ok(event.into())).unwrap();
|
||||
}
|
||||
|
||||
pub fn send_completion_stream_error(
|
||||
&self,
|
||||
request: &LanguageModelRequest,
|
||||
error: impl Into<LanguageModelCompletionError>,
|
||||
) {
|
||||
let current_completion_txs = self.current_completion_txs.lock();
|
||||
let tx = current_completion_txs
|
||||
.iter()
|
||||
.find(|(req, _)| req == request)
|
||||
.map(|(_, tx)| tx)
|
||||
.unwrap();
|
||||
tx.unbounded_send(Err(error.into())).unwrap();
|
||||
}
|
||||
|
||||
pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
|
||||
|
@ -170,6 +187,13 @@ impl FakeLanguageModel {
|
|||
self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
|
||||
}
|
||||
|
||||
pub fn send_last_completion_stream_error(
|
||||
&self,
|
||||
error: impl Into<LanguageModelCompletionError>,
|
||||
) {
|
||||
self.send_completion_stream_error(self.pending_completions().last().unwrap(), error);
|
||||
}
|
||||
|
||||
pub fn end_last_completion_stream(&self) {
|
||||
self.end_completion_stream(self.pending_completions().last().unwrap());
|
||||
}
|
||||
|
@ -229,7 +253,7 @@ impl LanguageModel for FakeLanguageModel {
|
|||
> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
self.current_completion_txs.lock().push((request, tx));
|
||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||
async move { Ok(rx.boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn as_fake(&self) -> &Self {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue