assistant2: Add live context type and use in message editor (#22865)

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
Michael Sloan 2025-01-08 14:47:58 -07:00 committed by GitHub
parent 5d8ef94c86
commit a0fca24e3f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 359 additions and 182 deletions

2
Cargo.lock generated
View file

@ -462,6 +462,7 @@ dependencies = [
"async-watch", "async-watch",
"chrono", "chrono",
"client", "client",
"clock",
"collections", "collections",
"command_palette_hooks", "command_palette_hooks",
"context_server", "context_server",
@ -476,6 +477,7 @@ dependencies = [
"html_to_markdown", "html_to_markdown",
"http_client", "http_client",
"indoc", "indoc",
"itertools 0.13.0",
"language", "language",
"language_model", "language_model",
"language_model_selector", "language_model_selector",

View file

@ -19,6 +19,7 @@ assets.workspace = true
assistant_tool.workspace = true assistant_tool.workspace = true
async-watch.workspace = true async-watch.workspace = true
client.workspace = true client.workspace = true
clock.workspace = true
chrono.workspace = true chrono.workspace = true
collections.workspace = true collections.workspace = true
command_palette_hooks.workspace = true command_palette_hooks.workspace = true
@ -33,6 +34,7 @@ gpui.workspace = true
handlebars.workspace = true handlebars.workspace = true
html_to_markdown.workspace = true html_to_markdown.workspace = true
http_client.workspace = true http_client.workspace = true
itertools.workspace = true
language.workspace = true language.workspace = true
language_model.workspace = true language_model.workspace = true
language_model_selector.workspace = true language_model_selector.workspace = true

View file

@ -282,11 +282,13 @@ impl ActiveThread {
.child(div().p_2p5().text_ui(cx).child(markdown.clone())) .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
.when_some(context, |parent, context| { .when_some(context, |parent, context| {
if !context.is_empty() { if !context.is_empty() {
parent.child(h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children( parent.child(
context.iter().map(|context| { h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
ContextPill::new_added(context.clone(), false, None) context.into_iter().map(|context| {
}), ContextPill::new_added(context, false, None)
)) }),
),
)
} else { } else {
parent parent
} }

View file

@ -421,8 +421,7 @@ impl CodegenAlternative {
}; };
if let Some(context_store) = &self.context_store { if let Some(context_store) = &self.context_store {
let context = context_store.update(cx, |this, _cx| this.context().clone()); attach_context_to_message(&mut request_message, context_store.read(cx).snapshot(cx));
attach_context_to_message(&mut request_message, context);
} }
request_message.content.push(prompt.into()); request_message.content.push(prompt.into());
@ -1053,7 +1052,7 @@ mod tests {
stream::{self}, stream::{self},
Stream, Stream,
}; };
use gpui::{Context, TestAppContext}; use gpui::TestAppContext;
use indoc::indoc; use indoc::indoc;
use language::{ use language::{
language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher, language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,

View file

@ -1,8 +1,17 @@
use gpui::SharedString; use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use collections::BTreeMap;
use gpui::{AppContext, Model, SharedString};
use language::Buffer;
use language_model::{LanguageModelRequestMessage, MessageContent}; use language_model::{LanguageModelRequestMessage, MessageContent};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use text::BufferId;
use util::post_inc; use util::post_inc;
use crate::thread::Thread;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
pub struct ContextId(pub(crate) usize); pub struct ContextId(pub(crate) usize);
@ -14,16 +23,17 @@ impl ContextId {
/// Some context attached to a message in a thread. /// Some context attached to a message in a thread.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Context { pub struct ContextSnapshot {
pub id: ContextId, pub id: ContextId,
pub name: SharedString, pub name: SharedString,
pub parent: Option<SharedString>, pub parent: Option<SharedString>,
pub tooltip: Option<SharedString>, pub tooltip: Option<SharedString>,
pub kind: ContextKind, pub kind: ContextKind,
/// Text to send to the model. This is not refreshed by `snapshot`.
pub text: SharedString, pub text: SharedString,
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ContextKind { pub enum ContextKind {
File, File,
Directory, Directory,
@ -31,18 +41,139 @@ pub enum ContextKind {
Thread, Thread,
} }
#[derive(Debug)]
pub enum Context {
File(FileContext),
Directory(DirectoryContext),
FetchedUrl(FetchedUrlContext),
Thread(ThreadContext),
}
impl Context {
pub fn id(&self) -> ContextId {
match self {
Self::File(file) => file.id,
Self::Directory(directory) => directory.snapshot.id,
Self::FetchedUrl(url) => url.id,
Self::Thread(thread) => thread.id,
}
}
}
// TODO: Model<Buffer> holds onto the buffer even if the file is deleted and closed. Should remove
// the context from the message editor in this case.
#[derive(Debug)]
pub struct FileContext {
pub id: ContextId,
pub buffer: Model<Buffer>,
#[allow(unused)]
pub version: clock::Global,
pub text: SharedString,
}
#[derive(Debug)]
pub struct DirectoryContext {
#[allow(unused)]
pub path: Rc<Path>,
// TODO: The choice to make this a BTreeMap was a result of use in a version of
// ContextStore::will_include_buffer before I realized that the path logic should be used there
// too.
#[allow(unused)]
pub buffers: BTreeMap<BufferId, (Model<Buffer>, clock::Global)>,
pub snapshot: ContextSnapshot,
}
#[derive(Debug)]
pub struct FetchedUrlContext {
pub id: ContextId,
pub url: SharedString,
pub text: SharedString,
}
// TODO: Model<Thread> holds onto the thread even if the thread is deleted. Can either handle this
// explicitly or have a WeakModel<Thread> and remove during snapshot.
#[derive(Debug)]
pub struct ThreadContext {
pub id: ContextId,
pub thread: Model<Thread>,
pub text: SharedString,
}
impl Context {
pub fn snapshot(&self, cx: &AppContext) -> Option<ContextSnapshot> {
match &self {
Self::File(file_context) => {
let path = file_context.path(cx)?;
let full_path: SharedString = path.to_string_lossy().into_owned().into();
let name = match path.file_name() {
Some(name) => name.to_string_lossy().into_owned().into(),
None => full_path.clone(),
};
let parent = path
.parent()
.and_then(|p| p.file_name())
.map(|p| p.to_string_lossy().into_owned().into());
Some(ContextSnapshot {
id: self.id(),
name,
parent,
tooltip: Some(full_path),
kind: ContextKind::File,
text: file_context.text.clone(),
})
}
Self::Directory(DirectoryContext { snapshot, .. }) => Some(snapshot.clone()),
Self::FetchedUrl(FetchedUrlContext { url, text, id }) => Some(ContextSnapshot {
id: *id,
name: url.clone(),
parent: None,
tooltip: None,
kind: ContextKind::FetchedUrl,
text: text.clone(),
}),
Self::Thread(thread_context) => {
let thread = thread_context.thread.read(cx);
Some(ContextSnapshot {
id: self.id(),
name: thread.summary().unwrap_or("New thread".into()),
parent: None,
tooltip: None,
kind: ContextKind::Thread,
text: thread_context.text.clone(),
})
}
}
}
}
impl FileContext {
pub fn path(&self, cx: &AppContext) -> Option<Arc<Path>> {
let buffer = self.buffer.read(cx);
if let Some(file) = buffer.file() {
Some(file.path().clone())
} else {
log::error!("Buffer that had a path unexpectedly no longer has a path.");
None
}
}
}
pub fn attach_context_to_message( pub fn attach_context_to_message(
message: &mut LanguageModelRequestMessage, message: &mut LanguageModelRequestMessage,
context: impl IntoIterator<Item = Context>, contexts: impl Iterator<Item = ContextSnapshot>,
) { ) {
let mut file_context = String::new(); let mut file_context = String::new();
let mut directory_context = String::new(); let mut directory_context = String::new();
let mut fetch_context = String::new(); let mut fetch_context = String::new();
let mut thread_context = String::new(); let mut thread_context = String::new();
for context in context.into_iter() { for context in contexts {
match context.kind { match context.kind {
ContextKind::File { .. } => { ContextKind::File => {
file_context.push_str(&context.text); file_context.push_str(&context.text);
file_context.push('\n'); file_context.push('\n');
} }
@ -56,7 +187,7 @@ pub fn attach_context_to_message(
fetch_context.push_str(&context.text); fetch_context.push_str(&context.text);
fetch_context.push('\n'); fetch_context.push('\n');
} }
ContextKind::Thread => { ContextKind::Thread { .. } => {
thread_context.push_str(&context.name); thread_context.push_str(&context.name);
thread_context.push('\n'); thread_context.push('\n');
thread_context.push_str(&context.text); thread_context.push_str(&context.text);

View file

@ -240,7 +240,7 @@ impl PickerDelegate for DirectoryContextPickerDelegate {
let added = self.context_store.upgrade().map_or(false, |context_store| { let added = self.context_store.upgrade().map_or(false, |context_store| {
context_store context_store
.read(cx) .read(cx)
.included_directory(&path_match.path) .includes_directory(&path_match.path)
.is_some() .is_some()
}); });

View file

@ -82,10 +82,12 @@ impl FetchContextPickerDelegate {
} }
async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> { async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
let mut url = url.to_owned(); let prefixed_url = if !url.starts_with("https://") && !url.starts_with("http://") {
if !url.starts_with("https://") && !url.starts_with("http://") { Some(format!("https://{url}"))
url = format!("https://{url}"); } else {
} None
};
let url = prefixed_url.as_deref().unwrap_or(url);
let mut response = http_client.get(&url, AsyncBody::default(), true).await?; let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
@ -200,7 +202,7 @@ impl PickerDelegate for FetchContextPickerDelegate {
this.delegate this.delegate
.context_store .context_store
.update(cx, |context_store, _cx| { .update(cx, |context_store, _cx| {
if context_store.included_url(&url).is_none() { if context_store.includes_url(&url).is_none() {
context_store.insert_fetched_url(url, text); context_store.insert_fetched_url(url, text);
} }
})?; })?;
@ -234,7 +236,7 @@ impl PickerDelegate for FetchContextPickerDelegate {
cx: &mut ViewContext<Picker<Self>>, cx: &mut ViewContext<Picker<Self>>,
) -> Option<Self::ListItem> { ) -> Option<Self::ListItem> {
let added = self.context_store.upgrade().map_or(false, |context_store| { let added = self.context_store.upgrade().map_or(false, |context_store| {
context_store.read(cx).included_url(&self.url).is_some() context_store.read(cx).includes_url(&self.url).is_some()
}); });
Some( Some(

View file

@ -11,7 +11,7 @@ use util::ResultExt as _;
use workspace::Workspace; use workspace::Workspace;
use crate::context_picker::{ConfirmBehavior, ContextPicker}; use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_store::{ContextStore, IncludedFile}; use crate::context_store::{ContextStore, FileInclusion};
pub struct FileContextPicker { pub struct FileContextPicker {
picker: View<Picker<FileContextPickerDelegate>>, picker: View<Picker<FileContextPickerDelegate>>,
@ -275,10 +275,11 @@ impl PickerDelegate for FileContextPickerDelegate {
(file_name, Some(directory)) (file_name, Some(directory))
}; };
let added = self let added = self.context_store.upgrade().and_then(|context_store| {
.context_store context_store
.upgrade() .read(cx)
.and_then(|context_store| context_store.read(cx).included_file(&path_match.path)); .will_include_file_path(&path_match.path, cx)
});
Some( Some(
ListItem::new(ix) ListItem::new(ix)
@ -295,7 +296,7 @@ impl PickerDelegate for FileContextPickerDelegate {
})), })),
) )
.when_some(added, |el, added| match added { .when_some(added, |el, added| match added {
IncludedFile::Direct(_) => el.end_slot( FileInclusion::Direct(_) => el.end_slot(
h_flex() h_flex()
.gap_1() .gap_1()
.child( .child(
@ -305,7 +306,7 @@ impl PickerDelegate for FileContextPickerDelegate {
) )
.child(Label::new("Added").size(LabelSize::Small)), .child(Label::new("Added").size(LabelSize::Small)),
), ),
IncludedFile::InDirectory(dir_name) => { FileInclusion::InDirectory(dir_name) => {
let dir_name = dir_name.to_string_lossy().into_owned(); let dir_name = dir_name.to_string_lossy().into_owned();
el.end_slot( el.end_slot(

View file

@ -194,7 +194,7 @@ impl PickerDelegate for ThreadContextPickerDelegate {
let thread = &self.matches[ix]; let thread = &self.matches[ix];
let added = self.context_store.upgrade().map_or(false, |context_store| { let added = self.context_store.upgrade().map_or(false, |context_store| {
context_store.read(cx).included_thread(&thread.id).is_some() context_store.read(cx).includes_thread(&thread.id).is_some()
}); });
Some( Some(

View file

@ -3,23 +3,25 @@ use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use anyhow::{anyhow, bail, Result}; use anyhow::{anyhow, bail, Result};
use collections::{HashMap, HashSet}; use collections::{BTreeMap, HashMap};
use gpui::{Model, ModelContext, SharedString, Task, WeakView}; use gpui::{AppContext, Model, ModelContext, SharedString, Task, WeakView};
use language::Buffer; use language::Buffer;
use project::{ProjectPath, Worktree}; use project::{ProjectPath, Worktree};
use text::BufferId;
use workspace::Workspace; use workspace::Workspace;
use crate::thread::Thread; use crate::context::{
use crate::{ Context, ContextId, ContextKind, ContextSnapshot, DirectoryContext, FetchedUrlContext,
context::{Context, ContextId, ContextKind}, FileContext, ThreadContext,
thread::ThreadId,
}; };
use crate::thread::{Thread, ThreadId};
pub struct ContextStore { pub struct ContextStore {
workspace: WeakView<Workspace>, workspace: WeakView<Workspace>,
context: Vec<Context>, context: Vec<Context>,
// TODO: If an EntityId is used for all context types (like BufferId), can remove ContextId.
next_context_id: ContextId, next_context_id: ContextId,
files: HashMap<PathBuf, ContextId>, files: BTreeMap<BufferId, ContextId>,
directories: HashMap<PathBuf, ContextId>, directories: HashMap<PathBuf, ContextId>,
threads: HashMap<ThreadId, ContextId>, threads: HashMap<ThreadId, ContextId>,
fetched_urls: HashMap<String, ContextId>, fetched_urls: HashMap<String, ContextId>,
@ -31,13 +33,22 @@ impl ContextStore {
workspace, workspace,
context: Vec::new(), context: Vec::new(),
next_context_id: ContextId(0), next_context_id: ContextId(0),
files: HashMap::default(), files: BTreeMap::default(),
directories: HashMap::default(), directories: HashMap::default(),
threads: HashMap::default(), threads: HashMap::default(),
fetched_urls: HashMap::default(), fetched_urls: HashMap::default(),
} }
} }
pub fn snapshot<'a>(
&'a self,
cx: &'a AppContext,
) -> impl Iterator<Item = ContextSnapshot> + 'a {
self.context()
.iter()
.flat_map(|context| context.snapshot(cx))
}
pub fn context(&self) -> &Vec<Context> { pub fn context(&self) -> &Vec<Context> {
&self.context &self.context
} }
@ -63,64 +74,54 @@ impl ContextStore {
return Task::ready(Err(anyhow!("failed to read project"))); return Task::ready(Err(anyhow!("failed to read project")));
}; };
let already_included = match self.included_file(&project_path.path) {
Some(IncludedFile::Direct(context_id)) => {
self.remove_context(&context_id);
true
}
Some(IncludedFile::InDirectory(_)) => true,
None => false,
};
if already_included {
return Task::ready(Ok(()));
}
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
let open_buffer_task = let open_buffer_task = project.update(&mut cx, |project, cx| {
project.update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?; project.open_buffer(project_path.clone(), cx)
})?;
let buffer = open_buffer_task.await?; let buffer = open_buffer_task.await?;
let buffer_id = buffer.update(&mut cx, |buffer, _cx| buffer.remote_id())?;
let already_included = this.update(&mut cx, |this, _cx| {
match this.will_include_buffer(buffer_id, &project_path.path) {
Some(FileInclusion::Direct(context_id)) => {
this.remove_context(context_id);
true
}
Some(FileInclusion::InDirectory(_)) => true,
None => false,
}
})?;
if already_included {
return anyhow::Ok(());
}
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.insert_file(buffer.read(cx)); this.insert_file(buffer, cx);
})?; })?;
anyhow::Ok(()) anyhow::Ok(())
}) })
} }
pub fn insert_file(&mut self, buffer: &Buffer) { pub fn insert_file(&mut self, buffer_model: Model<Buffer>, cx: &AppContext) {
let buffer = buffer_model.read(cx);
let Some(file) = buffer.file() else { let Some(file) = buffer.file() else {
return; return;
}; };
let path = file.path(); let mut text = String::new();
push_fenced_codeblock(file.path(), buffer.text(), &mut text);
let id = self.next_context_id.post_inc(); let id = self.next_context_id.post_inc();
self.files.insert(path.to_path_buf(), id); self.files.insert(buffer.remote_id(), id);
self.context.push(Context::File(FileContext {
let full_path: SharedString = path.to_string_lossy().into_owned().into();
let name = match path.file_name() {
Some(name) => name.to_string_lossy().into_owned().into(),
None => full_path.clone(),
};
let parent = path
.parent()
.and_then(|p| p.file_name())
.map(|p| p.to_string_lossy().into_owned().into());
let mut text = String::new();
push_fenced_codeblock(path, buffer.text(), &mut text);
self.context.push(Context {
id, id,
name, buffer: buffer_model,
parent, version: buffer.version.clone(),
tooltip: Some(full_path),
kind: ContextKind::File,
text: text.into(), text: text.into(),
}); }));
} }
pub fn add_directory( pub fn add_directory(
@ -136,9 +137,9 @@ impl ContextStore {
return Task::ready(Err(anyhow!("failed to read project"))); return Task::ready(Err(anyhow!("failed to read project")));
}; };
let already_included = if let Some(context_id) = self.included_directory(&project_path.path) let already_included = if let Some(context_id) = self.includes_directory(&project_path.path)
{ {
self.remove_context(&context_id); self.remove_context(context_id);
true true
} else { } else {
false false
@ -178,23 +179,24 @@ impl ContextStore {
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
let mut text = String::new(); let mut text = String::new();
let mut added_files = 0; let mut directory_buffers = BTreeMap::new();
for buffer_model in buffers {
for buffer in buffers.into_iter().flatten() { let buffer_model = buffer_model?;
let buffer = buffer.read(cx); let buffer = buffer_model.read(cx);
let path = buffer.file().map_or(&project_path.path, |file| file.path()); let path = buffer.file().map_or(&project_path.path, |file| file.path());
push_fenced_codeblock(&path, buffer.text(), &mut text); push_fenced_codeblock(&path, buffer.text(), &mut text);
added_files += 1; directory_buffers
.insert(buffer.remote_id(), (buffer_model, buffer.version.clone()));
} }
if added_files == 0 { if directory_buffers.is_empty() {
bail!( bail!(
"could not read any text files from {}", "could not read any text files from {}",
&project_path.path.display() &project_path.path.display()
); );
} }
this.insert_directory(&project_path.path, text); this.insert_directory(&project_path.path, directory_buffers, text);
anyhow::Ok(()) anyhow::Ok(())
})??; })??;
@ -203,7 +205,12 @@ impl ContextStore {
}) })
} }
pub fn insert_directory(&mut self, path: &Path, text: impl Into<SharedString>) { pub fn insert_directory(
&mut self,
path: &Path,
buffers: BTreeMap<BufferId, (Model<Buffer>, clock::Global)>,
text: impl Into<SharedString>,
) {
let id = self.next_context_id.post_inc(); let id = self.next_context_id.post_inc();
self.directories.insert(path.to_path_buf(), id); self.directories.insert(path.to_path_buf(), id);
@ -219,78 +226,104 @@ impl ContextStore {
.and_then(|p| p.file_name()) .and_then(|p| p.file_name())
.map(|p| p.to_string_lossy().into_owned().into()); .map(|p| p.to_string_lossy().into_owned().into());
self.context.push(Context { self.context.push(Context::Directory(DirectoryContext {
id, path: path.into(),
name, buffers,
parent, snapshot: ContextSnapshot {
tooltip: Some(full_path), id,
kind: ContextKind::Directory, name,
text: text.into(), parent,
}); tooltip: Some(full_path),
kind: ContextKind::Directory,
text: text.into(),
},
}));
} }
pub fn add_thread(&mut self, thread: Model<Thread>, cx: &mut ModelContext<Self>) { pub fn add_thread(&mut self, thread: Model<Thread>, cx: &mut ModelContext<Self>) {
if let Some(context_id) = self.included_thread(&thread.read(cx).id()) { if let Some(context_id) = self.includes_thread(&thread.read(cx).id()) {
self.remove_context(&context_id); self.remove_context(context_id);
} else { } else {
self.insert_thread(thread.read(cx)); self.insert_thread(thread, cx);
} }
} }
pub fn insert_thread(&mut self, thread: &Thread) { pub fn insert_thread(&mut self, thread: Model<Thread>, cx: &AppContext) {
let context_id = self.next_context_id.post_inc(); let id = self.next_context_id.post_inc();
self.threads.insert(thread.id().clone(), context_id); let thread_ref = thread.read(cx);
let text = thread_ref.text().into();
self.context.push(Context { self.threads.insert(thread_ref.id().clone(), id);
id: context_id, self.context
name: thread.summary().unwrap_or("New thread".into()), .push(Context::Thread(ThreadContext { id, thread, text }));
parent: None,
tooltip: None,
kind: ContextKind::Thread,
text: thread.text().into(),
});
} }
pub fn insert_fetched_url(&mut self, url: String, text: impl Into<SharedString>) { pub fn insert_fetched_url(&mut self, url: String, text: impl Into<SharedString>) {
let context_id = self.next_context_id.post_inc(); let id = self.next_context_id.post_inc();
self.fetched_urls.insert(url.clone(), context_id);
self.context.push(Context { self.fetched_urls.insert(url.clone(), id);
id: context_id, self.context.push(Context::FetchedUrl(FetchedUrlContext {
name: url.into(), id,
parent: None, url: url.into(),
tooltip: None,
kind: ContextKind::FetchedUrl,
text: text.into(), text: text.into(),
}); }));
} }
pub fn remove_context(&mut self, id: &ContextId) { pub fn remove_context(&mut self, id: ContextId) {
let Some(ix) = self.context.iter().position(|context| context.id == *id) else { let Some(ix) = self.context.iter().position(|context| context.id() == id) else {
return; return;
}; };
match self.context.remove(ix).kind { match self.context.remove(ix) {
ContextKind::File => { Context::File(_) => {
self.files.retain(|_, context_id| context_id != id); self.files.retain(|_, context_id| *context_id != id);
} }
ContextKind::Directory => { Context::Directory(_) => {
self.directories.retain(|_, context_id| context_id != id); self.directories.retain(|_, context_id| *context_id != id);
} }
ContextKind::FetchedUrl => { Context::FetchedUrl(_) => {
self.fetched_urls.retain(|_, context_id| context_id != id); self.fetched_urls.retain(|_, context_id| *context_id != id);
} }
ContextKind::Thread => { Context::Thread(_) => {
self.threads.retain(|_, context_id| context_id != id); self.threads.retain(|_, context_id| *context_id != id);
} }
} }
} }
pub fn included_file(&self, path: &Path) -> Option<IncludedFile> { /// Returns whether the buffer is already included directly in the context, or if it will be
if let Some(id) = self.files.get(path) { /// included in the context via a directory. Directory inclusion is based on paths rather than
return Some(IncludedFile::Direct(*id)); /// buffer IDs as the directory will be re-scanned.
pub fn will_include_buffer(&self, buffer_id: BufferId, path: &Path) -> Option<FileInclusion> {
if let Some(context_id) = self.files.get(&buffer_id) {
return Some(FileInclusion::Direct(*context_id));
} }
self.will_include_file_path_via_directory(path)
}
/// Returns whether this file path is already included directly in the context, or if it will be
/// included in the context via a directory.
pub fn will_include_file_path(&self, path: &Path, cx: &AppContext) -> Option<FileInclusion> {
if !self.files.is_empty() {
let found_file_context = self.context.iter().find(|context| match &context {
Context::File(file_context) => {
if let Some(file_path) = file_context.path(cx) {
*file_path == *path
} else {
false
}
}
_ => false,
});
if let Some(context) = found_file_context {
return Some(FileInclusion::Direct(context.id()));
}
}
self.will_include_file_path_via_directory(path)
}
fn will_include_file_path_via_directory(&self, path: &Path) -> Option<FileInclusion> {
if self.directories.is_empty() { if self.directories.is_empty() {
return None; return None;
} }
@ -299,40 +332,27 @@ impl ContextStore {
while buf.pop() { while buf.pop() {
if let Some(_) = self.directories.get(&buf) { if let Some(_) = self.directories.get(&buf) {
return Some(IncludedFile::InDirectory(buf)); return Some(FileInclusion::InDirectory(buf));
} }
} }
None None
} }
pub fn included_directory(&self, path: &Path) -> Option<ContextId> { pub fn includes_directory(&self, path: &Path) -> Option<ContextId> {
self.directories.get(path).copied() self.directories.get(path).copied()
} }
pub fn included_thread(&self, thread_id: &ThreadId) -> Option<ContextId> { pub fn includes_thread(&self, thread_id: &ThreadId) -> Option<ContextId> {
self.threads.get(thread_id).copied() self.threads.get(thread_id).copied()
} }
pub fn included_url(&self, url: &str) -> Option<ContextId> { pub fn includes_url(&self, url: &str) -> Option<ContextId> {
self.fetched_urls.get(url).copied() self.fetched_urls.get(url).copied()
} }
pub fn duplicated_names(&self) -> HashSet<SharedString> {
let mut seen = HashSet::default();
let mut dupes = HashSet::default();
for context in self.context().iter() {
if !seen.insert(&context.name) {
dupes.insert(context.name.clone());
}
}
dupes
}
} }
pub enum IncludedFile { pub enum FileInclusion {
Direct(ContextId), Direct(ContextId),
InDirectory(PathBuf), InDirectory(PathBuf),
} }

View file

@ -1,10 +1,12 @@
use std::rc::Rc; use std::rc::Rc;
use collections::HashSet;
use editor::Editor; use editor::Editor;
use gpui::{ use gpui::{
AppContext, DismissEvent, EventEmitter, FocusHandle, Model, Subscription, View, WeakModel, AppContext, DismissEvent, EventEmitter, FocusHandle, Model, Subscription, View, WeakModel,
WeakView, WeakView,
}; };
use itertools::Itertools;
use language::Buffer; use language::Buffer;
use ui::{prelude::*, KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip}; use ui::{prelude::*, KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip};
use workspace::Workspace; use workspace::Workspace;
@ -73,11 +75,17 @@ impl ContextStrip {
let active_item = workspace.read(cx).active_item(cx)?; let active_item = workspace.read(cx).active_item(cx)?;
let editor = active_item.to_any().downcast::<Editor>().ok()?.read(cx); let editor = active_item.to_any().downcast::<Editor>().ok()?.read(cx);
let active_buffer = editor.buffer().read(cx).as_singleton()?; let active_buffer_model = editor.buffer().read(cx).as_singleton()?;
let active_buffer = active_buffer_model.read(cx);
let path = active_buffer.read(cx).file()?.path(); let path = active_buffer.file()?.path();
if self.context_store.read(cx).included_file(path).is_some() { if self
.context_store
.read(cx)
.will_include_buffer(active_buffer.remote_id(), path)
.is_some()
{
return None; return None;
} }
@ -88,7 +96,7 @@ impl ContextStrip {
Some(SuggestedContext::File { Some(SuggestedContext::File {
name, name,
buffer: active_buffer.downgrade(), buffer: active_buffer_model.downgrade(),
}) })
} }
@ -106,7 +114,7 @@ impl ContextStrip {
if self if self
.context_store .context_store
.read(cx) .read(cx)
.included_thread(active_thread.id()) .includes_thread(active_thread.id())
.is_some() .is_some()
{ {
return None; return None;
@ -131,13 +139,24 @@ impl ContextStrip {
impl Render for ContextStrip { impl Render for ContextStrip {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement { fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let context_store = self.context_store.read(cx); let context_store = self.context_store.read(cx);
let context = context_store.context().clone(); let context = context_store
.context()
.iter()
.flat_map(|context| context.snapshot(cx))
.collect::<Vec<_>>();
let context_picker = self.context_picker.clone(); let context_picker = self.context_picker.clone();
let focus_handle = self.focus_handle.clone(); let focus_handle = self.focus_handle.clone();
let suggested_context = self.suggested_context(cx); let suggested_context = self.suggested_context(cx);
let dupe_names = context_store.duplicated_names(); let dupe_names = context
.iter()
.map(|context| context.name.clone())
.sorted()
.tuple_windows()
.filter(|(a, b)| a == b)
.map(|(a, _)| a)
.collect::<HashSet<SharedString>>();
h_flex() h_flex()
.flex_wrap() .flex_wrap()
@ -194,11 +213,11 @@ impl Render for ContextStrip {
context.clone(), context.clone(),
dupe_names.contains(&context.name), dupe_names.contains(&context.name),
Some({ Some({
let context = context.clone(); let id = context.id;
let context_store = self.context_store.clone(); let context_store = self.context_store.clone();
Rc::new(cx.listener(move |_this, _event, cx| { Rc::new(cx.listener(move |_this, _event, cx| {
context_store.update(cx, |this, _cx| { context_store.update(cx, |this, _cx| {
this.remove_context(&context.id); this.remove_context(id);
}); });
cx.notify(); cx.notify();
})) }))
@ -284,12 +303,12 @@ impl SuggestedContext {
match self { match self {
Self::File { buffer, name: _ } => { Self::File { buffer, name: _ } => {
if let Some(buffer) = buffer.upgrade() { if let Some(buffer) = buffer.upgrade() {
context_store.insert_file(buffer.read(cx)); context_store.insert_file(buffer, cx);
}; };
} }
Self::Thread { thread, name: _ } => { Self::Thread { thread, name: _ } => {
if let Some(thread) = thread.upgrade() { if let Some(thread) = thread.upgrade() {
context_store.insert_thread(thread.read(cx)); context_store.insert_thread(thread, cx);
}; };
} }
} }

View file

@ -147,11 +147,10 @@ impl MessageEditor {
editor.clear(cx); editor.clear(cx);
text text
}); });
let context = self
.context_store
.update(cx, |this, _cx| this.context().clone());
self.thread.update(cx, |thread, cx| { let thread = self.thread.clone();
thread.update(cx, |thread, cx| {
let context = self.context_store.read(cx).snapshot(cx).collect::<Vec<_>>();
thread.insert_user_message(user_message, context, cx); thread.insert_user_message(user_message, context, cx);
let mut request = thread.to_completion_request(request_kind, cx); let mut request = thread.to_completion_request(request_kind, cx);

View file

@ -245,10 +245,10 @@ impl TerminalInlineAssistant {
cache: false, cache: false,
}; };
let context = assist attach_context_to_message(
.context_store &mut request_message,
.update(cx, |this, _cx| this.context().clone()); assist.context_store.read(cx).snapshot(cx),
attach_context_to_message(&mut request_message, context); );
request_message.content.push(prompt.into()); request_message.content.push(prompt.into());

View file

@ -3,7 +3,7 @@ use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use assistant_tool::ToolWorkingSet; use assistant_tool::ToolWorkingSet;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::{HashMap, HashSet}; use collections::{BTreeMap, HashMap, HashSet};
use futures::future::Shared; use futures::future::Shared;
use futures::{FutureExt as _, StreamExt as _}; use futures::{FutureExt as _, StreamExt as _};
use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task}; use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task};
@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
use util::{post_inc, TryFutureExt as _}; use util::{post_inc, TryFutureExt as _};
use uuid::Uuid; use uuid::Uuid;
use crate::context::{attach_context_to_message, Context, ContextId}; use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub enum RequestKind { pub enum RequestKind {
@ -64,7 +64,7 @@ pub struct Thread {
pending_summary: Task<Option<()>>, pending_summary: Task<Option<()>>,
messages: Vec<Message>, messages: Vec<Message>,
next_message_id: MessageId, next_message_id: MessageId,
context: HashMap<ContextId, Context>, context: BTreeMap<ContextId, ContextSnapshot>,
context_by_message: HashMap<MessageId, Vec<ContextId>>, context_by_message: HashMap<MessageId, Vec<ContextId>>,
completion_count: usize, completion_count: usize,
pending_completions: Vec<PendingCompletion>, pending_completions: Vec<PendingCompletion>,
@ -83,7 +83,7 @@ impl Thread {
pending_summary: Task::ready(None), pending_summary: Task::ready(None),
messages: Vec::new(), messages: Vec::new(),
next_message_id: MessageId(0), next_message_id: MessageId(0),
context: HashMap::default(), context: BTreeMap::default(),
context_by_message: HashMap::default(), context_by_message: HashMap::default(),
completion_count: 0, completion_count: 0,
pending_completions: Vec::new(), pending_completions: Vec::new(),
@ -131,7 +131,7 @@ impl Thread {
&self.tools &self.tools
} }
pub fn context_for_message(&self, id: MessageId) -> Option<Vec<Context>> { pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
let context = self.context_by_message.get(&id)?; let context = self.context_by_message.get(&id)?;
Some( Some(
context context
@ -149,7 +149,7 @@ impl Thread {
pub fn insert_user_message( pub fn insert_user_message(
&mut self, &mut self,
text: impl Into<String>, text: impl Into<String>,
context: Vec<Context>, context: Vec<ContextSnapshot>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) { ) {
let message_id = self.insert_message(Role::User, text, cx); let message_id = self.insert_message(Role::User, text, cx);

View file

@ -3,12 +3,12 @@ use std::rc::Rc;
use gpui::ClickEvent; use gpui::ClickEvent;
use ui::{prelude::*, IconButtonShape, Tooltip}; use ui::{prelude::*, IconButtonShape, Tooltip};
use crate::context::{Context, ContextKind}; use crate::context::{ContextKind, ContextSnapshot};
#[derive(IntoElement)] #[derive(IntoElement)]
pub enum ContextPill { pub enum ContextPill {
Added { Added {
context: Context, context: ContextSnapshot,
dupe_name: bool, dupe_name: bool,
on_remove: Option<Rc<dyn Fn(&ClickEvent, &mut WindowContext)>>, on_remove: Option<Rc<dyn Fn(&ClickEvent, &mut WindowContext)>>,
}, },
@ -21,7 +21,7 @@ pub enum ContextPill {
impl ContextPill { impl ContextPill {
pub fn new_added( pub fn new_added(
context: Context, context: ContextSnapshot,
dupe_name: bool, dupe_name: bool,
on_remove: Option<Rc<dyn Fn(&ClickEvent, &mut WindowContext)>>, on_remove: Option<Rc<dyn Fn(&ClickEvent, &mut WindowContext)>>,
) -> Self { ) -> Self {
@ -49,10 +49,10 @@ impl ContextPill {
} }
} }
pub fn kind(&self) -> &ContextKind { pub fn kind(&self) -> ContextKind {
match self { match self {
Self::Added { context, .. } => &context.kind, Self::Added { context, .. } => context.kind,
Self::Suggested { kind, .. } => kind, Self::Suggested { kind, .. } => *kind,
} }
} }
} }