diff --git a/assets/prompts/assistant_system_prompt.hbs b/assets/prompts/assistant_system_prompt.hbs index 12453f21b0..ecc0dfd09b 100644 --- a/assets/prompts/assistant_system_prompt.hbs +++ b/assets/prompts/assistant_system_prompt.hbs @@ -14,5 +14,19 @@ Be concise and direct in your responses. The user has opened a project that contains the following root directories/files: {{#each worktrees}} -- {{root_name}} (absolute path: {{abs_path}}) +- `{{root_name}}` (absolute path: `{{abs_path}}`) {{/each}} +{{#if has_rules}} + +There are rules that apply to these root directories: +{{#each worktrees}} +{{#if rules_file}} + +`{{root_name}}/{{rules_file.rel_path}}`: + +`````` +{{{rules_file.text}}} +`````` +{{/if}} +{{/each}} +{{/if}} diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index 3d4ae49928..bb98f71949 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -8,7 +8,7 @@ use gpui::{ list, percentage, AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, Task, TextStyleRefinement, Transformation, - UnderlineStyle, + UnderlineStyle, WeakEntity, }; use language::{Buffer, LanguageRegistry}; use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role}; @@ -18,9 +18,9 @@ use settings::Settings as _; use std::sync::Arc; use std::time::Duration; use theme::ThemeSettings; -use ui::Color; use ui::{prelude::*, Disclosure, KeyBinding}; use util::ResultExt as _; +use workspace::{OpenOptions, Workspace}; use crate::context_store::{refresh_context_store_text, ContextStore}; @@ -29,6 +29,7 @@ pub struct ActiveThread { thread_store: Entity, thread: Entity, context_store: Entity, + workspace: WeakEntity, save_thread_task: Option>, messages: Vec, list_state: ListState, @@ -50,6 +51,7 @@ impl ActiveThread { thread_store: Entity, language_registry: Arc, context_store: Entity, + workspace: WeakEntity, window: &mut Window, cx: &mut Context, ) -> Self { @@ -63,6 +65,7 @@ impl ActiveThread { thread_store, thread: thread.clone(), context_store, + workspace, save_thread_task: None, messages: Vec::new(), rendered_messages_by_id: HashMap::default(), @@ -736,6 +739,7 @@ impl ActiveThread { }; v_flex() + .when(ix == 0, |parent| parent.child(self.render_rules_item(cx))) .when_some(checkpoint, |parent, checkpoint| { parent.child( h_flex().pl_2().child( @@ -1042,6 +1046,86 @@ impl ActiveThread { }), ) } + + fn render_rules_item(&self, cx: &Context) -> AnyElement { + let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref() + else { + return div().into_any(); + }; + + let rules_files = system_prompt_context + .worktrees + .iter() + .filter_map(|worktree| worktree.rules_file.as_ref()) + .collect::>(); + + let label_text = match rules_files.as_slice() { + &[] => return div().into_any(), + &[rules_file] => { + format!("Using {:?} file", rules_file.rel_path) + } + rules_files => { + format!("Using {} rules files", rules_files.len()) + } + }; + + div() + .pt_1() + .px_2p5() + .child( + h_flex() + .group("rules-item") + .w_full() + .gap_2() + .justify_between() + .child( + h_flex() + .gap_1p5() + .child( + Icon::new(IconName::File) + .size(IconSize::XSmall) + .color(Color::Disabled), + ) + .child( + Label::new(label_text) + .size(LabelSize::XSmall) + .color(Color::Muted) + .buffer_font(cx), + ), + ) + .child( + div().visible_on_hover("rules-item").child( + Button::new("open-rules", "Open Rules") + .label_size(LabelSize::XSmall) + .on_click(cx.listener(Self::handle_open_rules)), + ), + ), + ) + .into_any() + } + + fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context) { + let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref() + else { + return; + }; + + let abs_paths = system_prompt_context + .worktrees + .iter() + .flat_map(|worktree| worktree.rules_file.as_ref()) + .map(|rules_file| rules_file.abs_path.to_path_buf()) + .collect::>(); + + if let Ok(task) = self.workspace.update(cx, move |workspace, cx| { + // TODO: Open a multibuffer instead? In some cases this doesn't make the set of rules + // files clear. For example, if rules file 1 is already open but rules file 2 is not, + // this would open and focus rules file 2 in a tab that is not next to rules file 1. + workspace.open_paths(abs_paths, OpenOptions::default(), None, window, cx) + }) { + task.detach(); + } + } } impl Render for ActiveThread { diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index a963a0f538..f91600256f 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -174,6 +174,7 @@ impl AssistantPanel { thread_store.clone(), language_registry.clone(), message_editor_context_store.clone(), + workspace.clone(), window, cx, ) @@ -252,6 +253,7 @@ impl AssistantPanel { self.thread_store.clone(), self.language_registry.clone(), message_editor_context_store.clone(), + self.workspace.clone(), window, cx, ) @@ -389,6 +391,7 @@ impl AssistantPanel { this.thread_store.clone(), this.language_registry.clone(), message_editor_context_store.clone(), + this.workspace.clone(), window, cx, ) @@ -922,8 +925,8 @@ impl AssistantPanel { ThreadError::MaxMonthlySpendReached => { self.render_max_monthly_spend_reached_error(cx) } - ThreadError::Message(error_message) => { - self.render_error_message(&error_message, cx) + ThreadError::Message { header, message } => { + self.render_error_message(header, message, cx) } }) .into_any(), @@ -1026,7 +1029,8 @@ impl AssistantPanel { fn render_error_message( &self, - error_message: &SharedString, + header: SharedString, + message: SharedString, cx: &mut Context, ) -> AnyElement { v_flex() @@ -1036,17 +1040,14 @@ impl AssistantPanel { .gap_1p5() .items_center() .child(Icon::new(IconName::XCircle).color(Color::Error)) - .child( - Label::new("Error interacting with language model") - .weight(FontWeight::MEDIUM), - ), + .child(Label::new(header).weight(FontWeight::MEDIUM)), ) .child( div() .id("error-message") .max_h_32() .overflow_y_scroll() - .child(Label::new(error_message.clone())), + .child(Label::new(message)), ) .child( h_flex() diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index ae986743a3..8048875495 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -33,7 +33,7 @@ use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::thread::{RequestKind, Thread}; use crate::thread_store::ThreadStore; use crate::tool_selector::ToolSelector; -use crate::{Chat, ChatMode, RemoveAllContext, ToggleContextPicker}; +use crate::{Chat, ChatMode, RemoveAllContext, ThreadEvent, ToggleContextPicker}; pub struct MessageEditor { thread: Entity, @@ -206,12 +206,23 @@ impl MessageEditor { let refresh_task = refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx); + let system_prompt_context_task = self.thread.read(cx).load_system_prompt_context(cx); + let thread = self.thread.clone(); let context_store = self.context_store.clone(); let git_store = self.project.read(cx).git_store(); let checkpoint = git_store.read(cx).checkpoint(cx); cx.spawn(async move |_, cx| { refresh_task.await; + let (system_prompt_context, load_error) = system_prompt_context_task.await; + thread + .update(cx, |thread, cx| { + thread.set_system_prompt_context(system_prompt_context); + if let Some(load_error) = load_error { + cx.emit(ThreadEvent::ShowError(load_error)); + } + }) + .ok(); let checkpoint = checkpoint.await.log_err(); thread .update(cx, |thread, cx| { diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 98daa6fbcb..0257ff40ed 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -6,6 +6,7 @@ use anyhow::{Context as _, Result}; use assistant_tool::{ActionLog, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::{BTreeMap, HashMap, HashSet}; +use fs::Fs; use futures::future::Shared; use futures::{FutureExt, StreamExt as _}; use git; @@ -17,11 +18,13 @@ use language_model::{ Role, StopReason, TokenUsage, }; use project::git::GitStoreCheckpoint; -use project::Project; -use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder}; +use project::{Project, Worktree}; +use prompt_store::{ + AssistantSystemPromptContext, PromptBuilder, RulesFile, WorktreeInfoForSystemPrompt, +}; use scripting_tool::{ScriptingSession, ScriptingTool}; use serde::{Deserialize, Serialize}; -use util::{post_inc, ResultExt, TryFutureExt as _}; +use util::{maybe, post_inc, ResultExt as _, TryFutureExt as _}; use uuid::Uuid; use crate::context::{attach_context_to_message, ContextId, ContextSnapshot}; @@ -106,6 +109,7 @@ pub struct Thread { next_message_id: MessageId, context: BTreeMap, context_by_message: HashMap>, + system_prompt_context: Option, checkpoints_by_message: HashMap, completion_count: usize, pending_completions: Vec, @@ -136,6 +140,7 @@ impl Thread { next_message_id: MessageId(0), context: BTreeMap::default(), context_by_message: HashMap::default(), + system_prompt_context: None, checkpoints_by_message: HashMap::default(), completion_count: 0, pending_completions: Vec::new(), @@ -197,6 +202,7 @@ impl Thread { next_message_id, context: BTreeMap::default(), context_by_message: HashMap::default(), + system_prompt_context: None, checkpoints_by_message: HashMap::default(), completion_count: 0, pending_completions: Vec::new(), @@ -478,6 +484,116 @@ impl Thread { }) } + pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) { + self.system_prompt_context = Some(context); + } + + pub fn system_prompt_context(&self) -> &Option { + &self.system_prompt_context + } + + pub fn load_system_prompt_context( + &self, + cx: &App, + ) -> Task<(AssistantSystemPromptContext, Option)> { + let project = self.project.read(cx); + let tasks = project + .visible_worktrees(cx) + .map(|worktree| { + Self::load_worktree_info_for_system_prompt( + project.fs().clone(), + worktree.read(cx), + cx, + ) + }) + .collect::>(); + + cx.spawn(async |_cx| { + let results = futures::future::join_all(tasks).await; + let mut first_err = None; + let worktrees = results + .into_iter() + .map(|(worktree, err)| { + if first_err.is_none() && err.is_some() { + first_err = err; + } + worktree + }) + .collect::>(); + (AssistantSystemPromptContext::new(worktrees), first_err) + }) + } + + fn load_worktree_info_for_system_prompt( + fs: Arc, + worktree: &Worktree, + cx: &App, + ) -> Task<(WorktreeInfoForSystemPrompt, Option)> { + let root_name = worktree.root_name().into(); + let abs_path = worktree.abs_path(); + + // Note that Cline supports `.clinerules` being a directory, but that is not currently + // supported. This doesn't seem to occur often in GitHub repositories. + const RULES_FILE_NAMES: [&'static str; 5] = [ + ".rules", + ".cursorrules", + ".windsurfrules", + ".clinerules", + "CLAUDE.md", + ]; + let selected_rules_file = RULES_FILE_NAMES + .into_iter() + .filter_map(|name| { + worktree + .entry_for_path(name) + .filter(|entry| entry.is_file()) + .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path))) + }) + .next(); + + if let Some((rel_rules_path, abs_rules_path)) = selected_rules_file { + cx.spawn(async move |_| { + let rules_file_result = maybe!(async move { + let abs_rules_path = abs_rules_path?; + let text = fs.load(&abs_rules_path).await.with_context(|| { + format!("Failed to load assistant rules file {:?}", abs_rules_path) + })?; + anyhow::Ok(RulesFile { + rel_path: rel_rules_path, + abs_path: abs_rules_path.into(), + text: text.trim().to_string(), + }) + }) + .await; + let (rules_file, rules_file_error) = match rules_file_result { + Ok(rules_file) => (Some(rules_file), None), + Err(err) => ( + None, + Some(ThreadError::Message { + header: "Error loading rules file".into(), + message: format!("{err}").into(), + }), + ), + }; + let worktree_info = WorktreeInfoForSystemPrompt { + root_name, + abs_path, + rules_file, + }; + (worktree_info, rules_file_error) + }) + } else { + Task::ready(( + WorktreeInfoForSystemPrompt { + root_name, + abs_path, + rules_file: None, + }, + None, + )) + } + } + pub fn send_to_model( &mut self, model: Arc, @@ -515,36 +631,30 @@ impl Thread { request_kind: RequestKind, cx: &App, ) -> LanguageModelRequest { - let worktree_root_names = self - .project - .read(cx) - .visible_worktrees(cx) - .map(|worktree| { - let worktree = worktree.read(cx); - AssistantSystemPromptWorktree { - root_name: worktree.root_name().into(), - abs_path: worktree.abs_path(), - } - }) - .collect::>(); - let system_prompt = self - .prompt_builder - .generate_assistant_system_prompt(worktree_root_names) - .context("failed to generate assistant system prompt") - .log_err() - .unwrap_or_default(); - let mut request = LanguageModelRequest { - messages: vec![LanguageModelRequestMessage { - role: Role::System, - content: vec![MessageContent::Text(system_prompt)], - cache: true, - }], + messages: vec![], tools: Vec::new(), stop: Vec::new(), temperature: None, }; + if let Some(system_prompt_context) = self.system_prompt_context.as_ref() { + if let Some(system_prompt) = self + .prompt_builder + .generate_assistant_system_prompt(system_prompt_context) + .context("failed to generate assistant system prompt") + .log_err() + { + request.messages.push(LanguageModelRequestMessage { + role: Role::System, + content: vec![MessageContent::Text(system_prompt)], + cache: true, + }); + } + } else { + log::error!("system_prompt_context not set.") + } + let mut referenced_context_ids = HashSet::default(); for message in &self.messages { @@ -757,9 +867,10 @@ impl Thread { .map(|err| err.to_string()) .collect::>() .join("\n"); - cx.emit(ThreadEvent::ShowError(ThreadError::Message( - SharedString::from(error_message.clone()), - ))); + cx.emit(ThreadEvent::ShowError(ThreadError::Message { + header: "Error interacting with language model".into(), + message: SharedString::from(error_message.clone()), + })); } thread.cancel_last_completion(cx); @@ -1204,7 +1315,10 @@ impl Thread { pub enum ThreadError { PaymentRequired, MaxMonthlySpendReached, - Message(SharedString), + Message { + header: SharedString, + message: SharedString, + }, } #[derive(Debug, Clone)] diff --git a/crates/assistant2/src/thread_store.rs b/crates/assistant2/src/thread_store.rs index 7468cc1f36..cfdeb674d9 100644 --- a/crates/assistant2/src/thread_store.rs +++ b/crates/assistant2/src/thread_store.rs @@ -20,7 +20,7 @@ use prompt_store::PromptBuilder; use serde::{Deserialize, Serialize}; use util::ResultExt as _; -use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadId}; +use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId}; pub fn init(cx: &mut App) { ThreadsDatabase::init(cx); @@ -113,7 +113,7 @@ impl ThreadStore { .await? .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?; - this.update(cx, |this, cx| { + let thread = this.update(cx, |this, cx| { cx.new(|cx| { Thread::deserialize( id.clone(), @@ -124,7 +124,19 @@ impl ThreadStore { cx, ) }) - }) + })?; + + let (system_prompt_context, load_error) = thread + .update(cx, |thread, cx| thread.load_system_prompt_context(cx))? + .await; + thread.update(cx, |thread, cx| { + thread.set_system_prompt_context(system_prompt_context); + if let Some(load_error) = load_error { + cx.emit(ThreadEvent::ShowError(load_error)); + } + })?; + + Ok(thread) }) } diff --git a/crates/assistant_eval/src/eval.rs b/crates/assistant_eval/src/eval.rs index 8f5def88e3..2268cf78ab 100644 --- a/crates/assistant_eval/src/eval.rs +++ b/crates/assistant_eval/src/eval.rs @@ -79,10 +79,25 @@ impl Eval { let start_time = std::time::SystemTime::now(); + let (system_prompt_context, load_error) = cx + .update(|cx| { + assistant + .read(cx) + .thread + .read(cx) + .load_system_prompt_context(cx) + })? + .await; + + if let Some(load_error) = load_error { + return Err(anyhow!("{:?}", load_error)); + }; + assistant.update(cx, |assistant, cx| { assistant.thread.update(cx, |thread, cx| { let context = vec![]; thread.insert_user_message(self.user_prompt.clone(), context, None, cx); + thread.set_system_prompt_context(system_prompt_context); thread.send_to_model(model, RequestKind::Chat, cx); }); })?; diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs index 4c7630b6a1..d7ae16a8ea 100644 --- a/crates/prompt_store/src/prompts.rs +++ b/crates/prompt_store/src/prompts.rs @@ -18,13 +18,34 @@ use util::ResultExt; #[derive(Serialize)] pub struct AssistantSystemPromptContext { - pub worktrees: Vec, + pub worktrees: Vec, + pub has_rules: bool, +} + +impl AssistantSystemPromptContext { + pub fn new(worktrees: Vec) -> Self { + let has_rules = worktrees + .iter() + .any(|worktree| worktree.rules_file.is_some()); + Self { + worktrees, + has_rules, + } + } } #[derive(Serialize)] -pub struct AssistantSystemPromptWorktree { +pub struct WorktreeInfoForSystemPrompt { pub root_name: String, pub abs_path: Arc, + pub rules_file: Option, +} + +#[derive(Serialize)] +pub struct RulesFile { + pub rel_path: Arc, + pub abs_path: Arc, + pub text: String, } #[derive(Serialize)] @@ -234,12 +255,11 @@ impl PromptBuilder { pub fn generate_assistant_system_prompt( &self, - worktrees: Vec, + context: &AssistantSystemPromptContext, ) -> Result { - let prompt = AssistantSystemPromptContext { worktrees }; self.handlebars .lock() - .render("assistant_system_prompt", &prompt) + .render("assistant_system_prompt", context) } pub fn generate_inline_transformation_prompt(