diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/assistant_tools/src/edit_agent/evals.rs index 2af9c30434..9ad716aadb 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/assistant_tools/src/edit_agent/evals.rs @@ -15,7 +15,7 @@ use gpui::{AppContext, TestAppContext}; use indoc::{formatdoc, indoc}; use language_model::{ LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult, - LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, + LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, SelectedModel, }; use project::Project; use rand::prelude::*; @@ -25,6 +25,7 @@ use std::{ cmp::Reverse, fmt::{self, Display}, io::Write as _, + str::FromStr, sync::mpsc, }; use util::path; @@ -1216,7 +1217,7 @@ fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usiz passed_count as f64 / evaluated_count as f64 }; print!( - "\r\x1b[KEvaluated {}/{} ({:.2}%)", + "\r\x1b[KEvaluated {}/{} ({:.2}% passed)", evaluated_count, iterations, passed_ratio * 100.0 @@ -1255,13 +1256,21 @@ impl EditAgentTest { fs.insert_tree("/root", json!({})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let agent_model = SelectedModel::from_str( + &std::env::var("ZED_AGENT_MODEL") + .unwrap_or("anthropic/claude-3-7-sonnet-latest".into()), + ) + .unwrap(); + let judge_model = SelectedModel::from_str( + &std::env::var("ZED_JUDGE_MODEL") + .unwrap_or("anthropic/claude-3-7-sonnet-latest".into()), + ) + .unwrap(); let (agent_model, judge_model) = cx .update(|cx| { cx.spawn(async move |cx| { - let agent_model = - Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await; - let judge_model = - Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await; + let agent_model = Self::load_model(&agent_model, cx).await; + let judge_model = Self::load_model(&judge_model, cx).await; (agent_model.unwrap(), judge_model.unwrap()) }) }) @@ -1276,15 +1285,17 @@ impl EditAgentTest { } async fn load_model( - provider: &str, - id: &str, + selected_model: &SelectedModel, cx: &mut AsyncApp, ) -> Result> { let (provider, model) = cx.update(|cx| { let models = LanguageModelRegistry::read_global(cx); let model = models .available_models(cx) - .find(|model| model.provider_id().0 == provider && model.id().0 == id) + .find(|model| { + model.provider_id() == selected_model.provider + && model.id() == selected_model.model + }) .unwrap(); let provider = models.provider(&model.provider_id()).unwrap(); (provider, model) diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 46b0bc56fd..ce6518f65f 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -4,7 +4,7 @@ use crate::{ }; use collections::BTreeMap; use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*}; -use std::sync::Arc; +use std::{str::FromStr, sync::Arc}; use util::maybe; pub fn init(cx: &mut App) { @@ -27,11 +27,36 @@ pub struct LanguageModelRegistry { inline_alternatives: Vec>, } +#[derive(Debug)] pub struct SelectedModel { pub provider: LanguageModelProviderId, pub model: LanguageModelId, } +impl FromStr for SelectedModel { + type Err = String; + + /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel` + fn from_str(id: &str) -> Result { + let parts: Vec<&str> = id.split('/').collect(); + let [provider_id, model_id] = parts.as_slice() else { + return Err(format!( + "Invalid model identifier format: `{}`. Expected `provider_id/model_id`", + id + )); + }; + + if provider_id.is_empty() || model_id.is_empty() { + return Err(format!("Provider and model ids can't be empty: `{}`", id)); + } + + Ok(SelectedModel { + provider: LanguageModelProviderId(provider_id.to_string().into()), + model: LanguageModelId(model_id.to_string().into()), + }) + } +} + #[derive(Clone)] pub struct ConfiguredModel { pub provider: Arc,