agent: Snapshot context in user message instead of recreating it (#27967)

This makes context essentially work the same way as `read-file`,
increasing the likelihood of cache hits.

Just like with `read-file`, we'll notify the model when the user makes
an edit to one of the tracked files. In the future, we want to send a
diff instead of just a list of files, but that's an orthogonal change.


Release Notes:
- agent: Improved caching of files in context

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Agus Zubiaga 2025-04-03 15:52:28 -03:00 committed by GitHub
parent 0c82541f0a
commit 315f1bf168
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 551 additions and 136 deletions

View file

@ -34,7 +34,7 @@ use ui::{Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, Tooltip,
use util::ResultExt as _;
use workspace::{OpenOptions, Workspace};
use crate::context_store::{ContextStore, refresh_context_store_text};
use crate::context_store::ContextStore;
pub struct ActiveThread {
language_registry: Arc<LanguageRegistry>,
@ -593,54 +593,14 @@ impl ActiveThread {
}
if self.thread.read(cx).all_tools_finished() {
let pending_refresh_buffers = self.thread.update(cx, |thread, cx| {
thread.action_log().update(cx, |action_log, _cx| {
action_log.take_stale_buffers_in_context()
})
});
let context_update_task = if !pending_refresh_buffers.is_empty() {
let refresh_task = refresh_context_store_text(
self.context_store.clone(),
&pending_refresh_buffers,
cx,
);
cx.spawn(async move |this, cx| {
let updated_context_ids = refresh_task.await;
this.update(cx, |this, cx| {
this.context_store.read_with(cx, |context_store, _cx| {
context_store
.context()
.iter()
.filter(|context| {
updated_context_ids.contains(&context.id())
})
.cloned()
.collect()
})
})
})
} else {
Task::ready(anyhow::Ok(Vec::new()))
};
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() {
cx.spawn(async move |this, cx| {
let updated_context = context_update_task.await?;
this.update(cx, |this, cx| {
this.thread.update(cx, |thread, cx| {
thread.attach_tool_results(updated_context, cx);
if !canceled {
thread.send_to_model(model, RequestKind::Chat, cx);
}
});
})
})
.detach();
self.thread.update(cx, |thread, cx| {
thread.attach_tool_results(cx);
if !canceled {
thread.send_to_model(model, RequestKind::Chat, cx);
}
});
}
}
}

View file

@ -146,11 +146,11 @@ pub struct ContextSymbolId {
pub range: Range<Anchor>,
}
pub fn attach_context_to_message<'a>(
message: &mut LanguageModelRequestMessage,
/// Formats a collection of contexts into a string representation
pub fn format_context_as_string<'a>(
contexts: impl Iterator<Item = &'a AssistantContext>,
cx: &App,
) {
) -> Option<String> {
let mut file_context = Vec::new();
let mut directory_context = Vec::new();
let mut symbol_context = Vec::new();
@ -167,64 +167,78 @@ pub fn attach_context_to_message<'a>(
}
}
let mut context_chunks = Vec::new();
if file_context.is_empty()
&& directory_context.is_empty()
&& symbol_context.is_empty()
&& fetch_context.is_empty()
&& thread_context.is_empty()
{
return None;
}
let mut result = String::new();
result.push_str("\n<context>\n\
The following items were attached by the user. You don't need to use other tools to read them.\n\n");
if !file_context.is_empty() {
context_chunks.push("<files>\n");
result.push_str("<files>\n");
for context in file_context {
context_chunks.push(&context.context_buffer.text);
result.push_str(&context.context_buffer.text);
}
context_chunks.push("\n</files>\n");
result.push_str("</files>\n");
}
if !directory_context.is_empty() {
context_chunks.push("<directories>\n");
result.push_str("<directories>\n");
for context in directory_context {
for context_buffer in &context.context_buffers {
context_chunks.push(&context_buffer.text);
result.push_str(&context_buffer.text);
}
}
context_chunks.push("\n</directories>\n");
result.push_str("</directories>\n");
}
if !symbol_context.is_empty() {
context_chunks.push("<symbols>\n");
result.push_str("<symbols>\n");
for context in symbol_context {
context_chunks.push(&context.context_symbol.text);
result.push_str(&context.context_symbol.text);
result.push('\n');
}
context_chunks.push("\n</symbols>\n");
result.push_str("</symbols>\n");
}
if !fetch_context.is_empty() {
context_chunks.push("<fetched_urls>\n");
result.push_str("<fetched_urls>\n");
for context in &fetch_context {
context_chunks.push(&context.url);
context_chunks.push(&context.text);
result.push_str(&context.url);
result.push('\n');
result.push_str(&context.text);
result.push('\n');
}
context_chunks.push("\n</fetched_urls>\n");
result.push_str("</fetched_urls>\n");
}
// Need to own the SharedString for summary so that it can be referenced.
let mut thread_context_chunks = Vec::new();
if !thread_context.is_empty() {
context_chunks.push("<conversation_threads>\n");
result.push_str("<conversation_threads>\n");
for context in &thread_context {
thread_context_chunks.push(context.summary(cx));
thread_context_chunks.push(context.text.clone());
result.push_str(&context.summary(cx));
result.push('\n');
result.push_str(&context.text);
result.push('\n');
}
context_chunks.push("\n</conversation_threads>\n");
result.push_str("</conversation_threads>\n");
}
for chunk in &thread_context_chunks {
context_chunks.push(chunk);
}
result.push_str("</context>\n");
Some(result)
}
if !context_chunks.is_empty() {
message.content.push(
"\n<context>\n\
The following items were attached by the user. You don't need to use other tools to read them.\n\n".into(),
);
message.content.push(context_chunks.join("\n").into());
message.content.push("\n</context>\n".into());
pub fn attach_context_to_message<'a>(
message: &mut LanguageModelRequestMessage,
contexts: impl Iterator<Item = &'a AssistantContext>,
cx: &App,
) {
if let Some(context_string) = format_context_as_string(contexts, cx) {
message.content.push(context_string.into());
}
}

View file

@ -7,7 +7,7 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet};
use collections::{BTreeMap, HashMap};
use fs::Fs;
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
@ -30,7 +30,7 @@ use settings::Settings;
use util::{ResultExt as _, TryFutureExt as _, maybe, post_inc};
use uuid::Uuid;
use crate::context::{AssistantContext, ContextId, attach_context_to_message};
use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse,
@ -82,6 +82,7 @@ pub struct Message {
pub id: MessageId,
pub role: Role,
pub segments: Vec<MessageSegment>,
pub context: String,
}
impl Message {
@ -110,6 +111,11 @@ impl Message {
pub fn to_string(&self) -> String {
let mut result = String::new();
if !self.context.is_empty() {
result.push_str(&self.context);
}
for segment in &self.segments {
match segment {
MessageSegment::Text(text) => result.push_str(text),
@ -120,11 +126,12 @@ impl Message {
}
}
}
result
}
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MessageSegment {
Text(String),
Thinking(String),
@ -335,6 +342,7 @@ impl Thread {
}
})
.collect(),
context: message.context,
})
.collect(),
next_message_id,
@ -595,15 +603,58 @@ impl Thread {
git_checkpoint: Option<GitStoreCheckpoint>,
cx: &mut Context<Self>,
) -> MessageId {
let message_id =
self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
let context_ids = context
let text = text.into();
let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
// Filter out contexts that have already been included in previous messages
let new_context: Vec<_> = context
.into_iter()
.filter(|ctx| !self.context.contains_key(&ctx.id()))
.collect();
if !new_context.is_empty() {
if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
message.context = context_string;
}
}
self.action_log.update(cx, |log, cx| {
// Track all buffers added as context
for ctx in &new_context {
match ctx {
AssistantContext::File(file_ctx) => {
log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
}
AssistantContext::Directory(dir_ctx) => {
for context_buffer in &dir_ctx.context_buffers {
log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
}
}
AssistantContext::Symbol(symbol_ctx) => {
log.buffer_added_as_context(
symbol_ctx.context_symbol.buffer.clone(),
cx,
);
}
AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
}
}
});
}
let context_ids = new_context
.iter()
.map(|context| context.id())
.collect::<Vec<_>>();
self.context
.extend(context.into_iter().map(|context| (context.id(), context)));
self.context.extend(
new_context
.into_iter()
.map(|context| (context.id(), context)),
);
self.context_by_message.insert(message_id, context_ids);
if let Some(git_checkpoint) = git_checkpoint {
self.pending_checkpoint = Some(ThreadCheckpoint {
message_id,
@ -620,7 +671,12 @@ impl Thread {
cx: &mut Context<Self>,
) -> MessageId {
let id = self.next_message_id.post_inc();
self.messages.push(Message { id, role, segments });
self.messages.push(Message {
id,
role,
segments,
context: String::new(),
});
self.touch_updated_at();
cx.emit(ThreadEvent::MessageAdded(id));
id
@ -726,6 +782,7 @@ impl Thread {
content: tool_result.content.clone(),
})
.collect(),
context: message.context.clone(),
})
.collect(),
initial_project_snapshot,
@ -912,8 +969,6 @@ impl Thread {
log::error!("system_prompt_context not set.")
}
let mut added_context_ids = HashSet::<ContextId>::default();
for message in &self.messages {
let mut request_message = LanguageModelRequestMessage {
role: message.role,
@ -934,23 +989,6 @@ impl Thread {
}
}
// Attach context to this message if it's the first to reference it
if let Some(context_ids) = self.context_by_message.get(&message.id) {
let new_context_ids: Vec<_> = context_ids
.iter()
.filter(|id| !added_context_ids.contains(id))
.collect();
if !new_context_ids.is_empty() {
let referenced_context = new_context_ids
.iter()
.filter_map(|context_id| self.context.get(*context_id));
attach_context_to_message(&mut request_message, referenced_context, cx);
added_context_ids.extend(context_ids.iter());
}
}
if !message.segments.is_empty() {
request_message
.content
@ -970,11 +1008,9 @@ impl Thread {
request.messages.push(request_message);
}
// Set a cache breakpoint at the second-to-last message.
// https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
let breakpoint_index = request.messages.len() - 2;
for (index, message) in request.messages.iter_mut().enumerate() {
message.cache = index == breakpoint_index;
if let Some(last) = request.messages.last_mut() {
last.cache = true;
}
self.attached_tracked_files_state(&mut request.messages, cx);
@ -999,7 +1035,7 @@ impl Thread {
};
if stale_message.is_empty() {
write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
}
writeln!(&mut stale_message, "- {}", file.path().display()).ok();
@ -1453,17 +1489,7 @@ impl Thread {
})
}
pub fn attach_tool_results(
&mut self,
updated_context: Vec<AssistantContext>,
cx: &mut Context<Self>,
) {
self.context.extend(
updated_context
.into_iter()
.map(|context| (context.id(), context)),
);
pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
// Insert a user message to contain the tool results.
self.insert_user_message(
// TODO: Sending up a user message without any content results in the model sending back
@ -1672,6 +1698,11 @@ impl Thread {
Role::System => "System",
}
)?;
if !message.context.is_empty() {
writeln!(markdown, "{}", message.context)?;
}
for segment in &message.segments {
match segment {
MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
@ -1828,3 +1859,415 @@ struct PendingCompletion {
id: usize,
_task: Task<()>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ThreadStore, context_store::ContextStore, thread_store};
use assistant_settings::AssistantSettings;
use context_server::ContextServerSettings;
use editor::EditorSettings;
use gpui::TestAppContext;
use project::{FakeFs, Project};
use prompt_store::PromptBuilder;
use serde_json::json;
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use theme::ThemeSettings;
use util::path;
use workspace::Workspace;
#[gpui::test]
async fn test_message_with_context(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (_workspace, _thread_store, thread, context_store) =
setup_test_environment(cx, project.clone()).await;
add_file_to_context(&project, &context_store, "test/code.rs", cx)
.await
.unwrap();
let context =
context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
// Insert user message with context
let message_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Please explain this code", vec![context], None, cx)
});
// Check content and context in message object
let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
// Use different path format strings based on platform for the test
#[cfg(windows)]
let path_part = r"test\code.rs";
#[cfg(not(windows))]
let path_part = "test/code.rs";
let expected_context = format!(
r#"
<context>
The following items were attached by the user. You don't need to use other tools to read them.
<files>
```rs {path_part}
fn main() {{
println!("Hello, world!");
}}
```
</files>
</context>
"#
);
assert_eq!(message.role, Role::User);
assert_eq!(message.segments.len(), 1);
assert_eq!(
message.segments[0],
MessageSegment::Text("Please explain this code".to_string())
);
assert_eq!(message.context, expected_context);
// Check message in request
let request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 1);
let expected_full_message = format!("{}Please explain this code", expected_context);
assert_eq!(request.messages[0].string_contents(), expected_full_message);
}
#[gpui::test]
async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({
"file1.rs": "fn function1() {}\n",
"file2.rs": "fn function2() {}\n",
"file3.rs": "fn function3() {}\n",
}),
)
.await;
let (_, _thread_store, thread, context_store) =
setup_test_environment(cx, project.clone()).await;
// Open files individually
add_file_to_context(&project, &context_store, "test/file1.rs", cx)
.await
.unwrap();
add_file_to_context(&project, &context_store, "test/file2.rs", cx)
.await
.unwrap();
add_file_to_context(&project, &context_store, "test/file3.rs", cx)
.await
.unwrap();
// Get the context objects
let contexts = context_store.update(cx, |store, _| store.context().clone());
assert_eq!(contexts.len(), 3);
// First message with context 1
let message1_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
});
// Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
let message2_id = thread.update(cx, |thread, cx| {
thread.insert_user_message(
"Message 2",
vec![contexts[0].clone(), contexts[1].clone()],
None,
cx,
)
});
// Third message with all three contexts (contexts 1 and 2 should be skipped)
let message3_id = thread.update(cx, |thread, cx| {
thread.insert_user_message(
"Message 3",
vec![
contexts[0].clone(),
contexts[1].clone(),
contexts[2].clone(),
],
None,
cx,
)
});
// Check what contexts are included in each message
let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
(
thread.message(message1_id).unwrap().clone(),
thread.message(message2_id).unwrap().clone(),
thread.message(message3_id).unwrap().clone(),
)
});
// First message should include context 1
assert!(message1.context.contains("file1.rs"));
// Second message should include only context 2 (not 1)
assert!(!message2.context.contains("file1.rs"));
assert!(message2.context.contains("file2.rs"));
// Third message should include only context 3 (not 1 or 2)
assert!(!message3.context.contains("file1.rs"));
assert!(!message3.context.contains("file2.rs"));
assert!(message3.context.contains("file3.rs"));
// Check entire request to make sure all contexts are properly included
let request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
// The request should contain all 3 messages
assert_eq!(request.messages.len(), 3);
// Check that the contexts are properly formatted in each message
assert!(request.messages[0].string_contents().contains("file1.rs"));
assert!(!request.messages[0].string_contents().contains("file2.rs"));
assert!(!request.messages[0].string_contents().contains("file3.rs"));
assert!(!request.messages[1].string_contents().contains("file1.rs"));
assert!(request.messages[1].string_contents().contains("file2.rs"));
assert!(!request.messages[1].string_contents().contains("file3.rs"));
assert!(!request.messages[2].string_contents().contains("file1.rs"));
assert!(!request.messages[2].string_contents().contains("file2.rs"));
assert!(request.messages[2].string_contents().contains("file3.rs"));
}
#[gpui::test]
async fn test_message_without_files(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (_, _thread_store, thread, _context_store) =
setup_test_environment(cx, project.clone()).await;
// Insert user message without any context (empty context vector)
let message_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
});
// Check content and context in message object
let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
// Context should be empty when no files are included
assert_eq!(message.role, Role::User);
assert_eq!(message.segments.len(), 1);
assert_eq!(
message.segments[0],
MessageSegment::Text("What is the best way to learn Rust?".to_string())
);
assert_eq!(message.context, "");
// Check message in request
let request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 1);
assert_eq!(
request.messages[0].string_contents(),
"What is the best way to learn Rust?"
);
// Add second message, also without context
let message2_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Are there any good books?", vec![], None, cx)
});
let message2 =
thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
assert_eq!(message2.context, "");
// Check that both messages appear in the request
let request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 2);
assert_eq!(
request.messages[0].string_contents(),
"What is the best way to learn Rust?"
);
assert_eq!(
request.messages[1].string_contents(),
"Are there any good books?"
);
}
#[gpui::test]
async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (_workspace, _thread_store, thread, context_store) =
setup_test_environment(cx, project.clone()).await;
// Open buffer and add it to context
let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
.await
.unwrap();
let context =
context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
// Insert user message with the buffer as context
thread.update(cx, |thread, cx| {
thread.insert_user_message("Explain this code", vec![context], None, cx)
});
// Create a request and check that it doesn't have a stale buffer warning yet
let initial_request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
// Make sure we don't have a stale file warning yet
let has_stale_warning = initial_request.messages.iter().any(|msg| {
msg.string_contents()
.contains("These files changed since last read:")
});
assert!(
!has_stale_warning,
"Should not have stale buffer warning before buffer is modified"
);
// Modify the buffer
buffer.update(cx, |buffer, cx| {
// Find a position at the end of line 1
buffer.edit(
[(1..1, "\n println!(\"Added a new line\");\n")],
None,
cx,
);
});
// Insert another user message without context
thread.update(cx, |thread, cx| {
thread.insert_user_message("What does the code do now?", vec![], None, cx)
});
// Create a new request and check for the stale buffer warning
let new_request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
// We should have a stale file warning as the last message
let last_message = new_request
.messages
.last()
.expect("Request should have messages");
// The last message should be the stale buffer notification
assert_eq!(last_message.role, Role::User);
// Check the exact content of the message
let expected_content = "These files changed since last read:\n- code.rs\n";
assert_eq!(
last_message.string_contents(),
expected_content,
"Last message should be exactly the stale buffer notification"
);
}
fn init_test_settings(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
AssistantSettings::register(cx);
thread_store::init(cx);
workspace::init_settings(cx);
ThemeSettings::register(cx);
ContextServerSettings::register(cx);
EditorSettings::register(cx);
});
}
// Helper to create a test project with test files
async fn create_test_project(
cx: &mut TestAppContext,
files: serde_json::Value,
) -> Entity<Project> {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/test"), files).await;
Project::test(fs, [path!("/test").as_ref()], cx).await
}
async fn setup_test_environment(
cx: &mut TestAppContext,
project: Entity<Project>,
) -> (
Entity<Workspace>,
Entity<ThreadStore>,
Entity<Thread>,
Entity<ContextStore>,
) {
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let thread_store = cx.update(|_, cx| {
ThreadStore::new(
project.clone(),
Arc::default(),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
.unwrap()
});
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let context_store = cx.new(|_cx| ContextStore::new(workspace.downgrade(), None));
(workspace, thread_store, thread, context_store)
}
async fn add_file_to_context(
project: &Entity<Project>,
context_store: &Entity<ContextStore>,
path: &str,
cx: &mut TestAppContext,
) -> Result<Entity<language::Buffer>> {
let buffer_path = project
.read_with(cx, |project, cx| project.find_project_path(path, cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(buffer_path, cx))
.await
.unwrap();
context_store
.update(cx, |store, cx| {
store.add_file_from_buffer(buffer.clone(), cx)
})
.await?;
Ok(buffer)
}
}

View file

@ -374,6 +374,8 @@ pub struct SerializedMessage {
pub tool_uses: Vec<SerializedToolUse>,
#[serde(default)]
pub tool_results: Vec<SerializedToolResult>,
#[serde(default)]
pub context: String,
}
#[derive(Debug, Serialize, Deserialize)]
@ -441,6 +443,7 @@ impl LegacySerializedMessage {
segments: vec![SerializedMessageSegment::Text { text: self.text }],
tool_uses: self.tool_uses,
tool_results: self.tool_results,
context: String::new(),
}
}
}