
- Renamed Agent struct to NativeAgent to better reflect its native implementation - Renamed AgentConnection to NativeAgentConnection for consistency - Updated all references and implementations - Bumped agent-client-protocol version to 0.0.14
378 lines
12 KiB
Rust
378 lines
12 KiB
Rust
use super::*;
|
|
use crate::templates::Templates;
|
|
use acp_thread::AgentConnection as _;
|
|
use agent_client_protocol as acp;
|
|
use client::{Client, UserStore};
|
|
use gpui::{AppContext, Entity, TestAppContext};
|
|
use language_model::{
|
|
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
|
LanguageModelRegistry, MessageContent, StopReason,
|
|
};
|
|
use project::Project;
|
|
use reqwest_client::ReqwestClient;
|
|
use schemars::JsonSchema;
|
|
use serde::{Deserialize, Serialize};
|
|
use smol::stream::StreamExt;
|
|
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
|
|
|
mod test_tools;
|
|
use test_tools::*;
|
|
|
|
#[gpui::test]
|
|
async fn test_echo(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx).await;
|
|
|
|
let events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
|
|
})
|
|
.collect()
|
|
.await;
|
|
thread.update(cx, |thread, _cx| {
|
|
assert_eq!(
|
|
thread.messages().last().unwrap().content,
|
|
vec![MessageContent::Text("Hello".to_string())]
|
|
);
|
|
});
|
|
assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx).await;
|
|
|
|
// Test a tool call that's likely to complete *before* streaming stops.
|
|
let events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(EchoTool);
|
|
thread.send(
|
|
model.clone(),
|
|
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
|
|
cx,
|
|
)
|
|
})
|
|
.collect()
|
|
.await;
|
|
assert_eq!(
|
|
stop_events(events),
|
|
vec![StopReason::ToolUse, StopReason::EndTurn]
|
|
);
|
|
|
|
// Test a tool calls that's likely to complete *after* streaming stops.
|
|
let events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.remove_tool(&AgentTool::name(&EchoTool));
|
|
thread.add_tool(DelayTool);
|
|
thread.send(
|
|
model.clone(),
|
|
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
|
|
cx,
|
|
)
|
|
})
|
|
.collect()
|
|
.await;
|
|
assert_eq!(
|
|
stop_events(events),
|
|
vec![StopReason::ToolUse, StopReason::EndTurn]
|
|
);
|
|
thread.update(cx, |thread, _cx| {
|
|
assert!(thread
|
|
.messages()
|
|
.last()
|
|
.unwrap()
|
|
.content
|
|
.iter()
|
|
.any(|content| {
|
|
if let MessageContent::Text(text) = content {
|
|
text.contains("Ding")
|
|
} else {
|
|
false
|
|
}
|
|
}));
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx).await;
|
|
|
|
// Test a tool call that's likely to complete *before* streaming stops.
|
|
let mut events = thread.update(cx, |thread, cx| {
|
|
thread.add_tool(WordListTool);
|
|
thread.send(model.clone(), "Test the word_list tool.", cx)
|
|
});
|
|
|
|
let mut saw_partial_tool_use = false;
|
|
while let Some(event) = events.next().await {
|
|
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event {
|
|
thread.update(cx, |thread, _cx| {
|
|
// Look for a tool use in the thread's last message
|
|
let last_content = thread.messages().last().unwrap().content.last().unwrap();
|
|
if let MessageContent::ToolUse(last_tool_use) = last_content {
|
|
assert_eq!(last_tool_use.name.as_ref(), "word_list");
|
|
if tool_use_event.is_input_complete {
|
|
last_tool_use
|
|
.input
|
|
.get("a")
|
|
.expect("'a' has streamed because input is now complete");
|
|
last_tool_use
|
|
.input
|
|
.get("g")
|
|
.expect("'g' has streamed because input is now complete");
|
|
} else {
|
|
if !last_tool_use.is_input_complete
|
|
&& last_tool_use.input.get("g").is_none()
|
|
{
|
|
saw_partial_tool_use = true;
|
|
}
|
|
}
|
|
} else {
|
|
panic!("last content should be a tool use");
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
assert!(
|
|
saw_partial_tool_use,
|
|
"should see at least one partially streamed tool use in the history"
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
|
let ThreadTest { model, thread, .. } = setup(cx).await;
|
|
|
|
// Test concurrent tool calls with different delay times
|
|
let events = thread
|
|
.update(cx, |thread, cx| {
|
|
thread.add_tool(DelayTool);
|
|
thread.send(
|
|
model.clone(),
|
|
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
|
|
cx,
|
|
)
|
|
})
|
|
.collect()
|
|
.await;
|
|
|
|
let stop_reasons = stop_events(events);
|
|
if stop_reasons.len() == 2 {
|
|
assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
|
|
} else if stop_reasons.len() == 3 {
|
|
assert_eq!(
|
|
stop_reasons,
|
|
vec![
|
|
StopReason::ToolUse,
|
|
StopReason::ToolUse,
|
|
StopReason::EndTurn
|
|
]
|
|
);
|
|
} else {
|
|
panic!("Expected either 1 or 2 tool uses followed by end turn");
|
|
}
|
|
|
|
thread.update(cx, |thread, _cx| {
|
|
let last_message = thread.messages().last().unwrap();
|
|
let text = last_message
|
|
.content
|
|
.iter()
|
|
.filter_map(|content| {
|
|
if let MessageContent::Text(text) = content {
|
|
Some(text.as_str())
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.collect::<String>();
|
|
|
|
assert!(text.contains("Ding"));
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_agent_connection(cx: &mut TestAppContext) {
|
|
cx.executor().allow_parking();
|
|
cx.update(settings::init);
|
|
let templates = Templates::new();
|
|
|
|
// Initialize language model system with test provider
|
|
cx.update(|cx| {
|
|
gpui_tokio::init(cx);
|
|
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
|
cx.set_http_client(Arc::new(http_client));
|
|
|
|
client::init_settings(cx);
|
|
let client = Client::production(cx);
|
|
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
|
language_model::init(client.clone(), cx);
|
|
language_models::init(user_store.clone(), client.clone(), cx);
|
|
|
|
// Initialize project settings
|
|
Project::init_settings(cx);
|
|
|
|
// Use test registry with fake provider
|
|
LanguageModelRegistry::test(cx);
|
|
});
|
|
|
|
// Create agent and connection
|
|
let agent = cx.new(|_| NativeAgent::new(templates.clone()));
|
|
let connection = NativeAgentConnection(agent.clone());
|
|
|
|
// Test model_selector returns Some
|
|
let selector_opt = connection.model_selector();
|
|
assert!(
|
|
selector_opt.is_some(),
|
|
"agent2 should always support ModelSelector"
|
|
);
|
|
let selector = selector_opt.unwrap();
|
|
|
|
// Test list_models
|
|
let listed_models = cx
|
|
.update(|cx| {
|
|
let mut async_cx = cx.to_async();
|
|
selector.list_models(&mut async_cx)
|
|
})
|
|
.await
|
|
.expect("list_models should succeed");
|
|
assert!(!listed_models.is_empty(), "should have at least one model");
|
|
assert_eq!(listed_models[0].id().0, "fake");
|
|
|
|
// Create a project for new_thread
|
|
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
|
|
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
|
|
|
|
// Create a thread using new_thread
|
|
let cwd = Path::new("/test");
|
|
let connection_rc = Rc::new(connection.clone());
|
|
let acp_thread = cx
|
|
.update(|cx| {
|
|
let mut async_cx = cx.to_async();
|
|
connection_rc.new_thread(project, cwd, &mut async_cx)
|
|
})
|
|
.await
|
|
.expect("new_thread should succeed");
|
|
|
|
// Get the session_id from the AcpThread
|
|
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
|
|
|
// Test selected_model returns the default
|
|
let selected = cx
|
|
.update(|cx| {
|
|
let mut async_cx = cx.to_async();
|
|
selector.selected_model(&session_id, &mut async_cx)
|
|
})
|
|
.await
|
|
.expect("selected_model should succeed");
|
|
assert_eq!(selected.id().0, "fake", "should return default model");
|
|
|
|
// The thread was created via prompt with the default model
|
|
// We can verify it through selected_model
|
|
|
|
// Test prompt uses the selected model
|
|
let prompt_request = acp::PromptRequest {
|
|
session_id: session_id.clone(),
|
|
prompt: vec![acp::ContentBlock::Text(acp::TextContent {
|
|
text: "Test prompt".into(),
|
|
annotations: None,
|
|
})],
|
|
};
|
|
|
|
cx.update(|cx| connection.prompt(prompt_request, cx))
|
|
.await
|
|
.expect("prompt should succeed");
|
|
|
|
// The prompt was sent successfully
|
|
|
|
// Test cancel
|
|
cx.update(|cx| connection.cancel(&session_id, cx));
|
|
|
|
// After cancel, selected_model should fail
|
|
let result = cx
|
|
.update(|cx| {
|
|
let mut async_cx = cx.to_async();
|
|
selector.selected_model(&session_id, &mut async_cx)
|
|
})
|
|
.await;
|
|
assert!(result.is_err(), "selected_model should fail after cancel");
|
|
|
|
// Test error case: invalid session
|
|
let invalid_session = acp::SessionId("invalid".into());
|
|
let result = cx
|
|
.update(|cx| {
|
|
let mut async_cx = cx.to_async();
|
|
selector.selected_model(&invalid_session, &mut async_cx)
|
|
})
|
|
.await;
|
|
assert!(result.is_err(), "should fail for invalid session");
|
|
if let Err(e) = result {
|
|
assert!(
|
|
e.to_string().contains("Session not found"),
|
|
"should have correct error message"
|
|
);
|
|
}
|
|
}
|
|
|
|
/// Filters out the stop events for asserting against in tests
|
|
fn stop_events(
|
|
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
|
) -> Vec<StopReason> {
|
|
result_events
|
|
.into_iter()
|
|
.filter_map(|event| match event.unwrap() {
|
|
LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
|
|
_ => None,
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
struct ThreadTest {
|
|
model: Arc<dyn LanguageModel>,
|
|
thread: Entity<Thread>,
|
|
}
|
|
|
|
async fn setup(cx: &mut TestAppContext) -> ThreadTest {
|
|
cx.executor().allow_parking();
|
|
cx.update(settings::init);
|
|
let templates = Templates::new();
|
|
|
|
let model = cx
|
|
.update(|cx| {
|
|
gpui_tokio::init(cx);
|
|
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
|
cx.set_http_client(Arc::new(http_client));
|
|
|
|
client::init_settings(cx);
|
|
let client = Client::production(cx);
|
|
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
|
language_model::init(client.clone(), cx);
|
|
language_models::init(user_store.clone(), client.clone(), cx);
|
|
|
|
let models = LanguageModelRegistry::read_global(cx);
|
|
let model = models
|
|
.available_models(cx)
|
|
.find(|model| model.id().0 == "claude-3-7-sonnet-latest")
|
|
.unwrap();
|
|
|
|
let provider = models.provider(&model.provider_id()).unwrap();
|
|
let authenticated = provider.authenticate(cx);
|
|
|
|
cx.spawn(async move |_cx| {
|
|
authenticated.await.unwrap();
|
|
model
|
|
})
|
|
})
|
|
.await;
|
|
|
|
let thread = cx.new(|_| Thread::new(templates, model.clone()));
|
|
|
|
ThreadTest { model, thread }
|
|
}
|
|
|
|
#[cfg(test)]
|
|
#[ctor::ctor]
|
|
fn init_logger() {
|
|
if std::env::var("RUST_LOG").is_ok() {
|
|
env_logger::init();
|
|
}
|
|
}
|