assistant2: Implement refresh of context on message editor send (#22944)

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-01-10 01:09:47 -07:00 committed by GitHub
parent 0b105ba8b7
commit 767f44bd27
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 254 additions and 96 deletions

View file

@ -1,6 +1,5 @@
use std::path::Path; use std::path::Path;
use std::rc::Rc; use std::rc::Rc;
use std::sync::Arc;
use file_icons::FileIcons; use file_icons::FileIcons;
use gpui::{AppContext, Model, SharedString}; use gpui::{AppContext, Model, SharedString};
@ -11,7 +10,7 @@ use text::BufferId;
use ui::IconName; use ui::IconName;
use util::post_inc; use util::post_inc;
use crate::thread::Thread; use crate::{context_store::buffer_path_log_err, thread::Thread};
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
pub struct ContextId(pub(crate) usize); pub struct ContextId(pub(crate) usize);
@ -76,7 +75,7 @@ impl Context {
#[derive(Debug)] #[derive(Debug)]
pub struct FileContext { pub struct FileContext {
pub id: ContextId, pub id: ContextId,
pub buffer: ContextBuffer, pub context_buffer: ContextBuffer,
} }
#[derive(Debug)] #[derive(Debug)]
@ -84,7 +83,7 @@ pub struct DirectoryContext {
#[allow(unused)] #[allow(unused)]
pub path: Rc<Path>, pub path: Rc<Path>,
#[allow(unused)] #[allow(unused)]
pub buffers: Vec<ContextBuffer>, pub context_buffers: Vec<ContextBuffer>,
pub snapshot: ContextSnapshot, pub snapshot: ContextSnapshot,
} }
@ -108,7 +107,7 @@ pub struct ThreadContext {
// TODO: Model<Buffer> holds onto the buffer even if the file is deleted and closed. Should remove // TODO: Model<Buffer> holds onto the buffer even if the file is deleted and closed. Should remove
// the context from the message editor in this case. // the context from the message editor in this case.
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct ContextBuffer { pub struct ContextBuffer {
#[allow(unused)] #[allow(unused)]
pub id: BufferId, pub id: BufferId,
@ -130,18 +129,9 @@ impl Context {
} }
impl FileContext { impl FileContext {
pub fn path(&self, cx: &AppContext) -> Option<Arc<Path>> {
let buffer = self.buffer.buffer.read(cx);
if let Some(file) = buffer.file() {
Some(file.path().clone())
} else {
log::error!("Buffer that had a path unexpectedly no longer has a path.");
None
}
}
pub fn snapshot(&self, cx: &AppContext) -> Option<ContextSnapshot> { pub fn snapshot(&self, cx: &AppContext) -> Option<ContextSnapshot> {
let path = self.path(cx)?; let buffer = self.context_buffer.buffer.read(cx);
let path = buffer_path_log_err(buffer)?;
let full_path: SharedString = path.to_string_lossy().into_owned().into(); let full_path: SharedString = path.to_string_lossy().into_owned().into();
let name = match path.file_name() { let name = match path.file_name() {
Some(name) => name.to_string_lossy().into_owned().into(), Some(name) => name.to_string_lossy().into_owned().into(),
@ -161,12 +151,51 @@ impl FileContext {
tooltip: Some(full_path), tooltip: Some(full_path),
icon_path, icon_path,
kind: ContextKind::File, kind: ContextKind::File,
text: Box::new([self.buffer.text.clone()]), text: Box::new([self.context_buffer.text.clone()]),
}) })
} }
} }
impl DirectoryContext { impl DirectoryContext {
pub fn new(
id: ContextId,
path: &Path,
context_buffers: Vec<ContextBuffer>,
) -> DirectoryContext {
let full_path: SharedString = path.to_string_lossy().into_owned().into();
let name = match path.file_name() {
Some(name) => name.to_string_lossy().into_owned().into(),
None => full_path.clone(),
};
let parent = path
.parent()
.and_then(|p| p.file_name())
.map(|p| p.to_string_lossy().into_owned().into());
// TODO: include directory path in text?
let text = context_buffers
.iter()
.map(|b| b.text.clone())
.collect::<Vec<_>>()
.into();
DirectoryContext {
path: path.into(),
context_buffers,
snapshot: ContextSnapshot {
id,
name,
parent,
tooltip: Some(full_path),
icon_path: None,
kind: ContextKind::Directory,
text,
},
}
}
pub fn snapshot(&self) -> ContextSnapshot { pub fn snapshot(&self) -> ContextSnapshot {
self.snapshot.clone() self.snapshot.clone()
} }

View file

@ -3,6 +3,7 @@ use std::sync::Arc;
use anyhow::{anyhow, bail, Result}; use anyhow::{anyhow, bail, Result};
use collections::{BTreeMap, HashMap}; use collections::{BTreeMap, HashMap};
use futures::{self, future, Future, FutureExt};
use gpui::{AppContext, AsyncAppContext, Model, ModelContext, SharedString, Task, WeakView}; use gpui::{AppContext, AsyncAppContext, Model, ModelContext, SharedString, Task, WeakView};
use language::Buffer; use language::Buffer;
use project::{ProjectPath, Worktree}; use project::{ProjectPath, Worktree};
@ -11,8 +12,8 @@ use text::BufferId;
use workspace::Workspace; use workspace::Workspace;
use crate::context::{ use crate::context::{
Context, ContextBuffer, ContextId, ContextKind, ContextSnapshot, DirectoryContext, Context, ContextBuffer, ContextId, ContextSnapshot, DirectoryContext, FetchedUrlContext,
FetchedUrlContext, FileContext, ThreadContext, FileContext, ThreadContext,
}; };
use crate::thread::{Thread, ThreadId}; use crate::thread::{Thread, ThreadId};
@ -104,7 +105,7 @@ impl ContextStore {
project_path.path.clone(), project_path.path.clone(),
buffer_model, buffer_model,
buffer, buffer,
&cx.to_async(), cx.to_async(),
) )
})?; })?;
@ -133,7 +134,7 @@ impl ContextStore {
file.path().clone(), file.path().clone(),
buffer_model, buffer_model,
buffer, buffer,
&cx.to_async(), cx.to_async(),
)) ))
})??; })??;
@ -150,10 +151,8 @@ impl ContextStore {
pub fn insert_file(&mut self, context_buffer: ContextBuffer) { pub fn insert_file(&mut self, context_buffer: ContextBuffer) {
let id = self.next_context_id.post_inc(); let id = self.next_context_id.post_inc();
self.files.insert(context_buffer.id, id); self.files.insert(context_buffer.id, id);
self.context.push(Context::File(FileContext { self.context
id, .push(Context::File(FileContext { id, context_buffer }));
buffer: context_buffer,
}));
} }
pub fn add_directory( pub fn add_directory(
@ -207,7 +206,7 @@ impl ContextStore {
.collect::<Vec<_>>() .collect::<Vec<_>>()
})?; })?;
let buffers = futures::future::join_all(open_buffer_tasks).await; let buffers = future::join_all(open_buffer_tasks).await;
let mut buffer_infos = Vec::new(); let mut buffer_infos = Vec::new();
let mut text_tasks = Vec::new(); let mut text_tasks = Vec::new();
@ -216,68 +215,41 @@ impl ContextStore {
let buffer_model = buffer_model?; let buffer_model = buffer_model?;
let buffer = buffer_model.read(cx); let buffer = buffer_model.read(cx);
let (buffer_info, text_task) = let (buffer_info, text_task) =
collect_buffer_info_and_text(path, buffer_model, buffer, &cx.to_async()); collect_buffer_info_and_text(path, buffer_model, buffer, cx.to_async());
buffer_infos.push(buffer_info); buffer_infos.push(buffer_info);
text_tasks.push(text_task); text_tasks.push(text_task);
} }
anyhow::Ok(()) anyhow::Ok(())
})??; })??;
let buffer_texts = futures::future::join_all(text_tasks).await; let buffer_texts = future::join_all(text_tasks).await;
let directory_buffers = buffer_infos let context_buffers = buffer_infos
.into_iter() .into_iter()
.zip(buffer_texts.iter()) .zip(buffer_texts)
.map(|(info, text)| make_context_buffer(info, text.clone())) .map(|(info, text)| make_context_buffer(info, text))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
if directory_buffers.is_empty() { if context_buffers.is_empty() {
bail!("No text files found in {}", &project_path.path.display()); bail!("No text files found in {}", &project_path.path.display());
} }
// TODO: include directory path in text?
this.update(&mut cx, |this, _| { this.update(&mut cx, |this, _| {
this.insert_directory(&project_path.path, directory_buffers, buffer_texts.into()); this.insert_directory(&project_path.path, context_buffers);
})?; })?;
anyhow::Ok(()) anyhow::Ok(())
}) })
} }
pub fn insert_directory( pub fn insert_directory(&mut self, path: &Path, context_buffers: Vec<ContextBuffer>) {
&mut self,
path: &Path,
buffers: Vec<ContextBuffer>,
text: Box<[SharedString]>,
) {
let id = self.next_context_id.post_inc(); let id = self.next_context_id.post_inc();
self.directories.insert(path.to_path_buf(), id); self.directories.insert(path.to_path_buf(), id);
let full_path: SharedString = path.to_string_lossy().into_owned().into(); self.context.push(Context::Directory(DirectoryContext::new(
id,
let name = match path.file_name() { path,
Some(name) => name.to_string_lossy().into_owned().into(), context_buffers,
None => full_path.clone(), )));
};
let parent = path
.parent()
.and_then(|p| p.file_name())
.map(|p| p.to_string_lossy().into_owned().into());
self.context.push(Context::Directory(DirectoryContext {
path: path.into(),
buffers,
snapshot: ContextSnapshot {
id,
name,
parent,
tooltip: Some(full_path),
icon_path: None,
kind: ContextKind::Directory,
text,
},
}));
} }
pub fn add_thread(&mut self, thread: Model<Thread>, cx: &mut ModelContext<Self>) { pub fn add_thread(&mut self, thread: Model<Thread>, cx: &mut ModelContext<Self>) {
@ -347,7 +319,8 @@ impl ContextStore {
if !self.files.is_empty() { if !self.files.is_empty() {
let found_file_context = self.context.iter().find(|context| match &context { let found_file_context = self.context.iter().find(|context| match &context {
Context::File(file_context) => { Context::File(file_context) => {
if let Some(file_path) = file_context.path(cx) { let buffer = file_context.context_buffer.buffer.read(cx);
if let Some(file_path) = buffer_path_log_err(buffer) {
*file_path == *path *file_path == *path
} else { } else {
false false
@ -390,6 +363,17 @@ impl ContextStore {
pub fn includes_url(&self, url: &str) -> Option<ContextId> { pub fn includes_url(&self, url: &str) -> Option<ContextId> {
self.fetched_urls.get(url).copied() self.fetched_urls.get(url).copied()
} }
/// Replaces the context that matches the ID of the new context, if any match.
fn replace_context(&mut self, new_context: Context) {
let id = new_context.id();
for context in self.context.iter_mut() {
if context.id() == id {
*context = new_context;
break;
}
}
}
} }
pub enum FileInclusion { pub enum FileInclusion {
@ -417,7 +401,7 @@ fn collect_buffer_info_and_text(
path: Arc<Path>, path: Arc<Path>,
buffer_model: Model<Buffer>, buffer_model: Model<Buffer>,
buffer: &Buffer, buffer: &Buffer,
cx: &AsyncAppContext, cx: AsyncAppContext,
) -> (BufferInfo, Task<SharedString>) { ) -> (BufferInfo, Task<SharedString>) {
let buffer_info = BufferInfo { let buffer_info = BufferInfo {
id: buffer.remote_id(), id: buffer.remote_id(),
@ -432,6 +416,15 @@ fn collect_buffer_info_and_text(
(buffer_info, text_task) (buffer_info, text_task)
} }
pub fn buffer_path_log_err(buffer: &Buffer) -> Option<Arc<Path>> {
if let Some(file) = buffer.file() {
Some(file.path().clone())
} else {
log::error!("Buffer that had a path unexpectedly no longer has a path.");
None
}
}
fn to_fenced_codeblock(path: &Path, content: Rope) -> SharedString { fn to_fenced_codeblock(path: &Path, content: Rope) -> SharedString {
let path_extension = path.extension().and_then(|ext| ext.to_str()); let path_extension = path.extension().and_then(|ext| ext.to_str());
let path_string = path.to_string_lossy(); let path_string = path.to_string_lossy();
@ -485,3 +478,133 @@ fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec<Arc<Path>> {
files files
} }
pub fn refresh_context_store_text(
context_store: Model<ContextStore>,
cx: &AppContext,
) -> impl Future<Output = ()> {
let mut tasks = Vec::new();
let context_store_ref = context_store.read(cx);
for context in &context_store_ref.context {
match context {
Context::File(file_context) => {
let context_store = context_store.clone();
if let Some(task) = refresh_file_text(context_store, file_context, cx) {
tasks.push(task);
}
}
Context::Directory(directory_context) => {
let context_store = context_store.clone();
if let Some(task) = refresh_directory_text(context_store, directory_context, cx) {
tasks.push(task);
}
}
Context::Thread(thread_context) => {
let context_store = context_store.clone();
tasks.push(refresh_thread_text(context_store, thread_context, cx));
}
// Intentionally omit refreshing fetched URLs as it doesn't seem all that useful,
// and doing the caching properly could be tricky (unless it's already handled by
// the HttpClient?).
Context::FetchedUrl(_) => {}
}
}
future::join_all(tasks).map(|_| ())
}
fn refresh_file_text(
context_store: Model<ContextStore>,
file_context: &FileContext,
cx: &AppContext,
) -> Option<Task<()>> {
let id = file_context.id;
let task = refresh_context_buffer(&file_context.context_buffer, cx);
if let Some(task) = task {
Some(cx.spawn(|mut cx| async move {
let context_buffer = task.await;
context_store
.update(&mut cx, |context_store, _| {
let new_file_context = FileContext { id, context_buffer };
context_store.replace_context(Context::File(new_file_context));
})
.ok();
}))
} else {
None
}
}
fn refresh_directory_text(
context_store: Model<ContextStore>,
directory_context: &DirectoryContext,
cx: &AppContext,
) -> Option<Task<()>> {
let mut stale = false;
let futures = directory_context
.context_buffers
.iter()
.map(|context_buffer| {
if let Some(refresh_task) = refresh_context_buffer(context_buffer, cx) {
stale = true;
future::Either::Left(refresh_task)
} else {
future::Either::Right(future::ready((*context_buffer).clone()))
}
})
.collect::<Vec<_>>();
if !stale {
return None;
}
let context_buffers = future::join_all(futures);
let id = directory_context.snapshot.id;
let path = directory_context.path.clone();
Some(cx.spawn(|mut cx| async move {
let context_buffers = context_buffers.await;
context_store
.update(&mut cx, |context_store, _| {
let new_directory_context = DirectoryContext::new(id, &path, context_buffers);
context_store.replace_context(Context::Directory(new_directory_context));
})
.ok();
}))
}
fn refresh_thread_text(
context_store: Model<ContextStore>,
thread_context: &ThreadContext,
cx: &AppContext,
) -> Task<()> {
let id = thread_context.id;
let thread = thread_context.thread.clone();
cx.spawn(move |mut cx| async move {
context_store
.update(&mut cx, |context_store, cx| {
let text = thread.read(cx).text().into();
context_store.replace_context(Context::Thread(ThreadContext { id, thread, text }));
})
.ok();
})
}
fn refresh_context_buffer(
context_buffer: &ContextBuffer,
cx: &AppContext,
) -> Option<impl Future<Output = ContextBuffer>> {
let buffer = context_buffer.buffer.read(cx);
let path = buffer_path_log_err(buffer)?;
if buffer.version.changed_since(&context_buffer.version) {
let (buffer_info, text_task) = collect_buffer_info_and_text(
path,
context_buffer.buffer.clone(),
buffer,
cx.to_async(),
);
Some(text_task.map(move |text| make_context_buffer(buffer_info, text)))
} else {
None
}
}

View file

@ -19,7 +19,7 @@ use workspace::Workspace;
use crate::assistant_model_selector::AssistantModelSelector; use crate::assistant_model_selector::AssistantModelSelector;
use crate::context_picker::{ConfirmBehavior, ContextPicker}; use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_store::ContextStore; use crate::context_store::{refresh_context_store_text, ContextStore};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::thread::{RequestKind, Thread}; use crate::thread::{RequestKind, Thread};
use crate::thread_store::ThreadStore; use crate::thread_store::ThreadStore;
@ -125,22 +125,20 @@ impl MessageEditor {
self.send_to_model(RequestKind::Chat, cx); self.send_to_model(RequestKind::Chat, cx);
} }
fn send_to_model( fn send_to_model(&mut self, request_kind: RequestKind, cx: &mut ViewContext<Self>) {
&mut self,
request_kind: RequestKind,
cx: &mut ViewContext<Self>,
) -> Option<()> {
let provider = LanguageModelRegistry::read_global(cx).active_provider(); let provider = LanguageModelRegistry::read_global(cx).active_provider();
if provider if provider
.as_ref() .as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx)) .map_or(false, |provider| provider.must_accept_terms(cx))
{ {
cx.notify(); cx.notify();
return None; return;
} }
let model_registry = LanguageModelRegistry::read_global(cx); let model_registry = LanguageModelRegistry::read_global(cx);
let model = model_registry.active_model()?; let Some(model) = model_registry.active_model() else {
return;
};
let user_message = self.editor.update(cx, |editor, cx| { let user_message = self.editor.update(cx, |editor, cx| {
let text = editor.text(cx); let text = editor.text(cx);
@ -148,29 +146,37 @@ impl MessageEditor {
text text
}); });
let refresh_task = refresh_context_store_text(self.context_store.clone(), cx);
let thread = self.thread.clone(); let thread = self.thread.clone();
thread.update(cx, |thread, cx| { let context_store = self.context_store.clone();
let context = self.context_store.read(cx).snapshot(cx).collect::<Vec<_>>(); let use_tools = self.use_tools;
thread.insert_user_message(user_message, context, cx); cx.spawn(move |_, mut cx| async move {
let mut request = thread.to_completion_request(request_kind, cx); refresh_task.await;
thread
.update(&mut cx, |thread, cx| {
let context = context_store.read(cx).snapshot(cx).collect::<Vec<_>>();
thread.insert_user_message(user_message, context, cx);
let mut request = thread.to_completion_request(request_kind, cx);
if self.use_tools { if use_tools {
request.tools = thread request.tools = thread
.tools() .tools()
.tools(cx) .tools(cx)
.into_iter() .into_iter()
.map(|tool| LanguageModelRequestTool { .map(|tool| LanguageModelRequestTool {
name: tool.name(), name: tool.name(),
description: tool.description(), description: tool.description(),
input_schema: tool.input_schema(), input_schema: tool.input_schema(),
}) })
.collect(); .collect();
} }
thread.stream_completion(request, model, cx) thread.stream_completion(request, model, cx)
}); })
.ok();
None })
.detach();
} }
fn handle_editor_event( fn handle_editor_event(