agent: Use default prompts from prompt library in system prompt (#28915)
Related to #28490. - Default prompts from the prompt library are now included as "user rules" in the system prompt. - Presence of these user rules is shown at the beginning of the thread in the UI. _ Now uses an `Entity<PromptStore>` instead of an `Arc<PromptStore>`. Motivation for this is emitting a `PromptsUpdatedEvent`. - Now disallows concurrent reloading of the system prompt. Before this change it was possible for reloads to race. Release Notes: - agent: Added support for including default prompts from the Prompt Library as "user rules" in the system prompt. --------- Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
This commit is contained in:
parent
eea6cfb383
commit
502a0f6535
12 changed files with 433 additions and 165 deletions
|
@ -42,6 +42,7 @@ use ui::{
|
|||
};
|
||||
use util::ResultExt as _;
|
||||
use workspace::{OpenOptions, Workspace};
|
||||
use zed_actions::assistant::OpenPromptLibrary;
|
||||
|
||||
use crate::context_store::ContextStore;
|
||||
|
||||
|
@ -2948,53 +2949,106 @@ impl ActiveThread {
|
|||
return div().into_any();
|
||||
};
|
||||
|
||||
let default_user_rules_text = if project_context.default_user_rules.is_empty() {
|
||||
None
|
||||
} else if project_context.default_user_rules.len() == 1 {
|
||||
let user_rules = &project_context.default_user_rules[0];
|
||||
|
||||
match user_rules.title.as_ref() {
|
||||
Some(title) => Some(format!("Using \"{title}\" user rule")),
|
||||
None => Some("Using user rule".into()),
|
||||
}
|
||||
} else {
|
||||
Some(format!(
|
||||
"Using {} user rules",
|
||||
project_context.default_user_rules.len()
|
||||
))
|
||||
};
|
||||
|
||||
let rules_files = project_context
|
||||
.worktrees
|
||||
.iter()
|
||||
.filter_map(|worktree| worktree.rules_file.as_ref())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let label_text = match rules_files.as_slice() {
|
||||
&[] => return div().into_any(),
|
||||
&[rules_file] => {
|
||||
format!("Using {:?} file", rules_file.path_in_worktree)
|
||||
}
|
||||
rules_files => {
|
||||
format!("Using {} rules files", rules_files.len())
|
||||
}
|
||||
let rules_file_text = match rules_files.as_slice() {
|
||||
&[] => None,
|
||||
&[rules_file] => Some(format!(
|
||||
"Using project {:?} file",
|
||||
rules_file.path_in_worktree
|
||||
)),
|
||||
rules_files => Some(format!("Using {} project rules files", rules_files.len())),
|
||||
};
|
||||
|
||||
div()
|
||||
if default_user_rules_text.is_none() && rules_file_text.is_none() {
|
||||
return div().into_any();
|
||||
}
|
||||
|
||||
v_flex()
|
||||
.pt_2()
|
||||
.px_2p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
.gap_1()
|
||||
.when_some(
|
||||
default_user_rules_text,
|
||||
|parent, default_user_rules_text| {
|
||||
parent.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.w_full()
|
||||
.child(
|
||||
Icon::new(IconName::File)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Disabled),
|
||||
)
|
||||
.child(
|
||||
Label::new(label_text)
|
||||
Label::new(default_user_rules_text)
|
||||
.size(LabelSize::XSmall)
|
||||
.color(Color::Muted)
|
||||
.buffer_font(cx),
|
||||
.truncate()
|
||||
.buffer_font(cx)
|
||||
.ml_1p5()
|
||||
.mr_0p5(),
|
||||
)
|
||||
.child(
|
||||
IconButton::new("open-prompt-library", IconName::ArrowUpRightAlt)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Ignored)
|
||||
// TODO: Figure out a way to pass focus handle here so we can display the `OpenPromptLibrary` keybinding
|
||||
.tooltip(Tooltip::text("View User Rules"))
|
||||
.on_click(|_event, window, cx| {
|
||||
window.dispatch_action(Box::new(OpenPromptLibrary), cx)
|
||||
}),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
IconButton::new("open-rule", IconName::ArrowUpRightAlt)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Ignored)
|
||||
.on_click(cx.listener(Self::handle_open_rules))
|
||||
.tooltip(Tooltip::text("View Rules")),
|
||||
),
|
||||
},
|
||||
)
|
||||
.when_some(rules_file_text, |parent, rules_file_text| {
|
||||
parent.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.child(
|
||||
Icon::new(IconName::File)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Disabled),
|
||||
)
|
||||
.child(
|
||||
Label::new(rules_file_text)
|
||||
.size(LabelSize::XSmall)
|
||||
.color(Color::Muted)
|
||||
.buffer_font(cx)
|
||||
.ml_1p5()
|
||||
.mr_0p5(),
|
||||
)
|
||||
.child(
|
||||
IconButton::new("open-rule", IconName::ArrowUpRightAlt)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Ignored)
|
||||
.on_click(cx.listener(Self::handle_open_rules))
|
||||
.tooltip(Tooltip::text("View Rules")),
|
||||
),
|
||||
)
|
||||
})
|
||||
.into_any()
|
||||
}
|
||||
|
||||
|
|
|
@ -922,6 +922,7 @@ mod tests {
|
|||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
AssistantSettings::register(cx);
|
||||
prompt_store::init(cx);
|
||||
thread_store::init(cx);
|
||||
workspace::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
|
@ -951,7 +952,8 @@ mod tests {
|
|||
cx,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
.await
|
||||
.unwrap();
|
||||
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
||||
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
|
||||
|
||||
|
|
|
@ -213,7 +213,7 @@ impl AssistantPanel {
|
|||
let project = workspace.project().clone();
|
||||
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
|
||||
})?
|
||||
.await;
|
||||
.await?;
|
||||
|
||||
let slash_commands = Arc::new(SlashCommandWorkingSet::default());
|
||||
let context_store = workspace
|
||||
|
|
|
@ -4,7 +4,7 @@ use std::ops::Range;
|
|||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
|
@ -939,7 +939,7 @@ impl Thread {
|
|||
pub fn to_completion_request(
|
||||
&self,
|
||||
request_kind: RequestKind,
|
||||
cx: &App,
|
||||
cx: &mut Context<Self>,
|
||||
) -> LanguageModelRequest {
|
||||
let mut request = LanguageModelRequest {
|
||||
messages: vec![],
|
||||
|
@ -949,20 +949,33 @@ impl Thread {
|
|||
};
|
||||
|
||||
if let Some(project_context) = self.project_context.borrow().as_ref() {
|
||||
if let Some(system_prompt) = self
|
||||
match self
|
||||
.prompt_builder
|
||||
.generate_assistant_system_prompt(project_context)
|
||||
.context("failed to generate assistant system prompt")
|
||||
.log_err()
|
||||
{
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: vec![MessageContent::Text(system_prompt)],
|
||||
cache: true,
|
||||
});
|
||||
Err(err) => {
|
||||
let message = format!("{err:?}").into();
|
||||
log::error!("{message}");
|
||||
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
|
||||
header: "Error generating system prompt".into(),
|
||||
message,
|
||||
}));
|
||||
}
|
||||
Ok(system_prompt) => {
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: vec![MessageContent::Text(system_prompt)],
|
||||
cache: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::error!("project_context not set.")
|
||||
let message = "Context for system prompt unexpectedly not ready.".into();
|
||||
log::error!("{message}");
|
||||
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
|
||||
header: "Error generating system prompt".into(),
|
||||
message,
|
||||
}));
|
||||
}
|
||||
|
||||
for message in &self.messages {
|
||||
|
@ -2163,7 +2176,7 @@ fn main() {{
|
|||
assert_eq!(message.context, expected_context);
|
||||
|
||||
// Check message in request
|
||||
let request = thread.read_with(cx, |thread, cx| {
|
||||
let request = thread.update(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
|
@ -2255,7 +2268,7 @@ fn main() {{
|
|||
assert!(message3.context.contains("file3.rs"));
|
||||
|
||||
// Check entire request to make sure all contexts are properly included
|
||||
let request = thread.read_with(cx, |thread, cx| {
|
||||
let request = thread.update(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
|
@ -2307,7 +2320,7 @@ fn main() {{
|
|||
assert_eq!(message.context, "");
|
||||
|
||||
// Check message in request
|
||||
let request = thread.read_with(cx, |thread, cx| {
|
||||
let request = thread.update(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
|
@ -2327,7 +2340,7 @@ fn main() {{
|
|||
assert_eq!(message2.context, "");
|
||||
|
||||
// Check that both messages appear in the request
|
||||
let request = thread.read_with(cx, |thread, cx| {
|
||||
let request = thread.update(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
|
@ -2369,7 +2382,7 @@ fn main() {{
|
|||
});
|
||||
|
||||
// Create a request and check that it doesn't have a stale buffer warning yet
|
||||
let initial_request = thread.read_with(cx, |thread, cx| {
|
||||
let initial_request = thread.update(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
|
@ -2399,7 +2412,7 @@ fn main() {{
|
|||
});
|
||||
|
||||
// Create a new request and check for the stale buffer warning
|
||||
let new_request = thread.read_with(cx, |thread, cx| {
|
||||
let new_request = thread.update(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
|
@ -2428,6 +2441,7 @@ fn main() {{
|
|||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
AssistantSettings::register(cx);
|
||||
prompt_store::init(cx);
|
||||
thread_store::init(cx);
|
||||
workspace::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
|
@ -2467,7 +2481,8 @@ fn main() {{
|
|||
cx,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
||||
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
|
||||
|
|
|
@ -12,8 +12,9 @@ use collections::HashMap;
|
|||
use context_server::manager::ContextServerManager;
|
||||
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
|
||||
use fs::Fs;
|
||||
use futures::FutureExt as _;
|
||||
use futures::channel::{mpsc, oneshot};
|
||||
use futures::future::{self, BoxFuture, Shared};
|
||||
use futures::{FutureExt as _, StreamExt as _};
|
||||
use gpui::{
|
||||
App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
|
||||
Subscription, Task, prelude::*,
|
||||
|
@ -22,7 +23,10 @@ use heed::Database;
|
|||
use heed::types::SerdeBincode;
|
||||
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
|
||||
use project::{Project, Worktree};
|
||||
use prompt_store::{ProjectContext, PromptBuilder, RulesFileContext, WorktreeContext};
|
||||
use prompt_store::{
|
||||
DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptStore, PromptsUpdatedEvent,
|
||||
RulesFileContext, WorktreeContext,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use util::ResultExt as _;
|
||||
|
@ -62,6 +66,8 @@ pub struct ThreadStore {
|
|||
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
|
||||
threads: Vec<SerializedThreadMetadata>,
|
||||
project_context: SharedProjectContext,
|
||||
reload_system_prompt_tx: mpsc::Sender<()>,
|
||||
_reload_system_prompt_task: Task<()>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
|
@ -77,12 +83,22 @@ impl ThreadStore {
|
|||
tools: Entity<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
cx: &mut App,
|
||||
) -> Task<Entity<Self>> {
|
||||
let thread_store = cx.new(|cx| Self::new(project, tools, prompt_builder, cx));
|
||||
let reload = thread_store.update(cx, |store, cx| store.reload_system_prompt(cx));
|
||||
cx.foreground_executor().spawn(async move {
|
||||
reload.await;
|
||||
thread_store
|
||||
) -> Task<Result<Entity<Self>>> {
|
||||
let prompt_store = PromptStore::global(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
let prompt_store = prompt_store.await.ok();
|
||||
let (thread_store, ready_rx) = cx.update(|cx| {
|
||||
let mut option_ready_rx = None;
|
||||
let thread_store = cx.new(|cx| {
|
||||
let (thread_store, ready_rx) =
|
||||
Self::new(project, tools, prompt_builder, prompt_store, cx);
|
||||
option_ready_rx = Some(ready_rx);
|
||||
thread_store
|
||||
});
|
||||
(thread_store, option_ready_rx.take().unwrap())
|
||||
})?;
|
||||
ready_rx.await?;
|
||||
Ok(thread_store)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -90,17 +106,53 @@ impl ThreadStore {
|
|||
project: Entity<Project>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
) -> (Self, oneshot::Receiver<()>) {
|
||||
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
|
||||
let context_server_manager = cx.new(|cx| {
|
||||
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
|
||||
});
|
||||
let settings_subscription =
|
||||
|
||||
let mut subscriptions = vec![
|
||||
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
|
||||
this.load_default_profile(cx);
|
||||
});
|
||||
let project_subscription = cx.subscribe(&project, Self::handle_project_event);
|
||||
}),
|
||||
cx.subscribe(&project, Self::handle_project_event),
|
||||
];
|
||||
|
||||
if let Some(prompt_store) = prompt_store.as_ref() {
|
||||
subscriptions.push(cx.subscribe(
|
||||
prompt_store,
|
||||
|this, _prompt_store, PromptsUpdatedEvent, _cx| {
|
||||
this.enqueue_system_prompt_reload();
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
// This channel and task prevent concurrent and redundant loading of the system prompt.
|
||||
let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
|
||||
let (ready_tx, ready_rx) = oneshot::channel();
|
||||
let mut ready_tx = Some(ready_tx);
|
||||
let reload_system_prompt_task = cx.spawn({
|
||||
async move |thread_store, cx| {
|
||||
loop {
|
||||
let Some(reload_task) = thread_store
|
||||
.update(cx, |thread_store, cx| {
|
||||
thread_store.reload_system_prompt(prompt_store.clone(), cx)
|
||||
})
|
||||
.ok()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
reload_task.await;
|
||||
if let Some(ready_tx) = ready_tx.take() {
|
||||
ready_tx.send(()).ok();
|
||||
}
|
||||
reload_system_prompt_rx.next().await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let this = Self {
|
||||
project,
|
||||
|
@ -110,23 +162,25 @@ impl ThreadStore {
|
|||
context_server_tool_ids: HashMap::default(),
|
||||
threads: Vec::new(),
|
||||
project_context: SharedProjectContext::default(),
|
||||
_subscriptions: vec![settings_subscription, project_subscription],
|
||||
reload_system_prompt_tx,
|
||||
_reload_system_prompt_task: reload_system_prompt_task,
|
||||
_subscriptions: subscriptions,
|
||||
};
|
||||
this.load_default_profile(cx);
|
||||
this.register_context_server_handlers(cx);
|
||||
this.reload(cx).detach_and_log_err(cx);
|
||||
this
|
||||
(this, ready_rx)
|
||||
}
|
||||
|
||||
fn handle_project_event(
|
||||
&mut self,
|
||||
_project: Entity<Project>,
|
||||
event: &project::Event,
|
||||
cx: &mut Context<Self>,
|
||||
_cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
|
||||
self.reload_system_prompt(cx).detach();
|
||||
self.enqueue_system_prompt_reload();
|
||||
}
|
||||
project::Event::WorktreeUpdatedEntries(_, items) => {
|
||||
if items.iter().any(|(path, _, _)| {
|
||||
|
@ -134,16 +188,25 @@ impl ThreadStore {
|
|||
.iter()
|
||||
.any(|name| path.as_ref() == Path::new(name))
|
||||
}) {
|
||||
self.reload_system_prompt(cx).detach();
|
||||
self.enqueue_system_prompt_reload();
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reload_system_prompt(&self, cx: &mut Context<Self>) -> Task<()> {
|
||||
fn enqueue_system_prompt_reload(&mut self) {
|
||||
self.reload_system_prompt_tx.try_send(()).ok();
|
||||
}
|
||||
|
||||
// Note that this should only be called from `reload_system_prompt_task`.
|
||||
fn reload_system_prompt(
|
||||
&self,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<()> {
|
||||
let project = self.project.read(cx);
|
||||
let tasks = project
|
||||
let worktree_tasks = project
|
||||
.visible_worktrees(cx)
|
||||
.map(|worktree| {
|
||||
Self::load_worktree_info_for_system_prompt(
|
||||
|
@ -153,10 +216,23 @@ impl ThreadStore {
|
|||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let default_user_rules_task = match prompt_store {
|
||||
None => Task::ready(vec![]),
|
||||
Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
|
||||
let prompts = prompt_store.default_prompt_metadata();
|
||||
let load_tasks = prompts.into_iter().map(|prompt_metadata| {
|
||||
let contents = prompt_store.load(prompt_metadata.id, cx);
|
||||
async move { (contents.await, prompt_metadata) }
|
||||
});
|
||||
cx.background_spawn(future::join_all(load_tasks))
|
||||
}),
|
||||
};
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let results = futures::future::join_all(tasks).await;
|
||||
let worktrees = results
|
||||
let (worktrees, default_user_rules) =
|
||||
future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
|
||||
|
||||
let worktrees = worktrees
|
||||
.into_iter()
|
||||
.map(|(worktree, rules_error)| {
|
||||
if let Some(rules_error) = rules_error {
|
||||
|
@ -165,8 +241,29 @@ impl ThreadStore {
|
|||
worktree
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let default_user_rules = default_user_rules
|
||||
.into_iter()
|
||||
.flat_map(|(contents, prompt_metadata)| match contents {
|
||||
Ok(contents) => Some(DefaultUserRulesContext {
|
||||
title: prompt_metadata.title.map(|title| title.to_string()),
|
||||
contents,
|
||||
}),
|
||||
Err(err) => {
|
||||
this.update(cx, |_, cx| {
|
||||
cx.emit(RulesLoadingError {
|
||||
message: format!("{err:?}").into(),
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
this.update(cx, |this, _cx| {
|
||||
*this.project_context.0.borrow_mut() = Some(ProjectContext::new(worktrees));
|
||||
*this.project_context.0.borrow_mut() =
|
||||
Some(ProjectContext::new(worktrees, default_user_rules));
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue