agent: Snapshot context in user message instead of recreating it (#27967)
This makes context essentially work the same way as `read-file`, increasing the likelihood of cache hits. Just like with `read-file`, we'll notify the model when the user makes an edit to one of the tracked files. In the future, we want to send a diff instead of just a list of files, but that's an orthogonal change. Release Notes: - agent: Improved caching of files in context --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
0c82541f0a
commit
315f1bf168
6 changed files with 551 additions and 136 deletions
|
@ -34,7 +34,7 @@ use ui::{Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, Tooltip,
|
|||
use util::ResultExt as _;
|
||||
use workspace::{OpenOptions, Workspace};
|
||||
|
||||
use crate::context_store::{ContextStore, refresh_context_store_text};
|
||||
use crate::context_store::ContextStore;
|
||||
|
||||
pub struct ActiveThread {
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
|
@ -593,54 +593,14 @@ impl ActiveThread {
|
|||
}
|
||||
|
||||
if self.thread.read(cx).all_tools_finished() {
|
||||
let pending_refresh_buffers = self.thread.update(cx, |thread, cx| {
|
||||
thread.action_log().update(cx, |action_log, _cx| {
|
||||
action_log.take_stale_buffers_in_context()
|
||||
})
|
||||
});
|
||||
|
||||
let context_update_task = if !pending_refresh_buffers.is_empty() {
|
||||
let refresh_task = refresh_context_store_text(
|
||||
self.context_store.clone(),
|
||||
&pending_refresh_buffers,
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let updated_context_ids = refresh_task.await;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.context_store.read_with(cx, |context_store, _cx| {
|
||||
context_store
|
||||
.context()
|
||||
.iter()
|
||||
.filter(|context| {
|
||||
updated_context_ids.contains(&context.id())
|
||||
})
|
||||
.cloned()
|
||||
.collect()
|
||||
})
|
||||
})
|
||||
})
|
||||
} else {
|
||||
Task::ready(anyhow::Ok(Vec::new()))
|
||||
};
|
||||
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
if let Some(model) = model_registry.active_model() {
|
||||
cx.spawn(async move |this, cx| {
|
||||
let updated_context = context_update_task.await?;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.thread.update(cx, |thread, cx| {
|
||||
thread.attach_tool_results(updated_context, cx);
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.attach_tool_results(cx);
|
||||
if !canceled {
|
||||
thread.send_to_model(model, RequestKind::Chat, cx);
|
||||
}
|
||||
});
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -146,11 +146,11 @@ pub struct ContextSymbolId {
|
|||
pub range: Range<Anchor>,
|
||||
}
|
||||
|
||||
pub fn attach_context_to_message<'a>(
|
||||
message: &mut LanguageModelRequestMessage,
|
||||
/// Formats a collection of contexts into a string representation
|
||||
pub fn format_context_as_string<'a>(
|
||||
contexts: impl Iterator<Item = &'a AssistantContext>,
|
||||
cx: &App,
|
||||
) {
|
||||
) -> Option<String> {
|
||||
let mut file_context = Vec::new();
|
||||
let mut directory_context = Vec::new();
|
||||
let mut symbol_context = Vec::new();
|
||||
|
@ -167,64 +167,78 @@ pub fn attach_context_to_message<'a>(
|
|||
}
|
||||
}
|
||||
|
||||
let mut context_chunks = Vec::new();
|
||||
if file_context.is_empty()
|
||||
&& directory_context.is_empty()
|
||||
&& symbol_context.is_empty()
|
||||
&& fetch_context.is_empty()
|
||||
&& thread_context.is_empty()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut result = String::new();
|
||||
result.push_str("\n<context>\n\
|
||||
The following items were attached by the user. You don't need to use other tools to read them.\n\n");
|
||||
|
||||
if !file_context.is_empty() {
|
||||
context_chunks.push("<files>\n");
|
||||
result.push_str("<files>\n");
|
||||
for context in file_context {
|
||||
context_chunks.push(&context.context_buffer.text);
|
||||
result.push_str(&context.context_buffer.text);
|
||||
}
|
||||
context_chunks.push("\n</files>\n");
|
||||
result.push_str("</files>\n");
|
||||
}
|
||||
|
||||
if !directory_context.is_empty() {
|
||||
context_chunks.push("<directories>\n");
|
||||
result.push_str("<directories>\n");
|
||||
for context in directory_context {
|
||||
for context_buffer in &context.context_buffers {
|
||||
context_chunks.push(&context_buffer.text);
|
||||
result.push_str(&context_buffer.text);
|
||||
}
|
||||
}
|
||||
context_chunks.push("\n</directories>\n");
|
||||
result.push_str("</directories>\n");
|
||||
}
|
||||
|
||||
if !symbol_context.is_empty() {
|
||||
context_chunks.push("<symbols>\n");
|
||||
result.push_str("<symbols>\n");
|
||||
for context in symbol_context {
|
||||
context_chunks.push(&context.context_symbol.text);
|
||||
result.push_str(&context.context_symbol.text);
|
||||
result.push('\n');
|
||||
}
|
||||
context_chunks.push("\n</symbols>\n");
|
||||
result.push_str("</symbols>\n");
|
||||
}
|
||||
|
||||
if !fetch_context.is_empty() {
|
||||
context_chunks.push("<fetched_urls>\n");
|
||||
result.push_str("<fetched_urls>\n");
|
||||
for context in &fetch_context {
|
||||
context_chunks.push(&context.url);
|
||||
context_chunks.push(&context.text);
|
||||
result.push_str(&context.url);
|
||||
result.push('\n');
|
||||
result.push_str(&context.text);
|
||||
result.push('\n');
|
||||
}
|
||||
context_chunks.push("\n</fetched_urls>\n");
|
||||
result.push_str("</fetched_urls>\n");
|
||||
}
|
||||
|
||||
// Need to own the SharedString for summary so that it can be referenced.
|
||||
let mut thread_context_chunks = Vec::new();
|
||||
if !thread_context.is_empty() {
|
||||
context_chunks.push("<conversation_threads>\n");
|
||||
result.push_str("<conversation_threads>\n");
|
||||
for context in &thread_context {
|
||||
thread_context_chunks.push(context.summary(cx));
|
||||
thread_context_chunks.push(context.text.clone());
|
||||
result.push_str(&context.summary(cx));
|
||||
result.push('\n');
|
||||
result.push_str(&context.text);
|
||||
result.push('\n');
|
||||
}
|
||||
context_chunks.push("\n</conversation_threads>\n");
|
||||
result.push_str("</conversation_threads>\n");
|
||||
}
|
||||
|
||||
for chunk in &thread_context_chunks {
|
||||
context_chunks.push(chunk);
|
||||
}
|
||||
result.push_str("</context>\n");
|
||||
Some(result)
|
||||
}
|
||||
|
||||
if !context_chunks.is_empty() {
|
||||
message.content.push(
|
||||
"\n<context>\n\
|
||||
The following items were attached by the user. You don't need to use other tools to read them.\n\n".into(),
|
||||
);
|
||||
message.content.push(context_chunks.join("\n").into());
|
||||
message.content.push("\n</context>\n".into());
|
||||
pub fn attach_context_to_message<'a>(
|
||||
message: &mut LanguageModelRequestMessage,
|
||||
contexts: impl Iterator<Item = &'a AssistantContext>,
|
||||
cx: &App,
|
||||
) {
|
||||
if let Some(context_string) = format_context_as_string(contexts, cx) {
|
||||
message.content.push(context_string.into());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ use anyhow::{Context as _, Result, anyhow};
|
|||
use assistant_settings::AssistantSettings;
|
||||
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::{BTreeMap, HashMap, HashSet};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use fs::Fs;
|
||||
use futures::future::Shared;
|
||||
use futures::{FutureExt, StreamExt as _};
|
||||
|
@ -30,7 +30,7 @@ use settings::Settings;
|
|||
use util::{ResultExt as _, TryFutureExt as _, maybe, post_inc};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::context::{AssistantContext, ContextId, attach_context_to_message};
|
||||
use crate::context::{AssistantContext, ContextId, format_context_as_string};
|
||||
use crate::thread_store::{
|
||||
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
|
||||
SerializedToolUse,
|
||||
|
@ -82,6 +82,7 @@ pub struct Message {
|
|||
pub id: MessageId,
|
||||
pub role: Role,
|
||||
pub segments: Vec<MessageSegment>,
|
||||
pub context: String,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
|
@ -110,6 +111,11 @@ impl Message {
|
|||
|
||||
pub fn to_string(&self) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
if !self.context.is_empty() {
|
||||
result.push_str(&self.context);
|
||||
}
|
||||
|
||||
for segment in &self.segments {
|
||||
match segment {
|
||||
MessageSegment::Text(text) => result.push_str(text),
|
||||
|
@ -120,11 +126,12 @@ impl Message {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum MessageSegment {
|
||||
Text(String),
|
||||
Thinking(String),
|
||||
|
@ -335,6 +342,7 @@ impl Thread {
|
|||
}
|
||||
})
|
||||
.collect(),
|
||||
context: message.context,
|
||||
})
|
||||
.collect(),
|
||||
next_message_id,
|
||||
|
@ -595,15 +603,58 @@ impl Thread {
|
|||
git_checkpoint: Option<GitStoreCheckpoint>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> MessageId {
|
||||
let message_id =
|
||||
self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
|
||||
let context_ids = context
|
||||
let text = text.into();
|
||||
|
||||
let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
|
||||
|
||||
// Filter out contexts that have already been included in previous messages
|
||||
let new_context: Vec<_> = context
|
||||
.into_iter()
|
||||
.filter(|ctx| !self.context.contains_key(&ctx.id()))
|
||||
.collect();
|
||||
|
||||
if !new_context.is_empty() {
|
||||
if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
|
||||
if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
|
||||
message.context = context_string;
|
||||
}
|
||||
}
|
||||
|
||||
self.action_log.update(cx, |log, cx| {
|
||||
// Track all buffers added as context
|
||||
for ctx in &new_context {
|
||||
match ctx {
|
||||
AssistantContext::File(file_ctx) => {
|
||||
log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
|
||||
}
|
||||
AssistantContext::Directory(dir_ctx) => {
|
||||
for context_buffer in &dir_ctx.context_buffers {
|
||||
log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
|
||||
}
|
||||
}
|
||||
AssistantContext::Symbol(symbol_ctx) => {
|
||||
log.buffer_added_as_context(
|
||||
symbol_ctx.context_symbol.buffer.clone(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let context_ids = new_context
|
||||
.iter()
|
||||
.map(|context| context.id())
|
||||
.collect::<Vec<_>>();
|
||||
self.context
|
||||
.extend(context.into_iter().map(|context| (context.id(), context)));
|
||||
self.context.extend(
|
||||
new_context
|
||||
.into_iter()
|
||||
.map(|context| (context.id(), context)),
|
||||
);
|
||||
self.context_by_message.insert(message_id, context_ids);
|
||||
|
||||
if let Some(git_checkpoint) = git_checkpoint {
|
||||
self.pending_checkpoint = Some(ThreadCheckpoint {
|
||||
message_id,
|
||||
|
@ -620,7 +671,12 @@ impl Thread {
|
|||
cx: &mut Context<Self>,
|
||||
) -> MessageId {
|
||||
let id = self.next_message_id.post_inc();
|
||||
self.messages.push(Message { id, role, segments });
|
||||
self.messages.push(Message {
|
||||
id,
|
||||
role,
|
||||
segments,
|
||||
context: String::new(),
|
||||
});
|
||||
self.touch_updated_at();
|
||||
cx.emit(ThreadEvent::MessageAdded(id));
|
||||
id
|
||||
|
@ -726,6 +782,7 @@ impl Thread {
|
|||
content: tool_result.content.clone(),
|
||||
})
|
||||
.collect(),
|
||||
context: message.context.clone(),
|
||||
})
|
||||
.collect(),
|
||||
initial_project_snapshot,
|
||||
|
@ -912,8 +969,6 @@ impl Thread {
|
|||
log::error!("system_prompt_context not set.")
|
||||
}
|
||||
|
||||
let mut added_context_ids = HashSet::<ContextId>::default();
|
||||
|
||||
for message in &self.messages {
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: message.role,
|
||||
|
@ -934,23 +989,6 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
// Attach context to this message if it's the first to reference it
|
||||
if let Some(context_ids) = self.context_by_message.get(&message.id) {
|
||||
let new_context_ids: Vec<_> = context_ids
|
||||
.iter()
|
||||
.filter(|id| !added_context_ids.contains(id))
|
||||
.collect();
|
||||
|
||||
if !new_context_ids.is_empty() {
|
||||
let referenced_context = new_context_ids
|
||||
.iter()
|
||||
.filter_map(|context_id| self.context.get(*context_id));
|
||||
|
||||
attach_context_to_message(&mut request_message, referenced_context, cx);
|
||||
added_context_ids.extend(context_ids.iter());
|
||||
}
|
||||
}
|
||||
|
||||
if !message.segments.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
|
@ -970,11 +1008,9 @@ impl Thread {
|
|||
request.messages.push(request_message);
|
||||
}
|
||||
|
||||
// Set a cache breakpoint at the second-to-last message.
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
|
||||
let breakpoint_index = request.messages.len() - 2;
|
||||
for (index, message) in request.messages.iter_mut().enumerate() {
|
||||
message.cache = index == breakpoint_index;
|
||||
if let Some(last) = request.messages.last_mut() {
|
||||
last.cache = true;
|
||||
}
|
||||
|
||||
self.attached_tracked_files_state(&mut request.messages, cx);
|
||||
|
@ -999,7 +1035,7 @@ impl Thread {
|
|||
};
|
||||
|
||||
if stale_message.is_empty() {
|
||||
write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
|
||||
write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
|
||||
}
|
||||
|
||||
writeln!(&mut stale_message, "- {}", file.path().display()).ok();
|
||||
|
@ -1453,17 +1489,7 @@ impl Thread {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn attach_tool_results(
|
||||
&mut self,
|
||||
updated_context: Vec<AssistantContext>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.context.extend(
|
||||
updated_context
|
||||
.into_iter()
|
||||
.map(|context| (context.id(), context)),
|
||||
);
|
||||
|
||||
pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
|
||||
// Insert a user message to contain the tool results.
|
||||
self.insert_user_message(
|
||||
// TODO: Sending up a user message without any content results in the model sending back
|
||||
|
@ -1672,6 +1698,11 @@ impl Thread {
|
|||
Role::System => "System",
|
||||
}
|
||||
)?;
|
||||
|
||||
if !message.context.is_empty() {
|
||||
writeln!(markdown, "{}", message.context)?;
|
||||
}
|
||||
|
||||
for segment in &message.segments {
|
||||
match segment {
|
||||
MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
|
||||
|
@ -1828,3 +1859,415 @@ struct PendingCompletion {
|
|||
id: usize,
|
||||
_task: Task<()>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{ThreadStore, context_store::ContextStore, thread_store};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use context_server::ContextServerSettings;
|
||||
use editor::EditorSettings;
|
||||
use gpui::TestAppContext;
|
||||
use project::{FakeFs, Project};
|
||||
use prompt_store::PromptBuilder;
|
||||
use serde_json::json;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use theme::ThemeSettings;
|
||||
use util::path;
|
||||
use workspace::Workspace;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_message_with_context(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let project = create_test_project(
|
||||
cx,
|
||||
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let (_workspace, _thread_store, thread, context_store) =
|
||||
setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
add_file_to_context(&project, &context_store, "test/code.rs", cx)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let context =
|
||||
context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
|
||||
|
||||
// Insert user message with context
|
||||
let message_id = thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("Please explain this code", vec![context], None, cx)
|
||||
});
|
||||
|
||||
// Check content and context in message object
|
||||
let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
|
||||
|
||||
// Use different path format strings based on platform for the test
|
||||
#[cfg(windows)]
|
||||
let path_part = r"test\code.rs";
|
||||
#[cfg(not(windows))]
|
||||
let path_part = "test/code.rs";
|
||||
|
||||
let expected_context = format!(
|
||||
r#"
|
||||
<context>
|
||||
The following items were attached by the user. You don't need to use other tools to read them.
|
||||
|
||||
<files>
|
||||
```rs {path_part}
|
||||
fn main() {{
|
||||
println!("Hello, world!");
|
||||
}}
|
||||
```
|
||||
</files>
|
||||
</context>
|
||||
"#
|
||||
);
|
||||
|
||||
assert_eq!(message.role, Role::User);
|
||||
assert_eq!(message.segments.len(), 1);
|
||||
assert_eq!(
|
||||
message.segments[0],
|
||||
MessageSegment::Text("Please explain this code".to_string())
|
||||
);
|
||||
assert_eq!(message.context, expected_context);
|
||||
|
||||
// Check message in request
|
||||
let request = thread.read_with(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
assert_eq!(request.messages.len(), 1);
|
||||
let expected_full_message = format!("{}Please explain this code", expected_context);
|
||||
assert_eq!(request.messages[0].string_contents(), expected_full_message);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let project = create_test_project(
|
||||
cx,
|
||||
json!({
|
||||
"file1.rs": "fn function1() {}\n",
|
||||
"file2.rs": "fn function2() {}\n",
|
||||
"file3.rs": "fn function3() {}\n",
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let (_, _thread_store, thread, context_store) =
|
||||
setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Open files individually
|
||||
add_file_to_context(&project, &context_store, "test/file1.rs", cx)
|
||||
.await
|
||||
.unwrap();
|
||||
add_file_to_context(&project, &context_store, "test/file2.rs", cx)
|
||||
.await
|
||||
.unwrap();
|
||||
add_file_to_context(&project, &context_store, "test/file3.rs", cx)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Get the context objects
|
||||
let contexts = context_store.update(cx, |store, _| store.context().clone());
|
||||
assert_eq!(contexts.len(), 3);
|
||||
|
||||
// First message with context 1
|
||||
let message1_id = thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
|
||||
});
|
||||
|
||||
// Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
|
||||
let message2_id = thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message(
|
||||
"Message 2",
|
||||
vec![contexts[0].clone(), contexts[1].clone()],
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Third message with all three contexts (contexts 1 and 2 should be skipped)
|
||||
let message3_id = thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message(
|
||||
"Message 3",
|
||||
vec![
|
||||
contexts[0].clone(),
|
||||
contexts[1].clone(),
|
||||
contexts[2].clone(),
|
||||
],
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Check what contexts are included in each message
|
||||
let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
|
||||
(
|
||||
thread.message(message1_id).unwrap().clone(),
|
||||
thread.message(message2_id).unwrap().clone(),
|
||||
thread.message(message3_id).unwrap().clone(),
|
||||
)
|
||||
});
|
||||
|
||||
// First message should include context 1
|
||||
assert!(message1.context.contains("file1.rs"));
|
||||
|
||||
// Second message should include only context 2 (not 1)
|
||||
assert!(!message2.context.contains("file1.rs"));
|
||||
assert!(message2.context.contains("file2.rs"));
|
||||
|
||||
// Third message should include only context 3 (not 1 or 2)
|
||||
assert!(!message3.context.contains("file1.rs"));
|
||||
assert!(!message3.context.contains("file2.rs"));
|
||||
assert!(message3.context.contains("file3.rs"));
|
||||
|
||||
// Check entire request to make sure all contexts are properly included
|
||||
let request = thread.read_with(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
// The request should contain all 3 messages
|
||||
assert_eq!(request.messages.len(), 3);
|
||||
|
||||
// Check that the contexts are properly formatted in each message
|
||||
assert!(request.messages[0].string_contents().contains("file1.rs"));
|
||||
assert!(!request.messages[0].string_contents().contains("file2.rs"));
|
||||
assert!(!request.messages[0].string_contents().contains("file3.rs"));
|
||||
|
||||
assert!(!request.messages[1].string_contents().contains("file1.rs"));
|
||||
assert!(request.messages[1].string_contents().contains("file2.rs"));
|
||||
assert!(!request.messages[1].string_contents().contains("file3.rs"));
|
||||
|
||||
assert!(!request.messages[2].string_contents().contains("file1.rs"));
|
||||
assert!(!request.messages[2].string_contents().contains("file2.rs"));
|
||||
assert!(request.messages[2].string_contents().contains("file3.rs"));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_message_without_files(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let project = create_test_project(
|
||||
cx,
|
||||
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let (_, _thread_store, thread, _context_store) =
|
||||
setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Insert user message without any context (empty context vector)
|
||||
let message_id = thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
|
||||
});
|
||||
|
||||
// Check content and context in message object
|
||||
let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
|
||||
|
||||
// Context should be empty when no files are included
|
||||
assert_eq!(message.role, Role::User);
|
||||
assert_eq!(message.segments.len(), 1);
|
||||
assert_eq!(
|
||||
message.segments[0],
|
||||
MessageSegment::Text("What is the best way to learn Rust?".to_string())
|
||||
);
|
||||
assert_eq!(message.context, "");
|
||||
|
||||
// Check message in request
|
||||
let request = thread.read_with(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
assert_eq!(request.messages.len(), 1);
|
||||
assert_eq!(
|
||||
request.messages[0].string_contents(),
|
||||
"What is the best way to learn Rust?"
|
||||
);
|
||||
|
||||
// Add second message, also without context
|
||||
let message2_id = thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("Are there any good books?", vec![], None, cx)
|
||||
});
|
||||
|
||||
let message2 =
|
||||
thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
|
||||
assert_eq!(message2.context, "");
|
||||
|
||||
// Check that both messages appear in the request
|
||||
let request = thread.read_with(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
assert_eq!(request.messages.len(), 2);
|
||||
assert_eq!(
|
||||
request.messages[0].string_contents(),
|
||||
"What is the best way to learn Rust?"
|
||||
);
|
||||
assert_eq!(
|
||||
request.messages[1].string_contents(),
|
||||
"Are there any good books?"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let project = create_test_project(
|
||||
cx,
|
||||
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let (_workspace, _thread_store, thread, context_store) =
|
||||
setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Open buffer and add it to context
|
||||
let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let context =
|
||||
context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
|
||||
|
||||
// Insert user message with the buffer as context
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("Explain this code", vec![context], None, cx)
|
||||
});
|
||||
|
||||
// Create a request and check that it doesn't have a stale buffer warning yet
|
||||
let initial_request = thread.read_with(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
// Make sure we don't have a stale file warning yet
|
||||
let has_stale_warning = initial_request.messages.iter().any(|msg| {
|
||||
msg.string_contents()
|
||||
.contains("These files changed since last read:")
|
||||
});
|
||||
assert!(
|
||||
!has_stale_warning,
|
||||
"Should not have stale buffer warning before buffer is modified"
|
||||
);
|
||||
|
||||
// Modify the buffer
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
// Find a position at the end of line 1
|
||||
buffer.edit(
|
||||
[(1..1, "\n println!(\"Added a new line\");\n")],
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
// Insert another user message without context
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("What does the code do now?", vec![], None, cx)
|
||||
});
|
||||
|
||||
// Create a new request and check for the stale buffer warning
|
||||
let new_request = thread.read_with(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
// We should have a stale file warning as the last message
|
||||
let last_message = new_request
|
||||
.messages
|
||||
.last()
|
||||
.expect("Request should have messages");
|
||||
|
||||
// The last message should be the stale buffer notification
|
||||
assert_eq!(last_message.role, Role::User);
|
||||
|
||||
// Check the exact content of the message
|
||||
let expected_content = "These files changed since last read:\n- code.rs\n";
|
||||
assert_eq!(
|
||||
last_message.string_contents(),
|
||||
expected_content,
|
||||
"Last message should be exactly the stale buffer notification"
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test_settings(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
AssistantSettings::register(cx);
|
||||
thread_store::init(cx);
|
||||
workspace::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
ContextServerSettings::register(cx);
|
||||
EditorSettings::register(cx);
|
||||
});
|
||||
}
|
||||
|
||||
// Helper to create a test project with test files
|
||||
async fn create_test_project(
|
||||
cx: &mut TestAppContext,
|
||||
files: serde_json::Value,
|
||||
) -> Entity<Project> {
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/test"), files).await;
|
||||
Project::test(fs, [path!("/test").as_ref()], cx).await
|
||||
}
|
||||
|
||||
async fn setup_test_environment(
|
||||
cx: &mut TestAppContext,
|
||||
project: Entity<Project>,
|
||||
) -> (
|
||||
Entity<Workspace>,
|
||||
Entity<ThreadStore>,
|
||||
Entity<Thread>,
|
||||
Entity<ContextStore>,
|
||||
) {
|
||||
let (workspace, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
|
||||
let thread_store = cx.update(|_, cx| {
|
||||
ThreadStore::new(
|
||||
project.clone(),
|
||||
Arc::default(),
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
||||
let context_store = cx.new(|_cx| ContextStore::new(workspace.downgrade(), None));
|
||||
|
||||
(workspace, thread_store, thread, context_store)
|
||||
}
|
||||
|
||||
async fn add_file_to_context(
|
||||
project: &Entity<Project>,
|
||||
context_store: &Entity<ContextStore>,
|
||||
path: &str,
|
||||
cx: &mut TestAppContext,
|
||||
) -> Result<Entity<language::Buffer>> {
|
||||
let buffer_path = project
|
||||
.read_with(cx, |project, cx| project.find_project_path(path, cx))
|
||||
.unwrap();
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(buffer_path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
context_store
|
||||
.update(cx, |store, cx| {
|
||||
store.add_file_from_buffer(buffer.clone(), cx)
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(buffer)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -374,6 +374,8 @@ pub struct SerializedMessage {
|
|||
pub tool_uses: Vec<SerializedToolUse>,
|
||||
#[serde(default)]
|
||||
pub tool_results: Vec<SerializedToolResult>,
|
||||
#[serde(default)]
|
||||
pub context: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
@ -441,6 +443,7 @@ impl LegacySerializedMessage {
|
|||
segments: vec![SerializedMessageSegment::Text { text: self.text }],
|
||||
tool_uses: self.tool_uses,
|
||||
tool_results: self.tool_results,
|
||||
context: String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -158,7 +158,7 @@ impl HeadlessAssistant {
|
|||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
if let Some(model) = model_registry.active_model() {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.attach_tool_results(vec![], cx);
|
||||
thread.attach_tool_results(cx);
|
||||
thread.send_to_model(model, RequestKind::Chat, cx);
|
||||
});
|
||||
} else {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use anyhow::{Context as _, Result};
|
||||
use buffer_diff::BufferDiff;
|
||||
use collections::{BTreeMap, HashSet};
|
||||
use collections::BTreeMap;
|
||||
use futures::{StreamExt, channel::mpsc};
|
||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
|
||||
use language::{Anchor, Buffer, BufferEvent, DiskState, Point};
|
||||
|
@ -10,9 +10,6 @@ use util::RangeExt;
|
|||
|
||||
/// Tracks actions performed by tools in a thread
|
||||
pub struct ActionLog {
|
||||
/// Buffers that user manually added to the context, and whose content has
|
||||
/// changed since the model last saw them.
|
||||
stale_buffers_in_context: HashSet<Entity<Buffer>>,
|
||||
/// Buffers that we want to notify the model about when they change.
|
||||
tracked_buffers: BTreeMap<Entity<Buffer>, TrackedBuffer>,
|
||||
/// Has the model edited a file since it last checked diagnostics?
|
||||
|
@ -23,7 +20,6 @@ impl ActionLog {
|
|||
/// Creates a new, empty action log.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
stale_buffers_in_context: HashSet::default(),
|
||||
tracked_buffers: BTreeMap::default(),
|
||||
edited_since_project_diagnostics_check: false,
|
||||
}
|
||||
|
@ -259,6 +255,11 @@ impl ActionLog {
|
|||
self.track_buffer(buffer, false, cx);
|
||||
}
|
||||
|
||||
/// Track a buffer that was added as context, so we can notify the model about user edits.
|
||||
pub fn buffer_added_as_context(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.track_buffer(buffer, false, cx);
|
||||
}
|
||||
|
||||
/// Track a buffer as read, so we can notify the model about user edits.
|
||||
pub fn will_create_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.track_buffer(buffer.clone(), true, cx);
|
||||
|
@ -268,7 +269,6 @@ impl ActionLog {
|
|||
/// Mark a buffer as edited, so we can refresh it in the context
|
||||
pub fn buffer_edited(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.edited_since_project_diagnostics_check = true;
|
||||
self.stale_buffers_in_context.insert(buffer.clone());
|
||||
|
||||
let tracked_buffer = self.track_buffer(buffer.clone(), false, cx);
|
||||
if let TrackedBufferStatus::Deleted = tracked_buffer.status {
|
||||
|
@ -391,11 +391,6 @@ impl ActionLog {
|
|||
})
|
||||
.map(|(buffer, _)| buffer)
|
||||
}
|
||||
|
||||
/// Takes and returns the set of buffers pending refresh, clearing internal state.
|
||||
pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
|
||||
std::mem::take(&mut self.stale_buffers_in_context)
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_non_conflicting_edits(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue