Add comprehensive test for AgentConnection with ModelSelector

- Add public session_id() method to AcpThread to enable testing
- Fix ModelSelector methods to use async move closures properly to avoid borrow conflicts
- Add test_agent_connection that verifies:
  - Model selector is available for agent2
  - Can list available models
  - Can create threads with default model
  - Can query selected model for a session
  - Can send prompts using the selected model
  - Can cancel sessions
  - Handles errors for invalid sessions
- Remove unnecessary mut keywords from async closures
This commit is contained in:
Nathan Sobo 2025-08-02 08:45:51 -06:00
parent a4fe8c6972
commit 604a88f6e3
3 changed files with 163 additions and 37 deletions

View file

@ -656,6 +656,10 @@ impl AcpThread {
&self.entries
}
pub fn session_id(&self) -> &acp::SessionId {
&self.session_id
}
pub fn status(&self) -> ThreadStatus {
if self.send_task.is_some() {
if self.waiting_for_tool_confirmation() {

View file

@ -33,16 +33,17 @@ pub struct AgentConnection(pub Entity<Agent>);
impl ModelSelector for AgentConnection {
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
let result = cx.update(|cx| {
let registry = LanguageModelRegistry::read_global(cx);
let models = registry.available_models(cx).collect::<Vec<_>>();
if models.is_empty() {
Err(anyhow::anyhow!("No models available"))
} else {
Ok(models)
}
});
Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
cx.spawn(async move |cx| {
cx.update(|cx| {
let registry = LanguageModelRegistry::read_global(cx);
let models = registry.available_models(cx).collect::<Vec<_>>();
if models.is_empty() {
Err(anyhow::anyhow!("No models available"))
} else {
Ok(models)
}
})?
})
}
fn select_model(
@ -52,17 +53,19 @@ impl ModelSelector for AgentConnection {
cx: &mut AsyncApp,
) -> Task<Result<()>> {
let agent = self.0.clone();
let result = agent.update(cx, |agent, cx| {
if let Some(thread) = agent.sessions.get(session_id) {
thread.update(cx, |thread, _| {
thread.selected_model = model;
});
Ok(())
} else {
Err(anyhow::anyhow!("Session not found"))
}
});
Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
let session_id = session_id.clone();
cx.spawn(async move |cx| {
agent.update(cx, |agent, cx| {
if let Some(thread) = agent.sessions.get(&session_id) {
thread.update(cx, |thread, _| {
thread.selected_model = model;
});
Ok(())
} else {
Err(anyhow::anyhow!("Session not found"))
}
})?
})
}
fn selected_model(
@ -71,21 +74,14 @@ impl ModelSelector for AgentConnection {
cx: &mut AsyncApp,
) -> Task<Result<Arc<dyn LanguageModel>>> {
let agent = self.0.clone();
let thread_result = agent
.read_with(cx, |agent, _| agent.sessions.get(session_id).cloned())
.ok()
.flatten()
.ok_or_else(|| anyhow::anyhow!("Session not found"));
match thread_result {
Ok(thread) => {
let selected = thread
.read_with(cx, |thread, _| thread.selected_model.clone())
.unwrap_or_else(|e| panic!("Failed to read thread: {}", e));
Task::ready(Ok(selected))
}
Err(e) => Task::ready(Err(e)),
}
let session_id = session_id.clone();
cx.spawn(async move |cx| {
let thread = agent
.read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
Ok(selected)
})
}
}

View file

@ -1,16 +1,19 @@
use super::*;
use crate::templates::Templates;
use acp_thread::AgentConnection as _;
use agent_client_protocol as acp;
use client::{Client, UserStore};
use gpui::{AppContext, Entity, TestAppContext};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelRegistry, MessageContent, StopReason,
};
use project::Project;
use reqwest_client::ReqwestClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use smol::stream::StreamExt;
use std::{sync::Arc, time::Duration};
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
mod test_tools;
use test_tools::*;
@ -187,6 +190,129 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
});
}
#[gpui::test]
async fn test_agent_connection(cx: &mut TestAppContext) {
cx.executor().allow_parking();
cx.update(settings::init);
let templates = Templates::new();
// Initialize language model system with test provider
cx.update(|cx| {
gpui_tokio::init(cx);
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
client::init_settings(cx);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
// Initialize project settings
Project::init_settings(cx);
// Use test registry with fake provider
LanguageModelRegistry::test(cx);
});
// Create agent and connection
let agent = cx.new(|_| Agent::new(templates.clone()));
let connection = AgentConnection(agent.clone());
// Test model_selector returns Some
let selector_opt = connection.model_selector();
assert!(
selector_opt.is_some(),
"agent2 should always support ModelSelector"
);
let selector = selector_opt.unwrap();
// Test list_models
let listed_models = cx
.update(|cx| {
let mut async_cx = cx.to_async();
selector.list_models(&mut async_cx)
})
.await
.expect("list_models should succeed");
assert!(!listed_models.is_empty(), "should have at least one model");
assert_eq!(listed_models[0].id().0, "fake");
// Create a project for new_thread
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
// Create a thread using new_thread
let cwd = Path::new("/test");
let connection_rc = Rc::new(connection.clone());
let acp_thread = cx
.update(|cx| {
let mut async_cx = cx.to_async();
connection_rc.new_thread(project, cwd, &mut async_cx)
})
.await
.expect("new_thread should succeed");
// Get the session_id from the AcpThread
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
// Test selected_model returns the default
let selected = cx
.update(|cx| {
let mut async_cx = cx.to_async();
selector.selected_model(&session_id, &mut async_cx)
})
.await
.expect("selected_model should succeed");
assert_eq!(selected.id().0, "fake", "should return default model");
// The thread was created via prompt with the default model
// We can verify it through selected_model
// Test prompt uses the selected model
let prompt_request = acp::PromptRequest {
session_id: session_id.clone(),
prompt: vec![acp::ContentBlock::Text(acp::TextContent {
text: "Test prompt".into(),
annotations: None,
})],
};
cx.update(|cx| connection.prompt(prompt_request, cx))
.await
.expect("prompt should succeed");
// The prompt was sent successfully
// Test cancel
cx.update(|cx| connection.cancel(&session_id, cx));
// After cancel, selected_model should fail
let result = cx
.update(|cx| {
let mut async_cx = cx.to_async();
selector.selected_model(&session_id, &mut async_cx)
})
.await;
assert!(result.is_err(), "selected_model should fail after cancel");
// Test error case: invalid session
let invalid_session = acp::SessionId("invalid".into());
let result = cx
.update(|cx| {
let mut async_cx = cx.to_async();
selector.selected_model(&invalid_session, &mut async_cx)
})
.await;
assert!(result.is_err(), "should fail for invalid session");
if let Err(e) = result {
assert!(
e.to_string().contains("Session not found"),
"should have correct error message"
);
}
}
/// Filters out the stop events for asserting against in tests
fn stop_events(
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,