Actually run the eval and fix a hang when retrieving outline (#28547)
Release Notes: - Fixed a regression that caused the agent to hang sometimes. --------- Co-authored-by: Thomas Mickley-Doyle <tmickleydoyle@gmail.com> Co-authored-by: Nathan Sobo <nathan@zed.dev> Co-authored-by: Michael Sloan <mgsloan@gmail.com>
This commit is contained in:
parent
c0262cf62f
commit
2440faf4b2
28 changed files with 642 additions and 1862 deletions
|
@ -4,7 +4,7 @@ use crate::thread::{
|
|||
LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError,
|
||||
ThreadEvent, ThreadFeedback,
|
||||
};
|
||||
use crate::thread_store::ThreadStore;
|
||||
use crate::thread_store::{RulesLoadingError, ThreadStore};
|
||||
use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
|
||||
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
|
||||
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
|
||||
|
@ -21,7 +21,7 @@ use gpui::{
|
|||
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
|
||||
};
|
||||
use language::{Buffer, LanguageRegistry};
|
||||
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
|
||||
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role, StopReason};
|
||||
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
|
||||
use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
|
||||
use project::ProjectItem as _;
|
||||
|
@ -668,6 +668,7 @@ impl ActiveThread {
|
|||
let subscriptions = vec![
|
||||
cx.observe(&thread, |_, _, cx| cx.notify()),
|
||||
cx.subscribe_in(&thread, window, Self::handle_thread_event),
|
||||
cx.subscribe(&thread_store, Self::handle_rules_loading_error),
|
||||
];
|
||||
|
||||
let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.), {
|
||||
|
@ -833,10 +834,9 @@ impl ActiveThread {
|
|||
| ThreadEvent::SummaryChanged => {
|
||||
self.save_thread(cx);
|
||||
}
|
||||
ThreadEvent::DoneStreaming => {
|
||||
let thread = self.thread.read(cx);
|
||||
|
||||
if !thread.is_generating() {
|
||||
ThreadEvent::Stopped(reason) => match reason {
|
||||
Ok(StopReason::EndTurn | StopReason::MaxTokens) => {
|
||||
let thread = self.thread.read(cx);
|
||||
self.show_notification(
|
||||
if thread.used_tools_since_last_user_message() {
|
||||
"Finished running tools"
|
||||
|
@ -848,7 +848,8 @@ impl ActiveThread {
|
|||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
ThreadEvent::ToolConfirmationNeeded => {
|
||||
self.show_notification("Waiting for tool confirmation", IconName::Info, window, cx);
|
||||
}
|
||||
|
@ -925,6 +926,19 @@ impl ActiveThread {
|
|||
}
|
||||
}
|
||||
|
||||
fn handle_rules_loading_error(
|
||||
&mut self,
|
||||
_thread_store: Entity<ThreadStore>,
|
||||
error: &RulesLoadingError,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.last_error = Some(ThreadError::Message {
|
||||
header: "Error loading rules file".into(),
|
||||
message: error.message.clone(),
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn show_notification(
|
||||
&mut self,
|
||||
caption: impl Into<SharedString>,
|
||||
|
@ -2701,12 +2715,13 @@ impl ActiveThread {
|
|||
}
|
||||
|
||||
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
|
||||
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
|
||||
else {
|
||||
let project_context = self.thread.read(cx).project_context();
|
||||
let project_context = project_context.borrow();
|
||||
let Some(project_context) = project_context.as_ref() else {
|
||||
return div().into_any();
|
||||
};
|
||||
|
||||
let rules_files = system_prompt_context
|
||||
let rules_files = project_context
|
||||
.worktrees
|
||||
.iter()
|
||||
.filter_map(|worktree| worktree.rules_file.as_ref())
|
||||
|
@ -2796,12 +2811,13 @@ impl ActiveThread {
|
|||
}
|
||||
|
||||
fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
|
||||
else {
|
||||
let project_context = self.thread.read(cx).project_context();
|
||||
let project_context = project_context.borrow();
|
||||
let Some(project_context) = project_context.as_ref() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let abs_paths = system_prompt_context
|
||||
let abs_paths = project_context
|
||||
.worktrees
|
||||
.iter()
|
||||
.flat_map(|worktree| worktree.rules_file.as_ref())
|
||||
|
|
|
@ -921,15 +921,16 @@ mod tests {
|
|||
})
|
||||
.unwrap();
|
||||
|
||||
let thread_store = cx.update(|cx| {
|
||||
ThreadStore::new(
|
||||
project.clone(),
|
||||
Arc::default(),
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
let thread_store = cx
|
||||
.update(|cx| {
|
||||
ThreadStore::load(
|
||||
project.clone(),
|
||||
Arc::default(),
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
||||
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
|
||||
|
||||
|
|
|
@ -194,10 +194,12 @@ impl AssistantPanel {
|
|||
) -> Task<Result<Entity<Self>>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let tools = Arc::new(ToolWorkingSet::default());
|
||||
let thread_store = workspace.update(cx, |workspace, cx| {
|
||||
let project = workspace.project().clone();
|
||||
ThreadStore::new(project, tools.clone(), prompt_builder.clone(), cx)
|
||||
})??;
|
||||
let thread_store = workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
let project = workspace.project().clone();
|
||||
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let slash_commands = Arc::new(SlashCommandWorkingSet::default());
|
||||
let context_store = workspace
|
||||
|
|
|
@ -32,8 +32,8 @@ use crate::profile_selector::ProfileSelector;
|
|||
use crate::thread::{RequestKind, Thread, TokenUsageRatio};
|
||||
use crate::thread_store::ThreadStore;
|
||||
use crate::{
|
||||
AgentDiff, Chat, ChatMode, NewThread, OpenAgentDiff, RemoveAllContext, ThreadEvent,
|
||||
ToggleContextPicker, ToggleProfileSelector,
|
||||
AgentDiff, Chat, ChatMode, NewThread, OpenAgentDiff, RemoveAllContext, ToggleContextPicker,
|
||||
ToggleProfileSelector,
|
||||
};
|
||||
|
||||
pub struct MessageEditor {
|
||||
|
@ -235,8 +235,6 @@ impl MessageEditor {
|
|||
let refresh_task =
|
||||
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
|
||||
|
||||
let system_prompt_context_task = self.thread.read(cx).load_system_prompt_context(cx);
|
||||
|
||||
let thread = self.thread.clone();
|
||||
let context_store = self.context_store.clone();
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
|
@ -245,16 +243,6 @@ impl MessageEditor {
|
|||
cx.spawn(async move |this, cx| {
|
||||
let checkpoint = checkpoint.await.ok();
|
||||
refresh_task.await;
|
||||
let (system_prompt_context, load_error) = system_prompt_context_task.await;
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.set_system_prompt_context(system_prompt_context);
|
||||
if let Some(load_error) = load_error {
|
||||
cx.emit(ThreadEvent::ShowError(load_error));
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
|
|
|
@ -3,14 +3,12 @@ use std::io::Write;
|
|||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
use agent_rules::load_worktree_rules_file;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use feature_flags::{self, FeatureFlagAppExt};
|
||||
use fs::Fs;
|
||||
use futures::future::Shared;
|
||||
use futures::{FutureExt, StreamExt as _};
|
||||
use git::repository::DiffType;
|
||||
|
@ -21,19 +19,20 @@ use language_model::{
|
|||
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
PaymentRequiredError, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use project::Project;
|
||||
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
|
||||
use project::{Project, Worktree};
|
||||
use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
|
||||
use prompt_store::PromptBuilder;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
use thiserror::Error;
|
||||
use util::{ResultExt as _, TryFutureExt as _, post_inc};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::context::{AssistantContext, ContextId, format_context_as_string};
|
||||
use crate::thread_store::{
|
||||
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
|
||||
SerializedToolUse,
|
||||
SerializedToolUse, SharedProjectContext,
|
||||
};
|
||||
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
|
||||
|
||||
|
@ -247,7 +246,7 @@ pub struct Thread {
|
|||
next_message_id: MessageId,
|
||||
context: BTreeMap<ContextId, AssistantContext>,
|
||||
context_by_message: HashMap<MessageId, Vec<ContextId>>,
|
||||
system_prompt_context: Option<AssistantSystemPromptContext>,
|
||||
project_context: SharedProjectContext,
|
||||
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
|
||||
completion_count: usize,
|
||||
pending_completions: Vec<PendingCompletion>,
|
||||
|
@ -269,6 +268,7 @@ impl Thread {
|
|||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
system_prompt: SharedProjectContext,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
Self {
|
||||
|
@ -281,7 +281,7 @@ impl Thread {
|
|||
next_message_id: MessageId(0),
|
||||
context: BTreeMap::default(),
|
||||
context_by_message: HashMap::default(),
|
||||
system_prompt_context: None,
|
||||
project_context: system_prompt,
|
||||
checkpoints_by_message: HashMap::default(),
|
||||
completion_count: 0,
|
||||
pending_completions: Vec::new(),
|
||||
|
@ -310,6 +310,7 @@ impl Thread {
|
|||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
project_context: SharedProjectContext,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let next_message_id = MessageId(
|
||||
|
@ -350,7 +351,7 @@ impl Thread {
|
|||
next_message_id,
|
||||
context: BTreeMap::default(),
|
||||
context_by_message: HashMap::default(),
|
||||
system_prompt_context: None,
|
||||
project_context,
|
||||
checkpoints_by_message: HashMap::default(),
|
||||
completion_count: 0,
|
||||
pending_completions: Vec::new(),
|
||||
|
@ -388,6 +389,10 @@ impl Thread {
|
|||
self.summary.clone()
|
||||
}
|
||||
|
||||
pub fn project_context(&self) -> SharedProjectContext {
|
||||
self.project_context.clone()
|
||||
}
|
||||
|
||||
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
|
||||
|
||||
pub fn summary_or_default(&self) -> SharedString {
|
||||
|
@ -812,86 +817,6 @@ impl Thread {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
|
||||
self.system_prompt_context = Some(context);
|
||||
}
|
||||
|
||||
pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
|
||||
&self.system_prompt_context
|
||||
}
|
||||
|
||||
pub fn load_system_prompt_context(
|
||||
&self,
|
||||
cx: &App,
|
||||
) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
|
||||
let project = self.project.read(cx);
|
||||
let tasks = project
|
||||
.visible_worktrees(cx)
|
||||
.map(|worktree| {
|
||||
Self::load_worktree_info_for_system_prompt(
|
||||
project.fs().clone(),
|
||||
worktree.read(cx),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cx.spawn(async |_cx| {
|
||||
let results = futures::future::join_all(tasks).await;
|
||||
let mut first_err = None;
|
||||
let worktrees = results
|
||||
.into_iter()
|
||||
.map(|(worktree, err)| {
|
||||
if first_err.is_none() && err.is_some() {
|
||||
first_err = err;
|
||||
}
|
||||
worktree
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
(AssistantSystemPromptContext::new(worktrees), first_err)
|
||||
})
|
||||
}
|
||||
|
||||
fn load_worktree_info_for_system_prompt(
|
||||
fs: Arc<dyn Fs>,
|
||||
worktree: &Worktree,
|
||||
cx: &App,
|
||||
) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
|
||||
let root_name = worktree.root_name().into();
|
||||
let abs_path = worktree.abs_path();
|
||||
|
||||
let rules_task = load_worktree_rules_file(fs, worktree, cx);
|
||||
let Some(rules_task) = rules_task else {
|
||||
return Task::ready((
|
||||
WorktreeInfoForSystemPrompt {
|
||||
root_name,
|
||||
abs_path,
|
||||
rules_file: None,
|
||||
},
|
||||
None,
|
||||
));
|
||||
};
|
||||
|
||||
cx.spawn(async move |_| {
|
||||
let (rules_file, rules_file_error) = match rules_task.await {
|
||||
Ok(rules_file) => (Some(rules_file), None),
|
||||
Err(err) => (
|
||||
None,
|
||||
Some(ThreadError::Message {
|
||||
header: "Error loading rules file".into(),
|
||||
message: format!("{err}").into(),
|
||||
}),
|
||||
),
|
||||
};
|
||||
let worktree_info = WorktreeInfoForSystemPrompt {
|
||||
root_name,
|
||||
abs_path,
|
||||
rules_file,
|
||||
};
|
||||
(worktree_info, rules_file_error)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn send_to_model(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
|
@ -941,10 +866,10 @@ impl Thread {
|
|||
temperature: None,
|
||||
};
|
||||
|
||||
if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
|
||||
if let Some(project_context) = self.project_context.borrow().as_ref() {
|
||||
if let Some(system_prompt) = self
|
||||
.prompt_builder
|
||||
.generate_assistant_system_prompt(system_prompt_context)
|
||||
.generate_assistant_system_prompt(project_context)
|
||||
.context("failed to generate assistant system prompt")
|
||||
.log_err()
|
||||
{
|
||||
|
@ -955,7 +880,7 @@ impl Thread {
|
|||
});
|
||||
}
|
||||
} else {
|
||||
log::error!("system_prompt_context not set.")
|
||||
log::error!("project_context not set.")
|
||||
}
|
||||
|
||||
for message in &self.messages {
|
||||
|
@ -1215,7 +1140,7 @@ impl Thread {
|
|||
thread.cancel_last_completion(cx);
|
||||
}
|
||||
}
|
||||
cx.emit(ThreadEvent::DoneStreaming);
|
||||
cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
|
||||
|
||||
thread.auto_capture_telemetry(cx);
|
||||
|
||||
|
@ -1963,10 +1888,13 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Error)]
|
||||
pub enum ThreadError {
|
||||
#[error("Payment required")]
|
||||
PaymentRequired,
|
||||
#[error("Max monthly spend reached")]
|
||||
MaxMonthlySpendReached,
|
||||
#[error("Message {header}: {message}")]
|
||||
Message {
|
||||
header: SharedString,
|
||||
message: SharedString,
|
||||
|
@ -1979,7 +1907,7 @@ pub enum ThreadEvent {
|
|||
StreamedCompletion,
|
||||
StreamedAssistantText(MessageId, String),
|
||||
StreamedAssistantThinking(MessageId, String),
|
||||
DoneStreaming,
|
||||
Stopped(Result<StopReason, Arc<anyhow::Error>>),
|
||||
MessageAdded(MessageId),
|
||||
MessageEdited(MessageId),
|
||||
MessageDeleted(MessageId),
|
||||
|
@ -2085,9 +2013,9 @@ fn main() {{
|
|||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
assert_eq!(request.messages.len(), 1);
|
||||
assert_eq!(request.messages.len(), 2);
|
||||
let expected_full_message = format!("{}Please explain this code", expected_context);
|
||||
assert_eq!(request.messages[0].string_contents(), expected_full_message);
|
||||
assert_eq!(request.messages[1].string_contents(), expected_full_message);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
@ -2178,20 +2106,20 @@ fn main() {{
|
|||
});
|
||||
|
||||
// The request should contain all 3 messages
|
||||
assert_eq!(request.messages.len(), 3);
|
||||
assert_eq!(request.messages.len(), 4);
|
||||
|
||||
// Check that the contexts are properly formatted in each message
|
||||
assert!(request.messages[0].string_contents().contains("file1.rs"));
|
||||
assert!(!request.messages[0].string_contents().contains("file2.rs"));
|
||||
assert!(!request.messages[0].string_contents().contains("file3.rs"));
|
||||
|
||||
assert!(!request.messages[1].string_contents().contains("file1.rs"));
|
||||
assert!(request.messages[1].string_contents().contains("file2.rs"));
|
||||
assert!(request.messages[1].string_contents().contains("file1.rs"));
|
||||
assert!(!request.messages[1].string_contents().contains("file2.rs"));
|
||||
assert!(!request.messages[1].string_contents().contains("file3.rs"));
|
||||
|
||||
assert!(!request.messages[2].string_contents().contains("file1.rs"));
|
||||
assert!(!request.messages[2].string_contents().contains("file2.rs"));
|
||||
assert!(request.messages[2].string_contents().contains("file3.rs"));
|
||||
assert!(request.messages[2].string_contents().contains("file2.rs"));
|
||||
assert!(!request.messages[2].string_contents().contains("file3.rs"));
|
||||
|
||||
assert!(!request.messages[3].string_contents().contains("file1.rs"));
|
||||
assert!(!request.messages[3].string_contents().contains("file2.rs"));
|
||||
assert!(request.messages[3].string_contents().contains("file3.rs"));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
@ -2229,9 +2157,9 @@ fn main() {{
|
|||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
assert_eq!(request.messages.len(), 1);
|
||||
assert_eq!(request.messages.len(), 2);
|
||||
assert_eq!(
|
||||
request.messages[0].string_contents(),
|
||||
request.messages[1].string_contents(),
|
||||
"What is the best way to learn Rust?"
|
||||
);
|
||||
|
||||
|
@ -2249,13 +2177,13 @@ fn main() {{
|
|||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
assert_eq!(request.messages.len(), 2);
|
||||
assert_eq!(request.messages.len(), 3);
|
||||
assert_eq!(
|
||||
request.messages[0].string_contents(),
|
||||
request.messages[1].string_contents(),
|
||||
"What is the best way to learn Rust?"
|
||||
);
|
||||
assert_eq!(
|
||||
request.messages[1].string_contents(),
|
||||
request.messages[2].string_contents(),
|
||||
"Are there any good books?"
|
||||
);
|
||||
}
|
||||
|
@ -2376,15 +2304,16 @@ fn main() {{
|
|||
let (workspace, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
|
||||
let thread_store = cx.update(|_, cx| {
|
||||
ThreadStore::new(
|
||||
project.clone(),
|
||||
Arc::default(),
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
let thread_store = cx
|
||||
.update(|_, cx| {
|
||||
ThreadStore::load(
|
||||
project.clone(),
|
||||
Arc::default(),
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
|
||||
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
||||
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
|
||||
|
|
|
@ -1,37 +1,57 @@
|
|||
use std::borrow::Cow;
|
||||
use std::path::PathBuf;
|
||||
use std::cell::{Ref, RefCell};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
|
||||
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
use context_server::manager::ContextServerManager;
|
||||
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
|
||||
use fs::Fs;
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::{self, BoxFuture, Shared};
|
||||
use gpui::{
|
||||
App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
|
||||
prelude::*,
|
||||
App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
|
||||
Subscription, Task, prelude::*,
|
||||
};
|
||||
use heed::Database;
|
||||
use heed::types::SerdeBincode;
|
||||
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
|
||||
use project::Project;
|
||||
use prompt_store::PromptBuilder;
|
||||
use project::{Project, Worktree};
|
||||
use prompt_store::{ProjectContext, PromptBuilder, RulesFileContext, WorktreeContext};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::thread::{
|
||||
DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId,
|
||||
};
|
||||
use crate::thread::{DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadId};
|
||||
|
||||
const RULES_FILE_NAMES: [&'static str; 6] = [
|
||||
".rules",
|
||||
".cursorrules",
|
||||
".windsurfrules",
|
||||
".clinerules",
|
||||
".github/copilot-instructions.md",
|
||||
"CLAUDE.md",
|
||||
];
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
ThreadsDatabase::init(cx);
|
||||
}
|
||||
|
||||
/// A system prompt shared by all threads created by this ThreadStore
|
||||
#[derive(Clone, Default)]
|
||||
pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
|
||||
|
||||
impl SharedProjectContext {
|
||||
pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
|
||||
self.0.borrow()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ThreadStore {
|
||||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
|
@ -39,43 +59,187 @@ pub struct ThreadStore {
|
|||
context_server_manager: Entity<ContextServerManager>,
|
||||
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
|
||||
threads: Vec<SerializedThreadMetadata>,
|
||||
project_context: SharedProjectContext,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
pub struct RulesLoadingError {
|
||||
pub message: SharedString,
|
||||
}
|
||||
|
||||
impl EventEmitter<RulesLoadingError> for ThreadStore {}
|
||||
|
||||
impl ThreadStore {
|
||||
pub fn new(
|
||||
pub fn load(
|
||||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
cx: &mut App,
|
||||
) -> Result<Entity<Self>> {
|
||||
let this = cx.new(|cx| {
|
||||
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
|
||||
let context_server_manager = cx.new(|cx| {
|
||||
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
|
||||
});
|
||||
let settings_subscription =
|
||||
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
|
||||
this.load_default_profile(cx);
|
||||
});
|
||||
) -> Task<Entity<Self>> {
|
||||
let thread_store = cx.new(|cx| Self::new(project, tools, prompt_builder, cx));
|
||||
let reload = thread_store.update(cx, |store, cx| store.reload_system_prompt(cx));
|
||||
cx.foreground_executor().spawn(async move {
|
||||
reload.await;
|
||||
thread_store
|
||||
})
|
||||
}
|
||||
|
||||
let this = Self {
|
||||
project,
|
||||
tools,
|
||||
prompt_builder,
|
||||
context_server_manager,
|
||||
context_server_tool_ids: HashMap::default(),
|
||||
threads: Vec::new(),
|
||||
_subscriptions: vec![settings_subscription],
|
||||
};
|
||||
this.load_default_profile(cx);
|
||||
this.register_context_server_handlers(cx);
|
||||
this.reload(cx).detach_and_log_err(cx);
|
||||
|
||||
this
|
||||
fn new(
|
||||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
|
||||
let context_server_manager = cx.new(|cx| {
|
||||
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
|
||||
});
|
||||
let settings_subscription =
|
||||
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
|
||||
this.load_default_profile(cx);
|
||||
});
|
||||
let project_subscription = cx.subscribe(&project, Self::handle_project_event);
|
||||
|
||||
Ok(this)
|
||||
let this = Self {
|
||||
project,
|
||||
tools,
|
||||
prompt_builder,
|
||||
context_server_manager,
|
||||
context_server_tool_ids: HashMap::default(),
|
||||
threads: Vec::new(),
|
||||
project_context: SharedProjectContext::default(),
|
||||
_subscriptions: vec![settings_subscription, project_subscription],
|
||||
};
|
||||
this.load_default_profile(cx);
|
||||
this.register_context_server_handlers(cx);
|
||||
this.reload(cx).detach_and_log_err(cx);
|
||||
this
|
||||
}
|
||||
|
||||
fn handle_project_event(
|
||||
&mut self,
|
||||
_project: Entity<Project>,
|
||||
event: &project::Event,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
|
||||
self.reload_system_prompt(cx).detach();
|
||||
}
|
||||
project::Event::WorktreeUpdatedEntries(_, items) => {
|
||||
if items.iter().any(|(path, _, _)| {
|
||||
RULES_FILE_NAMES
|
||||
.iter()
|
||||
.any(|name| path.as_ref() == Path::new(name))
|
||||
}) {
|
||||
self.reload_system_prompt(cx).detach();
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reload_system_prompt(&self, cx: &mut Context<Self>) -> Task<()> {
|
||||
let project = self.project.read(cx);
|
||||
let tasks = project
|
||||
.visible_worktrees(cx)
|
||||
.map(|worktree| {
|
||||
Self::load_worktree_info_for_system_prompt(
|
||||
project.fs().clone(),
|
||||
worktree.read(cx),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let results = futures::future::join_all(tasks).await;
|
||||
let worktrees = results
|
||||
.into_iter()
|
||||
.map(|(worktree, rules_error)| {
|
||||
if let Some(rules_error) = rules_error {
|
||||
this.update(cx, |_, cx| cx.emit(rules_error)).ok();
|
||||
}
|
||||
worktree
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
this.update(cx, |this, _cx| {
|
||||
*this.project_context.0.borrow_mut() = Some(ProjectContext::new(worktrees));
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
}
|
||||
|
||||
fn load_worktree_info_for_system_prompt(
|
||||
fs: Arc<dyn Fs>,
|
||||
worktree: &Worktree,
|
||||
cx: &App,
|
||||
) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
|
||||
let root_name = worktree.root_name().into();
|
||||
let abs_path = worktree.abs_path();
|
||||
|
||||
let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
|
||||
let Some(rules_task) = rules_task else {
|
||||
return Task::ready((
|
||||
WorktreeContext {
|
||||
root_name,
|
||||
abs_path,
|
||||
rules_file: None,
|
||||
},
|
||||
None,
|
||||
));
|
||||
};
|
||||
|
||||
cx.spawn(async move |_| {
|
||||
let (rules_file, rules_file_error) = match rules_task.await {
|
||||
Ok(rules_file) => (Some(rules_file), None),
|
||||
Err(err) => (
|
||||
None,
|
||||
Some(RulesLoadingError {
|
||||
message: format!("{err}").into(),
|
||||
}),
|
||||
),
|
||||
};
|
||||
let worktree_info = WorktreeContext {
|
||||
root_name,
|
||||
abs_path,
|
||||
rules_file,
|
||||
};
|
||||
(worktree_info, rules_file_error)
|
||||
})
|
||||
}
|
||||
|
||||
fn load_worktree_rules_file(
|
||||
fs: Arc<dyn Fs>,
|
||||
worktree: &Worktree,
|
||||
cx: &App,
|
||||
) -> Option<Task<Result<RulesFileContext>>> {
|
||||
let selected_rules_file = RULES_FILE_NAMES
|
||||
.into_iter()
|
||||
.filter_map(|name| {
|
||||
worktree
|
||||
.entry_for_path(name)
|
||||
.filter(|entry| entry.is_file())
|
||||
.map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
|
||||
})
|
||||
.next();
|
||||
|
||||
// Note that Cline supports `.clinerules` being a directory, but that is not currently
|
||||
// supported. This doesn't seem to occur often in GitHub repositories.
|
||||
selected_rules_file.map(|(path_in_worktree, abs_path)| {
|
||||
let fs = fs.clone();
|
||||
cx.background_spawn(async move {
|
||||
let abs_path = abs_path?;
|
||||
let text = fs.load(&abs_path).await.with_context(|| {
|
||||
format!("Failed to load assistant rules file {:?}", abs_path)
|
||||
})?;
|
||||
anyhow::Ok(RulesFileContext {
|
||||
path_in_worktree,
|
||||
abs_path: abs_path.into(),
|
||||
text: text.trim().to_string(),
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
|
||||
|
@ -107,6 +271,7 @@ impl ThreadStore {
|
|||
self.project.clone(),
|
||||
self.tools.clone(),
|
||||
self.prompt_builder.clone(),
|
||||
self.project_context.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
|
@ -134,21 +299,12 @@ impl ThreadStore {
|
|||
this.project.clone(),
|
||||
this.tools.clone(),
|
||||
this.prompt_builder.clone(),
|
||||
this.project_context.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})?;
|
||||
|
||||
let (system_prompt_context, load_error) = thread
|
||||
.update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
|
||||
.await;
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_system_prompt_context(system_prompt_context);
|
||||
if let Some(load_error) = load_error {
|
||||
cx.emit(ThreadEvent::ShowError(load_error));
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(thread)
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue