agent2: Port retry logic (#36421)

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-08-19 11:41:55 +02:00 committed by GitHub
parent 47e1d4511c
commit 0ea0d466d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 514 additions and 52 deletions

View file

@ -24,6 +24,7 @@ use std::fmt::{Formatter, Write};
use std::ops::Range;
use std::process::ExitStatus;
use std::rc::Rc;
use std::time::{Duration, Instant};
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
use ui::App;
use util::ResultExt;
@ -658,6 +659,15 @@ impl PlanEntry {
}
}
#[derive(Debug, Clone)]
pub struct RetryStatus {
pub last_error: SharedString,
pub attempt: usize,
pub max_attempts: usize,
pub started_at: Instant,
pub duration: Duration,
}
pub struct AcpThread {
title: SharedString,
entries: Vec<AgentThreadEntry>,
@ -676,6 +686,7 @@ pub enum AcpThreadEvent {
EntryUpdated(usize),
EntriesRemoved(Range<usize>),
ToolAuthorizationRequired,
Retry(RetryStatus),
Stopped,
Error,
ServerExited(ExitStatus),
@ -916,6 +927,10 @@ impl AcpThread {
cx.emit(AcpThreadEvent::NewEntry);
}
pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
cx.emit(AcpThreadEvent::Retry(status));
}
pub fn update_tool_call(
&mut self,
update: impl Into<ToolCallUpdate>,

View file

@ -546,6 +546,11 @@ impl NativeAgentConnection {
thread.update_tool_call(update, cx)
})??;
}
AgentResponseEvent::Retry(status) => {
acp_thread.update(cx, |thread, cx| {
thread.update_retry_status(status, cx)
})?;
}
AgentResponseEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });

View file

@ -6,15 +6,16 @@ use agent_settings::AgentProfileId;
use anyhow::Result;
use client::{Client, UserStore};
use fs::{FakeFs, Fs};
use futures::channel::mpsc::UnboundedReceiver;
use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
use gpui::{
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
};
use indoc::indoc;
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
Role, StopReason, fake_provider::FakeLanguageModel,
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
fake_provider::FakeLanguageModel,
};
use pretty_assertions::assert_eq;
use project::Project;
@ -24,7 +25,6 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use settings::SettingsStore;
use smol::stream::StreamExt;
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
use util::path;
@ -1435,6 +1435,162 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
);
}
#[gpui::test]
async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let mut events = thread
.update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.end_last_completion_stream();
let mut retry_events = Vec::new();
while let Some(Ok(event)) = events.next().await {
match event {
AgentResponseEvent::Retry(retry_status) => {
retry_events.push(retry_status);
}
AgentResponseEvent::Stop(..) => break,
_ => {}
}
}
assert_eq!(retry_events.len(), 0);
thread.read_with(cx, |thread, _cx| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Hello!
## Assistant
Hey!
"}
)
});
}
#[gpui::test]
async fn test_send_retry_on_error(cx: &mut TestAppContext) {
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let mut events = thread
.update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)),
});
fake_model.end_last_completion_stream();
cx.executor().advance_clock(Duration::from_secs(3));
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.end_last_completion_stream();
let mut retry_events = Vec::new();
while let Some(Ok(event)) = events.next().await {
match event {
AgentResponseEvent::Retry(retry_status) => {
retry_events.push(retry_status);
}
AgentResponseEvent::Stop(..) => break,
_ => {}
}
}
assert_eq!(retry_events.len(), 1);
assert!(matches!(
retry_events[0],
acp_thread::RetryStatus { attempt: 1, .. }
));
thread.read_with(cx, |thread, _cx| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Hello!
## Assistant
Hey!
"}
)
});
}
#[gpui::test]
async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let mut events = thread
.update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.unwrap();
cx.run_until_parked();
for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
fake_model.send_last_completion_stream_error(
LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)),
},
);
fake_model.end_last_completion_stream();
cx.executor().advance_clock(Duration::from_secs(3));
cx.run_until_parked();
}
let mut errors = Vec::new();
let mut retry_events = Vec::new();
while let Some(event) = events.next().await {
match event {
Ok(AgentResponseEvent::Retry(retry_status)) => {
retry_events.push(retry_status);
}
Ok(AgentResponseEvent::Stop(..)) => break,
Err(error) => errors.push(error),
_ => {}
}
}
assert_eq!(
retry_events.len(),
crate::thread::MAX_RETRY_ATTEMPTS as usize
);
for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
assert_eq!(retry_events[i].attempt, i + 1);
}
assert_eq!(errors.len(), 1);
let error = errors[0]
.downcast_ref::<LanguageModelCompletionError>()
.unwrap();
assert!(matches!(
error,
LanguageModelCompletionError::ServerOverloaded { .. }
));
}
/// Filters out the stop events for asserting against in tests
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
result_events

View file

@ -12,12 +12,12 @@ use futures::{
channel::{mpsc, oneshot},
stream::FuturesUnordered,
};
use gpui::{App, Context, Entity, SharedString, Task};
use gpui::{App, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
};
use project::Project;
use prompt_store::ProjectContext;
@ -25,7 +25,12 @@ use schemars::{JsonSchema, Schema};
use serde::{Deserialize, Serialize};
use settings::{Settings, update_settings_file};
use smol::stream::StreamExt;
use std::{collections::BTreeMap, path::Path, sync::Arc};
use std::{
collections::BTreeMap,
path::Path,
sync::Arc,
time::{Duration, Instant},
};
use std::{fmt::Write, ops::Range};
use util::{ResultExt, markdown::MarkdownCodeBlock};
use uuid::Uuid;
@ -71,6 +76,21 @@ impl std::fmt::Display for PromptId {
}
}
pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
#[derive(Debug, Clone)]
enum RetryStrategy {
ExponentialBackoff {
initial_delay: Duration,
max_attempts: u8,
},
Fixed {
delay: Duration,
max_attempts: u8,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Message {
User(UserMessage),
@ -455,6 +475,7 @@ pub enum AgentResponseEvent {
ToolCall(acp::ToolCall),
ToolCallUpdate(acp_thread::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
Retry(acp_thread::RetryStatus),
Stop(acp::StopReason),
}
@ -662,41 +683,18 @@ impl Thread {
})??;
log::info!("Calling model.stream_completion");
let mut events = model.stream_completion(request, cx).await?;
log::debug!("Stream completion started successfully");
let mut tool_use_limit_reached = false;
let mut tool_uses = FuturesUnordered::new();
while let Some(event) = events.next().await {
match event? {
LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::ToolUseLimitReached,
) => {
tool_use_limit_reached = true;
}
LanguageModelCompletionEvent::Stop(reason) => {
event_stream.send_stop(reason);
if reason == StopReason::Refusal {
this.update(cx, |this, _cx| {
this.flush_pending_message();
this.messages.truncate(message_ix);
})?;
return Ok(());
}
}
event => {
log::trace!("Received completion event: {:?}", event);
this.update(cx, |this, cx| {
tool_uses.extend(this.handle_streamed_completion_event(
event,
&event_stream,
cx,
));
})
.ok();
}
}
}
let mut tool_uses = Self::stream_completion_with_retries(
this.clone(),
model.clone(),
request,
message_ix,
&event_stream,
&mut tool_use_limit_reached,
cx,
)
.await?;
let used_tools = tool_uses.is_empty();
while let Some(tool_result) = tool_uses.next().await {
@ -754,10 +752,105 @@ impl Thread {
Ok(events_rx)
}
async fn stream_completion_with_retries(
this: WeakEntity<Self>,
model: Arc<dyn LanguageModel>,
request: LanguageModelRequest,
message_ix: usize,
event_stream: &AgentResponseEventStream,
tool_use_limit_reached: &mut bool,
cx: &mut AsyncApp,
) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
log::debug!("Stream completion started successfully");
let mut attempt = None;
'retry: loop {
let mut events = model.stream_completion(request.clone(), cx).await?;
let mut tool_uses = FuturesUnordered::new();
while let Some(event) = events.next().await {
match event {
Ok(LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::ToolUseLimitReached,
)) => {
*tool_use_limit_reached = true;
}
Ok(LanguageModelCompletionEvent::Stop(reason)) => {
event_stream.send_stop(reason);
if reason == StopReason::Refusal {
this.update(cx, |this, _cx| {
this.flush_pending_message();
this.messages.truncate(message_ix);
})?;
return Ok(tool_uses);
}
}
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
this.update(cx, |this, cx| {
tool_uses.extend(this.handle_streamed_completion_event(
event,
event_stream,
cx,
));
})
.ok();
}
Err(error) => {
let completion_mode =
this.read_with(cx, |thread, _cx| thread.completion_mode())?;
if completion_mode == CompletionMode::Normal {
return Err(error.into());
}
let Some(strategy) = Self::retry_strategy_for(&error) else {
return Err(error.into());
};
let max_attempts = match &strategy {
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
};
let attempt = attempt.get_or_insert(0u8);
*attempt += 1;
let attempt = *attempt;
if attempt > max_attempts {
return Err(error.into());
}
let delay = match &strategy {
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
let delay_secs =
initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
Duration::from_secs(delay_secs)
}
RetryStrategy::Fixed { delay, .. } => *delay,
};
log::debug!("Retry attempt {attempt} with delay {delay:?}");
event_stream.send_retry(acp_thread::RetryStatus {
last_error: error.to_string().into(),
attempt: attempt as usize,
max_attempts: max_attempts as usize,
started_at: Instant::now(),
duration: delay,
});
cx.background_executor().timer(delay).await;
continue 'retry;
}
}
}
return Ok(tool_uses);
}
}
pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
log::debug!("Building system message");
let prompt = SystemPromptTemplate {
project: &self.project_context.read(cx),
project: self.project_context.read(cx),
available_tools: self.tools.keys().cloned().collect(),
}
.render(&self.templates)
@ -1158,6 +1251,113 @@ impl Thread {
fn advance_prompt_id(&mut self) {
self.prompt_id = PromptId::new();
}
fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
use LanguageModelCompletionError::*;
use http_client::StatusCode;
// General strategy here:
// - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
// - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
// - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
match error {
HttpResponseError {
status_code: StatusCode::TOO_MANY_REQUESTS,
..
} => Some(RetryStrategy::ExponentialBackoff {
initial_delay: BASE_RETRY_DELAY,
max_attempts: MAX_RETRY_ATTEMPTS,
}),
ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
Some(RetryStrategy::Fixed {
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
max_attempts: MAX_RETRY_ATTEMPTS,
})
}
UpstreamProviderError {
status,
retry_after,
..
} => match *status {
StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
Some(RetryStrategy::Fixed {
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
max_attempts: MAX_RETRY_ATTEMPTS,
})
}
StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
// Internal Server Error could be anything, retry up to 3 times.
max_attempts: 3,
}),
status => {
// There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
// but we frequently get them in practice. See https://http.dev/529
if status.as_u16() == 529 {
Some(RetryStrategy::Fixed {
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
max_attempts: MAX_RETRY_ATTEMPTS,
})
} else {
Some(RetryStrategy::Fixed {
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
max_attempts: 2,
})
}
}
},
ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,
max_attempts: 3,
}),
ApiReadResponseError { .. }
| HttpSend { .. }
| DeserializeResponse { .. }
| BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,
max_attempts: 3,
}),
// Retrying these errors definitely shouldn't help.
HttpResponseError {
status_code:
StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
..
}
| AuthenticationError { .. }
| PermissionError { .. }
| NoApiKey { .. }
| ApiEndpointNotFound { .. }
| PromptTooLarge { .. } => None,
// These errors might be transient, so retry them
SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,
max_attempts: 1,
}),
// Retry all other 4xx and 5xx errors once.
HttpResponseError { status_code, .. }
if status_code.is_client_error() || status_code.is_server_error() =>
{
Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,
max_attempts: 3,
})
}
Other(err)
if err.is::<language_model::PaymentRequiredError>()
|| err.is::<language_model::ModelRequestLimitReachedError>() =>
{
// Retrying won't help for Payment Required or Model Request Limit errors (where
// the user must upgrade to usage-based billing to get more requests, or else wait
// for a significant amount of time for the request limit to reset).
None
}
// Conservatively assume that any other errors are non-retryable
HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,
max_attempts: 2,
}),
}
}
}
struct RunningTurn {
@ -1367,6 +1567,12 @@ impl AgentResponseEventStream {
.ok();
}
fn send_retry(&self, status: acp_thread::RetryStatus) {
self.0
.unbounded_send(Ok(AgentResponseEvent::Retry(status)))
.ok();
}
fn send_stop(&self, reason: StopReason) {
match reason {
StopReason::EndTurn => {

View file

@ -1,7 +1,7 @@
use acp_thread::{
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
AuthRequired, LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
UserMessageId,
AuthRequired, LoadError, MentionUri, RetryStatus, ThreadStatus, ToolCall, ToolCallContent,
ToolCallStatus, UserMessageId,
};
use acp_thread::{AgentConnection, Plan};
use action_log::ActionLog;
@ -35,6 +35,7 @@ use prompt_store::PromptId;
use rope::Point;
use settings::{Settings as _, SettingsStore};
use std::sync::Arc;
use std::time::Instant;
use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration};
use text::Anchor;
use theme::ThemeSettings;
@ -115,6 +116,7 @@ pub struct AcpThreadView {
profile_selector: Option<Entity<ProfileSelector>>,
notifications: Vec<WindowHandle<AgentNotification>>,
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
thread_retry_status: Option<RetryStatus>,
thread_error: Option<ThreadError>,
list_state: ListState,
scrollbar_state: ScrollbarState,
@ -209,6 +211,7 @@ impl AcpThreadView {
notification_subscriptions: HashMap::default(),
list_state: list_state.clone(),
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
thread_retry_status: None,
thread_error: None,
auth_task: None,
expanded_tool_calls: HashSet::default(),
@ -445,6 +448,7 @@ impl AcpThreadView {
pub fn cancel_generation(&mut self, cx: &mut Context<Self>) {
self.thread_error.take();
self.thread_retry_status.take();
if let Some(thread) = self.thread() {
self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx)));
@ -775,7 +779,11 @@ impl AcpThreadView {
AcpThreadEvent::ToolAuthorizationRequired => {
self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx);
}
AcpThreadEvent::Retry(retry) => {
self.thread_retry_status = Some(retry.clone());
}
AcpThreadEvent::Stopped => {
self.thread_retry_status.take();
let used_tools = thread.read(cx).used_tools_since_last_user_message();
self.notify_with_sound(
if used_tools {
@ -789,6 +797,7 @@ impl AcpThreadView {
);
}
AcpThreadEvent::Error => {
self.thread_retry_status.take();
self.notify_with_sound(
"Agent stopped due to an error",
IconName::Warning,
@ -797,6 +806,7 @@ impl AcpThreadView {
);
}
AcpThreadEvent::ServerExited(status) => {
self.thread_retry_status.take();
self.thread_state = ThreadState::ServerExited { status: *status };
}
}
@ -3413,7 +3423,51 @@ impl AcpThreadView {
})
}
fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option<Div> {
fn render_thread_retry_status_callout(
&self,
_window: &mut Window,
_cx: &mut Context<Self>,
) -> Option<Callout> {
let state = self.thread_retry_status.as_ref()?;
let next_attempt_in = state
.duration
.saturating_sub(Instant::now().saturating_duration_since(state.started_at));
if next_attempt_in.is_zero() {
return None;
}
let next_attempt_in_secs = next_attempt_in.as_secs() + 1;
let retry_message = if state.max_attempts == 1 {
if next_attempt_in_secs == 1 {
"Retrying. Next attempt in 1 second.".to_string()
} else {
format!("Retrying. Next attempt in {next_attempt_in_secs} seconds.")
}
} else {
if next_attempt_in_secs == 1 {
format!(
"Retrying. Next attempt in 1 second (Attempt {} of {}).",
state.attempt, state.max_attempts,
)
} else {
format!(
"Retrying. Next attempt in {next_attempt_in_secs} seconds (Attempt {} of {}).",
state.attempt, state.max_attempts,
)
}
};
Some(
Callout::new()
.severity(Severity::Warning)
.title(state.last_error.clone())
.description(retry_message),
)
}
fn render_thread_error(&self, window: &mut Window, cx: &mut Context<Self>) -> Option<Div> {
let content = match self.thread_error.as_ref()? {
ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx),
ThreadError::PaymentRequired => self.render_payment_required_error(cx),
@ -3678,6 +3732,7 @@ impl Render for AcpThreadView {
}
_ => this,
})
.children(self.render_thread_retry_status_callout(window, cx))
.children(self.render_thread_error(window, cx))
.child(self.render_message_editor(window, cx))
}

View file

@ -1523,6 +1523,7 @@ impl AgentDiff {
AcpThreadEvent::EntriesRemoved(_)
| AcpThreadEvent::Stopped
| AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::Retry(_)
| AcpThreadEvent::Error
| AcpThreadEvent::ServerExited(_) => {}
}

View file

@ -4,10 +4,11 @@ use crate::{
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice,
};
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
use http_client::Result;
use parking_lot::Mutex;
use smol::stream::StreamExt;
use std::sync::Arc;
#[derive(Clone)]
@ -100,7 +101,9 @@ pub struct FakeLanguageModel {
current_completion_txs: Mutex<
Vec<(
LanguageModelRequest,
mpsc::UnboundedSender<LanguageModelCompletionEvent>,
mpsc::UnboundedSender<
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
)>,
>,
}
@ -150,7 +153,21 @@ impl FakeLanguageModel {
.find(|(req, _)| req == request)
.map(|(_, tx)| tx)
.unwrap();
tx.unbounded_send(event.into()).unwrap();
tx.unbounded_send(Ok(event.into())).unwrap();
}
pub fn send_completion_stream_error(
&self,
request: &LanguageModelRequest,
error: impl Into<LanguageModelCompletionError>,
) {
let current_completion_txs = self.current_completion_txs.lock();
let tx = current_completion_txs
.iter()
.find(|(req, _)| req == request)
.map(|(_, tx)| tx)
.unwrap();
tx.unbounded_send(Err(error.into())).unwrap();
}
pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
@ -170,6 +187,13 @@ impl FakeLanguageModel {
self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
}
pub fn send_last_completion_stream_error(
&self,
error: impl Into<LanguageModelCompletionError>,
) {
self.send_completion_stream_error(self.pending_completions().last().unwrap(), error);
}
pub fn end_last_completion_stream(&self) {
self.end_completion_stream(self.pending_completions().last().unwrap());
}
@ -229,7 +253,7 @@ impl LanguageModel for FakeLanguageModel {
> {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx));
async move { Ok(rx.map(Ok).boxed()) }.boxed()
async move { Ok(rx.boxed()) }.boxed()
}
fn as_fake(&self) -> &Self {