diff --git a/Cargo.lock b/Cargo.lock index 8e64dde937..13ad82070a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -462,6 +462,7 @@ dependencies = [ "async-watch", "chrono", "client", + "clock", "collections", "command_palette_hooks", "context_server", @@ -476,6 +477,7 @@ dependencies = [ "html_to_markdown", "http_client", "indoc", + "itertools 0.13.0", "language", "language_model", "language_model_selector", diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index c67674b437..8b52a946a7 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -19,6 +19,7 @@ assets.workspace = true assistant_tool.workspace = true async-watch.workspace = true client.workspace = true +clock.workspace = true chrono.workspace = true collections.workspace = true command_palette_hooks.workspace = true @@ -33,6 +34,7 @@ gpui.workspace = true handlebars.workspace = true html_to_markdown.workspace = true http_client.workspace = true +itertools.workspace = true language.workspace = true language_model.workspace = true language_model_selector.workspace = true diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index 75da3bb11f..bcbc1e8431 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -282,11 +282,13 @@ impl ActiveThread { .child(div().p_2p5().text_ui(cx).child(markdown.clone())) .when_some(context, |parent, context| { if !context.is_empty() { - parent.child(h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children( - context.iter().map(|context| { - ContextPill::new_added(context.clone(), false, None) - }), - )) + parent.child( + h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children( + context.into_iter().map(|context| { + ContextPill::new_added(context, false, None) + }), + ), + ) } else { parent } diff --git a/crates/assistant2/src/buffer_codegen.rs b/crates/assistant2/src/buffer_codegen.rs index aacbcccddf..90d830dea3 100644 --- a/crates/assistant2/src/buffer_codegen.rs +++ b/crates/assistant2/src/buffer_codegen.rs @@ -421,8 +421,7 @@ impl CodegenAlternative { }; if let Some(context_store) = &self.context_store { - let context = context_store.update(cx, |this, _cx| this.context().clone()); - attach_context_to_message(&mut request_message, context); + attach_context_to_message(&mut request_message, context_store.read(cx).snapshot(cx)); } request_message.content.push(prompt.into()); @@ -1053,7 +1052,7 @@ mod tests { stream::{self}, Stream, }; - use gpui::{Context, TestAppContext}; + use gpui::TestAppContext; use indoc::indoc; use language::{ language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher, diff --git a/crates/assistant2/src/context.rs b/crates/assistant2/src/context.rs index 0620707e47..55a0b71ba7 100644 --- a/crates/assistant2/src/context.rs +++ b/crates/assistant2/src/context.rs @@ -1,8 +1,17 @@ -use gpui::SharedString; +use std::path::Path; +use std::rc::Rc; +use std::sync::Arc; + +use collections::BTreeMap; +use gpui::{AppContext, Model, SharedString}; +use language::Buffer; use language_model::{LanguageModelRequestMessage, MessageContent}; use serde::{Deserialize, Serialize}; +use text::BufferId; use util::post_inc; +use crate::thread::Thread; + #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] pub struct ContextId(pub(crate) usize); @@ -14,16 +23,17 @@ impl ContextId { /// Some context attached to a message in a thread. #[derive(Debug, Clone)] -pub struct Context { +pub struct ContextSnapshot { pub id: ContextId, pub name: SharedString, pub parent: Option, pub tooltip: Option, pub kind: ContextKind, + /// Text to send to the model. This is not refreshed by `snapshot`. pub text: SharedString, } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ContextKind { File, Directory, @@ -31,18 +41,139 @@ pub enum ContextKind { Thread, } +#[derive(Debug)] +pub enum Context { + File(FileContext), + Directory(DirectoryContext), + FetchedUrl(FetchedUrlContext), + Thread(ThreadContext), +} + +impl Context { + pub fn id(&self) -> ContextId { + match self { + Self::File(file) => file.id, + Self::Directory(directory) => directory.snapshot.id, + Self::FetchedUrl(url) => url.id, + Self::Thread(thread) => thread.id, + } + } +} + +// TODO: Model holds onto the buffer even if the file is deleted and closed. Should remove +// the context from the message editor in this case. + +#[derive(Debug)] +pub struct FileContext { + pub id: ContextId, + pub buffer: Model, + #[allow(unused)] + pub version: clock::Global, + pub text: SharedString, +} + +#[derive(Debug)] +pub struct DirectoryContext { + #[allow(unused)] + pub path: Rc, + // TODO: The choice to make this a BTreeMap was a result of use in a version of + // ContextStore::will_include_buffer before I realized that the path logic should be used there + // too. + #[allow(unused)] + pub buffers: BTreeMap, clock::Global)>, + pub snapshot: ContextSnapshot, +} + +#[derive(Debug)] +pub struct FetchedUrlContext { + pub id: ContextId, + pub url: SharedString, + pub text: SharedString, +} + +// TODO: Model holds onto the thread even if the thread is deleted. Can either handle this +// explicitly or have a WeakModel and remove during snapshot. + +#[derive(Debug)] +pub struct ThreadContext { + pub id: ContextId, + pub thread: Model, + pub text: SharedString, +} + +impl Context { + pub fn snapshot(&self, cx: &AppContext) -> Option { + match &self { + Self::File(file_context) => { + let path = file_context.path(cx)?; + let full_path: SharedString = path.to_string_lossy().into_owned().into(); + let name = match path.file_name() { + Some(name) => name.to_string_lossy().into_owned().into(), + None => full_path.clone(), + }; + let parent = path + .parent() + .and_then(|p| p.file_name()) + .map(|p| p.to_string_lossy().into_owned().into()); + + Some(ContextSnapshot { + id: self.id(), + name, + parent, + tooltip: Some(full_path), + kind: ContextKind::File, + text: file_context.text.clone(), + }) + } + Self::Directory(DirectoryContext { snapshot, .. }) => Some(snapshot.clone()), + Self::FetchedUrl(FetchedUrlContext { url, text, id }) => Some(ContextSnapshot { + id: *id, + name: url.clone(), + parent: None, + tooltip: None, + kind: ContextKind::FetchedUrl, + text: text.clone(), + }), + Self::Thread(thread_context) => { + let thread = thread_context.thread.read(cx); + + Some(ContextSnapshot { + id: self.id(), + name: thread.summary().unwrap_or("New thread".into()), + parent: None, + tooltip: None, + kind: ContextKind::Thread, + text: thread_context.text.clone(), + }) + } + } + } +} + +impl FileContext { + pub fn path(&self, cx: &AppContext) -> Option> { + let buffer = self.buffer.read(cx); + if let Some(file) = buffer.file() { + Some(file.path().clone()) + } else { + log::error!("Buffer that had a path unexpectedly no longer has a path."); + None + } + } +} + pub fn attach_context_to_message( message: &mut LanguageModelRequestMessage, - context: impl IntoIterator, + contexts: impl Iterator, ) { let mut file_context = String::new(); let mut directory_context = String::new(); let mut fetch_context = String::new(); let mut thread_context = String::new(); - for context in context.into_iter() { + for context in contexts { match context.kind { - ContextKind::File { .. } => { + ContextKind::File => { file_context.push_str(&context.text); file_context.push('\n'); } @@ -56,7 +187,7 @@ pub fn attach_context_to_message( fetch_context.push_str(&context.text); fetch_context.push('\n'); } - ContextKind::Thread => { + ContextKind::Thread { .. } => { thread_context.push_str(&context.name); thread_context.push('\n'); thread_context.push_str(&context.text); diff --git a/crates/assistant2/src/context_picker/directory_context_picker.rs b/crates/assistant2/src/context_picker/directory_context_picker.rs index a8c76a1554..2a6d0a7fa8 100644 --- a/crates/assistant2/src/context_picker/directory_context_picker.rs +++ b/crates/assistant2/src/context_picker/directory_context_picker.rs @@ -240,7 +240,7 @@ impl PickerDelegate for DirectoryContextPickerDelegate { let added = self.context_store.upgrade().map_or(false, |context_store| { context_store .read(cx) - .included_directory(&path_match.path) + .includes_directory(&path_match.path) .is_some() }); diff --git a/crates/assistant2/src/context_picker/fetch_context_picker.rs b/crates/assistant2/src/context_picker/fetch_context_picker.rs index 24bee6ea2f..6704150096 100644 --- a/crates/assistant2/src/context_picker/fetch_context_picker.rs +++ b/crates/assistant2/src/context_picker/fetch_context_picker.rs @@ -82,10 +82,12 @@ impl FetchContextPickerDelegate { } async fn build_message(http_client: Arc, url: &str) -> Result { - let mut url = url.to_owned(); - if !url.starts_with("https://") && !url.starts_with("http://") { - url = format!("https://{url}"); - } + let prefixed_url = if !url.starts_with("https://") && !url.starts_with("http://") { + Some(format!("https://{url}")) + } else { + None + }; + let url = prefixed_url.as_deref().unwrap_or(url); let mut response = http_client.get(&url, AsyncBody::default(), true).await?; @@ -200,7 +202,7 @@ impl PickerDelegate for FetchContextPickerDelegate { this.delegate .context_store .update(cx, |context_store, _cx| { - if context_store.included_url(&url).is_none() { + if context_store.includes_url(&url).is_none() { context_store.insert_fetched_url(url, text); } })?; @@ -234,7 +236,7 @@ impl PickerDelegate for FetchContextPickerDelegate { cx: &mut ViewContext>, ) -> Option { let added = self.context_store.upgrade().map_or(false, |context_store| { - context_store.read(cx).included_url(&self.url).is_some() + context_store.read(cx).includes_url(&self.url).is_some() }); Some( diff --git a/crates/assistant2/src/context_picker/file_context_picker.rs b/crates/assistant2/src/context_picker/file_context_picker.rs index 8b75433569..288483bf2f 100644 --- a/crates/assistant2/src/context_picker/file_context_picker.rs +++ b/crates/assistant2/src/context_picker/file_context_picker.rs @@ -11,7 +11,7 @@ use util::ResultExt as _; use workspace::Workspace; use crate::context_picker::{ConfirmBehavior, ContextPicker}; -use crate::context_store::{ContextStore, IncludedFile}; +use crate::context_store::{ContextStore, FileInclusion}; pub struct FileContextPicker { picker: View>, @@ -275,10 +275,11 @@ impl PickerDelegate for FileContextPickerDelegate { (file_name, Some(directory)) }; - let added = self - .context_store - .upgrade() - .and_then(|context_store| context_store.read(cx).included_file(&path_match.path)); + let added = self.context_store.upgrade().and_then(|context_store| { + context_store + .read(cx) + .will_include_file_path(&path_match.path, cx) + }); Some( ListItem::new(ix) @@ -295,7 +296,7 @@ impl PickerDelegate for FileContextPickerDelegate { })), ) .when_some(added, |el, added| match added { - IncludedFile::Direct(_) => el.end_slot( + FileInclusion::Direct(_) => el.end_slot( h_flex() .gap_1() .child( @@ -305,7 +306,7 @@ impl PickerDelegate for FileContextPickerDelegate { ) .child(Label::new("Added").size(LabelSize::Small)), ), - IncludedFile::InDirectory(dir_name) => { + FileInclusion::InDirectory(dir_name) => { let dir_name = dir_name.to_string_lossy().into_owned(); el.end_slot( diff --git a/crates/assistant2/src/context_picker/thread_context_picker.rs b/crates/assistant2/src/context_picker/thread_context_picker.rs index 86578cb8f7..db09082bda 100644 --- a/crates/assistant2/src/context_picker/thread_context_picker.rs +++ b/crates/assistant2/src/context_picker/thread_context_picker.rs @@ -194,7 +194,7 @@ impl PickerDelegate for ThreadContextPickerDelegate { let thread = &self.matches[ix]; let added = self.context_store.upgrade().map_or(false, |context_store| { - context_store.read(cx).included_thread(&thread.id).is_some() + context_store.read(cx).includes_thread(&thread.id).is_some() }); Some( diff --git a/crates/assistant2/src/context_store.rs b/crates/assistant2/src/context_store.rs index 3ae7bfb9b4..84b4d0b97b 100644 --- a/crates/assistant2/src/context_store.rs +++ b/crates/assistant2/src/context_store.rs @@ -3,23 +3,25 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use anyhow::{anyhow, bail, Result}; -use collections::{HashMap, HashSet}; -use gpui::{Model, ModelContext, SharedString, Task, WeakView}; +use collections::{BTreeMap, HashMap}; +use gpui::{AppContext, Model, ModelContext, SharedString, Task, WeakView}; use language::Buffer; use project::{ProjectPath, Worktree}; +use text::BufferId; use workspace::Workspace; -use crate::thread::Thread; -use crate::{ - context::{Context, ContextId, ContextKind}, - thread::ThreadId, +use crate::context::{ + Context, ContextId, ContextKind, ContextSnapshot, DirectoryContext, FetchedUrlContext, + FileContext, ThreadContext, }; +use crate::thread::{Thread, ThreadId}; pub struct ContextStore { workspace: WeakView, context: Vec, + // TODO: If an EntityId is used for all context types (like BufferId), can remove ContextId. next_context_id: ContextId, - files: HashMap, + files: BTreeMap, directories: HashMap, threads: HashMap, fetched_urls: HashMap, @@ -31,13 +33,22 @@ impl ContextStore { workspace, context: Vec::new(), next_context_id: ContextId(0), - files: HashMap::default(), + files: BTreeMap::default(), directories: HashMap::default(), threads: HashMap::default(), fetched_urls: HashMap::default(), } } + pub fn snapshot<'a>( + &'a self, + cx: &'a AppContext, + ) -> impl Iterator + 'a { + self.context() + .iter() + .flat_map(|context| context.snapshot(cx)) + } + pub fn context(&self) -> &Vec { &self.context } @@ -63,64 +74,54 @@ impl ContextStore { return Task::ready(Err(anyhow!("failed to read project"))); }; - let already_included = match self.included_file(&project_path.path) { - Some(IncludedFile::Direct(context_id)) => { - self.remove_context(&context_id); - true - } - Some(IncludedFile::InDirectory(_)) => true, - None => false, - }; - if already_included { - return Task::ready(Ok(())); - } - cx.spawn(|this, mut cx| async move { - let open_buffer_task = - project.update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?; + let open_buffer_task = project.update(&mut cx, |project, cx| { + project.open_buffer(project_path.clone(), cx) + })?; let buffer = open_buffer_task.await?; + let buffer_id = buffer.update(&mut cx, |buffer, _cx| buffer.remote_id())?; + + let already_included = this.update(&mut cx, |this, _cx| { + match this.will_include_buffer(buffer_id, &project_path.path) { + Some(FileInclusion::Direct(context_id)) => { + this.remove_context(context_id); + true + } + Some(FileInclusion::InDirectory(_)) => true, + None => false, + } + })?; + + if already_included { + return anyhow::Ok(()); + } + this.update(&mut cx, |this, cx| { - this.insert_file(buffer.read(cx)); + this.insert_file(buffer, cx); })?; anyhow::Ok(()) }) } - pub fn insert_file(&mut self, buffer: &Buffer) { + pub fn insert_file(&mut self, buffer_model: Model, cx: &AppContext) { + let buffer = buffer_model.read(cx); let Some(file) = buffer.file() else { return; }; - let path = file.path(); + let mut text = String::new(); + push_fenced_codeblock(file.path(), buffer.text(), &mut text); let id = self.next_context_id.post_inc(); - self.files.insert(path.to_path_buf(), id); - - let full_path: SharedString = path.to_string_lossy().into_owned().into(); - - let name = match path.file_name() { - Some(name) => name.to_string_lossy().into_owned().into(), - None => full_path.clone(), - }; - - let parent = path - .parent() - .and_then(|p| p.file_name()) - .map(|p| p.to_string_lossy().into_owned().into()); - - let mut text = String::new(); - push_fenced_codeblock(path, buffer.text(), &mut text); - - self.context.push(Context { + self.files.insert(buffer.remote_id(), id); + self.context.push(Context::File(FileContext { id, - name, - parent, - tooltip: Some(full_path), - kind: ContextKind::File, + buffer: buffer_model, + version: buffer.version.clone(), text: text.into(), - }); + })); } pub fn add_directory( @@ -136,9 +137,9 @@ impl ContextStore { return Task::ready(Err(anyhow!("failed to read project"))); }; - let already_included = if let Some(context_id) = self.included_directory(&project_path.path) + let already_included = if let Some(context_id) = self.includes_directory(&project_path.path) { - self.remove_context(&context_id); + self.remove_context(context_id); true } else { false @@ -178,23 +179,24 @@ impl ContextStore { this.update(&mut cx, |this, cx| { let mut text = String::new(); - let mut added_files = 0; - - for buffer in buffers.into_iter().flatten() { - let buffer = buffer.read(cx); + let mut directory_buffers = BTreeMap::new(); + for buffer_model in buffers { + let buffer_model = buffer_model?; + let buffer = buffer_model.read(cx); let path = buffer.file().map_or(&project_path.path, |file| file.path()); push_fenced_codeblock(&path, buffer.text(), &mut text); - added_files += 1; + directory_buffers + .insert(buffer.remote_id(), (buffer_model, buffer.version.clone())); } - if added_files == 0 { + if directory_buffers.is_empty() { bail!( "could not read any text files from {}", &project_path.path.display() ); } - this.insert_directory(&project_path.path, text); + this.insert_directory(&project_path.path, directory_buffers, text); anyhow::Ok(()) })??; @@ -203,7 +205,12 @@ impl ContextStore { }) } - pub fn insert_directory(&mut self, path: &Path, text: impl Into) { + pub fn insert_directory( + &mut self, + path: &Path, + buffers: BTreeMap, clock::Global)>, + text: impl Into, + ) { let id = self.next_context_id.post_inc(); self.directories.insert(path.to_path_buf(), id); @@ -219,78 +226,104 @@ impl ContextStore { .and_then(|p| p.file_name()) .map(|p| p.to_string_lossy().into_owned().into()); - self.context.push(Context { - id, - name, - parent, - tooltip: Some(full_path), - kind: ContextKind::Directory, - text: text.into(), - }); + self.context.push(Context::Directory(DirectoryContext { + path: path.into(), + buffers, + snapshot: ContextSnapshot { + id, + name, + parent, + tooltip: Some(full_path), + kind: ContextKind::Directory, + text: text.into(), + }, + })); } pub fn add_thread(&mut self, thread: Model, cx: &mut ModelContext) { - if let Some(context_id) = self.included_thread(&thread.read(cx).id()) { - self.remove_context(&context_id); + if let Some(context_id) = self.includes_thread(&thread.read(cx).id()) { + self.remove_context(context_id); } else { - self.insert_thread(thread.read(cx)); + self.insert_thread(thread, cx); } } - pub fn insert_thread(&mut self, thread: &Thread) { - let context_id = self.next_context_id.post_inc(); - self.threads.insert(thread.id().clone(), context_id); + pub fn insert_thread(&mut self, thread: Model, cx: &AppContext) { + let id = self.next_context_id.post_inc(); + let thread_ref = thread.read(cx); + let text = thread_ref.text().into(); - self.context.push(Context { - id: context_id, - name: thread.summary().unwrap_or("New thread".into()), - parent: None, - tooltip: None, - kind: ContextKind::Thread, - text: thread.text().into(), - }); + self.threads.insert(thread_ref.id().clone(), id); + self.context + .push(Context::Thread(ThreadContext { id, thread, text })); } pub fn insert_fetched_url(&mut self, url: String, text: impl Into) { - let context_id = self.next_context_id.post_inc(); - self.fetched_urls.insert(url.clone(), context_id); + let id = self.next_context_id.post_inc(); - self.context.push(Context { - id: context_id, - name: url.into(), - parent: None, - tooltip: None, - kind: ContextKind::FetchedUrl, + self.fetched_urls.insert(url.clone(), id); + self.context.push(Context::FetchedUrl(FetchedUrlContext { + id, + url: url.into(), text: text.into(), - }); + })); } - pub fn remove_context(&mut self, id: &ContextId) { - let Some(ix) = self.context.iter().position(|context| context.id == *id) else { + pub fn remove_context(&mut self, id: ContextId) { + let Some(ix) = self.context.iter().position(|context| context.id() == id) else { return; }; - match self.context.remove(ix).kind { - ContextKind::File => { - self.files.retain(|_, context_id| context_id != id); + match self.context.remove(ix) { + Context::File(_) => { + self.files.retain(|_, context_id| *context_id != id); } - ContextKind::Directory => { - self.directories.retain(|_, context_id| context_id != id); + Context::Directory(_) => { + self.directories.retain(|_, context_id| *context_id != id); } - ContextKind::FetchedUrl => { - self.fetched_urls.retain(|_, context_id| context_id != id); + Context::FetchedUrl(_) => { + self.fetched_urls.retain(|_, context_id| *context_id != id); } - ContextKind::Thread => { - self.threads.retain(|_, context_id| context_id != id); + Context::Thread(_) => { + self.threads.retain(|_, context_id| *context_id != id); } } } - pub fn included_file(&self, path: &Path) -> Option { - if let Some(id) = self.files.get(path) { - return Some(IncludedFile::Direct(*id)); + /// Returns whether the buffer is already included directly in the context, or if it will be + /// included in the context via a directory. Directory inclusion is based on paths rather than + /// buffer IDs as the directory will be re-scanned. + pub fn will_include_buffer(&self, buffer_id: BufferId, path: &Path) -> Option { + if let Some(context_id) = self.files.get(&buffer_id) { + return Some(FileInclusion::Direct(*context_id)); } + self.will_include_file_path_via_directory(path) + } + + /// Returns whether this file path is already included directly in the context, or if it will be + /// included in the context via a directory. + pub fn will_include_file_path(&self, path: &Path, cx: &AppContext) -> Option { + if !self.files.is_empty() { + let found_file_context = self.context.iter().find(|context| match &context { + Context::File(file_context) => { + if let Some(file_path) = file_context.path(cx) { + *file_path == *path + } else { + false + } + } + _ => false, + }); + if let Some(context) = found_file_context { + return Some(FileInclusion::Direct(context.id())); + } + } + + self.will_include_file_path_via_directory(path) + } + + fn will_include_file_path_via_directory(&self, path: &Path) -> Option { if self.directories.is_empty() { return None; } @@ -299,40 +332,27 @@ impl ContextStore { while buf.pop() { if let Some(_) = self.directories.get(&buf) { - return Some(IncludedFile::InDirectory(buf)); + return Some(FileInclusion::InDirectory(buf)); } } None } - pub fn included_directory(&self, path: &Path) -> Option { + pub fn includes_directory(&self, path: &Path) -> Option { self.directories.get(path).copied() } - pub fn included_thread(&self, thread_id: &ThreadId) -> Option { + pub fn includes_thread(&self, thread_id: &ThreadId) -> Option { self.threads.get(thread_id).copied() } - pub fn included_url(&self, url: &str) -> Option { + pub fn includes_url(&self, url: &str) -> Option { self.fetched_urls.get(url).copied() } - - pub fn duplicated_names(&self) -> HashSet { - let mut seen = HashSet::default(); - let mut dupes = HashSet::default(); - - for context in self.context().iter() { - if !seen.insert(&context.name) { - dupes.insert(context.name.clone()); - } - } - - dupes - } } -pub enum IncludedFile { +pub enum FileInclusion { Direct(ContextId), InDirectory(PathBuf), } diff --git a/crates/assistant2/src/context_strip.rs b/crates/assistant2/src/context_strip.rs index ded54edc82..f62d799169 100644 --- a/crates/assistant2/src/context_strip.rs +++ b/crates/assistant2/src/context_strip.rs @@ -1,10 +1,12 @@ use std::rc::Rc; +use collections::HashSet; use editor::Editor; use gpui::{ AppContext, DismissEvent, EventEmitter, FocusHandle, Model, Subscription, View, WeakModel, WeakView, }; +use itertools::Itertools; use language::Buffer; use ui::{prelude::*, KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip}; use workspace::Workspace; @@ -73,11 +75,17 @@ impl ContextStrip { let active_item = workspace.read(cx).active_item(cx)?; let editor = active_item.to_any().downcast::().ok()?.read(cx); - let active_buffer = editor.buffer().read(cx).as_singleton()?; + let active_buffer_model = editor.buffer().read(cx).as_singleton()?; + let active_buffer = active_buffer_model.read(cx); - let path = active_buffer.read(cx).file()?.path(); + let path = active_buffer.file()?.path(); - if self.context_store.read(cx).included_file(path).is_some() { + if self + .context_store + .read(cx) + .will_include_buffer(active_buffer.remote_id(), path) + .is_some() + { return None; } @@ -88,7 +96,7 @@ impl ContextStrip { Some(SuggestedContext::File { name, - buffer: active_buffer.downgrade(), + buffer: active_buffer_model.downgrade(), }) } @@ -106,7 +114,7 @@ impl ContextStrip { if self .context_store .read(cx) - .included_thread(active_thread.id()) + .includes_thread(active_thread.id()) .is_some() { return None; @@ -131,13 +139,24 @@ impl ContextStrip { impl Render for ContextStrip { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { let context_store = self.context_store.read(cx); - let context = context_store.context().clone(); + let context = context_store + .context() + .iter() + .flat_map(|context| context.snapshot(cx)) + .collect::>(); let context_picker = self.context_picker.clone(); let focus_handle = self.focus_handle.clone(); let suggested_context = self.suggested_context(cx); - let dupe_names = context_store.duplicated_names(); + let dupe_names = context + .iter() + .map(|context| context.name.clone()) + .sorted() + .tuple_windows() + .filter(|(a, b)| a == b) + .map(|(a, _)| a) + .collect::>(); h_flex() .flex_wrap() @@ -194,11 +213,11 @@ impl Render for ContextStrip { context.clone(), dupe_names.contains(&context.name), Some({ - let context = context.clone(); + let id = context.id; let context_store = self.context_store.clone(); Rc::new(cx.listener(move |_this, _event, cx| { context_store.update(cx, |this, _cx| { - this.remove_context(&context.id); + this.remove_context(id); }); cx.notify(); })) @@ -284,12 +303,12 @@ impl SuggestedContext { match self { Self::File { buffer, name: _ } => { if let Some(buffer) = buffer.upgrade() { - context_store.insert_file(buffer.read(cx)); + context_store.insert_file(buffer, cx); }; } Self::Thread { thread, name: _ } => { if let Some(thread) = thread.upgrade() { - context_store.insert_thread(thread.read(cx)); + context_store.insert_thread(thread, cx); }; } } diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index 5621032b8f..e9a7b4fc8e 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -147,11 +147,10 @@ impl MessageEditor { editor.clear(cx); text }); - let context = self - .context_store - .update(cx, |this, _cx| this.context().clone()); - self.thread.update(cx, |thread, cx| { + let thread = self.thread.clone(); + thread.update(cx, |thread, cx| { + let context = self.context_store.read(cx).snapshot(cx).collect::>(); thread.insert_user_message(user_message, context, cx); let mut request = thread.to_completion_request(request_kind, cx); diff --git a/crates/assistant2/src/terminal_inline_assistant.rs b/crates/assistant2/src/terminal_inline_assistant.rs index d3773de0a2..1b75dc2c3e 100644 --- a/crates/assistant2/src/terminal_inline_assistant.rs +++ b/crates/assistant2/src/terminal_inline_assistant.rs @@ -245,10 +245,10 @@ impl TerminalInlineAssistant { cache: false, }; - let context = assist - .context_store - .update(cx, |this, _cx| this.context().clone()); - attach_context_to_message(&mut request_message, context); + attach_context_to_message( + &mut request_message, + assist.context_store.read(cx).snapshot(cx), + ); request_message.content.push(prompt.into()); diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index f3de36f8a7..b2949680c4 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use anyhow::Result; use assistant_tool::ToolWorkingSet; use chrono::{DateTime, Utc}; -use collections::{HashMap, HashSet}; +use collections::{BTreeMap, HashMap, HashSet}; use futures::future::Shared; use futures::{FutureExt as _, StreamExt as _}; use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task}; @@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize}; use util::{post_inc, TryFutureExt as _}; use uuid::Uuid; -use crate::context::{attach_context_to_message, Context, ContextId}; +use crate::context::{attach_context_to_message, ContextId, ContextSnapshot}; #[derive(Debug, Clone, Copy)] pub enum RequestKind { @@ -64,7 +64,7 @@ pub struct Thread { pending_summary: Task>, messages: Vec, next_message_id: MessageId, - context: HashMap, + context: BTreeMap, context_by_message: HashMap>, completion_count: usize, pending_completions: Vec, @@ -83,7 +83,7 @@ impl Thread { pending_summary: Task::ready(None), messages: Vec::new(), next_message_id: MessageId(0), - context: HashMap::default(), + context: BTreeMap::default(), context_by_message: HashMap::default(), completion_count: 0, pending_completions: Vec::new(), @@ -131,7 +131,7 @@ impl Thread { &self.tools } - pub fn context_for_message(&self, id: MessageId) -> Option> { + pub fn context_for_message(&self, id: MessageId) -> Option> { let context = self.context_by_message.get(&id)?; Some( context @@ -149,7 +149,7 @@ impl Thread { pub fn insert_user_message( &mut self, text: impl Into, - context: Vec, + context: Vec, cx: &mut ModelContext, ) { let message_id = self.insert_message(Role::User, text, cx); diff --git a/crates/assistant2/src/ui/context_pill.rs b/crates/assistant2/src/ui/context_pill.rs index c6d2d3e7b0..e70169ed5f 100644 --- a/crates/assistant2/src/ui/context_pill.rs +++ b/crates/assistant2/src/ui/context_pill.rs @@ -3,12 +3,12 @@ use std::rc::Rc; use gpui::ClickEvent; use ui::{prelude::*, IconButtonShape, Tooltip}; -use crate::context::{Context, ContextKind}; +use crate::context::{ContextKind, ContextSnapshot}; #[derive(IntoElement)] pub enum ContextPill { Added { - context: Context, + context: ContextSnapshot, dupe_name: bool, on_remove: Option>, }, @@ -21,7 +21,7 @@ pub enum ContextPill { impl ContextPill { pub fn new_added( - context: Context, + context: ContextSnapshot, dupe_name: bool, on_remove: Option>, ) -> Self { @@ -49,10 +49,10 @@ impl ContextPill { } } - pub fn kind(&self) -> &ContextKind { + pub fn kind(&self) -> ContextKind { match self { - Self::Added { context, .. } => &context.kind, - Self::Suggested { kind, .. } => kind, + Self::Added { context, .. } => context.kind, + Self::Suggested { kind, .. } => *kind, } } }