Agent2 Model Selector (#36028)

Release Notes:

- N/A

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
This commit is contained in:
Ben Brandt 2025-08-13 11:01:02 +02:00 committed by GitHub
parent 8ff2e3e195
commit db497ac867
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1078 additions and 148 deletions

View file

@ -4,18 +4,22 @@ use crate::{
FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool,
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool,
};
use acp_thread::ModelSelector;
use acp_thread::AgentModelSelector;
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
use collections::{HashSet, IndexMap};
use fs::Fs;
use futures::{StreamExt, future};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
};
use language_model::{LanguageModel, LanguageModelRegistry};
use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
use project::{Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::{
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
};
use settings::update_settings_file;
use std::cell::RefCell;
use std::collections::HashMap;
use std::path::Path;
@ -48,6 +52,104 @@ struct Session {
_subscription: Subscription,
}
pub struct LanguageModels {
/// Access language model by ID
models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
/// Cached list for returning language model information
model_list: acp_thread::AgentModelList,
refresh_models_rx: watch::Receiver<()>,
refresh_models_tx: watch::Sender<()>,
}
impl LanguageModels {
fn new(cx: &App) -> Self {
let (refresh_models_tx, refresh_models_rx) = watch::channel(());
let mut this = Self {
models: HashMap::default(),
model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
refresh_models_rx,
refresh_models_tx,
};
this.refresh_list(cx);
this
}
fn refresh_list(&mut self, cx: &App) {
let providers = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.into_iter()
.filter(|provider| provider.is_authenticated(cx))
.collect::<Vec<_>>();
let mut language_model_list = IndexMap::default();
let mut recommended_models = HashSet::default();
let mut recommended = Vec::new();
for provider in &providers {
for model in provider.recommended_models(cx) {
recommended_models.insert(model.id());
recommended.push(Self::map_language_model_to_info(&model, &provider));
}
}
if !recommended.is_empty() {
language_model_list.insert(
acp_thread::AgentModelGroupName("Recommended".into()),
recommended,
);
}
let mut models = HashMap::default();
for provider in providers {
let mut provider_models = Vec::new();
for model in provider.provided_models(cx) {
let model_info = Self::map_language_model_to_info(&model, &provider);
let model_id = model_info.id.clone();
if !recommended_models.contains(&model.id()) {
provider_models.push(model_info);
}
models.insert(model_id, model);
}
if !provider_models.is_empty() {
language_model_list.insert(
acp_thread::AgentModelGroupName(provider.name().0.clone()),
provider_models,
);
}
}
self.models = models;
self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
self.refresh_models_tx.send(()).ok();
}
fn watch(&self) -> watch::Receiver<()> {
self.refresh_models_rx.clone()
}
pub fn model_from_id(
&self,
model_id: &acp_thread::AgentModelId,
) -> Option<Arc<dyn LanguageModel>> {
self.models.get(model_id).cloned()
}
fn map_language_model_to_info(
model: &Arc<dyn LanguageModel>,
provider: &Arc<dyn LanguageModelProvider>,
) -> acp_thread::AgentModelInfo {
acp_thread::AgentModelInfo {
id: Self::model_id(model),
name: model.name().0,
icon: Some(provider.icon()),
}
}
fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
}
}
pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>,
@ -58,8 +160,11 @@ pub struct NativeAgent {
context_server_registry: Entity<ContextServerRegistry>,
/// Shared templates for all threads
templates: Arc<Templates>,
/// Cached model information
models: LanguageModels,
project: Entity<Project>,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
_subscriptions: Vec<Subscription>,
}
@ -68,6 +173,7 @@ impl NativeAgent {
project: Entity<Project>,
templates: Arc<Templates>,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
cx: &mut AsyncApp,
) -> Result<Entity<NativeAgent>> {
log::info!("Creating new NativeAgent");
@ -77,7 +183,13 @@ impl NativeAgent {
.await;
cx.new(|cx| {
let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
let mut subscriptions = vec![
cx.subscribe(&project, Self::handle_project_event),
cx.subscribe(
&LanguageModelRegistry::global(cx),
Self::handle_models_updated_event,
),
];
if let Some(prompt_store) = prompt_store.as_ref() {
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
}
@ -95,13 +207,19 @@ impl NativeAgent {
ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
}),
templates,
models: LanguageModels::new(cx),
project,
prompt_store,
fs,
_subscriptions: subscriptions,
}
})
}
pub fn models(&self) -> &LanguageModels {
&self.models
}
async fn maintain_project_context(
this: WeakEntity<Self>,
mut needs_refresh: watch::Receiver<()>,
@ -297,75 +415,104 @@ impl NativeAgent {
) {
self.project_context_needs_refresh.send(()).ok();
}
fn handle_models_updated_event(
&mut self,
_registry: Entity<LanguageModelRegistry>,
_event: &language_model::Event,
cx: &mut Context<Self>,
) {
self.models.refresh_list(cx);
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, _| {
let model_id = LanguageModels::model_id(&thread.selected_model);
if let Some(model) = self.models.model_from_id(&model_id) {
thread.selected_model = model.clone();
}
});
}
}
}
/// Wrapper struct that implements the AgentConnection trait
#[derive(Clone)]
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
impl ModelSelector for NativeAgentConnection {
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
impl AgentModelSelector for NativeAgentConnection {
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
log::debug!("NativeAgentConnection::list_models called");
cx.spawn(async move |cx| {
cx.update(|cx| {
let registry = LanguageModelRegistry::read_global(cx);
let models = registry.available_models(cx).collect::<Vec<_>>();
log::info!("Found {} available models", models.len());
if models.is_empty() {
Err(anyhow::anyhow!("No models available"))
} else {
Ok(models)
}
})?
let list = self.0.read(cx).models.model_list.clone();
Task::ready(if list.is_empty() {
Err(anyhow::anyhow!("No models available"))
} else {
Ok(list)
})
}
fn select_model(
&self,
session_id: acp::SessionId,
model: Arc<dyn LanguageModel>,
cx: &mut AsyncApp,
model_id: acp_thread::AgentModelId,
cx: &mut App,
) -> Task<Result<()>> {
log::info!(
"Setting model for session {}: {:?}",
session_id,
model.name()
);
let agent = self.0.clone();
log::info!("Setting model for session {}: {}", session_id, model_id);
let Some(thread) = self
.0
.read(cx)
.sessions
.get(&session_id)
.map(|session| session.thread.clone())
else {
return Task::ready(Err(anyhow!("Session not found")));
};
cx.spawn(async move |cx| {
agent.update(cx, |agent, cx| {
if let Some(session) = agent.sessions.get(&session_id) {
session.thread.update(cx, |thread, _cx| {
thread.selected_model = model;
});
Ok(())
} else {
Err(anyhow!("Session not found"))
}
})?
})
let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
};
thread.update(cx, |thread, _cx| {
thread.selected_model = model.clone();
});
update_settings_file::<AgentSettings>(
self.0.read(cx).fs.clone(),
cx,
move |settings, _cx| {
settings.set_model(model);
},
);
Task::ready(Ok(()))
}
fn selected_model(
&self,
session_id: &acp::SessionId,
cx: &mut AsyncApp,
) -> Task<Result<Arc<dyn LanguageModel>>> {
let agent = self.0.clone();
cx: &mut App,
) -> Task<Result<acp_thread::AgentModelInfo>> {
let session_id = session_id.clone();
cx.spawn(async move |cx| {
let thread = agent
.read_with(cx, |agent, _| {
agent
.sessions
.get(&session_id)
.map(|session| session.thread.clone())
})?
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
Ok(selected)
})
let Some(thread) = self
.0
.read(cx)
.sessions
.get(&session_id)
.map(|session| session.thread.clone())
else {
return Task::ready(Err(anyhow!("Session not found")));
};
let model = thread.read(cx).selected_model.clone();
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
else {
return Task::ready(Err(anyhow!("Provider not found")));
};
Task::ready(Ok(LanguageModels::map_language_model_to_info(
&model, &provider,
)))
}
fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
self.0.read(cx).models.watch()
}
}
@ -413,13 +560,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
let default_model = registry
.default_model()
.map(|configured| {
log::info!(
"Using configured default model: {:?} from provider: {:?}",
configured.model.name(),
configured.provider.name()
);
configured.model
.and_then(|default_model| {
agent
.models
.model_from_id(&LanguageModels::model_id(&default_model.model))
})
.ok_or_else(|| {
log::warn!("No default model configured in settings");
@ -487,8 +631,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
Task::ready(Ok(()))
}
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
}
fn prompt(
@ -629,6 +773,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
#[cfg(test)]
mod tests {
use super::*;
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
use fs::FakeFs;
use gpui::TestAppContext;
use serde_json::json;
@ -646,9 +791,15 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [], cx).await;
let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
.await
.unwrap();
let agent = NativeAgent::new(
project.clone(),
Templates::new(),
None,
fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap();
agent.read_with(cx, |agent, _| {
assert_eq!(agent.project_context.borrow().worktrees, vec![])
});
@ -689,13 +840,131 @@ mod tests {
});
}
#[gpui::test]
async fn test_listing_models(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [], cx).await;
let connection = NativeAgentConnection(
NativeAgent::new(
project.clone(),
Templates::new(),
None,
fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap(),
);
let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
let acp_thread::AgentModelList::Grouped(models) = models else {
panic!("Unexpected model group");
};
assert_eq!(
models,
IndexMap::from_iter([(
AgentModelGroupName("Fake".into()),
vec![AgentModelInfo {
id: AgentModelId("fake/fake".into()),
name: "Fake".into(),
icon: Some(ui::IconName::ZedAssistant),
}]
)])
);
}
#[gpui::test]
async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.create_dir(paths::settings_file().parent().unwrap())
.await
.unwrap();
fs.insert_file(
paths::settings_file(),
json!({
"agent": {
"default_model": {
"provider": "foo",
"model": "bar"
}
}
})
.to_string()
.into_bytes(),
)
.await;
let project = Project::test(fs.clone(), [], cx).await;
// Create the agent and connection
let agent = NativeAgent::new(
project.clone(),
Templates::new(),
None,
fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap();
let connection = NativeAgentConnection(agent.clone());
// Create a thread/session
let acp_thread = cx
.update(|cx| {
Rc::new(connection.clone()).new_thread(
project.clone(),
Path::new("/a"),
&mut cx.to_async(),
)
})
.await
.unwrap();
let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
// Select a model
let model_id = AgentModelId("fake/fake".into());
cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
.await
.unwrap();
// Verify the thread has the selected model
agent.read_with(cx, |agent, _| {
let session = agent.sessions.get(&session_id).unwrap();
session.thread.read_with(cx, |thread, _| {
assert_eq!(thread.selected_model.id().0, "fake");
});
});
cx.run_until_parked();
// Verify settings file was updated
let settings_content = fs.load(paths::settings_file()).await.unwrap();
let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
// Check that the agent settings contain the selected model
assert_eq!(
settings_json["agent"]["default_model"]["model"],
json!("fake")
);
assert_eq!(
settings_json["agent"]["default_model"]["provider"],
json!("fake")
);
}
fn init_test(cx: &mut TestAppContext) {
env_logger::try_init().ok();
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
agent_settings::init(cx);
language::init(cx);
LanguageModelRegistry::test(cx);
});
}
}

