Add system prompt and tool permission to agent2 (#35781)

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-08-07 15:40:12 +02:00 committed by GitHub
parent 4dbd24d75f
commit 03876d076e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1111 additions and 304 deletions

View file

@ -1,30 +1,34 @@
use super::*;
use crate::templates::Templates;
use acp_thread::AgentConnection;
use agent_client_protocol as acp;
use agent_client_protocol::{self as acp};
use anyhow::Result;
use assistant_tool::ActionLog;
use client::{Client, UserStore};
use fs::FakeFs;
use futures::channel::mpsc::UnboundedReceiver;
use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
use indoc::indoc;
use language_model::{
fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, MessageContent,
StopReason,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult,
LanguageModelToolUse, MessageContent, Role, StopReason,
};
use project::Project;
use prompt_store::ProjectContext;
use reqwest_client::ReqwestClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use smol::stream::StreamExt;
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
use util::path;
mod test_tools;
use test_tools::*;
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
#[ignore = "can't run on CI yet"]
async fn test_echo(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
@ -44,7 +48,7 @@ async fn test_echo(cx: &mut TestAppContext) {
}
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
#[ignore = "can't run on CI yet"]
async fn test_thinking(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
@ -77,7 +81,46 @@ async fn test_thinking(cx: &mut TestAppContext) {
}
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
async fn test_system_prompt(cx: &mut TestAppContext) {
let ThreadTest {
model,
thread,
project_context,
..
} = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
project_context.borrow_mut().shell = "test-shell".into();
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions();
assert_eq!(
pending_completions.len(),
1,
"unexpected pending completions: {:?}",
pending_completions
);
let pending_completion = pending_completions.pop().unwrap();
assert_eq!(pending_completion.messages[0].role, Role::System);
let system_message = &pending_completion.messages[0];
let system_prompt = system_message.content[0].to_str().unwrap();
assert!(
system_prompt.contains("test-shell"),
"unexpected system message: {:?}",
system_message
);
assert!(
system_prompt.contains("## Fixing Diagnostics"),
"unexpected system message: {:?}",
system_message
);
}
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
@ -127,7 +170,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
}
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
#[ignore = "can't run on CI yet"]
async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
@ -175,7 +218,104 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
}
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
async fn test_tool_authorization(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let mut events = thread.update(cx, |thread, cx| {
thread.add_tool(ToolRequiringPermission);
thread.send(model.clone(), "abc", cx)
});
cx.run_until_parked();
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "tool_id_1".into(),
name: ToolRequiringPermission.name().into(),
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
},
));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "tool_id_2".into(),
name: ToolRequiringPermission.name().into(),
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
},
));
fake_model.end_last_completion_stream();
let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
// Approve the first
tool_call_auth_1
.response
.send(tool_call_auth_1.options[1].id.clone())
.unwrap();
cx.run_until_parked();
// Reject the second
tool_call_auth_2
.response
.send(tool_call_auth_1.options[2].id.clone())
.unwrap();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
let message = completion.messages.last().unwrap();
assert_eq!(
message.content,
vec![
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
tool_name: tool_call_auth_1.tool_call.title.into(),
is_error: false,
content: "Allowed".into(),
output: None
}),
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
tool_name: tool_call_auth_2.tool_call.title.into(),
is_error: true,
content: "Permission to run tool denied by user".into(),
output: None
})
]
);
}
async fn next_tool_call_authorization(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) -> ToolCallAuthorization {
loop {
let event = events
.next()
.await
.expect("no tool call authorization event received")
.unwrap();
if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
let permission_kinds = tool_call_authorization
.options
.iter()
.map(|o| o.kind)
.collect::<Vec<_>>();
assert_eq!(
permission_kinds,
vec![
acp::PermissionOptionKind::AllowAlways,
acp::PermissionOptionKind::AllowOnce,
acp::PermissionOptionKind::RejectOnce,
]
);
return tool_call_authorization;
}
}
}
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
@ -214,7 +354,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
}
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
#[ignore = "can't run on CI yet"]
async fn test_cancellation(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
@ -281,12 +421,10 @@ async fn test_cancellation(cx: &mut TestAppContext) {
#[gpui::test]
async fn test_refusal(cx: &mut TestAppContext) {
let fake_model = Arc::new(FakeLanguageModel::default());
let ThreadTest { thread, .. } = setup(cx, TestModel::Fake(fake_model.clone())).await;
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let events = thread.update(cx, |thread, cx| {
thread.send(fake_model.clone(), "Hello", cx)
});
let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx));
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
@ -343,8 +481,16 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
});
cx.executor().forbid_parking();
// Create a project for new_thread
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
fake_fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
let cwd = Path::new("/test");
// Create agent and connection
let agent = cx.new(|_| NativeAgent::new(templates.clone()));
let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
.await
.unwrap();
let connection = NativeAgentConnection(agent.clone());
// Test model_selector returns Some
@ -366,12 +512,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
assert!(!listed_models.is_empty(), "should have at least one model");
assert_eq!(listed_models[0].id().0, "fake");
// Create a project for new_thread
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
// Create a thread using new_thread
let cwd = Path::new("/test");
let connection_rc = Rc::new(connection.clone());
let acp_thread = cx
.update(|cx| {
@ -457,12 +598,13 @@ fn stop_events(
struct ThreadTest {
model: Arc<dyn LanguageModel>,
thread: Entity<Thread>,
project_context: Rc<RefCell<ProjectContext>>,
}
enum TestModel {
Sonnet4,
Sonnet4Thinking,
Fake(Arc<FakeLanguageModel>),
Fake,
}
impl TestModel {
@ -470,7 +612,7 @@ impl TestModel {
match self {
TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
TestModel::Fake(fake_model) => fake_model.id(),
TestModel::Fake => unreachable!(),
}
}
}
@ -499,8 +641,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
if let TestModel::Fake(model) = model {
Task::ready(model as Arc<_>)
if let TestModel::Fake = model {
Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
} else {
let model_id = model.id();
let models = LanguageModelRegistry::read_global(cx);
@ -520,9 +662,22 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
})
.await;
let thread = cx.new(|_| Thread::new(project, templates, model.clone()));
ThreadTest { model, thread }
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|_| {
Thread::new(
project,
project_context.clone(),
action_log,
templates,
model.clone(),
)
});
ThreadTest {
model,
thread,
project_context,
}
}
#[cfg(test)]