Lay the groundwork for a Rust-based eval (#28488)
Also, we moved the logic for driving the agentic loop into `Thread` so that we don't have to re-implement it. Release Notes: - N/A --------- Co-authored-by: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
parent
55760295d9
commit
8ac378b86e
13 changed files with 455 additions and 65 deletions
31
Cargo.lock
generated
31
Cargo.lock
generated
|
@ -4901,6 +4901,37 @@ dependencies = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "eval"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"agent",
|
||||||
|
"anyhow",
|
||||||
|
"assistant_tool",
|
||||||
|
"assistant_tools",
|
||||||
|
"client",
|
||||||
|
"collections",
|
||||||
|
"context_server",
|
||||||
|
"dap",
|
||||||
|
"env_logger 0.11.8",
|
||||||
|
"fs",
|
||||||
|
"gpui",
|
||||||
|
"gpui_tokio",
|
||||||
|
"language",
|
||||||
|
"language_model",
|
||||||
|
"language_models",
|
||||||
|
"node_runtime",
|
||||||
|
"project",
|
||||||
|
"prompt_store",
|
||||||
|
"release_channel",
|
||||||
|
"reqwest_client",
|
||||||
|
"serde",
|
||||||
|
"settings",
|
||||||
|
"smol",
|
||||||
|
"toml 0.8.20",
|
||||||
|
"workspace-hack",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "evals"
|
name = "evals"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
|
@ -47,6 +47,7 @@ members = [
|
||||||
"crates/diagnostics",
|
"crates/diagnostics",
|
||||||
"crates/docs_preprocessor",
|
"crates/docs_preprocessor",
|
||||||
"crates/editor",
|
"crates/editor",
|
||||||
|
"crates/eval",
|
||||||
"crates/evals",
|
"crates/evals",
|
||||||
"crates/extension",
|
"crates/extension",
|
||||||
"crates/extension_api",
|
"crates/extension_api",
|
||||||
|
|
|
@ -21,7 +21,7 @@ use gpui::{
|
||||||
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
|
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
|
||||||
};
|
};
|
||||||
use language::{Buffer, LanguageRegistry};
|
use language::{Buffer, LanguageRegistry};
|
||||||
use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelToolUseId, Role};
|
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
|
||||||
use markdown::parser::CodeBlockKind;
|
use markdown::parser::CodeBlockKind;
|
||||||
use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown, without_fences};
|
use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown, without_fences};
|
||||||
use project::ProjectItem as _;
|
use project::ProjectItem as _;
|
||||||
|
@ -897,11 +897,7 @@ impl ActiveThread {
|
||||||
self.save_thread(cx);
|
self.save_thread(cx);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
ThreadEvent::UsePendingTools => {
|
ThreadEvent::UsePendingTools { tool_uses } => {
|
||||||
let tool_uses = self
|
|
||||||
.thread
|
|
||||||
.update(cx, |thread, cx| thread.use_pending_tools(cx));
|
|
||||||
|
|
||||||
for tool_use in tool_uses {
|
for tool_use in tool_uses {
|
||||||
self.render_tool_use_markdown(
|
self.render_tool_use_markdown(
|
||||||
tool_use.id.clone(),
|
tool_use.id.clone(),
|
||||||
|
@ -913,11 +909,8 @@ impl ActiveThread {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ThreadEvent::ToolFinished {
|
ThreadEvent::ToolFinished {
|
||||||
pending_tool_use,
|
pending_tool_use, ..
|
||||||
canceled,
|
|
||||||
..
|
|
||||||
} => {
|
} => {
|
||||||
let canceled = *canceled;
|
|
||||||
if let Some(tool_use) = pending_tool_use {
|
if let Some(tool_use) = pending_tool_use {
|
||||||
self.render_tool_use_markdown(
|
self.render_tool_use_markdown(
|
||||||
tool_use.id.clone(),
|
tool_use.id.clone(),
|
||||||
|
@ -931,18 +924,6 @@ impl ActiveThread {
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.thread.read(cx).all_tools_finished() {
|
|
||||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
|
||||||
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
|
|
||||||
self.thread.update(cx, |thread, cx| {
|
|
||||||
thread.attach_tool_results(cx);
|
|
||||||
if !canceled {
|
|
||||||
thread.send_to_model(model, RequestKind::Chat, cx);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
ThreadEvent::CheckpointChanged => cx.notify(),
|
ThreadEvent::CheckpointChanged => cx.notify(),
|
||||||
}
|
}
|
||||||
|
|
|
@ -1181,7 +1181,8 @@ impl Thread {
|
||||||
match result.as_ref() {
|
match result.as_ref() {
|
||||||
Ok(stop_reason) => match stop_reason {
|
Ok(stop_reason) => match stop_reason {
|
||||||
StopReason::ToolUse => {
|
StopReason::ToolUse => {
|
||||||
cx.emit(ThreadEvent::UsePendingTools);
|
let tool_uses = thread.use_pending_tools(cx);
|
||||||
|
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
|
||||||
}
|
}
|
||||||
StopReason::EndTurn => {}
|
StopReason::EndTurn => {}
|
||||||
StopReason::MaxTokens => {}
|
StopReason::MaxTokens => {}
|
||||||
|
@ -1369,10 +1370,7 @@ impl Thread {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn use_pending_tools(
|
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
|
||||||
&mut self,
|
|
||||||
cx: &mut Context<Self>,
|
|
||||||
) -> impl IntoIterator<Item = PendingToolUse> + use<> {
|
|
||||||
let request = self.to_completion_request(RequestKind::Chat, cx);
|
let request = self.to_completion_request(RequestKind::Chat, cx);
|
||||||
let messages = Arc::new(request.messages);
|
let messages = Arc::new(request.messages);
|
||||||
let pending_tool_uses = self
|
let pending_tool_uses = self
|
||||||
|
@ -1460,18 +1458,36 @@ impl Thread {
|
||||||
output,
|
output,
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
|
||||||
cx.emit(ThreadEvent::ToolFinished {
|
|
||||||
tool_use_id,
|
|
||||||
pending_tool_use,
|
|
||||||
canceled: false,
|
|
||||||
});
|
|
||||||
})
|
})
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn tool_finished(
|
||||||
|
&mut self,
|
||||||
|
tool_use_id: LanguageModelToolUseId,
|
||||||
|
pending_tool_use: Option<PendingToolUse>,
|
||||||
|
canceled: bool,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
if self.all_tools_finished() {
|
||||||
|
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||||
|
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
|
||||||
|
self.attach_tool_results(cx);
|
||||||
|
if !canceled {
|
||||||
|
self.send_to_model(model, RequestKind::Chat, cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cx.emit(ThreadEvent::ToolFinished {
|
||||||
|
tool_use_id,
|
||||||
|
pending_tool_use,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
|
pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
|
||||||
// Insert a user message to contain the tool results.
|
// Insert a user message to contain the tool results.
|
||||||
self.insert_user_message(
|
self.insert_user_message(
|
||||||
|
@ -1495,11 +1511,12 @@ impl Thread {
|
||||||
let mut canceled = false;
|
let mut canceled = false;
|
||||||
for pending_tool_use in self.tool_use.cancel_pending() {
|
for pending_tool_use in self.tool_use.cancel_pending() {
|
||||||
canceled = true;
|
canceled = true;
|
||||||
cx.emit(ThreadEvent::ToolFinished {
|
self.tool_finished(
|
||||||
tool_use_id: pending_tool_use.id.clone(),
|
pending_tool_use.id.clone(),
|
||||||
pending_tool_use: Some(pending_tool_use),
|
Some(pending_tool_use),
|
||||||
canceled: true,
|
true,
|
||||||
});
|
cx,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
canceled
|
canceled
|
||||||
};
|
};
|
||||||
|
@ -1866,12 +1883,7 @@ impl Thread {
|
||||||
|
|
||||||
self.tool_use
|
self.tool_use
|
||||||
.insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
|
.insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
|
||||||
|
self.tool_finished(tool_use_id.clone(), None, true, cx);
|
||||||
cx.emit(ThreadEvent::ToolFinished {
|
|
||||||
tool_use_id,
|
|
||||||
pending_tool_use: None,
|
|
||||||
canceled: true,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1897,14 +1909,14 @@ pub enum ThreadEvent {
|
||||||
MessageDeleted(MessageId),
|
MessageDeleted(MessageId),
|
||||||
SummaryGenerated,
|
SummaryGenerated,
|
||||||
SummaryChanged,
|
SummaryChanged,
|
||||||
UsePendingTools,
|
UsePendingTools {
|
||||||
|
tool_uses: Vec<PendingToolUse>,
|
||||||
|
},
|
||||||
ToolFinished {
|
ToolFinished {
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
tool_use_id: LanguageModelToolUseId,
|
tool_use_id: LanguageModelToolUseId,
|
||||||
/// The pending tool use that corresponds to this tool.
|
/// The pending tool use that corresponds to this tool.
|
||||||
pending_tool_use: Option<PendingToolUse>,
|
pending_tool_use: Option<PendingToolUse>,
|
||||||
/// Whether the tool was canceled by the user.
|
|
||||||
canceled: bool,
|
|
||||||
},
|
},
|
||||||
CheckpointChanged,
|
CheckpointChanged,
|
||||||
ToolConfirmationNeeded,
|
ToolConfirmationNeeded,
|
||||||
|
|
|
@ -95,11 +95,7 @@ impl HeadlessAssistant {
|
||||||
self.done_tx.send_blocking(Ok(())).unwrap()
|
self.done_tx.send_blocking(Ok(())).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ThreadEvent::UsePendingTools => {
|
ThreadEvent::UsePendingTools { .. } => {}
|
||||||
thread.update(cx, |thread, cx| {
|
|
||||||
thread.use_pending_tools(cx);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
ThreadEvent::ToolConfirmationNeeded => {
|
ThreadEvent::ToolConfirmationNeeded => {
|
||||||
// Automatically approve all tools that need confirmation in headless mode
|
// Automatically approve all tools that need confirmation in headless mode
|
||||||
println!("Tool confirmation needed - automatically approving in headless mode");
|
println!("Tool confirmation needed - automatically approving in headless mode");
|
||||||
|
@ -152,19 +148,6 @@ impl HeadlessAssistant {
|
||||||
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
|
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
|
||||||
println!("Tool result: {:?}", tool_result);
|
println!("Tool result: {:?}", tool_result);
|
||||||
}
|
}
|
||||||
if thread.read(cx).all_tools_finished() {
|
|
||||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
|
||||||
if let Some(model) = model_registry.default_model() {
|
|
||||||
thread.update(cx, |thread, cx| {
|
|
||||||
thread.attach_tool_results(cx);
|
|
||||||
thread.send_to_model(model.model, RequestKind::Chat, cx);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
println!(
|
|
||||||
"Warning: No active language model available to continue conversation"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
39
crates/eval/Cargo.toml
Normal file
39
crates/eval/Cargo.toml
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
[package]
|
||||||
|
name = "eval"
|
||||||
|
version = "0.1.0"
|
||||||
|
publish.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
agent.workspace = true
|
||||||
|
anyhow.workspace = true
|
||||||
|
assistant_tool.workspace = true
|
||||||
|
assistant_tools.workspace = true
|
||||||
|
client.workspace = true
|
||||||
|
collections.workspace = true
|
||||||
|
context_server.workspace = true
|
||||||
|
dap.workspace = true
|
||||||
|
env_logger.workspace = true
|
||||||
|
fs.workspace = true
|
||||||
|
gpui.workspace = true
|
||||||
|
gpui_tokio.workspace = true
|
||||||
|
language.workspace = true
|
||||||
|
language_model.workspace = true
|
||||||
|
language_models.workspace = true
|
||||||
|
node_runtime.workspace = true
|
||||||
|
project.workspace = true
|
||||||
|
prompt_store.workspace = true
|
||||||
|
release_channel.workspace = true
|
||||||
|
reqwest_client.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
settings.workspace = true
|
||||||
|
smol.workspace = true
|
||||||
|
toml.workspace = true
|
||||||
|
workspace-hack.workspace = true
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "eval"
|
||||||
|
path = "src/eval.rs"
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
1
crates/eval/LICENSE-GPL
Symbolic link
1
crates/eval/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
../../LICENSE-GPL
|
7
crates/eval/README.md
Normal file
7
crates/eval/README.md
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Eval
|
||||||
|
|
||||||
|
This eval assumes the working directory is the root of the repository. Run it with:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cargo run -p eval
|
||||||
|
```
|
|
@ -0,0 +1,2 @@
|
||||||
|
path = "../zed_worktree"
|
||||||
|
revision = "38fcadf9481d018543c65f36ac3bafeba190179b"
|
|
@ -0,0 +1,3 @@
|
||||||
|
Look at the `find_replace_file_tool.rs`. I want to implement a card for it. The card should be a brand new `Entity` with a `Render` implementation.
|
||||||
|
|
||||||
|
The card should show a diff. It should be a beautifully presented diff. The card "box" should look like what we show for markdown codeblocks (look at `MarkdownElement`). I want to see a red background for lines that were deleted and a green background for lines that were added. We should have a div per diff line.
|
229
crates/eval/src/agent.rs
Normal file
229
crates/eval/src/agent.rs
Normal file
|
@ -0,0 +1,229 @@
|
||||||
|
use ::agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
|
||||||
|
use anyhow::anyhow;
|
||||||
|
use assistant_tool::ToolWorkingSet;
|
||||||
|
use client::{Client, UserStore};
|
||||||
|
use collections::HashMap;
|
||||||
|
use dap::DapRegistry;
|
||||||
|
use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
|
||||||
|
use language::LanguageRegistry;
|
||||||
|
use language_model::{
|
||||||
|
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||||
|
};
|
||||||
|
use node_runtime::NodeRuntime;
|
||||||
|
use project::{Project, RealFs};
|
||||||
|
use prompt_store::PromptBuilder;
|
||||||
|
use settings::SettingsStore;
|
||||||
|
use smol::channel;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
|
||||||
|
pub struct AgentAppState {
|
||||||
|
pub languages: Arc<LanguageRegistry>,
|
||||||
|
pub client: Arc<Client>,
|
||||||
|
pub user_store: Entity<UserStore>,
|
||||||
|
pub fs: Arc<dyn fs::Fs>,
|
||||||
|
pub node_runtime: NodeRuntime,
|
||||||
|
|
||||||
|
// Additional fields not present in `workspace::AppState`.
|
||||||
|
pub prompt_builder: Arc<PromptBuilder>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Agent {
|
||||||
|
// pub thread: Entity<Thread>,
|
||||||
|
// pub project: Entity<Project>,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub thread_store: Entity<ThreadStore>,
|
||||||
|
pub tool_use_counts: HashMap<Arc<str>, u32>,
|
||||||
|
pub done_tx: channel::Sender<anyhow::Result<()>>,
|
||||||
|
_subscription: Subscription,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Agent {
|
||||||
|
pub fn new(
|
||||||
|
app_state: Arc<AgentAppState>,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
|
||||||
|
let env = None;
|
||||||
|
let project = Project::local(
|
||||||
|
app_state.client.clone(),
|
||||||
|
app_state.node_runtime.clone(),
|
||||||
|
app_state.user_store.clone(),
|
||||||
|
app_state.languages.clone(),
|
||||||
|
Arc::new(DapRegistry::default()),
|
||||||
|
app_state.fs.clone(),
|
||||||
|
env,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
|
||||||
|
let tools = Arc::new(ToolWorkingSet::default());
|
||||||
|
let thread_store =
|
||||||
|
ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
|
||||||
|
|
||||||
|
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
|
||||||
|
|
||||||
|
let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
|
||||||
|
|
||||||
|
let headless_thread = cx.new(move |cx| Self {
|
||||||
|
_subscription: cx.subscribe(&thread, Self::handle_thread_event),
|
||||||
|
// thread,
|
||||||
|
// project,
|
||||||
|
thread_store,
|
||||||
|
tool_use_counts: HashMap::default(),
|
||||||
|
done_tx,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok((headless_thread, done_rx))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_thread_event(
|
||||||
|
&mut self,
|
||||||
|
thread: Entity<Thread>,
|
||||||
|
event: &ThreadEvent,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
match event {
|
||||||
|
ThreadEvent::ShowError(err) => self
|
||||||
|
.done_tx
|
||||||
|
.send_blocking(Err(anyhow!("{:?}", err)))
|
||||||
|
.unwrap(),
|
||||||
|
ThreadEvent::DoneStreaming => {
|
||||||
|
let thread = thread.read(cx);
|
||||||
|
if let Some(message) = thread.messages().last() {
|
||||||
|
println!("Message: {}", message.to_string());
|
||||||
|
}
|
||||||
|
if thread.all_tools_finished() {
|
||||||
|
self.done_tx.send_blocking(Ok(())).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ThreadEvent::UsePendingTools { .. } => {}
|
||||||
|
ThreadEvent::ToolConfirmationNeeded => {
|
||||||
|
// Automatically approve all tools that need confirmation in headless mode
|
||||||
|
println!("Tool confirmation needed - automatically approving in headless mode");
|
||||||
|
|
||||||
|
// Get the tools needing confirmation
|
||||||
|
let tools_needing_confirmation: Vec<_> = thread
|
||||||
|
.read(cx)
|
||||||
|
.tools_needing_confirmation()
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Run each tool that needs confirmation
|
||||||
|
for tool_use in tools_needing_confirmation {
|
||||||
|
if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
println!("Auto-approving tool: {}", tool_use.name);
|
||||||
|
|
||||||
|
// Create a request to send to the tool
|
||||||
|
let request = thread.to_completion_request(RequestKind::Chat, cx);
|
||||||
|
let messages = Arc::new(request.messages);
|
||||||
|
|
||||||
|
// Run the tool
|
||||||
|
thread.run_tool(
|
||||||
|
tool_use.id.clone(),
|
||||||
|
tool_use.ui_text.clone(),
|
||||||
|
tool_use.input.clone(),
|
||||||
|
&messages,
|
||||||
|
tool,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ThreadEvent::ToolFinished {
|
||||||
|
tool_use_id,
|
||||||
|
pending_tool_use,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
if let Some(pending_tool_use) = pending_tool_use {
|
||||||
|
println!(
|
||||||
|
"Used tool {} with input: {}",
|
||||||
|
pending_tool_use.name, pending_tool_use.input
|
||||||
|
);
|
||||||
|
*self
|
||||||
|
.tool_use_counts
|
||||||
|
.entry(pending_tool_use.name.clone())
|
||||||
|
.or_insert(0) += 1;
|
||||||
|
}
|
||||||
|
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
|
||||||
|
println!("Tool result: {:?}", tool_result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
||||||
|
release_channel::init(SemanticVersion::default(), cx);
|
||||||
|
gpui_tokio::init(cx);
|
||||||
|
|
||||||
|
let mut settings_store = SettingsStore::new(cx);
|
||||||
|
settings_store
|
||||||
|
.set_default_settings(settings::default_settings().as_ref(), cx)
|
||||||
|
.unwrap();
|
||||||
|
cx.set_global(settings_store);
|
||||||
|
client::init_settings(cx);
|
||||||
|
Project::init_settings(cx);
|
||||||
|
|
||||||
|
let client = Client::production(cx);
|
||||||
|
cx.set_http_client(client.http_client().clone());
|
||||||
|
|
||||||
|
let git_binary_path = None;
|
||||||
|
let fs = Arc::new(RealFs::new(
|
||||||
|
git_binary_path,
|
||||||
|
cx.background_executor().clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
|
||||||
|
|
||||||
|
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||||
|
|
||||||
|
language::init(cx);
|
||||||
|
language_model::init(client.clone(), cx);
|
||||||
|
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||||
|
assistant_tools::init(client.http_client().clone(), cx);
|
||||||
|
context_server::init(cx);
|
||||||
|
let stdout_is_a_pty = false;
|
||||||
|
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
|
||||||
|
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
|
||||||
|
|
||||||
|
Arc::new(AgentAppState {
|
||||||
|
languages,
|
||||||
|
client,
|
||||||
|
user_store,
|
||||||
|
fs,
|
||||||
|
node_runtime: NodeRuntime::unavailable(),
|
||||||
|
prompt_builder,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
|
||||||
|
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||||
|
let model = model_registry
|
||||||
|
.available_models(cx)
|
||||||
|
.find(|model| model.id().0 == model_name);
|
||||||
|
|
||||||
|
let Some(model) = model else {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"No language model named {} was available. Available models: {}",
|
||||||
|
model_name,
|
||||||
|
model_registry
|
||||||
|
.available_models(cx)
|
||||||
|
.map(|model| model.id().0.clone())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ")
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn authenticate_model_provider(
|
||||||
|
provider_id: LanguageModelProviderId,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<std::result::Result<(), AuthenticateError>> {
|
||||||
|
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||||
|
let model_provider = model_registry.provider(&provider_id).unwrap();
|
||||||
|
model_provider.authenticate(cx)
|
||||||
|
}
|
101
crates/eval/src/eval.rs
Normal file
101
crates/eval/src/eval.rs
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
use agent::Agent;
|
||||||
|
use anyhow::Result;
|
||||||
|
use gpui::Application;
|
||||||
|
use language_model::LanguageModelRegistry;
|
||||||
|
use reqwest_client::ReqwestClient;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::{
|
||||||
|
fs,
|
||||||
|
path::{Path, PathBuf},
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
mod agent;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ExampleBase {
|
||||||
|
pub path: PathBuf,
|
||||||
|
pub revision: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Example {
|
||||||
|
pub base: ExampleBase,
|
||||||
|
|
||||||
|
/// Content of the prompt.md file
|
||||||
|
pub prompt: String,
|
||||||
|
|
||||||
|
/// Content of the rubric.md file
|
||||||
|
pub rubric: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Example {
|
||||||
|
/// Load an example from a directory containing base.toml, prompt.md, and rubric.md
|
||||||
|
pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
|
||||||
|
let base_path = dir_path.as_ref().join("base.toml");
|
||||||
|
let prompt_path = dir_path.as_ref().join("prompt.md");
|
||||||
|
let rubric_path = dir_path.as_ref().join("rubric.md");
|
||||||
|
|
||||||
|
let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
|
||||||
|
base.path = base.path.canonicalize()?;
|
||||||
|
|
||||||
|
Ok(Example {
|
||||||
|
base,
|
||||||
|
prompt: fs::read_to_string(prompt_path)?,
|
||||||
|
rubric: fs::read_to_string(rubric_path)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set up the example by checking out the specified Git revision
|
||||||
|
pub fn setup(&self) -> Result<()> {
|
||||||
|
use std::process::Command;
|
||||||
|
|
||||||
|
// Check if the directory exists
|
||||||
|
let path = Path::new(&self.base.path);
|
||||||
|
anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
|
||||||
|
|
||||||
|
// Change to the project directory and checkout the specified revision
|
||||||
|
let output = Command::new("git")
|
||||||
|
.current_dir(&self.base.path)
|
||||||
|
.arg("checkout")
|
||||||
|
.arg(&self.base.revision)
|
||||||
|
.output()?;
|
||||||
|
anyhow::ensure!(
|
||||||
|
output.status.success(),
|
||||||
|
"Failed to checkout revision {}: {}",
|
||||||
|
self.base.revision,
|
||||||
|
String::from_utf8_lossy(&output.stderr),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
env_logger::init();
|
||||||
|
let http_client = Arc::new(ReqwestClient::new());
|
||||||
|
let app = Application::headless().with_http_client(http_client.clone());
|
||||||
|
|
||||||
|
app.run(move |cx| {
|
||||||
|
let app_state = crate::agent::init(cx);
|
||||||
|
let _agent = Agent::new(app_state, cx);
|
||||||
|
|
||||||
|
let model = agent::find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
|
||||||
|
|
||||||
|
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||||
|
registry.set_default_model(Some(model.clone()), cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
let model_provider_id = model.provider_id();
|
||||||
|
|
||||||
|
let authenticate = agent::authenticate_model_provider(model_provider_id.clone(), cx);
|
||||||
|
|
||||||
|
cx.spawn(async move |_cx| {
|
||||||
|
authenticate.await.unwrap();
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
});
|
||||||
|
|
||||||
|
// let example =
|
||||||
|
// Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
|
||||||
|
// example.setup()?;
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue