Actually run the eval and fix a hang when retrieving outline (#28547)

Release Notes:

- Fixed a regression that caused the agent to hang sometimes.

---------

Co-authored-by: Thomas Mickley-Doyle <tmickleydoyle@gmail.com>
Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Michael Sloan <mgsloan@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-04-10 18:01:33 -06:00 committed by GitHub
parent c0262cf62f
commit 2440faf4b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 642 additions and 1862 deletions

View file

@ -9,12 +9,13 @@ agent.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
assistant_settings.workspace = true
client.workspace = true
collections.workspace = true
context_server.workspace = true
dap.workspace = true
env_logger.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
gpui_tokio.workspace = true
language.workspace = true
@ -27,7 +28,6 @@ release_channel.workspace = true
reqwest_client.workspace = true
serde.workspace = true
settings.workspace = true
smol.workspace = true
toml.workspace = true
workspace-hack.workspace = true

View file

@ -1,229 +0,0 @@
use ::agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
use anyhow::anyhow;
use assistant_tool::ToolWorkingSet;
use client::{Client, UserStore};
use collections::HashMap;
use dap::DapRegistry;
use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
use language::LanguageRegistry;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
};
use node_runtime::NodeRuntime;
use project::{Project, RealFs};
use prompt_store::PromptBuilder;
use settings::SettingsStore;
use smol::channel;
use std::sync::Arc;
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
pub struct AgentAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
pub fs: Arc<dyn fs::Fs>,
pub node_runtime: NodeRuntime,
// Additional fields not present in `workspace::AppState`.
pub prompt_builder: Arc<PromptBuilder>,
}
pub struct Agent {
// pub thread: Entity<Thread>,
// pub project: Entity<Project>,
#[allow(dead_code)]
pub thread_store: Entity<ThreadStore>,
pub tool_use_counts: HashMap<Arc<str>, u32>,
pub done_tx: channel::Sender<anyhow::Result<()>>,
_subscription: Subscription,
}
impl Agent {
pub fn new(
app_state: Arc<AgentAppState>,
cx: &mut App,
) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
let env = None;
let project = Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
Arc::new(DapRegistry::default()),
app_state.fs.clone(),
env,
cx,
);
let tools = Arc::new(ToolWorkingSet::default());
let thread_store =
ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
let headless_thread = cx.new(move |cx| Self {
_subscription: cx.subscribe(&thread, Self::handle_thread_event),
// thread,
// project,
thread_store,
tool_use_counts: HashMap::default(),
done_tx,
});
Ok((headless_thread, done_rx))
}
fn handle_thread_event(
&mut self,
thread: Entity<Thread>,
event: &ThreadEvent,
cx: &mut Context<Self>,
) {
match event {
ThreadEvent::ShowError(err) => self
.done_tx
.send_blocking(Err(anyhow!("{:?}", err)))
.unwrap(),
ThreadEvent::DoneStreaming => {
let thread = thread.read(cx);
if let Some(message) = thread.messages().last() {
println!("Message: {}", message.to_string());
}
if thread.all_tools_finished() {
self.done_tx.send_blocking(Ok(())).unwrap()
}
}
ThreadEvent::UsePendingTools { .. } => {}
ThreadEvent::ToolConfirmationNeeded => {
// Automatically approve all tools that need confirmation in headless mode
println!("Tool confirmation needed - automatically approving in headless mode");
// Get the tools needing confirmation
let tools_needing_confirmation: Vec<_> = thread
.read(cx)
.tools_needing_confirmation()
.cloned()
.collect();
// Run each tool that needs confirmation
for tool_use in tools_needing_confirmation {
if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
thread.update(cx, |thread, cx| {
println!("Auto-approving tool: {}", tool_use.name);
// Create a request to send to the tool
let request = thread.to_completion_request(RequestKind::Chat, cx);
let messages = Arc::new(request.messages);
// Run the tool
thread.run_tool(
tool_use.id.clone(),
tool_use.ui_text.clone(),
tool_use.input.clone(),
&messages,
tool,
cx,
);
});
}
}
}
ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
..
} => {
if let Some(pending_tool_use) = pending_tool_use {
println!(
"Used tool {} with input: {}",
pending_tool_use.name, pending_tool_use.input
);
*self
.tool_use_counts
.entry(pending_tool_use.name.clone())
.or_insert(0) += 1;
}
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
println!("Tool result: {:?}", tool_result);
}
}
_ => {}
}
}
}
pub fn init(cx: &mut App) -> Arc<AgentAppState> {
release_channel::init(SemanticVersion::default(), cx);
gpui_tokio::init(cx);
let mut settings_store = SettingsStore::new(cx);
settings_store
.set_default_settings(settings::default_settings().as_ref(), cx)
.unwrap();
cx.set_global(settings_store);
client::init_settings(cx);
Project::init_settings(cx);
let client = Client::production(cx);
cx.set_http_client(client.http_client().clone());
let git_binary_path = None;
let fs = Arc::new(RealFs::new(
git_binary_path,
cx.background_executor().clone(),
));
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language::init(cx);
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
assistant_tools::init(client.http_client().clone(), cx);
context_server::init(cx);
let stdout_is_a_pty = false;
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
Arc::new(AgentAppState {
languages,
client,
user_store,
fs,
node_runtime: NodeRuntime::unavailable(),
prompt_builder,
})
}
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model = model_registry
.available_models(cx)
.find(|model| model.id().0 == model_name);
let Some(model) = model else {
return Err(anyhow!(
"No language model named {} was available. Available models: {}",
model_name,
model_registry
.available_models(cx)
.map(|model| model.id().0.clone())
.collect::<Vec<_>>()
.join(", ")
));
};
Ok(model)
}
pub fn authenticate_model_provider(
provider_id: LanguageModelProviderId,
cx: &mut App,
) -> Task<std::result::Result<(), AuthenticateError>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model_provider = model_registry.provider(&provider_id).unwrap();
model_provider.authenticate(cx)
}

