diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 6765710f77..933f7bd0b7 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -34,7 +34,7 @@ use ui::{Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, Tooltip, use util::ResultExt as _; use workspace::{OpenOptions, Workspace}; -use crate::context_store::{ContextStore, refresh_context_store_text}; +use crate::context_store::ContextStore; pub struct ActiveThread { language_registry: Arc, @@ -593,54 +593,14 @@ impl ActiveThread { } if self.thread.read(cx).all_tools_finished() { - let pending_refresh_buffers = self.thread.update(cx, |thread, cx| { - thread.action_log().update(cx, |action_log, _cx| { - action_log.take_stale_buffers_in_context() - }) - }); - - let context_update_task = if !pending_refresh_buffers.is_empty() { - let refresh_task = refresh_context_store_text( - self.context_store.clone(), - &pending_refresh_buffers, - cx, - ); - - cx.spawn(async move |this, cx| { - let updated_context_ids = refresh_task.await; - - this.update(cx, |this, cx| { - this.context_store.read_with(cx, |context_store, _cx| { - context_store - .context() - .iter() - .filter(|context| { - updated_context_ids.contains(&context.id()) - }) - .cloned() - .collect() - }) - }) - }) - } else { - Task::ready(anyhow::Ok(Vec::new())) - }; - let model_registry = LanguageModelRegistry::read_global(cx); if let Some(model) = model_registry.active_model() { - cx.spawn(async move |this, cx| { - let updated_context = context_update_task.await?; - - this.update(cx, |this, cx| { - this.thread.update(cx, |thread, cx| { - thread.attach_tool_results(updated_context, cx); - if !canceled { - thread.send_to_model(model, RequestKind::Chat, cx); - } - }); - }) - }) - .detach(); + self.thread.update(cx, |thread, cx| { + thread.attach_tool_results(cx); + if !canceled { + thread.send_to_model(model, RequestKind::Chat, cx); + } + }); } } } diff --git a/crates/agent/src/context.rs b/crates/agent/src/context.rs index 57169de0af..9133ca3ffa 100644 --- a/crates/agent/src/context.rs +++ b/crates/agent/src/context.rs @@ -146,11 +146,11 @@ pub struct ContextSymbolId { pub range: Range, } -pub fn attach_context_to_message<'a>( - message: &mut LanguageModelRequestMessage, +/// Formats a collection of contexts into a string representation +pub fn format_context_as_string<'a>( contexts: impl Iterator, cx: &App, -) { +) -> Option { let mut file_context = Vec::new(); let mut directory_context = Vec::new(); let mut symbol_context = Vec::new(); @@ -167,64 +167,78 @@ pub fn attach_context_to_message<'a>( } } - let mut context_chunks = Vec::new(); + if file_context.is_empty() + && directory_context.is_empty() + && symbol_context.is_empty() + && fetch_context.is_empty() + && thread_context.is_empty() + { + return None; + } + + let mut result = String::new(); + result.push_str("\n\n\ + The following items were attached by the user. You don't need to use other tools to read them.\n\n"); if !file_context.is_empty() { - context_chunks.push("\n"); + result.push_str("\n"); for context in file_context { - context_chunks.push(&context.context_buffer.text); + result.push_str(&context.context_buffer.text); } - context_chunks.push("\n\n"); + result.push_str("\n"); } if !directory_context.is_empty() { - context_chunks.push("\n"); + result.push_str("\n"); for context in directory_context { for context_buffer in &context.context_buffers { - context_chunks.push(&context_buffer.text); + result.push_str(&context_buffer.text); } } - context_chunks.push("\n\n"); + result.push_str("\n"); } if !symbol_context.is_empty() { - context_chunks.push("\n"); + result.push_str("\n"); for context in symbol_context { - context_chunks.push(&context.context_symbol.text); + result.push_str(&context.context_symbol.text); + result.push('\n'); } - context_chunks.push("\n\n"); + result.push_str("\n"); } if !fetch_context.is_empty() { - context_chunks.push("\n"); + result.push_str("\n"); for context in &fetch_context { - context_chunks.push(&context.url); - context_chunks.push(&context.text); + result.push_str(&context.url); + result.push('\n'); + result.push_str(&context.text); + result.push('\n'); } - context_chunks.push("\n\n"); + result.push_str("\n"); } - // Need to own the SharedString for summary so that it can be referenced. - let mut thread_context_chunks = Vec::new(); if !thread_context.is_empty() { - context_chunks.push("\n"); + result.push_str("\n"); for context in &thread_context { - thread_context_chunks.push(context.summary(cx)); - thread_context_chunks.push(context.text.clone()); + result.push_str(&context.summary(cx)); + result.push('\n'); + result.push_str(&context.text); + result.push('\n'); } - context_chunks.push("\n\n"); + result.push_str("\n"); } - for chunk in &thread_context_chunks { - context_chunks.push(chunk); - } + result.push_str("\n"); + Some(result) +} - if !context_chunks.is_empty() { - message.content.push( - "\n\n\ - The following items were attached by the user. You don't need to use other tools to read them.\n\n".into(), - ); - message.content.push(context_chunks.join("\n").into()); - message.content.push("\n\n".into()); +pub fn attach_context_to_message<'a>( + message: &mut LanguageModelRequestMessage, + contexts: impl Iterator, + cx: &App, +) { + if let Some(context_string) = format_context_as_string(contexts, cx) { + message.content.push(context_string.into()); } } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 8b5941dff6..207f47cae0 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -7,7 +7,7 @@ use anyhow::{Context as _, Result, anyhow}; use assistant_settings::AssistantSettings; use assistant_tool::{ActionLog, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; -use collections::{BTreeMap, HashMap, HashSet}; +use collections::{BTreeMap, HashMap}; use fs::Fs; use futures::future::Shared; use futures::{FutureExt, StreamExt as _}; @@ -30,7 +30,7 @@ use settings::Settings; use util::{ResultExt as _, TryFutureExt as _, maybe, post_inc}; use uuid::Uuid; -use crate::context::{AssistantContext, ContextId, attach_context_to_message}; +use crate::context::{AssistantContext, ContextId, format_context_as_string}; use crate::thread_store::{ SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult, SerializedToolUse, @@ -82,6 +82,7 @@ pub struct Message { pub id: MessageId, pub role: Role, pub segments: Vec, + pub context: String, } impl Message { @@ -110,6 +111,11 @@ impl Message { pub fn to_string(&self) -> String { let mut result = String::new(); + + if !self.context.is_empty() { + result.push_str(&self.context); + } + for segment in &self.segments { match segment { MessageSegment::Text(text) => result.push_str(text), @@ -120,11 +126,12 @@ impl Message { } } } + result } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum MessageSegment { Text(String), Thinking(String), @@ -335,6 +342,7 @@ impl Thread { } }) .collect(), + context: message.context, }) .collect(), next_message_id, @@ -595,15 +603,58 @@ impl Thread { git_checkpoint: Option, cx: &mut Context, ) -> MessageId { - let message_id = - self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx); - let context_ids = context + let text = text.into(); + + let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx); + + // Filter out contexts that have already been included in previous messages + let new_context: Vec<_> = context + .into_iter() + .filter(|ctx| !self.context.contains_key(&ctx.id())) + .collect(); + + if !new_context.is_empty() { + if let Some(context_string) = format_context_as_string(new_context.iter(), cx) { + if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) { + message.context = context_string; + } + } + + self.action_log.update(cx, |log, cx| { + // Track all buffers added as context + for ctx in &new_context { + match ctx { + AssistantContext::File(file_ctx) => { + log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx); + } + AssistantContext::Directory(dir_ctx) => { + for context_buffer in &dir_ctx.context_buffers { + log.buffer_added_as_context(context_buffer.buffer.clone(), cx); + } + } + AssistantContext::Symbol(symbol_ctx) => { + log.buffer_added_as_context( + symbol_ctx.context_symbol.buffer.clone(), + cx, + ); + } + AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {} + } + } + }); + } + + let context_ids = new_context .iter() .map(|context| context.id()) .collect::>(); - self.context - .extend(context.into_iter().map(|context| (context.id(), context))); + self.context.extend( + new_context + .into_iter() + .map(|context| (context.id(), context)), + ); self.context_by_message.insert(message_id, context_ids); + if let Some(git_checkpoint) = git_checkpoint { self.pending_checkpoint = Some(ThreadCheckpoint { message_id, @@ -620,7 +671,12 @@ impl Thread { cx: &mut Context, ) -> MessageId { let id = self.next_message_id.post_inc(); - self.messages.push(Message { id, role, segments }); + self.messages.push(Message { + id, + role, + segments, + context: String::new(), + }); self.touch_updated_at(); cx.emit(ThreadEvent::MessageAdded(id)); id @@ -726,6 +782,7 @@ impl Thread { content: tool_result.content.clone(), }) .collect(), + context: message.context.clone(), }) .collect(), initial_project_snapshot, @@ -912,8 +969,6 @@ impl Thread { log::error!("system_prompt_context not set.") } - let mut added_context_ids = HashSet::::default(); - for message in &self.messages { let mut request_message = LanguageModelRequestMessage { role: message.role, @@ -934,23 +989,6 @@ impl Thread { } } - // Attach context to this message if it's the first to reference it - if let Some(context_ids) = self.context_by_message.get(&message.id) { - let new_context_ids: Vec<_> = context_ids - .iter() - .filter(|id| !added_context_ids.contains(id)) - .collect(); - - if !new_context_ids.is_empty() { - let referenced_context = new_context_ids - .iter() - .filter_map(|context_id| self.context.get(*context_id)); - - attach_context_to_message(&mut request_message, referenced_context, cx); - added_context_ids.extend(context_ids.iter()); - } - } - if !message.segments.is_empty() { request_message .content @@ -970,11 +1008,9 @@ impl Thread { request.messages.push(request_message); } - // Set a cache breakpoint at the second-to-last message. // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching - let breakpoint_index = request.messages.len() - 2; - for (index, message) in request.messages.iter_mut().enumerate() { - message.cache = index == breakpoint_index; + if let Some(last) = request.messages.last_mut() { + last.cache = true; } self.attached_tracked_files_state(&mut request.messages, cx); @@ -999,7 +1035,7 @@ impl Thread { }; if stale_message.is_empty() { - write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok(); + write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok(); } writeln!(&mut stale_message, "- {}", file.path().display()).ok(); @@ -1453,17 +1489,7 @@ impl Thread { }) } - pub fn attach_tool_results( - &mut self, - updated_context: Vec, - cx: &mut Context, - ) { - self.context.extend( - updated_context - .into_iter() - .map(|context| (context.id(), context)), - ); - + pub fn attach_tool_results(&mut self, cx: &mut Context) { // Insert a user message to contain the tool results. self.insert_user_message( // TODO: Sending up a user message without any content results in the model sending back @@ -1672,6 +1698,11 @@ impl Thread { Role::System => "System", } )?; + + if !message.context.is_empty() { + writeln!(markdown, "{}", message.context)?; + } + for segment in &message.segments { match segment { MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?, @@ -1828,3 +1859,415 @@ struct PendingCompletion { id: usize, _task: Task<()>, } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ThreadStore, context_store::ContextStore, thread_store}; + use assistant_settings::AssistantSettings; + use context_server::ContextServerSettings; + use editor::EditorSettings; + use gpui::TestAppContext; + use project::{FakeFs, Project}; + use prompt_store::PromptBuilder; + use serde_json::json; + use settings::{Settings, SettingsStore}; + use std::sync::Arc; + use theme::ThemeSettings; + use util::path; + use workspace::Workspace; + + #[gpui::test] + async fn test_message_with_context(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project( + cx, + json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), + ) + .await; + + let (_workspace, _thread_store, thread, context_store) = + setup_test_environment(cx, project.clone()).await; + + add_file_to_context(&project, &context_store, "test/code.rs", cx) + .await + .unwrap(); + + let context = + context_store.update(cx, |store, _| store.context().first().cloned().unwrap()); + + // Insert user message with context + let message_id = thread.update(cx, |thread, cx| { + thread.insert_user_message("Please explain this code", vec![context], None, cx) + }); + + // Check content and context in message object + let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone()); + + // Use different path format strings based on platform for the test + #[cfg(windows)] + let path_part = r"test\code.rs"; + #[cfg(not(windows))] + let path_part = "test/code.rs"; + + let expected_context = format!( + r#" + +The following items were attached by the user. You don't need to use other tools to read them. + + +```rs {path_part} +fn main() {{ + println!("Hello, world!"); +}} +``` + + +"# + ); + + assert_eq!(message.role, Role::User); + assert_eq!(message.segments.len(), 1); + assert_eq!( + message.segments[0], + MessageSegment::Text("Please explain this code".to_string()) + ); + assert_eq!(message.context, expected_context); + + // Check message in request + let request = thread.read_with(cx, |thread, cx| { + thread.to_completion_request(RequestKind::Chat, cx) + }); + + assert_eq!(request.messages.len(), 1); + let expected_full_message = format!("{}Please explain this code", expected_context); + assert_eq!(request.messages[0].string_contents(), expected_full_message); + } + + #[gpui::test] + async fn test_only_include_new_contexts(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project( + cx, + json!({ + "file1.rs": "fn function1() {}\n", + "file2.rs": "fn function2() {}\n", + "file3.rs": "fn function3() {}\n", + }), + ) + .await; + + let (_, _thread_store, thread, context_store) = + setup_test_environment(cx, project.clone()).await; + + // Open files individually + add_file_to_context(&project, &context_store, "test/file1.rs", cx) + .await + .unwrap(); + add_file_to_context(&project, &context_store, "test/file2.rs", cx) + .await + .unwrap(); + add_file_to_context(&project, &context_store, "test/file3.rs", cx) + .await + .unwrap(); + + // Get the context objects + let contexts = context_store.update(cx, |store, _| store.context().clone()); + assert_eq!(contexts.len(), 3); + + // First message with context 1 + let message1_id = thread.update(cx, |thread, cx| { + thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx) + }); + + // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included) + let message2_id = thread.update(cx, |thread, cx| { + thread.insert_user_message( + "Message 2", + vec![contexts[0].clone(), contexts[1].clone()], + None, + cx, + ) + }); + + // Third message with all three contexts (contexts 1 and 2 should be skipped) + let message3_id = thread.update(cx, |thread, cx| { + thread.insert_user_message( + "Message 3", + vec![ + contexts[0].clone(), + contexts[1].clone(), + contexts[2].clone(), + ], + None, + cx, + ) + }); + + // Check what contexts are included in each message + let (message1, message2, message3) = thread.read_with(cx, |thread, _| { + ( + thread.message(message1_id).unwrap().clone(), + thread.message(message2_id).unwrap().clone(), + thread.message(message3_id).unwrap().clone(), + ) + }); + + // First message should include context 1 + assert!(message1.context.contains("file1.rs")); + + // Second message should include only context 2 (not 1) + assert!(!message2.context.contains("file1.rs")); + assert!(message2.context.contains("file2.rs")); + + // Third message should include only context 3 (not 1 or 2) + assert!(!message3.context.contains("file1.rs")); + assert!(!message3.context.contains("file2.rs")); + assert!(message3.context.contains("file3.rs")); + + // Check entire request to make sure all contexts are properly included + let request = thread.read_with(cx, |thread, cx| { + thread.to_completion_request(RequestKind::Chat, cx) + }); + + // The request should contain all 3 messages + assert_eq!(request.messages.len(), 3); + + // Check that the contexts are properly formatted in each message + assert!(request.messages[0].string_contents().contains("file1.rs")); + assert!(!request.messages[0].string_contents().contains("file2.rs")); + assert!(!request.messages[0].string_contents().contains("file3.rs")); + + assert!(!request.messages[1].string_contents().contains("file1.rs")); + assert!(request.messages[1].string_contents().contains("file2.rs")); + assert!(!request.messages[1].string_contents().contains("file3.rs")); + + assert!(!request.messages[2].string_contents().contains("file1.rs")); + assert!(!request.messages[2].string_contents().contains("file2.rs")); + assert!(request.messages[2].string_contents().contains("file3.rs")); + } + + #[gpui::test] + async fn test_message_without_files(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project( + cx, + json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), + ) + .await; + + let (_, _thread_store, thread, _context_store) = + setup_test_environment(cx, project.clone()).await; + + // Insert user message without any context (empty context vector) + let message_id = thread.update(cx, |thread, cx| { + thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx) + }); + + // Check content and context in message object + let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone()); + + // Context should be empty when no files are included + assert_eq!(message.role, Role::User); + assert_eq!(message.segments.len(), 1); + assert_eq!( + message.segments[0], + MessageSegment::Text("What is the best way to learn Rust?".to_string()) + ); + assert_eq!(message.context, ""); + + // Check message in request + let request = thread.read_with(cx, |thread, cx| { + thread.to_completion_request(RequestKind::Chat, cx) + }); + + assert_eq!(request.messages.len(), 1); + assert_eq!( + request.messages[0].string_contents(), + "What is the best way to learn Rust?" + ); + + // Add second message, also without context + let message2_id = thread.update(cx, |thread, cx| { + thread.insert_user_message("Are there any good books?", vec![], None, cx) + }); + + let message2 = + thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone()); + assert_eq!(message2.context, ""); + + // Check that both messages appear in the request + let request = thread.read_with(cx, |thread, cx| { + thread.to_completion_request(RequestKind::Chat, cx) + }); + + assert_eq!(request.messages.len(), 2); + assert_eq!( + request.messages[0].string_contents(), + "What is the best way to learn Rust?" + ); + assert_eq!( + request.messages[1].string_contents(), + "Are there any good books?" + ); + } + + #[gpui::test] + async fn test_stale_buffer_notification(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project( + cx, + json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), + ) + .await; + + let (_workspace, _thread_store, thread, context_store) = + setup_test_environment(cx, project.clone()).await; + + // Open buffer and add it to context + let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx) + .await + .unwrap(); + + let context = + context_store.update(cx, |store, _| store.context().first().cloned().unwrap()); + + // Insert user message with the buffer as context + thread.update(cx, |thread, cx| { + thread.insert_user_message("Explain this code", vec![context], None, cx) + }); + + // Create a request and check that it doesn't have a stale buffer warning yet + let initial_request = thread.read_with(cx, |thread, cx| { + thread.to_completion_request(RequestKind::Chat, cx) + }); + + // Make sure we don't have a stale file warning yet + let has_stale_warning = initial_request.messages.iter().any(|msg| { + msg.string_contents() + .contains("These files changed since last read:") + }); + assert!( + !has_stale_warning, + "Should not have stale buffer warning before buffer is modified" + ); + + // Modify the buffer + buffer.update(cx, |buffer, cx| { + // Find a position at the end of line 1 + buffer.edit( + [(1..1, "\n println!(\"Added a new line\");\n")], + None, + cx, + ); + }); + + // Insert another user message without context + thread.update(cx, |thread, cx| { + thread.insert_user_message("What does the code do now?", vec![], None, cx) + }); + + // Create a new request and check for the stale buffer warning + let new_request = thread.read_with(cx, |thread, cx| { + thread.to_completion_request(RequestKind::Chat, cx) + }); + + // We should have a stale file warning as the last message + let last_message = new_request + .messages + .last() + .expect("Request should have messages"); + + // The last message should be the stale buffer notification + assert_eq!(last_message.role, Role::User); + + // Check the exact content of the message + let expected_content = "These files changed since last read:\n- code.rs\n"; + assert_eq!( + last_message.string_contents(), + expected_content, + "Last message should be exactly the stale buffer notification" + ); + } + + fn init_test_settings(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + AssistantSettings::register(cx); + thread_store::init(cx); + workspace::init_settings(cx); + ThemeSettings::register(cx); + ContextServerSettings::register(cx); + EditorSettings::register(cx); + }); + } + + // Helper to create a test project with test files + async fn create_test_project( + cx: &mut TestAppContext, + files: serde_json::Value, + ) -> Entity { + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), files).await; + Project::test(fs, [path!("/test").as_ref()], cx).await + } + + async fn setup_test_environment( + cx: &mut TestAppContext, + project: Entity, + ) -> ( + Entity, + Entity, + Entity, + Entity, + ) { + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let thread_store = cx.update(|_, cx| { + ThreadStore::new( + project.clone(), + Arc::default(), + Arc::new(PromptBuilder::new(None).unwrap()), + cx, + ) + .unwrap() + }); + + let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); + let context_store = cx.new(|_cx| ContextStore::new(workspace.downgrade(), None)); + + (workspace, thread_store, thread, context_store) + } + + async fn add_file_to_context( + project: &Entity, + context_store: &Entity, + path: &str, + cx: &mut TestAppContext, + ) -> Result> { + let buffer_path = project + .read_with(cx, |project, cx| project.find_project_path(path, cx)) + .unwrap(); + + let buffer = project + .update(cx, |project, cx| project.open_buffer(buffer_path, cx)) + .await + .unwrap(); + + context_store + .update(cx, |store, cx| { + store.add_file_from_buffer(buffer.clone(), cx) + }) + .await?; + + Ok(buffer) + } +} diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index bd976b104b..2a838d2866 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -374,6 +374,8 @@ pub struct SerializedMessage { pub tool_uses: Vec, #[serde(default)] pub tool_results: Vec, + #[serde(default)] + pub context: String, } #[derive(Debug, Serialize, Deserialize)] @@ -441,6 +443,7 @@ impl LegacySerializedMessage { segments: vec![SerializedMessageSegment::Text { text: self.text }], tool_uses: self.tool_uses, tool_results: self.tool_results, + context: String::new(), } } } diff --git a/crates/assistant_eval/src/headless_assistant.rs b/crates/assistant_eval/src/headless_assistant.rs index 26ed900e47..4f82604d48 100644 --- a/crates/assistant_eval/src/headless_assistant.rs +++ b/crates/assistant_eval/src/headless_assistant.rs @@ -158,7 +158,7 @@ impl HeadlessAssistant { let model_registry = LanguageModelRegistry::read_global(cx); if let Some(model) = model_registry.active_model() { thread.update(cx, |thread, cx| { - thread.attach_tool_results(vec![], cx); + thread.attach_tool_results(cx); thread.send_to_model(model, RequestKind::Chat, cx); }); } else { diff --git a/crates/assistant_tool/src/action_log.rs b/crates/assistant_tool/src/action_log.rs index d754fe284e..b8dd354b09 100644 --- a/crates/assistant_tool/src/action_log.rs +++ b/crates/assistant_tool/src/action_log.rs @@ -1,6 +1,6 @@ use anyhow::{Context as _, Result}; use buffer_diff::BufferDiff; -use collections::{BTreeMap, HashSet}; +use collections::BTreeMap; use futures::{StreamExt, channel::mpsc}; use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity}; use language::{Anchor, Buffer, BufferEvent, DiskState, Point}; @@ -10,9 +10,6 @@ use util::RangeExt; /// Tracks actions performed by tools in a thread pub struct ActionLog { - /// Buffers that user manually added to the context, and whose content has - /// changed since the model last saw them. - stale_buffers_in_context: HashSet>, /// Buffers that we want to notify the model about when they change. tracked_buffers: BTreeMap, TrackedBuffer>, /// Has the model edited a file since it last checked diagnostics? @@ -23,7 +20,6 @@ impl ActionLog { /// Creates a new, empty action log. pub fn new() -> Self { Self { - stale_buffers_in_context: HashSet::default(), tracked_buffers: BTreeMap::default(), edited_since_project_diagnostics_check: false, } @@ -259,6 +255,11 @@ impl ActionLog { self.track_buffer(buffer, false, cx); } + /// Track a buffer that was added as context, so we can notify the model about user edits. + pub fn buffer_added_as_context(&mut self, buffer: Entity, cx: &mut Context) { + self.track_buffer(buffer, false, cx); + } + /// Track a buffer as read, so we can notify the model about user edits. pub fn will_create_buffer(&mut self, buffer: Entity, cx: &mut Context) { self.track_buffer(buffer.clone(), true, cx); @@ -268,7 +269,6 @@ impl ActionLog { /// Mark a buffer as edited, so we can refresh it in the context pub fn buffer_edited(&mut self, buffer: Entity, cx: &mut Context) { self.edited_since_project_diagnostics_check = true; - self.stale_buffers_in_context.insert(buffer.clone()); let tracked_buffer = self.track_buffer(buffer.clone(), false, cx); if let TrackedBufferStatus::Deleted = tracked_buffer.status { @@ -391,11 +391,6 @@ impl ActionLog { }) .map(|(buffer, _)| buffer) } - - /// Takes and returns the set of buffers pending refresh, clearing internal state. - pub fn take_stale_buffers_in_context(&mut self) -> HashSet> { - std::mem::take(&mut self.stale_buffers_in_context) - } } fn apply_non_conflicting_edits(