Actually run the eval and fix a hang when retrieving outline (#28547)
Release Notes: - Fixed a regression that caused the agent to hang sometimes. --------- Co-authored-by: Thomas Mickley-Doyle <tmickleydoyle@gmail.com> Co-authored-by: Nathan Sobo <nathan@zed.dev> Co-authored-by: Michael Sloan <mgsloan@gmail.com>
This commit is contained in:
parent
c0262cf62f
commit
2440faf4b2
28 changed files with 642 additions and 1862 deletions
57
Cargo.lock
generated
57
Cargo.lock
generated
|
@ -52,7 +52,6 @@ dependencies = [
|
|||
name = "agent"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"agent_rules",
|
||||
"anyhow",
|
||||
"assistant_context_editor",
|
||||
"assistant_settings",
|
||||
|
@ -116,6 +115,7 @@ dependencies = [
|
|||
"terminal_view",
|
||||
"text",
|
||||
"theme",
|
||||
"thiserror 2.0.12",
|
||||
"time",
|
||||
"time_format",
|
||||
"ui",
|
||||
|
@ -127,57 +127,6 @@ dependencies = [
|
|||
"zed_actions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent_eval"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"agent",
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"assistant_tools",
|
||||
"clap",
|
||||
"client",
|
||||
"collections",
|
||||
"context_server",
|
||||
"dap",
|
||||
"env_logger 0.11.8",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"language",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"node_runtime",
|
||||
"project",
|
||||
"prompt_store",
|
||||
"release_channel",
|
||||
"reqwest_client",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_json_lenient",
|
||||
"settings",
|
||||
"smol",
|
||||
"tempfile",
|
||||
"util",
|
||||
"walkdir",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent_rules"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"fs",
|
||||
"gpui",
|
||||
"indoc",
|
||||
"prompt_store",
|
||||
"util",
|
||||
"workspace-hack",
|
||||
"worktree",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.7.8"
|
||||
|
@ -4910,14 +4859,15 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"agent",
|
||||
"anyhow",
|
||||
"assistant_settings",
|
||||
"assistant_tool",
|
||||
"assistant_tools",
|
||||
"client",
|
||||
"collections",
|
||||
"context_server",
|
||||
"dap",
|
||||
"env_logger 0.11.8",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"language",
|
||||
|
@ -4930,7 +4880,6 @@ dependencies = [
|
|||
"reqwest_client",
|
||||
"serde",
|
||||
"settings",
|
||||
"smol",
|
||||
"toml 0.8.20",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
|
|
@ -3,13 +3,11 @@ resolver = "2"
|
|||
members = [
|
||||
"crates/activity_indicator",
|
||||
"crates/agent",
|
||||
"crates/agent_rules",
|
||||
"crates/anthropic",
|
||||
"crates/askpass",
|
||||
"crates/assets",
|
||||
"crates/assistant",
|
||||
"crates/assistant_context_editor",
|
||||
"crates/agent_eval",
|
||||
"crates/assistant_settings",
|
||||
"crates/assistant_slash_command",
|
||||
"crates/assistant_slash_commands",
|
||||
|
@ -211,14 +209,12 @@ edition = "2024"
|
|||
|
||||
activity_indicator = { path = "crates/activity_indicator" }
|
||||
agent = { path = "crates/agent" }
|
||||
agent_rules = { path = "crates/agent_rules" }
|
||||
ai = { path = "crates/ai" }
|
||||
anthropic = { path = "crates/anthropic" }
|
||||
askpass = { path = "crates/askpass" }
|
||||
assets = { path = "crates/assets" }
|
||||
assistant = { path = "crates/assistant" }
|
||||
assistant_context_editor = { path = "crates/assistant_context_editor" }
|
||||
assistant_eval = { path = "crates/agent_eval" }
|
||||
assistant_settings = { path = "crates/assistant_settings" }
|
||||
assistant_slash_command = { path = "crates/assistant_slash_command" }
|
||||
assistant_slash_commands = { path = "crates/assistant_slash_commands" }
|
||||
|
|
|
@ -19,7 +19,6 @@ test-support = [
|
|||
]
|
||||
|
||||
[dependencies]
|
||||
agent_rules.workspace = true
|
||||
anyhow.workspace = true
|
||||
assistant_context_editor.workspace = true
|
||||
assistant_settings.workspace = true
|
||||
|
@ -81,6 +80,7 @@ terminal.workspace = true
|
|||
terminal_view.workspace = true
|
||||
text.workspace = true
|
||||
theme.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
time_format.workspace = true
|
||||
ui.workspace = true
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::thread::{
|
|||
LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError,
|
||||
ThreadEvent, ThreadFeedback,
|
||||
};
|
||||
use crate::thread_store::ThreadStore;
|
||||
use crate::thread_store::{RulesLoadingError, ThreadStore};
|
||||
use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
|
||||
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
|
||||
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
|
||||
|
@ -21,7 +21,7 @@ use gpui::{
|
|||
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
|
||||
};
|
||||
use language::{Buffer, LanguageRegistry};
|
||||
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
|
||||
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role, StopReason};
|
||||
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
|
||||
use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
|
||||
use project::ProjectItem as _;
|
||||
|
@ -668,6 +668,7 @@ impl ActiveThread {
|
|||
let subscriptions = vec![
|
||||
cx.observe(&thread, |_, _, cx| cx.notify()),
|
||||
cx.subscribe_in(&thread, window, Self::handle_thread_event),
|
||||
cx.subscribe(&thread_store, Self::handle_rules_loading_error),
|
||||
];
|
||||
|
||||
let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.), {
|
||||
|
@ -833,10 +834,9 @@ impl ActiveThread {
|
|||
| ThreadEvent::SummaryChanged => {
|
||||
self.save_thread(cx);
|
||||
}
|
||||
ThreadEvent::DoneStreaming => {
|
||||
let thread = self.thread.read(cx);
|
||||
|
||||
if !thread.is_generating() {
|
||||
ThreadEvent::Stopped(reason) => match reason {
|
||||
Ok(StopReason::EndTurn | StopReason::MaxTokens) => {
|
||||
let thread = self.thread.read(cx);
|
||||
self.show_notification(
|
||||
if thread.used_tools_since_last_user_message() {
|
||||
"Finished running tools"
|
||||
|
@ -848,7 +848,8 @@ impl ActiveThread {
|
|||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
ThreadEvent::ToolConfirmationNeeded => {
|
||||
self.show_notification("Waiting for tool confirmation", IconName::Info, window, cx);
|
||||
}
|
||||
|
@ -925,6 +926,19 @@ impl ActiveThread {
|
|||
}
|
||||
}
|
||||
|
||||
fn handle_rules_loading_error(
|
||||
&mut self,
|
||||
_thread_store: Entity<ThreadStore>,
|
||||
error: &RulesLoadingError,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.last_error = Some(ThreadError::Message {
|
||||
header: "Error loading rules file".into(),
|
||||
message: error.message.clone(),
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn show_notification(
|
||||
&mut self,
|
||||
caption: impl Into<SharedString>,
|
||||
|
@ -2701,12 +2715,13 @@ impl ActiveThread {
|
|||
}
|
||||
|
||||
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
|
||||
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
|
||||
else {
|
||||
let project_context = self.thread.read(cx).project_context();
|
||||
let project_context = project_context.borrow();
|
||||
let Some(project_context) = project_context.as_ref() else {
|
||||
return div().into_any();
|
||||
};
|
||||
|
||||
let rules_files = system_prompt_context
|
||||
let rules_files = project_context
|
||||
.worktrees
|
||||
.iter()
|
||||
.filter_map(|worktree| worktree.rules_file.as_ref())
|
||||
|
@ -2796,12 +2811,13 @@ impl ActiveThread {
|
|||
}
|
||||
|
||||
fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
|
||||
else {
|
||||
let project_context = self.thread.read(cx).project_context();
|
||||
let project_context = project_context.borrow();
|
||||
let Some(project_context) = project_context.as_ref() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let abs_paths = system_prompt_context
|
||||
let abs_paths = project_context
|
||||
.worktrees
|
||||
.iter()
|
||||
.flat_map(|worktree| worktree.rules_file.as_ref())
|
||||
|
|
|
@ -921,15 +921,16 @@ mod tests {
|
|||
})
|
||||
.unwrap();
|
||||
|
||||
let thread_store = cx.update(|cx| {
|
||||
ThreadStore::new(
|
||||
project.clone(),
|
||||
Arc::default(),
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
let thread_store = cx
|
||||
.update(|cx| {
|
||||
ThreadStore::load(
|
||||
project.clone(),
|
||||
Arc::default(),
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
||||
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
|
||||
|
||||
|
|
|
@ -194,10 +194,12 @@ impl AssistantPanel {
|
|||
) -> Task<Result<Entity<Self>>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let tools = Arc::new(ToolWorkingSet::default());
|
||||
let thread_store = workspace.update(cx, |workspace, cx| {
|
||||
let project = workspace.project().clone();
|
||||
ThreadStore::new(project, tools.clone(), prompt_builder.clone(), cx)
|
||||
})??;
|
||||
let thread_store = workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
let project = workspace.project().clone();
|
||||
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let slash_commands = Arc::new(SlashCommandWorkingSet::default());
|
||||
let context_store = workspace
|
||||
|
|
|
@ -32,8 +32,8 @@ use crate::profile_selector::ProfileSelector;
|
|||
use crate::thread::{RequestKind, Thread, TokenUsageRatio};
|
||||
use crate::thread_store::ThreadStore;
|
||||
use crate::{
|
||||
AgentDiff, Chat, ChatMode, NewThread, OpenAgentDiff, RemoveAllContext, ThreadEvent,
|
||||
ToggleContextPicker, ToggleProfileSelector,
|
||||
AgentDiff, Chat, ChatMode, NewThread, OpenAgentDiff, RemoveAllContext, ToggleContextPicker,
|
||||
ToggleProfileSelector,
|
||||
};
|
||||
|
||||
pub struct MessageEditor {
|
||||
|
@ -235,8 +235,6 @@ impl MessageEditor {
|
|||
let refresh_task =
|
||||
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
|
||||
|
||||
let system_prompt_context_task = self.thread.read(cx).load_system_prompt_context(cx);
|
||||
|
||||
let thread = self.thread.clone();
|
||||
let context_store = self.context_store.clone();
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
|
@ -245,16 +243,6 @@ impl MessageEditor {
|
|||
cx.spawn(async move |this, cx| {
|
||||
let checkpoint = checkpoint.await.ok();
|
||||
refresh_task.await;
|
||||
let (system_prompt_context, load_error) = system_prompt_context_task.await;
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.set_system_prompt_context(system_prompt_context);
|
||||
if let Some(load_error) = load_error {
|
||||
cx.emit(ThreadEvent::ShowError(load_error));
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
|
|
|
@ -3,14 +3,12 @@ use std::io::Write;
|
|||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
use agent_rules::load_worktree_rules_file;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use feature_flags::{self, FeatureFlagAppExt};
|
||||
use fs::Fs;
|
||||
use futures::future::Shared;
|
||||
use futures::{FutureExt, StreamExt as _};
|
||||
use git::repository::DiffType;
|
||||
|
@ -21,19 +19,20 @@ use language_model::{
|
|||
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
PaymentRequiredError, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use project::Project;
|
||||
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
|
||||
use project::{Project, Worktree};
|
||||
use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
|
||||
use prompt_store::PromptBuilder;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
use thiserror::Error;
|
||||
use util::{ResultExt as _, TryFutureExt as _, post_inc};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::context::{AssistantContext, ContextId, format_context_as_string};
|
||||
use crate::thread_store::{
|
||||
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
|
||||
SerializedToolUse,
|
||||
SerializedToolUse, SharedProjectContext,
|
||||
};
|
||||
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
|
||||
|
||||
|
@ -247,7 +246,7 @@ pub struct Thread {
|
|||
next_message_id: MessageId,
|
||||
context: BTreeMap<ContextId, AssistantContext>,
|
||||
context_by_message: HashMap<MessageId, Vec<ContextId>>,
|
||||
system_prompt_context: Option<AssistantSystemPromptContext>,
|
||||
project_context: SharedProjectContext,
|
||||
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
|
||||
completion_count: usize,
|
||||
pending_completions: Vec<PendingCompletion>,
|
||||
|
@ -269,6 +268,7 @@ impl Thread {
|
|||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
system_prompt: SharedProjectContext,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
Self {
|
||||
|
@ -281,7 +281,7 @@ impl Thread {
|
|||
next_message_id: MessageId(0),
|
||||
context: BTreeMap::default(),
|
||||
context_by_message: HashMap::default(),
|
||||
system_prompt_context: None,
|
||||
project_context: system_prompt,
|
||||
checkpoints_by_message: HashMap::default(),
|
||||
completion_count: 0,
|
||||
pending_completions: Vec::new(),
|
||||
|
@ -310,6 +310,7 @@ impl Thread {
|
|||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
project_context: SharedProjectContext,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let next_message_id = MessageId(
|
||||
|
@ -350,7 +351,7 @@ impl Thread {
|
|||
next_message_id,
|
||||
context: BTreeMap::default(),
|
||||
context_by_message: HashMap::default(),
|
||||
system_prompt_context: None,
|
||||
project_context,
|
||||
checkpoints_by_message: HashMap::default(),
|
||||
completion_count: 0,
|
||||
pending_completions: Vec::new(),
|
||||
|
@ -388,6 +389,10 @@ impl Thread {
|
|||
self.summary.clone()
|
||||
}
|
||||
|
||||
pub fn project_context(&self) -> SharedProjectContext {
|
||||
self.project_context.clone()
|
||||
}
|
||||
|
||||
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
|
||||
|
||||
pub fn summary_or_default(&self) -> SharedString {
|
||||
|
@ -812,86 +817,6 @@ impl Thread {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
|
||||
self.system_prompt_context = Some(context);
|
||||
}
|
||||
|
||||
pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
|
||||
&self.system_prompt_context
|
||||
}
|
||||
|
||||
pub fn load_system_prompt_context(
|
||||
&self,
|
||||
cx: &App,
|
||||
) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
|
||||
let project = self.project.read(cx);
|
||||
let tasks = project
|
||||
.visible_worktrees(cx)
|
||||
.map(|worktree| {
|
||||
Self::load_worktree_info_for_system_prompt(
|
||||
project.fs().clone(),
|
||||
worktree.read(cx),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cx.spawn(async |_cx| {
|
||||
let results = futures::future::join_all(tasks).await;
|
||||
let mut first_err = None;
|
||||
let worktrees = results
|
||||
.into_iter()
|
||||
.map(|(worktree, err)| {
|
||||
if first_err.is_none() && err.is_some() {
|
||||
first_err = err;
|
||||
}
|
||||
worktree
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
(AssistantSystemPromptContext::new(worktrees), first_err)
|
||||
})
|
||||
}
|
||||
|
||||
fn load_worktree_info_for_system_prompt(
|
||||
fs: Arc<dyn Fs>,
|
||||
worktree: &Worktree,
|
||||
cx: &App,
|
||||
) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
|
||||
let root_name = worktree.root_name().into();
|
||||
let abs_path = worktree.abs_path();
|
||||
|
||||
let rules_task = load_worktree_rules_file(fs, worktree, cx);
|
||||
let Some(rules_task) = rules_task else {
|
||||
return Task::ready((
|
||||
WorktreeInfoForSystemPrompt {
|
||||
root_name,
|
||||
abs_path,
|
||||
rules_file: None,
|
||||
},
|
||||
None,
|
||||
));
|
||||
};
|
||||
|
||||
cx.spawn(async move |_| {
|
||||
let (rules_file, rules_file_error) = match rules_task.await {
|
||||
Ok(rules_file) => (Some(rules_file), None),
|
||||
Err(err) => (
|
||||
None,
|
||||
Some(ThreadError::Message {
|
||||
header: "Error loading rules file".into(),
|
||||
message: format!("{err}").into(),
|
||||
}),
|
||||
),
|
||||
};
|
||||
let worktree_info = WorktreeInfoForSystemPrompt {
|
||||
root_name,
|
||||
abs_path,
|
||||
rules_file,
|
||||
};
|
||||
(worktree_info, rules_file_error)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn send_to_model(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
|
@ -941,10 +866,10 @@ impl Thread {
|
|||
temperature: None,
|
||||
};
|
||||
|
||||
if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
|
||||
if let Some(project_context) = self.project_context.borrow().as_ref() {
|
||||
if let Some(system_prompt) = self
|
||||
.prompt_builder
|
||||
.generate_assistant_system_prompt(system_prompt_context)
|
||||
.generate_assistant_system_prompt(project_context)
|
||||
.context("failed to generate assistant system prompt")
|
||||
.log_err()
|
||||
{
|
||||
|
@ -955,7 +880,7 @@ impl Thread {
|
|||
});
|
||||
}
|
||||
} else {
|
||||
log::error!("system_prompt_context not set.")
|
||||
log::error!("project_context not set.")
|
||||
}
|
||||
|
||||
for message in &self.messages {
|
||||
|
@ -1215,7 +1140,7 @@ impl Thread {
|
|||
thread.cancel_last_completion(cx);
|
||||
}
|
||||
}
|
||||
cx.emit(ThreadEvent::DoneStreaming);
|
||||
cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
|
||||
|
||||
thread.auto_capture_telemetry(cx);
|
||||
|
||||
|
@ -1963,10 +1888,13 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Error)]
|
||||
pub enum ThreadError {
|
||||
#[error("Payment required")]
|
||||
PaymentRequired,
|
||||
#[error("Max monthly spend reached")]
|
||||
MaxMonthlySpendReached,
|
||||
#[error("Message {header}: {message}")]
|
||||
Message {
|
||||
header: SharedString,
|
||||
message: SharedString,
|
||||
|
@ -1979,7 +1907,7 @@ pub enum ThreadEvent {
|
|||
StreamedCompletion,
|
||||
StreamedAssistantText(MessageId, String),
|
||||
StreamedAssistantThinking(MessageId, String),
|
||||
DoneStreaming,
|
||||
Stopped(Result<StopReason, Arc<anyhow::Error>>),
|
||||
MessageAdded(MessageId),
|
||||
MessageEdited(MessageId),
|
||||
MessageDeleted(MessageId),
|
||||
|
@ -2085,9 +2013,9 @@ fn main() {{
|
|||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
assert_eq!(request.messages.len(), 1);
|
||||
assert_eq!(request.messages.len(), 2);
|
||||
let expected_full_message = format!("{}Please explain this code", expected_context);
|
||||
assert_eq!(request.messages[0].string_contents(), expected_full_message);
|
||||
assert_eq!(request.messages[1].string_contents(), expected_full_message);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
@ -2178,20 +2106,20 @@ fn main() {{
|
|||
});
|
||||
|
||||
// The request should contain all 3 messages
|
||||
assert_eq!(request.messages.len(), 3);
|
||||
assert_eq!(request.messages.len(), 4);
|
||||
|
||||
// Check that the contexts are properly formatted in each message
|
||||
assert!(request.messages[0].string_contents().contains("file1.rs"));
|
||||
assert!(!request.messages[0].string_contents().contains("file2.rs"));
|
||||
assert!(!request.messages[0].string_contents().contains("file3.rs"));
|
||||
|
||||
assert!(!request.messages[1].string_contents().contains("file1.rs"));
|
||||
assert!(request.messages[1].string_contents().contains("file2.rs"));
|
||||
assert!(request.messages[1].string_contents().contains("file1.rs"));
|
||||
assert!(!request.messages[1].string_contents().contains("file2.rs"));
|
||||
assert!(!request.messages[1].string_contents().contains("file3.rs"));
|
||||
|
||||
assert!(!request.messages[2].string_contents().contains("file1.rs"));
|
||||
assert!(!request.messages[2].string_contents().contains("file2.rs"));
|
||||
assert!(request.messages[2].string_contents().contains("file3.rs"));
|
||||
assert!(request.messages[2].string_contents().contains("file2.rs"));
|
||||
assert!(!request.messages[2].string_contents().contains("file3.rs"));
|
||||
|
||||
assert!(!request.messages[3].string_contents().contains("file1.rs"));
|
||||
assert!(!request.messages[3].string_contents().contains("file2.rs"));
|
||||
assert!(request.messages[3].string_contents().contains("file3.rs"));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
@ -2229,9 +2157,9 @@ fn main() {{
|
|||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
assert_eq!(request.messages.len(), 1);
|
||||
assert_eq!(request.messages.len(), 2);
|
||||
assert_eq!(
|
||||
request.messages[0].string_contents(),
|
||||
request.messages[1].string_contents(),
|
||||
"What is the best way to learn Rust?"
|
||||
);
|
||||
|
||||
|
@ -2249,13 +2177,13 @@ fn main() {{
|
|||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
});
|
||||
|
||||
assert_eq!(request.messages.len(), 2);
|
||||
assert_eq!(request.messages.len(), 3);
|
||||
assert_eq!(
|
||||
request.messages[0].string_contents(),
|
||||
request.messages[1].string_contents(),
|
||||
"What is the best way to learn Rust?"
|
||||
);
|
||||
assert_eq!(
|
||||
request.messages[1].string_contents(),
|
||||
request.messages[2].string_contents(),
|
||||
"Are there any good books?"
|
||||
);
|
||||
}
|
||||
|
@ -2376,15 +2304,16 @@ fn main() {{
|
|||
let (workspace, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
|
||||
let thread_store = cx.update(|_, cx| {
|
||||
ThreadStore::new(
|
||||
project.clone(),
|
||||
Arc::default(),
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
let thread_store = cx
|
||||
.update(|_, cx| {
|
||||
ThreadStore::load(
|
||||
project.clone(),
|
||||
Arc::default(),
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
|
||||
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
||||
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
|
||||
|
|
|
@ -1,37 +1,57 @@
|
|||
use std::borrow::Cow;
|
||||
use std::path::PathBuf;
|
||||
use std::cell::{Ref, RefCell};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
|
||||
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
use context_server::manager::ContextServerManager;
|
||||
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
|
||||
use fs::Fs;
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::{self, BoxFuture, Shared};
|
||||
use gpui::{
|
||||
App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
|
||||
prelude::*,
|
||||
App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
|
||||
Subscription, Task, prelude::*,
|
||||
};
|
||||
use heed::Database;
|
||||
use heed::types::SerdeBincode;
|
||||
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
|
||||
use project::Project;
|
||||
use prompt_store::PromptBuilder;
|
||||
use project::{Project, Worktree};
|
||||
use prompt_store::{ProjectContext, PromptBuilder, RulesFileContext, WorktreeContext};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::thread::{
|
||||
DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId,
|
||||
};
|
||||
use crate::thread::{DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadId};
|
||||
|
||||
const RULES_FILE_NAMES: [&'static str; 6] = [
|
||||
".rules",
|
||||
".cursorrules",
|
||||
".windsurfrules",
|
||||
".clinerules",
|
||||
".github/copilot-instructions.md",
|
||||
"CLAUDE.md",
|
||||
];
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
ThreadsDatabase::init(cx);
|
||||
}
|
||||
|
||||
/// A system prompt shared by all threads created by this ThreadStore
|
||||
#[derive(Clone, Default)]
|
||||
pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
|
||||
|
||||
impl SharedProjectContext {
|
||||
pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
|
||||
self.0.borrow()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ThreadStore {
|
||||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
|
@ -39,43 +59,187 @@ pub struct ThreadStore {
|
|||
context_server_manager: Entity<ContextServerManager>,
|
||||
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
|
||||
threads: Vec<SerializedThreadMetadata>,
|
||||
project_context: SharedProjectContext,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
pub struct RulesLoadingError {
|
||||
pub message: SharedString,
|
||||
}
|
||||
|
||||
impl EventEmitter<RulesLoadingError> for ThreadStore {}
|
||||
|
||||
impl ThreadStore {
|
||||
pub fn new(
|
||||
pub fn load(
|
||||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
cx: &mut App,
|
||||
) -> Result<Entity<Self>> {
|
||||
let this = cx.new(|cx| {
|
||||
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
|
||||
let context_server_manager = cx.new(|cx| {
|
||||
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
|
||||
});
|
||||
let settings_subscription =
|
||||
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
|
||||
this.load_default_profile(cx);
|
||||
});
|
||||
) -> Task<Entity<Self>> {
|
||||
let thread_store = cx.new(|cx| Self::new(project, tools, prompt_builder, cx));
|
||||
let reload = thread_store.update(cx, |store, cx| store.reload_system_prompt(cx));
|
||||
cx.foreground_executor().spawn(async move {
|
||||
reload.await;
|
||||
thread_store
|
||||
})
|
||||
}
|
||||
|
||||
let this = Self {
|
||||
project,
|
||||
tools,
|
||||
prompt_builder,
|
||||
context_server_manager,
|
||||
context_server_tool_ids: HashMap::default(),
|
||||
threads: Vec::new(),
|
||||
_subscriptions: vec![settings_subscription],
|
||||
};
|
||||
this.load_default_profile(cx);
|
||||
this.register_context_server_handlers(cx);
|
||||
this.reload(cx).detach_and_log_err(cx);
|
||||
|
||||
this
|
||||
fn new(
|
||||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
|
||||
let context_server_manager = cx.new(|cx| {
|
||||
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
|
||||
});
|
||||
let settings_subscription =
|
||||
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
|
||||
this.load_default_profile(cx);
|
||||
});
|
||||
let project_subscription = cx.subscribe(&project, Self::handle_project_event);
|
||||
|
||||
Ok(this)
|
||||
let this = Self {
|
||||
project,
|
||||
tools,
|
||||
prompt_builder,
|
||||
context_server_manager,
|
||||
context_server_tool_ids: HashMap::default(),
|
||||
threads: Vec::new(),
|
||||
project_context: SharedProjectContext::default(),
|
||||
_subscriptions: vec![settings_subscription, project_subscription],
|
||||
};
|
||||
this.load_default_profile(cx);
|
||||
this.register_context_server_handlers(cx);
|
||||
this.reload(cx).detach_and_log_err(cx);
|
||||
this
|
||||
}
|
||||
|
||||
fn handle_project_event(
|
||||
&mut self,
|
||||
_project: Entity<Project>,
|
||||
event: &project::Event,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
|
||||
self.reload_system_prompt(cx).detach();
|
||||
}
|
||||
project::Event::WorktreeUpdatedEntries(_, items) => {
|
||||
if items.iter().any(|(path, _, _)| {
|
||||
RULES_FILE_NAMES
|
||||
.iter()
|
||||
.any(|name| path.as_ref() == Path::new(name))
|
||||
}) {
|
||||
self.reload_system_prompt(cx).detach();
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reload_system_prompt(&self, cx: &mut Context<Self>) -> Task<()> {
|
||||
let project = self.project.read(cx);
|
||||
let tasks = project
|
||||
.visible_worktrees(cx)
|
||||
.map(|worktree| {
|
||||
Self::load_worktree_info_for_system_prompt(
|
||||
project.fs().clone(),
|
||||
worktree.read(cx),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let results = futures::future::join_all(tasks).await;
|
||||
let worktrees = results
|
||||
.into_iter()
|
||||
.map(|(worktree, rules_error)| {
|
||||
if let Some(rules_error) = rules_error {
|
||||
this.update(cx, |_, cx| cx.emit(rules_error)).ok();
|
||||
}
|
||||
worktree
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
this.update(cx, |this, _cx| {
|
||||
*this.project_context.0.borrow_mut() = Some(ProjectContext::new(worktrees));
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
}
|
||||
|
||||
fn load_worktree_info_for_system_prompt(
|
||||
fs: Arc<dyn Fs>,
|
||||
worktree: &Worktree,
|
||||
cx: &App,
|
||||
) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
|
||||
let root_name = worktree.root_name().into();
|
||||
let abs_path = worktree.abs_path();
|
||||
|
||||
let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
|
||||
let Some(rules_task) = rules_task else {
|
||||
return Task::ready((
|
||||
WorktreeContext {
|
||||
root_name,
|
||||
abs_path,
|
||||
rules_file: None,
|
||||
},
|
||||
None,
|
||||
));
|
||||
};
|
||||
|
||||
cx.spawn(async move |_| {
|
||||
let (rules_file, rules_file_error) = match rules_task.await {
|
||||
Ok(rules_file) => (Some(rules_file), None),
|
||||
Err(err) => (
|
||||
None,
|
||||
Some(RulesLoadingError {
|
||||
message: format!("{err}").into(),
|
||||
}),
|
||||
),
|
||||
};
|
||||
let worktree_info = WorktreeContext {
|
||||
root_name,
|
||||
abs_path,
|
||||
rules_file,
|
||||
};
|
||||
(worktree_info, rules_file_error)
|
||||
})
|
||||
}
|
||||
|
||||
fn load_worktree_rules_file(
|
||||
fs: Arc<dyn Fs>,
|
||||
worktree: &Worktree,
|
||||
cx: &App,
|
||||
) -> Option<Task<Result<RulesFileContext>>> {
|
||||
let selected_rules_file = RULES_FILE_NAMES
|
||||
.into_iter()
|
||||
.filter_map(|name| {
|
||||
worktree
|
||||
.entry_for_path(name)
|
||||
.filter(|entry| entry.is_file())
|
||||
.map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
|
||||
})
|
||||
.next();
|
||||
|
||||
// Note that Cline supports `.clinerules` being a directory, but that is not currently
|
||||
// supported. This doesn't seem to occur often in GitHub repositories.
|
||||
selected_rules_file.map(|(path_in_worktree, abs_path)| {
|
||||
let fs = fs.clone();
|
||||
cx.background_spawn(async move {
|
||||
let abs_path = abs_path?;
|
||||
let text = fs.load(&abs_path).await.with_context(|| {
|
||||
format!("Failed to load assistant rules file {:?}", abs_path)
|
||||
})?;
|
||||
anyhow::Ok(RulesFileContext {
|
||||
path_in_worktree,
|
||||
abs_path: abs_path.into(),
|
||||
text: text.trim().to_string(),
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
|
||||
|
@ -107,6 +271,7 @@ impl ThreadStore {
|
|||
self.project.clone(),
|
||||
self.tools.clone(),
|
||||
self.prompt_builder.clone(),
|
||||
self.project_context.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
|
@ -134,21 +299,12 @@ impl ThreadStore {
|
|||
this.project.clone(),
|
||||
this.tools.clone(),
|
||||
this.prompt_builder.clone(),
|
||||
this.project_context.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})?;
|
||||
|
||||
let (system_prompt_context, load_error) = thread
|
||||
.update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
|
||||
.await;
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_system_prompt_context(system_prompt_context);
|
||||
if let Some(load_error) = load_error {
|
||||
cx.emit(ThreadEvent::ShowError(load_error));
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(thread)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,46 +0,0 @@
|
|||
[package]
|
||||
name = "agent_eval"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "agent_eval"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
agent.workspace = true
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
assistant_tools.workspace = true
|
||||
clap.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
context_server.workspace = true
|
||||
dap.workspace = true
|
||||
env_logger.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
gpui_tokio.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
node_runtime.workspace = true
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
release_channel.workspace = true
|
||||
reqwest_client.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde_json_lenient.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
tempfile.workspace = true
|
||||
util.workspace = true
|
||||
walkdir.workspace = true
|
||||
workspace-hack.workspace = true
|
|
@ -1 +0,0 @@
|
|||
../../LICENSE-GPL
|
|
@ -1,52 +0,0 @@
|
|||
// Copied from `crates/zed/build.rs`, with removal of code for including the zed icon on windows.
|
||||
|
||||
use std::process::Command;
|
||||
|
||||
fn main() {
|
||||
if cfg!(target_os = "macos") {
|
||||
println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
|
||||
|
||||
// Weakly link ReplayKit to ensure Zed can be used on macOS 10.15+.
|
||||
println!("cargo:rustc-link-arg=-Wl,-weak_framework,ReplayKit");
|
||||
|
||||
// Seems to be required to enable Swift concurrency
|
||||
println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift");
|
||||
|
||||
// Register exported Objective-C selectors, protocols, etc
|
||||
println!("cargo:rustc-link-arg=-Wl,-ObjC");
|
||||
}
|
||||
|
||||
// Populate git sha environment variable if git is available
|
||||
println!("cargo:rerun-if-changed=../../.git/logs/HEAD");
|
||||
println!(
|
||||
"cargo:rustc-env=TARGET={}",
|
||||
std::env::var("TARGET").unwrap()
|
||||
);
|
||||
if let Ok(output) = Command::new("git").args(["rev-parse", "HEAD"]).output() {
|
||||
if output.status.success() {
|
||||
let git_sha = String::from_utf8_lossy(&output.stdout);
|
||||
let git_sha = git_sha.trim();
|
||||
|
||||
println!("cargo:rustc-env=ZED_COMMIT_SHA={git_sha}");
|
||||
|
||||
if let Ok(build_profile) = std::env::var("PROFILE") {
|
||||
if build_profile == "release" {
|
||||
// This is currently the best way to make `cargo build ...`'s build script
|
||||
// to print something to stdout without extra verbosity.
|
||||
println!(
|
||||
"cargo:warning=Info: using '{git_sha}' hash for ZED_COMMIT_SHA env var"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
#[cfg(target_env = "msvc")]
|
||||
{
|
||||
// todo(windows): This is to avoid stack overflow. Remove it when solved.
|
||||
println!("cargo:rustc-link-arg=/stack:{}", 8 * 1024 * 1024);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,384 +0,0 @@
|
|||
use crate::git_commands::{run_git, setup_temp_repo};
|
||||
use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
|
||||
use crate::{get_exercise_language, get_exercise_name};
|
||||
use agent::RequestKind;
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::HashMap;
|
||||
use gpui::{App, Task};
|
||||
use language_model::{LanguageModel, TokenUsage};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fs,
|
||||
io::Write,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct EvalResult {
|
||||
pub exercise_name: String,
|
||||
pub diff: String,
|
||||
pub assistant_response: String,
|
||||
pub elapsed_time_ms: u128,
|
||||
pub timestamp: u128,
|
||||
// Token usage fields
|
||||
pub input_tokens: usize,
|
||||
pub output_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
pub tool_use_counts: usize,
|
||||
}
|
||||
|
||||
pub struct EvalOutput {
|
||||
pub diff: String,
|
||||
pub last_message: String,
|
||||
pub elapsed_time: Duration,
|
||||
pub assistant_response_count: usize,
|
||||
pub tool_use_counts: HashMap<Arc<str>, u32>,
|
||||
pub token_usage: TokenUsage,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct EvalSetup {
|
||||
pub url: String,
|
||||
pub base_sha: String,
|
||||
}
|
||||
|
||||
pub struct Eval {
|
||||
pub repo_path: PathBuf,
|
||||
pub eval_setup: EvalSetup,
|
||||
pub user_prompt: String,
|
||||
}
|
||||
|
||||
impl Eval {
|
||||
// Keep this method for potential future use, but mark it as intentionally unused
|
||||
#[allow(dead_code)]
|
||||
pub async fn load(_name: String, path: PathBuf, repos_dir: &Path) -> Result<Self> {
|
||||
let prompt_path = path.join("prompt.txt");
|
||||
let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
|
||||
let setup_path = path.join("setup.json");
|
||||
let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
|
||||
let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
|
||||
|
||||
// Move this internal function inside the load method since it's only used here
|
||||
fn repo_dir_name(url: &str) -> String {
|
||||
url.trim_start_matches("https://")
|
||||
.replace(|c: char| !c.is_alphanumeric(), "_")
|
||||
}
|
||||
|
||||
let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
|
||||
|
||||
Ok(Eval {
|
||||
repo_path,
|
||||
eval_setup,
|
||||
user_prompt,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run(
|
||||
self,
|
||||
app_state: Arc<HeadlessAppState>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<EvalOutput>> {
|
||||
cx.spawn(async move |cx| {
|
||||
run_git(&self.repo_path, &["checkout", &self.eval_setup.base_sha]).await?;
|
||||
|
||||
let (assistant, done_rx) =
|
||||
cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
|
||||
|
||||
let _worktree = assistant
|
||||
.update(cx, |assistant, cx| {
|
||||
assistant.project.update(cx, |project, cx| {
|
||||
project.create_worktree(&self.repo_path, true, cx)
|
||||
})
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let start_time = std::time::SystemTime::now();
|
||||
|
||||
let (system_prompt_context, load_error) = cx
|
||||
.update(|cx| {
|
||||
assistant
|
||||
.read(cx)
|
||||
.thread
|
||||
.read(cx)
|
||||
.load_system_prompt_context(cx)
|
||||
})?
|
||||
.await;
|
||||
|
||||
if let Some(load_error) = load_error {
|
||||
return Err(anyhow!("{:?}", load_error));
|
||||
};
|
||||
|
||||
assistant.update(cx, |assistant, cx| {
|
||||
assistant.thread.update(cx, |thread, cx| {
|
||||
let context = vec![];
|
||||
thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
|
||||
thread.set_system_prompt_context(system_prompt_context);
|
||||
thread.send_to_model(model, RequestKind::Chat, cx);
|
||||
});
|
||||
})?;
|
||||
|
||||
done_rx.recv().await??;
|
||||
|
||||
// Add this section to check untracked files
|
||||
println!("Checking for untracked files:");
|
||||
let untracked = run_git(
|
||||
&self.repo_path,
|
||||
&["ls-files", "--others", "--exclude-standard"],
|
||||
)
|
||||
.await?;
|
||||
if untracked.is_empty() {
|
||||
println!("No untracked files found");
|
||||
} else {
|
||||
// Add all files to git so they appear in the diff
|
||||
println!("Adding untracked files to git");
|
||||
run_git(&self.repo_path, &["add", "."]).await?;
|
||||
}
|
||||
|
||||
// get git status
|
||||
let _status = run_git(&self.repo_path, &["status", "--short"]).await?;
|
||||
|
||||
let elapsed_time = start_time.elapsed()?;
|
||||
|
||||
// Get diff of staged changes (the files we just added)
|
||||
let staged_diff = run_git(&self.repo_path, &["diff", "--staged"]).await?;
|
||||
|
||||
// Get diff of unstaged changes
|
||||
let unstaged_diff = run_git(&self.repo_path, &["diff"]).await?;
|
||||
|
||||
// Combine both diffs
|
||||
let diff = if unstaged_diff.is_empty() {
|
||||
staged_diff
|
||||
} else if staged_diff.is_empty() {
|
||||
unstaged_diff
|
||||
} else {
|
||||
format!(
|
||||
"# Staged changes\n{}\n\n# Unstaged changes\n{}",
|
||||
staged_diff, unstaged_diff
|
||||
)
|
||||
};
|
||||
|
||||
assistant.update(cx, |assistant, cx| {
|
||||
let thread = assistant.thread.read(cx);
|
||||
let last_message = thread.messages().last().unwrap();
|
||||
if last_message.role != language_model::Role::Assistant {
|
||||
return Err(anyhow!("Last message is not from assistant"));
|
||||
}
|
||||
let assistant_response_count = thread
|
||||
.messages()
|
||||
.filter(|message| message.role == language_model::Role::Assistant)
|
||||
.count();
|
||||
Ok(EvalOutput {
|
||||
diff,
|
||||
last_message: last_message.to_string(),
|
||||
elapsed_time,
|
||||
assistant_response_count,
|
||||
tool_use_counts: assistant.tool_use_counts.clone(),
|
||||
token_usage: thread.cumulative_token_usage(),
|
||||
})
|
||||
})?
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl EvalOutput {
|
||||
// Keep this method for potential future use, but mark it as intentionally unused
|
||||
#[allow(dead_code)]
|
||||
pub fn save_to_directory(&self, output_dir: &Path, eval_output_value: String) -> Result<()> {
|
||||
// Create the output directory if it doesn't exist
|
||||
fs::create_dir_all(&output_dir)?;
|
||||
|
||||
// Save the diff to a file
|
||||
let diff_path = output_dir.join("diff.patch");
|
||||
let mut diff_file = fs::File::create(&diff_path)?;
|
||||
diff_file.write_all(self.diff.as_bytes())?;
|
||||
|
||||
// Save the last message to a file
|
||||
let message_path = output_dir.join("assistant_response.txt");
|
||||
let mut message_file = fs::File::create(&message_path)?;
|
||||
message_file.write_all(self.last_message.as_bytes())?;
|
||||
|
||||
// Current metrics for this run
|
||||
let current_metrics = serde_json::json!({
|
||||
"elapsed_time_ms": self.elapsed_time.as_millis(),
|
||||
"assistant_response_count": self.assistant_response_count,
|
||||
"tool_use_counts": self.tool_use_counts,
|
||||
"token_usage": self.token_usage,
|
||||
"eval_output_value": eval_output_value,
|
||||
});
|
||||
|
||||
// Get current timestamp in milliseconds
|
||||
let timestamp = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)?
|
||||
.as_millis()
|
||||
.to_string();
|
||||
|
||||
// Path to metrics file
|
||||
let metrics_path = output_dir.join("metrics.json");
|
||||
|
||||
// Load existing metrics if the file exists, or create a new object
|
||||
let mut historical_metrics = if metrics_path.exists() {
|
||||
let metrics_content = fs::read_to_string(&metrics_path)?;
|
||||
serde_json::from_str::<serde_json::Value>(&metrics_content)
|
||||
.unwrap_or_else(|_| serde_json::json!({}))
|
||||
} else {
|
||||
serde_json::json!({})
|
||||
};
|
||||
|
||||
// Add new run with timestamp as key
|
||||
if let serde_json::Value::Object(ref mut map) = historical_metrics {
|
||||
map.insert(timestamp, current_metrics);
|
||||
}
|
||||
|
||||
// Write updated metrics back to file
|
||||
let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
|
||||
let mut metrics_file = fs::File::create(&metrics_path)?;
|
||||
metrics_file.write_all(metrics_json.as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_instructions(exercise_path: &Path) -> Result<String> {
|
||||
let instructions_path = exercise_path.join(".docs").join("instructions.md");
|
||||
println!("Reading instructions from: {}", instructions_path.display());
|
||||
let instructions = smol::unblock(move || std::fs::read_to_string(&instructions_path)).await?;
|
||||
Ok(instructions)
|
||||
}
|
||||
|
||||
pub async fn save_eval_results(exercise_path: &Path, results: Vec<EvalResult>) -> Result<()> {
|
||||
let eval_dir = exercise_path.join("evaluation");
|
||||
fs::create_dir_all(&eval_dir)?;
|
||||
|
||||
let eval_file = eval_dir.join("evals.json");
|
||||
|
||||
println!("Saving evaluation results to: {}", eval_file.display());
|
||||
println!(
|
||||
"Results to save: {} evaluations for exercise path: {}",
|
||||
results.len(),
|
||||
exercise_path.display()
|
||||
);
|
||||
|
||||
// Check file existence before reading/writing
|
||||
if eval_file.exists() {
|
||||
println!("Existing evals.json file found, will update it");
|
||||
} else {
|
||||
println!("No existing evals.json file found, will create new one");
|
||||
}
|
||||
|
||||
// Structure to organize evaluations by test name and timestamp
|
||||
let mut eval_data: serde_json::Value = if eval_file.exists() {
|
||||
let content = fs::read_to_string(&eval_file)?;
|
||||
serde_json::from_str(&content).unwrap_or_else(|_| serde_json::json!({}))
|
||||
} else {
|
||||
serde_json::json!({})
|
||||
};
|
||||
|
||||
// Get current timestamp for this batch of results
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)?
|
||||
.as_millis()
|
||||
.to_string();
|
||||
|
||||
// Group the new results by test name (exercise name)
|
||||
for result in results {
|
||||
let exercise_name = &result.exercise_name;
|
||||
|
||||
println!("Adding result: exercise={}", exercise_name);
|
||||
|
||||
// Ensure the exercise entry exists
|
||||
if eval_data.get(exercise_name).is_none() {
|
||||
eval_data[exercise_name] = serde_json::json!({});
|
||||
}
|
||||
|
||||
// Ensure the timestamp entry exists as an object
|
||||
if eval_data[exercise_name].get(×tamp).is_none() {
|
||||
eval_data[exercise_name][×tamp] = serde_json::json!({});
|
||||
}
|
||||
|
||||
// Add this result under the timestamp with template name as key
|
||||
eval_data[exercise_name][×tamp] = serde_json::to_value(&result)?;
|
||||
}
|
||||
|
||||
// Write back to file with pretty formatting
|
||||
let json_content = serde_json::to_string_pretty(&eval_data)?;
|
||||
match fs::write(&eval_file, json_content) {
|
||||
Ok(_) => println!("✓ Successfully saved results to {}", eval_file.display()),
|
||||
Err(e) => println!("✗ Failed to write results file: {}", e),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_exercise_eval(
|
||||
exercise_path: PathBuf,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
app_state: Arc<HeadlessAppState>,
|
||||
base_sha: String,
|
||||
_framework_path: PathBuf,
|
||||
cx: gpui::AsyncApp,
|
||||
) -> Result<EvalResult> {
|
||||
let exercise_name = get_exercise_name(&exercise_path);
|
||||
let language = get_exercise_language(&exercise_path)?;
|
||||
let mut instructions = read_instructions(&exercise_path).await?;
|
||||
instructions.push_str(&format!(
|
||||
"\n\nWhen writing the code for this prompt, use {} to achieve the goal.",
|
||||
language
|
||||
));
|
||||
|
||||
println!("Running evaluation for exercise: {}", exercise_name);
|
||||
|
||||
// Create temporary directory with exercise files
|
||||
let temp_dir = setup_temp_repo(&exercise_path, &base_sha).await?;
|
||||
let temp_path = temp_dir.path().to_path_buf();
|
||||
|
||||
let local_commit_sha = run_git(&temp_path, &["rev-parse", "HEAD"]).await?;
|
||||
|
||||
let start_time = SystemTime::now();
|
||||
|
||||
// Create a basic eval struct to work with the existing system
|
||||
let eval = Eval {
|
||||
repo_path: temp_path.clone(),
|
||||
eval_setup: EvalSetup {
|
||||
url: format!("file://{}", temp_path.display()),
|
||||
base_sha: local_commit_sha, // Use the local commit SHA instead of the framework base SHA
|
||||
},
|
||||
user_prompt: instructions.clone(),
|
||||
};
|
||||
|
||||
// Run the evaluation
|
||||
let eval_output = cx
|
||||
.update(|cx| eval.run(app_state.clone(), model.clone(), cx))?
|
||||
.await?;
|
||||
|
||||
// Get diff from git
|
||||
let diff = eval_output.diff.clone();
|
||||
|
||||
let elapsed_time = start_time.elapsed()?;
|
||||
|
||||
// Calculate total tokens as the sum of input and output tokens
|
||||
let input_tokens = eval_output.token_usage.input_tokens;
|
||||
let output_tokens = eval_output.token_usage.output_tokens;
|
||||
let tool_use_counts = eval_output.tool_use_counts.values().sum::<u32>();
|
||||
let total_tokens = input_tokens + output_tokens;
|
||||
|
||||
// Save results to evaluation directory
|
||||
let result = EvalResult {
|
||||
exercise_name: exercise_name.clone(),
|
||||
diff,
|
||||
assistant_response: eval_output.last_message.clone(),
|
||||
elapsed_time_ms: elapsed_time.as_millis(),
|
||||
timestamp: SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)?
|
||||
.as_millis(),
|
||||
// Convert u32 token counts to usize
|
||||
input_tokens: input_tokens.try_into().unwrap(),
|
||||
output_tokens: output_tokens.try_into().unwrap(),
|
||||
total_tokens: total_tokens.try_into().unwrap(),
|
||||
tool_use_counts: tool_use_counts.try_into().unwrap(),
|
||||
};
|
||||
|
||||
Ok(result)
|
||||
}
|
|
@ -1,149 +0,0 @@
|
|||
use anyhow::{Result, anyhow};
|
||||
use std::{
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
pub fn get_exercise_name(exercise_path: &Path) -> String {
|
||||
exercise_path
|
||||
.file_name()
|
||||
.unwrap_or_default()
|
||||
.to_string_lossy()
|
||||
.to_string()
|
||||
}
|
||||
|
||||
pub fn get_exercise_language(exercise_path: &Path) -> Result<String> {
|
||||
// Extract the language from path (data/python/exercises/... => python)
|
||||
let parts: Vec<_> = exercise_path.components().collect();
|
||||
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if i > 0 && part.as_os_str() == "eval_code" {
|
||||
if i + 1 < parts.len() {
|
||||
let language = parts[i + 1].as_os_str().to_string_lossy().to_string();
|
||||
return Ok(language);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow!(
|
||||
"Could not determine language from path: {:?}",
|
||||
exercise_path
|
||||
))
|
||||
}
|
||||
|
||||
pub fn find_exercises(
|
||||
framework_path: &Path,
|
||||
languages: &[&str],
|
||||
max_per_language: Option<usize>,
|
||||
) -> Result<Vec<PathBuf>> {
|
||||
let mut all_exercises = Vec::new();
|
||||
|
||||
println!("Searching for exercises in languages: {:?}", languages);
|
||||
|
||||
for language in languages {
|
||||
let language_dir = framework_path
|
||||
.join("eval_code")
|
||||
.join(language)
|
||||
.join("exercises")
|
||||
.join("practice");
|
||||
|
||||
println!("Checking language directory: {:?}", language_dir);
|
||||
if !language_dir.exists() {
|
||||
println!("Warning: Language directory not found: {:?}", language_dir);
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut exercises = Vec::new();
|
||||
match fs::read_dir(&language_dir) {
|
||||
Ok(entries) => {
|
||||
for entry_result in entries {
|
||||
match entry_result {
|
||||
Ok(entry) => {
|
||||
let path = entry.path();
|
||||
|
||||
if path.is_dir() {
|
||||
// Special handling for "internal" directory
|
||||
if *language == "internal" {
|
||||
// Check for repo_info.json to validate it's an internal exercise
|
||||
let repo_info_path = path.join(".meta").join("repo_info.json");
|
||||
let instructions_path =
|
||||
path.join(".docs").join("instructions.md");
|
||||
|
||||
if repo_info_path.exists() && instructions_path.exists() {
|
||||
exercises.push(path);
|
||||
}
|
||||
} else {
|
||||
// Map the language to the file extension - original code
|
||||
let language_extension = match *language {
|
||||
"python" => "py",
|
||||
"go" => "go",
|
||||
"rust" => "rs",
|
||||
"typescript" => "ts",
|
||||
"javascript" => "js",
|
||||
"ruby" => "rb",
|
||||
"php" => "php",
|
||||
"bash" => "sh",
|
||||
"multi" => "diff",
|
||||
_ => continue, // Skip unsupported languages
|
||||
};
|
||||
|
||||
// Check if this is a valid exercise with instructions and example
|
||||
let instructions_path =
|
||||
path.join(".docs").join("instructions.md");
|
||||
let has_instructions = instructions_path.exists();
|
||||
let example_path = path
|
||||
.join(".meta")
|
||||
.join(format!("example.{}", language_extension));
|
||||
let has_example = example_path.exists();
|
||||
|
||||
if has_instructions && has_example {
|
||||
exercises.push(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => println!("Error reading directory entry: {}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => println!(
|
||||
"Error reading directory {}: {}",
|
||||
language_dir.display(),
|
||||
err
|
||||
),
|
||||
}
|
||||
|
||||
// Sort exercises by name for consistent selection
|
||||
exercises.sort_by(|a, b| {
|
||||
let a_name = a.file_name().unwrap_or_default().to_string_lossy();
|
||||
let b_name = b.file_name().unwrap_or_default().to_string_lossy();
|
||||
a_name.cmp(&b_name)
|
||||
});
|
||||
|
||||
// Apply the limit if specified
|
||||
if let Some(limit) = max_per_language {
|
||||
if exercises.len() > limit {
|
||||
println!(
|
||||
"Limiting {} exercises to {} for language {}",
|
||||
exercises.len(),
|
||||
limit,
|
||||
language
|
||||
);
|
||||
exercises.truncate(limit);
|
||||
}
|
||||
}
|
||||
|
||||
println!(
|
||||
"Found {} exercises for language {}: {:?}",
|
||||
exercises.len(),
|
||||
language,
|
||||
exercises
|
||||
.iter()
|
||||
.map(|p| p.file_name().unwrap_or_default().to_string_lossy())
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
all_exercises.extend(exercises);
|
||||
}
|
||||
|
||||
Ok(all_exercises)
|
||||
}
|
|
@ -1,125 +0,0 @@
|
|||
use anyhow::{Result, anyhow};
|
||||
use serde::Deserialize;
|
||||
use std::{fs, path::Path};
|
||||
use tempfile::TempDir;
|
||||
use util::command::new_smol_command;
|
||||
use walkdir::WalkDir;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SetupConfig {
|
||||
#[serde(rename = "base.sha")]
|
||||
pub base_sha: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RepoInfo {
|
||||
pub remote_url: String,
|
||||
pub head_sha: String,
|
||||
}
|
||||
|
||||
pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
|
||||
let output = new_smol_command("git")
|
||||
.current_dir(repo_path)
|
||||
.args(args)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
if output.status.success() {
|
||||
Ok(String::from_utf8(output.stdout)?.trim().to_string())
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"Git command failed: {} with status: {}",
|
||||
args.join(" "),
|
||||
output.status
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_base_sha(framework_path: &Path) -> Result<String> {
|
||||
let setup_path = framework_path.join("setup.json");
|
||||
let setup_content = smol::unblock(move || std::fs::read_to_string(&setup_path)).await?;
|
||||
let setup_config: SetupConfig = serde_json_lenient::from_str_lenient(&setup_content)?;
|
||||
Ok(setup_config.base_sha)
|
||||
}
|
||||
|
||||
pub async fn read_repo_info(exercise_path: &Path) -> Result<RepoInfo> {
|
||||
let repo_info_path = exercise_path.join(".meta").join("repo_info.json");
|
||||
println!("Reading repo info from: {}", repo_info_path.display());
|
||||
let repo_info_content = smol::unblock(move || std::fs::read_to_string(&repo_info_path)).await?;
|
||||
let repo_info: RepoInfo = serde_json_lenient::from_str_lenient(&repo_info_content)?;
|
||||
|
||||
// Remove any quotes from the strings
|
||||
let remote_url = repo_info.remote_url.trim_matches('"').to_string();
|
||||
let head_sha = repo_info.head_sha.trim_matches('"').to_string();
|
||||
|
||||
Ok(RepoInfo {
|
||||
remote_url,
|
||||
head_sha,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn setup_temp_repo(exercise_path: &Path, _base_sha: &str) -> Result<TempDir> {
|
||||
let temp_dir = TempDir::new()?;
|
||||
|
||||
// Check if this is an internal exercise by looking for repo_info.json
|
||||
let repo_info_path = exercise_path.join(".meta").join("repo_info.json");
|
||||
if repo_info_path.exists() {
|
||||
// This is an internal exercise, handle it differently
|
||||
let repo_info = read_repo_info(exercise_path).await?;
|
||||
|
||||
// Clone the repository to the temp directory
|
||||
let url = repo_info.remote_url;
|
||||
let clone_path = temp_dir.path();
|
||||
println!(
|
||||
"Cloning repository from {} to {}",
|
||||
url,
|
||||
clone_path.display()
|
||||
);
|
||||
run_git(
|
||||
&std::env::current_dir()?,
|
||||
&["clone", &url, &clone_path.to_string_lossy()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Checkout the specified commit
|
||||
println!("Checking out commit: {}", repo_info.head_sha);
|
||||
run_git(temp_dir.path(), &["checkout", &repo_info.head_sha]).await?;
|
||||
|
||||
println!("Successfully set up internal repository");
|
||||
} else {
|
||||
// Original code for regular exercises
|
||||
// Copy the exercise files to the temp directory, excluding .docs and .meta
|
||||
for entry in WalkDir::new(exercise_path).min_depth(0).max_depth(10) {
|
||||
let entry = entry?;
|
||||
let source_path = entry.path();
|
||||
|
||||
// Skip .docs and .meta directories completely
|
||||
if source_path.starts_with(exercise_path.join(".docs"))
|
||||
|| source_path.starts_with(exercise_path.join(".meta"))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if source_path.is_file() {
|
||||
let relative_path = source_path.strip_prefix(exercise_path)?;
|
||||
let dest_path = temp_dir.path().join(relative_path);
|
||||
|
||||
// Make sure parent directories exist
|
||||
if let Some(parent) = dest_path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
fs::copy(source_path, dest_path)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize git repo in the temp directory
|
||||
run_git(temp_dir.path(), &["init"]).await?;
|
||||
run_git(temp_dir.path(), &["add", "."]).await?;
|
||||
run_git(temp_dir.path(), &["commit", "-m", "Initial commit"]).await?;
|
||||
|
||||
println!("Created temp repo without .docs and .meta directories");
|
||||
}
|
||||
|
||||
Ok(temp_dir)
|
||||
}
|
|
@ -1,229 +0,0 @@
|
|||
use agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
|
||||
use anyhow::anyhow;
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use client::{Client, UserStore};
|
||||
use collections::HashMap;
|
||||
use dap::DapRegistry;
|
||||
use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||
};
|
||||
use node_runtime::NodeRuntime;
|
||||
use project::{Project, RealFs};
|
||||
use prompt_store::PromptBuilder;
|
||||
use settings::SettingsStore;
|
||||
use smol::channel;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
|
||||
pub struct HeadlessAppState {
|
||||
pub languages: Arc<LanguageRegistry>,
|
||||
pub client: Arc<Client>,
|
||||
pub user_store: Entity<UserStore>,
|
||||
pub fs: Arc<dyn fs::Fs>,
|
||||
pub node_runtime: NodeRuntime,
|
||||
|
||||
// Additional fields not present in `workspace::AppState`.
|
||||
pub prompt_builder: Arc<PromptBuilder>,
|
||||
}
|
||||
|
||||
pub struct HeadlessAssistant {
|
||||
pub thread: Entity<Thread>,
|
||||
pub project: Entity<Project>,
|
||||
#[allow(dead_code)]
|
||||
pub thread_store: Entity<ThreadStore>,
|
||||
pub tool_use_counts: HashMap<Arc<str>, u32>,
|
||||
pub done_tx: channel::Sender<anyhow::Result<()>>,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl HeadlessAssistant {
|
||||
pub fn new(
|
||||
app_state: Arc<HeadlessAppState>,
|
||||
cx: &mut App,
|
||||
) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
|
||||
let env = None;
|
||||
let project = Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
Arc::new(DapRegistry::default()),
|
||||
app_state.fs.clone(),
|
||||
env,
|
||||
cx,
|
||||
);
|
||||
|
||||
let tools = Arc::new(ToolWorkingSet::default());
|
||||
let thread_store =
|
||||
ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
|
||||
|
||||
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
|
||||
|
||||
let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
|
||||
|
||||
let headless_thread = cx.new(move |cx| Self {
|
||||
_subscription: cx.subscribe(&thread, Self::handle_thread_event),
|
||||
thread,
|
||||
project,
|
||||
thread_store,
|
||||
tool_use_counts: HashMap::default(),
|
||||
done_tx,
|
||||
});
|
||||
|
||||
Ok((headless_thread, done_rx))
|
||||
}
|
||||
|
||||
fn handle_thread_event(
|
||||
&mut self,
|
||||
thread: Entity<Thread>,
|
||||
event: &ThreadEvent,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
ThreadEvent::ShowError(err) => self
|
||||
.done_tx
|
||||
.send_blocking(Err(anyhow!("{:?}", err)))
|
||||
.unwrap(),
|
||||
ThreadEvent::DoneStreaming => {
|
||||
let thread = thread.read(cx);
|
||||
if let Some(message) = thread.messages().last() {
|
||||
println!("Message: {}", message.to_string());
|
||||
}
|
||||
if thread.all_tools_finished() {
|
||||
self.done_tx.send_blocking(Ok(())).unwrap()
|
||||
}
|
||||
}
|
||||
ThreadEvent::UsePendingTools { .. } => {}
|
||||
ThreadEvent::ToolConfirmationNeeded => {
|
||||
// Automatically approve all tools that need confirmation in headless mode
|
||||
println!("Tool confirmation needed - automatically approving in headless mode");
|
||||
|
||||
// Get the tools needing confirmation
|
||||
let tools_needing_confirmation: Vec<_> = thread
|
||||
.read(cx)
|
||||
.tools_needing_confirmation()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
// Run each tool that needs confirmation
|
||||
for tool_use in tools_needing_confirmation {
|
||||
if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
|
||||
thread.update(cx, |thread, cx| {
|
||||
println!("Auto-approving tool: {}", tool_use.name);
|
||||
|
||||
// Create a request to send to the tool
|
||||
let request = thread.to_completion_request(RequestKind::Chat, cx);
|
||||
let messages = Arc::new(request.messages);
|
||||
|
||||
// Run the tool
|
||||
thread.run_tool(
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone(),
|
||||
tool_use.input.clone(),
|
||||
&messages,
|
||||
tool,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
ThreadEvent::ToolFinished {
|
||||
tool_use_id,
|
||||
pending_tool_use,
|
||||
..
|
||||
} => {
|
||||
if let Some(pending_tool_use) = pending_tool_use {
|
||||
println!(
|
||||
"Used tool {} with input: {}",
|
||||
pending_tool_use.name, pending_tool_use.input
|
||||
);
|
||||
*self
|
||||
.tool_use_counts
|
||||
.entry(pending_tool_use.name.clone())
|
||||
.or_insert(0) += 1;
|
||||
}
|
||||
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
|
||||
println!("Tool result: {:?}", tool_result);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
|
||||
release_channel::init(SemanticVersion::default(), cx);
|
||||
gpui_tokio::init(cx);
|
||||
|
||||
let mut settings_store = SettingsStore::new(cx);
|
||||
settings_store
|
||||
.set_default_settings(settings::default_settings().as_ref(), cx)
|
||||
.unwrap();
|
||||
cx.set_global(settings_store);
|
||||
client::init_settings(cx);
|
||||
Project::init_settings(cx);
|
||||
|
||||
let client = Client::production(cx);
|
||||
cx.set_http_client(client.http_client().clone());
|
||||
|
||||
let git_binary_path = None;
|
||||
let fs = Arc::new(RealFs::new(
|
||||
git_binary_path,
|
||||
cx.background_executor().clone(),
|
||||
));
|
||||
|
||||
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
|
||||
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
|
||||
language::init(cx);
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||
assistant_tools::init(client.http_client().clone(), cx);
|
||||
context_server::init(cx);
|
||||
let stdout_is_a_pty = false;
|
||||
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
|
||||
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
|
||||
|
||||
Arc::new(HeadlessAppState {
|
||||
languages,
|
||||
client,
|
||||
user_store,
|
||||
fs,
|
||||
node_runtime: NodeRuntime::unavailable(),
|
||||
prompt_builder,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let model = model_registry
|
||||
.available_models(cx)
|
||||
.find(|model| model.id().0 == model_name);
|
||||
|
||||
let Some(model) = model else {
|
||||
return Err(anyhow!(
|
||||
"No language model named {} was available. Available models: {}",
|
||||
model_name,
|
||||
model_registry
|
||||
.available_models(cx)
|
||||
.map(|model| model.id().0.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
));
|
||||
};
|
||||
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
pub fn authenticate_model_provider(
|
||||
provider_id: LanguageModelProviderId,
|
||||
cx: &mut App,
|
||||
) -> Task<std::result::Result<(), AuthenticateError>> {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let model_provider = model_registry.provider(&provider_id).unwrap();
|
||||
model_provider.authenticate(cx)
|
||||
}
|
|
@ -1,205 +0,0 @@
|
|||
mod eval;
|
||||
mod get_exercise;
|
||||
mod git_commands;
|
||||
mod headless_assistant;
|
||||
|
||||
use clap::Parser;
|
||||
use eval::{run_exercise_eval, save_eval_results};
|
||||
use futures::stream::{self, StreamExt};
|
||||
use get_exercise::{find_exercises, get_exercise_language, get_exercise_name};
|
||||
use git_commands::read_base_sha;
|
||||
use gpui::Application;
|
||||
use headless_assistant::{authenticate_model_provider, find_model};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(
|
||||
name = "agent_eval",
|
||||
disable_version_flag = true,
|
||||
before_help = "Tool eval runner"
|
||||
)]
|
||||
struct Args {
|
||||
/// Match the names of evals to run.
|
||||
#[arg(long)]
|
||||
exercise_names: Vec<String>,
|
||||
/// Runs all exercises, causes the exercise_names to be ignored.
|
||||
#[arg(long)]
|
||||
all: bool,
|
||||
/// Supported language types to evaluate (default: internal).
|
||||
/// Internal is data generated from the agent panel
|
||||
#[arg(long, default_value = "internal")]
|
||||
languages: String,
|
||||
/// Name of the model (default: "claude-3-7-sonnet-latest")
|
||||
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
|
||||
model_name: String,
|
||||
/// Name of the editor model (default: value of `--model_name`).
|
||||
#[arg(long)]
|
||||
editor_model_name: Option<String>,
|
||||
/// Number of evaluations to run concurrently (default: 3)
|
||||
#[arg(short, long, default_value = "5")]
|
||||
concurrency: usize,
|
||||
/// Maximum number of exercises to evaluate per language
|
||||
#[arg(long)]
|
||||
max_exercises_per_language: Option<usize>,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
env_logger::init();
|
||||
let args = Args::parse();
|
||||
let http_client = Arc::new(ReqwestClient::new());
|
||||
let app = Application::headless().with_http_client(http_client.clone());
|
||||
|
||||
// Path to the zed-ace-framework repo
|
||||
let framework_path = PathBuf::from("../zed-ace-framework")
|
||||
.canonicalize()
|
||||
.unwrap();
|
||||
|
||||
// Fix the 'languages' lifetime issue by creating owned Strings instead of slices
|
||||
let languages: Vec<String> = args.languages.split(',').map(|s| s.to_string()).collect();
|
||||
|
||||
println!("Using zed-ace-framework at: {:?}", framework_path);
|
||||
println!("Evaluating languages: {:?}", languages);
|
||||
|
||||
app.run(move |cx| {
|
||||
let app_state = headless_assistant::init(cx);
|
||||
|
||||
let model = find_model(&args.model_name, cx).unwrap();
|
||||
let editor_model = if let Some(model_name) = &args.editor_model_name {
|
||||
find_model(model_name, cx).unwrap()
|
||||
} else {
|
||||
model.clone()
|
||||
};
|
||||
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.set_default_model(Some(model.clone()), cx);
|
||||
});
|
||||
|
||||
let model_provider_id = model.provider_id();
|
||||
let editor_model_provider_id = editor_model.provider_id();
|
||||
|
||||
let framework_path_clone = framework_path.clone();
|
||||
let languages_clone = languages.clone();
|
||||
let exercise_names = args.exercise_names.clone();
|
||||
let all_flag = args.all;
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Authenticate all model providers first
|
||||
cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
println!("framework path: {}", framework_path_clone.display());
|
||||
|
||||
let base_sha = read_base_sha(&framework_path_clone).await.unwrap();
|
||||
|
||||
println!("base sha: {}", base_sha);
|
||||
|
||||
let all_exercises = find_exercises(
|
||||
&framework_path_clone,
|
||||
&languages_clone
|
||||
.iter()
|
||||
.map(|s| s.as_str())
|
||||
.collect::<Vec<_>>(),
|
||||
args.max_exercises_per_language,
|
||||
)
|
||||
.unwrap();
|
||||
println!("Found {} exercises total", all_exercises.len());
|
||||
|
||||
// Filter exercises if specific ones were requested
|
||||
let exercises_to_run = if !exercise_names.is_empty() {
|
||||
// If exercise names are specified, filter by them regardless of --all flag
|
||||
all_exercises
|
||||
.into_iter()
|
||||
.filter(|path| {
|
||||
let name = get_exercise_name(path);
|
||||
exercise_names.iter().any(|filter| name.contains(filter))
|
||||
})
|
||||
.collect()
|
||||
} else if all_flag {
|
||||
// Only use all_flag if no exercise names are specified
|
||||
all_exercises
|
||||
} else {
|
||||
// Default behavior (no filters)
|
||||
all_exercises
|
||||
};
|
||||
|
||||
println!("Will run {} exercises", exercises_to_run.len());
|
||||
|
||||
// Create exercise eval tasks - each exercise is a single task that will run templates sequentially
|
||||
let exercise_tasks: Vec<_> = exercises_to_run
|
||||
.into_iter()
|
||||
.map(|exercise_path| {
|
||||
let exercise_name = get_exercise_name(&exercise_path);
|
||||
let model_clone = model.clone();
|
||||
let app_state_clone = app_state.clone();
|
||||
let base_sha_clone = base_sha.clone();
|
||||
let framework_path_clone = framework_path_clone.clone();
|
||||
let cx_clone = cx.clone();
|
||||
|
||||
async move {
|
||||
println!("Processing exercise: {}", exercise_name);
|
||||
let mut exercise_results = Vec::new();
|
||||
|
||||
match run_exercise_eval(
|
||||
exercise_path.clone(),
|
||||
model_clone.clone(),
|
||||
app_state_clone.clone(),
|
||||
base_sha_clone.clone(),
|
||||
framework_path_clone.clone(),
|
||||
cx_clone.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => {
|
||||
println!("Completed {}", exercise_name);
|
||||
exercise_results.push(result);
|
||||
}
|
||||
Err(err) => {
|
||||
println!("Error running {}: {}", exercise_name, err);
|
||||
}
|
||||
}
|
||||
|
||||
// Save results for this exercise
|
||||
if !exercise_results.is_empty() {
|
||||
if let Err(err) =
|
||||
save_eval_results(&exercise_path, exercise_results.clone()).await
|
||||
{
|
||||
println!("Error saving results for {}: {}", exercise_name, err);
|
||||
} else {
|
||||
println!("Saved results for {}", exercise_name);
|
||||
}
|
||||
}
|
||||
|
||||
exercise_results
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
println!(
|
||||
"Running {} exercises with concurrency: {}",
|
||||
exercise_tasks.len(),
|
||||
args.concurrency
|
||||
);
|
||||
|
||||
// Run exercises concurrently, with each exercise running its templates sequentially
|
||||
let all_results = stream::iter(exercise_tasks)
|
||||
.buffer_unordered(args.concurrency)
|
||||
.flat_map(stream::iter)
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
println!("Completed {} evaluation runs", all_results.len());
|
||||
cx.update(|cx| cx.quit()).unwrap();
|
||||
})
|
||||
.detach();
|
||||
});
|
||||
|
||||
println!("Done running evals");
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
[package]
|
||||
name = "agent_rules"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/agent_rules.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
fs.workspace = true
|
||||
gpui.workspace = true
|
||||
prompt_store.workspace = true
|
||||
util.workspace = true
|
||||
worktree.workspace = true
|
||||
workspace-hack = { version = "0.1", path = "../../tooling/workspace-hack" }
|
||||
|
||||
[dev-dependencies]
|
||||
indoc.workspace = true
|
|
@ -1 +0,0 @@
|
|||
../../LICENSE-GPL
|
|
@ -1,51 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use fs::Fs;
|
||||
use gpui::{App, AppContext, Task};
|
||||
use prompt_store::SystemPromptRulesFile;
|
||||
use util::maybe;
|
||||
use worktree::Worktree;
|
||||
|
||||
const RULES_FILE_NAMES: [&'static str; 6] = [
|
||||
".rules",
|
||||
".cursorrules",
|
||||
".windsurfrules",
|
||||
".clinerules",
|
||||
".github/copilot-instructions.md",
|
||||
"CLAUDE.md",
|
||||
];
|
||||
|
||||
pub fn load_worktree_rules_file(
|
||||
fs: Arc<dyn Fs>,
|
||||
worktree: &Worktree,
|
||||
cx: &App,
|
||||
) -> Option<Task<Result<SystemPromptRulesFile>>> {
|
||||
let selected_rules_file = RULES_FILE_NAMES
|
||||
.into_iter()
|
||||
.filter_map(|name| {
|
||||
worktree
|
||||
.entry_for_path(name)
|
||||
.filter(|entry| entry.is_file())
|
||||
.map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
|
||||
})
|
||||
.next();
|
||||
|
||||
// Note that Cline supports `.clinerules` being a directory, but that is not currently
|
||||
// supported. This doesn't seem to occur often in GitHub repositories.
|
||||
selected_rules_file.map(|(path_in_worktree, abs_path)| {
|
||||
let fs = fs.clone();
|
||||
cx.background_spawn(maybe!(async move {
|
||||
let abs_path = abs_path?;
|
||||
let text = fs
|
||||
.load(&abs_path)
|
||||
.await
|
||||
.with_context(|| format!("Failed to load assistant rules file {:?}", abs_path))?;
|
||||
anyhow::Ok(SystemPromptRulesFile {
|
||||
path_in_worktree,
|
||||
abs_path: abs_path.into(),
|
||||
text: text.trim().to_string(),
|
||||
})
|
||||
}))
|
||||
})
|
||||
}
|
|
@ -69,7 +69,7 @@ pub enum AssistantProviderContentV1 {
|
|||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct AssistantSettings {
|
||||
pub enabled: bool,
|
||||
pub button: bool,
|
||||
|
|
|
@ -179,11 +179,9 @@ pub async fn file_outline(
|
|||
|
||||
// Wait until the buffer has been fully parsed, so that we can read its outline.
|
||||
let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
|
||||
while parse_status
|
||||
.recv()
|
||||
.await
|
||||
.map_or(false, |status| status != ParseStatus::Idle)
|
||||
{}
|
||||
while *parse_status.borrow() != ParseStatus::Idle {
|
||||
parse_status.changed().await?;
|
||||
}
|
||||
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||
let Some(outline) = snapshot.outline(None) else {
|
||||
|
|
|
@ -9,12 +9,13 @@ agent.workspace = true
|
|||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
assistant_tools.workspace = true
|
||||
assistant_settings.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
context_server.workspace = true
|
||||
dap.workspace = true
|
||||
env_logger.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
gpui_tokio.workspace = true
|
||||
language.workspace = true
|
||||
|
@ -27,7 +28,6 @@ release_channel.workspace = true
|
|||
reqwest_client.workspace = true
|
||||
serde.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
toml.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
|
|
|
@ -1,229 +0,0 @@
|
|||
use ::agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
|
||||
use anyhow::anyhow;
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use client::{Client, UserStore};
|
||||
use collections::HashMap;
|
||||
use dap::DapRegistry;
|
||||
use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||
};
|
||||
use node_runtime::NodeRuntime;
|
||||
use project::{Project, RealFs};
|
||||
use prompt_store::PromptBuilder;
|
||||
use settings::SettingsStore;
|
||||
use smol::channel;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
|
||||
pub struct AgentAppState {
|
||||
pub languages: Arc<LanguageRegistry>,
|
||||
pub client: Arc<Client>,
|
||||
pub user_store: Entity<UserStore>,
|
||||
pub fs: Arc<dyn fs::Fs>,
|
||||
pub node_runtime: NodeRuntime,
|
||||
|
||||
// Additional fields not present in `workspace::AppState`.
|
||||
pub prompt_builder: Arc<PromptBuilder>,
|
||||
}
|
||||
|
||||
pub struct Agent {
|
||||
// pub thread: Entity<Thread>,
|
||||
// pub project: Entity<Project>,
|
||||
#[allow(dead_code)]
|
||||
pub thread_store: Entity<ThreadStore>,
|
||||
pub tool_use_counts: HashMap<Arc<str>, u32>,
|
||||
pub done_tx: channel::Sender<anyhow::Result<()>>,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
pub fn new(
|
||||
app_state: Arc<AgentAppState>,
|
||||
cx: &mut App,
|
||||
) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
|
||||
let env = None;
|
||||
let project = Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
Arc::new(DapRegistry::default()),
|
||||
app_state.fs.clone(),
|
||||
env,
|
||||
cx,
|
||||
);
|
||||
|
||||
let tools = Arc::new(ToolWorkingSet::default());
|
||||
let thread_store =
|
||||
ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
|
||||
|
||||
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
|
||||
|
||||
let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
|
||||
|
||||
let headless_thread = cx.new(move |cx| Self {
|
||||
_subscription: cx.subscribe(&thread, Self::handle_thread_event),
|
||||
// thread,
|
||||
// project,
|
||||
thread_store,
|
||||
tool_use_counts: HashMap::default(),
|
||||
done_tx,
|
||||
});
|
||||
|
||||
Ok((headless_thread, done_rx))
|
||||
}
|
||||
|
||||
fn handle_thread_event(
|
||||
&mut self,
|
||||
thread: Entity<Thread>,
|
||||
event: &ThreadEvent,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
ThreadEvent::ShowError(err) => self
|
||||
.done_tx
|
||||
.send_blocking(Err(anyhow!("{:?}", err)))
|
||||
.unwrap(),
|
||||
ThreadEvent::DoneStreaming => {
|
||||
let thread = thread.read(cx);
|
||||
if let Some(message) = thread.messages().last() {
|
||||
println!("Message: {}", message.to_string());
|
||||
}
|
||||
if thread.all_tools_finished() {
|
||||
self.done_tx.send_blocking(Ok(())).unwrap()
|
||||
}
|
||||
}
|
||||
ThreadEvent::UsePendingTools { .. } => {}
|
||||
ThreadEvent::ToolConfirmationNeeded => {
|
||||
// Automatically approve all tools that need confirmation in headless mode
|
||||
println!("Tool confirmation needed - automatically approving in headless mode");
|
||||
|
||||
// Get the tools needing confirmation
|
||||
let tools_needing_confirmation: Vec<_> = thread
|
||||
.read(cx)
|
||||
.tools_needing_confirmation()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
// Run each tool that needs confirmation
|
||||
for tool_use in tools_needing_confirmation {
|
||||
if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
|
||||
thread.update(cx, |thread, cx| {
|
||||
println!("Auto-approving tool: {}", tool_use.name);
|
||||
|
||||
// Create a request to send to the tool
|
||||
let request = thread.to_completion_request(RequestKind::Chat, cx);
|
||||
let messages = Arc::new(request.messages);
|
||||
|
||||
// Run the tool
|
||||
thread.run_tool(
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone(),
|
||||
tool_use.input.clone(),
|
||||
&messages,
|
||||
tool,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
ThreadEvent::ToolFinished {
|
||||
tool_use_id,
|
||||
pending_tool_use,
|
||||
..
|
||||
} => {
|
||||
if let Some(pending_tool_use) = pending_tool_use {
|
||||
println!(
|
||||
"Used tool {} with input: {}",
|
||||
pending_tool_use.name, pending_tool_use.input
|
||||
);
|
||||
*self
|
||||
.tool_use_counts
|
||||
.entry(pending_tool_use.name.clone())
|
||||
.or_insert(0) += 1;
|
||||
}
|
||||
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
|
||||
println!("Tool result: {:?}", tool_result);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
||||
release_channel::init(SemanticVersion::default(), cx);
|
||||
gpui_tokio::init(cx);
|
||||
|
||||
let mut settings_store = SettingsStore::new(cx);
|
||||
settings_store
|
||||
.set_default_settings(settings::default_settings().as_ref(), cx)
|
||||
.unwrap();
|
||||
cx.set_global(settings_store);
|
||||
client::init_settings(cx);
|
||||
Project::init_settings(cx);
|
||||
|
||||
let client = Client::production(cx);
|
||||
cx.set_http_client(client.http_client().clone());
|
||||
|
||||
let git_binary_path = None;
|
||||
let fs = Arc::new(RealFs::new(
|
||||
git_binary_path,
|
||||
cx.background_executor().clone(),
|
||||
));
|
||||
|
||||
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
|
||||
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
|
||||
language::init(cx);
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||
assistant_tools::init(client.http_client().clone(), cx);
|
||||
context_server::init(cx);
|
||||
let stdout_is_a_pty = false;
|
||||
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
|
||||
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
|
||||
|
||||
Arc::new(AgentAppState {
|
||||
languages,
|
||||
client,
|
||||
user_store,
|
||||
fs,
|
||||
node_runtime: NodeRuntime::unavailable(),
|
||||
prompt_builder,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let model = model_registry
|
||||
.available_models(cx)
|
||||
.find(|model| model.id().0 == model_name);
|
||||
|
||||
let Some(model) = model else {
|
||||
return Err(anyhow!(
|
||||
"No language model named {} was available. Available models: {}",
|
||||
model_name,
|
||||
model_registry
|
||||
.available_models(cx)
|
||||
.map(|model| model.id().0.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
));
|
||||
};
|
||||
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
pub fn authenticate_model_provider(
|
||||
provider_id: LanguageModelProviderId,
|
||||
cx: &mut App,
|
||||
) -> Task<std::result::Result<(), AuthenticateError>> {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let model_provider = model_registry.provider(&provider_id).unwrap();
|
||||
model_provider.authenticate(cx)
|
||||
}
|
|
@ -1,74 +1,22 @@
|
|||
use agent::Agent;
|
||||
use anyhow::Result;
|
||||
use gpui::Application;
|
||||
use language_model::LanguageModelRegistry;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
mod example;
|
||||
|
||||
use assistant_settings::AssistantSettings;
|
||||
use client::{Client, UserStore};
|
||||
pub(crate) use example::*;
|
||||
|
||||
use ::fs::RealFs;
|
||||
use anyhow::anyhow;
|
||||
use gpui::{App, AppContext, Application, Entity, SemanticVersion, Task};
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||
};
|
||||
mod agent;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ExampleBase {
|
||||
pub path: PathBuf,
|
||||
pub revision: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Example {
|
||||
pub base: ExampleBase,
|
||||
|
||||
/// Content of the prompt.md file
|
||||
pub prompt: String,
|
||||
|
||||
/// Content of the rubric.md file
|
||||
pub rubric: String,
|
||||
}
|
||||
|
||||
impl Example {
|
||||
/// Load an example from a directory containing base.toml, prompt.md, and rubric.md
|
||||
pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
|
||||
let base_path = dir_path.as_ref().join("base.toml");
|
||||
let prompt_path = dir_path.as_ref().join("prompt.md");
|
||||
let rubric_path = dir_path.as_ref().join("rubric.md");
|
||||
|
||||
let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
|
||||
base.path = base.path.canonicalize()?;
|
||||
|
||||
Ok(Example {
|
||||
base,
|
||||
prompt: fs::read_to_string(prompt_path)?,
|
||||
rubric: fs::read_to_string(rubric_path)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set up the example by checking out the specified Git revision
|
||||
pub fn setup(&self) -> Result<()> {
|
||||
use std::process::Command;
|
||||
|
||||
// Check if the directory exists
|
||||
let path = Path::new(&self.base.path);
|
||||
anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
|
||||
|
||||
// Change to the project directory and checkout the specified revision
|
||||
let output = Command::new("git")
|
||||
.current_dir(&self.base.path)
|
||||
.arg("checkout")
|
||||
.arg(&self.base.revision)
|
||||
.output()?;
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"Failed to checkout revision {}: {}",
|
||||
self.base.revision,
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
use node_runtime::NodeRuntime;
|
||||
use project::Project;
|
||||
use prompt_store::PromptBuilder;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn main() {
|
||||
env_logger::init();
|
||||
|
@ -76,10 +24,9 @@ fn main() {
|
|||
let app = Application::headless().with_http_client(http_client.clone());
|
||||
|
||||
app.run(move |cx| {
|
||||
let app_state = crate::agent::init(cx);
|
||||
let _agent = Agent::new(app_state, cx);
|
||||
let app_state = init(cx);
|
||||
|
||||
let model = agent::find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
|
||||
let model = find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
|
||||
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.set_default_model(Some(model.clone()), cx);
|
||||
|
@ -87,15 +34,112 @@ fn main() {
|
|||
|
||||
let model_provider_id = model.provider_id();
|
||||
|
||||
let authenticate = agent::authenticate_model_provider(model_provider_id.clone(), cx);
|
||||
let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
cx.spawn(async move |cx| {
|
||||
authenticate.await.unwrap();
|
||||
})
|
||||
.detach();
|
||||
});
|
||||
|
||||
// let example =
|
||||
// Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
|
||||
// example.setup()?;
|
||||
let example =
|
||||
Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
|
||||
example.setup()?;
|
||||
cx.update(|cx| example.run(model, app_state, cx))?.await?;
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
});
|
||||
}
|
||||
|
||||
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
|
||||
pub struct AgentAppState {
|
||||
pub languages: Arc<LanguageRegistry>,
|
||||
pub client: Arc<Client>,
|
||||
pub user_store: Entity<UserStore>,
|
||||
pub fs: Arc<dyn fs::Fs>,
|
||||
pub node_runtime: NodeRuntime,
|
||||
|
||||
// Additional fields not present in `workspace::AppState`.
|
||||
pub prompt_builder: Arc<PromptBuilder>,
|
||||
}
|
||||
|
||||
pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
||||
release_channel::init(SemanticVersion::default(), cx);
|
||||
gpui_tokio::init(cx);
|
||||
|
||||
let mut settings_store = SettingsStore::new(cx);
|
||||
settings_store
|
||||
.set_default_settings(settings::default_settings().as_ref(), cx)
|
||||
.unwrap();
|
||||
cx.set_global(settings_store);
|
||||
client::init_settings(cx);
|
||||
Project::init_settings(cx);
|
||||
|
||||
let client = Client::production(cx);
|
||||
cx.set_http_client(client.http_client().clone());
|
||||
|
||||
let git_binary_path = None;
|
||||
let fs = Arc::new(RealFs::new(
|
||||
git_binary_path,
|
||||
cx.background_executor().clone(),
|
||||
));
|
||||
|
||||
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
|
||||
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
|
||||
language::init(cx);
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||
assistant_tools::init(client.http_client().clone(), cx);
|
||||
context_server::init(cx);
|
||||
let stdout_is_a_pty = false;
|
||||
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
|
||||
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
|
||||
|
||||
AssistantSettings::override_global(
|
||||
AssistantSettings {
|
||||
always_allow_tool_actions: true,
|
||||
..AssistantSettings::get_global(cx).clone()
|
||||
},
|
||||
cx,
|
||||
);
|
||||
|
||||
Arc::new(AgentAppState {
|
||||
languages,
|
||||
client,
|
||||
user_store,
|
||||
fs,
|
||||
node_runtime: NodeRuntime::unavailable(),
|
||||
prompt_builder,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let model = model_registry
|
||||
.available_models(cx)
|
||||
.find(|model| model.id().0 == model_name);
|
||||
|
||||
let Some(model) = model else {
|
||||
return Err(anyhow!(
|
||||
"No language model named {} was available. Available models: {}",
|
||||
model_name,
|
||||
model_registry
|
||||
.available_models(cx)
|
||||
.map(|model| model.id().0.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
));
|
||||
};
|
||||
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
pub fn authenticate_model_provider(
|
||||
provider_id: LanguageModelProviderId,
|
||||
cx: &mut App,
|
||||
) -> Task<std::result::Result<(), AuthenticateError>> {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let model_provider = model_registry.provider(&provider_id).unwrap();
|
||||
model_provider.authenticate(cx)
|
||||
}
|
||||
|
|
178
crates/eval/src/example.rs
Normal file
178
crates/eval/src/example.rs
Normal file
|
@ -0,0 +1,178 @@
|
|||
use agent::{RequestKind, ThreadEvent, ThreadStore};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use dap::DapRegistry;
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{App, Task};
|
||||
use language_model::{LanguageModel, StopReason};
|
||||
use project::Project;
|
||||
use serde::Deserialize;
|
||||
use std::process::Command;
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use crate::AgentAppState;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ExampleBase {
|
||||
pub path: PathBuf,
|
||||
pub revision: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Example {
|
||||
pub base: ExampleBase,
|
||||
|
||||
/// Content of the prompt.md file
|
||||
pub prompt: String,
|
||||
|
||||
/// Content of the rubric.md file
|
||||
pub _rubric: String,
|
||||
}
|
||||
|
||||
impl Example {
|
||||
/// Load an example from a directory containing base.toml, prompt.md, and rubric.md
|
||||
pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
|
||||
let base_path = dir_path.as_ref().join("base.toml");
|
||||
let prompt_path = dir_path.as_ref().join("prompt.md");
|
||||
let rubric_path = dir_path.as_ref().join("rubric.md");
|
||||
|
||||
let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
|
||||
base.path = base.path.canonicalize()?;
|
||||
|
||||
Ok(Example {
|
||||
base,
|
||||
prompt: fs::read_to_string(prompt_path)?,
|
||||
_rubric: fs::read_to_string(rubric_path)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set up the example by checking out the specified Git revision
|
||||
pub fn setup(&self) -> Result<()> {
|
||||
// Check if the directory exists
|
||||
let path = Path::new(&self.base.path);
|
||||
anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
|
||||
|
||||
// Change to the project directory and checkout the specified revision
|
||||
let output = Command::new("git")
|
||||
.current_dir(&self.base.path)
|
||||
.arg("checkout")
|
||||
.arg(&self.base.revision)
|
||||
.output()?;
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"Failed to checkout revision {}: {}",
|
||||
self.base.revision,
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn run(
|
||||
self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
app_state: Arc<AgentAppState>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<()>> {
|
||||
let project = Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
Arc::new(DapRegistry::default()),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
|
||||
let worktree = project.update(cx, |project, cx| {
|
||||
project.create_worktree(self.base.path, true, cx)
|
||||
});
|
||||
|
||||
let tools = Arc::new(ToolWorkingSet::default());
|
||||
let thread_store =
|
||||
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
|
||||
|
||||
println!("USER:");
|
||||
println!("{}", self.prompt);
|
||||
println!("ASSISTANT:");
|
||||
cx.spawn(async move |cx| {
|
||||
worktree.await?;
|
||||
let thread_store = thread_store.await;
|
||||
let thread =
|
||||
thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let mut tx = Some(tx);
|
||||
|
||||
let _subscription =
|
||||
cx.subscribe(
|
||||
&thread,
|
||||
move |thread, event: &ThreadEvent, cx| match event {
|
||||
ThreadEvent::Stopped(reason) => match reason {
|
||||
Ok(StopReason::EndTurn) => {
|
||||
if let Some(tx) = tx.take() {
|
||||
tx.send(Ok(())).ok();
|
||||
}
|
||||
}
|
||||
Ok(StopReason::MaxTokens) => {
|
||||
if let Some(tx) = tx.take() {
|
||||
tx.send(Err(anyhow!("Exceeded maximum tokens"))).ok();
|
||||
}
|
||||
}
|
||||
Ok(StopReason::ToolUse) => {}
|
||||
Err(error) => {
|
||||
if let Some(tx) = tx.take() {
|
||||
tx.send(Err(anyhow!(error.clone()))).ok();
|
||||
}
|
||||
}
|
||||
},
|
||||
ThreadEvent::ShowError(thread_error) => {
|
||||
if let Some(tx) = tx.take() {
|
||||
tx.send(Err(anyhow!(thread_error.clone()))).ok();
|
||||
}
|
||||
}
|
||||
ThreadEvent::StreamedAssistantText(_, chunk) => {
|
||||
print!("{}", chunk);
|
||||
}
|
||||
ThreadEvent::StreamedAssistantThinking(_, chunk) => {
|
||||
print!("{}", chunk);
|
||||
}
|
||||
ThreadEvent::UsePendingTools { tool_uses } => {
|
||||
println!("\n\nUSING TOOLS:");
|
||||
for tool_use in tool_uses {
|
||||
println!("{}: {}", tool_use.name, tool_use.input);
|
||||
}
|
||||
}
|
||||
ThreadEvent::ToolFinished {
|
||||
tool_use_id,
|
||||
pending_tool_use,
|
||||
..
|
||||
} => {
|
||||
if let Some(tool_use) = pending_tool_use {
|
||||
println!("\nTOOL FINISHED: {}", tool_use.name);
|
||||
}
|
||||
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
|
||||
println!("\n{}\n", tool_result.content);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
)?;
|
||||
|
||||
thread.update(cx, |thread, cx| {
|
||||
let context = vec![];
|
||||
thread.insert_user_message(self.prompt.clone(), context, None, cx);
|
||||
thread.send_to_model(model, RequestKind::Chat, cx);
|
||||
})?;
|
||||
|
||||
rx.await??;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
AnyView, AnyWindowHandle, App, AppCell, AppContext, BackgroundExecutor, BorrowAppContext,
|
||||
Entity, Focusable, ForegroundExecutor, Global, PromptLevel, Render, Reservation, Result, Task,
|
||||
VisualContext, Window, WindowHandle,
|
||||
Entity, EventEmitter, Focusable, ForegroundExecutor, Global, PromptLevel, Render, Reservation,
|
||||
Result, Subscription, Task, VisualContext, Window, WindowHandle,
|
||||
};
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use derive_more::{Deref, DerefMut};
|
||||
|
@ -154,6 +154,26 @@ impl AsyncApp {
|
|||
Ok(lock.update(f))
|
||||
}
|
||||
|
||||
/// Arrange for the given callback to be invoked whenever the given entity emits an event of a given type.
|
||||
/// The callback is provided a handle to the emitting entity and a reference to the emitted event.
|
||||
pub fn subscribe<T, Event>(
|
||||
&mut self,
|
||||
entity: &Entity<T>,
|
||||
mut on_event: impl FnMut(Entity<T>, &Event, &mut App) + 'static,
|
||||
) -> Result<Subscription>
|
||||
where
|
||||
T: 'static + EventEmitter<Event>,
|
||||
Event: 'static,
|
||||
{
|
||||
let app = self
|
||||
.app
|
||||
.upgrade()
|
||||
.ok_or_else(|| anyhow!("app was released"))?;
|
||||
let mut lock = app.borrow_mut();
|
||||
let subscription = lock.subscribe(entity, on_event);
|
||||
Ok(subscription)
|
||||
}
|
||||
|
||||
/// Open a window with the given options based on the root view returned by the given function.
|
||||
pub fn open_window<V>(
|
||||
&self,
|
||||
|
|
|
@ -16,17 +16,17 @@ use std::{
|
|||
use text::LineEnding;
|
||||
use util::{ResultExt, get_system_shell};
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct AssistantSystemPromptContext {
|
||||
pub worktrees: Vec<WorktreeInfoForSystemPrompt>,
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ProjectContext {
|
||||
pub worktrees: Vec<WorktreeContext>,
|
||||
pub has_rules: bool,
|
||||
pub os: String,
|
||||
pub arch: String,
|
||||
pub shell: String,
|
||||
}
|
||||
|
||||
impl AssistantSystemPromptContext {
|
||||
pub fn new(worktrees: Vec<WorktreeInfoForSystemPrompt>) -> Self {
|
||||
impl ProjectContext {
|
||||
pub fn new(worktrees: Vec<WorktreeContext>) -> Self {
|
||||
let has_rules = worktrees
|
||||
.iter()
|
||||
.any(|worktree| worktree.rules_file.is_some());
|
||||
|
@ -40,15 +40,15 @@ impl AssistantSystemPromptContext {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct WorktreeInfoForSystemPrompt {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct WorktreeContext {
|
||||
pub root_name: String,
|
||||
pub abs_path: Arc<Path>,
|
||||
pub rules_file: Option<SystemPromptRulesFile>,
|
||||
pub rules_file: Option<RulesFileContext>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct SystemPromptRulesFile {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct RulesFileContext {
|
||||
pub path_in_worktree: Arc<Path>,
|
||||
pub abs_path: Arc<Path>,
|
||||
pub text: String,
|
||||
|
@ -260,7 +260,7 @@ impl PromptBuilder {
|
|||
|
||||
pub fn generate_assistant_system_prompt(
|
||||
&self,
|
||||
context: &AssistantSystemPromptContext,
|
||||
context: &ProjectContext,
|
||||
) -> Result<String, RenderError> {
|
||||
self.handlebars
|
||||
.lock()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue