From 66e873942dca65f9f44be3c8ffe033eb56f0b7b9 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 15 May 2024 16:16:10 -0400 Subject: [PATCH] assistant: Factor `RecentBuffersContext` logic out of `AssistantPanel` (#11876) This PR factors some more code related to the `RecentBuffersContext` out of the `AssistantPanel` and into the corresponding module. We're trying to strike a balance between keeping this code easy to evolve as we work on the Assistant, while also having some semblance of separation/structure. This also adds the missing functionality of updating the remaining token count when the `CurrentProjectContext` is enabled/disabled. Release Notes: - N/A --------- Co-authored-by: Max --- crates/assistant/src/ambient_context.rs | 6 + .../src/ambient_context/current_project.rs | 11 +- .../src/ambient_context/recent_buffers.rs | 168 ++++++++++++++- crates/assistant/src/assistant_panel.rs | 191 +++--------------- 4 files changed, 205 insertions(+), 171 deletions(-) diff --git a/crates/assistant/src/ambient_context.rs b/crates/assistant/src/ambient_context.rs index de4195a1db..7a693e8c9f 100644 --- a/crates/assistant/src/ambient_context.rs +++ b/crates/assistant/src/ambient_context.rs @@ -9,3 +9,9 @@ pub struct AmbientContext { pub recent_buffers: RecentBuffersContext, pub current_project: CurrentProjectContext, } + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub enum ContextUpdated { + Updating, + Disabled, +} diff --git a/crates/assistant/src/ambient_context/current_project.rs b/crates/assistant/src/ambient_context/current_project.rs index cc77674023..d183c7fe25 100644 --- a/crates/assistant/src/ambient_context/current_project.rs +++ b/crates/assistant/src/ambient_context/current_project.rs @@ -8,6 +8,7 @@ use gpui::{AsyncAppContext, ModelContext, Task, WeakModel}; use project::{Project, ProjectPath}; use util::ResultExt; +use crate::ambient_context::ContextUpdated; use crate::assistant_panel::Conversation; use crate::{LanguageModelRequestMessage, Role}; @@ -44,12 +45,12 @@ impl CurrentProjectContext { fs: Arc, project: WeakModel, cx: &mut ModelContext, - ) { + ) -> ContextUpdated { if !self.enabled { self.message.clear(); self.pending_message = None; cx.notify(); - return; + return ContextUpdated::Disabled; } self.pending_message = Some(cx.spawn(|conversation, mut cx| async move { @@ -74,12 +75,16 @@ impl CurrentProjectContext { if let Some(message) = message_task.await.log_err() { conversation - .update(&mut cx, |conversation, _cx| { + .update(&mut cx, |conversation, cx| { conversation.ambient_context.current_project.message = message; + conversation.count_remaining_tokens(cx); + cx.notify(); }) .log_err(); } })); + + ContextUpdated::Updating } async fn build_message(fs: Arc, path_to_cargo_toml: &Path) -> Result { diff --git a/crates/assistant/src/ambient_context/recent_buffers.rs b/crates/assistant/src/ambient_context/recent_buffers.rs index eceecd6056..a63074906b 100644 --- a/crates/assistant/src/ambient_context/recent_buffers.rs +++ b/crates/assistant/src/ambient_context/recent_buffers.rs @@ -1,6 +1,13 @@ -use gpui::{Subscription, Task, WeakModel}; -use language::Buffer; +use std::fmt::Write; +use std::iter; +use std::path::PathBuf; +use std::time::Duration; +use gpui::{ModelContext, Subscription, Task, WeakModel}; +use language::{Buffer, BufferSnapshot, DiagnosticEntry, Point}; + +use crate::ambient_context::ContextUpdated; +use crate::assistant_panel::Conversation; use crate::{LanguageModelRequestMessage, Role}; pub struct RecentBuffersContext { @@ -34,4 +41,161 @@ impl RecentBuffersContext { content: self.message.clone(), }) } + + pub fn update(&mut self, cx: &mut ModelContext) -> ContextUpdated { + let buffers = self + .buffers + .iter() + .filter_map(|recent| { + recent + .buffer + .read_with(cx, |buffer, cx| { + ( + buffer.file().map(|file| file.full_path(cx)), + buffer.snapshot(), + ) + }) + .ok() + }) + .collect::>(); + + if !self.enabled || buffers.is_empty() { + self.message.clear(); + self.pending_message = None; + cx.notify(); + ContextUpdated::Disabled + } else { + self.pending_message = Some(cx.spawn(|this, mut cx| async move { + const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(100); + cx.background_executor().timer(DEBOUNCE_TIMEOUT).await; + + let message = cx + .background_executor() + .spawn(async move { Self::build_message(&buffers) }) + .await; + this.update(&mut cx, |conversation, cx| { + conversation.ambient_context.recent_buffers.message = message; + conversation.count_remaining_tokens(cx); + cx.notify(); + }) + .ok(); + })); + + ContextUpdated::Updating + } + } + + fn build_message(buffers: &[(Option, BufferSnapshot)]) -> String { + let mut message = String::new(); + writeln!( + message, + "The following is a list of recent buffers that the user has opened." + ) + .unwrap(); + writeln!( + message, + "For every line in the buffer, I will include a row number that line corresponds to." + ) + .unwrap(); + writeln!( + message, + "Lines that don't have a number correspond to errors and warnings. For example:" + ) + .unwrap(); + writeln!(message, "path/to/file.md").unwrap(); + writeln!(message, "```markdown").unwrap(); + writeln!(message, "1 The quick brown fox").unwrap(); + writeln!(message, "2 jumps over one active").unwrap(); + writeln!(message, " --- error: should be 'the'").unwrap(); + writeln!(message, " ------ error: should be 'lazy'").unwrap(); + writeln!(message, "3 dog").unwrap(); + writeln!(message, "```").unwrap(); + + message.push('\n'); + writeln!(message, "Here's the actual recent buffer list:").unwrap(); + for (path, buffer) in buffers { + if let Some(path) = path { + writeln!(message, "{}", path.display()).unwrap(); + } else { + writeln!(message, "untitled").unwrap(); + } + + if let Some(language) = buffer.language() { + writeln!(message, "```{}", language.name().to_lowercase()).unwrap(); + } else { + writeln!(message, "```").unwrap(); + } + + let mut diagnostics = buffer + .diagnostics_in_range::<_, Point>( + language::Anchor::MIN..language::Anchor::MAX, + false, + ) + .peekable(); + + let mut active_diagnostics = Vec::>::new(); + const GUTTER_PADDING: usize = 4; + let gutter_width = + ((buffer.max_point().row + 1) as f32).log10() as usize + 1 + GUTTER_PADDING; + for buffer_row in 0..=buffer.max_point().row { + let display_row = buffer_row + 1; + active_diagnostics.retain(|diagnostic| { + (diagnostic.range.start.row..=diagnostic.range.end.row).contains(&buffer_row) + }); + while diagnostics.peek().map_or(false, |diagnostic| { + (diagnostic.range.start.row..=diagnostic.range.end.row).contains(&buffer_row) + }) { + active_diagnostics.push(diagnostics.next().unwrap()); + } + + let row_width = (display_row as f32).log10() as usize + 1; + write!(message, "{}", display_row).unwrap(); + if row_width < gutter_width { + message.extend(iter::repeat(' ').take(gutter_width - row_width)); + } + + for chunk in buffer.text_for_range( + Point::new(buffer_row, 0)..Point::new(buffer_row, buffer.line_len(buffer_row)), + ) { + message.push_str(chunk); + } + message.push('\n'); + + for diagnostic in &active_diagnostics { + message.extend(iter::repeat(' ').take(gutter_width)); + + let start_column = if diagnostic.range.start.row == buffer_row { + message + .extend(iter::repeat(' ').take(diagnostic.range.start.column as usize)); + diagnostic.range.start.column + } else { + 0 + }; + let end_column = if diagnostic.range.end.row == buffer_row { + diagnostic.range.end.column + } else { + buffer.line_len(buffer_row) + }; + + message.extend(iter::repeat('-').take((end_column - start_column) as usize)); + writeln!(message, " {}", diagnostic.diagnostic.message).unwrap(); + } + } + + message.push('\n'); + } + + writeln!( + message, + "When quoting the above code, mention which rows the code occurs at." + ) + .unwrap(); + writeln!( + message, + "Never include rows in the quoted code itself and only report lines that didn't start with a row number." + ) + .unwrap(); + + message + } } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 9f65558808..215e6d8f23 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,4 +1,4 @@ -use crate::ambient_context::{AmbientContext, RecentBuffer}; +use crate::ambient_context::{AmbientContext, ContextUpdated, RecentBuffer}; use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings, ZedDotDevModel}, codegen::{self, Codegen, CodegenKind}, @@ -31,10 +31,7 @@ use gpui::{ Subscription, Task, TextStyle, UniformListScrollHandle, View, ViewContext, VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext, }; -use language::{ - language_settings::SoftWrap, Buffer, BufferSnapshot, DiagnosticEntry, LanguageRegistry, Point, - ToOffset as _, -}; +use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, Point, ToOffset as _}; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; use project::Project; @@ -1519,7 +1516,12 @@ impl Conversation { fn toggle_recent_buffers(&mut self, cx: &mut ModelContext) { self.ambient_context.recent_buffers.enabled = !self.ambient_context.recent_buffers.enabled; - self.update_recent_buffers_context(cx); + match self.ambient_context.recent_buffers.update(cx) { + ContextUpdated::Updating => {} + ContextUpdated::Disabled => { + self.count_remaining_tokens(cx); + } + } } fn toggle_current_project_context( @@ -1530,7 +1532,12 @@ impl Conversation { ) { self.ambient_context.current_project.enabled = !self.ambient_context.current_project.enabled; - self.ambient_context.current_project.update(fs, project, cx); + match self.ambient_context.current_project.update(fs, project, cx) { + ContextUpdated::Updating => {} + ContextUpdated::Disabled => { + self.count_remaining_tokens(cx); + } + } } fn set_recent_buffers( @@ -1545,170 +1552,22 @@ impl Conversation { .extend(buffers.into_iter().map(|buffer| RecentBuffer { buffer: buffer.downgrade(), _subscription: cx.observe(&buffer, |this, _, cx| { - this.update_recent_buffers_context(cx); + match this.ambient_context.recent_buffers.update(cx) { + ContextUpdated::Updating => {} + ContextUpdated::Disabled => { + this.count_remaining_tokens(cx); + } + } }), })); - self.update_recent_buffers_context(cx); - } - - fn update_recent_buffers_context(&mut self, cx: &mut ModelContext) { - let buffers = self - .ambient_context - .recent_buffers - .buffers - .iter() - .filter_map(|recent| { - recent - .buffer - .read_with(cx, |buffer, cx| { - ( - buffer.file().map(|file| file.full_path(cx)), - buffer.snapshot(), - ) - }) - .ok() - }) - .collect::>(); - - if !self.ambient_context.recent_buffers.enabled || buffers.is_empty() { - self.ambient_context.recent_buffers.message.clear(); - self.ambient_context.recent_buffers.pending_message = None; - self.count_remaining_tokens(cx); - cx.notify(); - } else { - self.ambient_context.recent_buffers.pending_message = - Some(cx.spawn(|this, mut cx| async move { - const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(100); - cx.background_executor().timer(DEBOUNCE_TIMEOUT).await; - - let message = cx - .background_executor() - .spawn(async move { Self::message_for_recent_buffers(&buffers) }) - .await; - this.update(&mut cx, |this, cx| { - this.ambient_context.recent_buffers.message = message; - this.count_remaining_tokens(cx); - cx.notify(); - }) - .ok(); - })); + match self.ambient_context.recent_buffers.update(cx) { + ContextUpdated::Updating => {} + ContextUpdated::Disabled => { + self.count_remaining_tokens(cx); + } } } - fn message_for_recent_buffers(buffers: &[(Option, BufferSnapshot)]) -> String { - let mut message = String::new(); - writeln!( - message, - "The following is a list of recent buffers that the user has opened." - ) - .unwrap(); - writeln!( - message, - "For every line in the buffer, I will include a row number that line corresponds to." - ) - .unwrap(); - writeln!( - message, - "Lines that don't have a number correspond to errors and warnings. For example:" - ) - .unwrap(); - writeln!(message, "path/to/file.md").unwrap(); - writeln!(message, "```markdown").unwrap(); - writeln!(message, "1 The quick brown fox").unwrap(); - writeln!(message, "2 jumps over one active").unwrap(); - writeln!(message, " --- error: should be 'the'").unwrap(); - writeln!(message, " ------ error: should be 'lazy'").unwrap(); - writeln!(message, "3 dog").unwrap(); - writeln!(message, "```").unwrap(); - - message.push('\n'); - writeln!(message, "Here's the actual recent buffer list:").unwrap(); - for (path, buffer) in buffers { - if let Some(path) = path { - writeln!(message, "{}", path.display()).unwrap(); - } else { - writeln!(message, "untitled").unwrap(); - } - - if let Some(language) = buffer.language() { - writeln!(message, "```{}", language.name().to_lowercase()).unwrap(); - } else { - writeln!(message, "```").unwrap(); - } - - let mut diagnostics = buffer - .diagnostics_in_range::<_, Point>( - language::Anchor::MIN..language::Anchor::MAX, - false, - ) - .peekable(); - - let mut active_diagnostics = Vec::>::new(); - const GUTTER_PADDING: usize = 4; - let gutter_width = - ((buffer.max_point().row + 1) as f32).log10() as usize + 1 + GUTTER_PADDING; - for buffer_row in 0..=buffer.max_point().row { - let display_row = buffer_row + 1; - active_diagnostics.retain(|diagnostic| { - (diagnostic.range.start.row..=diagnostic.range.end.row).contains(&buffer_row) - }); - while diagnostics.peek().map_or(false, |diagnostic| { - (diagnostic.range.start.row..=diagnostic.range.end.row).contains(&buffer_row) - }) { - active_diagnostics.push(diagnostics.next().unwrap()); - } - - let row_width = (display_row as f32).log10() as usize + 1; - write!(message, "{}", display_row).unwrap(); - if row_width < gutter_width { - message.extend(iter::repeat(' ').take(gutter_width - row_width)); - } - - for chunk in buffer.text_for_range( - Point::new(buffer_row, 0)..Point::new(buffer_row, buffer.line_len(buffer_row)), - ) { - message.push_str(chunk); - } - message.push('\n'); - - for diagnostic in &active_diagnostics { - message.extend(iter::repeat(' ').take(gutter_width)); - - let start_column = if diagnostic.range.start.row == buffer_row { - message - .extend(iter::repeat(' ').take(diagnostic.range.start.column as usize)); - diagnostic.range.start.column - } else { - 0 - }; - let end_column = if diagnostic.range.end.row == buffer_row { - diagnostic.range.end.column - } else { - buffer.line_len(buffer_row) - }; - - message.extend(iter::repeat('-').take((end_column - start_column) as usize)); - writeln!(message, " {}", diagnostic.diagnostic.message).unwrap(); - } - } - - message.push('\n'); - } - - writeln!( - message, - "When quoting the above code, mention which rows the code occurs at." - ) - .unwrap(); - writeln!( - message, - "Never include rows in the quoted code itself and only report lines that didn't start with a row number." - ) - .unwrap(); - - message - } - fn handle_buffer_event( &mut self, _: Model,