View file

@ -1,8 +1,8 @@
use std::path::Path;
use std::rc::Rc;
use std::{path::Path, rc::Rc, sync::Arc};
use agent_servers::AgentServer;
use anyhow::Result;
use fs::Fs;
use gpui::{App, Entity, Task};
use project::Project;
use prompt_store::PromptStore;
@ -10,7 +10,15 @@ use prompt_store::PromptStore;
use crate::{NativeAgent, NativeAgentConnection, templates::Templates};
#[derive(Clone)]
pub struct NativeAgentServer;
pub struct NativeAgentServer {
fs: Arc<dyn Fs>,
}
impl NativeAgentServer {
pub fn new(fs: Arc<dyn Fs>) -> Self {
Self { fs }
}
}
impl AgentServer for NativeAgentServer {
fn name(&self) -> &'static str {
@ -41,6 +49,7 @@ impl AgentServer for NativeAgentServer {
_root_dir
);
let project = project.clone();
let fs = self.fs.clone();
let prompt_store = PromptStore::global(cx);
cx.spawn(async move |cx| {
log::debug!("Creating templates for native agent");
@ -48,7 +57,7 @@ impl AgentServer for NativeAgentServer {
let prompt_store = prompt_store.await?;
log::debug!("Creating native agent entity");
let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?;
let agent = NativeAgent::new(project, templates, Some(prompt_store), fs, cx).await?;
// Create the connection wrapper
let connection = NativeAgentConnection(agent);

View file

@ -1,6 +1,6 @@
use super::*;
use crate::MessageContent;
use acp_thread::AgentConnection;
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList};
use action_log::ActionLog;
use agent_client_protocol::{self as acp};
use agent_settings::AgentProfileId;
@ -686,13 +686,19 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
// Create a project for new_thread
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
fake_fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
let cwd = Path::new("/test");
// Create agent and connection
let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
.await
.unwrap();
let agent = NativeAgent::new(
project.clone(),
templates.clone(),
None,
fake_fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap();
let connection = NativeAgentConnection(agent.clone());
// Test model_selector returns Some
@ -705,22 +711,22 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
// Test list_models
let listed_models = cx
.update(|cx| {
let mut async_cx = cx.to_async();
selector.list_models(&mut async_cx)
})
.update(|cx| selector.list_models(cx))
.await
.expect("list_models should succeed");
let AgentModelList::Grouped(listed_models) = listed_models else {
panic!("Unexpected model list type");
};
assert!(!listed_models.is_empty(), "should have at least one model");
assert_eq!(listed_models[0].id().0, "fake");
assert_eq!(
listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
"fake/fake"
);
// Create a thread using new_thread
let connection_rc = Rc::new(connection.clone());
let acp_thread = cx
.update(|cx| {
let mut async_cx = cx.to_async();
connection_rc.new_thread(project, cwd, &mut async_cx)
})
.update(|cx| connection_rc.new_thread(project, cwd, &mut cx.to_async()))
.await
.expect("new_thread should succeed");
@ -729,12 +735,12 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
// Test selected_model returns the default
let model = cx
.update(|cx| {
let mut async_cx = cx.to_async();
selector.selected_model(&session_id, &mut async_cx)
})
.update(|cx| selector.selected_model(&session_id, cx))
.await
.expect("selected_model should succeed");
let model = cx
.update(|cx| agent.read(cx).models().model_from_id(&model.id))
.unwrap();
let model = model.as_fake();
assert_eq!(model.id().0, "fake", "should return default model");