agent eval: Fix --model arg and add --provider (#29883)

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-05-04 13:43:57 -06:00 committed by GitHub
parent 007685f6d4
commit bb82d9ca82
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -46,9 +46,12 @@ struct Args {
/// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run. /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
#[arg(value_name = "EXAMPLE_SUBSTRING")] #[arg(value_name = "EXAMPLE_SUBSTRING")]
filter: Vec<String>, filter: Vec<String>,
/// Model to use (default: "claude-3-7-sonnet-latest") /// ID of model to use.
#[arg(long, default_value = "claude-3-7-sonnet-latest")] #[arg(long, default_value = "claude-3-7-sonnet-latest")]
model: String, model: String,
/// Model provider to use.
#[arg(long, default_value = "anthropic")]
provider: String,
#[arg(long, value_delimiter = ',', default_value = "rs,ts")] #[arg(long, value_delimiter = ',', default_value = "rs,ts")]
languages: Vec<String>, languages: Vec<String>,
/// How many times to run each example. /// How many times to run each example.
@ -123,7 +126,7 @@ fn main() {
let mut cumulative_tool_metrics = ToolMetrics::default(); let mut cumulative_tool_metrics = ToolMetrics::default();
let model_registry = LanguageModelRegistry::read_global(cx); let model_registry = LanguageModelRegistry::read_global(cx);
let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap(); let model = find_model(&args.provider, &args.model, model_registry, cx).unwrap();
let model_provider_id = model.provider_id(); let model_provider_id = model.provider_id();
let model_provider = model_registry.provider(&model_provider_id).unwrap(); let model_provider = model_registry.provider(&model_provider_id).unwrap();
@ -452,27 +455,36 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
} }
pub fn find_model( pub fn find_model(
model_name: &str, provider_id: &str,
model_id: &str,
model_registry: &LanguageModelRegistry, model_registry: &LanguageModelRegistry,
cx: &App, cx: &App,
) -> anyhow::Result<Arc<dyn LanguageModel>> { ) -> anyhow::Result<Arc<dyn LanguageModel>> {
let model = model_registry let matching_models = model_registry
.available_models(cx) .available_models(cx)
.find(|model| model.id().0 == model_name); .filter(|model| model.id().0 == model_id && model.provider_id().0 == provider_id)
.collect::<Vec<_>>();
let Some(model) = model else { match matching_models.as_slice() {
return Err(anyhow!( [model] => Ok(model.clone()),
"No language model named {} was available. Available models: {}", [] => Err(anyhow!(
model_name, "No language model with ID {} was available. Available models: {}",
model_id,
model_registry model_registry
.available_models(cx) .available_models(cx)
.map(|model| model.id().0.clone()) .map(|model| model.id().0.clone())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(", ") .join(", ")
)); )),
}; _ => Err(anyhow!(
"Multiple language models with ID {} available - use `--provider` to choose one of: {:?}",
Ok(model) model_id,
matching_models
.iter()
.map(|model| model.provider_id().0)
.collect::<Vec<_>>()
)),
}
} }
pub fn commit_sha_for_path(repo_path: &Path) -> String { pub fn commit_sha_for_path(repo_path: &Path) -> String {