
Closes: https://github.com/zed-industries/zed/issues/20582 Allows users to select a specific model for each AI-powered feature: - Agent panel - Inline assistant - Thread summarization - Commit message generation If unspecified for a given feature, it will use the `default_model` setting. Release Notes: - Added support for configuring a specific model for each AI-powered feature --------- Co-authored-by: Danilo Leal <daniloleal09@gmail.com> Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
258 lines
10 KiB
Rust
258 lines
10 KiB
Rust
mod eval;
|
|
mod get_exercise;
|
|
mod git_commands;
|
|
mod headless_assistant;
|
|
mod judge;
|
|
mod templates_eval;
|
|
|
|
use clap::Parser;
|
|
use eval::{run_exercise_eval, save_eval_results};
|
|
use futures::stream::{self, StreamExt};
|
|
use get_exercise::{find_exercises, get_exercise_language, get_exercise_name};
|
|
use git_commands::read_base_sha;
|
|
use gpui::Application;
|
|
use headless_assistant::{authenticate_model_provider, find_model};
|
|
use language_model::LanguageModelRegistry;
|
|
use reqwest_client::ReqwestClient;
|
|
use std::{path::PathBuf, sync::Arc};
|
|
use templates_eval::all_templates;
|
|
|
|
#[derive(Parser, Debug)]
|
|
#[command(
|
|
name = "assistant_eval",
|
|
disable_version_flag = true,
|
|
before_help = "Tool eval runner"
|
|
)]
|
|
struct Args {
|
|
/// Match the names of evals to run.
|
|
#[arg(long)]
|
|
exercise_names: Vec<String>,
|
|
/// Runs all exercises, causes the exercise_names to be ignored.
|
|
#[arg(long)]
|
|
all: bool,
|
|
/// Supported language types to evaluate (default: internal).
|
|
/// Internal is data generated from the agent panel
|
|
#[arg(long, default_value = "internal")]
|
|
languages: String,
|
|
/// Name of the model (default: "claude-3-7-sonnet-latest")
|
|
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
|
|
model_name: String,
|
|
/// Name of the judge model (default: value of `--model_name`).
|
|
#[arg(long)]
|
|
judge_model_name: Option<String>,
|
|
/// Number of evaluations to run concurrently (default: 3)
|
|
#[arg(short, long, default_value = "3")]
|
|
concurrency: usize,
|
|
/// Maximum number of exercises to evaluate per language
|
|
#[arg(long)]
|
|
max_exercises_per_language: Option<usize>,
|
|
}
|
|
|
|
// First, let's define the order in which templates should be executed
|
|
const TEMPLATE_EXECUTION_ORDER: [&str; 3] = [
|
|
"ProjectCreation",
|
|
"CodeModification",
|
|
"ConversationalGuidance",
|
|
];
|
|
|
|
fn main() {
|
|
env_logger::init();
|
|
let args = Args::parse();
|
|
let http_client = Arc::new(ReqwestClient::new());
|
|
let app = Application::headless().with_http_client(http_client.clone());
|
|
|
|
// Path to the zed-ace-framework repo
|
|
let framework_path = PathBuf::from("../zed-ace-framework")
|
|
.canonicalize()
|
|
.unwrap();
|
|
|
|
// Fix the 'languages' lifetime issue by creating owned Strings instead of slices
|
|
let languages: Vec<String> = args.languages.split(',').map(|s| s.to_string()).collect();
|
|
|
|
println!("Using zed-ace-framework at: {:?}", framework_path);
|
|
println!("Evaluating languages: {:?}", languages);
|
|
|
|
app.run(move |cx| {
|
|
let app_state = headless_assistant::init(cx);
|
|
|
|
let model = find_model(&args.model_name, cx).unwrap();
|
|
let judge_model = if let Some(model_name) = &args.judge_model_name {
|
|
find_model(model_name, cx).unwrap()
|
|
} else {
|
|
model.clone()
|
|
};
|
|
|
|
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
|
registry.set_default_model(Some(model.clone()), cx);
|
|
});
|
|
|
|
let model_provider_id = model.provider_id();
|
|
let judge_model_provider_id = judge_model.provider_id();
|
|
|
|
let framework_path_clone = framework_path.clone();
|
|
let languages_clone = languages.clone();
|
|
let exercise_names = args.exercise_names.clone();
|
|
let all_flag = args.all;
|
|
|
|
cx.spawn(async move |cx| {
|
|
// Authenticate all model providers first
|
|
cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
|
|
.unwrap()
|
|
.await
|
|
.unwrap();
|
|
cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
|
|
.unwrap()
|
|
.await
|
|
.unwrap();
|
|
|
|
// Read base SHA from setup.json
|
|
let base_sha = read_base_sha(&framework_path_clone).await.unwrap();
|
|
|
|
// Find all exercises for the specified languages
|
|
let all_exercises = find_exercises(
|
|
&framework_path_clone,
|
|
&languages_clone
|
|
.iter()
|
|
.map(|s| s.as_str())
|
|
.collect::<Vec<_>>(),
|
|
args.max_exercises_per_language,
|
|
)
|
|
.unwrap();
|
|
println!("Found {} exercises total", all_exercises.len());
|
|
|
|
// Filter exercises if specific ones were requested
|
|
let exercises_to_run = if !exercise_names.is_empty() {
|
|
// If exercise names are specified, filter by them regardless of --all flag
|
|
all_exercises
|
|
.into_iter()
|
|
.filter(|path| {
|
|
let name = get_exercise_name(path);
|
|
exercise_names.iter().any(|filter| name.contains(filter))
|
|
})
|
|
.collect()
|
|
} else if all_flag {
|
|
// Only use all_flag if no exercise names are specified
|
|
all_exercises
|
|
} else {
|
|
// Default behavior (no filters)
|
|
all_exercises
|
|
};
|
|
|
|
println!("Will run {} exercises", exercises_to_run.len());
|
|
|
|
// Get all templates and sort them according to the execution order
|
|
let mut templates = all_templates();
|
|
templates.sort_by_key(|template| {
|
|
TEMPLATE_EXECUTION_ORDER
|
|
.iter()
|
|
.position(|&name| name == template.name)
|
|
.unwrap_or(usize::MAX)
|
|
});
|
|
|
|
// Create exercise eval tasks - each exercise is a single task that will run templates sequentially
|
|
let exercise_tasks: Vec<_> = exercises_to_run
|
|
.into_iter()
|
|
.map(|exercise_path| {
|
|
let exercise_name = get_exercise_name(&exercise_path);
|
|
let templates_clone = templates.clone();
|
|
let model_clone = model.clone();
|
|
let judge_model_clone = judge_model.clone();
|
|
let app_state_clone = app_state.clone();
|
|
let base_sha_clone = base_sha.clone();
|
|
let framework_path_clone = framework_path_clone.clone();
|
|
let cx_clone = cx.clone();
|
|
|
|
async move {
|
|
println!("Processing exercise: {}", exercise_name);
|
|
let mut exercise_results = Vec::new();
|
|
|
|
// Determine the language for this exercise
|
|
let language = match get_exercise_language(&exercise_path) {
|
|
Ok(lang) => lang,
|
|
Err(err) => {
|
|
println!(
|
|
"Error determining language for {}: {}",
|
|
exercise_name, err
|
|
);
|
|
return exercise_results;
|
|
}
|
|
};
|
|
|
|
// Run each template sequentially for this exercise
|
|
for template in templates_clone {
|
|
// For "multi" or "internal" language, only run the CodeModification template
|
|
if (language == "multi" || language == "internal")
|
|
&& template.name != "CodeModification"
|
|
{
|
|
println!(
|
|
"Skipping {} template for {} language",
|
|
template.name, language
|
|
);
|
|
continue;
|
|
}
|
|
|
|
match run_exercise_eval(
|
|
exercise_path.clone(),
|
|
template.clone(),
|
|
model_clone.clone(),
|
|
judge_model_clone.clone(),
|
|
app_state_clone.clone(),
|
|
base_sha_clone.clone(),
|
|
framework_path_clone.clone(),
|
|
cx_clone.clone(),
|
|
)
|
|
.await
|
|
{
|
|
Ok(result) => {
|
|
println!(
|
|
"Completed {} with template {} - score: {}",
|
|
exercise_name, template.name, result.score
|
|
);
|
|
exercise_results.push(result);
|
|
}
|
|
Err(err) => {
|
|
println!(
|
|
"Error running {} with template {}: {}",
|
|
exercise_name, template.name, err
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Save results for this exercise
|
|
if !exercise_results.is_empty() {
|
|
if let Err(err) =
|
|
save_eval_results(&exercise_path, exercise_results.clone()).await
|
|
{
|
|
println!("Error saving results for {}: {}", exercise_name, err);
|
|
} else {
|
|
println!("Saved results for {}", exercise_name);
|
|
}
|
|
}
|
|
|
|
exercise_results
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
println!(
|
|
"Running {} exercises with concurrency: {}",
|
|
exercise_tasks.len(),
|
|
args.concurrency
|
|
);
|
|
|
|
// Run exercises concurrently, with each exercise running its templates sequentially
|
|
let all_results = stream::iter(exercise_tasks)
|
|
.buffer_unordered(args.concurrency)
|
|
.flat_map(stream::iter)
|
|
.collect::<Vec<_>>()
|
|
.await;
|
|
|
|
println!("Completed {} evaluation runs", all_results.len());
|
|
cx.update(|cx| cx.quit()).unwrap();
|
|
})
|
|
.detach();
|
|
});
|
|
|
|
println!("Done running evals");
|
|
}
|