Add system prompt and tool permission to agent2 (#35781)
Release Notes: - N/A --------- Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com> Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
parent
4dbd24d75f
commit
03876d076e
21 changed files with 1111 additions and 304 deletions
|
@ -1,30 +1,34 @@
|
|||
use super::*;
|
||||
use crate::templates::Templates;
|
||||
use acp_thread::AgentConnection;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
use assistant_tool::ActionLog;
|
||||
use client::{Client, UserStore};
|
||||
use fs::FakeFs;
|
||||
use futures::channel::mpsc::UnboundedReceiver;
|
||||
use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language_model::{
|
||||
fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
|
||||
LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, MessageContent,
|
||||
StopReason,
|
||||
LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult,
|
||||
LanguageModelToolUse, MessageContent, Role, StopReason,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use smol::stream::StreamExt;
|
||||
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||
use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||
use util::path;
|
||||
|
||||
mod test_tools;
|
||||
use test_tools::*;
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_echo(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||
|
||||
|
@ -44,7 +48,7 @@ async fn test_echo(cx: &mut TestAppContext) {
|
|||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_thinking(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
|
||||
|
||||
|
@ -77,7 +81,46 @@ async fn test_thinking(cx: &mut TestAppContext) {
|
|||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
async fn test_system_prompt(cx: &mut TestAppContext) {
|
||||
let ThreadTest {
|
||||
model,
|
||||
thread,
|
||||
project_context,
|
||||
..
|
||||
} = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
project_context.borrow_mut().shell = "test-shell".into();
|
||||
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
||||
thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
|
||||
cx.run_until_parked();
|
||||
let mut pending_completions = fake_model.pending_completions();
|
||||
assert_eq!(
|
||||
pending_completions.len(),
|
||||
1,
|
||||
"unexpected pending completions: {:?}",
|
||||
pending_completions
|
||||
);
|
||||
|
||||
let pending_completion = pending_completions.pop().unwrap();
|
||||
assert_eq!(pending_completion.messages[0].role, Role::System);
|
||||
|
||||
let system_message = &pending_completion.messages[0];
|
||||
let system_prompt = system_message.content[0].to_str().unwrap();
|
||||
assert!(
|
||||
system_prompt.contains("test-shell"),
|
||||
"unexpected system message: {:?}",
|
||||
system_message
|
||||
);
|
||||
assert!(
|
||||
system_prompt.contains("## Fixing Diagnostics"),
|
||||
"unexpected system message: {:?}",
|
||||
system_message
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||
|
||||
|
@ -127,7 +170,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
|||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||
|
||||
|
@ -175,7 +218,104 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
|||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(ToolRequiringPermission);
|
||||
thread.send(model.clone(), "abc", cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: "tool_id_1".into(),
|
||||
name: ToolRequiringPermission.name().into(),
|
||||
raw_input: "{}".into(),
|
||||
input: json!({}),
|
||||
is_input_complete: true,
|
||||
},
|
||||
));
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: "tool_id_2".into(),
|
||||
name: ToolRequiringPermission.name().into(),
|
||||
raw_input: "{}".into(),
|
||||
input: json!({}),
|
||||
is_input_complete: true,
|
||||
},
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
|
||||
let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
|
||||
|
||||
// Approve the first
|
||||
tool_call_auth_1
|
||||
.response
|
||||
.send(tool_call_auth_1.options[1].id.clone())
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
// Reject the second
|
||||
tool_call_auth_2
|
||||
.response
|
||||
.send(tool_call_auth_1.options[2].id.clone())
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
let message = completion.messages.last().unwrap();
|
||||
assert_eq!(
|
||||
message.content,
|
||||
vec![
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
|
||||
tool_name: tool_call_auth_1.tool_call.title.into(),
|
||||
is_error: false,
|
||||
content: "Allowed".into(),
|
||||
output: None
|
||||
}),
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
|
||||
tool_name: tool_call_auth_2.tool_call.title.into(),
|
||||
is_error: true,
|
||||
content: "Permission to run tool denied by user".into(),
|
||||
output: None
|
||||
})
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
async fn next_tool_call_authorization(
|
||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
) -> ToolCallAuthorization {
|
||||
loop {
|
||||
let event = events
|
||||
.next()
|
||||
.await
|
||||
.expect("no tool call authorization event received")
|
||||
.unwrap();
|
||||
if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
|
||||
let permission_kinds = tool_call_authorization
|
||||
.options
|
||||
.iter()
|
||||
.map(|o| o.kind)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
permission_kinds,
|
||||
vec![
|
||||
acp::PermissionOptionKind::AllowAlways,
|
||||
acp::PermissionOptionKind::AllowOnce,
|
||||
acp::PermissionOptionKind::RejectOnce,
|
||||
]
|
||||
);
|
||||
return tool_call_authorization;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||
|
||||
|
@ -214,7 +354,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
|||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_cancellation(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||
|
||||
|
@ -281,12 +421,10 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
|||
|
||||
#[gpui::test]
|
||||
async fn test_refusal(cx: &mut TestAppContext) {
|
||||
let fake_model = Arc::new(FakeLanguageModel::default());
|
||||
let ThreadTest { thread, .. } = setup(cx, TestModel::Fake(fake_model.clone())).await;
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let events = thread.update(cx, |thread, cx| {
|
||||
thread.send(fake_model.clone(), "Hello", cx)
|
||||
});
|
||||
let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx));
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
|
@ -343,8 +481,16 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
|||
});
|
||||
cx.executor().forbid_parking();
|
||||
|
||||
// Create a project for new_thread
|
||||
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
|
||||
fake_fs.insert_tree(path!("/test"), json!({})).await;
|
||||
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
|
||||
let cwd = Path::new("/test");
|
||||
|
||||
// Create agent and connection
|
||||
let agent = cx.new(|_| NativeAgent::new(templates.clone()));
|
||||
let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
let connection = NativeAgentConnection(agent.clone());
|
||||
|
||||
// Test model_selector returns Some
|
||||
|
@ -366,12 +512,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
|||
assert!(!listed_models.is_empty(), "should have at least one model");
|
||||
assert_eq!(listed_models[0].id().0, "fake");
|
||||
|
||||
// Create a project for new_thread
|
||||
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
|
||||
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
|
||||
|
||||
// Create a thread using new_thread
|
||||
let cwd = Path::new("/test");
|
||||
let connection_rc = Rc::new(connection.clone());
|
||||
let acp_thread = cx
|
||||
.update(|cx| {
|
||||
|
@ -457,12 +598,13 @@ fn stop_events(
|
|||
struct ThreadTest {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
thread: Entity<Thread>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
}
|
||||
|
||||
enum TestModel {
|
||||
Sonnet4,
|
||||
Sonnet4Thinking,
|
||||
Fake(Arc<FakeLanguageModel>),
|
||||
Fake,
|
||||
}
|
||||
|
||||
impl TestModel {
|
||||
|
@ -470,7 +612,7 @@ impl TestModel {
|
|||
match self {
|
||||
TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
|
||||
TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
|
||||
TestModel::Fake(fake_model) => fake_model.id(),
|
||||
TestModel::Fake => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -499,8 +641,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
if let TestModel::Fake(model) = model {
|
||||
Task::ready(model as Arc<_>)
|
||||
if let TestModel::Fake = model {
|
||||
Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
|
||||
} else {
|
||||
let model_id = model.id();
|
||||
let models = LanguageModelRegistry::read_global(cx);
|
||||
|
@ -520,9 +662,22 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
})
|
||||
.await;
|
||||
|
||||
let thread = cx.new(|_| Thread::new(project, templates, model.clone()));
|
||||
|
||||
ThreadTest { model, thread }
|
||||
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let thread = cx.new(|_| {
|
||||
Thread::new(
|
||||
project,
|
||||
project_context.clone(),
|
||||
action_log,
|
||||
templates,
|
||||
model.clone(),
|
||||
)
|
||||
});
|
||||
ThreadTest {
|
||||
model,
|
||||
thread,
|
||||
project_context,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue