Add system prompt and tool permission to agent2 (#35781)

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-08-07 15:40:12 +02:00 committed by GitHub
parent 4dbd24d75f
commit 03876d076e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1111 additions and 304 deletions

View file

@ -1,9 +1,13 @@
use crate::{prompts::BasePrompt, templates::Templates};
use crate::templates::{SystemPromptTemplate, Template, Templates};
use agent_client_protocol as acp;
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context as _, Result};
use assistant_tool::ActionLog;
use cloud_llm_client::{CompletionIntent, CompletionMode};
use collections::HashMap;
use futures::{channel::mpsc, stream::FuturesUnordered};
use futures::{
channel::{mpsc, oneshot},
stream::FuturesUnordered,
};
use gpui::{App, Context, Entity, ImageFormat, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
@ -13,10 +17,11 @@ use language_model::{
};
use log;
use project::Project;
use prompt_store::ProjectContext;
use schemars::{JsonSchema, Schema};
use serde::Deserialize;
use smol::stream::StreamExt;
use std::{collections::BTreeMap, fmt::Write, sync::Arc};
use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
use util::{markdown::MarkdownCodeBlock, ResultExt};
#[derive(Debug, Clone)]
@ -97,11 +102,15 @@ pub enum AgentResponseEvent {
Thinking(String),
ToolCall(acp::ToolCall),
ToolCallUpdate(acp::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
Stop(acp::StopReason),
}
pub trait Prompt {
fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
#[derive(Debug)]
pub struct ToolCallAuthorization {
pub tool_call: acp::ToolCall,
pub options: Vec<acp::PermissionOption>,
pub response: oneshot::Sender<acp::PermissionOptionId>,
}
pub struct Thread {
@ -112,28 +121,31 @@ pub struct Thread {
/// we run tools, report their results.
running_turn: Option<Task<()>>,
pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
system_prompts: Vec<Arc<dyn Prompt>>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>,
// action_log: Entity<ActionLog>,
_action_log: Entity<ActionLog>,
}
impl Thread {
pub fn new(
project: Entity<Project>,
_project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
default_model: Arc<dyn LanguageModel>,
) -> Self {
Self {
messages: Vec::new(),
completion_mode: CompletionMode::Normal,
system_prompts: vec![Arc::new(BasePrompt::new(project))],
running_turn: None,
pending_tool_uses: HashMap::default(),
tools: BTreeMap::default(),
project_context,
templates,
selected_model: default_model,
_action_log: action_log,
}
}
@ -188,6 +200,7 @@ impl Thread {
cx.notify();
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let event_stream = AgentResponseEventStream(events_tx);
let user_message_ix = self.messages.len();
self.messages.push(AgentMessage {
@ -222,12 +235,7 @@ impl Thread {
while let Some(event) = events.next().await {
match event {
Ok(LanguageModelCompletionEvent::Stop(reason)) => {
if let Some(reason) = to_acp_stop_reason(reason) {
events_tx
.unbounded_send(Ok(AgentResponseEvent::Stop(reason)))
.ok();
}
event_stream.send_stop(reason);
if reason == StopReason::Refusal {
thread.update(cx, |thread, _cx| {
thread.messages.truncate(user_message_ix);
@ -240,14 +248,16 @@ impl Thread {
thread
.update(cx, |thread, cx| {
tool_uses.extend(thread.handle_streamed_completion_event(
event, &events_tx, cx,
event,
&event_stream,
cx,
));
})
.ok();
}
Err(error) => {
log::error!("Error in completion stream: {:?}", error);
events_tx.unbounded_send(Err(error)).ok();
event_stream.send_error(error);
break;
}
}
@ -266,11 +276,7 @@ impl Thread {
while let Some(tool_result) = tool_uses.next().await {
log::info!("Tool finished {:?}", tool_result);
events_tx
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
to_acp_tool_call_update(&tool_result),
)))
.ok();
event_stream.send_tool_call_result(&tool_result);
thread
.update(cx, |thread, _cx| {
thread.pending_tool_uses.remove(&tool_result.tool_use_id);
@ -291,7 +297,7 @@ impl Thread {
if let Err(error) = turn_result {
log::error!("Turn execution failed: {:?}", error);
events_tx.unbounded_send(Err(error)).ok();
event_stream.send_error(error);
} else {
log::info!("Turn execution completed successfully");
}
@ -299,24 +305,20 @@ impl Thread {
events_rx
}
pub fn build_system_message(&self, cx: &App) -> Option<AgentMessage> {
pub fn build_system_message(&self) -> AgentMessage {
log::debug!("Building system message");
let mut system_message = AgentMessage {
role: Role::System,
content: Vec::new(),
};
for prompt in &self.system_prompts {
if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
system_message
.content
.push(MessageContent::Text(rendered_prompt));
}
let prompt = SystemPromptTemplate {
project: &self.project_context.borrow(),
available_tools: self.tools.keys().cloned().collect(),
}
.render(&self.templates)
.context("failed to build system prompt")
.expect("Invalid template");
log::debug!("System message built");
AgentMessage {
role: Role::System,
content: vec![prompt.into()],
}
let result = (!system_message.content.is_empty()).then_some(system_message);
log::debug!("System message built: {}", result.is_some());
result
}
/// A helper method that's called on every streamed completion event.
@ -325,7 +327,7 @@ impl Thread {
fn handle_streamed_completion_event(
&mut self,
event: LanguageModelCompletionEvent,
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
event_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
log::trace!("Handling streamed completion event: {:?}", event);
@ -338,13 +340,13 @@ impl Thread {
content: Vec::new(),
});
}
Text(new_text) => self.handle_text_event(new_text, events_tx, cx),
Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
Thinking { text, signature } => {
self.handle_thinking_event(text, signature, events_tx, cx)
self.handle_thinking_event(text, signature, event_stream, cx)
}
RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
ToolUse(tool_use) => {
return self.handle_tool_use_event(tool_use, events_tx, cx);
return self.handle_tool_use_event(tool_use, event_stream, cx);
}
ToolUseJsonParseError {
id,
@ -369,12 +371,10 @@ impl Thread {
fn handle_text_event(
&mut self,
new_text: String,
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
events_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) {
events_tx
.unbounded_send(Ok(AgentResponseEvent::Text(new_text.clone())))
.ok();
events_stream.send_text(&new_text);
let last_message = self.last_assistant_message();
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
@ -390,12 +390,10 @@ impl Thread {
&mut self,
new_text: String,
new_signature: Option<String>,
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
event_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) {
events_tx
.unbounded_send(Ok(AgentResponseEvent::Thinking(new_text.clone())))
.ok();
event_stream.send_thinking(&new_text);
let last_message = self.last_assistant_message();
if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
@ -423,7 +421,7 @@ impl Thread {
fn handle_tool_use_event(
&mut self,
tool_use: LanguageModelToolUse,
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
event_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
cx.notify();
@ -446,32 +444,18 @@ impl Thread {
}
});
if push_new_tool_use {
events_tx
.unbounded_send(Ok(AgentResponseEvent::ToolCall(acp::ToolCall {
id: acp::ToolCallId(tool_use.id.to_string().into()),
title: tool_use.name.to_string(),
kind: acp::ToolKind::Other,
status: acp::ToolCallStatus::Pending,
content: vec![],
locations: vec![],
raw_input: Some(tool_use.input.clone()),
})))
.ok();
event_stream.send_tool_call(&tool_use);
last_message
.content
.push(MessageContent::ToolUse(tool_use.clone()));
} else {
events_tx
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use.id.to_string().into()),
fields: acp::ToolCallUpdateFields {
raw_input: Some(tool_use.input.clone()),
..Default::default()
},
},
)))
.ok();
event_stream.send_tool_call_update(
&tool_use.id,
acp::ToolCallUpdateFields {
raw_input: Some(tool_use.input.clone()),
..Default::default()
},
);
}
if !tool_use.is_input_complete {
@ -479,22 +463,10 @@ impl Thread {
}
if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
events_tx
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use.id.to_string().into()),
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
},
},
)))
.ok();
let pending_tool_result = tool.clone().run(tool_use.input, cx);
let tool_result =
self.run_tool(tool.clone(), tool_use.clone(), event_stream.clone(), cx);
Some(cx.foreground_executor().spawn(async move {
match pending_tool_result.await {
match tool_result.await {
Ok(tool_output) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
@ -523,6 +495,30 @@ impl Thread {
}
}
fn run_tool(
&self,
tool: Arc<dyn AnyAgentTool>,
tool_use: LanguageModelToolUse,
event_stream: AgentResponseEventStream,
cx: &mut Context<Self>,
) -> Task<Result<String>> {
let needs_authorization = tool.needs_authorization(tool_use.input.clone(), cx);
cx.spawn(async move |_this, cx| {
if needs_authorization? {
event_stream.authorize_tool_call(&tool_use).await?;
}
event_stream.send_tool_call_update(
&tool_use.id,
acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
},
);
cx.update(|cx| tool.run(tool_use.input, cx))?.await
})
}
fn handle_tool_use_json_parse_error_event(
&mut self,
tool_use_id: LanguageModelToolUseId,
@ -575,7 +571,7 @@ impl Thread {
log::debug!("Completion intent: {:?}", completion_intent);
log::debug!("Completion mode: {:?}", self.completion_mode);
let messages = self.build_request_messages(cx);
let messages = self.build_request_messages();
log::info!("Request will include {} messages", messages.len());
let tools: Vec<LanguageModelRequestTool> = self
@ -613,14 +609,13 @@ impl Thread {
request
}
fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
log::trace!(
"Building request messages from {} thread messages",
self.messages.len()
);
let messages = self
.build_system_message(cx)
let messages = Some(self.build_system_message())
.iter()
.chain(self.messages.iter())
.map(|message| {
@ -674,6 +669,10 @@ where
schemars::schema_for!(Self::Input)
}
/// Returns true if the tool needs the users's authorization
/// before running.
fn needs_authorization(&self, input: Self::Input, cx: &App) -> bool;
/// Runs the tool with the provided input.
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
@ -688,6 +687,7 @@ pub trait AnyAgentTool {
fn name(&self) -> SharedString;
fn description(&self, cx: &mut App) -> SharedString;
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
fn needs_authorization(&self, input: serde_json::Value, cx: &mut App) -> Result<bool>;
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
}
@ -707,6 +707,14 @@ where
Ok(serde_json::to_value(self.0.input_schema(format))?)
}
fn needs_authorization(&self, input: serde_json::Value, cx: &mut App) -> Result<bool> {
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
match parsed_input {
Ok(input) => Ok(self.0.needs_authorization(input, cx)),
Err(error) => Err(anyhow!(error)),
}
}
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
match parsed_input {
@ -716,39 +724,153 @@ where
}
}
fn to_acp_stop_reason(reason: StopReason) -> Option<acp::StopReason> {
match reason {
StopReason::EndTurn => Some(acp::StopReason::EndTurn),
StopReason::MaxTokens => Some(acp::StopReason::MaxTokens),
StopReason::Refusal => Some(acp::StopReason::Refusal),
StopReason::ToolUse => None,
}
}
#[derive(Clone)]
struct AgentResponseEventStream(
mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
);
fn to_acp_tool_call_update(tool_result: &LanguageModelToolResult) -> acp::ToolCallUpdate {
let status = if tool_result.is_error {
acp::ToolCallStatus::Failed
} else {
acp::ToolCallStatus::Completed
};
let content = match &tool_result.content {
LanguageModelToolResultContent::Text(text) => text.to_string().into(),
LanguageModelToolResultContent::Image(LanguageModelImage { source, .. }) => {
acp::ToolCallContent::Content {
content: acp::ContentBlock::Image(acp::ImageContent {
annotations: None,
data: source.to_string(),
mime_type: ImageFormat::Png.mime_type().to_string(),
}),
impl AgentResponseEventStream {
fn send_text(&self, text: &str) {
self.0
.unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
.ok();
}
fn send_thinking(&self, text: &str) {
self.0
.unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
.ok();
}
fn authorize_tool_call(
&self,
tool_use: &LanguageModelToolUse,
) -> impl use<> + Future<Output = Result<()>> {
let (response_tx, response_rx) = oneshot::channel();
self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
ToolCallAuthorization {
tool_call: acp::ToolCall {
id: acp::ToolCallId(tool_use.id.to_string().into()),
title: tool_use.name.to_string(),
kind: acp::ToolKind::Other,
status: acp::ToolCallStatus::Pending,
content: vec![],
locations: vec![],
raw_input: Some(tool_use.input.clone()),
},
options: vec![
acp::PermissionOption {
id: acp::PermissionOptionId("always_allow".into()),
name: "Always Allow".into(),
kind: acp::PermissionOptionKind::AllowAlways,
},
acp::PermissionOption {
id: acp::PermissionOptionId("allow".into()),
name: "Allow".into(),
kind: acp::PermissionOptionKind::AllowOnce,
},
acp::PermissionOption {
id: acp::PermissionOptionId("deny".into()),
name: "Deny".into(),
kind: acp::PermissionOptionKind::RejectOnce,
},
],
response: response_tx,
},
)))
.ok();
async move {
match response_rx.await?.0.as_ref() {
"allow" | "always_allow" => Ok(()),
_ => Err(anyhow!("Permission to run tool denied by user")),
}
}
};
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_result.tool_use_id.to_string().into()),
fields: acp::ToolCallUpdateFields {
status: Some(status),
content: Some(vec![content]),
..Default::default()
},
}
fn send_tool_call(&self, tool_use: &LanguageModelToolUse) {
self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCall(acp::ToolCall {
id: acp::ToolCallId(tool_use.id.to_string().into()),
title: tool_use.name.to_string(),
kind: acp::ToolKind::Other,
status: acp::ToolCallStatus::Pending,
content: vec![],
locations: vec![],
raw_input: Some(tool_use.input.clone()),
})))
.ok();
}
fn send_tool_call_update(
&self,
tool_use_id: &LanguageModelToolUseId,
fields: acp::ToolCallUpdateFields,
) {
self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.to_string().into()),
fields,
},
)))
.ok();
}
fn send_tool_call_result(&self, tool_result: &LanguageModelToolResult) {
let status = if tool_result.is_error {
acp::ToolCallStatus::Failed
} else {
acp::ToolCallStatus::Completed
};
let content = match &tool_result.content {
LanguageModelToolResultContent::Text(text) => text.to_string().into(),
LanguageModelToolResultContent::Image(LanguageModelImage { source, .. }) => {
acp::ToolCallContent::Content {
content: acp::ContentBlock::Image(acp::ImageContent {
annotations: None,
data: source.to_string(),
mime_type: ImageFormat::Png.mime_type().to_string(),
}),
}
}
};
self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_result.tool_use_id.to_string().into()),
fields: acp::ToolCallUpdateFields {
status: Some(status),
content: Some(vec![content]),
..Default::default()
},
},
)))
.ok();
}
fn send_stop(&self, reason: StopReason) {
match reason {
StopReason::EndTurn => {
self.0
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
.ok();
}
StopReason::MaxTokens => {
self.0
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
.ok();
}
StopReason::Refusal => {
self.0
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
.ok();
}
StopReason::ToolUse => {}
}
}
fn send_error(&self, error: LanguageModelCompletionError) {
self.0.unbounded_send(Err(error)).ok();
}
}