agent2: Initial infra for checkpoints and message editing (#36120)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
f4b0332f78
commit
23cd5b59b2
17 changed files with 1374 additions and 582 deletions
|
@ -1,8 +1,9 @@
|
|||
use crate::{AgentResponseEvent, Thread, templates::Templates};
|
||||
use crate::{
|
||||
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
|
||||
FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool,
|
||||
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool,
|
||||
FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool,
|
||||
ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
|
||||
WebSearchTool,
|
||||
};
|
||||
use acp_thread::AgentModelSelector;
|
||||
use agent_client_protocol as acp;
|
||||
|
@ -637,9 +638,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
|
||||
fn prompt(
|
||||
&self,
|
||||
id: Option<acp_thread::UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
let id = id.expect("UserMessageId is required");
|
||||
let session_id = params.session_id.clone();
|
||||
let agent = self.0.clone();
|
||||
log::info!("Received prompt request for session: {}", session_id);
|
||||
|
@ -660,13 +663,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
})?;
|
||||
log::debug!("Found session for: {}", session_id);
|
||||
|
||||
let message: Vec<MessageContent> = params
|
||||
let content: Vec<UserMessageContent> = params
|
||||
.prompt
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.collect::<Vec<_>>();
|
||||
log::info!("Converted prompt to message: {} chars", message.len());
|
||||
log::debug!("Message content: {:?}", message);
|
||||
log::info!("Converted prompt to message: {} chars", content.len());
|
||||
log::debug!("Message id: {:?}", id);
|
||||
log::debug!("Message content: {:?}", content);
|
||||
|
||||
// Get model using the ModelSelector capability (always available for agent2)
|
||||
// Get the selected model from the thread directly
|
||||
|
@ -674,7 +678,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
|
||||
// Send to thread
|
||||
log::info!("Sending message to thread with model: {:?}", model.name());
|
||||
let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
|
||||
let mut response_stream =
|
||||
thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
|
||||
|
||||
// Handle response stream and forward to session.acp_thread
|
||||
while let Some(result) = response_stream.next().await {
|
||||
|
@ -768,6 +773,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn session_editor(
|
||||
&self,
|
||||
session_id: &agent_client_protocol::SessionId,
|
||||
cx: &mut App,
|
||||
) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
|
||||
self.0.update(cx, |agent, _cx| {
|
||||
agent
|
||||
.sessions
|
||||
.get(session_id)
|
||||
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct NativeAgentSessionEditor(Entity<Thread>);
|
||||
|
||||
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
|
||||
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use super::*;
|
||||
use crate::MessageContent;
|
||||
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList};
|
||||
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use agent_settings::AgentProfileId;
|
||||
|
@ -38,15 +37,19 @@ async fn test_echo(cx: &mut TestAppContext) {
|
|||
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send("Testing: Reply with 'Hello'", cx)
|
||||
thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.messages().last().unwrap().content,
|
||||
vec![MessageContent::Text("Hello".to_string())]
|
||||
);
|
||||
thread.last_message().unwrap().to_markdown(),
|
||||
indoc! {"
|
||||
## Assistant
|
||||
|
||||
Hello
|
||||
"}
|
||||
)
|
||||
});
|
||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||
}
|
||||
|
@ -59,12 +62,13 @@ async fn test_thinking(cx: &mut TestAppContext) {
|
|||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(
|
||||
indoc! {"
|
||||
UserMessageId::new(),
|
||||
[indoc! {"
|
||||
Testing:
|
||||
|
||||
Generate a thinking step where you just think the word 'Think',
|
||||
and have your final answer be 'Hello'
|
||||
"},
|
||||
"}],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
|
@ -72,9 +76,10 @@ async fn test_thinking(cx: &mut TestAppContext) {
|
|||
.await;
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.messages().last().unwrap().to_markdown(),
|
||||
thread.last_message().unwrap().to_markdown(),
|
||||
indoc! {"
|
||||
## assistant
|
||||
## Assistant
|
||||
|
||||
<think>Think</think>
|
||||
Hello
|
||||
"}
|
||||
|
@ -95,7 +100,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
|
|||
|
||||
project_context.borrow_mut().shell = "test-shell".into();
|
||||
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
||||
thread.update(cx, |thread, cx| thread.send("abc", cx));
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
let mut pending_completions = fake_model.pending_completions();
|
||||
assert_eq!(
|
||||
|
@ -132,7 +139,8 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
|||
.update(cx, |thread, cx| {
|
||||
thread.add_tool(EchoTool);
|
||||
thread.send(
|
||||
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
|
||||
UserMessageId::new(),
|
||||
["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
|
@ -146,7 +154,11 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
|||
thread.remove_tool(&AgentTool::name(&EchoTool));
|
||||
thread.add_tool(DelayTool);
|
||||
thread.send(
|
||||
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
|
||||
UserMessageId::new(),
|
||||
[
|
||||
"Now call the delay tool with 200ms.",
|
||||
"When the timer goes off, then you echo the output of the tool.",
|
||||
],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
|
@ -156,13 +168,14 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
|||
thread.update(cx, |thread, _cx| {
|
||||
assert!(
|
||||
thread
|
||||
.messages()
|
||||
.last()
|
||||
.last_message()
|
||||
.unwrap()
|
||||
.as_agent_message()
|
||||
.unwrap()
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
if let AgentMessageContent::Text(text) = content {
|
||||
text.contains("Ding")
|
||||
} else {
|
||||
false
|
||||
|
@ -182,7 +195,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
|||
// Test a tool call that's likely to complete *before* streaming stops.
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(WordListTool);
|
||||
thread.send("Test the word_list tool.", cx)
|
||||
thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
|
||||
});
|
||||
|
||||
let mut saw_partial_tool_use = false;
|
||||
|
@ -190,8 +203,10 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
|||
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
// Look for a tool use in the thread's last message
|
||||
let last_content = thread.messages().last().unwrap().content.last().unwrap();
|
||||
if let MessageContent::ToolUse(last_tool_use) = last_content {
|
||||
let message = thread.last_message().unwrap();
|
||||
let agent_message = message.as_agent_message().unwrap();
|
||||
let last_content = agent_message.content.last().unwrap();
|
||||
if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
|
||||
assert_eq!(last_tool_use.name.as_ref(), "word_list");
|
||||
if tool_call.status == acp::ToolCallStatus::Pending {
|
||||
if !last_tool_use.is_input_complete
|
||||
|
@ -229,7 +244,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
|||
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(ToolRequiringPermission);
|
||||
thread.send("abc", cx)
|
||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
|
@ -357,7 +372,9 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
|
|||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
|
@ -449,7 +466,12 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
|||
.update(cx, |thread, cx| {
|
||||
thread.add_tool(DelayTool);
|
||||
thread.send(
|
||||
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
|
||||
UserMessageId::new(),
|
||||
[
|
||||
"Call the delay tool twice in the same message.",
|
||||
"Once with 100ms. Once with 300ms.",
|
||||
"When both timers are complete, describe the outputs.",
|
||||
],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
|
@ -460,12 +482,13 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
|||
assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
let last_message = thread.messages().last().unwrap();
|
||||
let text = last_message
|
||||
let last_message = thread.last_message().unwrap();
|
||||
let agent_message = last_message.as_agent_message().unwrap();
|
||||
let text = agent_message
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
if let AgentMessageContent::Text(text) = content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
|
@ -521,7 +544,7 @@ async fn test_profiles(cx: &mut TestAppContext) {
|
|||
// Test that test-1 profile (default) has echo and delay tools
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_profile(AgentProfileId("test-1".into()));
|
||||
thread.send("test", cx);
|
||||
thread.send(UserMessageId::new(), ["test"], cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
|
@ -539,7 +562,7 @@ async fn test_profiles(cx: &mut TestAppContext) {
|
|||
// Switch to test-2 profile, and verify that it has only the infinite tool.
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_profile(AgentProfileId("test-2".into()));
|
||||
thread.send("test2", cx)
|
||||
thread.send(UserMessageId::new(), ["test2"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
let mut pending_completions = fake_model.pending_completions();
|
||||
|
@ -562,7 +585,8 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
|||
thread.add_tool(InfiniteTool);
|
||||
thread.add_tool(EchoTool);
|
||||
thread.send(
|
||||
"Call the echo tool and then call the infinite tool, then explain their output",
|
||||
UserMessageId::new(),
|
||||
["Call the echo tool, then call the infinite tool, then explain their output"],
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
@ -607,14 +631,20 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
|||
// Ensure we can still send a new message after cancellation.
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send("Testing: reply with 'Hello' then stop.", cx)
|
||||
thread.send(
|
||||
UserMessageId::new(),
|
||||
["Testing: reply with 'Hello' then stop."],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
thread.update(cx, |thread, _cx| {
|
||||
let message = thread.last_message().unwrap();
|
||||
let agent_message = message.as_agent_message().unwrap();
|
||||
assert_eq!(
|
||||
thread.messages().last().unwrap().content,
|
||||
vec![MessageContent::Text("Hello".to_string())]
|
||||
agent_message.content,
|
||||
vec![AgentMessageContent::Text("Hello".to_string())]
|
||||
);
|
||||
});
|
||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||
|
@ -625,13 +655,16 @@ async fn test_refusal(cx: &mut TestAppContext) {
|
|||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
|
||||
let events = thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Hello"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## user
|
||||
## User
|
||||
|
||||
Hello
|
||||
"}
|
||||
);
|
||||
|
@ -643,9 +676,12 @@ async fn test_refusal(cx: &mut TestAppContext) {
|
|||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## user
|
||||
## User
|
||||
|
||||
Hello
|
||||
## assistant
|
||||
|
||||
## Assistant
|
||||
|
||||
Hey!
|
||||
"}
|
||||
);
|
||||
|
@ -661,6 +697,85 @@ async fn test_refusal(cx: &mut TestAppContext) {
|
|||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_truncate(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let message_id = UserMessageId::new();
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(message_id.clone(), ["Hello"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hello
|
||||
"}
|
||||
);
|
||||
});
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hello
|
||||
|
||||
## Assistant
|
||||
|
||||
Hey!
|
||||
"}
|
||||
);
|
||||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, _cx| thread.truncate(message_id))
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.to_markdown(), "");
|
||||
});
|
||||
|
||||
// Ensure we can still send a new message after truncation.
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Hi"], cx)
|
||||
});
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hi
|
||||
"}
|
||||
);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hi
|
||||
|
||||
## Assistant
|
||||
|
||||
Ahoy!
|
||||
"}
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
cx.update(settings::init);
|
||||
|
@ -774,6 +889,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
|||
let result = cx
|
||||
.update(|cx| {
|
||||
connection.prompt(
|
||||
Some(acp_thread::UserMessageId::new()),
|
||||
acp::PromptRequest {
|
||||
session_id: session_id.clone(),
|
||||
prompt: vec!["ghi".into()],
|
||||
|
@ -796,7 +912,9 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
|||
thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Think"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
// Simulate streaming partial input.
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
|
||||
use acp_thread::MentionUri;
|
||||
use acp_thread::{MentionUri, UserMessageId};
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::{AgentProfileId, AgentSettings};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::adapt_schema_to_format;
|
||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
||||
use collections::HashMap;
|
||||
use collections::IndexMap;
|
||||
use fs::Fs;
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
|
@ -19,7 +19,6 @@ use language_model::{
|
|||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||
};
|
||||
use log;
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
use schemars::{JsonSchema, Schema};
|
||||
|
@ -30,49 +29,199 @@ use std::fmt::Write;
|
|||
use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
|
||||
use util::{ResultExt, markdown::MarkdownCodeBlock};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgentMessage {
|
||||
pub role: Role,
|
||||
pub content: Vec<MessageContent>,
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum Message {
|
||||
User(UserMessage),
|
||||
Agent(AgentMessage),
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn as_agent_message(&self) -> Option<&AgentMessage> {
|
||||
match self {
|
||||
Message::Agent(agent_message) => Some(agent_message),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_markdown(&self) -> String {
|
||||
match self {
|
||||
Message::User(message) => message.to_markdown(),
|
||||
Message::Agent(message) => message.to_markdown(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum MessageContent {
|
||||
pub struct UserMessage {
|
||||
pub id: UserMessageId,
|
||||
pub content: Vec<UserMessageContent>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum UserMessageContent {
|
||||
Text(String),
|
||||
Thinking {
|
||||
text: String,
|
||||
signature: Option<String>,
|
||||
},
|
||||
Mention {
|
||||
uri: MentionUri,
|
||||
content: String,
|
||||
},
|
||||
RedactedThinking(String),
|
||||
Mention { uri: MentionUri, content: String },
|
||||
Image(LanguageModelImage),
|
||||
ToolUse(LanguageModelToolUse),
|
||||
ToolResult(LanguageModelToolResult),
|
||||
}
|
||||
|
||||
impl UserMessage {
|
||||
pub fn to_markdown(&self) -> String {
|
||||
let mut markdown = String::from("## User\n\n");
|
||||
|
||||
for content in &self.content {
|
||||
match content {
|
||||
UserMessageContent::Text(text) => {
|
||||
markdown.push_str(text);
|
||||
markdown.push('\n');
|
||||
}
|
||||
UserMessageContent::Image(_) => {
|
||||
markdown.push_str("<image />\n");
|
||||
}
|
||||
UserMessageContent::Mention { uri, content } => {
|
||||
if !content.is_empty() {
|
||||
markdown.push_str(&format!("{}\n\n{}\n", uri.to_link(), content));
|
||||
} else {
|
||||
markdown.push_str(&format!("{}\n", uri.to_link()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
markdown
|
||||
}
|
||||
|
||||
fn to_request(&self) -> LanguageModelRequestMessage {
|
||||
let mut message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::with_capacity(self.content.len()),
|
||||
cache: false,
|
||||
};
|
||||
|
||||
const OPEN_CONTEXT: &str = "<context>\n\
|
||||
The following items were attached by the user. \
|
||||
They are up-to-date and don't need to be re-read.\n\n";
|
||||
|
||||
const OPEN_FILES_TAG: &str = "<files>";
|
||||
const OPEN_SYMBOLS_TAG: &str = "<symbols>";
|
||||
const OPEN_THREADS_TAG: &str = "<threads>";
|
||||
const OPEN_RULES_TAG: &str =
|
||||
"<rules>\nThe user has specified the following rules that should be applied:\n";
|
||||
|
||||
let mut file_context = OPEN_FILES_TAG.to_string();
|
||||
let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
|
||||
let mut thread_context = OPEN_THREADS_TAG.to_string();
|
||||
let mut rules_context = OPEN_RULES_TAG.to_string();
|
||||
|
||||
for chunk in &self.content {
|
||||
let chunk = match chunk {
|
||||
UserMessageContent::Text(text) => {
|
||||
language_model::MessageContent::Text(text.clone())
|
||||
}
|
||||
UserMessageContent::Image(value) => {
|
||||
language_model::MessageContent::Image(value.clone())
|
||||
}
|
||||
UserMessageContent::Mention { uri, content } => {
|
||||
match uri {
|
||||
MentionUri::File(path) | MentionUri::Symbol(path, _) => {
|
||||
write!(
|
||||
&mut symbol_context,
|
||||
"\n{}",
|
||||
MarkdownCodeBlock {
|
||||
tag: &codeblock_tag(&path),
|
||||
text: &content.to_string(),
|
||||
}
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
MentionUri::Thread(_session_id) => {
|
||||
write!(&mut thread_context, "\n{}\n", content).ok();
|
||||
}
|
||||
MentionUri::Rule(_user_prompt_id) => {
|
||||
write!(
|
||||
&mut rules_context,
|
||||
"\n{}",
|
||||
MarkdownCodeBlock {
|
||||
tag: "",
|
||||
text: &content
|
||||
}
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
language_model::MessageContent::Text(uri.to_link())
|
||||
}
|
||||
};
|
||||
|
||||
message.content.push(chunk);
|
||||
}
|
||||
|
||||
let len_before_context = message.content.len();
|
||||
|
||||
if file_context.len() > OPEN_FILES_TAG.len() {
|
||||
file_context.push_str("</files>\n");
|
||||
message
|
||||
.content
|
||||
.push(language_model::MessageContent::Text(file_context));
|
||||
}
|
||||
|
||||
if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
|
||||
symbol_context.push_str("</symbols>\n");
|
||||
message
|
||||
.content
|
||||
.push(language_model::MessageContent::Text(symbol_context));
|
||||
}
|
||||
|
||||
if thread_context.len() > OPEN_THREADS_TAG.len() {
|
||||
thread_context.push_str("</threads>\n");
|
||||
message
|
||||
.content
|
||||
.push(language_model::MessageContent::Text(thread_context));
|
||||
}
|
||||
|
||||
if rules_context.len() > OPEN_RULES_TAG.len() {
|
||||
rules_context.push_str("</user_rules>\n");
|
||||
message
|
||||
.content
|
||||
.push(language_model::MessageContent::Text(rules_context));
|
||||
}
|
||||
|
||||
if message.content.len() > len_before_context {
|
||||
message.content.insert(
|
||||
len_before_context,
|
||||
language_model::MessageContent::Text(OPEN_CONTEXT.into()),
|
||||
);
|
||||
message
|
||||
.content
|
||||
.push(language_model::MessageContent::Text("</context>".into()));
|
||||
}
|
||||
|
||||
message
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentMessage {
|
||||
pub fn to_markdown(&self) -> String {
|
||||
let mut markdown = format!("## {}\n", self.role);
|
||||
let mut markdown = String::from("## Assistant\n\n");
|
||||
|
||||
for content in &self.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
AgentMessageContent::Text(text) => {
|
||||
markdown.push_str(text);
|
||||
markdown.push('\n');
|
||||
}
|
||||
MessageContent::Thinking { text, .. } => {
|
||||
AgentMessageContent::Thinking { text, .. } => {
|
||||
markdown.push_str("<think>");
|
||||
markdown.push_str(text);
|
||||
markdown.push_str("</think>\n");
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
|
||||
MessageContent::Image(_) => {
|
||||
AgentMessageContent::RedactedThinking(_) => {
|
||||
markdown.push_str("<redacted_thinking />\n")
|
||||
}
|
||||
AgentMessageContent::Image(_) => {
|
||||
markdown.push_str("<image />\n");
|
||||
}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
AgentMessageContent::ToolUse(tool_use) => {
|
||||
markdown.push_str(&format!(
|
||||
"**Tool Use**: {} (ID: {})\n",
|
||||
tool_use.name, tool_use.id
|
||||
|
@ -85,41 +234,106 @@ impl AgentMessage {
|
|||
}
|
||||
));
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
markdown.push_str(&format!(
|
||||
"**Tool Result**: {} (ID: {})\n\n",
|
||||
tool_result.tool_name, tool_result.tool_use_id
|
||||
));
|
||||
if tool_result.is_error {
|
||||
markdown.push_str("**ERROR:**\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
writeln!(markdown, "{text}\n").ok();
|
||||
}
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
writeln!(markdown, "<image />\n").ok();
|
||||
}
|
||||
}
|
||||
for tool_result in self.tool_results.values() {
|
||||
markdown.push_str(&format!(
|
||||
"**Tool Result**: {} (ID: {})\n\n",
|
||||
tool_result.tool_name, tool_result.tool_use_id
|
||||
));
|
||||
if tool_result.is_error {
|
||||
markdown.push_str("**ERROR:**\n");
|
||||
}
|
||||
|
||||
if let Some(output) = tool_result.output.as_ref() {
|
||||
writeln!(
|
||||
markdown,
|
||||
"**Debug Output**:\n\n```json\n{}\n```\n",
|
||||
serde_json::to_string_pretty(output).unwrap()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
writeln!(markdown, "{text}\n").ok();
|
||||
}
|
||||
MessageContent::Mention { uri, .. } => {
|
||||
write!(markdown, "{}", uri.to_link()).ok();
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
writeln!(markdown, "<image />\n").ok();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(output) = tool_result.output.as_ref() {
|
||||
writeln!(
|
||||
markdown,
|
||||
"**Debug Output**:\n\n```json\n{}\n```\n",
|
||||
serde_json::to_string_pretty(output).unwrap()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
markdown
|
||||
}
|
||||
|
||||
pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
|
||||
let mut content = Vec::with_capacity(self.content.len());
|
||||
for chunk in &self.content {
|
||||
let chunk = match chunk {
|
||||
AgentMessageContent::Text(text) => {
|
||||
language_model::MessageContent::Text(text.clone())
|
||||
}
|
||||
AgentMessageContent::Thinking { text, signature } => {
|
||||
language_model::MessageContent::Thinking {
|
||||
text: text.clone(),
|
||||
signature: signature.clone(),
|
||||
}
|
||||
}
|
||||
AgentMessageContent::RedactedThinking(value) => {
|
||||
language_model::MessageContent::RedactedThinking(value.clone())
|
||||
}
|
||||
AgentMessageContent::ToolUse(value) => {
|
||||
language_model::MessageContent::ToolUse(value.clone())
|
||||
}
|
||||
AgentMessageContent::Image(value) => {
|
||||
language_model::MessageContent::Image(value.clone())
|
||||
}
|
||||
};
|
||||
content.push(chunk);
|
||||
}
|
||||
|
||||
let mut messages = vec![LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content,
|
||||
cache: false,
|
||||
}];
|
||||
|
||||
if !self.tool_results.is_empty() {
|
||||
let mut tool_results = Vec::with_capacity(self.tool_results.len());
|
||||
for tool_result in self.tool_results.values() {
|
||||
tool_results.push(language_model::MessageContent::ToolResult(
|
||||
tool_result.clone(),
|
||||
));
|
||||
}
|
||||
messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: tool_results,
|
||||
cache: false,
|
||||
});
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct AgentMessage {
|
||||
pub content: Vec<AgentMessageContent>,
|
||||
pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum AgentMessageContent {
|
||||
Text(String),
|
||||
Thinking {
|
||||
text: String,
|
||||
signature: Option<String>,
|
||||
},
|
||||
RedactedThinking(String),
|
||||
Image(LanguageModelImage),
|
||||
ToolUse(LanguageModelToolUse),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -140,13 +354,13 @@ pub struct ToolCallAuthorization {
|
|||
}
|
||||
|
||||
pub struct Thread {
|
||||
messages: Vec<AgentMessage>,
|
||||
messages: Vec<Message>,
|
||||
completion_mode: CompletionMode,
|
||||
/// Holds the task that handles agent interaction until the end of the turn.
|
||||
/// Survives across multiple requests as the model performs tool calls and
|
||||
/// we run tools, report their results.
|
||||
running_turn: Option<Task<()>>,
|
||||
pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
|
||||
pending_agent_message: Option<AgentMessage>,
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
profile_id: AgentProfileId,
|
||||
|
@ -172,7 +386,7 @@ impl Thread {
|
|||
messages: Vec::new(),
|
||||
completion_mode: CompletionMode::Normal,
|
||||
running_turn: None,
|
||||
pending_tool_uses: HashMap::default(),
|
||||
pending_agent_message: None,
|
||||
tools: BTreeMap::default(),
|
||||
context_server_registry,
|
||||
profile_id,
|
||||
|
@ -196,8 +410,13 @@ impl Thread {
|
|||
self.completion_mode = mode;
|
||||
}
|
||||
|
||||
pub fn messages(&self) -> &[AgentMessage] {
|
||||
&self.messages
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn last_message(&self) -> Option<Message> {
|
||||
if let Some(message) = self.pending_agent_message.clone() {
|
||||
Some(Message::Agent(message))
|
||||
} else {
|
||||
self.messages.last().cloned()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_tool(&mut self, tool: impl AgentTool) {
|
||||
|
@ -213,35 +432,36 @@ impl Thread {
|
|||
}
|
||||
|
||||
pub fn cancel(&mut self) {
|
||||
// TODO: do we need to emit a stop::cancel for ACP?
|
||||
self.running_turn.take();
|
||||
self.flush_pending_agent_message();
|
||||
}
|
||||
|
||||
let tool_results = self
|
||||
.pending_tool_uses
|
||||
.drain()
|
||||
.map(|(tool_use_id, tool_use)| {
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id,
|
||||
tool_name: tool_use.name.clone(),
|
||||
is_error: true,
|
||||
content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
|
||||
output: None,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
self.last_user_message().content.extend(tool_results);
|
||||
pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
|
||||
self.cancel();
|
||||
let Some(position) = self.messages.iter().position(
|
||||
|msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
|
||||
) else {
|
||||
return Err(anyhow!("Message not found"));
|
||||
};
|
||||
self.messages.truncate(position);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sending a message results in the model streaming a response, which could include tool calls.
|
||||
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
|
||||
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
||||
pub fn send(
|
||||
pub fn send<T>(
|
||||
&mut self,
|
||||
content: impl Into<UserMessage>,
|
||||
message_id: UserMessageId,
|
||||
content: impl IntoIterator<Item = T>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
|
||||
let content = content.into().0;
|
||||
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>
|
||||
where
|
||||
T: Into<UserMessageContent>,
|
||||
{
|
||||
let model = self.selected_model.clone();
|
||||
let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
|
||||
log::info!("Thread::send called with model: {:?}", model.name());
|
||||
log::debug!("Thread::send content: {:?}", content);
|
||||
|
||||
|
@ -251,10 +471,10 @@ impl Thread {
|
|||
let event_stream = AgentResponseEventStream(events_tx);
|
||||
|
||||
let user_message_ix = self.messages.len();
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::User,
|
||||
self.messages.push(Message::User(UserMessage {
|
||||
id: message_id,
|
||||
content,
|
||||
});
|
||||
}));
|
||||
log::info!("Total messages in thread: {}", self.messages.len());
|
||||
self.running_turn = Some(cx.spawn(async move |thread, cx| {
|
||||
log::info!("Starting agent turn execution");
|
||||
|
@ -270,15 +490,11 @@ impl Thread {
|
|||
thread.build_completion_request(completion_intent, cx)
|
||||
})?;
|
||||
|
||||
// println!(
|
||||
// "request: {}",
|
||||
// serde_json::to_string_pretty(&request).unwrap()
|
||||
// );
|
||||
|
||||
// Stream events, appending to messages and collecting up tool uses.
|
||||
log::info!("Calling model.stream_completion");
|
||||
let mut events = model.stream_completion(request, cx).await?;
|
||||
log::debug!("Stream completion started successfully");
|
||||
|
||||
let mut tool_uses = FuturesUnordered::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
|
@ -286,6 +502,7 @@ impl Thread {
|
|||
event_stream.send_stop(reason);
|
||||
if reason == StopReason::Refusal {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.pending_agent_message = None;
|
||||
thread.messages.truncate(user_message_ix);
|
||||
})?;
|
||||
break 'outer;
|
||||
|
@ -338,15 +555,16 @@ impl Thread {
|
|||
);
|
||||
thread
|
||||
.update(cx, |thread, _cx| {
|
||||
thread.pending_tool_uses.remove(&tool_result.tool_use_id);
|
||||
thread
|
||||
.last_user_message()
|
||||
.content
|
||||
.push(MessageContent::ToolResult(tool_result));
|
||||
.pending_agent_message()
|
||||
.tool_results
|
||||
.insert(tool_result.tool_use_id.clone(), tool_result);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
thread.update(cx, |thread, _cx| thread.flush_pending_agent_message())?;
|
||||
|
||||
completion_intent = CompletionIntent::ToolResults;
|
||||
}
|
||||
|
||||
|
@ -354,6 +572,10 @@ impl Thread {
|
|||
}
|
||||
.await;
|
||||
|
||||
thread
|
||||
.update(cx, |thread, _cx| thread.flush_pending_agent_message())
|
||||
.ok();
|
||||
|
||||
if let Err(error) = turn_result {
|
||||
log::error!("Turn execution failed: {:?}", error);
|
||||
event_stream.send_error(error);
|
||||
|
@ -364,7 +586,7 @@ impl Thread {
|
|||
events_rx
|
||||
}
|
||||
|
||||
pub fn build_system_message(&self) -> AgentMessage {
|
||||
pub fn build_system_message(&self) -> LanguageModelRequestMessage {
|
||||
log::debug!("Building system message");
|
||||
let prompt = SystemPromptTemplate {
|
||||
project: &self.project_context.borrow(),
|
||||
|
@ -374,9 +596,10 @@ impl Thread {
|
|||
.context("failed to build system prompt")
|
||||
.expect("Invalid template");
|
||||
log::debug!("System message built");
|
||||
AgentMessage {
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: vec![prompt.as_str().into()],
|
||||
content: vec![prompt.into()],
|
||||
cache: true,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -394,10 +617,7 @@ impl Thread {
|
|||
|
||||
match event {
|
||||
StartMessage { .. } => {
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::Assistant,
|
||||
content: Vec::new(),
|
||||
});
|
||||
self.messages.push(Message::Agent(AgentMessage::default()));
|
||||
}
|
||||
Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
|
||||
Thinking { text, signature } => {
|
||||
|
@ -435,11 +655,13 @@ impl Thread {
|
|||
) {
|
||||
events_stream.send_text(&new_text);
|
||||
|
||||
let last_message = self.last_assistant_message();
|
||||
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
|
||||
let last_message = self.pending_agent_message();
|
||||
if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
|
||||
text.push_str(&new_text);
|
||||
} else {
|
||||
last_message.content.push(MessageContent::Text(new_text));
|
||||
last_message
|
||||
.content
|
||||
.push(AgentMessageContent::Text(new_text));
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
|
@ -454,13 +676,14 @@ impl Thread {
|
|||
) {
|
||||
event_stream.send_thinking(&new_text);
|
||||
|
||||
let last_message = self.last_assistant_message();
|
||||
if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
|
||||
let last_message = self.pending_agent_message();
|
||||
if let Some(AgentMessageContent::Thinking { text, signature }) =
|
||||
last_message.content.last_mut()
|
||||
{
|
||||
text.push_str(&new_text);
|
||||
*signature = new_signature.or(signature.take());
|
||||
} else {
|
||||
last_message.content.push(MessageContent::Thinking {
|
||||
last_message.content.push(AgentMessageContent::Thinking {
|
||||
text: new_text,
|
||||
signature: new_signature,
|
||||
});
|
||||
|
@ -470,10 +693,10 @@ impl Thread {
|
|||
}
|
||||
|
||||
fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
|
||||
let last_message = self.last_assistant_message();
|
||||
let last_message = self.pending_agent_message();
|
||||
last_message
|
||||
.content
|
||||
.push(MessageContent::RedactedThinking(data));
|
||||
.push(AgentMessageContent::RedactedThinking(data));
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
|
@ -486,14 +709,17 @@ impl Thread {
|
|||
cx.notify();
|
||||
|
||||
let tool = self.tools.get(tool_use.name.as_ref()).cloned();
|
||||
|
||||
self.pending_tool_uses
|
||||
.insert(tool_use.id.clone(), tool_use.clone());
|
||||
let last_message = self.last_assistant_message();
|
||||
let mut title = SharedString::from(&tool_use.name);
|
||||
let mut kind = acp::ToolKind::Other;
|
||||
if let Some(tool) = tool.as_ref() {
|
||||
title = tool.initial_title(tool_use.input.clone());
|
||||
kind = tool.kind();
|
||||
}
|
||||
|
||||
// Ensure the last message ends in the current tool use
|
||||
let last_message = self.pending_agent_message();
|
||||
let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
|
||||
if let MessageContent::ToolUse(last_tool_use) = content {
|
||||
if let AgentMessageContent::ToolUse(last_tool_use) = content {
|
||||
if last_tool_use.id == tool_use.id {
|
||||
*last_tool_use = tool_use.clone();
|
||||
false
|
||||
|
@ -505,18 +731,11 @@ impl Thread {
|
|||
}
|
||||
});
|
||||
|
||||
let mut title = SharedString::from(&tool_use.name);
|
||||
let mut kind = acp::ToolKind::Other;
|
||||
if let Some(tool) = tool.as_ref() {
|
||||
title = tool.initial_title(tool_use.input.clone());
|
||||
kind = tool.kind();
|
||||
}
|
||||
|
||||
if push_new_tool_use {
|
||||
event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
|
||||
last_message
|
||||
.content
|
||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||
.push(AgentMessageContent::ToolUse(tool_use.clone()));
|
||||
} else {
|
||||
event_stream.update_tool_call_fields(
|
||||
&tool_use.id,
|
||||
|
@ -601,30 +820,37 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
/// Guarantees the last message is from the assistant and returns a mutable reference.
|
||||
fn last_assistant_message(&mut self) -> &mut AgentMessage {
|
||||
if self
|
||||
.messages
|
||||
.last()
|
||||
.map_or(true, |m| m.role != Role::Assistant)
|
||||
{
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::Assistant,
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
self.messages.last_mut().unwrap()
|
||||
fn pending_agent_message(&mut self) -> &mut AgentMessage {
|
||||
self.pending_agent_message.get_or_insert_default()
|
||||
}
|
||||
|
||||
/// Guarantees the last message is from the user and returns a mutable reference.
|
||||
fn last_user_message(&mut self) -> &mut AgentMessage {
|
||||
if self.messages.last().map_or(true, |m| m.role != Role::User) {
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
});
|
||||
fn flush_pending_agent_message(&mut self) {
|
||||
let Some(mut message) = self.pending_agent_message.take() else {
|
||||
return;
|
||||
};
|
||||
|
||||
for content in &message.content {
|
||||
let AgentMessageContent::ToolUse(tool_use) = content else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if !message.tool_results.contains_key(&tool_use.id) {
|
||||
message.tool_results.insert(
|
||||
tool_use.id.clone(),
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id.clone(),
|
||||
tool_name: tool_use.name.clone(),
|
||||
is_error: true,
|
||||
content: LanguageModelToolResultContent::Text(
|
||||
"Tool canceled by user".into(),
|
||||
),
|
||||
output: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
self.messages.last_mut().unwrap()
|
||||
|
||||
self.messages.push(Message::Agent(message));
|
||||
}
|
||||
|
||||
pub(crate) fn build_completion_request(
|
||||
|
@ -712,49 +938,39 @@ impl Thread {
|
|||
"Building request messages from {} thread messages",
|
||||
self.messages.len()
|
||||
);
|
||||
let mut messages = vec![self.build_system_message()];
|
||||
for message in &self.messages {
|
||||
match message {
|
||||
Message::User(message) => messages.push(message.to_request()),
|
||||
Message::Agent(message) => messages.extend(message.to_request()),
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(message) = self.pending_agent_message.as_ref() {
|
||||
messages.extend(message.to_request());
|
||||
}
|
||||
|
||||
let messages = Some(self.build_system_message())
|
||||
.iter()
|
||||
.chain(self.messages.iter())
|
||||
.map(|message| {
|
||||
log::trace!(
|
||||
" - {} message with {} content items",
|
||||
match message.role {
|
||||
Role::System => "System",
|
||||
Role::User => "User",
|
||||
Role::Assistant => "Assistant",
|
||||
},
|
||||
message.content.len()
|
||||
);
|
||||
message.to_request()
|
||||
})
|
||||
.collect();
|
||||
messages
|
||||
}
|
||||
|
||||
pub fn to_markdown(&self) -> String {
|
||||
let mut markdown = String::new();
|
||||
for message in &self.messages {
|
||||
for (ix, message) in self.messages.iter().enumerate() {
|
||||
if ix > 0 {
|
||||
markdown.push('\n');
|
||||
}
|
||||
markdown.push_str(&message.to_markdown());
|
||||
}
|
||||
|
||||
if let Some(message) = self.pending_agent_message.as_ref() {
|
||||
markdown.push('\n');
|
||||
markdown.push_str(&message.to_markdown());
|
||||
}
|
||||
|
||||
markdown
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UserMessage(Vec<MessageContent>);
|
||||
|
||||
impl From<Vec<MessageContent>> for UserMessage {
|
||||
fn from(content: Vec<MessageContent>) -> Self {
|
||||
UserMessage(content)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Into<MessageContent>> From<T> for UserMessage {
|
||||
fn from(content: T) -> Self {
|
||||
UserMessage(vec![content.into()])
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AgentTool
|
||||
where
|
||||
Self: 'static + Sized,
|
||||
|
@ -1151,130 +1367,6 @@ impl std::ops::DerefMut for ToolCallEventStreamReceiver {
|
|||
}
|
||||
}
|
||||
|
||||
impl AgentMessage {
|
||||
fn to_request(&self) -> language_model::LanguageModelRequestMessage {
|
||||
let mut message = LanguageModelRequestMessage {
|
||||
role: self.role,
|
||||
content: Vec::with_capacity(self.content.len()),
|
||||
cache: false,
|
||||
};
|
||||
|
||||
const OPEN_CONTEXT: &str = "<context>\n\
|
||||
The following items were attached by the user. \
|
||||
They are up-to-date and don't need to be re-read.\n\n";
|
||||
|
||||
const OPEN_FILES_TAG: &str = "<files>";
|
||||
const OPEN_SYMBOLS_TAG: &str = "<symbols>";
|
||||
const OPEN_THREADS_TAG: &str = "<threads>";
|
||||
const OPEN_RULES_TAG: &str =
|
||||
"<rules>\nThe user has specified the following rules that should be applied:\n";
|
||||
|
||||
let mut file_context = OPEN_FILES_TAG.to_string();
|
||||
let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
|
||||
let mut thread_context = OPEN_THREADS_TAG.to_string();
|
||||
let mut rules_context = OPEN_RULES_TAG.to_string();
|
||||
|
||||
for chunk in &self.content {
|
||||
let chunk = match chunk {
|
||||
MessageContent::Text(text) => language_model::MessageContent::Text(text.clone()),
|
||||
MessageContent::Thinking { text, signature } => {
|
||||
language_model::MessageContent::Thinking {
|
||||
text: text.clone(),
|
||||
signature: signature.clone(),
|
||||
}
|
||||
}
|
||||
MessageContent::RedactedThinking(value) => {
|
||||
language_model::MessageContent::RedactedThinking(value.clone())
|
||||
}
|
||||
MessageContent::ToolUse(value) => {
|
||||
language_model::MessageContent::ToolUse(value.clone())
|
||||
}
|
||||
MessageContent::ToolResult(value) => {
|
||||
language_model::MessageContent::ToolResult(value.clone())
|
||||
}
|
||||
MessageContent::Image(value) => {
|
||||
language_model::MessageContent::Image(value.clone())
|
||||
}
|
||||
MessageContent::Mention { uri, content } => {
|
||||
match uri {
|
||||
MentionUri::File(path) | MentionUri::Symbol(path, _) => {
|
||||
write!(
|
||||
&mut symbol_context,
|
||||
"\n{}",
|
||||
MarkdownCodeBlock {
|
||||
tag: &codeblock_tag(&path),
|
||||
text: &content.to_string(),
|
||||
}
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
MentionUri::Thread(_session_id) => {
|
||||
write!(&mut thread_context, "\n{}\n", content).ok();
|
||||
}
|
||||
MentionUri::Rule(_user_prompt_id) => {
|
||||
write!(
|
||||
&mut rules_context,
|
||||
"\n{}",
|
||||
MarkdownCodeBlock {
|
||||
tag: "",
|
||||
text: &content
|
||||
}
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
language_model::MessageContent::Text(uri.to_link())
|
||||
}
|
||||
};
|
||||
|
||||
message.content.push(chunk);
|
||||
}
|
||||
|
||||
let len_before_context = message.content.len();
|
||||
|
||||
if file_context.len() > OPEN_FILES_TAG.len() {
|
||||
file_context.push_str("</files>\n");
|
||||
message
|
||||
.content
|
||||
.push(language_model::MessageContent::Text(file_context));
|
||||
}
|
||||
|
||||
if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
|
||||
symbol_context.push_str("</symbols>\n");
|
||||
message
|
||||
.content
|
||||
.push(language_model::MessageContent::Text(symbol_context));
|
||||
}
|
||||
|
||||
if thread_context.len() > OPEN_THREADS_TAG.len() {
|
||||
thread_context.push_str("</threads>\n");
|
||||
message
|
||||
.content
|
||||
.push(language_model::MessageContent::Text(thread_context));
|
||||
}
|
||||
|
||||
if rules_context.len() > OPEN_RULES_TAG.len() {
|
||||
rules_context.push_str("</user_rules>\n");
|
||||
message
|
||||
.content
|
||||
.push(language_model::MessageContent::Text(rules_context));
|
||||
}
|
||||
|
||||
if message.content.len() > len_before_context {
|
||||
message.content.insert(
|
||||
len_before_context,
|
||||
language_model::MessageContent::Text(OPEN_CONTEXT.into()),
|
||||
);
|
||||
message
|
||||
.content
|
||||
.push(language_model::MessageContent::Text("</context>".into()));
|
||||
}
|
||||
|
||||
message
|
||||
}
|
||||
}
|
||||
|
||||
fn codeblock_tag(full_path: &Path) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
|
@ -1287,16 +1379,20 @@ fn codeblock_tag(full_path: &Path) -> String {
|
|||
result
|
||||
}
|
||||
|
||||
impl From<acp::ContentBlock> for MessageContent {
|
||||
impl From<&str> for UserMessageContent {
|
||||
fn from(text: &str) -> Self {
|
||||
Self::Text(text.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<acp::ContentBlock> for UserMessageContent {
|
||||
fn from(value: acp::ContentBlock) -> Self {
|
||||
match value {
|
||||
acp::ContentBlock::Text(text_content) => MessageContent::Text(text_content.text),
|
||||
acp::ContentBlock::Image(image_content) => {
|
||||
MessageContent::Image(convert_image(image_content))
|
||||
}
|
||||
acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
|
||||
acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
|
||||
acp::ContentBlock::Audio(_) => {
|
||||
// TODO
|
||||
MessageContent::Text("[audio]".to_string())
|
||||
Self::Text("[audio]".to_string())
|
||||
}
|
||||
acp::ContentBlock::ResourceLink(resource_link) => {
|
||||
match MentionUri::parse(&resource_link.uri) {
|
||||
|
@ -1306,10 +1402,7 @@ impl From<acp::ContentBlock> for MessageContent {
|
|||
},
|
||||
Err(err) => {
|
||||
log::error!("Failed to parse mention link: {}", err);
|
||||
MessageContent::Text(format!(
|
||||
"[{}]({})",
|
||||
resource_link.name, resource_link.uri
|
||||
))
|
||||
Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1322,7 +1415,7 @@ impl From<acp::ContentBlock> for MessageContent {
|
|||
},
|
||||
Err(err) => {
|
||||
log::error!("Failed to parse mention link: {}", err);
|
||||
MessageContent::Text(
|
||||
Self::Text(
|
||||
MarkdownCodeBlock {
|
||||
tag: &resource.uri,
|
||||
text: &resource.text,
|
||||
|
@ -1334,7 +1427,7 @@ impl From<acp::ContentBlock> for MessageContent {
|
|||
}
|
||||
acp::EmbeddedResourceResource::BlobResourceContents(_) => {
|
||||
// TODO
|
||||
MessageContent::Text("[blob]".to_string())
|
||||
Self::Text("[blob]".to_string())
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -1348,9 +1441,3 @@ fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
|
|||
size: gpui::Size::new(0.into(), 0.into()),
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for MessageContent {
|
||||
fn from(text: &str) -> Self {
|
||||
MessageContent::Text(text.into())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue