Add judge to new eval + provide LSP diagnostics (#28713)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <antonio@zed.dev> Co-authored-by: agus <agus@zed.dev>
This commit is contained in:
parent
2603f36737
commit
6b80eb556c
11 changed files with 838 additions and 84 deletions
11
Cargo.lock
generated
11
Cargo.lock
generated
|
@ -4878,25 +4878,36 @@ dependencies = [
|
||||||
"assistant_settings",
|
"assistant_settings",
|
||||||
"assistant_tool",
|
"assistant_tool",
|
||||||
"assistant_tools",
|
"assistant_tools",
|
||||||
|
"async-watch",
|
||||||
|
"chrono",
|
||||||
|
"clap",
|
||||||
"client",
|
"client",
|
||||||
"context_server",
|
"context_server",
|
||||||
"dap",
|
"dap",
|
||||||
"env_logger 0.11.8",
|
"env_logger 0.11.8",
|
||||||
|
"extension",
|
||||||
"fs",
|
"fs",
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
"gpui",
|
"gpui",
|
||||||
"gpui_tokio",
|
"gpui_tokio",
|
||||||
|
"handlebars 4.5.0",
|
||||||
"language",
|
"language",
|
||||||
|
"language_extension",
|
||||||
"language_model",
|
"language_model",
|
||||||
"language_models",
|
"language_models",
|
||||||
|
"languages",
|
||||||
"node_runtime",
|
"node_runtime",
|
||||||
|
"paths",
|
||||||
"project",
|
"project",
|
||||||
"prompt_store",
|
"prompt_store",
|
||||||
"release_channel",
|
"release_channel",
|
||||||
"reqwest_client",
|
"reqwest_client",
|
||||||
"serde",
|
"serde",
|
||||||
"settings",
|
"settings",
|
||||||
|
"shellexpand 2.1.2",
|
||||||
"toml 0.8.20",
|
"toml 0.8.20",
|
||||||
|
"unindent",
|
||||||
|
"util",
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -827,7 +827,7 @@ impl Thread {
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
initial_project_snapshot,
|
initial_project_snapshot,
|
||||||
cumulative_token_usage: this.cumulative_token_usage.clone(),
|
cumulative_token_usage: this.cumulative_token_usage,
|
||||||
detailed_summary_state: this.detailed_summary_state.clone(),
|
detailed_summary_state: this.detailed_summary_state.clone(),
|
||||||
exceeded_window_error: this.exceeded_window_error.clone(),
|
exceeded_window_error: this.exceeded_window_error.clone(),
|
||||||
})
|
})
|
||||||
|
@ -1016,7 +1016,7 @@ impl Thread {
|
||||||
let task = cx.spawn(async move |thread, cx| {
|
let task = cx.spawn(async move |thread, cx| {
|
||||||
let stream = model.stream_completion(request, &cx);
|
let stream = model.stream_completion(request, &cx);
|
||||||
let initial_token_usage =
|
let initial_token_usage =
|
||||||
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
|
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
|
||||||
let stream_completion = async {
|
let stream_completion = async {
|
||||||
let mut events = stream.await?;
|
let mut events = stream.await?;
|
||||||
let mut stop_reason = StopReason::EndTurn;
|
let mut stop_reason = StopReason::EndTurn;
|
||||||
|
@ -1038,9 +1038,9 @@ impl Thread {
|
||||||
stop_reason = reason;
|
stop_reason = reason;
|
||||||
}
|
}
|
||||||
LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
|
LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
|
||||||
thread.cumulative_token_usage =
|
thread.cumulative_token_usage = thread.cumulative_token_usage
|
||||||
thread.cumulative_token_usage.clone() + token_usage.clone()
|
+ token_usage
|
||||||
- current_token_usage.clone();
|
- current_token_usage;
|
||||||
current_token_usage = token_usage;
|
current_token_usage = token_usage;
|
||||||
}
|
}
|
||||||
LanguageModelCompletionEvent::Text(chunk) => {
|
LanguageModelCompletionEvent::Text(chunk) => {
|
||||||
|
@ -1183,7 +1183,7 @@ impl Thread {
|
||||||
thread.auto_capture_telemetry(cx);
|
thread.auto_capture_telemetry(cx);
|
||||||
|
|
||||||
if let Ok(initial_usage) = initial_token_usage {
|
if let Ok(initial_usage) = initial_token_usage {
|
||||||
let usage = thread.cumulative_token_usage.clone() - initial_usage;
|
let usage = thread.cumulative_token_usage - initial_usage;
|
||||||
|
|
||||||
telemetry::event!(
|
telemetry::event!(
|
||||||
"Assistant Thread Completion",
|
"Assistant Thread Completion",
|
||||||
|
@ -1862,6 +1862,10 @@ impl Thread {
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn cumulative_token_usage(&self) -> TokenUsage {
|
||||||
|
self.cumulative_token_usage
|
||||||
|
}
|
||||||
|
|
||||||
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
|
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
|
||||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||||
let Some(model) = model_registry.default_model() else {
|
let Some(model) = model_registry.default_model() else {
|
||||||
|
|
3
crates/eval/.gitignore
vendored
Normal file
3
crates/eval/.gitignore
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
repos/
|
||||||
|
worktrees/
|
||||||
|
runs/
|
|
@ -7,28 +7,39 @@ edition.workspace = true
|
||||||
[dependencies]
|
[dependencies]
|
||||||
agent.workspace = true
|
agent.workspace = true
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
|
async-watch.workspace = true
|
||||||
|
assistant_settings.workspace = true
|
||||||
assistant_tool.workspace = true
|
assistant_tool.workspace = true
|
||||||
assistant_tools.workspace = true
|
assistant_tools.workspace = true
|
||||||
assistant_settings.workspace = true
|
chrono.workspace = true
|
||||||
|
clap.workspace = true
|
||||||
client.workspace = true
|
client.workspace = true
|
||||||
context_server.workspace = true
|
context_server.workspace = true
|
||||||
dap.workspace = true
|
dap.workspace = true
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
|
extension.workspace = true
|
||||||
fs.workspace = true
|
fs.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
gpui_tokio.workspace = true
|
gpui_tokio.workspace = true
|
||||||
|
handlebars.workspace = true
|
||||||
language.workspace = true
|
language.workspace = true
|
||||||
|
language_extension.workspace = true
|
||||||
language_model.workspace = true
|
language_model.workspace = true
|
||||||
language_models.workspace = true
|
language_models.workspace = true
|
||||||
|
languages.workspace = true
|
||||||
node_runtime.workspace = true
|
node_runtime.workspace = true
|
||||||
|
paths.workspace = true
|
||||||
project.workspace = true
|
project.workspace = true
|
||||||
prompt_store.workspace = true
|
prompt_store.workspace = true
|
||||||
release_channel.workspace = true
|
release_channel.workspace = true
|
||||||
reqwest_client.workspace = true
|
reqwest_client.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
settings.workspace = true
|
settings.workspace = true
|
||||||
|
shellexpand.workspace = true
|
||||||
toml.workspace = true
|
toml.workspace = true
|
||||||
|
unindent.workspace = true
|
||||||
|
util.workspace = true
|
||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
path = "../zed_worktree"
|
url = "https://github.com/zed-industries/zed.git"
|
||||||
revision = "38fcadf9481d018543c65f36ac3bafeba190179b"
|
revision = "38fcadf9481d018543c65f36ac3bafeba190179b"
|
||||||
|
language_extension = "rs"
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
1. The changes must replace the previous output returned by `FindReplaceFileTool` with the new `ToolResult` struct. The struct should contain an `output` field that is the same as the string we were returning before, and a new `card` field that contains a view for the card
|
||||||
|
2. The card should be a view that displays a diff. Each line in the diff should be colored according to whether it was added, removed or unchanged.
|
|
@ -1,32 +1,75 @@
|
||||||
mod example;
|
mod example;
|
||||||
|
|
||||||
use assistant_settings::AssistantSettings;
|
use assistant_settings::AssistantSettings;
|
||||||
use client::{Client, UserStore};
|
use client::{Client, ProxySettings, UserStore};
|
||||||
pub(crate) use example::*;
|
pub(crate) use example::*;
|
||||||
|
|
||||||
use ::fs::RealFs;
|
use ::fs::RealFs;
|
||||||
use anyhow::anyhow;
|
use anyhow::{Result, anyhow};
|
||||||
use gpui::{App, AppContext, Application, Entity, SemanticVersion, Task};
|
use clap::Parser;
|
||||||
|
use extension::ExtensionHostProxy;
|
||||||
|
use futures::future;
|
||||||
|
use gpui::http_client::{Uri, read_proxy_from_env};
|
||||||
|
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task};
|
||||||
|
use gpui_tokio::Tokio;
|
||||||
use language::LanguageRegistry;
|
use language::LanguageRegistry;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
|
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||||
};
|
};
|
||||||
use node_runtime::NodeRuntime;
|
use node_runtime::{NodeBinaryOptions, NodeRuntime};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
use project::project_settings::ProjectSettings;
|
||||||
use prompt_store::PromptBuilder;
|
use prompt_store::PromptBuilder;
|
||||||
|
use release_channel::AppVersion;
|
||||||
use reqwest_client::ReqwestClient;
|
use reqwest_client::ReqwestClient;
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use util::ResultExt as _;
|
||||||
|
|
||||||
|
pub const RUNS_DIR: &str = "./crates/eval/runs";
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(name = "eval", disable_version_flag = true)]
|
||||||
|
struct Args {
|
||||||
|
/// Runs all examples that contain these substrings. If unspecified, all examples are run.
|
||||||
|
#[arg(value_name = "EXAMPLE_SUBSTRING")]
|
||||||
|
examples: Vec<String>,
|
||||||
|
/// Model to use (default: "claude-3-7-sonnet-latest")
|
||||||
|
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
|
||||||
|
model: String,
|
||||||
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
env_logger::init();
|
env_logger::init();
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let all_available_examples = list_all_examples().unwrap();
|
||||||
|
let example_paths = all_available_examples
|
||||||
|
.iter()
|
||||||
|
.filter_map(|example_path| {
|
||||||
|
let name = example_path.file_name()?.to_string_lossy();
|
||||||
|
if args.examples.is_empty()
|
||||||
|
|| args
|
||||||
|
.examples
|
||||||
|
.iter()
|
||||||
|
.any(|name_substring| name.contains(name_substring))
|
||||||
|
{
|
||||||
|
Some(example_path.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
let http_client = Arc::new(ReqwestClient::new());
|
let http_client = Arc::new(ReqwestClient::new());
|
||||||
let app = Application::headless().with_http_client(http_client.clone());
|
let app = Application::headless().with_http_client(http_client.clone());
|
||||||
|
|
||||||
app.run(move |cx| {
|
app.run(move |cx| {
|
||||||
let app_state = init(cx);
|
let app_state = init(cx);
|
||||||
|
|
||||||
let model = find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
|
let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
|
||||||
|
|
||||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||||
registry.set_default_model(Some(model.clone()), cx);
|
registry.set_default_model(Some(model.clone()), cx);
|
||||||
|
@ -39,17 +82,142 @@ fn main() {
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
authenticate.await.unwrap();
|
authenticate.await.unwrap();
|
||||||
|
|
||||||
let example =
|
std::fs::create_dir_all(REPOS_DIR)?;
|
||||||
Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
|
std::fs::create_dir_all(WORKTREES_DIR)?;
|
||||||
example.setup()?;
|
|
||||||
cx.update(|cx| example.run(model, app_state, cx))?.await?;
|
|
||||||
|
|
||||||
anyhow::Ok(())
|
let run_dir = Path::new(RUNS_DIR).join(format!(
|
||||||
|
"{}",
|
||||||
|
chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
|
||||||
|
));
|
||||||
|
std::fs::create_dir_all(&run_dir)?;
|
||||||
|
|
||||||
|
let mut examples = Vec::new();
|
||||||
|
for example_path in example_paths {
|
||||||
|
let example = Example::load_from_directory(&example_path, &run_dir)?;
|
||||||
|
examples.push((example_path, example));
|
||||||
|
}
|
||||||
|
let mut repo_urls = HashSet::new();
|
||||||
|
|
||||||
|
let mut clone_tasks = Vec::new();
|
||||||
|
|
||||||
|
for (_, example) in examples.iter() {
|
||||||
|
let repo_url = example.base.url.clone();
|
||||||
|
if repo_urls.insert(repo_url.clone()) {
|
||||||
|
let repo_path = repo_path_for_url(&repo_url);
|
||||||
|
|
||||||
|
if !repo_path.join(".git").is_dir() {
|
||||||
|
println!("Cloning: {}", repo_url);
|
||||||
|
|
||||||
|
let git_task = cx.spawn(async move |_cx| {
|
||||||
|
std::fs::create_dir_all(&repo_path)?;
|
||||||
|
run_git(&repo_path, &["init"]).await?;
|
||||||
|
run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
|
||||||
|
});
|
||||||
|
|
||||||
|
clone_tasks.push(git_task);
|
||||||
|
} else {
|
||||||
|
println!("Already cloned: {}", repo_url);
|
||||||
|
|
||||||
|
let actual_origin =
|
||||||
|
run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
|
||||||
|
if actual_origin != repo_url {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"remote origin {} does not match expected origin {}",
|
||||||
|
actual_origin,
|
||||||
|
repo_url,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
future::join_all(clone_tasks).await;
|
||||||
|
|
||||||
|
let tasks = examples
|
||||||
|
.into_iter()
|
||||||
|
.map(|(example_path, example)| {
|
||||||
|
let app_state = app_state.clone();
|
||||||
|
let model = model.clone();
|
||||||
|
cx.spawn(async move |cx| {
|
||||||
|
(
|
||||||
|
example_path,
|
||||||
|
run_example(example, model, app_state, cx).await,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let results: Vec<(PathBuf, Result<JudgeOutput>)> = future::join_all(tasks).await;
|
||||||
|
|
||||||
|
println!("\n\n");
|
||||||
|
println!("========================================");
|
||||||
|
println!(" EVAL RESULTS ");
|
||||||
|
println!("========================================");
|
||||||
|
println!("");
|
||||||
|
|
||||||
|
let mut judge_scores = Vec::new();
|
||||||
|
|
||||||
|
for (example_path, result) in results {
|
||||||
|
let example_name = example_path.file_name().unwrap().to_string_lossy();
|
||||||
|
match result {
|
||||||
|
Err(err) => {
|
||||||
|
println!("💥 {:<30}: {:?}", example_name, err);
|
||||||
|
}
|
||||||
|
Ok(judge_output) => {
|
||||||
|
const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"{} {:<30}: {}",
|
||||||
|
SCORES[judge_output.score.min(5) as usize],
|
||||||
|
example_name,
|
||||||
|
judge_output.score,
|
||||||
|
);
|
||||||
|
judge_scores.push(judge_output.score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let score_count = judge_scores.len();
|
||||||
|
let average_score = judge_scores
|
||||||
|
.into_iter()
|
||||||
|
.map(|score| score as f32)
|
||||||
|
.sum::<f32>()
|
||||||
|
/ (score_count as f32);
|
||||||
|
println!("\nAverage score: {average_score}");
|
||||||
|
|
||||||
|
cx.update(|cx| cx.quit())
|
||||||
})
|
})
|
||||||
.detach_and_log_err(cx);
|
.detach_and_log_err(cx);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn run_example(
|
||||||
|
mut example: Example,
|
||||||
|
model: Arc<dyn LanguageModel>,
|
||||||
|
app_state: Arc<AgentAppState>,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Result<JudgeOutput> {
|
||||||
|
example.setup().await?;
|
||||||
|
cx.update(|cx| example.run(model.clone(), app_state, cx))?
|
||||||
|
.await?;
|
||||||
|
let diff = example.repository_diff().await?;
|
||||||
|
example.judge(model, diff, cx).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn list_all_examples() -> Result<Vec<PathBuf>> {
|
||||||
|
let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
|
||||||
|
let entries = std::fs::read_dir(path).unwrap();
|
||||||
|
let mut result_paths = Vec::new();
|
||||||
|
for entry in entries {
|
||||||
|
let entry = entry?;
|
||||||
|
let path = entry.path();
|
||||||
|
if path.is_dir() {
|
||||||
|
result_paths.push(path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(result_paths)
|
||||||
|
}
|
||||||
|
|
||||||
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
|
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
|
||||||
pub struct AgentAppState {
|
pub struct AgentAppState {
|
||||||
pub languages: Arc<LanguageRegistry>,
|
pub languages: Arc<LanguageRegistry>,
|
||||||
|
@ -72,6 +240,27 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
client::init_settings(cx);
|
client::init_settings(cx);
|
||||||
|
|
||||||
|
// Set User-Agent so we can download language servers from GitHub
|
||||||
|
let user_agent = format!(
|
||||||
|
"Zed/{} ({}; {})",
|
||||||
|
AppVersion::global(cx),
|
||||||
|
std::env::consts::OS,
|
||||||
|
std::env::consts::ARCH
|
||||||
|
);
|
||||||
|
let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
|
||||||
|
let proxy_url = proxy_str
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|input| input.parse::<Uri>().ok())
|
||||||
|
.or_else(read_proxy_from_env);
|
||||||
|
let http = {
|
||||||
|
let _guard = Tokio::handle(cx).enter();
|
||||||
|
|
||||||
|
ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
|
||||||
|
.expect("could not start HTTP client")
|
||||||
|
};
|
||||||
|
cx.set_http_client(Arc::new(http));
|
||||||
|
|
||||||
Project::init_settings(cx);
|
Project::init_settings(cx);
|
||||||
|
|
||||||
let client = Client::production(cx);
|
let client = Client::production(cx);
|
||||||
|
@ -83,13 +272,47 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
||||||
cx.background_executor().clone(),
|
cx.background_executor().clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
|
let mut languages = LanguageRegistry::new(cx.background_executor().clone());
|
||||||
|
languages.set_language_server_download_dir(paths::languages_dir().clone());
|
||||||
|
let languages = Arc::new(languages);
|
||||||
|
|
||||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||||
|
|
||||||
|
extension::init(cx);
|
||||||
|
|
||||||
|
let (tx, rx) = async_watch::channel(None);
|
||||||
|
cx.observe_global::<SettingsStore>(move |cx| {
|
||||||
|
let settings = &ProjectSettings::get_global(cx).node;
|
||||||
|
let options = NodeBinaryOptions {
|
||||||
|
allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
|
||||||
|
allow_binary_download: true,
|
||||||
|
use_paths: settings.path.as_ref().map(|node_path| {
|
||||||
|
let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
|
||||||
|
let npm_path = settings
|
||||||
|
.npm_path
|
||||||
|
.as_ref()
|
||||||
|
.map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
|
||||||
|
(
|
||||||
|
node_path.clone(),
|
||||||
|
npm_path.unwrap_or_else(|| {
|
||||||
|
let base_path = PathBuf::new();
|
||||||
|
node_path.parent().unwrap_or(&base_path).join("npm")
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
tx.send(Some(options)).log_err();
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
|
||||||
|
|
||||||
|
let extension_host_proxy = ExtensionHostProxy::global(cx);
|
||||||
|
|
||||||
language::init(cx);
|
language::init(cx);
|
||||||
|
language_extension::init(extension_host_proxy.clone(), languages.clone());
|
||||||
language_model::init(client.clone(), cx);
|
language_model::init(client.clone(), cx);
|
||||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||||
|
languages::init(languages.clone(), node_runtime.clone(), cx);
|
||||||
assistant_tools::init(client.http_client().clone(), cx);
|
assistant_tools::init(client.http_client().clone(), cx);
|
||||||
context_server::init(cx);
|
context_server::init(cx);
|
||||||
let stdout_is_a_pty = false;
|
let stdout_is_a_pty = false;
|
||||||
|
@ -109,7 +332,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
||||||
client,
|
client,
|
||||||
user_store,
|
user_store,
|
||||||
fs,
|
fs,
|
||||||
node_runtime: NodeRuntime::unavailable(),
|
node_runtime,
|
||||||
prompt_builder,
|
prompt_builder,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,83 +1,161 @@
|
||||||
use agent::{RequestKind, ThreadEvent, ThreadStore};
|
use agent::{RequestKind, ThreadEvent, ThreadStore};
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use assistant_tool::ToolWorkingSet;
|
use assistant_tool::ToolWorkingSet;
|
||||||
|
use client::proto::LspWorkProgress;
|
||||||
use dap::DapRegistry;
|
use dap::DapRegistry;
|
||||||
use futures::channel::oneshot;
|
use futures::channel::{mpsc, oneshot};
|
||||||
use gpui::{App, Task};
|
use futures::{FutureExt, StreamExt as _};
|
||||||
use language_model::{LanguageModel, StopReason};
|
use gpui::{App, AsyncApp, Entity, Task};
|
||||||
use project::Project;
|
use handlebars::Handlebars;
|
||||||
use serde::Deserialize;
|
use language::{DiagnosticSeverity, OffsetRangeExt};
|
||||||
use std::process::Command;
|
use language_model::{
|
||||||
use std::sync::Arc;
|
LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
|
||||||
|
StopReason, TokenUsage,
|
||||||
|
};
|
||||||
|
use project::{LspStore, Project, ProjectPath};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fmt::Write as _;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::Write as _;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use std::time::Duration;
|
||||||
use std::{
|
use std::{
|
||||||
fs,
|
fs,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
};
|
};
|
||||||
|
use unindent::Unindent as _;
|
||||||
|
use util::ResultExt as _;
|
||||||
|
use util::command::new_smol_command;
|
||||||
|
use util::serde::default_true;
|
||||||
|
|
||||||
use crate::AgentAppState;
|
use crate::AgentAppState;
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
|
||||||
|
pub const REPOS_DIR: &str = "./crates/eval/repos";
|
||||||
|
pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub struct ExampleBase {
|
pub struct ExampleBase {
|
||||||
pub path: PathBuf,
|
pub url: String,
|
||||||
pub revision: String,
|
pub revision: String,
|
||||||
|
pub language_extension: Option<String>,
|
||||||
|
pub insert_id: Option<String>,
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub require_lsp: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct Example {
|
pub struct Example {
|
||||||
|
pub name: String,
|
||||||
|
/// Content of `base.toml`
|
||||||
pub base: ExampleBase,
|
pub base: ExampleBase,
|
||||||
|
/// Content of `prompt.md`
|
||||||
/// Content of the prompt.md file
|
|
||||||
pub prompt: String,
|
pub prompt: String,
|
||||||
|
/// Content of `criteria.md`
|
||||||
|
pub criteria: String,
|
||||||
|
/// Markdown log file to append to
|
||||||
|
pub log_file: Arc<Mutex<File>>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Content of the rubric.md file
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
pub _rubric: String,
|
pub struct RunOutput {
|
||||||
|
pub repository_diff: String,
|
||||||
|
pub diagnostics: String,
|
||||||
|
pub response_count: usize,
|
||||||
|
pub token_usage: TokenUsage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct JudgeInput {
|
||||||
|
pub repository_diff: String,
|
||||||
|
pub criteria: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct JudgeOutput {
|
||||||
|
pub analysis: String,
|
||||||
|
pub score: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Example {
|
impl Example {
|
||||||
/// Load an example from a directory containing base.toml, prompt.md, and rubric.md
|
/// Load an example from a directory containing base.toml, prompt.md, and criteria.md
|
||||||
pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
|
pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
|
||||||
let base_path = dir_path.as_ref().join("base.toml");
|
let name = dir_path.file_name().unwrap().to_string_lossy().to_string();
|
||||||
let prompt_path = dir_path.as_ref().join("prompt.md");
|
let base_path = dir_path.join("base.toml");
|
||||||
let rubric_path = dir_path.as_ref().join("rubric.md");
|
let prompt_path = dir_path.join("prompt.md");
|
||||||
|
let criteria_path = dir_path.join("criteria.md");
|
||||||
|
|
||||||
let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
|
let log_file_path = run_dir.join(format!(
|
||||||
base.path = base.path.canonicalize()?;
|
"{}.md",
|
||||||
|
dir_path.file_name().unwrap().to_str().unwrap()
|
||||||
|
));
|
||||||
|
let log_file = Arc::new(Mutex::new(File::create(&log_file_path).unwrap()));
|
||||||
|
println!("{}> Logging to {:?}", name, log_file_path);
|
||||||
|
|
||||||
Ok(Example {
|
Ok(Example {
|
||||||
base,
|
name,
|
||||||
prompt: fs::read_to_string(prompt_path)?,
|
base: toml::from_str(&fs::read_to_string(&base_path)?)?,
|
||||||
_rubric: fs::read_to_string(rubric_path)?,
|
prompt: fs::read_to_string(prompt_path.clone())?,
|
||||||
|
criteria: fs::read_to_string(criteria_path.clone())?,
|
||||||
|
log_file,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set up the example by checking out the specified Git revision
|
pub fn worktree_path(&self) -> PathBuf {
|
||||||
pub fn setup(&self) -> Result<()> {
|
Path::new(WORKTREES_DIR)
|
||||||
// Check if the directory exists
|
.canonicalize()
|
||||||
let path = Path::new(&self.base.path);
|
.context(format!("No such directory {WORKTREES_DIR}"))
|
||||||
anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
|
.unwrap()
|
||||||
|
.join(&self.name)
|
||||||
|
}
|
||||||
|
|
||||||
// Change to the project directory and checkout the specified revision
|
/// Set up the example by checking out the specified Git revision
|
||||||
let output = Command::new("git")
|
pub async fn setup(&self) -> Result<()> {
|
||||||
.current_dir(&self.base.path)
|
let repo_path = repo_path_for_url(&self.base.url);
|
||||||
.arg("checkout")
|
|
||||||
.arg(&self.base.revision)
|
run_git(
|
||||||
.output()?;
|
&repo_path,
|
||||||
anyhow::ensure!(
|
&["fetch", "--depth", "1", "origin", &self.base.revision],
|
||||||
output.status.success(),
|
)
|
||||||
"Failed to checkout revision {}: {}",
|
.await?;
|
||||||
self.base.revision,
|
|
||||||
String::from_utf8_lossy(&output.stderr),
|
let worktree_path = self.worktree_path();
|
||||||
);
|
|
||||||
|
if worktree_path.is_dir() {
|
||||||
|
println!("{}> Resetting existing worktree", self.name);
|
||||||
|
|
||||||
|
// TODO: consider including "-x" to remove ignored files. The downside of this is that
|
||||||
|
// it will also remove build artifacts, and so prevent incremental reuse there.
|
||||||
|
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
|
||||||
|
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
|
||||||
|
run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
|
||||||
|
} else {
|
||||||
|
println!("{}> Creating worktree", self.name);
|
||||||
|
|
||||||
|
let worktree_path_string = worktree_path.to_string_lossy().to_string();
|
||||||
|
|
||||||
|
run_git(
|
||||||
|
&repo_path,
|
||||||
|
&[
|
||||||
|
"worktree",
|
||||||
|
"add",
|
||||||
|
"-f",
|
||||||
|
&worktree_path_string,
|
||||||
|
&self.base.revision,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn run(
|
pub fn run(
|
||||||
self,
|
&self,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
app_state: Arc<AgentAppState>,
|
app_state: Arc<AgentAppState>,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<()>> {
|
) -> Task<Result<RunOutput>> {
|
||||||
let project = Project::local(
|
let project = Project::local(
|
||||||
app_state.client.clone(),
|
app_state.client.clone(),
|
||||||
app_state.node_runtime.clone(),
|
app_state.node_runtime.clone(),
|
||||||
|
@ -89,30 +167,119 @@ impl Example {
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let worktree_path = self.worktree_path();
|
||||||
let worktree = project.update(cx, |project, cx| {
|
let worktree = project.update(cx, |project, cx| {
|
||||||
project.create_worktree(self.base.path, true, cx)
|
project.create_worktree(&worktree_path, true, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
let tools = Arc::new(ToolWorkingSet::default());
|
let tools = Arc::new(ToolWorkingSet::default());
|
||||||
let thread_store =
|
let thread_store =
|
||||||
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
|
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
|
||||||
|
let this = self.clone();
|
||||||
|
|
||||||
println!("USER:");
|
|
||||||
println!("{}", self.prompt);
|
|
||||||
println!("ASSISTANT:");
|
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
worktree.await?;
|
let worktree = worktree.await?;
|
||||||
|
|
||||||
|
// Wait for worktree scan to finish before choosing a file to open.
|
||||||
|
worktree
|
||||||
|
.update(cx, |worktree, _cx| {
|
||||||
|
worktree.as_local().unwrap().scan_complete()
|
||||||
|
})?
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let lsp_open_handle_and_store = if this.base.require_lsp {
|
||||||
|
let language_extension = this.base.language_extension.as_deref().context(
|
||||||
|
"language_extension field is required in base.toml when `require_lsp == true`",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Open a file that matches the language to cause LSP to start.
|
||||||
|
let language_file = worktree.read_with(cx, |worktree, _cx| {
|
||||||
|
worktree
|
||||||
|
.files(false, 0)
|
||||||
|
.find_map(|e| {
|
||||||
|
if e.path.clone().extension().and_then(|ext| ext.to_str())
|
||||||
|
== Some(language_extension)
|
||||||
|
{
|
||||||
|
Some(ProjectPath {
|
||||||
|
worktree_id: worktree.id(),
|
||||||
|
path: e.path.clone(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.context("Failed to find a file for example language")
|
||||||
|
})??;
|
||||||
|
|
||||||
|
let open_language_file_buffer_task = project.update(cx, |project, cx| {
|
||||||
|
project.open_buffer(language_file.clone(), cx)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let language_file_buffer = open_language_file_buffer_task.await?;
|
||||||
|
|
||||||
|
let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
|
||||||
|
(
|
||||||
|
project.register_buffer_with_language_servers(&language_file_buffer, cx),
|
||||||
|
project.lsp_store().clone(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// TODO: remove this once the diagnostics tool waits for new diagnostics
|
||||||
|
cx.background_executor().timer(Duration::new(5, 0)).await;
|
||||||
|
wait_for_lang_server(&lsp_store, this.name.clone(), cx).await?;
|
||||||
|
|
||||||
|
lsp_store.update(cx, |lsp_store, cx| {
|
||||||
|
lsp_open_handle.update(cx, |buffer, cx| {
|
||||||
|
buffer.update(cx, |buffer, cx| {
|
||||||
|
let has_language_server = lsp_store
|
||||||
|
.language_servers_for_local_buffer(buffer, cx)
|
||||||
|
.next()
|
||||||
|
.is_some();
|
||||||
|
if has_language_server {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(anyhow!(
|
||||||
|
"`{:?}` was opened to cause the language server to start, \
|
||||||
|
but no language servers are registered for its buffer. \
|
||||||
|
Set `require_lsp = false` in `base.toml` to skip this.",
|
||||||
|
language_file
|
||||||
|
))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})??;
|
||||||
|
|
||||||
|
Some((lsp_open_handle, lsp_store))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
|
||||||
|
return Err(anyhow!("Setup only mode"));
|
||||||
|
}
|
||||||
|
|
||||||
let thread_store = thread_store.await;
|
let thread_store = thread_store.await;
|
||||||
let thread =
|
let thread =
|
||||||
thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
|
thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut log_file = this.log_file.lock().unwrap();
|
||||||
|
writeln!(&mut log_file, "👤 USER:").log_err();
|
||||||
|
writeln!(&mut log_file, "{}", this.prompt).log_err();
|
||||||
|
writeln!(&mut log_file, "🤖 ASSISTANT:").log_err();
|
||||||
|
log_file.flush().log_err();
|
||||||
|
}
|
||||||
|
|
||||||
let (tx, rx) = oneshot::channel();
|
let (tx, rx) = oneshot::channel();
|
||||||
let mut tx = Some(tx);
|
let mut tx = Some(tx);
|
||||||
|
|
||||||
let _subscription =
|
let _subscription = cx.subscribe(&thread, {
|
||||||
cx.subscribe(
|
let log_file = this.log_file.clone();
|
||||||
&thread,
|
let name = this.name.clone();
|
||||||
move |thread, event: &ThreadEvent, cx| match event {
|
move |thread, event: &ThreadEvent, cx| {
|
||||||
|
let mut log_file = log_file.lock().unwrap();
|
||||||
|
|
||||||
|
match event {
|
||||||
ThreadEvent::Stopped(reason) => match reason {
|
ThreadEvent::Stopped(reason) => match reason {
|
||||||
Ok(StopReason::EndTurn) => {
|
Ok(StopReason::EndTurn) => {
|
||||||
if let Some(tx) = tx.take() {
|
if let Some(tx) = tx.take() {
|
||||||
|
@ -137,15 +304,16 @@ impl Example {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ThreadEvent::StreamedAssistantText(_, chunk) => {
|
ThreadEvent::StreamedAssistantText(_, chunk) => {
|
||||||
print!("{}", chunk);
|
write!(&mut log_file, "{}", chunk).log_err();
|
||||||
}
|
}
|
||||||
ThreadEvent::StreamedAssistantThinking(_, chunk) => {
|
ThreadEvent::StreamedAssistantThinking(_, chunk) => {
|
||||||
print!("{}", chunk);
|
write!(&mut log_file, "{}", chunk).log_err();
|
||||||
}
|
}
|
||||||
ThreadEvent::UsePendingTools { tool_uses } => {
|
ThreadEvent::UsePendingTools { tool_uses } => {
|
||||||
println!("\n\nUSING TOOLS:");
|
writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
|
||||||
for tool_use in tool_uses {
|
for tool_use in tool_uses {
|
||||||
println!("{}: {}", tool_use.name, tool_use.input);
|
writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
|
||||||
|
.log_err();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ThreadEvent::ToolFinished {
|
ThreadEvent::ToolFinished {
|
||||||
|
@ -154,25 +322,331 @@ impl Example {
|
||||||
..
|
..
|
||||||
} => {
|
} => {
|
||||||
if let Some(tool_use) = pending_tool_use {
|
if let Some(tool_use) = pending_tool_use {
|
||||||
println!("\nTOOL FINISHED: {}", tool_use.name);
|
let message = format!("TOOL FINISHED: {}", tool_use.name);
|
||||||
|
println!("{name}> {message}");
|
||||||
|
writeln!(&mut log_file, "\n{}", message).log_err();
|
||||||
}
|
}
|
||||||
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
|
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
|
||||||
println!("\n{}\n", tool_result.content);
|
let message = format!("\n{}\n", tool_result.content);
|
||||||
|
writeln!(&mut log_file, "{}", message).log_err();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
},
|
}
|
||||||
)?;
|
|
||||||
|
log_file.flush().log_err();
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
let context = vec![];
|
let context = vec![];
|
||||||
thread.insert_user_message(self.prompt.clone(), context, None, cx);
|
thread.insert_user_message(this.prompt.clone(), context, None, cx);
|
||||||
thread.send_to_model(model, RequestKind::Chat, cx);
|
thread.send_to_model(model, RequestKind::Chat, cx);
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
rx.await??;
|
rx.await??;
|
||||||
|
|
||||||
Ok(())
|
if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
|
||||||
|
wait_for_lang_server(lsp_store, this.name.clone(), cx).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let repository_diff = this.repository_diff().await?;
|
||||||
|
let diagnostics = cx
|
||||||
|
.update(move |cx| {
|
||||||
|
cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
|
||||||
|
})?
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
drop(lsp_open_handle_and_store);
|
||||||
|
|
||||||
|
thread.update(cx, |thread, _cx| {
|
||||||
|
let response_count = thread
|
||||||
|
.messages()
|
||||||
|
.filter(|message| message.role == language_model::Role::Assistant)
|
||||||
|
.count();
|
||||||
|
RunOutput {
|
||||||
|
repository_diff,
|
||||||
|
diagnostics,
|
||||||
|
response_count,
|
||||||
|
token_usage: thread.cumulative_token_usage(),
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn judge(
|
||||||
|
&mut self,
|
||||||
|
model: Arc<dyn LanguageModel>,
|
||||||
|
repository_diff: String,
|
||||||
|
cx: &AsyncApp,
|
||||||
|
) -> Result<JudgeOutput> {
|
||||||
|
let judge_prompt = include_str!("judge_prompt.hbs");
|
||||||
|
let judge_prompt_name = "judge_prompt";
|
||||||
|
let mut handlebars = Handlebars::new();
|
||||||
|
handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
|
||||||
|
let prompt = handlebars.render(
|
||||||
|
judge_prompt_name,
|
||||||
|
&JudgeInput {
|
||||||
|
repository_diff,
|
||||||
|
criteria: self.criteria.clone(),
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let request = LanguageModelRequest {
|
||||||
|
messages: vec![LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec![MessageContent::Text(prompt)],
|
||||||
|
cache: false,
|
||||||
|
}],
|
||||||
|
temperature: None,
|
||||||
|
tools: Vec::new(),
|
||||||
|
stop: Vec::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = send_language_model_request(model, request, cx).await?;
|
||||||
|
|
||||||
|
let mut log_file = self.log_file.lock().unwrap();
|
||||||
|
|
||||||
|
writeln!(&mut log_file, "\n\n").log_err();
|
||||||
|
writeln!(&mut log_file, "========================================").log_err();
|
||||||
|
writeln!(&mut log_file, " JUDGE OUTPUT ").log_err();
|
||||||
|
writeln!(&mut log_file, "========================================").log_err();
|
||||||
|
writeln!(&mut log_file, "\n{}", &response).log_err();
|
||||||
|
|
||||||
|
parse_judge_output(&response)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn repository_diff(&self) -> Result<String> {
|
||||||
|
let worktree_path = self.worktree_path();
|
||||||
|
run_git(&worktree_path, &["add", "-N"]).await?;
|
||||||
|
run_git(&worktree_path, &["diff"]).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn wait_for_lang_server(
|
||||||
|
lsp_store: &Entity<LspStore>,
|
||||||
|
name: String,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Task<Result<()>> {
|
||||||
|
if cx
|
||||||
|
.update(|cx| !has_pending_lang_server_work(lsp_store, cx))
|
||||||
|
.unwrap()
|
||||||
|
|| std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
|
||||||
|
{
|
||||||
|
return Task::ready(anyhow::Ok(()));
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("{}> ⏵ Waiting for language server", name);
|
||||||
|
|
||||||
|
let (mut tx, mut rx) = mpsc::channel(1);
|
||||||
|
|
||||||
|
let subscription =
|
||||||
|
cx.subscribe(&lsp_store, {
|
||||||
|
let name = name.clone();
|
||||||
|
move |lsp_store, event, cx| {
|
||||||
|
match event {
|
||||||
|
project::LspStoreEvent::LanguageServerUpdate {
|
||||||
|
message:
|
||||||
|
client::proto::update_language_server::Variant::WorkProgress(
|
||||||
|
LspWorkProgress {
|
||||||
|
message: Some(message),
|
||||||
|
..
|
||||||
|
},
|
||||||
|
),
|
||||||
|
..
|
||||||
|
} => println!("{name}> ⟲ {message}"),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !has_pending_lang_server_work(&lsp_store, cx) {
|
||||||
|
tx.try_send(()).ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.spawn(async move |cx| {
|
||||||
|
let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
|
||||||
|
let result = futures::select! {
|
||||||
|
_ = rx.next() => {
|
||||||
|
println!("{}> ⚑ Language server idle", name);
|
||||||
|
anyhow::Ok(())
|
||||||
|
},
|
||||||
|
_ = timeout.fuse() => {
|
||||||
|
Err(anyhow!("LSP wait timed out after 5 minutes"))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
drop(subscription);
|
||||||
|
result
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
|
||||||
|
lsp_store
|
||||||
|
.read(cx)
|
||||||
|
.language_server_statuses()
|
||||||
|
.any(|(_, status)| !status.pending_work.is_empty())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
|
||||||
|
let paths_with_diagnostics = project.update(cx, |project, cx| {
|
||||||
|
project
|
||||||
|
.diagnostic_summaries(true, cx)
|
||||||
|
.filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
|
||||||
|
.map(|(project_path, _, _)| project_path)
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut output = String::new();
|
||||||
|
for project_path in paths_with_diagnostics {
|
||||||
|
let buffer = project
|
||||||
|
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||||
|
.await?;
|
||||||
|
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||||
|
|
||||||
|
for (_, group) in snapshot.diagnostic_groups(None) {
|
||||||
|
let entry = &group.entries[group.primary_ix];
|
||||||
|
let range = entry.range.to_point(&snapshot);
|
||||||
|
let severity = match entry.diagnostic.severity {
|
||||||
|
DiagnosticSeverity::ERROR => "error",
|
||||||
|
DiagnosticSeverity::WARNING => "warning",
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
writeln!(
|
||||||
|
output,
|
||||||
|
"{} at line {}: {}",
|
||||||
|
severity,
|
||||||
|
range.start.row + 1,
|
||||||
|
entry.diagnostic.message
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anyhow::Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
|
||||||
|
let analysis = get_tag("analysis", response)?.to_string();
|
||||||
|
let score = get_tag("score", response)?
|
||||||
|
.parse()
|
||||||
|
.context("error parsing score")?;
|
||||||
|
|
||||||
|
Ok(JudgeOutput { analysis, score })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_tag(name: &'static str, response: &str) -> Result<String> {
|
||||||
|
let start_tag = format!("<{}>", name);
|
||||||
|
let end_tag = format!("</{}>", name);
|
||||||
|
|
||||||
|
let start_ix = response
|
||||||
|
.find(&start_tag)
|
||||||
|
.context(format!("{} start tag not found", name))?;
|
||||||
|
let content_start_ix = start_ix + start_tag.len();
|
||||||
|
|
||||||
|
let end_ix = content_start_ix
|
||||||
|
+ response[content_start_ix..]
|
||||||
|
.find(&end_tag)
|
||||||
|
.context(format!("{} end tag not found", name))?;
|
||||||
|
|
||||||
|
let content = response[content_start_ix..end_ix].trim().unindent();
|
||||||
|
|
||||||
|
anyhow::Ok(content)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
|
||||||
|
let repo_name = repo_url
|
||||||
|
.trim_start_matches("https://")
|
||||||
|
.replace(|c: char| !c.is_alphanumeric(), "-");
|
||||||
|
Path::new(REPOS_DIR)
|
||||||
|
.canonicalize()
|
||||||
|
.context(format!("No such directory {REPOS_DIR}"))
|
||||||
|
.unwrap()
|
||||||
|
.join(repo_name)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
|
||||||
|
let output = new_smol_command("git")
|
||||||
|
.current_dir(repo_path)
|
||||||
|
.args(args)
|
||||||
|
.output()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if output.status.success() {
|
||||||
|
Ok(String::from_utf8(output.stdout)?.trim().to_string())
|
||||||
|
} else {
|
||||||
|
Err(anyhow!(
|
||||||
|
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||||
|
args.join(" "),
|
||||||
|
repo_path.display(),
|
||||||
|
output.status,
|
||||||
|
String::from_utf8_lossy(&output.stderr),
|
||||||
|
String::from_utf8_lossy(&output.stdout),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send_language_model_request(
|
||||||
|
model: Arc<dyn LanguageModel>,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
cx: &AsyncApp,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
match model.stream_completion_text(request, &cx).await {
|
||||||
|
Ok(mut stream) => {
|
||||||
|
let mut full_response = String::new();
|
||||||
|
while let Some(chunk_result) = stream.stream.next().await {
|
||||||
|
match chunk_result {
|
||||||
|
Ok(chunk_str) => {
|
||||||
|
print!("{}", &chunk_str);
|
||||||
|
full_response.push_str(&chunk_str);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Error receiving response from language model: {err}"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(full_response)
|
||||||
|
}
|
||||||
|
Err(err) => Err(anyhow!(
|
||||||
|
"Failed to get response from language model. Error was: {err}"
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_judge_output() {
|
||||||
|
let response = r#"
|
||||||
|
<analysis>The model did a good job but there were still compilations errors.</analysis>
|
||||||
|
<score>3</score>
|
||||||
|
"#
|
||||||
|
.unindent();
|
||||||
|
|
||||||
|
let output = parse_judge_output(&response).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
output.analysis,
|
||||||
|
"The model did a good job but there were still compilations errors."
|
||||||
|
);
|
||||||
|
assert_eq!(output.score, 3);
|
||||||
|
|
||||||
|
let response = r#"
|
||||||
|
Text around ignored
|
||||||
|
|
||||||
|
<analysis>
|
||||||
|
Failed to compile:
|
||||||
|
- Error 1
|
||||||
|
- Error 2
|
||||||
|
</analysis>
|
||||||
|
|
||||||
|
<score>1</score>
|
||||||
|
"#
|
||||||
|
.unindent();
|
||||||
|
|
||||||
|
let output = parse_judge_output(&response).unwrap();
|
||||||
|
assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
|
||||||
|
assert_eq!(output.score, 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
25
crates/eval/src/judge_prompt.hbs
Normal file
25
crates/eval/src/judge_prompt.hbs
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
You are an expert software developer tasked with evaluating the following changes to a codebase:
|
||||||
|
|
||||||
|
<changes>
|
||||||
|
{{repository_diff}}
|
||||||
|
</changes>
|
||||||
|
|
||||||
|
Use the following criteria to score the above changes.
|
||||||
|
|
||||||
|
<criteria>
|
||||||
|
{{criteria}}
|
||||||
|
</criteria>
|
||||||
|
|
||||||
|
Based on these criteria, give the test output a score between 0 and 5.
|
||||||
|
|
||||||
|
- 5 means: changes meet all criteria
|
||||||
|
- 0 means: changes don't meet any criteria
|
||||||
|
|
||||||
|
Be suspicious of the changes because they were generated by an LLM.
|
||||||
|
Sometimes the LLM decides to change random code, so if the changes are not mentioned in the criteria, penalize the score.
|
||||||
|
Analyze the diff hunk by hunk and describe how each change meets or fails to meet the criteria.
|
||||||
|
|
||||||
|
```
|
||||||
|
<analysis>{YOUR ANALYSIS HERE}</analysis>
|
||||||
|
<score>{YOUR SCORE HERE}</score>
|
||||||
|
```
|
|
@ -83,7 +83,7 @@ pub enum StopReason {
|
||||||
ToolUse,
|
ToolUse,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Default)]
|
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
|
||||||
pub struct TokenUsage {
|
pub struct TokenUsage {
|
||||||
#[serde(default, skip_serializing_if = "is_default")]
|
#[serde(default, skip_serializing_if = "is_default")]
|
||||||
pub input_tokens: u32,
|
pub input_tokens: u32,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue