Initial support for AI assistant rules files (#27168)

Release Notes:

- N/A

---------

Co-authored-by: Danilo <danilo@zed.dev>
Co-authored-by: Nathan <nathan@zed.dev>
Co-authored-by: Thomas <thomas@zed.dev>
This commit is contained in:
Michael Sloan 2025-03-20 02:30:04 -06:00 committed by GitHub
parent 14920ab910
commit 1180b6fbc7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 322 additions and 51 deletions

View file

@ -14,5 +14,19 @@ Be concise and direct in your responses.
The user has opened a project that contains the following root directories/files: The user has opened a project that contains the following root directories/files:
{{#each worktrees}} {{#each worktrees}}
- {{root_name}} (absolute path: {{abs_path}}) - `{{root_name}}` (absolute path: `{{abs_path}}`)
{{/each}} {{/each}}
{{#if has_rules}}
There are rules that apply to these root directories:
{{#each worktrees}}
{{#if rules_file}}
`{{root_name}}/{{rules_file.rel_path}}`:
``````
{{{rules_file.text}}}
``````
{{/if}}
{{/each}}
{{/if}}

View file

@ -8,7 +8,7 @@ use gpui::{
list, percentage, AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, list, percentage, AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent,
DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment, ListOffset, DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment, ListOffset,
ListState, StyleRefinement, Subscription, Task, TextStyleRefinement, Transformation, ListState, StyleRefinement, Subscription, Task, TextStyleRefinement, Transformation,
UnderlineStyle, UnderlineStyle, WeakEntity,
}; };
use language::{Buffer, LanguageRegistry}; use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role}; use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
@ -18,9 +18,9 @@ use settings::Settings as _;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::Color;
use ui::{prelude::*, Disclosure, KeyBinding}; use ui::{prelude::*, Disclosure, KeyBinding};
use util::ResultExt as _; use util::ResultExt as _;
use workspace::{OpenOptions, Workspace};
use crate::context_store::{refresh_context_store_text, ContextStore}; use crate::context_store::{refresh_context_store_text, ContextStore};
@ -29,6 +29,7 @@ pub struct ActiveThread {
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
thread: Entity<Thread>, thread: Entity<Thread>,
context_store: Entity<ContextStore>, context_store: Entity<ContextStore>,
workspace: WeakEntity<Workspace>,
save_thread_task: Option<Task<()>>, save_thread_task: Option<Task<()>>,
messages: Vec<MessageId>, messages: Vec<MessageId>,
list_state: ListState, list_state: ListState,
@ -50,6 +51,7 @@ impl ActiveThread {
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
context_store: Entity<ContextStore>, context_store: Entity<ContextStore>,
workspace: WeakEntity<Workspace>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
@ -63,6 +65,7 @@ impl ActiveThread {
thread_store, thread_store,
thread: thread.clone(), thread: thread.clone(),
context_store, context_store,
workspace,
save_thread_task: None, save_thread_task: None,
messages: Vec::new(), messages: Vec::new(),
rendered_messages_by_id: HashMap::default(), rendered_messages_by_id: HashMap::default(),
@ -736,6 +739,7 @@ impl ActiveThread {
}; };
v_flex() v_flex()
.when(ix == 0, |parent| parent.child(self.render_rules_item(cx)))
.when_some(checkpoint, |parent, checkpoint| { .when_some(checkpoint, |parent, checkpoint| {
parent.child( parent.child(
h_flex().pl_2().child( h_flex().pl_2().child(
@ -1042,6 +1046,86 @@ impl ActiveThread {
}), }),
) )
} }
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
else {
return div().into_any();
};
let rules_files = system_prompt_context
.worktrees
.iter()
.filter_map(|worktree| worktree.rules_file.as_ref())
.collect::<Vec<_>>();
let label_text = match rules_files.as_slice() {
&[] => return div().into_any(),
&[rules_file] => {
format!("Using {:?} file", rules_file.rel_path)
}
rules_files => {
format!("Using {} rules files", rules_files.len())
}
};
div()
.pt_1()
.px_2p5()
.child(
h_flex()
.group("rules-item")
.w_full()
.gap_2()
.justify_between()
.child(
h_flex()
.gap_1p5()
.child(
Icon::new(IconName::File)
.size(IconSize::XSmall)
.color(Color::Disabled),
)
.child(
Label::new(label_text)
.size(LabelSize::XSmall)
.color(Color::Muted)
.buffer_font(cx),
),
)
.child(
div().visible_on_hover("rules-item").child(
Button::new("open-rules", "Open Rules")
.label_size(LabelSize::XSmall)
.on_click(cx.listener(Self::handle_open_rules)),
),
),
)
.into_any()
}
fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
else {
return;
};
let abs_paths = system_prompt_context
.worktrees
.iter()
.flat_map(|worktree| worktree.rules_file.as_ref())
.map(|rules_file| rules_file.abs_path.to_path_buf())
.collect::<Vec<_>>();
if let Ok(task) = self.workspace.update(cx, move |workspace, cx| {
// TODO: Open a multibuffer instead? In some cases this doesn't make the set of rules
// files clear. For example, if rules file 1 is already open but rules file 2 is not,
// this would open and focus rules file 2 in a tab that is not next to rules file 1.
workspace.open_paths(abs_paths, OpenOptions::default(), None, window, cx)
}) {
task.detach();
}
}
} }
impl Render for ActiveThread { impl Render for ActiveThread {

View file

@ -174,6 +174,7 @@ impl AssistantPanel {
thread_store.clone(), thread_store.clone(),
language_registry.clone(), language_registry.clone(),
message_editor_context_store.clone(), message_editor_context_store.clone(),
workspace.clone(),
window, window,
cx, cx,
) )
@ -252,6 +253,7 @@ impl AssistantPanel {
self.thread_store.clone(), self.thread_store.clone(),
self.language_registry.clone(), self.language_registry.clone(),
message_editor_context_store.clone(), message_editor_context_store.clone(),
self.workspace.clone(),
window, window,
cx, cx,
) )
@ -389,6 +391,7 @@ impl AssistantPanel {
this.thread_store.clone(), this.thread_store.clone(),
this.language_registry.clone(), this.language_registry.clone(),
message_editor_context_store.clone(), message_editor_context_store.clone(),
this.workspace.clone(),
window, window,
cx, cx,
) )
@ -922,8 +925,8 @@ impl AssistantPanel {
ThreadError::MaxMonthlySpendReached => { ThreadError::MaxMonthlySpendReached => {
self.render_max_monthly_spend_reached_error(cx) self.render_max_monthly_spend_reached_error(cx)
} }
ThreadError::Message(error_message) => { ThreadError::Message { header, message } => {
self.render_error_message(&error_message, cx) self.render_error_message(header, message, cx)
} }
}) })
.into_any(), .into_any(),
@ -1026,7 +1029,8 @@ impl AssistantPanel {
fn render_error_message( fn render_error_message(
&self, &self,
error_message: &SharedString, header: SharedString,
message: SharedString,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> AnyElement { ) -> AnyElement {
v_flex() v_flex()
@ -1036,17 +1040,14 @@ impl AssistantPanel {
.gap_1p5() .gap_1p5()
.items_center() .items_center()
.child(Icon::new(IconName::XCircle).color(Color::Error)) .child(Icon::new(IconName::XCircle).color(Color::Error))
.child( .child(Label::new(header).weight(FontWeight::MEDIUM)),
Label::new("Error interacting with language model")
.weight(FontWeight::MEDIUM),
),
) )
.child( .child(
div() div()
.id("error-message") .id("error-message")
.max_h_32() .max_h_32()
.overflow_y_scroll() .overflow_y_scroll()
.child(Label::new(error_message.clone())), .child(Label::new(message)),
) )
.child( .child(
h_flex() h_flex()

View file

@ -33,7 +33,7 @@ use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::thread::{RequestKind, Thread}; use crate::thread::{RequestKind, Thread};
use crate::thread_store::ThreadStore; use crate::thread_store::ThreadStore;
use crate::tool_selector::ToolSelector; use crate::tool_selector::ToolSelector;
use crate::{Chat, ChatMode, RemoveAllContext, ToggleContextPicker}; use crate::{Chat, ChatMode, RemoveAllContext, ThreadEvent, ToggleContextPicker};
pub struct MessageEditor { pub struct MessageEditor {
thread: Entity<Thread>, thread: Entity<Thread>,
@ -206,12 +206,23 @@ impl MessageEditor {
let refresh_task = let refresh_task =
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx); refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
let system_prompt_context_task = self.thread.read(cx).load_system_prompt_context(cx);
let thread = self.thread.clone(); let thread = self.thread.clone();
let context_store = self.context_store.clone(); let context_store = self.context_store.clone();
let git_store = self.project.read(cx).git_store(); let git_store = self.project.read(cx).git_store();
let checkpoint = git_store.read(cx).checkpoint(cx); let checkpoint = git_store.read(cx).checkpoint(cx);
cx.spawn(async move |_, cx| { cx.spawn(async move |_, cx| {
refresh_task.await; refresh_task.await;
let (system_prompt_context, load_error) = system_prompt_context_task.await;
thread
.update(cx, |thread, cx| {
thread.set_system_prompt_context(system_prompt_context);
if let Some(load_error) = load_error {
cx.emit(ThreadEvent::ShowError(load_error));
}
})
.ok();
let checkpoint = checkpoint.await.log_err(); let checkpoint = checkpoint.await.log_err();
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {

View file

@ -6,6 +6,7 @@ use anyhow::{Context as _, Result};
use assistant_tool::{ActionLog, ToolWorkingSet}; use assistant_tool::{ActionLog, ToolWorkingSet};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet}; use collections::{BTreeMap, HashMap, HashSet};
use fs::Fs;
use futures::future::Shared; use futures::future::Shared;
use futures::{FutureExt, StreamExt as _}; use futures::{FutureExt, StreamExt as _};
use git; use git;
@ -17,11 +18,13 @@ use language_model::{
Role, StopReason, TokenUsage, Role, StopReason, TokenUsage,
}; };
use project::git::GitStoreCheckpoint; use project::git::GitStoreCheckpoint;
use project::Project; use project::{Project, Worktree};
use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder}; use prompt_store::{
AssistantSystemPromptContext, PromptBuilder, RulesFile, WorktreeInfoForSystemPrompt,
};
use scripting_tool::{ScriptingSession, ScriptingTool}; use scripting_tool::{ScriptingSession, ScriptingTool};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use util::{post_inc, ResultExt, TryFutureExt as _}; use util::{maybe, post_inc, ResultExt as _, TryFutureExt as _};
use uuid::Uuid; use uuid::Uuid;
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot}; use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
@ -106,6 +109,7 @@ pub struct Thread {
next_message_id: MessageId, next_message_id: MessageId,
context: BTreeMap<ContextId, ContextSnapshot>, context: BTreeMap<ContextId, ContextSnapshot>,
context_by_message: HashMap<MessageId, Vec<ContextId>>, context_by_message: HashMap<MessageId, Vec<ContextId>>,
system_prompt_context: Option<AssistantSystemPromptContext>,
checkpoints_by_message: HashMap<MessageId, GitStoreCheckpoint>, checkpoints_by_message: HashMap<MessageId, GitStoreCheckpoint>,
completion_count: usize, completion_count: usize,
pending_completions: Vec<PendingCompletion>, pending_completions: Vec<PendingCompletion>,
@ -136,6 +140,7 @@ impl Thread {
next_message_id: MessageId(0), next_message_id: MessageId(0),
context: BTreeMap::default(), context: BTreeMap::default(),
context_by_message: HashMap::default(), context_by_message: HashMap::default(),
system_prompt_context: None,
checkpoints_by_message: HashMap::default(), checkpoints_by_message: HashMap::default(),
completion_count: 0, completion_count: 0,
pending_completions: Vec::new(), pending_completions: Vec::new(),
@ -197,6 +202,7 @@ impl Thread {
next_message_id, next_message_id,
context: BTreeMap::default(), context: BTreeMap::default(),
context_by_message: HashMap::default(), context_by_message: HashMap::default(),
system_prompt_context: None,
checkpoints_by_message: HashMap::default(), checkpoints_by_message: HashMap::default(),
completion_count: 0, completion_count: 0,
pending_completions: Vec::new(), pending_completions: Vec::new(),
@ -478,6 +484,116 @@ impl Thread {
}) })
} }
pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
self.system_prompt_context = Some(context);
}
pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
&self.system_prompt_context
}
pub fn load_system_prompt_context(
&self,
cx: &App,
) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
let project = self.project.read(cx);
let tasks = project
.visible_worktrees(cx)
.map(|worktree| {
Self::load_worktree_info_for_system_prompt(
project.fs().clone(),
worktree.read(cx),
cx,
)
})
.collect::<Vec<_>>();
cx.spawn(async |_cx| {
let results = futures::future::join_all(tasks).await;
let mut first_err = None;
let worktrees = results
.into_iter()
.map(|(worktree, err)| {
if first_err.is_none() && err.is_some() {
first_err = err;
}
worktree
})
.collect::<Vec<_>>();
(AssistantSystemPromptContext::new(worktrees), first_err)
})
}
fn load_worktree_info_for_system_prompt(
fs: Arc<dyn Fs>,
worktree: &Worktree,
cx: &App,
) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
let root_name = worktree.root_name().into();
let abs_path = worktree.abs_path();
// Note that Cline supports `.clinerules` being a directory, but that is not currently
// supported. This doesn't seem to occur often in GitHub repositories.
const RULES_FILE_NAMES: [&'static str; 5] = [
".rules",
".cursorrules",
".windsurfrules",
".clinerules",
"CLAUDE.md",
];
let selected_rules_file = RULES_FILE_NAMES
.into_iter()
.filter_map(|name| {
worktree
.entry_for_path(name)
.filter(|entry| entry.is_file())
.map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
})
.next();
if let Some((rel_rules_path, abs_rules_path)) = selected_rules_file {
cx.spawn(async move |_| {
let rules_file_result = maybe!(async move {
let abs_rules_path = abs_rules_path?;
let text = fs.load(&abs_rules_path).await.with_context(|| {
format!("Failed to load assistant rules file {:?}", abs_rules_path)
})?;
anyhow::Ok(RulesFile {
rel_path: rel_rules_path,
abs_path: abs_rules_path.into(),
text: text.trim().to_string(),
})
})
.await;
let (rules_file, rules_file_error) = match rules_file_result {
Ok(rules_file) => (Some(rules_file), None),
Err(err) => (
None,
Some(ThreadError::Message {
header: "Error loading rules file".into(),
message: format!("{err}").into(),
}),
),
};
let worktree_info = WorktreeInfoForSystemPrompt {
root_name,
abs_path,
rules_file,
};
(worktree_info, rules_file_error)
})
} else {
Task::ready((
WorktreeInfoForSystemPrompt {
root_name,
abs_path,
rules_file: None,
},
None,
))
}
}
pub fn send_to_model( pub fn send_to_model(
&mut self, &mut self,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
@ -515,36 +631,30 @@ impl Thread {
request_kind: RequestKind, request_kind: RequestKind,
cx: &App, cx: &App,
) -> LanguageModelRequest { ) -> LanguageModelRequest {
let worktree_root_names = self
.project
.read(cx)
.visible_worktrees(cx)
.map(|worktree| {
let worktree = worktree.read(cx);
AssistantSystemPromptWorktree {
root_name: worktree.root_name().into(),
abs_path: worktree.abs_path(),
}
})
.collect::<Vec<_>>();
let system_prompt = self
.prompt_builder
.generate_assistant_system_prompt(worktree_root_names)
.context("failed to generate assistant system prompt")
.log_err()
.unwrap_or_default();
let mut request = LanguageModelRequest { let mut request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage { messages: vec![],
role: Role::System,
content: vec![MessageContent::Text(system_prompt)],
cache: true,
}],
tools: Vec::new(), tools: Vec::new(),
stop: Vec::new(), stop: Vec::new(),
temperature: None, temperature: None,
}; };
if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
if let Some(system_prompt) = self
.prompt_builder
.generate_assistant_system_prompt(system_prompt_context)
.context("failed to generate assistant system prompt")
.log_err()
{
request.messages.push(LanguageModelRequestMessage {
role: Role::System,
content: vec![MessageContent::Text(system_prompt)],
cache: true,
});
}
} else {
log::error!("system_prompt_context not set.")
}
let mut referenced_context_ids = HashSet::default(); let mut referenced_context_ids = HashSet::default();
for message in &self.messages { for message in &self.messages {
@ -757,9 +867,10 @@ impl Thread {
.map(|err| err.to_string()) .map(|err| err.to_string())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("\n"); .join("\n");
cx.emit(ThreadEvent::ShowError(ThreadError::Message( cx.emit(ThreadEvent::ShowError(ThreadError::Message {
SharedString::from(error_message.clone()), header: "Error interacting with language model".into(),
))); message: SharedString::from(error_message.clone()),
}));
} }
thread.cancel_last_completion(cx); thread.cancel_last_completion(cx);
@ -1204,7 +1315,10 @@ impl Thread {
pub enum ThreadError { pub enum ThreadError {
PaymentRequired, PaymentRequired,
MaxMonthlySpendReached, MaxMonthlySpendReached,
Message(SharedString), Message {
header: SharedString,
message: SharedString,
},
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]

View file

@ -20,7 +20,7 @@ use prompt_store::PromptBuilder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use util::ResultExt as _; use util::ResultExt as _;
use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadId}; use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
pub fn init(cx: &mut App) { pub fn init(cx: &mut App) {
ThreadsDatabase::init(cx); ThreadsDatabase::init(cx);
@ -113,7 +113,7 @@ impl ThreadStore {
.await? .await?
.ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?; .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
this.update(cx, |this, cx| { let thread = this.update(cx, |this, cx| {
cx.new(|cx| { cx.new(|cx| {
Thread::deserialize( Thread::deserialize(
id.clone(), id.clone(),
@ -124,7 +124,19 @@ impl ThreadStore {
cx, cx,
) )
}) })
}) })?;
let (system_prompt_context, load_error) = thread
.update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
.await;
thread.update(cx, |thread, cx| {
thread.set_system_prompt_context(system_prompt_context);
if let Some(load_error) = load_error {
cx.emit(ThreadEvent::ShowError(load_error));
}
})?;
Ok(thread)
}) })
} }

View file

@ -79,10 +79,25 @@ impl Eval {
let start_time = std::time::SystemTime::now(); let start_time = std::time::SystemTime::now();
let (system_prompt_context, load_error) = cx
.update(|cx| {
assistant
.read(cx)
.thread
.read(cx)
.load_system_prompt_context(cx)
})?
.await;
if let Some(load_error) = load_error {
return Err(anyhow!("{:?}", load_error));
};
assistant.update(cx, |assistant, cx| { assistant.update(cx, |assistant, cx| {
assistant.thread.update(cx, |thread, cx| { assistant.thread.update(cx, |thread, cx| {
let context = vec![]; let context = vec![];
thread.insert_user_message(self.user_prompt.clone(), context, None, cx); thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
thread.set_system_prompt_context(system_prompt_context);
thread.send_to_model(model, RequestKind::Chat, cx); thread.send_to_model(model, RequestKind::Chat, cx);
}); });
})?; })?;

View file

@ -18,13 +18,34 @@ use util::ResultExt;
#[derive(Serialize)] #[derive(Serialize)]
pub struct AssistantSystemPromptContext { pub struct AssistantSystemPromptContext {
pub worktrees: Vec<AssistantSystemPromptWorktree>, pub worktrees: Vec<WorktreeInfoForSystemPrompt>,
pub has_rules: bool,
}
impl AssistantSystemPromptContext {
pub fn new(worktrees: Vec<WorktreeInfoForSystemPrompt>) -> Self {
let has_rules = worktrees
.iter()
.any(|worktree| worktree.rules_file.is_some());
Self {
worktrees,
has_rules,
}
}
} }
#[derive(Serialize)] #[derive(Serialize)]
pub struct AssistantSystemPromptWorktree { pub struct WorktreeInfoForSystemPrompt {
pub root_name: String, pub root_name: String,
pub abs_path: Arc<Path>, pub abs_path: Arc<Path>,
pub rules_file: Option<RulesFile>,
}
#[derive(Serialize)]
pub struct RulesFile {
pub rel_path: Arc<Path>,
pub abs_path: Arc<Path>,
pub text: String,
} }
#[derive(Serialize)] #[derive(Serialize)]
@ -234,12 +255,11 @@ impl PromptBuilder {
pub fn generate_assistant_system_prompt( pub fn generate_assistant_system_prompt(
&self, &self,
worktrees: Vec<AssistantSystemPromptWorktree>, context: &AssistantSystemPromptContext,
) -> Result<String, RenderError> { ) -> Result<String, RenderError> {
let prompt = AssistantSystemPromptContext { worktrees };
self.handlebars self.handlebars
.lock() .lock()
.render("assistant_system_prompt", &prompt) .render("assistant_system_prompt", context)
} }
pub fn generate_inline_transformation_prompt( pub fn generate_inline_transformation_prompt(