Add comprehensive test for AgentConnection with ModelSelector
- Add public session_id() method to AcpThread to enable testing - Fix ModelSelector methods to use async move closures properly to avoid borrow conflicts - Add test_agent_connection that verifies: - Model selector is available for agent2 - Can list available models - Can create threads with default model - Can query selected model for a session - Can send prompts using the selected model - Can cancel sessions - Handles errors for invalid sessions - Remove unnecessary mut keywords from async closures
This commit is contained in:
parent
a4fe8c6972
commit
604a88f6e3
3 changed files with 163 additions and 37 deletions
|
@ -656,6 +656,10 @@ impl AcpThread {
|
|||
&self.entries
|
||||
}
|
||||
|
||||
pub fn session_id(&self) -> &acp::SessionId {
|
||||
&self.session_id
|
||||
}
|
||||
|
||||
pub fn status(&self) -> ThreadStatus {
|
||||
if self.send_task.is_some() {
|
||||
if self.waiting_for_tool_confirmation() {
|
||||
|
|
|
@ -33,16 +33,17 @@ pub struct AgentConnection(pub Entity<Agent>);
|
|||
|
||||
impl ModelSelector for AgentConnection {
|
||||
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
|
||||
let result = cx.update(|cx| {
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
let models = registry.available_models(cx).collect::<Vec<_>>();
|
||||
if models.is_empty() {
|
||||
Err(anyhow::anyhow!("No models available"))
|
||||
} else {
|
||||
Ok(models)
|
||||
}
|
||||
});
|
||||
Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
|
||||
cx.spawn(async move |cx| {
|
||||
cx.update(|cx| {
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
let models = registry.available_models(cx).collect::<Vec<_>>();
|
||||
if models.is_empty() {
|
||||
Err(anyhow::anyhow!("No models available"))
|
||||
} else {
|
||||
Ok(models)
|
||||
}
|
||||
})?
|
||||
})
|
||||
}
|
||||
|
||||
fn select_model(
|
||||
|
@ -52,17 +53,19 @@ impl ModelSelector for AgentConnection {
|
|||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<()>> {
|
||||
let agent = self.0.clone();
|
||||
let result = agent.update(cx, |agent, cx| {
|
||||
if let Some(thread) = agent.sessions.get(session_id) {
|
||||
thread.update(cx, |thread, _| {
|
||||
thread.selected_model = model;
|
||||
});
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!("Session not found"))
|
||||
}
|
||||
});
|
||||
Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
|
||||
let session_id = session_id.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
agent.update(cx, |agent, cx| {
|
||||
if let Some(thread) = agent.sessions.get(&session_id) {
|
||||
thread.update(cx, |thread, _| {
|
||||
thread.selected_model = model;
|
||||
});
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!("Session not found"))
|
||||
}
|
||||
})?
|
||||
})
|
||||
}
|
||||
|
||||
fn selected_model(
|
||||
|
@ -71,21 +74,14 @@ impl ModelSelector for AgentConnection {
|
|||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Arc<dyn LanguageModel>>> {
|
||||
let agent = self.0.clone();
|
||||
let thread_result = agent
|
||||
.read_with(cx, |agent, _| agent.sessions.get(session_id).cloned())
|
||||
.ok()
|
||||
.flatten()
|
||||
.ok_or_else(|| anyhow::anyhow!("Session not found"));
|
||||
|
||||
match thread_result {
|
||||
Ok(thread) => {
|
||||
let selected = thread
|
||||
.read_with(cx, |thread, _| thread.selected_model.clone())
|
||||
.unwrap_or_else(|e| panic!("Failed to read thread: {}", e));
|
||||
Task::ready(Ok(selected))
|
||||
}
|
||||
Err(e) => Task::ready(Err(e)),
|
||||
}
|
||||
let session_id = session_id.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let thread = agent
|
||||
.read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
|
||||
let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
||||
Ok(selected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,16 +1,19 @@
|
|||
use super::*;
|
||||
use crate::templates::Templates;
|
||||
use acp_thread::AgentConnection as _;
|
||||
use agent_client_protocol as acp;
|
||||
use client::{Client, UserStore};
|
||||
use gpui::{AppContext, Entity, TestAppContext};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelRegistry, MessageContent, StopReason,
|
||||
};
|
||||
use project::Project;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||
|
||||
mod test_tools;
|
||||
use test_tools::*;
|
||||
|
@ -187,6 +190,129 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
|||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
cx.executor().allow_parking();
|
||||
cx.update(settings::init);
|
||||
let templates = Templates::new();
|
||||
|
||||
// Initialize language model system with test provider
|
||||
cx.update(|cx| {
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
// Initialize project settings
|
||||
Project::init_settings(cx);
|
||||
|
||||
// Use test registry with fake provider
|
||||
LanguageModelRegistry::test(cx);
|
||||
});
|
||||
|
||||
// Create agent and connection
|
||||
let agent = cx.new(|_| Agent::new(templates.clone()));
|
||||
let connection = AgentConnection(agent.clone());
|
||||
|
||||
// Test model_selector returns Some
|
||||
let selector_opt = connection.model_selector();
|
||||
assert!(
|
||||
selector_opt.is_some(),
|
||||
"agent2 should always support ModelSelector"
|
||||
);
|
||||
let selector = selector_opt.unwrap();
|
||||
|
||||
// Test list_models
|
||||
let listed_models = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.list_models(&mut async_cx)
|
||||
})
|
||||
.await
|
||||
.expect("list_models should succeed");
|
||||
assert!(!listed_models.is_empty(), "should have at least one model");
|
||||
assert_eq!(listed_models[0].id().0, "fake");
|
||||
|
||||
// Create a project for new_thread
|
||||
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
|
||||
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
|
||||
|
||||
// Create a thread using new_thread
|
||||
let cwd = Path::new("/test");
|
||||
let connection_rc = Rc::new(connection.clone());
|
||||
let acp_thread = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
connection_rc.new_thread(project, cwd, &mut async_cx)
|
||||
})
|
||||
.await
|
||||
.expect("new_thread should succeed");
|
||||
|
||||
// Get the session_id from the AcpThread
|
||||
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
||||
|
||||
// Test selected_model returns the default
|
||||
let selected = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.selected_model(&session_id, &mut async_cx)
|
||||
})
|
||||
.await
|
||||
.expect("selected_model should succeed");
|
||||
assert_eq!(selected.id().0, "fake", "should return default model");
|
||||
|
||||
// The thread was created via prompt with the default model
|
||||
// We can verify it through selected_model
|
||||
|
||||
// Test prompt uses the selected model
|
||||
let prompt_request = acp::PromptRequest {
|
||||
session_id: session_id.clone(),
|
||||
prompt: vec![acp::ContentBlock::Text(acp::TextContent {
|
||||
text: "Test prompt".into(),
|
||||
annotations: None,
|
||||
})],
|
||||
};
|
||||
|
||||
cx.update(|cx| connection.prompt(prompt_request, cx))
|
||||
.await
|
||||
.expect("prompt should succeed");
|
||||
|
||||
// The prompt was sent successfully
|
||||
|
||||
// Test cancel
|
||||
cx.update(|cx| connection.cancel(&session_id, cx));
|
||||
|
||||
// After cancel, selected_model should fail
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.selected_model(&session_id, &mut async_cx)
|
||||
})
|
||||
.await;
|
||||
assert!(result.is_err(), "selected_model should fail after cancel");
|
||||
|
||||
// Test error case: invalid session
|
||||
let invalid_session = acp::SessionId("invalid".into());
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.selected_model(&invalid_session, &mut async_cx)
|
||||
})
|
||||
.await;
|
||||
assert!(result.is_err(), "should fail for invalid session");
|
||||
if let Err(e) = result {
|
||||
assert!(
|
||||
e.to_string().contains("Session not found"),
|
||||
"should have correct error message"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Filters out the stop events for asserting against in tests
|
||||
fn stop_events(
|
||||
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue