use super::*; use crate::templates::Templates; use acp_thread::AgentConnection as _; use agent_client_protocol as acp; use client::{Client, UserStore}; use gpui::{AppContext, Entity, TestAppContext}; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRegistry, MessageContent, StopReason, }; use project::Project; use reqwest_client::ReqwestClient; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use smol::stream::StreamExt; use std::{path::Path, rc::Rc, sync::Arc, time::Duration}; mod test_tools; use test_tools::*; #[gpui::test] async fn test_echo(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx).await; let events = thread .update(cx, |thread, cx| { thread.send(model.clone(), "Testing: Reply with 'Hello'", cx) }) .collect() .await; thread.update(cx, |thread, _cx| { assert_eq!( thread.messages().last().unwrap().content, vec![MessageContent::Text("Hello".to_string())] ); }); assert_eq!(stop_events(events), vec![StopReason::EndTurn]); } #[gpui::test] async fn test_basic_tool_calls(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx).await; // Test a tool call that's likely to complete *before* streaming stops. let events = thread .update(cx, |thread, cx| { thread.add_tool(EchoTool); thread.send( model.clone(), "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.", cx, ) }) .collect() .await; assert_eq!( stop_events(events), vec![StopReason::ToolUse, StopReason::EndTurn] ); // Test a tool calls that's likely to complete *after* streaming stops. let events = thread .update(cx, |thread, cx| { thread.remove_tool(&AgentTool::name(&EchoTool)); thread.add_tool(DelayTool); thread.send( model.clone(), "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.", cx, ) }) .collect() .await; assert_eq!( stop_events(events), vec![StopReason::ToolUse, StopReason::EndTurn] ); thread.update(cx, |thread, _cx| { assert!(thread .messages() .last() .unwrap() .content .iter() .any(|content| { if let MessageContent::Text(text) = content { text.contains("Ding") } else { false } })); }); } #[gpui::test] async fn test_streaming_tool_calls(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx).await; // Test a tool call that's likely to complete *before* streaming stops. let mut events = thread.update(cx, |thread, cx| { thread.add_tool(WordListTool); thread.send(model.clone(), "Test the word_list tool.", cx) }); let mut saw_partial_tool_use = false; while let Some(event) = events.next().await { if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event { thread.update(cx, |thread, _cx| { // Look for a tool use in the thread's last message let last_content = thread.messages().last().unwrap().content.last().unwrap(); if let MessageContent::ToolUse(last_tool_use) = last_content { assert_eq!(last_tool_use.name.as_ref(), "word_list"); if tool_use_event.is_input_complete { last_tool_use .input .get("a") .expect("'a' has streamed because input is now complete"); last_tool_use .input .get("g") .expect("'g' has streamed because input is now complete"); } else { if !last_tool_use.is_input_complete && last_tool_use.input.get("g").is_none() { saw_partial_tool_use = true; } } } else { panic!("last content should be a tool use"); } }); } } assert!( saw_partial_tool_use, "should see at least one partially streamed tool use in the history" ); } #[gpui::test] async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx).await; // Test concurrent tool calls with different delay times let events = thread .update(cx, |thread, cx| { thread.add_tool(DelayTool); thread.send( model.clone(), "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.", cx, ) }) .collect() .await; let stop_reasons = stop_events(events); if stop_reasons.len() == 2 { assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]); } else if stop_reasons.len() == 3 { assert_eq!( stop_reasons, vec![ StopReason::ToolUse, StopReason::ToolUse, StopReason::EndTurn ] ); } else { panic!("Expected either 1 or 2 tool uses followed by end turn"); } thread.update(cx, |thread, _cx| { let last_message = thread.messages().last().unwrap(); let text = last_message .content .iter() .filter_map(|content| { if let MessageContent::Text(text) = content { Some(text.as_str()) } else { None } }) .collect::(); assert!(text.contains("Ding")); }); } #[gpui::test] async fn test_agent_connection(cx: &mut TestAppContext) { cx.executor().allow_parking(); cx.update(settings::init); let templates = Templates::new(); // Initialize language model system with test provider cx.update(|cx| { gpui_tokio::init(cx); let http_client = ReqwestClient::user_agent("agent tests").unwrap(); cx.set_http_client(Arc::new(http_client)); client::init_settings(cx); let client = Client::production(cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); language_model::init(client.clone(), cx); language_models::init(user_store.clone(), client.clone(), cx); // Initialize project settings Project::init_settings(cx); // Use test registry with fake provider LanguageModelRegistry::test(cx); }); // Create agent and connection let agent = cx.new(|_| NativeAgent::new(templates.clone())); let connection = NativeAgentConnection(agent.clone()); // Test model_selector returns Some let selector_opt = connection.model_selector(); assert!( selector_opt.is_some(), "agent2 should always support ModelSelector" ); let selector = selector_opt.unwrap(); // Test list_models let listed_models = cx .update(|cx| { let mut async_cx = cx.to_async(); selector.list_models(&mut async_cx) }) .await .expect("list_models should succeed"); assert!(!listed_models.is_empty(), "should have at least one model"); assert_eq!(listed_models[0].id().0, "fake"); // Create a project for new_thread let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone())); let project = Project::test(fake_fs, [Path::new("/test")], cx).await; // Create a thread using new_thread let cwd = Path::new("/test"); let connection_rc = Rc::new(connection.clone()); let acp_thread = cx .update(|cx| { let mut async_cx = cx.to_async(); connection_rc.new_thread(project, cwd, &mut async_cx) }) .await .expect("new_thread should succeed"); // Get the session_id from the AcpThread let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); // Test selected_model returns the default let selected = cx .update(|cx| { let mut async_cx = cx.to_async(); selector.selected_model(&session_id, &mut async_cx) }) .await .expect("selected_model should succeed"); assert_eq!(selected.id().0, "fake", "should return default model"); // The thread was created via prompt with the default model // We can verify it through selected_model // Test prompt uses the selected model let prompt_request = acp::PromptRequest { session_id: session_id.clone(), prompt: vec![acp::ContentBlock::Text(acp::TextContent { text: "Test prompt".into(), annotations: None, })], }; cx.update(|cx| connection.prompt(prompt_request, cx)) .await .expect("prompt should succeed"); // The prompt was sent successfully // Test cancel cx.update(|cx| connection.cancel(&session_id, cx)); // After cancel, selected_model should fail let result = cx .update(|cx| { let mut async_cx = cx.to_async(); selector.selected_model(&session_id, &mut async_cx) }) .await; assert!(result.is_err(), "selected_model should fail after cancel"); // Test error case: invalid session let invalid_session = acp::SessionId("invalid".into()); let result = cx .update(|cx| { let mut async_cx = cx.to_async(); selector.selected_model(&invalid_session, &mut async_cx) }) .await; assert!(result.is_err(), "should fail for invalid session"); if let Err(e) = result { assert!( e.to_string().contains("Session not found"), "should have correct error message" ); } } /// Filters out the stop events for asserting against in tests fn stop_events( result_events: Vec>, ) -> Vec { result_events .into_iter() .filter_map(|event| match event.unwrap() { LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason), _ => None, }) .collect() } struct ThreadTest { model: Arc, thread: Entity, } async fn setup(cx: &mut TestAppContext) -> ThreadTest { cx.executor().allow_parking(); cx.update(settings::init); let templates = Templates::new(); let model = cx .update(|cx| { gpui_tokio::init(cx); let http_client = ReqwestClient::user_agent("agent tests").unwrap(); cx.set_http_client(Arc::new(http_client)); client::init_settings(cx); let client = Client::production(cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); language_model::init(client.clone(), cx); language_models::init(user_store.clone(), client.clone(), cx); let models = LanguageModelRegistry::read_global(cx); let model = models .available_models(cx) .find(|model| model.id().0 == "claude-3-7-sonnet-latest") .unwrap(); let provider = models.provider(&model.provider_id()).unwrap(); let authenticated = provider.authenticate(cx); cx.spawn(async move |_cx| { authenticated.await.unwrap(); model }) }) .await; let thread = cx.new(|_| Thread::new(templates, model.clone())); ThreadTest { model, thread } } #[cfg(test)] #[ctor::ctor] fn init_logger() { if std::env::var("RUST_LOG").is_ok() { env_logger::init(); } }