Actually run the eval and fix a hang when retrieving outline (#28547)
Release Notes: - Fixed a regression that caused the agent to hang sometimes. --------- Co-authored-by: Thomas Mickley-Doyle <tmickleydoyle@gmail.com> Co-authored-by: Nathan Sobo <nathan@zed.dev> Co-authored-by: Michael Sloan <mgsloan@gmail.com>
This commit is contained in:
parent
c0262cf62f
commit
2440faf4b2
28 changed files with 642 additions and 1862 deletions
57
Cargo.lock
generated
57
Cargo.lock
generated
|
@ -52,7 +52,6 @@ dependencies = [
|
||||||
name = "agent"
|
name = "agent"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"agent_rules",
|
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"assistant_context_editor",
|
"assistant_context_editor",
|
||||||
"assistant_settings",
|
"assistant_settings",
|
||||||
|
@ -116,6 +115,7 @@ dependencies = [
|
||||||
"terminal_view",
|
"terminal_view",
|
||||||
"text",
|
"text",
|
||||||
"theme",
|
"theme",
|
||||||
|
"thiserror 2.0.12",
|
||||||
"time",
|
"time",
|
||||||
"time_format",
|
"time_format",
|
||||||
"ui",
|
"ui",
|
||||||
|
@ -127,57 +127,6 @@ dependencies = [
|
||||||
"zed_actions",
|
"zed_actions",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "agent_eval"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
"agent",
|
|
||||||
"anyhow",
|
|
||||||
"assistant_tool",
|
|
||||||
"assistant_tools",
|
|
||||||
"clap",
|
|
||||||
"client",
|
|
||||||
"collections",
|
|
||||||
"context_server",
|
|
||||||
"dap",
|
|
||||||
"env_logger 0.11.8",
|
|
||||||
"fs",
|
|
||||||
"futures 0.3.31",
|
|
||||||
"gpui",
|
|
||||||
"gpui_tokio",
|
|
||||||
"language",
|
|
||||||
"language_model",
|
|
||||||
"language_models",
|
|
||||||
"node_runtime",
|
|
||||||
"project",
|
|
||||||
"prompt_store",
|
|
||||||
"release_channel",
|
|
||||||
"reqwest_client",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"serde_json_lenient",
|
|
||||||
"settings",
|
|
||||||
"smol",
|
|
||||||
"tempfile",
|
|
||||||
"util",
|
|
||||||
"walkdir",
|
|
||||||
"workspace-hack",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "agent_rules"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
"anyhow",
|
|
||||||
"fs",
|
|
||||||
"gpui",
|
|
||||||
"indoc",
|
|
||||||
"prompt_store",
|
|
||||||
"util",
|
|
||||||
"workspace-hack",
|
|
||||||
"worktree",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ahash"
|
name = "ahash"
|
||||||
version = "0.7.8"
|
version = "0.7.8"
|
||||||
|
@ -4910,14 +4859,15 @@ version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"agent",
|
"agent",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"assistant_settings",
|
||||||
"assistant_tool",
|
"assistant_tool",
|
||||||
"assistant_tools",
|
"assistant_tools",
|
||||||
"client",
|
"client",
|
||||||
"collections",
|
|
||||||
"context_server",
|
"context_server",
|
||||||
"dap",
|
"dap",
|
||||||
"env_logger 0.11.8",
|
"env_logger 0.11.8",
|
||||||
"fs",
|
"fs",
|
||||||
|
"futures 0.3.31",
|
||||||
"gpui",
|
"gpui",
|
||||||
"gpui_tokio",
|
"gpui_tokio",
|
||||||
"language",
|
"language",
|
||||||
|
@ -4930,7 +4880,6 @@ dependencies = [
|
||||||
"reqwest_client",
|
"reqwest_client",
|
||||||
"serde",
|
"serde",
|
||||||
"settings",
|
"settings",
|
||||||
"smol",
|
|
||||||
"toml 0.8.20",
|
"toml 0.8.20",
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,13 +3,11 @@ resolver = "2"
|
||||||
members = [
|
members = [
|
||||||
"crates/activity_indicator",
|
"crates/activity_indicator",
|
||||||
"crates/agent",
|
"crates/agent",
|
||||||
"crates/agent_rules",
|
|
||||||
"crates/anthropic",
|
"crates/anthropic",
|
||||||
"crates/askpass",
|
"crates/askpass",
|
||||||
"crates/assets",
|
"crates/assets",
|
||||||
"crates/assistant",
|
"crates/assistant",
|
||||||
"crates/assistant_context_editor",
|
"crates/assistant_context_editor",
|
||||||
"crates/agent_eval",
|
|
||||||
"crates/assistant_settings",
|
"crates/assistant_settings",
|
||||||
"crates/assistant_slash_command",
|
"crates/assistant_slash_command",
|
||||||
"crates/assistant_slash_commands",
|
"crates/assistant_slash_commands",
|
||||||
|
@ -211,14 +209,12 @@ edition = "2024"
|
||||||
|
|
||||||
activity_indicator = { path = "crates/activity_indicator" }
|
activity_indicator = { path = "crates/activity_indicator" }
|
||||||
agent = { path = "crates/agent" }
|
agent = { path = "crates/agent" }
|
||||||
agent_rules = { path = "crates/agent_rules" }
|
|
||||||
ai = { path = "crates/ai" }
|
ai = { path = "crates/ai" }
|
||||||
anthropic = { path = "crates/anthropic" }
|
anthropic = { path = "crates/anthropic" }
|
||||||
askpass = { path = "crates/askpass" }
|
askpass = { path = "crates/askpass" }
|
||||||
assets = { path = "crates/assets" }
|
assets = { path = "crates/assets" }
|
||||||
assistant = { path = "crates/assistant" }
|
assistant = { path = "crates/assistant" }
|
||||||
assistant_context_editor = { path = "crates/assistant_context_editor" }
|
assistant_context_editor = { path = "crates/assistant_context_editor" }
|
||||||
assistant_eval = { path = "crates/agent_eval" }
|
|
||||||
assistant_settings = { path = "crates/assistant_settings" }
|
assistant_settings = { path = "crates/assistant_settings" }
|
||||||
assistant_slash_command = { path = "crates/assistant_slash_command" }
|
assistant_slash_command = { path = "crates/assistant_slash_command" }
|
||||||
assistant_slash_commands = { path = "crates/assistant_slash_commands" }
|
assistant_slash_commands = { path = "crates/assistant_slash_commands" }
|
||||||
|
|
|
@ -19,7 +19,6 @@ test-support = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
agent_rules.workspace = true
|
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
assistant_context_editor.workspace = true
|
assistant_context_editor.workspace = true
|
||||||
assistant_settings.workspace = true
|
assistant_settings.workspace = true
|
||||||
|
@ -81,6 +80,7 @@ terminal.workspace = true
|
||||||
terminal_view.workspace = true
|
terminal_view.workspace = true
|
||||||
text.workspace = true
|
text.workspace = true
|
||||||
theme.workspace = true
|
theme.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
time.workspace = true
|
time.workspace = true
|
||||||
time_format.workspace = true
|
time_format.workspace = true
|
||||||
ui.workspace = true
|
ui.workspace = true
|
||||||
|
|
|
@ -4,7 +4,7 @@ use crate::thread::{
|
||||||
LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError,
|
LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError,
|
||||||
ThreadEvent, ThreadFeedback,
|
ThreadEvent, ThreadFeedback,
|
||||||
};
|
};
|
||||||
use crate::thread_store::ThreadStore;
|
use crate::thread_store::{RulesLoadingError, ThreadStore};
|
||||||
use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
|
use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
|
||||||
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
|
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
|
||||||
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
|
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
|
||||||
|
@ -21,7 +21,7 @@ use gpui::{
|
||||||
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
|
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
|
||||||
};
|
};
|
||||||
use language::{Buffer, LanguageRegistry};
|
use language::{Buffer, LanguageRegistry};
|
||||||
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
|
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role, StopReason};
|
||||||
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
|
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
|
||||||
use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
|
use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
|
||||||
use project::ProjectItem as _;
|
use project::ProjectItem as _;
|
||||||
|
@ -668,6 +668,7 @@ impl ActiveThread {
|
||||||
let subscriptions = vec![
|
let subscriptions = vec![
|
||||||
cx.observe(&thread, |_, _, cx| cx.notify()),
|
cx.observe(&thread, |_, _, cx| cx.notify()),
|
||||||
cx.subscribe_in(&thread, window, Self::handle_thread_event),
|
cx.subscribe_in(&thread, window, Self::handle_thread_event),
|
||||||
|
cx.subscribe(&thread_store, Self::handle_rules_loading_error),
|
||||||
];
|
];
|
||||||
|
|
||||||
let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.), {
|
let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.), {
|
||||||
|
@ -833,10 +834,9 @@ impl ActiveThread {
|
||||||
| ThreadEvent::SummaryChanged => {
|
| ThreadEvent::SummaryChanged => {
|
||||||
self.save_thread(cx);
|
self.save_thread(cx);
|
||||||
}
|
}
|
||||||
ThreadEvent::DoneStreaming => {
|
ThreadEvent::Stopped(reason) => match reason {
|
||||||
let thread = self.thread.read(cx);
|
Ok(StopReason::EndTurn | StopReason::MaxTokens) => {
|
||||||
|
let thread = self.thread.read(cx);
|
||||||
if !thread.is_generating() {
|
|
||||||
self.show_notification(
|
self.show_notification(
|
||||||
if thread.used_tools_since_last_user_message() {
|
if thread.used_tools_since_last_user_message() {
|
||||||
"Finished running tools"
|
"Finished running tools"
|
||||||
|
@ -848,7 +848,8 @@ impl ActiveThread {
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
_ => {}
|
||||||
|
},
|
||||||
ThreadEvent::ToolConfirmationNeeded => {
|
ThreadEvent::ToolConfirmationNeeded => {
|
||||||
self.show_notification("Waiting for tool confirmation", IconName::Info, window, cx);
|
self.show_notification("Waiting for tool confirmation", IconName::Info, window, cx);
|
||||||
}
|
}
|
||||||
|
@ -925,6 +926,19 @@ impl ActiveThread {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn handle_rules_loading_error(
|
||||||
|
&mut self,
|
||||||
|
_thread_store: Entity<ThreadStore>,
|
||||||
|
error: &RulesLoadingError,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
self.last_error = Some(ThreadError::Message {
|
||||||
|
header: "Error loading rules file".into(),
|
||||||
|
message: error.message.clone(),
|
||||||
|
});
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
|
||||||
fn show_notification(
|
fn show_notification(
|
||||||
&mut self,
|
&mut self,
|
||||||
caption: impl Into<SharedString>,
|
caption: impl Into<SharedString>,
|
||||||
|
@ -2701,12 +2715,13 @@ impl ActiveThread {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
|
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
|
||||||
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
|
let project_context = self.thread.read(cx).project_context();
|
||||||
else {
|
let project_context = project_context.borrow();
|
||||||
|
let Some(project_context) = project_context.as_ref() else {
|
||||||
return div().into_any();
|
return div().into_any();
|
||||||
};
|
};
|
||||||
|
|
||||||
let rules_files = system_prompt_context
|
let rules_files = project_context
|
||||||
.worktrees
|
.worktrees
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|worktree| worktree.rules_file.as_ref())
|
.filter_map(|worktree| worktree.rules_file.as_ref())
|
||||||
|
@ -2796,12 +2811,13 @@ impl ActiveThread {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
|
fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
|
||||||
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
|
let project_context = self.thread.read(cx).project_context();
|
||||||
else {
|
let project_context = project_context.borrow();
|
||||||
|
let Some(project_context) = project_context.as_ref() else {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
let abs_paths = system_prompt_context
|
let abs_paths = project_context
|
||||||
.worktrees
|
.worktrees
|
||||||
.iter()
|
.iter()
|
||||||
.flat_map(|worktree| worktree.rules_file.as_ref())
|
.flat_map(|worktree| worktree.rules_file.as_ref())
|
||||||
|
|
|
@ -921,15 +921,16 @@ mod tests {
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let thread_store = cx.update(|cx| {
|
let thread_store = cx
|
||||||
ThreadStore::new(
|
.update(|cx| {
|
||||||
project.clone(),
|
ThreadStore::load(
|
||||||
Arc::default(),
|
project.clone(),
|
||||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
Arc::default(),
|
||||||
cx,
|
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||||
)
|
cx,
|
||||||
.unwrap()
|
)
|
||||||
});
|
})
|
||||||
|
.await;
|
||||||
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
||||||
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
|
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
|
||||||
|
|
||||||
|
|
|
@ -194,10 +194,12 @@ impl AssistantPanel {
|
||||||
) -> Task<Result<Entity<Self>>> {
|
) -> Task<Result<Entity<Self>>> {
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
let tools = Arc::new(ToolWorkingSet::default());
|
let tools = Arc::new(ToolWorkingSet::default());
|
||||||
let thread_store = workspace.update(cx, |workspace, cx| {
|
let thread_store = workspace
|
||||||
let project = workspace.project().clone();
|
.update(cx, |workspace, cx| {
|
||||||
ThreadStore::new(project, tools.clone(), prompt_builder.clone(), cx)
|
let project = workspace.project().clone();
|
||||||
})??;
|
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
|
||||||
|
})?
|
||||||
|
.await;
|
||||||
|
|
||||||
let slash_commands = Arc::new(SlashCommandWorkingSet::default());
|
let slash_commands = Arc::new(SlashCommandWorkingSet::default());
|
||||||
let context_store = workspace
|
let context_store = workspace
|
||||||
|
|
|
@ -32,8 +32,8 @@ use crate::profile_selector::ProfileSelector;
|
||||||
use crate::thread::{RequestKind, Thread, TokenUsageRatio};
|
use crate::thread::{RequestKind, Thread, TokenUsageRatio};
|
||||||
use crate::thread_store::ThreadStore;
|
use crate::thread_store::ThreadStore;
|
||||||
use crate::{
|
use crate::{
|
||||||
AgentDiff, Chat, ChatMode, NewThread, OpenAgentDiff, RemoveAllContext, ThreadEvent,
|
AgentDiff, Chat, ChatMode, NewThread, OpenAgentDiff, RemoveAllContext, ToggleContextPicker,
|
||||||
ToggleContextPicker, ToggleProfileSelector,
|
ToggleProfileSelector,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct MessageEditor {
|
pub struct MessageEditor {
|
||||||
|
@ -235,8 +235,6 @@ impl MessageEditor {
|
||||||
let refresh_task =
|
let refresh_task =
|
||||||
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
|
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
|
||||||
|
|
||||||
let system_prompt_context_task = self.thread.read(cx).load_system_prompt_context(cx);
|
|
||||||
|
|
||||||
let thread = self.thread.clone();
|
let thread = self.thread.clone();
|
||||||
let context_store = self.context_store.clone();
|
let context_store = self.context_store.clone();
|
||||||
let git_store = self.project.read(cx).git_store().clone();
|
let git_store = self.project.read(cx).git_store().clone();
|
||||||
|
@ -245,16 +243,6 @@ impl MessageEditor {
|
||||||
cx.spawn(async move |this, cx| {
|
cx.spawn(async move |this, cx| {
|
||||||
let checkpoint = checkpoint.await.ok();
|
let checkpoint = checkpoint.await.ok();
|
||||||
refresh_task.await;
|
refresh_task.await;
|
||||||
let (system_prompt_context, load_error) = system_prompt_context_task.await;
|
|
||||||
|
|
||||||
thread
|
|
||||||
.update(cx, |thread, cx| {
|
|
||||||
thread.set_system_prompt_context(system_prompt_context);
|
|
||||||
if let Some(load_error) = load_error {
|
|
||||||
cx.emit(ThreadEvent::ShowError(load_error));
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.log_err();
|
|
||||||
|
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
|
|
|
@ -3,14 +3,12 @@ use std::io::Write;
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use agent_rules::load_worktree_rules_file;
|
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use assistant_settings::AssistantSettings;
|
use assistant_settings::AssistantSettings;
|
||||||
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
|
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use collections::{BTreeMap, HashMap};
|
use collections::{BTreeMap, HashMap};
|
||||||
use feature_flags::{self, FeatureFlagAppExt};
|
use feature_flags::{self, FeatureFlagAppExt};
|
||||||
use fs::Fs;
|
|
||||||
use futures::future::Shared;
|
use futures::future::Shared;
|
||||||
use futures::{FutureExt, StreamExt as _};
|
use futures::{FutureExt, StreamExt as _};
|
||||||
use git::repository::DiffType;
|
use git::repository::DiffType;
|
||||||
|
@ -21,19 +19,20 @@ use language_model::{
|
||||||
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||||
PaymentRequiredError, Role, StopReason, TokenUsage,
|
PaymentRequiredError, Role, StopReason, TokenUsage,
|
||||||
};
|
};
|
||||||
|
use project::Project;
|
||||||
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
|
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
|
||||||
use project::{Project, Worktree};
|
use prompt_store::PromptBuilder;
|
||||||
use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
|
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::Settings;
|
use settings::Settings;
|
||||||
|
use thiserror::Error;
|
||||||
use util::{ResultExt as _, TryFutureExt as _, post_inc};
|
use util::{ResultExt as _, TryFutureExt as _, post_inc};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::context::{AssistantContext, ContextId, format_context_as_string};
|
use crate::context::{AssistantContext, ContextId, format_context_as_string};
|
||||||
use crate::thread_store::{
|
use crate::thread_store::{
|
||||||
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
|
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
|
||||||
SerializedToolUse,
|
SerializedToolUse, SharedProjectContext,
|
||||||
};
|
};
|
||||||
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
|
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
|
||||||
|
|
||||||
|
@ -247,7 +246,7 @@ pub struct Thread {
|
||||||
next_message_id: MessageId,
|
next_message_id: MessageId,
|
||||||
context: BTreeMap<ContextId, AssistantContext>,
|
context: BTreeMap<ContextId, AssistantContext>,
|
||||||
context_by_message: HashMap<MessageId, Vec<ContextId>>,
|
context_by_message: HashMap<MessageId, Vec<ContextId>>,
|
||||||
system_prompt_context: Option<AssistantSystemPromptContext>,
|
project_context: SharedProjectContext,
|
||||||
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
|
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
|
||||||
completion_count: usize,
|
completion_count: usize,
|
||||||
pending_completions: Vec<PendingCompletion>,
|
pending_completions: Vec<PendingCompletion>,
|
||||||
|
@ -269,6 +268,7 @@ impl Thread {
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
tools: Arc<ToolWorkingSet>,
|
||||||
prompt_builder: Arc<PromptBuilder>,
|
prompt_builder: Arc<PromptBuilder>,
|
||||||
|
system_prompt: SharedProjectContext,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
@ -281,7 +281,7 @@ impl Thread {
|
||||||
next_message_id: MessageId(0),
|
next_message_id: MessageId(0),
|
||||||
context: BTreeMap::default(),
|
context: BTreeMap::default(),
|
||||||
context_by_message: HashMap::default(),
|
context_by_message: HashMap::default(),
|
||||||
system_prompt_context: None,
|
project_context: system_prompt,
|
||||||
checkpoints_by_message: HashMap::default(),
|
checkpoints_by_message: HashMap::default(),
|
||||||
completion_count: 0,
|
completion_count: 0,
|
||||||
pending_completions: Vec::new(),
|
pending_completions: Vec::new(),
|
||||||
|
@ -310,6 +310,7 @@ impl Thread {
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
tools: Arc<ToolWorkingSet>,
|
||||||
prompt_builder: Arc<PromptBuilder>,
|
prompt_builder: Arc<PromptBuilder>,
|
||||||
|
project_context: SharedProjectContext,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let next_message_id = MessageId(
|
let next_message_id = MessageId(
|
||||||
|
@ -350,7 +351,7 @@ impl Thread {
|
||||||
next_message_id,
|
next_message_id,
|
||||||
context: BTreeMap::default(),
|
context: BTreeMap::default(),
|
||||||
context_by_message: HashMap::default(),
|
context_by_message: HashMap::default(),
|
||||||
system_prompt_context: None,
|
project_context,
|
||||||
checkpoints_by_message: HashMap::default(),
|
checkpoints_by_message: HashMap::default(),
|
||||||
completion_count: 0,
|
completion_count: 0,
|
||||||
pending_completions: Vec::new(),
|
pending_completions: Vec::new(),
|
||||||
|
@ -388,6 +389,10 @@ impl Thread {
|
||||||
self.summary.clone()
|
self.summary.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn project_context(&self) -> SharedProjectContext {
|
||||||
|
self.project_context.clone()
|
||||||
|
}
|
||||||
|
|
||||||
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
|
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
|
||||||
|
|
||||||
pub fn summary_or_default(&self) -> SharedString {
|
pub fn summary_or_default(&self) -> SharedString {
|
||||||
|
@ -812,86 +817,6 @@ impl Thread {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
|
|
||||||
self.system_prompt_context = Some(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
|
|
||||||
&self.system_prompt_context
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_system_prompt_context(
|
|
||||||
&self,
|
|
||||||
cx: &App,
|
|
||||||
) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
|
|
||||||
let project = self.project.read(cx);
|
|
||||||
let tasks = project
|
|
||||||
.visible_worktrees(cx)
|
|
||||||
.map(|worktree| {
|
|
||||||
Self::load_worktree_info_for_system_prompt(
|
|
||||||
project.fs().clone(),
|
|
||||||
worktree.read(cx),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
cx.spawn(async |_cx| {
|
|
||||||
let results = futures::future::join_all(tasks).await;
|
|
||||||
let mut first_err = None;
|
|
||||||
let worktrees = results
|
|
||||||
.into_iter()
|
|
||||||
.map(|(worktree, err)| {
|
|
||||||
if first_err.is_none() && err.is_some() {
|
|
||||||
first_err = err;
|
|
||||||
}
|
|
||||||
worktree
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
(AssistantSystemPromptContext::new(worktrees), first_err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_worktree_info_for_system_prompt(
|
|
||||||
fs: Arc<dyn Fs>,
|
|
||||||
worktree: &Worktree,
|
|
||||||
cx: &App,
|
|
||||||
) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
|
|
||||||
let root_name = worktree.root_name().into();
|
|
||||||
let abs_path = worktree.abs_path();
|
|
||||||
|
|
||||||
let rules_task = load_worktree_rules_file(fs, worktree, cx);
|
|
||||||
let Some(rules_task) = rules_task else {
|
|
||||||
return Task::ready((
|
|
||||||
WorktreeInfoForSystemPrompt {
|
|
||||||
root_name,
|
|
||||||
abs_path,
|
|
||||||
rules_file: None,
|
|
||||||
},
|
|
||||||
None,
|
|
||||||
));
|
|
||||||
};
|
|
||||||
|
|
||||||
cx.spawn(async move |_| {
|
|
||||||
let (rules_file, rules_file_error) = match rules_task.await {
|
|
||||||
Ok(rules_file) => (Some(rules_file), None),
|
|
||||||
Err(err) => (
|
|
||||||
None,
|
|
||||||
Some(ThreadError::Message {
|
|
||||||
header: "Error loading rules file".into(),
|
|
||||||
message: format!("{err}").into(),
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
};
|
|
||||||
let worktree_info = WorktreeInfoForSystemPrompt {
|
|
||||||
root_name,
|
|
||||||
abs_path,
|
|
||||||
rules_file,
|
|
||||||
};
|
|
||||||
(worktree_info, rules_file_error)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send_to_model(
|
pub fn send_to_model(
|
||||||
&mut self,
|
&mut self,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
|
@ -941,10 +866,10 @@ impl Thread {
|
||||||
temperature: None,
|
temperature: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
|
if let Some(project_context) = self.project_context.borrow().as_ref() {
|
||||||
if let Some(system_prompt) = self
|
if let Some(system_prompt) = self
|
||||||
.prompt_builder
|
.prompt_builder
|
||||||
.generate_assistant_system_prompt(system_prompt_context)
|
.generate_assistant_system_prompt(project_context)
|
||||||
.context("failed to generate assistant system prompt")
|
.context("failed to generate assistant system prompt")
|
||||||
.log_err()
|
.log_err()
|
||||||
{
|
{
|
||||||
|
@ -955,7 +880,7 @@ impl Thread {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log::error!("system_prompt_context not set.")
|
log::error!("project_context not set.")
|
||||||
}
|
}
|
||||||
|
|
||||||
for message in &self.messages {
|
for message in &self.messages {
|
||||||
|
@ -1215,7 +1140,7 @@ impl Thread {
|
||||||
thread.cancel_last_completion(cx);
|
thread.cancel_last_completion(cx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cx.emit(ThreadEvent::DoneStreaming);
|
cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
|
||||||
|
|
||||||
thread.auto_capture_telemetry(cx);
|
thread.auto_capture_telemetry(cx);
|
||||||
|
|
||||||
|
@ -1963,10 +1888,13 @@ impl Thread {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, Error)]
|
||||||
pub enum ThreadError {
|
pub enum ThreadError {
|
||||||
|
#[error("Payment required")]
|
||||||
PaymentRequired,
|
PaymentRequired,
|
||||||
|
#[error("Max monthly spend reached")]
|
||||||
MaxMonthlySpendReached,
|
MaxMonthlySpendReached,
|
||||||
|
#[error("Message {header}: {message}")]
|
||||||
Message {
|
Message {
|
||||||
header: SharedString,
|
header: SharedString,
|
||||||
message: SharedString,
|
message: SharedString,
|
||||||
|
@ -1979,7 +1907,7 @@ pub enum ThreadEvent {
|
||||||
StreamedCompletion,
|
StreamedCompletion,
|
||||||
StreamedAssistantText(MessageId, String),
|
StreamedAssistantText(MessageId, String),
|
||||||
StreamedAssistantThinking(MessageId, String),
|
StreamedAssistantThinking(MessageId, String),
|
||||||
DoneStreaming,
|
Stopped(Result<StopReason, Arc<anyhow::Error>>),
|
||||||
MessageAdded(MessageId),
|
MessageAdded(MessageId),
|
||||||
MessageEdited(MessageId),
|
MessageEdited(MessageId),
|
||||||
MessageDeleted(MessageId),
|
MessageDeleted(MessageId),
|
||||||
|
@ -2085,9 +2013,9 @@ fn main() {{
|
||||||
thread.to_completion_request(RequestKind::Chat, cx)
|
thread.to_completion_request(RequestKind::Chat, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
assert_eq!(request.messages.len(), 1);
|
assert_eq!(request.messages.len(), 2);
|
||||||
let expected_full_message = format!("{}Please explain this code", expected_context);
|
let expected_full_message = format!("{}Please explain this code", expected_context);
|
||||||
assert_eq!(request.messages[0].string_contents(), expected_full_message);
|
assert_eq!(request.messages[1].string_contents(), expected_full_message);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
|
@ -2178,20 +2106,20 @@ fn main() {{
|
||||||
});
|
});
|
||||||
|
|
||||||
// The request should contain all 3 messages
|
// The request should contain all 3 messages
|
||||||
assert_eq!(request.messages.len(), 3);
|
assert_eq!(request.messages.len(), 4);
|
||||||
|
|
||||||
// Check that the contexts are properly formatted in each message
|
// Check that the contexts are properly formatted in each message
|
||||||
assert!(request.messages[0].string_contents().contains("file1.rs"));
|
assert!(request.messages[1].string_contents().contains("file1.rs"));
|
||||||
assert!(!request.messages[0].string_contents().contains("file2.rs"));
|
assert!(!request.messages[1].string_contents().contains("file2.rs"));
|
||||||
assert!(!request.messages[0].string_contents().contains("file3.rs"));
|
|
||||||
|
|
||||||
assert!(!request.messages[1].string_contents().contains("file1.rs"));
|
|
||||||
assert!(request.messages[1].string_contents().contains("file2.rs"));
|
|
||||||
assert!(!request.messages[1].string_contents().contains("file3.rs"));
|
assert!(!request.messages[1].string_contents().contains("file3.rs"));
|
||||||
|
|
||||||
assert!(!request.messages[2].string_contents().contains("file1.rs"));
|
assert!(!request.messages[2].string_contents().contains("file1.rs"));
|
||||||
assert!(!request.messages[2].string_contents().contains("file2.rs"));
|
assert!(request.messages[2].string_contents().contains("file2.rs"));
|
||||||
assert!(request.messages[2].string_contents().contains("file3.rs"));
|
assert!(!request.messages[2].string_contents().contains("file3.rs"));
|
||||||
|
|
||||||
|
assert!(!request.messages[3].string_contents().contains("file1.rs"));
|
||||||
|
assert!(!request.messages[3].string_contents().contains("file2.rs"));
|
||||||
|
assert!(request.messages[3].string_contents().contains("file3.rs"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
|
@ -2229,9 +2157,9 @@ fn main() {{
|
||||||
thread.to_completion_request(RequestKind::Chat, cx)
|
thread.to_completion_request(RequestKind::Chat, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
assert_eq!(request.messages.len(), 1);
|
assert_eq!(request.messages.len(), 2);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
request.messages[0].string_contents(),
|
request.messages[1].string_contents(),
|
||||||
"What is the best way to learn Rust?"
|
"What is the best way to learn Rust?"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -2249,13 +2177,13 @@ fn main() {{
|
||||||
thread.to_completion_request(RequestKind::Chat, cx)
|
thread.to_completion_request(RequestKind::Chat, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
assert_eq!(request.messages.len(), 2);
|
assert_eq!(request.messages.len(), 3);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
request.messages[0].string_contents(),
|
request.messages[1].string_contents(),
|
||||||
"What is the best way to learn Rust?"
|
"What is the best way to learn Rust?"
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
request.messages[1].string_contents(),
|
request.messages[2].string_contents(),
|
||||||
"Are there any good books?"
|
"Are there any good books?"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -2376,15 +2304,16 @@ fn main() {{
|
||||||
let (workspace, cx) =
|
let (workspace, cx) =
|
||||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||||
|
|
||||||
let thread_store = cx.update(|_, cx| {
|
let thread_store = cx
|
||||||
ThreadStore::new(
|
.update(|_, cx| {
|
||||||
project.clone(),
|
ThreadStore::load(
|
||||||
Arc::default(),
|
project.clone(),
|
||||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
Arc::default(),
|
||||||
cx,
|
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||||
)
|
cx,
|
||||||
.unwrap()
|
)
|
||||||
});
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
||||||
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
|
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
|
||||||
|
|
|
@ -1,37 +1,57 @@
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use std::path::PathBuf;
|
use std::cell::{Ref, RefCell};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::rc::Rc;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
|
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
|
||||||
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
|
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use context_server::manager::ContextServerManager;
|
use context_server::manager::ContextServerManager;
|
||||||
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
|
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
|
||||||
|
use fs::Fs;
|
||||||
use futures::FutureExt as _;
|
use futures::FutureExt as _;
|
||||||
use futures::future::{self, BoxFuture, Shared};
|
use futures::future::{self, BoxFuture, Shared};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
|
App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
|
||||||
prelude::*,
|
Subscription, Task, prelude::*,
|
||||||
};
|
};
|
||||||
use heed::Database;
|
use heed::Database;
|
||||||
use heed::types::SerdeBincode;
|
use heed::types::SerdeBincode;
|
||||||
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
|
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
|
||||||
use project::Project;
|
use project::{Project, Worktree};
|
||||||
use prompt_store::PromptBuilder;
|
use prompt_store::{ProjectContext, PromptBuilder, RulesFileContext, WorktreeContext};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::{Settings as _, SettingsStore};
|
use settings::{Settings as _, SettingsStore};
|
||||||
use util::ResultExt as _;
|
use util::ResultExt as _;
|
||||||
|
|
||||||
use crate::thread::{
|
use crate::thread::{DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadId};
|
||||||
DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId,
|
|
||||||
};
|
const RULES_FILE_NAMES: [&'static str; 6] = [
|
||||||
|
".rules",
|
||||||
|
".cursorrules",
|
||||||
|
".windsurfrules",
|
||||||
|
".clinerules",
|
||||||
|
".github/copilot-instructions.md",
|
||||||
|
"CLAUDE.md",
|
||||||
|
];
|
||||||
|
|
||||||
pub fn init(cx: &mut App) {
|
pub fn init(cx: &mut App) {
|
||||||
ThreadsDatabase::init(cx);
|
ThreadsDatabase::init(cx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A system prompt shared by all threads created by this ThreadStore
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
|
||||||
|
|
||||||
|
impl SharedProjectContext {
|
||||||
|
pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
|
||||||
|
self.0.borrow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ThreadStore {
|
pub struct ThreadStore {
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
tools: Arc<ToolWorkingSet>,
|
||||||
|
@ -39,43 +59,187 @@ pub struct ThreadStore {
|
||||||
context_server_manager: Entity<ContextServerManager>,
|
context_server_manager: Entity<ContextServerManager>,
|
||||||
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
|
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
|
||||||
threads: Vec<SerializedThreadMetadata>,
|
threads: Vec<SerializedThreadMetadata>,
|
||||||
|
project_context: SharedProjectContext,
|
||||||
_subscriptions: Vec<Subscription>,
|
_subscriptions: Vec<Subscription>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct RulesLoadingError {
|
||||||
|
pub message: SharedString,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EventEmitter<RulesLoadingError> for ThreadStore {}
|
||||||
|
|
||||||
impl ThreadStore {
|
impl ThreadStore {
|
||||||
pub fn new(
|
pub fn load(
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
tools: Arc<ToolWorkingSet>,
|
||||||
prompt_builder: Arc<PromptBuilder>,
|
prompt_builder: Arc<PromptBuilder>,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Result<Entity<Self>> {
|
) -> Task<Entity<Self>> {
|
||||||
let this = cx.new(|cx| {
|
let thread_store = cx.new(|cx| Self::new(project, tools, prompt_builder, cx));
|
||||||
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
|
let reload = thread_store.update(cx, |store, cx| store.reload_system_prompt(cx));
|
||||||
let context_server_manager = cx.new(|cx| {
|
cx.foreground_executor().spawn(async move {
|
||||||
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
|
reload.await;
|
||||||
});
|
thread_store
|
||||||
let settings_subscription =
|
})
|
||||||
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
|
}
|
||||||
this.load_default_profile(cx);
|
|
||||||
});
|
|
||||||
|
|
||||||
let this = Self {
|
fn new(
|
||||||
project,
|
project: Entity<Project>,
|
||||||
tools,
|
tools: Arc<ToolWorkingSet>,
|
||||||
prompt_builder,
|
prompt_builder: Arc<PromptBuilder>,
|
||||||
context_server_manager,
|
cx: &mut Context<Self>,
|
||||||
context_server_tool_ids: HashMap::default(),
|
) -> Self {
|
||||||
threads: Vec::new(),
|
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
|
||||||
_subscriptions: vec![settings_subscription],
|
let context_server_manager = cx.new(|cx| {
|
||||||
};
|
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
|
||||||
this.load_default_profile(cx);
|
|
||||||
this.register_context_server_handlers(cx);
|
|
||||||
this.reload(cx).detach_and_log_err(cx);
|
|
||||||
|
|
||||||
this
|
|
||||||
});
|
});
|
||||||
|
let settings_subscription =
|
||||||
|
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
|
||||||
|
this.load_default_profile(cx);
|
||||||
|
});
|
||||||
|
let project_subscription = cx.subscribe(&project, Self::handle_project_event);
|
||||||
|
|
||||||
Ok(this)
|
let this = Self {
|
||||||
|
project,
|
||||||
|
tools,
|
||||||
|
prompt_builder,
|
||||||
|
context_server_manager,
|
||||||
|
context_server_tool_ids: HashMap::default(),
|
||||||
|
threads: Vec::new(),
|
||||||
|
project_context: SharedProjectContext::default(),
|
||||||
|
_subscriptions: vec![settings_subscription, project_subscription],
|
||||||
|
};
|
||||||
|
this.load_default_profile(cx);
|
||||||
|
this.register_context_server_handlers(cx);
|
||||||
|
this.reload(cx).detach_and_log_err(cx);
|
||||||
|
this
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_project_event(
|
||||||
|
&mut self,
|
||||||
|
_project: Entity<Project>,
|
||||||
|
event: &project::Event,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
match event {
|
||||||
|
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
|
||||||
|
self.reload_system_prompt(cx).detach();
|
||||||
|
}
|
||||||
|
project::Event::WorktreeUpdatedEntries(_, items) => {
|
||||||
|
if items.iter().any(|(path, _, _)| {
|
||||||
|
RULES_FILE_NAMES
|
||||||
|
.iter()
|
||||||
|
.any(|name| path.as_ref() == Path::new(name))
|
||||||
|
}) {
|
||||||
|
self.reload_system_prompt(cx).detach();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reload_system_prompt(&self, cx: &mut Context<Self>) -> Task<()> {
|
||||||
|
let project = self.project.read(cx);
|
||||||
|
let tasks = project
|
||||||
|
.visible_worktrees(cx)
|
||||||
|
.map(|worktree| {
|
||||||
|
Self::load_worktree_info_for_system_prompt(
|
||||||
|
project.fs().clone(),
|
||||||
|
worktree.read(cx),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
cx.spawn(async move |this, cx| {
|
||||||
|
let results = futures::future::join_all(tasks).await;
|
||||||
|
let worktrees = results
|
||||||
|
.into_iter()
|
||||||
|
.map(|(worktree, rules_error)| {
|
||||||
|
if let Some(rules_error) = rules_error {
|
||||||
|
this.update(cx, |_, cx| cx.emit(rules_error)).ok();
|
||||||
|
}
|
||||||
|
worktree
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
this.update(cx, |this, _cx| {
|
||||||
|
*this.project_context.0.borrow_mut() = Some(ProjectContext::new(worktrees));
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_worktree_info_for_system_prompt(
|
||||||
|
fs: Arc<dyn Fs>,
|
||||||
|
worktree: &Worktree,
|
||||||
|
cx: &App,
|
||||||
|
) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
|
||||||
|
let root_name = worktree.root_name().into();
|
||||||
|
let abs_path = worktree.abs_path();
|
||||||
|
|
||||||
|
let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
|
||||||
|
let Some(rules_task) = rules_task else {
|
||||||
|
return Task::ready((
|
||||||
|
WorktreeContext {
|
||||||
|
root_name,
|
||||||
|
abs_path,
|
||||||
|
rules_file: None,
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
cx.spawn(async move |_| {
|
||||||
|
let (rules_file, rules_file_error) = match rules_task.await {
|
||||||
|
Ok(rules_file) => (Some(rules_file), None),
|
||||||
|
Err(err) => (
|
||||||
|
None,
|
||||||
|
Some(RulesLoadingError {
|
||||||
|
message: format!("{err}").into(),
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
let worktree_info = WorktreeContext {
|
||||||
|
root_name,
|
||||||
|
abs_path,
|
||||||
|
rules_file,
|
||||||
|
};
|
||||||
|
(worktree_info, rules_file_error)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_worktree_rules_file(
|
||||||
|
fs: Arc<dyn Fs>,
|
||||||
|
worktree: &Worktree,
|
||||||
|
cx: &App,
|
||||||
|
) -> Option<Task<Result<RulesFileContext>>> {
|
||||||
|
let selected_rules_file = RULES_FILE_NAMES
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|name| {
|
||||||
|
worktree
|
||||||
|
.entry_for_path(name)
|
||||||
|
.filter(|entry| entry.is_file())
|
||||||
|
.map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
|
||||||
|
})
|
||||||
|
.next();
|
||||||
|
|
||||||
|
// Note that Cline supports `.clinerules` being a directory, but that is not currently
|
||||||
|
// supported. This doesn't seem to occur often in GitHub repositories.
|
||||||
|
selected_rules_file.map(|(path_in_worktree, abs_path)| {
|
||||||
|
let fs = fs.clone();
|
||||||
|
cx.background_spawn(async move {
|
||||||
|
let abs_path = abs_path?;
|
||||||
|
let text = fs.load(&abs_path).await.with_context(|| {
|
||||||
|
format!("Failed to load assistant rules file {:?}", abs_path)
|
||||||
|
})?;
|
||||||
|
anyhow::Ok(RulesFileContext {
|
||||||
|
path_in_worktree,
|
||||||
|
abs_path: abs_path.into(),
|
||||||
|
text: text.trim().to_string(),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
|
pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
|
||||||
|
@ -107,6 +271,7 @@ impl ThreadStore {
|
||||||
self.project.clone(),
|
self.project.clone(),
|
||||||
self.tools.clone(),
|
self.tools.clone(),
|
||||||
self.prompt_builder.clone(),
|
self.prompt_builder.clone(),
|
||||||
|
self.project_context.clone(),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
@ -134,21 +299,12 @@ impl ThreadStore {
|
||||||
this.project.clone(),
|
this.project.clone(),
|
||||||
this.tools.clone(),
|
this.tools.clone(),
|
||||||
this.prompt_builder.clone(),
|
this.prompt_builder.clone(),
|
||||||
|
this.project_context.clone(),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let (system_prompt_context, load_error) = thread
|
|
||||||
.update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
|
|
||||||
.await;
|
|
||||||
thread.update(cx, |thread, cx| {
|
|
||||||
thread.set_system_prompt_context(system_prompt_context);
|
|
||||||
if let Some(load_error) = load_error {
|
|
||||||
cx.emit(ThreadEvent::ShowError(load_error));
|
|
||||||
}
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(thread)
|
Ok(thread)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,46 +0,0 @@
|
||||||
[package]
|
|
||||||
name = "agent_eval"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition.workspace = true
|
|
||||||
publish.workspace = true
|
|
||||||
license = "GPL-3.0-or-later"
|
|
||||||
|
|
||||||
[lints]
|
|
||||||
workspace = true
|
|
||||||
|
|
||||||
[[bin]]
|
|
||||||
name = "agent_eval"
|
|
||||||
path = "src/main.rs"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
agent.workspace = true
|
|
||||||
anyhow.workspace = true
|
|
||||||
assistant_tool.workspace = true
|
|
||||||
assistant_tools.workspace = true
|
|
||||||
clap.workspace = true
|
|
||||||
client.workspace = true
|
|
||||||
collections.workspace = true
|
|
||||||
context_server.workspace = true
|
|
||||||
dap.workspace = true
|
|
||||||
env_logger.workspace = true
|
|
||||||
fs.workspace = true
|
|
||||||
futures.workspace = true
|
|
||||||
gpui.workspace = true
|
|
||||||
gpui_tokio.workspace = true
|
|
||||||
language.workspace = true
|
|
||||||
language_model.workspace = true
|
|
||||||
language_models.workspace = true
|
|
||||||
node_runtime.workspace = true
|
|
||||||
project.workspace = true
|
|
||||||
prompt_store.workspace = true
|
|
||||||
release_channel.workspace = true
|
|
||||||
reqwest_client.workspace = true
|
|
||||||
serde.workspace = true
|
|
||||||
serde_json.workspace = true
|
|
||||||
serde_json_lenient.workspace = true
|
|
||||||
settings.workspace = true
|
|
||||||
smol.workspace = true
|
|
||||||
tempfile.workspace = true
|
|
||||||
util.workspace = true
|
|
||||||
walkdir.workspace = true
|
|
||||||
workspace-hack.workspace = true
|
|
|
@ -1 +0,0 @@
|
||||||
../../LICENSE-GPL
|
|
|
@ -1,52 +0,0 @@
|
||||||
// Copied from `crates/zed/build.rs`, with removal of code for including the zed icon on windows.
|
|
||||||
|
|
||||||
use std::process::Command;
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
if cfg!(target_os = "macos") {
|
|
||||||
println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
|
|
||||||
|
|
||||||
// Weakly link ReplayKit to ensure Zed can be used on macOS 10.15+.
|
|
||||||
println!("cargo:rustc-link-arg=-Wl,-weak_framework,ReplayKit");
|
|
||||||
|
|
||||||
// Seems to be required to enable Swift concurrency
|
|
||||||
println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift");
|
|
||||||
|
|
||||||
// Register exported Objective-C selectors, protocols, etc
|
|
||||||
println!("cargo:rustc-link-arg=-Wl,-ObjC");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Populate git sha environment variable if git is available
|
|
||||||
println!("cargo:rerun-if-changed=../../.git/logs/HEAD");
|
|
||||||
println!(
|
|
||||||
"cargo:rustc-env=TARGET={}",
|
|
||||||
std::env::var("TARGET").unwrap()
|
|
||||||
);
|
|
||||||
if let Ok(output) = Command::new("git").args(["rev-parse", "HEAD"]).output() {
|
|
||||||
if output.status.success() {
|
|
||||||
let git_sha = String::from_utf8_lossy(&output.stdout);
|
|
||||||
let git_sha = git_sha.trim();
|
|
||||||
|
|
||||||
println!("cargo:rustc-env=ZED_COMMIT_SHA={git_sha}");
|
|
||||||
|
|
||||||
if let Ok(build_profile) = std::env::var("PROFILE") {
|
|
||||||
if build_profile == "release" {
|
|
||||||
// This is currently the best way to make `cargo build ...`'s build script
|
|
||||||
// to print something to stdout without extra verbosity.
|
|
||||||
println!(
|
|
||||||
"cargo:warning=Info: using '{git_sha}' hash for ZED_COMMIT_SHA env var"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(target_os = "windows")]
|
|
||||||
{
|
|
||||||
#[cfg(target_env = "msvc")]
|
|
||||||
{
|
|
||||||
// todo(windows): This is to avoid stack overflow. Remove it when solved.
|
|
||||||
println!("cargo:rustc-link-arg=/stack:{}", 8 * 1024 * 1024);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,384 +0,0 @@
|
||||||
use crate::git_commands::{run_git, setup_temp_repo};
|
|
||||||
use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
|
|
||||||
use crate::{get_exercise_language, get_exercise_name};
|
|
||||||
use agent::RequestKind;
|
|
||||||
use anyhow::{Result, anyhow};
|
|
||||||
use collections::HashMap;
|
|
||||||
use gpui::{App, Task};
|
|
||||||
use language_model::{LanguageModel, TokenUsage};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::{
|
|
||||||
fs,
|
|
||||||
io::Write,
|
|
||||||
path::{Path, PathBuf},
|
|
||||||
sync::Arc,
|
|
||||||
time::{Duration, SystemTime},
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
||||||
pub struct EvalResult {
|
|
||||||
pub exercise_name: String,
|
|
||||||
pub diff: String,
|
|
||||||
pub assistant_response: String,
|
|
||||||
pub elapsed_time_ms: u128,
|
|
||||||
pub timestamp: u128,
|
|
||||||
// Token usage fields
|
|
||||||
pub input_tokens: usize,
|
|
||||||
pub output_tokens: usize,
|
|
||||||
pub total_tokens: usize,
|
|
||||||
pub tool_use_counts: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct EvalOutput {
|
|
||||||
pub diff: String,
|
|
||||||
pub last_message: String,
|
|
||||||
pub elapsed_time: Duration,
|
|
||||||
pub assistant_response_count: usize,
|
|
||||||
pub tool_use_counts: HashMap<Arc<str>, u32>,
|
|
||||||
pub token_usage: TokenUsage,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct EvalSetup {
|
|
||||||
pub url: String,
|
|
||||||
pub base_sha: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Eval {
|
|
||||||
pub repo_path: PathBuf,
|
|
||||||
pub eval_setup: EvalSetup,
|
|
||||||
pub user_prompt: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Eval {
|
|
||||||
// Keep this method for potential future use, but mark it as intentionally unused
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub async fn load(_name: String, path: PathBuf, repos_dir: &Path) -> Result<Self> {
|
|
||||||
let prompt_path = path.join("prompt.txt");
|
|
||||||
let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
|
|
||||||
let setup_path = path.join("setup.json");
|
|
||||||
let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
|
|
||||||
let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
|
|
||||||
|
|
||||||
// Move this internal function inside the load method since it's only used here
|
|
||||||
fn repo_dir_name(url: &str) -> String {
|
|
||||||
url.trim_start_matches("https://")
|
|
||||||
.replace(|c: char| !c.is_alphanumeric(), "_")
|
|
||||||
}
|
|
||||||
|
|
||||||
let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
|
|
||||||
|
|
||||||
Ok(Eval {
|
|
||||||
repo_path,
|
|
||||||
eval_setup,
|
|
||||||
user_prompt,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn run(
|
|
||||||
self,
|
|
||||||
app_state: Arc<HeadlessAppState>,
|
|
||||||
model: Arc<dyn LanguageModel>,
|
|
||||||
cx: &mut App,
|
|
||||||
) -> Task<Result<EvalOutput>> {
|
|
||||||
cx.spawn(async move |cx| {
|
|
||||||
run_git(&self.repo_path, &["checkout", &self.eval_setup.base_sha]).await?;
|
|
||||||
|
|
||||||
let (assistant, done_rx) =
|
|
||||||
cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
|
|
||||||
|
|
||||||
let _worktree = assistant
|
|
||||||
.update(cx, |assistant, cx| {
|
|
||||||
assistant.project.update(cx, |project, cx| {
|
|
||||||
project.create_worktree(&self.repo_path, true, cx)
|
|
||||||
})
|
|
||||||
})?
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let start_time = std::time::SystemTime::now();
|
|
||||||
|
|
||||||
let (system_prompt_context, load_error) = cx
|
|
||||||
.update(|cx| {
|
|
||||||
assistant
|
|
||||||
.read(cx)
|
|
||||||
.thread
|
|
||||||
.read(cx)
|
|
||||||
.load_system_prompt_context(cx)
|
|
||||||
})?
|
|
||||||
.await;
|
|
||||||
|
|
||||||
if let Some(load_error) = load_error {
|
|
||||||
return Err(anyhow!("{:?}", load_error));
|
|
||||||
};
|
|
||||||
|
|
||||||
assistant.update(cx, |assistant, cx| {
|
|
||||||
assistant.thread.update(cx, |thread, cx| {
|
|
||||||
let context = vec![];
|
|
||||||
thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
|
|
||||||
thread.set_system_prompt_context(system_prompt_context);
|
|
||||||
thread.send_to_model(model, RequestKind::Chat, cx);
|
|
||||||
});
|
|
||||||
})?;
|
|
||||||
|
|
||||||
done_rx.recv().await??;
|
|
||||||
|
|
||||||
// Add this section to check untracked files
|
|
||||||
println!("Checking for untracked files:");
|
|
||||||
let untracked = run_git(
|
|
||||||
&self.repo_path,
|
|
||||||
&["ls-files", "--others", "--exclude-standard"],
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
if untracked.is_empty() {
|
|
||||||
println!("No untracked files found");
|
|
||||||
} else {
|
|
||||||
// Add all files to git so they appear in the diff
|
|
||||||
println!("Adding untracked files to git");
|
|
||||||
run_git(&self.repo_path, &["add", "."]).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// get git status
|
|
||||||
let _status = run_git(&self.repo_path, &["status", "--short"]).await?;
|
|
||||||
|
|
||||||
let elapsed_time = start_time.elapsed()?;
|
|
||||||
|
|
||||||
// Get diff of staged changes (the files we just added)
|
|
||||||
let staged_diff = run_git(&self.repo_path, &["diff", "--staged"]).await?;
|
|
||||||
|
|
||||||
// Get diff of unstaged changes
|
|
||||||
let unstaged_diff = run_git(&self.repo_path, &["diff"]).await?;
|
|
||||||
|
|
||||||
// Combine both diffs
|
|
||||||
let diff = if unstaged_diff.is_empty() {
|
|
||||||
staged_diff
|
|
||||||
} else if staged_diff.is_empty() {
|
|
||||||
unstaged_diff
|
|
||||||
} else {
|
|
||||||
format!(
|
|
||||||
"# Staged changes\n{}\n\n# Unstaged changes\n{}",
|
|
||||||
staged_diff, unstaged_diff
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
assistant.update(cx, |assistant, cx| {
|
|
||||||
let thread = assistant.thread.read(cx);
|
|
||||||
let last_message = thread.messages().last().unwrap();
|
|
||||||
if last_message.role != language_model::Role::Assistant {
|
|
||||||
return Err(anyhow!("Last message is not from assistant"));
|
|
||||||
}
|
|
||||||
let assistant_response_count = thread
|
|
||||||
.messages()
|
|
||||||
.filter(|message| message.role == language_model::Role::Assistant)
|
|
||||||
.count();
|
|
||||||
Ok(EvalOutput {
|
|
||||||
diff,
|
|
||||||
last_message: last_message.to_string(),
|
|
||||||
elapsed_time,
|
|
||||||
assistant_response_count,
|
|
||||||
tool_use_counts: assistant.tool_use_counts.clone(),
|
|
||||||
token_usage: thread.cumulative_token_usage(),
|
|
||||||
})
|
|
||||||
})?
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EvalOutput {
|
|
||||||
// Keep this method for potential future use, but mark it as intentionally unused
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub fn save_to_directory(&self, output_dir: &Path, eval_output_value: String) -> Result<()> {
|
|
||||||
// Create the output directory if it doesn't exist
|
|
||||||
fs::create_dir_all(&output_dir)?;
|
|
||||||
|
|
||||||
// Save the diff to a file
|
|
||||||
let diff_path = output_dir.join("diff.patch");
|
|
||||||
let mut diff_file = fs::File::create(&diff_path)?;
|
|
||||||
diff_file.write_all(self.diff.as_bytes())?;
|
|
||||||
|
|
||||||
// Save the last message to a file
|
|
||||||
let message_path = output_dir.join("assistant_response.txt");
|
|
||||||
let mut message_file = fs::File::create(&message_path)?;
|
|
||||||
message_file.write_all(self.last_message.as_bytes())?;
|
|
||||||
|
|
||||||
// Current metrics for this run
|
|
||||||
let current_metrics = serde_json::json!({
|
|
||||||
"elapsed_time_ms": self.elapsed_time.as_millis(),
|
|
||||||
"assistant_response_count": self.assistant_response_count,
|
|
||||||
"tool_use_counts": self.tool_use_counts,
|
|
||||||
"token_usage": self.token_usage,
|
|
||||||
"eval_output_value": eval_output_value,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Get current timestamp in milliseconds
|
|
||||||
let timestamp = std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)?
|
|
||||||
.as_millis()
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
// Path to metrics file
|
|
||||||
let metrics_path = output_dir.join("metrics.json");
|
|
||||||
|
|
||||||
// Load existing metrics if the file exists, or create a new object
|
|
||||||
let mut historical_metrics = if metrics_path.exists() {
|
|
||||||
let metrics_content = fs::read_to_string(&metrics_path)?;
|
|
||||||
serde_json::from_str::<serde_json::Value>(&metrics_content)
|
|
||||||
.unwrap_or_else(|_| serde_json::json!({}))
|
|
||||||
} else {
|
|
||||||
serde_json::json!({})
|
|
||||||
};
|
|
||||||
|
|
||||||
// Add new run with timestamp as key
|
|
||||||
if let serde_json::Value::Object(ref mut map) = historical_metrics {
|
|
||||||
map.insert(timestamp, current_metrics);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write updated metrics back to file
|
|
||||||
let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
|
|
||||||
let mut metrics_file = fs::File::create(&metrics_path)?;
|
|
||||||
metrics_file.write_all(metrics_json.as_bytes())?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn read_instructions(exercise_path: &Path) -> Result<String> {
|
|
||||||
let instructions_path = exercise_path.join(".docs").join("instructions.md");
|
|
||||||
println!("Reading instructions from: {}", instructions_path.display());
|
|
||||||
let instructions = smol::unblock(move || std::fs::read_to_string(&instructions_path)).await?;
|
|
||||||
Ok(instructions)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn save_eval_results(exercise_path: &Path, results: Vec<EvalResult>) -> Result<()> {
|
|
||||||
let eval_dir = exercise_path.join("evaluation");
|
|
||||||
fs::create_dir_all(&eval_dir)?;
|
|
||||||
|
|
||||||
let eval_file = eval_dir.join("evals.json");
|
|
||||||
|
|
||||||
println!("Saving evaluation results to: {}", eval_file.display());
|
|
||||||
println!(
|
|
||||||
"Results to save: {} evaluations for exercise path: {}",
|
|
||||||
results.len(),
|
|
||||||
exercise_path.display()
|
|
||||||
);
|
|
||||||
|
|
||||||
// Check file existence before reading/writing
|
|
||||||
if eval_file.exists() {
|
|
||||||
println!("Existing evals.json file found, will update it");
|
|
||||||
} else {
|
|
||||||
println!("No existing evals.json file found, will create new one");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Structure to organize evaluations by test name and timestamp
|
|
||||||
let mut eval_data: serde_json::Value = if eval_file.exists() {
|
|
||||||
let content = fs::read_to_string(&eval_file)?;
|
|
||||||
serde_json::from_str(&content).unwrap_or_else(|_| serde_json::json!({}))
|
|
||||||
} else {
|
|
||||||
serde_json::json!({})
|
|
||||||
};
|
|
||||||
|
|
||||||
// Get current timestamp for this batch of results
|
|
||||||
let timestamp = SystemTime::now()
|
|
||||||
.duration_since(SystemTime::UNIX_EPOCH)?
|
|
||||||
.as_millis()
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
// Group the new results by test name (exercise name)
|
|
||||||
for result in results {
|
|
||||||
let exercise_name = &result.exercise_name;
|
|
||||||
|
|
||||||
println!("Adding result: exercise={}", exercise_name);
|
|
||||||
|
|
||||||
// Ensure the exercise entry exists
|
|
||||||
if eval_data.get(exercise_name).is_none() {
|
|
||||||
eval_data[exercise_name] = serde_json::json!({});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the timestamp entry exists as an object
|
|
||||||
if eval_data[exercise_name].get(×tamp).is_none() {
|
|
||||||
eval_data[exercise_name][×tamp] = serde_json::json!({});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add this result under the timestamp with template name as key
|
|
||||||
eval_data[exercise_name][×tamp] = serde_json::to_value(&result)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write back to file with pretty formatting
|
|
||||||
let json_content = serde_json::to_string_pretty(&eval_data)?;
|
|
||||||
match fs::write(&eval_file, json_content) {
|
|
||||||
Ok(_) => println!("✓ Successfully saved results to {}", eval_file.display()),
|
|
||||||
Err(e) => println!("✗ Failed to write results file: {}", e),
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run_exercise_eval(
|
|
||||||
exercise_path: PathBuf,
|
|
||||||
model: Arc<dyn LanguageModel>,
|
|
||||||
app_state: Arc<HeadlessAppState>,
|
|
||||||
base_sha: String,
|
|
||||||
_framework_path: PathBuf,
|
|
||||||
cx: gpui::AsyncApp,
|
|
||||||
) -> Result<EvalResult> {
|
|
||||||
let exercise_name = get_exercise_name(&exercise_path);
|
|
||||||
let language = get_exercise_language(&exercise_path)?;
|
|
||||||
let mut instructions = read_instructions(&exercise_path).await?;
|
|
||||||
instructions.push_str(&format!(
|
|
||||||
"\n\nWhen writing the code for this prompt, use {} to achieve the goal.",
|
|
||||||
language
|
|
||||||
));
|
|
||||||
|
|
||||||
println!("Running evaluation for exercise: {}", exercise_name);
|
|
||||||
|
|
||||||
// Create temporary directory with exercise files
|
|
||||||
let temp_dir = setup_temp_repo(&exercise_path, &base_sha).await?;
|
|
||||||
let temp_path = temp_dir.path().to_path_buf();
|
|
||||||
|
|
||||||
let local_commit_sha = run_git(&temp_path, &["rev-parse", "HEAD"]).await?;
|
|
||||||
|
|
||||||
let start_time = SystemTime::now();
|
|
||||||
|
|
||||||
// Create a basic eval struct to work with the existing system
|
|
||||||
let eval = Eval {
|
|
||||||
repo_path: temp_path.clone(),
|
|
||||||
eval_setup: EvalSetup {
|
|
||||||
url: format!("file://{}", temp_path.display()),
|
|
||||||
base_sha: local_commit_sha, // Use the local commit SHA instead of the framework base SHA
|
|
||||||
},
|
|
||||||
user_prompt: instructions.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Run the evaluation
|
|
||||||
let eval_output = cx
|
|
||||||
.update(|cx| eval.run(app_state.clone(), model.clone(), cx))?
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Get diff from git
|
|
||||||
let diff = eval_output.diff.clone();
|
|
||||||
|
|
||||||
let elapsed_time = start_time.elapsed()?;
|
|
||||||
|
|
||||||
// Calculate total tokens as the sum of input and output tokens
|
|
||||||
let input_tokens = eval_output.token_usage.input_tokens;
|
|
||||||
let output_tokens = eval_output.token_usage.output_tokens;
|
|
||||||
let tool_use_counts = eval_output.tool_use_counts.values().sum::<u32>();
|
|
||||||
let total_tokens = input_tokens + output_tokens;
|
|
||||||
|
|
||||||
// Save results to evaluation directory
|
|
||||||
let result = EvalResult {
|
|
||||||
exercise_name: exercise_name.clone(),
|
|
||||||
diff,
|
|
||||||
assistant_response: eval_output.last_message.clone(),
|
|
||||||
elapsed_time_ms: elapsed_time.as_millis(),
|
|
||||||
timestamp: SystemTime::now()
|
|
||||||
.duration_since(SystemTime::UNIX_EPOCH)?
|
|
||||||
.as_millis(),
|
|
||||||
// Convert u32 token counts to usize
|
|
||||||
input_tokens: input_tokens.try_into().unwrap(),
|
|
||||||
output_tokens: output_tokens.try_into().unwrap(),
|
|
||||||
total_tokens: total_tokens.try_into().unwrap(),
|
|
||||||
tool_use_counts: tool_use_counts.try_into().unwrap(),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
|
@ -1,149 +0,0 @@
|
||||||
use anyhow::{Result, anyhow};
|
|
||||||
use std::{
|
|
||||||
fs,
|
|
||||||
path::{Path, PathBuf},
|
|
||||||
};
|
|
||||||
|
|
||||||
pub fn get_exercise_name(exercise_path: &Path) -> String {
|
|
||||||
exercise_path
|
|
||||||
.file_name()
|
|
||||||
.unwrap_or_default()
|
|
||||||
.to_string_lossy()
|
|
||||||
.to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_exercise_language(exercise_path: &Path) -> Result<String> {
|
|
||||||
// Extract the language from path (data/python/exercises/... => python)
|
|
||||||
let parts: Vec<_> = exercise_path.components().collect();
|
|
||||||
|
|
||||||
for (i, part) in parts.iter().enumerate() {
|
|
||||||
if i > 0 && part.as_os_str() == "eval_code" {
|
|
||||||
if i + 1 < parts.len() {
|
|
||||||
let language = parts[i + 1].as_os_str().to_string_lossy().to_string();
|
|
||||||
return Ok(language);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Err(anyhow!(
|
|
||||||
"Could not determine language from path: {:?}",
|
|
||||||
exercise_path
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn find_exercises(
|
|
||||||
framework_path: &Path,
|
|
||||||
languages: &[&str],
|
|
||||||
max_per_language: Option<usize>,
|
|
||||||
) -> Result<Vec<PathBuf>> {
|
|
||||||
let mut all_exercises = Vec::new();
|
|
||||||
|
|
||||||
println!("Searching for exercises in languages: {:?}", languages);
|
|
||||||
|
|
||||||
for language in languages {
|
|
||||||
let language_dir = framework_path
|
|
||||||
.join("eval_code")
|
|
||||||
.join(language)
|
|
||||||
.join("exercises")
|
|
||||||
.join("practice");
|
|
||||||
|
|
||||||
println!("Checking language directory: {:?}", language_dir);
|
|
||||||
if !language_dir.exists() {
|
|
||||||
println!("Warning: Language directory not found: {:?}", language_dir);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut exercises = Vec::new();
|
|
||||||
match fs::read_dir(&language_dir) {
|
|
||||||
Ok(entries) => {
|
|
||||||
for entry_result in entries {
|
|
||||||
match entry_result {
|
|
||||||
Ok(entry) => {
|
|
||||||
let path = entry.path();
|
|
||||||
|
|
||||||
if path.is_dir() {
|
|
||||||
// Special handling for "internal" directory
|
|
||||||
if *language == "internal" {
|
|
||||||
// Check for repo_info.json to validate it's an internal exercise
|
|
||||||
let repo_info_path = path.join(".meta").join("repo_info.json");
|
|
||||||
let instructions_path =
|
|
||||||
path.join(".docs").join("instructions.md");
|
|
||||||
|
|
||||||
if repo_info_path.exists() && instructions_path.exists() {
|
|
||||||
exercises.push(path);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Map the language to the file extension - original code
|
|
||||||
let language_extension = match *language {
|
|
||||||
"python" => "py",
|
|
||||||
"go" => "go",
|
|
||||||
"rust" => "rs",
|
|
||||||
"typescript" => "ts",
|
|
||||||
"javascript" => "js",
|
|
||||||
"ruby" => "rb",
|
|
||||||
"php" => "php",
|
|
||||||
"bash" => "sh",
|
|
||||||
"multi" => "diff",
|
|
||||||
_ => continue, // Skip unsupported languages
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check if this is a valid exercise with instructions and example
|
|
||||||
let instructions_path =
|
|
||||||
path.join(".docs").join("instructions.md");
|
|
||||||
let has_instructions = instructions_path.exists();
|
|
||||||
let example_path = path
|
|
||||||
.join(".meta")
|
|
||||||
.join(format!("example.{}", language_extension));
|
|
||||||
let has_example = example_path.exists();
|
|
||||||
|
|
||||||
if has_instructions && has_example {
|
|
||||||
exercises.push(path);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(err) => println!("Error reading directory entry: {}", err),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(err) => println!(
|
|
||||||
"Error reading directory {}: {}",
|
|
||||||
language_dir.display(),
|
|
||||||
err
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort exercises by name for consistent selection
|
|
||||||
exercises.sort_by(|a, b| {
|
|
||||||
let a_name = a.file_name().unwrap_or_default().to_string_lossy();
|
|
||||||
let b_name = b.file_name().unwrap_or_default().to_string_lossy();
|
|
||||||
a_name.cmp(&b_name)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Apply the limit if specified
|
|
||||||
if let Some(limit) = max_per_language {
|
|
||||||
if exercises.len() > limit {
|
|
||||||
println!(
|
|
||||||
"Limiting {} exercises to {} for language {}",
|
|
||||||
exercises.len(),
|
|
||||||
limit,
|
|
||||||
language
|
|
||||||
);
|
|
||||||
exercises.truncate(limit);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
println!(
|
|
||||||
"Found {} exercises for language {}: {:?}",
|
|
||||||
exercises.len(),
|
|
||||||
language,
|
|
||||||
exercises
|
|
||||||
.iter()
|
|
||||||
.map(|p| p.file_name().unwrap_or_default().to_string_lossy())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
);
|
|
||||||
all_exercises.extend(exercises);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(all_exercises)
|
|
||||||
}
|
|
|
@ -1,125 +0,0 @@
|
||||||
use anyhow::{Result, anyhow};
|
|
||||||
use serde::Deserialize;
|
|
||||||
use std::{fs, path::Path};
|
|
||||||
use tempfile::TempDir;
|
|
||||||
use util::command::new_smol_command;
|
|
||||||
use walkdir::WalkDir;
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct SetupConfig {
|
|
||||||
#[serde(rename = "base.sha")]
|
|
||||||
pub base_sha: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct RepoInfo {
|
|
||||||
pub remote_url: String,
|
|
||||||
pub head_sha: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
|
|
||||||
let output = new_smol_command("git")
|
|
||||||
.current_dir(repo_path)
|
|
||||||
.args(args)
|
|
||||||
.output()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if output.status.success() {
|
|
||||||
Ok(String::from_utf8(output.stdout)?.trim().to_string())
|
|
||||||
} else {
|
|
||||||
Err(anyhow!(
|
|
||||||
"Git command failed: {} with status: {}",
|
|
||||||
args.join(" "),
|
|
||||||
output.status
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn read_base_sha(framework_path: &Path) -> Result<String> {
|
|
||||||
let setup_path = framework_path.join("setup.json");
|
|
||||||
let setup_content = smol::unblock(move || std::fs::read_to_string(&setup_path)).await?;
|
|
||||||
let setup_config: SetupConfig = serde_json_lenient::from_str_lenient(&setup_content)?;
|
|
||||||
Ok(setup_config.base_sha)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn read_repo_info(exercise_path: &Path) -> Result<RepoInfo> {
|
|
||||||
let repo_info_path = exercise_path.join(".meta").join("repo_info.json");
|
|
||||||
println!("Reading repo info from: {}", repo_info_path.display());
|
|
||||||
let repo_info_content = smol::unblock(move || std::fs::read_to_string(&repo_info_path)).await?;
|
|
||||||
let repo_info: RepoInfo = serde_json_lenient::from_str_lenient(&repo_info_content)?;
|
|
||||||
|
|
||||||
// Remove any quotes from the strings
|
|
||||||
let remote_url = repo_info.remote_url.trim_matches('"').to_string();
|
|
||||||
let head_sha = repo_info.head_sha.trim_matches('"').to_string();
|
|
||||||
|
|
||||||
Ok(RepoInfo {
|
|
||||||
remote_url,
|
|
||||||
head_sha,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn setup_temp_repo(exercise_path: &Path, _base_sha: &str) -> Result<TempDir> {
|
|
||||||
let temp_dir = TempDir::new()?;
|
|
||||||
|
|
||||||
// Check if this is an internal exercise by looking for repo_info.json
|
|
||||||
let repo_info_path = exercise_path.join(".meta").join("repo_info.json");
|
|
||||||
if repo_info_path.exists() {
|
|
||||||
// This is an internal exercise, handle it differently
|
|
||||||
let repo_info = read_repo_info(exercise_path).await?;
|
|
||||||
|
|
||||||
// Clone the repository to the temp directory
|
|
||||||
let url = repo_info.remote_url;
|
|
||||||
let clone_path = temp_dir.path();
|
|
||||||
println!(
|
|
||||||
"Cloning repository from {} to {}",
|
|
||||||
url,
|
|
||||||
clone_path.display()
|
|
||||||
);
|
|
||||||
run_git(
|
|
||||||
&std::env::current_dir()?,
|
|
||||||
&["clone", &url, &clone_path.to_string_lossy()],
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Checkout the specified commit
|
|
||||||
println!("Checking out commit: {}", repo_info.head_sha);
|
|
||||||
run_git(temp_dir.path(), &["checkout", &repo_info.head_sha]).await?;
|
|
||||||
|
|
||||||
println!("Successfully set up internal repository");
|
|
||||||
} else {
|
|
||||||
// Original code for regular exercises
|
|
||||||
// Copy the exercise files to the temp directory, excluding .docs and .meta
|
|
||||||
for entry in WalkDir::new(exercise_path).min_depth(0).max_depth(10) {
|
|
||||||
let entry = entry?;
|
|
||||||
let source_path = entry.path();
|
|
||||||
|
|
||||||
// Skip .docs and .meta directories completely
|
|
||||||
if source_path.starts_with(exercise_path.join(".docs"))
|
|
||||||
|| source_path.starts_with(exercise_path.join(".meta"))
|
|
||||||
{
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if source_path.is_file() {
|
|
||||||
let relative_path = source_path.strip_prefix(exercise_path)?;
|
|
||||||
let dest_path = temp_dir.path().join(relative_path);
|
|
||||||
|
|
||||||
// Make sure parent directories exist
|
|
||||||
if let Some(parent) = dest_path.parent() {
|
|
||||||
fs::create_dir_all(parent)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
fs::copy(source_path, dest_path)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize git repo in the temp directory
|
|
||||||
run_git(temp_dir.path(), &["init"]).await?;
|
|
||||||
run_git(temp_dir.path(), &["add", "."]).await?;
|
|
||||||
run_git(temp_dir.path(), &["commit", "-m", "Initial commit"]).await?;
|
|
||||||
|
|
||||||
println!("Created temp repo without .docs and .meta directories");
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(temp_dir)
|
|
||||||
}
|
|
|
@ -1,229 +0,0 @@
|
||||||
use agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
|
|
||||||
use anyhow::anyhow;
|
|
||||||
use assistant_tool::ToolWorkingSet;
|
|
||||||
use client::{Client, UserStore};
|
|
||||||
use collections::HashMap;
|
|
||||||
use dap::DapRegistry;
|
|
||||||
use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
|
|
||||||
use language::LanguageRegistry;
|
|
||||||
use language_model::{
|
|
||||||
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
|
|
||||||
};
|
|
||||||
use node_runtime::NodeRuntime;
|
|
||||||
use project::{Project, RealFs};
|
|
||||||
use prompt_store::PromptBuilder;
|
|
||||||
use settings::SettingsStore;
|
|
||||||
use smol::channel;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
|
|
||||||
pub struct HeadlessAppState {
|
|
||||||
pub languages: Arc<LanguageRegistry>,
|
|
||||||
pub client: Arc<Client>,
|
|
||||||
pub user_store: Entity<UserStore>,
|
|
||||||
pub fs: Arc<dyn fs::Fs>,
|
|
||||||
pub node_runtime: NodeRuntime,
|
|
||||||
|
|
||||||
// Additional fields not present in `workspace::AppState`.
|
|
||||||
pub prompt_builder: Arc<PromptBuilder>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct HeadlessAssistant {
|
|
||||||
pub thread: Entity<Thread>,
|
|
||||||
pub project: Entity<Project>,
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub thread_store: Entity<ThreadStore>,
|
|
||||||
pub tool_use_counts: HashMap<Arc<str>, u32>,
|
|
||||||
pub done_tx: channel::Sender<anyhow::Result<()>>,
|
|
||||||
_subscription: Subscription,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl HeadlessAssistant {
|
|
||||||
pub fn new(
|
|
||||||
app_state: Arc<HeadlessAppState>,
|
|
||||||
cx: &mut App,
|
|
||||||
) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
|
|
||||||
let env = None;
|
|
||||||
let project = Project::local(
|
|
||||||
app_state.client.clone(),
|
|
||||||
app_state.node_runtime.clone(),
|
|
||||||
app_state.user_store.clone(),
|
|
||||||
app_state.languages.clone(),
|
|
||||||
Arc::new(DapRegistry::default()),
|
|
||||||
app_state.fs.clone(),
|
|
||||||
env,
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
|
|
||||||
let tools = Arc::new(ToolWorkingSet::default());
|
|
||||||
let thread_store =
|
|
||||||
ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
|
|
||||||
|
|
||||||
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
|
|
||||||
|
|
||||||
let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
|
|
||||||
|
|
||||||
let headless_thread = cx.new(move |cx| Self {
|
|
||||||
_subscription: cx.subscribe(&thread, Self::handle_thread_event),
|
|
||||||
thread,
|
|
||||||
project,
|
|
||||||
thread_store,
|
|
||||||
tool_use_counts: HashMap::default(),
|
|
||||||
done_tx,
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok((headless_thread, done_rx))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn handle_thread_event(
|
|
||||||
&mut self,
|
|
||||||
thread: Entity<Thread>,
|
|
||||||
event: &ThreadEvent,
|
|
||||||
cx: &mut Context<Self>,
|
|
||||||
) {
|
|
||||||
match event {
|
|
||||||
ThreadEvent::ShowError(err) => self
|
|
||||||
.done_tx
|
|
||||||
.send_blocking(Err(anyhow!("{:?}", err)))
|
|
||||||
.unwrap(),
|
|
||||||
ThreadEvent::DoneStreaming => {
|
|
||||||
let thread = thread.read(cx);
|
|
||||||
if let Some(message) = thread.messages().last() {
|
|
||||||
println!("Message: {}", message.to_string());
|
|
||||||
}
|
|
||||||
if thread.all_tools_finished() {
|
|
||||||
self.done_tx.send_blocking(Ok(())).unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ThreadEvent::UsePendingTools { .. } => {}
|
|
||||||
ThreadEvent::ToolConfirmationNeeded => {
|
|
||||||
// Automatically approve all tools that need confirmation in headless mode
|
|
||||||
println!("Tool confirmation needed - automatically approving in headless mode");
|
|
||||||
|
|
||||||
// Get the tools needing confirmation
|
|
||||||
let tools_needing_confirmation: Vec<_> = thread
|
|
||||||
.read(cx)
|
|
||||||
.tools_needing_confirmation()
|
|
||||||
.cloned()
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Run each tool that needs confirmation
|
|
||||||
for tool_use in tools_needing_confirmation {
|
|
||||||
if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
|
|
||||||
thread.update(cx, |thread, cx| {
|
|
||||||
println!("Auto-approving tool: {}", tool_use.name);
|
|
||||||
|
|
||||||
// Create a request to send to the tool
|
|
||||||
let request = thread.to_completion_request(RequestKind::Chat, cx);
|
|
||||||
let messages = Arc::new(request.messages);
|
|
||||||
|
|
||||||
// Run the tool
|
|
||||||
thread.run_tool(
|
|
||||||
tool_use.id.clone(),
|
|
||||||
tool_use.ui_text.clone(),
|
|
||||||
tool_use.input.clone(),
|
|
||||||
&messages,
|
|
||||||
tool,
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ThreadEvent::ToolFinished {
|
|
||||||
tool_use_id,
|
|
||||||
pending_tool_use,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
if let Some(pending_tool_use) = pending_tool_use {
|
|
||||||
println!(
|
|
||||||
"Used tool {} with input: {}",
|
|
||||||
pending_tool_use.name, pending_tool_use.input
|
|
||||||
);
|
|
||||||
*self
|
|
||||||
.tool_use_counts
|
|
||||||
.entry(pending_tool_use.name.clone())
|
|
||||||
.or_insert(0) += 1;
|
|
||||||
}
|
|
||||||
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
|
|
||||||
println!("Tool result: {:?}", tool_result);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
|
|
||||||
release_channel::init(SemanticVersion::default(), cx);
|
|
||||||
gpui_tokio::init(cx);
|
|
||||||
|
|
||||||
let mut settings_store = SettingsStore::new(cx);
|
|
||||||
settings_store
|
|
||||||
.set_default_settings(settings::default_settings().as_ref(), cx)
|
|
||||||
.unwrap();
|
|
||||||
cx.set_global(settings_store);
|
|
||||||
client::init_settings(cx);
|
|
||||||
Project::init_settings(cx);
|
|
||||||
|
|
||||||
let client = Client::production(cx);
|
|
||||||
cx.set_http_client(client.http_client().clone());
|
|
||||||
|
|
||||||
let git_binary_path = None;
|
|
||||||
let fs = Arc::new(RealFs::new(
|
|
||||||
git_binary_path,
|
|
||||||
cx.background_executor().clone(),
|
|
||||||
));
|
|
||||||
|
|
||||||
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
|
|
||||||
|
|
||||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
|
||||||
|
|
||||||
language::init(cx);
|
|
||||||
language_model::init(client.clone(), cx);
|
|
||||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
|
||||||
assistant_tools::init(client.http_client().clone(), cx);
|
|
||||||
context_server::init(cx);
|
|
||||||
let stdout_is_a_pty = false;
|
|
||||||
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
|
|
||||||
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
|
|
||||||
|
|
||||||
Arc::new(HeadlessAppState {
|
|
||||||
languages,
|
|
||||||
client,
|
|
||||||
user_store,
|
|
||||||
fs,
|
|
||||||
node_runtime: NodeRuntime::unavailable(),
|
|
||||||
prompt_builder,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
|
|
||||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
|
||||||
let model = model_registry
|
|
||||||
.available_models(cx)
|
|
||||||
.find(|model| model.id().0 == model_name);
|
|
||||||
|
|
||||||
let Some(model) = model else {
|
|
||||||
return Err(anyhow!(
|
|
||||||
"No language model named {} was available. Available models: {}",
|
|
||||||
model_name,
|
|
||||||
model_registry
|
|
||||||
.available_models(cx)
|
|
||||||
.map(|model| model.id().0.clone())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join(", ")
|
|
||||||
));
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(model)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn authenticate_model_provider(
|
|
||||||
provider_id: LanguageModelProviderId,
|
|
||||||
cx: &mut App,
|
|
||||||
) -> Task<std::result::Result<(), AuthenticateError>> {
|
|
||||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
|
||||||
let model_provider = model_registry.provider(&provider_id).unwrap();
|
|
||||||
model_provider.authenticate(cx)
|
|
||||||
}
|
|
|
@ -1,205 +0,0 @@
|
||||||
mod eval;
|
|
||||||
mod get_exercise;
|
|
||||||
mod git_commands;
|
|
||||||
mod headless_assistant;
|
|
||||||
|
|
||||||
use clap::Parser;
|
|
||||||
use eval::{run_exercise_eval, save_eval_results};
|
|
||||||
use futures::stream::{self, StreamExt};
|
|
||||||
use get_exercise::{find_exercises, get_exercise_language, get_exercise_name};
|
|
||||||
use git_commands::read_base_sha;
|
|
||||||
use gpui::Application;
|
|
||||||
use headless_assistant::{authenticate_model_provider, find_model};
|
|
||||||
use language_model::LanguageModelRegistry;
|
|
||||||
use reqwest_client::ReqwestClient;
|
|
||||||
use std::{path::PathBuf, sync::Arc};
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(
|
|
||||||
name = "agent_eval",
|
|
||||||
disable_version_flag = true,
|
|
||||||
before_help = "Tool eval runner"
|
|
||||||
)]
|
|
||||||
struct Args {
|
|
||||||
/// Match the names of evals to run.
|
|
||||||
#[arg(long)]
|
|
||||||
exercise_names: Vec<String>,
|
|
||||||
/// Runs all exercises, causes the exercise_names to be ignored.
|
|
||||||
#[arg(long)]
|
|
||||||
all: bool,
|
|
||||||
/// Supported language types to evaluate (default: internal).
|
|
||||||
/// Internal is data generated from the agent panel
|
|
||||||
#[arg(long, default_value = "internal")]
|
|
||||||
languages: String,
|
|
||||||
/// Name of the model (default: "claude-3-7-sonnet-latest")
|
|
||||||
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
|
|
||||||
model_name: String,
|
|
||||||
/// Name of the editor model (default: value of `--model_name`).
|
|
||||||
#[arg(long)]
|
|
||||||
editor_model_name: Option<String>,
|
|
||||||
/// Number of evaluations to run concurrently (default: 3)
|
|
||||||
#[arg(short, long, default_value = "5")]
|
|
||||||
concurrency: usize,
|
|
||||||
/// Maximum number of exercises to evaluate per language
|
|
||||||
#[arg(long)]
|
|
||||||
max_exercises_per_language: Option<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
env_logger::init();
|
|
||||||
let args = Args::parse();
|
|
||||||
let http_client = Arc::new(ReqwestClient::new());
|
|
||||||
let app = Application::headless().with_http_client(http_client.clone());
|
|
||||||
|
|
||||||
// Path to the zed-ace-framework repo
|
|
||||||
let framework_path = PathBuf::from("../zed-ace-framework")
|
|
||||||
.canonicalize()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Fix the 'languages' lifetime issue by creating owned Strings instead of slices
|
|
||||||
let languages: Vec<String> = args.languages.split(',').map(|s| s.to_string()).collect();
|
|
||||||
|
|
||||||
println!("Using zed-ace-framework at: {:?}", framework_path);
|
|
||||||
println!("Evaluating languages: {:?}", languages);
|
|
||||||
|
|
||||||
app.run(move |cx| {
|
|
||||||
let app_state = headless_assistant::init(cx);
|
|
||||||
|
|
||||||
let model = find_model(&args.model_name, cx).unwrap();
|
|
||||||
let editor_model = if let Some(model_name) = &args.editor_model_name {
|
|
||||||
find_model(model_name, cx).unwrap()
|
|
||||||
} else {
|
|
||||||
model.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
|
||||||
registry.set_default_model(Some(model.clone()), cx);
|
|
||||||
});
|
|
||||||
|
|
||||||
let model_provider_id = model.provider_id();
|
|
||||||
let editor_model_provider_id = editor_model.provider_id();
|
|
||||||
|
|
||||||
let framework_path_clone = framework_path.clone();
|
|
||||||
let languages_clone = languages.clone();
|
|
||||||
let exercise_names = args.exercise_names.clone();
|
|
||||||
let all_flag = args.all;
|
|
||||||
|
|
||||||
cx.spawn(async move |cx| {
|
|
||||||
// Authenticate all model providers first
|
|
||||||
cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
|
|
||||||
.unwrap()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
|
|
||||||
.unwrap()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
println!("framework path: {}", framework_path_clone.display());
|
|
||||||
|
|
||||||
let base_sha = read_base_sha(&framework_path_clone).await.unwrap();
|
|
||||||
|
|
||||||
println!("base sha: {}", base_sha);
|
|
||||||
|
|
||||||
let all_exercises = find_exercises(
|
|
||||||
&framework_path_clone,
|
|
||||||
&languages_clone
|
|
||||||
.iter()
|
|
||||||
.map(|s| s.as_str())
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
args.max_exercises_per_language,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
println!("Found {} exercises total", all_exercises.len());
|
|
||||||
|
|
||||||
// Filter exercises if specific ones were requested
|
|
||||||
let exercises_to_run = if !exercise_names.is_empty() {
|
|
||||||
// If exercise names are specified, filter by them regardless of --all flag
|
|
||||||
all_exercises
|
|
||||||
.into_iter()
|
|
||||||
.filter(|path| {
|
|
||||||
let name = get_exercise_name(path);
|
|
||||||
exercise_names.iter().any(|filter| name.contains(filter))
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
} else if all_flag {
|
|
||||||
// Only use all_flag if no exercise names are specified
|
|
||||||
all_exercises
|
|
||||||
} else {
|
|
||||||
// Default behavior (no filters)
|
|
||||||
all_exercises
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("Will run {} exercises", exercises_to_run.len());
|
|
||||||
|
|
||||||
// Create exercise eval tasks - each exercise is a single task that will run templates sequentially
|
|
||||||
let exercise_tasks: Vec<_> = exercises_to_run
|
|
||||||
.into_iter()
|
|
||||||
.map(|exercise_path| {
|
|
||||||
let exercise_name = get_exercise_name(&exercise_path);
|
|
||||||
let model_clone = model.clone();
|
|
||||||
let app_state_clone = app_state.clone();
|
|
||||||
let base_sha_clone = base_sha.clone();
|
|
||||||
let framework_path_clone = framework_path_clone.clone();
|
|
||||||
let cx_clone = cx.clone();
|
|
||||||
|
|
||||||
async move {
|
|
||||||
println!("Processing exercise: {}", exercise_name);
|
|
||||||
let mut exercise_results = Vec::new();
|
|
||||||
|
|
||||||
match run_exercise_eval(
|
|
||||||
exercise_path.clone(),
|
|
||||||
model_clone.clone(),
|
|
||||||
app_state_clone.clone(),
|
|
||||||
base_sha_clone.clone(),
|
|
||||||
framework_path_clone.clone(),
|
|
||||||
cx_clone.clone(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(result) => {
|
|
||||||
println!("Completed {}", exercise_name);
|
|
||||||
exercise_results.push(result);
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
println!("Error running {}: {}", exercise_name, err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save results for this exercise
|
|
||||||
if !exercise_results.is_empty() {
|
|
||||||
if let Err(err) =
|
|
||||||
save_eval_results(&exercise_path, exercise_results.clone()).await
|
|
||||||
{
|
|
||||||
println!("Error saving results for {}: {}", exercise_name, err);
|
|
||||||
} else {
|
|
||||||
println!("Saved results for {}", exercise_name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
exercise_results
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
println!(
|
|
||||||
"Running {} exercises with concurrency: {}",
|
|
||||||
exercise_tasks.len(),
|
|
||||||
args.concurrency
|
|
||||||
);
|
|
||||||
|
|
||||||
// Run exercises concurrently, with each exercise running its templates sequentially
|
|
||||||
let all_results = stream::iter(exercise_tasks)
|
|
||||||
.buffer_unordered(args.concurrency)
|
|
||||||
.flat_map(stream::iter)
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
println!("Completed {} evaluation runs", all_results.len());
|
|
||||||
cx.update(|cx| cx.quit()).unwrap();
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
});
|
|
||||||
|
|
||||||
println!("Done running evals");
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
[package]
|
|
||||||
name = "agent_rules"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition.workspace = true
|
|
||||||
publish.workspace = true
|
|
||||||
license = "GPL-3.0-or-later"
|
|
||||||
|
|
||||||
[lints]
|
|
||||||
workspace = true
|
|
||||||
|
|
||||||
[lib]
|
|
||||||
path = "src/agent_rules.rs"
|
|
||||||
doctest = false
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
anyhow.workspace = true
|
|
||||||
fs.workspace = true
|
|
||||||
gpui.workspace = true
|
|
||||||
prompt_store.workspace = true
|
|
||||||
util.workspace = true
|
|
||||||
worktree.workspace = true
|
|
||||||
workspace-hack = { version = "0.1", path = "../../tooling/workspace-hack" }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
|
||||||
indoc.workspace = true
|
|
|
@ -1 +0,0 @@
|
||||||
../../LICENSE-GPL
|
|
|
@ -1,51 +0,0 @@
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use anyhow::{Context as _, Result};
|
|
||||||
use fs::Fs;
|
|
||||||
use gpui::{App, AppContext, Task};
|
|
||||||
use prompt_store::SystemPromptRulesFile;
|
|
||||||
use util::maybe;
|
|
||||||
use worktree::Worktree;
|
|
||||||
|
|
||||||
const RULES_FILE_NAMES: [&'static str; 6] = [
|
|
||||||
".rules",
|
|
||||||
".cursorrules",
|
|
||||||
".windsurfrules",
|
|
||||||
".clinerules",
|
|
||||||
".github/copilot-instructions.md",
|
|
||||||
"CLAUDE.md",
|
|
||||||
];
|
|
||||||
|
|
||||||
pub fn load_worktree_rules_file(
|
|
||||||
fs: Arc<dyn Fs>,
|
|
||||||
worktree: &Worktree,
|
|
||||||
cx: &App,
|
|
||||||
) -> Option<Task<Result<SystemPromptRulesFile>>> {
|
|
||||||
let selected_rules_file = RULES_FILE_NAMES
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|name| {
|
|
||||||
worktree
|
|
||||||
.entry_for_path(name)
|
|
||||||
.filter(|entry| entry.is_file())
|
|
||||||
.map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
|
|
||||||
})
|
|
||||||
.next();
|
|
||||||
|
|
||||||
// Note that Cline supports `.clinerules` being a directory, but that is not currently
|
|
||||||
// supported. This doesn't seem to occur often in GitHub repositories.
|
|
||||||
selected_rules_file.map(|(path_in_worktree, abs_path)| {
|
|
||||||
let fs = fs.clone();
|
|
||||||
cx.background_spawn(maybe!(async move {
|
|
||||||
let abs_path = abs_path?;
|
|
||||||
let text = fs
|
|
||||||
.load(&abs_path)
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("Failed to load assistant rules file {:?}", abs_path))?;
|
|
||||||
anyhow::Ok(SystemPromptRulesFile {
|
|
||||||
path_in_worktree,
|
|
||||||
abs_path: abs_path.into(),
|
|
||||||
text: text.trim().to_string(),
|
|
||||||
})
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -69,7 +69,7 @@ pub enum AssistantProviderContentV1 {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Clone, Debug, Default)]
|
||||||
pub struct AssistantSettings {
|
pub struct AssistantSettings {
|
||||||
pub enabled: bool,
|
pub enabled: bool,
|
||||||
pub button: bool,
|
pub button: bool,
|
||||||
|
|
|
@ -179,11 +179,9 @@ pub async fn file_outline(
|
||||||
|
|
||||||
// Wait until the buffer has been fully parsed, so that we can read its outline.
|
// Wait until the buffer has been fully parsed, so that we can read its outline.
|
||||||
let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
|
let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
|
||||||
while parse_status
|
while *parse_status.borrow() != ParseStatus::Idle {
|
||||||
.recv()
|
parse_status.changed().await?;
|
||||||
.await
|
}
|
||||||
.map_or(false, |status| status != ParseStatus::Idle)
|
|
||||||
{}
|
|
||||||
|
|
||||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||||
let Some(outline) = snapshot.outline(None) else {
|
let Some(outline) = snapshot.outline(None) else {
|
||||||
|
|
|
@ -9,12 +9,13 @@ agent.workspace = true
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
assistant_tool.workspace = true
|
assistant_tool.workspace = true
|
||||||
assistant_tools.workspace = true
|
assistant_tools.workspace = true
|
||||||
|
assistant_settings.workspace = true
|
||||||
client.workspace = true
|
client.workspace = true
|
||||||
collections.workspace = true
|
|
||||||
context_server.workspace = true
|
context_server.workspace = true
|
||||||
dap.workspace = true
|
dap.workspace = true
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
fs.workspace = true
|
fs.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
gpui_tokio.workspace = true
|
gpui_tokio.workspace = true
|
||||||
language.workspace = true
|
language.workspace = true
|
||||||
|
@ -27,7 +28,6 @@ release_channel.workspace = true
|
||||||
reqwest_client.workspace = true
|
reqwest_client.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
settings.workspace = true
|
settings.workspace = true
|
||||||
smol.workspace = true
|
|
||||||
toml.workspace = true
|
toml.workspace = true
|
||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
|
|
||||||
|
|
|
@ -1,229 +0,0 @@
|
||||||
use ::agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
|
|
||||||
use anyhow::anyhow;
|
|
||||||
use assistant_tool::ToolWorkingSet;
|
|
||||||
use client::{Client, UserStore};
|
|
||||||
use collections::HashMap;
|
|
||||||
use dap::DapRegistry;
|
|
||||||
use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
|
|
||||||
use language::LanguageRegistry;
|
|
||||||
use language_model::{
|
|
||||||
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
|
|
||||||
};
|
|
||||||
use node_runtime::NodeRuntime;
|
|
||||||
use project::{Project, RealFs};
|
|
||||||
use prompt_store::PromptBuilder;
|
|
||||||
use settings::SettingsStore;
|
|
||||||
use smol::channel;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
|
|
||||||
pub struct AgentAppState {
|
|
||||||
pub languages: Arc<LanguageRegistry>,
|
|
||||||
pub client: Arc<Client>,
|
|
||||||
pub user_store: Entity<UserStore>,
|
|
||||||
pub fs: Arc<dyn fs::Fs>,
|
|
||||||
pub node_runtime: NodeRuntime,
|
|
||||||
|
|
||||||
// Additional fields not present in `workspace::AppState`.
|
|
||||||
pub prompt_builder: Arc<PromptBuilder>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Agent {
|
|
||||||
// pub thread: Entity<Thread>,
|
|
||||||
// pub project: Entity<Project>,
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub thread_store: Entity<ThreadStore>,
|
|
||||||
pub tool_use_counts: HashMap<Arc<str>, u32>,
|
|
||||||
pub done_tx: channel::Sender<anyhow::Result<()>>,
|
|
||||||
_subscription: Subscription,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Agent {
|
|
||||||
pub fn new(
|
|
||||||
app_state: Arc<AgentAppState>,
|
|
||||||
cx: &mut App,
|
|
||||||
) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
|
|
||||||
let env = None;
|
|
||||||
let project = Project::local(
|
|
||||||
app_state.client.clone(),
|
|
||||||
app_state.node_runtime.clone(),
|
|
||||||
app_state.user_store.clone(),
|
|
||||||
app_state.languages.clone(),
|
|
||||||
Arc::new(DapRegistry::default()),
|
|
||||||
app_state.fs.clone(),
|
|
||||||
env,
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
|
|
||||||
let tools = Arc::new(ToolWorkingSet::default());
|
|
||||||
let thread_store =
|
|
||||||
ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
|
|
||||||
|
|
||||||
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
|
|
||||||
|
|
||||||
let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
|
|
||||||
|
|
||||||
let headless_thread = cx.new(move |cx| Self {
|
|
||||||
_subscription: cx.subscribe(&thread, Self::handle_thread_event),
|
|
||||||
// thread,
|
|
||||||
// project,
|
|
||||||
thread_store,
|
|
||||||
tool_use_counts: HashMap::default(),
|
|
||||||
done_tx,
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok((headless_thread, done_rx))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn handle_thread_event(
|
|
||||||
&mut self,
|
|
||||||
thread: Entity<Thread>,
|
|
||||||
event: &ThreadEvent,
|
|
||||||
cx: &mut Context<Self>,
|
|
||||||
) {
|
|
||||||
match event {
|
|
||||||
ThreadEvent::ShowError(err) => self
|
|
||||||
.done_tx
|
|
||||||
.send_blocking(Err(anyhow!("{:?}", err)))
|
|
||||||
.unwrap(),
|
|
||||||
ThreadEvent::DoneStreaming => {
|
|
||||||
let thread = thread.read(cx);
|
|
||||||
if let Some(message) = thread.messages().last() {
|
|
||||||
println!("Message: {}", message.to_string());
|
|
||||||
}
|
|
||||||
if thread.all_tools_finished() {
|
|
||||||
self.done_tx.send_blocking(Ok(())).unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ThreadEvent::UsePendingTools { .. } => {}
|
|
||||||
ThreadEvent::ToolConfirmationNeeded => {
|
|
||||||
// Automatically approve all tools that need confirmation in headless mode
|
|
||||||
println!("Tool confirmation needed - automatically approving in headless mode");
|
|
||||||
|
|
||||||
// Get the tools needing confirmation
|
|
||||||
let tools_needing_confirmation: Vec<_> = thread
|
|
||||||
.read(cx)
|
|
||||||
.tools_needing_confirmation()
|
|
||||||
.cloned()
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Run each tool that needs confirmation
|
|
||||||
for tool_use in tools_needing_confirmation {
|
|
||||||
if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
|
|
||||||
thread.update(cx, |thread, cx| {
|
|
||||||
println!("Auto-approving tool: {}", tool_use.name);
|
|
||||||
|
|
||||||
// Create a request to send to the tool
|
|
||||||
let request = thread.to_completion_request(RequestKind::Chat, cx);
|
|
||||||
let messages = Arc::new(request.messages);
|
|
||||||
|
|
||||||
// Run the tool
|
|
||||||
thread.run_tool(
|
|
||||||
tool_use.id.clone(),
|
|
||||||
tool_use.ui_text.clone(),
|
|
||||||
tool_use.input.clone(),
|
|
||||||
&messages,
|
|
||||||
tool,
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ThreadEvent::ToolFinished {
|
|
||||||
tool_use_id,
|
|
||||||
pending_tool_use,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
if let Some(pending_tool_use) = pending_tool_use {
|
|
||||||
println!(
|
|
||||||
"Used tool {} with input: {}",
|
|
||||||
pending_tool_use.name, pending_tool_use.input
|
|
||||||
);
|
|
||||||
*self
|
|
||||||
.tool_use_counts
|
|
||||||
.entry(pending_tool_use.name.clone())
|
|
||||||
.or_insert(0) += 1;
|
|
||||||
}
|
|
||||||
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
|
|
||||||
println!("Tool result: {:?}", tool_result);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
|
||||||
release_channel::init(SemanticVersion::default(), cx);
|
|
||||||
gpui_tokio::init(cx);
|
|
||||||
|
|
||||||
let mut settings_store = SettingsStore::new(cx);
|
|
||||||
settings_store
|
|
||||||
.set_default_settings(settings::default_settings().as_ref(), cx)
|
|
||||||
.unwrap();
|
|
||||||
cx.set_global(settings_store);
|
|
||||||
client::init_settings(cx);
|
|
||||||
Project::init_settings(cx);
|
|
||||||
|
|
||||||
let client = Client::production(cx);
|
|
||||||
cx.set_http_client(client.http_client().clone());
|
|
||||||
|
|
||||||
let git_binary_path = None;
|
|
||||||
let fs = Arc::new(RealFs::new(
|
|
||||||
git_binary_path,
|
|
||||||
cx.background_executor().clone(),
|
|
||||||
));
|
|
||||||
|
|
||||||
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
|
|
||||||
|
|
||||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
|
||||||
|
|
||||||
language::init(cx);
|
|
||||||
language_model::init(client.clone(), cx);
|
|
||||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
|
||||||
assistant_tools::init(client.http_client().clone(), cx);
|
|
||||||
context_server::init(cx);
|
|
||||||
let stdout_is_a_pty = false;
|
|
||||||
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
|
|
||||||
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
|
|
||||||
|
|
||||||
Arc::new(AgentAppState {
|
|
||||||
languages,
|
|
||||||
client,
|
|
||||||
user_store,
|
|
||||||
fs,
|
|
||||||
node_runtime: NodeRuntime::unavailable(),
|
|
||||||
prompt_builder,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
|
|
||||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
|
||||||
let model = model_registry
|
|
||||||
.available_models(cx)
|
|
||||||
.find(|model| model.id().0 == model_name);
|
|
||||||
|
|
||||||
let Some(model) = model else {
|
|
||||||
return Err(anyhow!(
|
|
||||||
"No language model named {} was available. Available models: {}",
|
|
||||||
model_name,
|
|
||||||
model_registry
|
|
||||||
.available_models(cx)
|
|
||||||
.map(|model| model.id().0.clone())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join(", ")
|
|
||||||
));
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(model)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn authenticate_model_provider(
|
|
||||||
provider_id: LanguageModelProviderId,
|
|
||||||
cx: &mut App,
|
|
||||||
) -> Task<std::result::Result<(), AuthenticateError>> {
|
|
||||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
|
||||||
let model_provider = model_registry.provider(&provider_id).unwrap();
|
|
||||||
model_provider.authenticate(cx)
|
|
||||||
}
|
|
|
@ -1,74 +1,22 @@
|
||||||
use agent::Agent;
|
mod example;
|
||||||
use anyhow::Result;
|
|
||||||
use gpui::Application;
|
use assistant_settings::AssistantSettings;
|
||||||
use language_model::LanguageModelRegistry;
|
use client::{Client, UserStore};
|
||||||
use reqwest_client::ReqwestClient;
|
pub(crate) use example::*;
|
||||||
use serde::Deserialize;
|
|
||||||
use std::{
|
use ::fs::RealFs;
|
||||||
fs,
|
use anyhow::anyhow;
|
||||||
path::{Path, PathBuf},
|
use gpui::{App, AppContext, Application, Entity, SemanticVersion, Task};
|
||||||
sync::Arc,
|
use language::LanguageRegistry;
|
||||||
|
use language_model::{
|
||||||
|
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||||
};
|
};
|
||||||
mod agent;
|
use node_runtime::NodeRuntime;
|
||||||
|
use project::Project;
|
||||||
#[derive(Debug, Deserialize)]
|
use prompt_store::PromptBuilder;
|
||||||
pub struct ExampleBase {
|
use reqwest_client::ReqwestClient;
|
||||||
pub path: PathBuf,
|
use settings::{Settings, SettingsStore};
|
||||||
pub revision: String,
|
use std::sync::Arc;
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Example {
|
|
||||||
pub base: ExampleBase,
|
|
||||||
|
|
||||||
/// Content of the prompt.md file
|
|
||||||
pub prompt: String,
|
|
||||||
|
|
||||||
/// Content of the rubric.md file
|
|
||||||
pub rubric: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Example {
|
|
||||||
/// Load an example from a directory containing base.toml, prompt.md, and rubric.md
|
|
||||||
pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
|
|
||||||
let base_path = dir_path.as_ref().join("base.toml");
|
|
||||||
let prompt_path = dir_path.as_ref().join("prompt.md");
|
|
||||||
let rubric_path = dir_path.as_ref().join("rubric.md");
|
|
||||||
|
|
||||||
let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
|
|
||||||
base.path = base.path.canonicalize()?;
|
|
||||||
|
|
||||||
Ok(Example {
|
|
||||||
base,
|
|
||||||
prompt: fs::read_to_string(prompt_path)?,
|
|
||||||
rubric: fs::read_to_string(rubric_path)?,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set up the example by checking out the specified Git revision
|
|
||||||
pub fn setup(&self) -> Result<()> {
|
|
||||||
use std::process::Command;
|
|
||||||
|
|
||||||
// Check if the directory exists
|
|
||||||
let path = Path::new(&self.base.path);
|
|
||||||
anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
|
|
||||||
|
|
||||||
// Change to the project directory and checkout the specified revision
|
|
||||||
let output = Command::new("git")
|
|
||||||
.current_dir(&self.base.path)
|
|
||||||
.arg("checkout")
|
|
||||||
.arg(&self.base.revision)
|
|
||||||
.output()?;
|
|
||||||
anyhow::ensure!(
|
|
||||||
output.status.success(),
|
|
||||||
"Failed to checkout revision {}: {}",
|
|
||||||
self.base.revision,
|
|
||||||
String::from_utf8_lossy(&output.stderr),
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
env_logger::init();
|
env_logger::init();
|
||||||
|
@ -76,10 +24,9 @@ fn main() {
|
||||||
let app = Application::headless().with_http_client(http_client.clone());
|
let app = Application::headless().with_http_client(http_client.clone());
|
||||||
|
|
||||||
app.run(move |cx| {
|
app.run(move |cx| {
|
||||||
let app_state = crate::agent::init(cx);
|
let app_state = init(cx);
|
||||||
let _agent = Agent::new(app_state, cx);
|
|
||||||
|
|
||||||
let model = agent::find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
|
let model = find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
|
||||||
|
|
||||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||||
registry.set_default_model(Some(model.clone()), cx);
|
registry.set_default_model(Some(model.clone()), cx);
|
||||||
|
@ -87,15 +34,112 @@ fn main() {
|
||||||
|
|
||||||
let model_provider_id = model.provider_id();
|
let model_provider_id = model.provider_id();
|
||||||
|
|
||||||
let authenticate = agent::authenticate_model_provider(model_provider_id.clone(), cx);
|
let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
|
||||||
|
|
||||||
cx.spawn(async move |_cx| {
|
cx.spawn(async move |cx| {
|
||||||
authenticate.await.unwrap();
|
authenticate.await.unwrap();
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
});
|
|
||||||
|
|
||||||
// let example =
|
let example =
|
||||||
// Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
|
Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
|
||||||
// example.setup()?;
|
example.setup()?;
|
||||||
|
cx.update(|cx| example.run(model, app_state, cx))?.await?;
|
||||||
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
})
|
||||||
|
.detach_and_log_err(cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
|
||||||
|
pub struct AgentAppState {
|
||||||
|
pub languages: Arc<LanguageRegistry>,
|
||||||
|
pub client: Arc<Client>,
|
||||||
|
pub user_store: Entity<UserStore>,
|
||||||
|
pub fs: Arc<dyn fs::Fs>,
|
||||||
|
pub node_runtime: NodeRuntime,
|
||||||
|
|
||||||
|
// Additional fields not present in `workspace::AppState`.
|
||||||
|
pub prompt_builder: Arc<PromptBuilder>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
||||||
|
release_channel::init(SemanticVersion::default(), cx);
|
||||||
|
gpui_tokio::init(cx);
|
||||||
|
|
||||||
|
let mut settings_store = SettingsStore::new(cx);
|
||||||
|
settings_store
|
||||||
|
.set_default_settings(settings::default_settings().as_ref(), cx)
|
||||||
|
.unwrap();
|
||||||
|
cx.set_global(settings_store);
|
||||||
|
client::init_settings(cx);
|
||||||
|
Project::init_settings(cx);
|
||||||
|
|
||||||
|
let client = Client::production(cx);
|
||||||
|
cx.set_http_client(client.http_client().clone());
|
||||||
|
|
||||||
|
let git_binary_path = None;
|
||||||
|
let fs = Arc::new(RealFs::new(
|
||||||
|
git_binary_path,
|
||||||
|
cx.background_executor().clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
|
||||||
|
|
||||||
|
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||||
|
|
||||||
|
language::init(cx);
|
||||||
|
language_model::init(client.clone(), cx);
|
||||||
|
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||||
|
assistant_tools::init(client.http_client().clone(), cx);
|
||||||
|
context_server::init(cx);
|
||||||
|
let stdout_is_a_pty = false;
|
||||||
|
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
|
||||||
|
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
|
||||||
|
|
||||||
|
AssistantSettings::override_global(
|
||||||
|
AssistantSettings {
|
||||||
|
always_allow_tool_actions: true,
|
||||||
|
..AssistantSettings::get_global(cx).clone()
|
||||||
|
},
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
|
||||||
|
Arc::new(AgentAppState {
|
||||||
|
languages,
|
||||||
|
client,
|
||||||
|
user_store,
|
||||||
|
fs,
|
||||||
|
node_runtime: NodeRuntime::unavailable(),
|
||||||
|
prompt_builder,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
|
||||||
|
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||||
|
let model = model_registry
|
||||||
|
.available_models(cx)
|
||||||
|
.find(|model| model.id().0 == model_name);
|
||||||
|
|
||||||
|
let Some(model) = model else {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"No language model named {} was available. Available models: {}",
|
||||||
|
model_name,
|
||||||
|
model_registry
|
||||||
|
.available_models(cx)
|
||||||
|
.map(|model| model.id().0.clone())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ")
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn authenticate_model_provider(
|
||||||
|
provider_id: LanguageModelProviderId,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<std::result::Result<(), AuthenticateError>> {
|
||||||
|
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||||
|
let model_provider = model_registry.provider(&provider_id).unwrap();
|
||||||
|
model_provider.authenticate(cx)
|
||||||
}
|
}
|
||||||
|
|
178
crates/eval/src/example.rs
Normal file
178
crates/eval/src/example.rs
Normal file
|
@ -0,0 +1,178 @@
|
||||||
|
use agent::{RequestKind, ThreadEvent, ThreadStore};
|
||||||
|
use anyhow::{Result, anyhow};
|
||||||
|
use assistant_tool::ToolWorkingSet;
|
||||||
|
use dap::DapRegistry;
|
||||||
|
use futures::channel::oneshot;
|
||||||
|
use gpui::{App, Task};
|
||||||
|
use language_model::{LanguageModel, StopReason};
|
||||||
|
use project::Project;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::process::Command;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::{
|
||||||
|
fs,
|
||||||
|
path::{Path, PathBuf},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::AgentAppState;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ExampleBase {
|
||||||
|
pub path: PathBuf,
|
||||||
|
pub revision: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Example {
|
||||||
|
pub base: ExampleBase,
|
||||||
|
|
||||||
|
/// Content of the prompt.md file
|
||||||
|
pub prompt: String,
|
||||||
|
|
||||||
|
/// Content of the rubric.md file
|
||||||
|
pub _rubric: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Example {
|
||||||
|
/// Load an example from a directory containing base.toml, prompt.md, and rubric.md
|
||||||
|
pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
|
||||||
|
let base_path = dir_path.as_ref().join("base.toml");
|
||||||
|
let prompt_path = dir_path.as_ref().join("prompt.md");
|
||||||
|
let rubric_path = dir_path.as_ref().join("rubric.md");
|
||||||
|
|
||||||
|
let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
|
||||||
|
base.path = base.path.canonicalize()?;
|
||||||
|
|
||||||
|
Ok(Example {
|
||||||
|
base,
|
||||||
|
prompt: fs::read_to_string(prompt_path)?,
|
||||||
|
_rubric: fs::read_to_string(rubric_path)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set up the example by checking out the specified Git revision
|
||||||
|
pub fn setup(&self) -> Result<()> {
|
||||||
|
// Check if the directory exists
|
||||||
|
let path = Path::new(&self.base.path);
|
||||||
|
anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
|
||||||
|
|
||||||
|
// Change to the project directory and checkout the specified revision
|
||||||
|
let output = Command::new("git")
|
||||||
|
.current_dir(&self.base.path)
|
||||||
|
.arg("checkout")
|
||||||
|
.arg(&self.base.revision)
|
||||||
|
.output()?;
|
||||||
|
anyhow::ensure!(
|
||||||
|
output.status.success(),
|
||||||
|
"Failed to checkout revision {}: {}",
|
||||||
|
self.base.revision,
|
||||||
|
String::from_utf8_lossy(&output.stderr),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run(
|
||||||
|
self,
|
||||||
|
model: Arc<dyn LanguageModel>,
|
||||||
|
app_state: Arc<AgentAppState>,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<Result<()>> {
|
||||||
|
let project = Project::local(
|
||||||
|
app_state.client.clone(),
|
||||||
|
app_state.node_runtime.clone(),
|
||||||
|
app_state.user_store.clone(),
|
||||||
|
app_state.languages.clone(),
|
||||||
|
Arc::new(DapRegistry::default()),
|
||||||
|
app_state.fs.clone(),
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
|
||||||
|
let worktree = project.update(cx, |project, cx| {
|
||||||
|
project.create_worktree(self.base.path, true, cx)
|
||||||
|
});
|
||||||
|
|
||||||
|
let tools = Arc::new(ToolWorkingSet::default());
|
||||||
|
let thread_store =
|
||||||
|
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
|
||||||
|
|
||||||
|
println!("USER:");
|
||||||
|
println!("{}", self.prompt);
|
||||||
|
println!("ASSISTANT:");
|
||||||
|
cx.spawn(async move |cx| {
|
||||||
|
worktree.await?;
|
||||||
|
let thread_store = thread_store.await;
|
||||||
|
let thread =
|
||||||
|
thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
|
||||||
|
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
let mut tx = Some(tx);
|
||||||
|
|
||||||
|
let _subscription =
|
||||||
|
cx.subscribe(
|
||||||
|
&thread,
|
||||||
|
move |thread, event: &ThreadEvent, cx| match event {
|
||||||
|
ThreadEvent::Stopped(reason) => match reason {
|
||||||
|
Ok(StopReason::EndTurn) => {
|
||||||
|
if let Some(tx) = tx.take() {
|
||||||
|
tx.send(Ok(())).ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(StopReason::MaxTokens) => {
|
||||||
|
if let Some(tx) = tx.take() {
|
||||||
|
tx.send(Err(anyhow!("Exceeded maximum tokens"))).ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(StopReason::ToolUse) => {}
|
||||||
|
Err(error) => {
|
||||||
|
if let Some(tx) = tx.take() {
|
||||||
|
tx.send(Err(anyhow!(error.clone()))).ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ThreadEvent::ShowError(thread_error) => {
|
||||||
|
if let Some(tx) = tx.take() {
|
||||||
|
tx.send(Err(anyhow!(thread_error.clone()))).ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ThreadEvent::StreamedAssistantText(_, chunk) => {
|
||||||
|
print!("{}", chunk);
|
||||||
|
}
|
||||||
|
ThreadEvent::StreamedAssistantThinking(_, chunk) => {
|
||||||
|
print!("{}", chunk);
|
||||||
|
}
|
||||||
|
ThreadEvent::UsePendingTools { tool_uses } => {
|
||||||
|
println!("\n\nUSING TOOLS:");
|
||||||
|
for tool_use in tool_uses {
|
||||||
|
println!("{}: {}", tool_use.name, tool_use.input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ThreadEvent::ToolFinished {
|
||||||
|
tool_use_id,
|
||||||
|
pending_tool_use,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
if let Some(tool_use) = pending_tool_use {
|
||||||
|
println!("\nTOOL FINISHED: {}", tool_use.name);
|
||||||
|
}
|
||||||
|
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
|
||||||
|
println!("\n{}\n", tool_result.content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
let context = vec![];
|
||||||
|
thread.insert_user_message(self.prompt.clone(), context, None, cx);
|
||||||
|
thread.send_to_model(model, RequestKind::Chat, cx);
|
||||||
|
})?;
|
||||||
|
|
||||||
|
rx.await??;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,7 +1,7 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
AnyView, AnyWindowHandle, App, AppCell, AppContext, BackgroundExecutor, BorrowAppContext,
|
AnyView, AnyWindowHandle, App, AppCell, AppContext, BackgroundExecutor, BorrowAppContext,
|
||||||
Entity, Focusable, ForegroundExecutor, Global, PromptLevel, Render, Reservation, Result, Task,
|
Entity, EventEmitter, Focusable, ForegroundExecutor, Global, PromptLevel, Render, Reservation,
|
||||||
VisualContext, Window, WindowHandle,
|
Result, Subscription, Task, VisualContext, Window, WindowHandle,
|
||||||
};
|
};
|
||||||
use anyhow::{Context as _, anyhow};
|
use anyhow::{Context as _, anyhow};
|
||||||
use derive_more::{Deref, DerefMut};
|
use derive_more::{Deref, DerefMut};
|
||||||
|
@ -154,6 +154,26 @@ impl AsyncApp {
|
||||||
Ok(lock.update(f))
|
Ok(lock.update(f))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Arrange for the given callback to be invoked whenever the given entity emits an event of a given type.
|
||||||
|
/// The callback is provided a handle to the emitting entity and a reference to the emitted event.
|
||||||
|
pub fn subscribe<T, Event>(
|
||||||
|
&mut self,
|
||||||
|
entity: &Entity<T>,
|
||||||
|
mut on_event: impl FnMut(Entity<T>, &Event, &mut App) + 'static,
|
||||||
|
) -> Result<Subscription>
|
||||||
|
where
|
||||||
|
T: 'static + EventEmitter<Event>,
|
||||||
|
Event: 'static,
|
||||||
|
{
|
||||||
|
let app = self
|
||||||
|
.app
|
||||||
|
.upgrade()
|
||||||
|
.ok_or_else(|| anyhow!("app was released"))?;
|
||||||
|
let mut lock = app.borrow_mut();
|
||||||
|
let subscription = lock.subscribe(entity, on_event);
|
||||||
|
Ok(subscription)
|
||||||
|
}
|
||||||
|
|
||||||
/// Open a window with the given options based on the root view returned by the given function.
|
/// Open a window with the given options based on the root view returned by the given function.
|
||||||
pub fn open_window<V>(
|
pub fn open_window<V>(
|
||||||
&self,
|
&self,
|
||||||
|
|
|
@ -16,17 +16,17 @@ use std::{
|
||||||
use text::LineEnding;
|
use text::LineEnding;
|
||||||
use util::{ResultExt, get_system_shell};
|
use util::{ResultExt, get_system_shell};
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
pub struct AssistantSystemPromptContext {
|
pub struct ProjectContext {
|
||||||
pub worktrees: Vec<WorktreeInfoForSystemPrompt>,
|
pub worktrees: Vec<WorktreeContext>,
|
||||||
pub has_rules: bool,
|
pub has_rules: bool,
|
||||||
pub os: String,
|
pub os: String,
|
||||||
pub arch: String,
|
pub arch: String,
|
||||||
pub shell: String,
|
pub shell: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AssistantSystemPromptContext {
|
impl ProjectContext {
|
||||||
pub fn new(worktrees: Vec<WorktreeInfoForSystemPrompt>) -> Self {
|
pub fn new(worktrees: Vec<WorktreeContext>) -> Self {
|
||||||
let has_rules = worktrees
|
let has_rules = worktrees
|
||||||
.iter()
|
.iter()
|
||||||
.any(|worktree| worktree.rules_file.is_some());
|
.any(|worktree| worktree.rules_file.is_some());
|
||||||
|
@ -40,15 +40,15 @@ impl AssistantSystemPromptContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
pub struct WorktreeInfoForSystemPrompt {
|
pub struct WorktreeContext {
|
||||||
pub root_name: String,
|
pub root_name: String,
|
||||||
pub abs_path: Arc<Path>,
|
pub abs_path: Arc<Path>,
|
||||||
pub rules_file: Option<SystemPromptRulesFile>,
|
pub rules_file: Option<RulesFileContext>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
pub struct SystemPromptRulesFile {
|
pub struct RulesFileContext {
|
||||||
pub path_in_worktree: Arc<Path>,
|
pub path_in_worktree: Arc<Path>,
|
||||||
pub abs_path: Arc<Path>,
|
pub abs_path: Arc<Path>,
|
||||||
pub text: String,
|
pub text: String,
|
||||||
|
@ -260,7 +260,7 @@ impl PromptBuilder {
|
||||||
|
|
||||||
pub fn generate_assistant_system_prompt(
|
pub fn generate_assistant_system_prompt(
|
||||||
&self,
|
&self,
|
||||||
context: &AssistantSystemPromptContext,
|
context: &ProjectContext,
|
||||||
) -> Result<String, RenderError> {
|
) -> Result<String, RenderError> {
|
||||||
self.handlebars
|
self.handlebars
|
||||||
.lock()
|
.lock()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue