From bb82d9ca829dcbbf7031421e83ae94d52169b31d Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Sun, 4 May 2025 13:43:57 -0600 Subject: [PATCH] agent eval: Fix `--model` arg and add `--provider` (#29883) Release Notes: - N/A --- crates/eval/src/eval.rs | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index 00bc60c9ea..265009c5bb 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -46,9 +46,12 @@ struct Args { /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run. #[arg(value_name = "EXAMPLE_SUBSTRING")] filter: Vec, - /// Model to use (default: "claude-3-7-sonnet-latest") + /// ID of model to use. #[arg(long, default_value = "claude-3-7-sonnet-latest")] model: String, + /// Model provider to use. + #[arg(long, default_value = "anthropic")] + provider: String, #[arg(long, value_delimiter = ',', default_value = "rs,ts")] languages: Vec, /// How many times to run each example. @@ -123,7 +126,7 @@ fn main() { let mut cumulative_tool_metrics = ToolMetrics::default(); let model_registry = LanguageModelRegistry::read_global(cx); - let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap(); + let model = find_model(&args.provider, &args.model, model_registry, cx).unwrap(); let model_provider_id = model.provider_id(); let model_provider = model_registry.provider(&model_provider_id).unwrap(); @@ -452,27 +455,36 @@ pub fn init(cx: &mut App) -> Arc { } pub fn find_model( - model_name: &str, + provider_id: &str, + model_id: &str, model_registry: &LanguageModelRegistry, cx: &App, ) -> anyhow::Result> { - let model = model_registry + let matching_models = model_registry .available_models(cx) - .find(|model| model.id().0 == model_name); + .filter(|model| model.id().0 == model_id && model.provider_id().0 == provider_id) + .collect::>(); - let Some(model) = model else { - return Err(anyhow!( - "No language model named {} was available. Available models: {}", - model_name, + match matching_models.as_slice() { + [model] => Ok(model.clone()), + [] => Err(anyhow!( + "No language model with ID {} was available. Available models: {}", + model_id, model_registry .available_models(cx) .map(|model| model.id().0.clone()) .collect::>() .join(", ") - )); - }; - - Ok(model) + )), + _ => Err(anyhow!( + "Multiple language models with ID {} available - use `--provider` to choose one of: {:?}", + model_id, + matching_models + .iter() + .map(|model| model.provider_id().0) + .collect::>() + )), + } } pub fn commit_sha_for_path(repo_path: &Path) -> String {