parent
47e1d4511c
commit
0ea0d466d2
7 changed files with 514 additions and 52 deletions
|
@ -12,12 +12,12 @@ use futures::{
|
|||
channel::{mpsc, oneshot},
|
||||
stream::FuturesUnordered,
|
||||
};
|
||||
use gpui::{App, Context, Entity, SharedString, Task};
|
||||
use gpui::{App, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
||||
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
|
@ -25,7 +25,12 @@ use schemars::{JsonSchema, Schema};
|
|||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, update_settings_file};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{collections::BTreeMap, path::Path, sync::Arc};
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use std::{fmt::Write, ops::Range};
|
||||
use util::{ResultExt, markdown::MarkdownCodeBlock};
|
||||
use uuid::Uuid;
|
||||
|
@ -71,6 +76,21 @@ impl std::fmt::Display for PromptId {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
|
||||
pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum RetryStrategy {
|
||||
ExponentialBackoff {
|
||||
initial_delay: Duration,
|
||||
max_attempts: u8,
|
||||
},
|
||||
Fixed {
|
||||
delay: Duration,
|
||||
max_attempts: u8,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum Message {
|
||||
User(UserMessage),
|
||||
|
@ -455,6 +475,7 @@ pub enum AgentResponseEvent {
|
|||
ToolCall(acp::ToolCall),
|
||||
ToolCallUpdate(acp_thread::ToolCallUpdate),
|
||||
ToolCallAuthorization(ToolCallAuthorization),
|
||||
Retry(acp_thread::RetryStatus),
|
||||
Stop(acp::StopReason),
|
||||
}
|
||||
|
||||
|
@ -662,41 +683,18 @@ impl Thread {
|
|||
})??;
|
||||
|
||||
log::info!("Calling model.stream_completion");
|
||||
let mut events = model.stream_completion(request, cx).await?;
|
||||
log::debug!("Stream completion started successfully");
|
||||
|
||||
let mut tool_use_limit_reached = false;
|
||||
let mut tool_uses = FuturesUnordered::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event? {
|
||||
LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::ToolUseLimitReached,
|
||||
) => {
|
||||
tool_use_limit_reached = true;
|
||||
}
|
||||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
event_stream.send_stop(reason);
|
||||
if reason == StopReason::Refusal {
|
||||
this.update(cx, |this, _cx| {
|
||||
this.flush_pending_message();
|
||||
this.messages.truncate(message_ix);
|
||||
})?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
event => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
this.update(cx, |this, cx| {
|
||||
tool_uses.extend(this.handle_streamed_completion_event(
|
||||
event,
|
||||
&event_stream,
|
||||
cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut tool_uses = Self::stream_completion_with_retries(
|
||||
this.clone(),
|
||||
model.clone(),
|
||||
request,
|
||||
message_ix,
|
||||
&event_stream,
|
||||
&mut tool_use_limit_reached,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let used_tools = tool_uses.is_empty();
|
||||
while let Some(tool_result) = tool_uses.next().await {
|
||||
|
@ -754,10 +752,105 @@ impl Thread {
|
|||
Ok(events_rx)
|
||||
}
|
||||
|
||||
async fn stream_completion_with_retries(
|
||||
this: WeakEntity<Self>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
request: LanguageModelRequest,
|
||||
message_ix: usize,
|
||||
event_stream: &AgentResponseEventStream,
|
||||
tool_use_limit_reached: &mut bool,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
|
||||
log::debug!("Stream completion started successfully");
|
||||
|
||||
let mut attempt = None;
|
||||
'retry: loop {
|
||||
let mut events = model.stream_completion(request.clone(), cx).await?;
|
||||
let mut tool_uses = FuturesUnordered::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::ToolUseLimitReached,
|
||||
)) => {
|
||||
*tool_use_limit_reached = true;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Stop(reason)) => {
|
||||
event_stream.send_stop(reason);
|
||||
if reason == StopReason::Refusal {
|
||||
this.update(cx, |this, _cx| {
|
||||
this.flush_pending_message();
|
||||
this.messages.truncate(message_ix);
|
||||
})?;
|
||||
return Ok(tool_uses);
|
||||
}
|
||||
}
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
this.update(cx, |this, cx| {
|
||||
tool_uses.extend(this.handle_streamed_completion_event(
|
||||
event,
|
||||
event_stream,
|
||||
cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
Err(error) => {
|
||||
let completion_mode =
|
||||
this.read_with(cx, |thread, _cx| thread.completion_mode())?;
|
||||
if completion_mode == CompletionMode::Normal {
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
let Some(strategy) = Self::retry_strategy_for(&error) else {
|
||||
return Err(error.into());
|
||||
};
|
||||
|
||||
let max_attempts = match &strategy {
|
||||
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
|
||||
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
|
||||
};
|
||||
|
||||
let attempt = attempt.get_or_insert(0u8);
|
||||
|
||||
*attempt += 1;
|
||||
|
||||
let attempt = *attempt;
|
||||
if attempt > max_attempts {
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
let delay = match &strategy {
|
||||
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
|
||||
let delay_secs =
|
||||
initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
|
||||
Duration::from_secs(delay_secs)
|
||||
}
|
||||
RetryStrategy::Fixed { delay, .. } => *delay,
|
||||
};
|
||||
log::debug!("Retry attempt {attempt} with delay {delay:?}");
|
||||
|
||||
event_stream.send_retry(acp_thread::RetryStatus {
|
||||
last_error: error.to_string().into(),
|
||||
attempt: attempt as usize,
|
||||
max_attempts: max_attempts as usize,
|
||||
started_at: Instant::now(),
|
||||
duration: delay,
|
||||
});
|
||||
|
||||
cx.background_executor().timer(delay).await;
|
||||
continue 'retry;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ok(tool_uses);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
|
||||
log::debug!("Building system message");
|
||||
let prompt = SystemPromptTemplate {
|
||||
project: &self.project_context.read(cx),
|
||||
project: self.project_context.read(cx),
|
||||
available_tools: self.tools.keys().cloned().collect(),
|
||||
}
|
||||
.render(&self.templates)
|
||||
|
@ -1158,6 +1251,113 @@ impl Thread {
|
|||
fn advance_prompt_id(&mut self) {
|
||||
self.prompt_id = PromptId::new();
|
||||
}
|
||||
|
||||
fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
|
||||
use LanguageModelCompletionError::*;
|
||||
use http_client::StatusCode;
|
||||
|
||||
// General strategy here:
|
||||
// - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
|
||||
// - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
|
||||
// - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
|
||||
match error {
|
||||
HttpResponseError {
|
||||
status_code: StatusCode::TOO_MANY_REQUESTS,
|
||||
..
|
||||
} => Some(RetryStrategy::ExponentialBackoff {
|
||||
initial_delay: BASE_RETRY_DELAY,
|
||||
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||
}),
|
||||
ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
|
||||
Some(RetryStrategy::Fixed {
|
||||
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||
})
|
||||
}
|
||||
UpstreamProviderError {
|
||||
status,
|
||||
retry_after,
|
||||
..
|
||||
} => match *status {
|
||||
StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
|
||||
Some(RetryStrategy::Fixed {
|
||||
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||
})
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
|
||||
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||
// Internal Server Error could be anything, retry up to 3 times.
|
||||
max_attempts: 3,
|
||||
}),
|
||||
status => {
|
||||
// There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
|
||||
// but we frequently get them in practice. See https://http.dev/529
|
||||
if status.as_u16() == 529 {
|
||||
Some(RetryStrategy::Fixed {
|
||||
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||
max_attempts: MAX_RETRY_ATTEMPTS,
|
||||
})
|
||||
} else {
|
||||
Some(RetryStrategy::Fixed {
|
||||
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
|
||||
max_attempts: 2,
|
||||
})
|
||||
}
|
||||
}
|
||||
},
|
||||
ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
max_attempts: 3,
|
||||
}),
|
||||
ApiReadResponseError { .. }
|
||||
| HttpSend { .. }
|
||||
| DeserializeResponse { .. }
|
||||
| BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
max_attempts: 3,
|
||||
}),
|
||||
// Retrying these errors definitely shouldn't help.
|
||||
HttpResponseError {
|
||||
status_code:
|
||||
StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
|
||||
..
|
||||
}
|
||||
| AuthenticationError { .. }
|
||||
| PermissionError { .. }
|
||||
| NoApiKey { .. }
|
||||
| ApiEndpointNotFound { .. }
|
||||
| PromptTooLarge { .. } => None,
|
||||
// These errors might be transient, so retry them
|
||||
SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
max_attempts: 1,
|
||||
}),
|
||||
// Retry all other 4xx and 5xx errors once.
|
||||
HttpResponseError { status_code, .. }
|
||||
if status_code.is_client_error() || status_code.is_server_error() =>
|
||||
{
|
||||
Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
max_attempts: 3,
|
||||
})
|
||||
}
|
||||
Other(err)
|
||||
if err.is::<language_model::PaymentRequiredError>()
|
||||
|| err.is::<language_model::ModelRequestLimitReachedError>() =>
|
||||
{
|
||||
// Retrying won't help for Payment Required or Model Request Limit errors (where
|
||||
// the user must upgrade to usage-based billing to get more requests, or else wait
|
||||
// for a significant amount of time for the request limit to reset).
|
||||
None
|
||||
}
|
||||
// Conservatively assume that any other errors are non-retryable
|
||||
HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
max_attempts: 2,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RunningTurn {
|
||||
|
@ -1367,6 +1567,12 @@ impl AgentResponseEventStream {
|
|||
.ok();
|
||||
}
|
||||
|
||||
fn send_retry(&self, status: acp_thread::RetryStatus) {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::Retry(status)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn send_stop(&self, reason: StopReason) {
|
||||
match reason {
|
||||
StopReason::EndTurn => {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue