Add judge to new eval + provide LSP diagnostics (#28713)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <antonio@zed.dev> Co-authored-by: agus <agus@zed.dev>
This commit is contained in:
parent
2603f36737
commit
6b80eb556c
11 changed files with 838 additions and 84 deletions
|
@ -1,32 +1,75 @@
|
|||
mod example;
|
||||
|
||||
use assistant_settings::AssistantSettings;
|
||||
use client::{Client, UserStore};
|
||||
use client::{Client, ProxySettings, UserStore};
|
||||
pub(crate) use example::*;
|
||||
|
||||
use ::fs::RealFs;
|
||||
use anyhow::anyhow;
|
||||
use gpui::{App, AppContext, Application, Entity, SemanticVersion, Task};
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::Parser;
|
||||
use extension::ExtensionHostProxy;
|
||||
use futures::future;
|
||||
use gpui::http_client::{Uri, read_proxy_from_env};
|
||||
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task};
|
||||
use gpui_tokio::Tokio;
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||
};
|
||||
use node_runtime::NodeRuntime;
|
||||
use node_runtime::{NodeBinaryOptions, NodeRuntime};
|
||||
use project::Project;
|
||||
use project::project_settings::ProjectSettings;
|
||||
use prompt_store::PromptBuilder;
|
||||
use release_channel::AppVersion;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::collections::HashSet;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use util::ResultExt as _;
|
||||
|
||||
pub const RUNS_DIR: &str = "./crates/eval/runs";
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "eval", disable_version_flag = true)]
|
||||
struct Args {
|
||||
/// Runs all examples that contain these substrings. If unspecified, all examples are run.
|
||||
#[arg(value_name = "EXAMPLE_SUBSTRING")]
|
||||
examples: Vec<String>,
|
||||
/// Model to use (default: "claude-3-7-sonnet-latest")
|
||||
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
|
||||
model: String,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
env_logger::init();
|
||||
|
||||
let args = Args::parse();
|
||||
let all_available_examples = list_all_examples().unwrap();
|
||||
let example_paths = all_available_examples
|
||||
.iter()
|
||||
.filter_map(|example_path| {
|
||||
let name = example_path.file_name()?.to_string_lossy();
|
||||
if args.examples.is_empty()
|
||||
|| args
|
||||
.examples
|
||||
.iter()
|
||||
.any(|name_substring| name.contains(name_substring))
|
||||
{
|
||||
Some(example_path.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let http_client = Arc::new(ReqwestClient::new());
|
||||
let app = Application::headless().with_http_client(http_client.clone());
|
||||
|
||||
app.run(move |cx| {
|
||||
let app_state = init(cx);
|
||||
|
||||
let model = find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
|
||||
let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
|
||||
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.set_default_model(Some(model.clone()), cx);
|
||||
|
@ -39,17 +82,142 @@ fn main() {
|
|||
cx.spawn(async move |cx| {
|
||||
authenticate.await.unwrap();
|
||||
|
||||
let example =
|
||||
Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
|
||||
example.setup()?;
|
||||
cx.update(|cx| example.run(model, app_state, cx))?.await?;
|
||||
std::fs::create_dir_all(REPOS_DIR)?;
|
||||
std::fs::create_dir_all(WORKTREES_DIR)?;
|
||||
|
||||
anyhow::Ok(())
|
||||
let run_dir = Path::new(RUNS_DIR).join(format!(
|
||||
"{}",
|
||||
chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
|
||||
));
|
||||
std::fs::create_dir_all(&run_dir)?;
|
||||
|
||||
let mut examples = Vec::new();
|
||||
for example_path in example_paths {
|
||||
let example = Example::load_from_directory(&example_path, &run_dir)?;
|
||||
examples.push((example_path, example));
|
||||
}
|
||||
let mut repo_urls = HashSet::new();
|
||||
|
||||
let mut clone_tasks = Vec::new();
|
||||
|
||||
for (_, example) in examples.iter() {
|
||||
let repo_url = example.base.url.clone();
|
||||
if repo_urls.insert(repo_url.clone()) {
|
||||
let repo_path = repo_path_for_url(&repo_url);
|
||||
|
||||
if !repo_path.join(".git").is_dir() {
|
||||
println!("Cloning: {}", repo_url);
|
||||
|
||||
let git_task = cx.spawn(async move |_cx| {
|
||||
std::fs::create_dir_all(&repo_path)?;
|
||||
run_git(&repo_path, &["init"]).await?;
|
||||
run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
|
||||
});
|
||||
|
||||
clone_tasks.push(git_task);
|
||||
} else {
|
||||
println!("Already cloned: {}", repo_url);
|
||||
|
||||
let actual_origin =
|
||||
run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
|
||||
if actual_origin != repo_url {
|
||||
return Err(anyhow!(
|
||||
"remote origin {} does not match expected origin {}",
|
||||
actual_origin,
|
||||
repo_url,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
future::join_all(clone_tasks).await;
|
||||
|
||||
let tasks = examples
|
||||
.into_iter()
|
||||
.map(|(example_path, example)| {
|
||||
let app_state = app_state.clone();
|
||||
let model = model.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
(
|
||||
example_path,
|
||||
run_example(example, model, app_state, cx).await,
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let results: Vec<(PathBuf, Result<JudgeOutput>)> = future::join_all(tasks).await;
|
||||
|
||||
println!("\n\n");
|
||||
println!("========================================");
|
||||
println!(" EVAL RESULTS ");
|
||||
println!("========================================");
|
||||
println!("");
|
||||
|
||||
let mut judge_scores = Vec::new();
|
||||
|
||||
for (example_path, result) in results {
|
||||
let example_name = example_path.file_name().unwrap().to_string_lossy();
|
||||
match result {
|
||||
Err(err) => {
|
||||
println!("💥 {:<30}: {:?}", example_name, err);
|
||||
}
|
||||
Ok(judge_output) => {
|
||||
const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
|
||||
|
||||
println!(
|
||||
"{} {:<30}: {}",
|
||||
SCORES[judge_output.score.min(5) as usize],
|
||||
example_name,
|
||||
judge_output.score,
|
||||
);
|
||||
judge_scores.push(judge_output.score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let score_count = judge_scores.len();
|
||||
let average_score = judge_scores
|
||||
.into_iter()
|
||||
.map(|score| score as f32)
|
||||
.sum::<f32>()
|
||||
/ (score_count as f32);
|
||||
println!("\nAverage score: {average_score}");
|
||||
|
||||
cx.update(|cx| cx.quit())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
});
|
||||
}
|
||||
|
||||
async fn run_example(
|
||||
mut example: Example,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
app_state: Arc<AgentAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<JudgeOutput> {
|
||||
example.setup().await?;
|
||||
cx.update(|cx| example.run(model.clone(), app_state, cx))?
|
||||
.await?;
|
||||
let diff = example.repository_diff().await?;
|
||||
example.judge(model, diff, cx).await
|
||||
}
|
||||
|
||||
fn list_all_examples() -> Result<Vec<PathBuf>> {
|
||||
let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
|
||||
let entries = std::fs::read_dir(path).unwrap();
|
||||
let mut result_paths = Vec::new();
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
result_paths.push(path);
|
||||
}
|
||||
}
|
||||
Ok(result_paths)
|
||||
}
|
||||
|
||||
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
|
||||
pub struct AgentAppState {
|
||||
pub languages: Arc<LanguageRegistry>,
|
||||
|
@ -72,6 +240,27 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
|||
.unwrap();
|
||||
cx.set_global(settings_store);
|
||||
client::init_settings(cx);
|
||||
|
||||
// Set User-Agent so we can download language servers from GitHub
|
||||
let user_agent = format!(
|
||||
"Zed/{} ({}; {})",
|
||||
AppVersion::global(cx),
|
||||
std::env::consts::OS,
|
||||
std::env::consts::ARCH
|
||||
);
|
||||
let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
|
||||
let proxy_url = proxy_str
|
||||
.as_ref()
|
||||
.and_then(|input| input.parse::<Uri>().ok())
|
||||
.or_else(read_proxy_from_env);
|
||||
let http = {
|
||||
let _guard = Tokio::handle(cx).enter();
|
||||
|
||||
ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
|
||||
.expect("could not start HTTP client")
|
||||
};
|
||||
cx.set_http_client(Arc::new(http));
|
||||
|
||||
Project::init_settings(cx);
|
||||
|
||||
let client = Client::production(cx);
|
||||
|
@ -83,13 +272,47 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
|||
cx.background_executor().clone(),
|
||||
));
|
||||
|
||||
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
|
||||
let mut languages = LanguageRegistry::new(cx.background_executor().clone());
|
||||
languages.set_language_server_download_dir(paths::languages_dir().clone());
|
||||
let languages = Arc::new(languages);
|
||||
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
|
||||
extension::init(cx);
|
||||
|
||||
let (tx, rx) = async_watch::channel(None);
|
||||
cx.observe_global::<SettingsStore>(move |cx| {
|
||||
let settings = &ProjectSettings::get_global(cx).node;
|
||||
let options = NodeBinaryOptions {
|
||||
allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
|
||||
allow_binary_download: true,
|
||||
use_paths: settings.path.as_ref().map(|node_path| {
|
||||
let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
|
||||
let npm_path = settings
|
||||
.npm_path
|
||||
.as_ref()
|
||||
.map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
|
||||
(
|
||||
node_path.clone(),
|
||||
npm_path.unwrap_or_else(|| {
|
||||
let base_path = PathBuf::new();
|
||||
node_path.parent().unwrap_or(&base_path).join("npm")
|
||||
}),
|
||||
)
|
||||
}),
|
||||
};
|
||||
tx.send(Some(options)).log_err();
|
||||
})
|
||||
.detach();
|
||||
let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
|
||||
|
||||
let extension_host_proxy = ExtensionHostProxy::global(cx);
|
||||
|
||||
language::init(cx);
|
||||
language_extension::init(extension_host_proxy.clone(), languages.clone());
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||
languages::init(languages.clone(), node_runtime.clone(), cx);
|
||||
assistant_tools::init(client.http_client().clone(), cx);
|
||||
context_server::init(cx);
|
||||
let stdout_is_a_pty = false;
|
||||
|
@ -109,7 +332,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
|||
client,
|
||||
user_store,
|
||||
fs,
|
||||
node_runtime: NodeRuntime::unavailable(),
|
||||
node_runtime,
|
||||
prompt_builder,
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue