diff --git a/Cargo.lock b/Cargo.lock index 76f45bd28e..e034212748 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -160,6 +160,7 @@ dependencies = [ "agent_servers", "anyhow", "client", + "clock", "cloud_llm_client", "collections", "ctor", diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 70779bba74..74aa2993dd 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -42,6 +42,7 @@ workspace-hack.workspace = true [dev-dependencies] ctor.workspace = true client = { workspace = true, "features" = ["test-support"] } +clock = { workspace = true, "features" = ["test-support"] } env_logger.workspace = true fs = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index bd4f82200b..305a31fc98 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -2,7 +2,7 @@ use acp_thread::ModelSelector; use agent_client_protocol as acp; use anyhow::{anyhow, Result}; use futures::StreamExt; -use gpui::{App, AppContext, AsyncApp, Entity, Task}; +use gpui::{App, AppContext, AsyncApp, Entity, Subscription, Task, WeakEntity}; use language_model::{LanguageModel, LanguageModelRegistry}; use project::Project; use std::collections::HashMap; @@ -17,7 +17,8 @@ struct Session { /// The internal thread that processes messages thread: Entity, /// The ACP thread that handles protocol communication - acp_thread: Entity, + acp_thread: WeakEntity, + _subscription: Subscription, } pub struct NativeAgent { @@ -162,12 +163,15 @@ impl acp_thread::AgentConnection for NativeAgentConnection { })?; // Store the session - agent.update(cx, |agent, _cx| { + agent.update(cx, |agent, cx| { agent.sessions.insert( session_id, Session { thread, - acp_thread: acp_thread.clone(), + acp_thread: acp_thread.downgrade(), + _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { + this.sessions.remove(acp_thread.session_id()); + }) }, ); })?; diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 2e2b25a119..330d04b60c 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1,10 +1,10 @@ use super::*; use crate::templates::Templates; -use acp_thread::AgentConnection as _; +use acp_thread::AgentConnection; use agent_client_protocol as acp; use client::{Client, UserStore}; use fs::FakeFs; -use gpui::{AppContext, Entity, Task, TestAppContext}; +use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext}; use indoc::indoc; use language_model::{ fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError, @@ -322,31 +322,26 @@ async fn test_refusal(cx: &mut TestAppContext) { }); } -#[ignore = "temporarily disabled until it can be run on CI"] #[gpui::test] async fn test_agent_connection(cx: &mut TestAppContext) { - cx.executor().allow_parking(); cx.update(settings::init); let templates = Templates::new(); // Initialize language model system with test provider cx.update(|cx| { gpui_tokio::init(cx); - let http_client = ReqwestClient::user_agent("agent tests").unwrap(); - cx.set_http_client(Arc::new(http_client)); - client::init_settings(cx); - let client = Client::production(cx); + + let http_client = FakeHttpClient::with_404_response(); + let clock = Arc::new(clock::FakeSystemClock::new()); + let client = Client::new(clock, http_client, cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); language_model::init(client.clone(), cx); language_models::init(user_store.clone(), client.clone(), cx); - - // Initialize project settings Project::init_settings(cx); - - // Use test registry with fake provider LanguageModelRegistry::test(cx); }); + cx.executor().forbid_parking(); // Create agent and connection let agent = cx.new(|_| NativeAgent::new(templates.clone())); @@ -390,34 +385,60 @@ async fn test_agent_connection(cx: &mut TestAppContext) { let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); // Test selected_model returns the default - let selected = cx + let model = cx .update(|cx| { let mut async_cx = cx.to_async(); selector.selected_model(&session_id, &mut async_cx) }) .await .expect("selected_model should succeed"); - assert_eq!(selected.id().0, "fake", "should return default model"); + let model = model.as_fake(); + assert_eq!(model.id().0, "fake", "should return default model"); - // The thread was created via prompt with the default model - // We can verify it through selected_model + let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx)); + cx.run_until_parked(); + model.send_last_completion_stream_text_chunk("def"); + cx.run_until_parked(); + acp_thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User - // Test prompt uses the selected model - let prompt_request = acp::PromptRequest { - session_id: session_id.clone(), - prompt: vec![acp::ContentBlock::Text(acp::TextContent { - text: "Test prompt".into(), - annotations: None, - })], - }; + abc - let request = cx.update(|cx| connection.prompt(prompt_request, cx)); - let request = cx.background_spawn(request); - smol::Timer::after(Duration::from_millis(100)).await; + ## Assistant + + def + + "} + ) + }); // Test cancel cx.update(|cx| connection.cancel(&session_id, cx)); request.await.expect("prompt should fail gracefully"); + + // Ensure that dropping the ACP thread causes the native thread to be + // dropped as well. + cx.update(|_| drop(acp_thread)); + let result = cx + .update(|cx| { + connection.prompt( + acp::PromptRequest { + session_id: session_id.clone(), + prompt: vec!["ghi".into()], + }, + cx, + ) + }) + .await; + assert_eq!( + result.as_ref().unwrap_err().to_string(), + "Session not found", + "unexpected result: {:?}", + result + ); } /// Filters out the stop events for asserting against in tests