diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index ac790c8498..f3d9a35c2b 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -17,17 +17,17 @@ use test_tools::*; #[gpui::test] async fn test_echo(cx: &mut TestAppContext) { - let AgentTest { model, agent, .. } = setup(cx).await; + let ThreadTest { model, thread, .. } = setup(cx).await; - let events = agent - .update(cx, |agent, cx| { - agent.send(model.clone(), "Testing: Reply with 'Hello'", cx) + let events = thread + .update(cx, |thread, cx| { + thread.send(model.clone(), "Testing: Reply with 'Hello'", cx) }) .collect() .await; - agent.update(cx, |agent, _cx| { + thread.update(cx, |thread, _cx| { assert_eq!( - agent.messages().last().unwrap().content, + thread.messages().last().unwrap().content, vec![MessageContent::Text("Hello".to_string())] ); }); @@ -36,13 +36,13 @@ async fn test_echo(cx: &mut TestAppContext) { #[gpui::test] async fn test_basic_tool_calls(cx: &mut TestAppContext) { - let AgentTest { model, agent, .. } = setup(cx).await; + let ThreadTest { model, thread, .. } = setup(cx).await; // Test a tool call that's likely to complete *before* streaming stops. - let events = agent - .update(cx, |agent, cx| { - agent.add_tool(EchoTool); - agent.send( + let events = thread + .update(cx, |thread, cx| { + thread.add_tool(EchoTool); + thread.send( model.clone(), "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.", cx, @@ -56,11 +56,11 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { ); // Test a tool calls that's likely to complete *after* streaming stops. - let events = agent - .update(cx, |agent, cx| { - agent.remove_tool(&AgentTool::name(&EchoTool)); - agent.add_tool(DelayTool); - agent.send( + let events = thread + .update(cx, |thread, cx| { + thread.remove_tool(&AgentTool::name(&EchoTool)); + thread.add_tool(DelayTool); + thread.send( model.clone(), "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.", cx, @@ -72,8 +72,8 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { stop_events(events), vec![StopReason::ToolUse, StopReason::EndTurn] ); - agent.update(cx, |agent, _cx| { - assert!(agent + thread.update(cx, |thread, _cx| { + assert!(thread .messages() .last() .unwrap() @@ -91,20 +91,20 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { #[gpui::test] async fn test_streaming_tool_calls(cx: &mut TestAppContext) { - let AgentTest { model, agent, .. } = setup(cx).await; + let ThreadTest { model, thread, .. } = setup(cx).await; // Test a tool call that's likely to complete *before* streaming stops. - let mut events = agent.update(cx, |agent, cx| { - agent.add_tool(WordListTool); - agent.send(model.clone(), "Test the word_list tool.", cx) + let mut events = thread.update(cx, |thread, cx| { + thread.add_tool(WordListTool); + thread.send(model.clone(), "Test the word_list tool.", cx) }); let mut saw_partial_tool_use = false; while let Some(event) = events.next().await { if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event { - agent.update(cx, |agent, _cx| { - // Look for a tool use in the agent's last message - let last_content = agent.messages().last().unwrap().content.last().unwrap(); + thread.update(cx, |thread, _cx| { + // Look for a tool use in the thread's last message + let last_content = thread.messages().last().unwrap().content.last().unwrap(); if let MessageContent::ToolUse(last_tool_use) = last_content { assert_eq!(last_tool_use.name.as_ref(), "word_list"); if tool_use_event.is_input_complete { @@ -138,13 +138,13 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) { #[gpui::test] async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { - let AgentTest { model, agent, .. } = setup(cx).await; + let ThreadTest { model, thread, .. } = setup(cx).await; // Test concurrent tool calls with different delay times - let events = agent - .update(cx, |agent, cx| { - agent.add_tool(DelayTool); - agent.send( + let events = thread + .update(cx, |thread, cx| { + thread.add_tool(DelayTool); + thread.send( model.clone(), "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.", cx, @@ -169,8 +169,8 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { panic!("Expected either 1 or 2 tool uses followed by end turn"); } - agent.update(cx, |agent, _cx| { - let last_message = agent.messages().last().unwrap(); + thread.update(cx, |thread, _cx| { + let last_message = thread.messages().last().unwrap(); let text = last_message .content .iter() @@ -200,16 +200,16 @@ fn stop_events( .collect() } -struct AgentTest { +struct ThreadTest { model: Arc, - agent: Entity, + thread: Entity, } -async fn setup(cx: &mut TestAppContext) -> AgentTest { +async fn setup(cx: &mut TestAppContext) -> ThreadTest { cx.executor().allow_parking(); cx.update(settings::init); let templates = Templates::new(); - let agent = cx.new(|_| Thread::new(templates)); + let thread = cx.new(|_| Thread::new(templates)); let model = cx .update(|cx| { @@ -239,7 +239,7 @@ async fn setup(cx: &mut TestAppContext) -> AgentTest { }) .await; - AgentTest { model, agent } + ThreadTest { model, thread } } #[cfg(test)]