View file

@ -1,74 +1,22 @@
use agent::Agent;
use anyhow::Result;
use gpui::Application;
use language_model::LanguageModelRegistry;
use reqwest_client::ReqwestClient;
use serde::Deserialize;
use std::{
fs,
path::{Path, PathBuf},
sync::Arc,
mod example;
use assistant_settings::AssistantSettings;
use client::{Client, UserStore};
pub(crate) use example::*;
use ::fs::RealFs;
use anyhow::anyhow;
use gpui::{App, AppContext, Application, Entity, SemanticVersion, Task};
use language::LanguageRegistry;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
};
mod agent;
#[derive(Debug, Deserialize)]
pub struct ExampleBase {
pub path: PathBuf,
pub revision: String,
}
#[derive(Debug)]
pub struct Example {
pub base: ExampleBase,
/// Content of the prompt.md file
pub prompt: String,
/// Content of the rubric.md file
pub rubric: String,
}
impl Example {
/// Load an example from a directory containing base.toml, prompt.md, and rubric.md
pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
let base_path = dir_path.as_ref().join("base.toml");
let prompt_path = dir_path.as_ref().join("prompt.md");
let rubric_path = dir_path.as_ref().join("rubric.md");
let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
base.path = base.path.canonicalize()?;
Ok(Example {
base,
prompt: fs::read_to_string(prompt_path)?,
rubric: fs::read_to_string(rubric_path)?,
})
}
/// Set up the example by checking out the specified Git revision
pub fn setup(&self) -> Result<()> {
use std::process::Command;
// Check if the directory exists
let path = Path::new(&self.base.path);
anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
// Change to the project directory and checkout the specified revision
let output = Command::new("git")
.current_dir(&self.base.path)
.arg("checkout")
.arg(&self.base.revision)
.output()?;
anyhow::ensure!(
output.status.success(),
"Failed to checkout revision {}: {}",
self.base.revision,
String::from_utf8_lossy(&output.stderr),
);
Ok(())
}
}
use node_runtime::NodeRuntime;
use project::Project;
use prompt_store::PromptBuilder;
use reqwest_client::ReqwestClient;
use settings::{Settings, SettingsStore};
use std::sync::Arc;
fn main() {
env_logger::init();
@ -76,10 +24,9 @@ fn main() {
let app = Application::headless().with_http_client(http_client.clone());
app.run(move |cx| {
let app_state = crate::agent::init(cx);
let _agent = Agent::new(app_state, cx);
let app_state = init(cx);
let model = agent::find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
let model = find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(Some(model.clone()), cx);
@ -87,15 +34,112 @@ fn main() {
let model_provider_id = model.provider_id();
let authenticate = agent::authenticate_model_provider(model_provider_id.clone(), cx);
let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
cx.spawn(async move |_cx| {
cx.spawn(async move |cx| {
authenticate.await.unwrap();
})
.detach();
});
// let example =
// Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
// example.setup()?;
let example =
Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
example.setup()?;
cx.update(|cx| example.run(model, app_state, cx))?.await?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
});
}
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
pub struct AgentAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
pub fs: Arc<dyn fs::Fs>,
pub node_runtime: NodeRuntime,
// Additional fields not present in `workspace::AppState`.
pub prompt_builder: Arc<PromptBuilder>,
}
pub fn init(cx: &mut App) -> Arc<AgentAppState> {
release_channel::init(SemanticVersion::default(), cx);
gpui_tokio::init(cx);
let mut settings_store = SettingsStore::new(cx);
settings_store
.set_default_settings(settings::default_settings().as_ref(), cx)
.unwrap();
cx.set_global(settings_store);
client::init_settings(cx);
Project::init_settings(cx);
let client = Client::production(cx);
cx.set_http_client(client.http_client().clone());
let git_binary_path = None;
let fs = Arc::new(RealFs::new(
git_binary_path,
cx.background_executor().clone(),
));
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language::init(cx);
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
assistant_tools::init(client.http_client().clone(), cx);
context_server::init(cx);
let stdout_is_a_pty = false;
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
AssistantSettings::override_global(
AssistantSettings {
always_allow_tool_actions: true,
..AssistantSettings::get_global(cx).clone()
},
cx,
);
Arc::new(AgentAppState {
languages,
client,
user_store,
fs,
node_runtime: NodeRuntime::unavailable(),
prompt_builder,
})
}
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model = model_registry
.available_models(cx)
.find(|model| model.id().0 == model_name);
let Some(model) = model else {
return Err(anyhow!(
"No language model named {} was available. Available models: {}",
model_name,
model_registry
.available_models(cx)
.map(|model| model.id().0.clone())
.collect::<Vec<_>>()
.join(", ")
));
};
Ok(model)
}
pub fn authenticate_model_provider(
provider_id: LanguageModelProviderId,
cx: &mut App,
) -> Task<std::result::Result<(), AuthenticateError>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model_provider = model_registry.provider(&provider_id).unwrap();
model_provider.authenticate(cx)
}

178
crates/eval/src/example.rs Normal file
View file

@ -0,0 +1,178 @@
use agent::{RequestKind, ThreadEvent, ThreadStore};
use anyhow::{Result, anyhow};
use assistant_tool::ToolWorkingSet;
use dap::DapRegistry;
use futures::channel::oneshot;
use gpui::{App, Task};
use language_model::{LanguageModel, StopReason};
use project::Project;
use serde::Deserialize;
use std::process::Command;
use std::sync::Arc;
use std::{
fs,
path::{Path, PathBuf},
};
use crate::AgentAppState;
#[derive(Debug, Deserialize)]
pub struct ExampleBase {
pub path: PathBuf,
pub revision: String,
}
#[derive(Debug)]
pub struct Example {
pub base: ExampleBase,
/// Content of the prompt.md file
pub prompt: String,
/// Content of the rubric.md file
pub _rubric: String,
}
impl Example {
/// Load an example from a directory containing base.toml, prompt.md, and rubric.md
pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
let base_path = dir_path.as_ref().join("base.toml");
let prompt_path = dir_path.as_ref().join("prompt.md");
let rubric_path = dir_path.as_ref().join("rubric.md");
let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
base.path = base.path.canonicalize()?;
Ok(Example {
base,
prompt: fs::read_to_string(prompt_path)?,
_rubric: fs::read_to_string(rubric_path)?,
})
}
/// Set up the example by checking out the specified Git revision
pub fn setup(&self) -> Result<()> {
// Check if the directory exists
let path = Path::new(&self.base.path);
anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
// Change to the project directory and checkout the specified revision
let output = Command::new("git")
.current_dir(&self.base.path)
.arg("checkout")
.arg(&self.base.revision)
.output()?;
anyhow::ensure!(
output.status.success(),
"Failed to checkout revision {}: {}",
self.base.revision,
String::from_utf8_lossy(&output.stderr),
);
Ok(())
}
pub fn run(
self,
model: Arc<dyn LanguageModel>,
app_state: Arc<AgentAppState>,
cx: &mut App,
) -> Task<Result<()>> {
let project = Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
Arc::new(DapRegistry::default()),
app_state.fs.clone(),
None,
cx,
);
let worktree = project.update(cx, |project, cx| {
project.create_worktree(self.base.path, true, cx)
});
let tools = Arc::new(ToolWorkingSet::default());
let thread_store =
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
println!("USER:");
println!("{}", self.prompt);
println!("ASSISTANT:");
cx.spawn(async move |cx| {
worktree.await?;
let thread_store = thread_store.await;
let thread =
thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
let (tx, rx) = oneshot::channel();
let mut tx = Some(tx);
let _subscription =
cx.subscribe(
&thread,
move |thread, event: &ThreadEvent, cx| match event {
ThreadEvent::Stopped(reason) => match reason {
Ok(StopReason::EndTurn) => {
if let Some(tx) = tx.take() {
tx.send(Ok(())).ok();
}
}
Ok(StopReason::MaxTokens) => {
if let Some(tx) = tx.take() {
tx.send(Err(anyhow!("Exceeded maximum tokens"))).ok();
}
}
Ok(StopReason::ToolUse) => {}
Err(error) => {
if let Some(tx) = tx.take() {
tx.send(Err(anyhow!(error.clone()))).ok();
}
}
},
ThreadEvent::ShowError(thread_error) => {
if let Some(tx) = tx.take() {
tx.send(Err(anyhow!(thread_error.clone()))).ok();
}
}
ThreadEvent::StreamedAssistantText(_, chunk) => {
print!("{}", chunk);
}
ThreadEvent::StreamedAssistantThinking(_, chunk) => {
print!("{}", chunk);
}
ThreadEvent::UsePendingTools { tool_uses } => {
println!("\n\nUSING TOOLS:");
for tool_use in tool_uses {
println!("{}: {}", tool_use.name, tool_use.input);
}
}
ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
..
} => {
if let Some(tool_use) = pending_tool_use {
println!("\nTOOL FINISHED: {}", tool_use.name);
}
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
println!("\n{}\n", tool_result.content);
}
}
_ => {}
},
)?;
thread.update(cx, |thread, cx| {
let context = vec![];
thread.insert_user_message(self.prompt.clone(), context, None, cx);
thread.send_to_model(model, RequestKind::Chat, cx);
})?;
rx.await??;
Ok(())
})
}
}