parent
47e1d4511c
commit
0ea0d466d2
7 changed files with 514 additions and 52 deletions
|
@ -6,15 +6,16 @@ use agent_settings::AgentProfileId;
|
|||
use anyhow::Result;
|
||||
use client::{Client, UserStore};
|
||||
use fs::{FakeFs, Fs};
|
||||
use futures::channel::mpsc::UnboundedReceiver;
|
||||
use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
|
||||
use gpui::{
|
||||
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
|
||||
};
|
||||
use indoc::indoc;
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
|
||||
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
|
||||
Role, StopReason, fake_provider::FakeLanguageModel,
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
|
||||
LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
|
||||
fake_provider::FakeLanguageModel,
|
||||
};
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::Project;
|
||||
|
@ -24,7 +25,6 @@ use schemars::JsonSchema;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt;
|
||||
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||
use util::path;
|
||||
|
||||
|
@ -1435,6 +1435,162 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
|||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
|
||||
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
|
||||
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||
fake_model.end_last_completion_stream();
|
||||
|
||||
let mut retry_events = Vec::new();
|
||||
while let Some(Ok(event)) = events.next().await {
|
||||
match event {
|
||||
AgentResponseEvent::Retry(retry_status) => {
|
||||
retry_events.push(retry_status);
|
||||
}
|
||||
AgentResponseEvent::Stop(..) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(retry_events.len(), 0);
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hello!
|
||||
|
||||
## Assistant
|
||||
|
||||
Hey!
|
||||
"}
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_send_retry_on_error(cx: &mut TestAppContext) {
|
||||
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
|
||||
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
||||
provider: LanguageModelProviderName::new("Anthropic"),
|
||||
retry_after: Some(Duration::from_secs(3)),
|
||||
});
|
||||
fake_model.end_last_completion_stream();
|
||||
|
||||
cx.executor().advance_clock(Duration::from_secs(3));
|
||||
cx.run_until_parked();
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||
fake_model.end_last_completion_stream();
|
||||
|
||||
let mut retry_events = Vec::new();
|
||||
while let Some(Ok(event)) = events.next().await {
|
||||
match event {
|
||||
AgentResponseEvent::Retry(retry_status) => {
|
||||
retry_events.push(retry_status);
|
||||
}
|
||||
AgentResponseEvent::Stop(..) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(retry_events.len(), 1);
|
||||
assert!(matches!(
|
||||
retry_events[0],
|
||||
acp_thread::RetryStatus { attempt: 1, .. }
|
||||
));
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hello!
|
||||
|
||||
## Assistant
|
||||
|
||||
Hey!
|
||||
"}
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
|
||||
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
|
||||
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
||||
})
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
|
||||
fake_model.send_last_completion_stream_error(
|
||||
LanguageModelCompletionError::ServerOverloaded {
|
||||
provider: LanguageModelProviderName::new("Anthropic"),
|
||||
retry_after: Some(Duration::from_secs(3)),
|
||||
},
|
||||
);
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.executor().advance_clock(Duration::from_secs(3));
|
||||
cx.run_until_parked();
|
||||
}
|
||||
|
||||
let mut errors = Vec::new();
|
||||
let mut retry_events = Vec::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(AgentResponseEvent::Retry(retry_status)) => {
|
||||
retry_events.push(retry_status);
|
||||
}
|
||||
Ok(AgentResponseEvent::Stop(..)) => break,
|
||||
Err(error) => errors.push(error),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
retry_events.len(),
|
||||
crate::thread::MAX_RETRY_ATTEMPTS as usize
|
||||
);
|
||||
for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
|
||||
assert_eq!(retry_events[i].attempt, i + 1);
|
||||
}
|
||||
assert_eq!(errors.len(), 1);
|
||||
let error = errors[0]
|
||||
.downcast_ref::<LanguageModelCompletionError>()
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
error,
|
||||
LanguageModelCompletionError::ServerOverloaded { .. }
|
||||
));
|
||||
}
|
||||
|
||||
/// Filters out the stop events for asserting against in tests
|
||||
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
|
||||
result_events
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue