Agent2 Model Selector (#36028)
Release Notes: - N/A --------- Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
This commit is contained in:
parent
8ff2e3e195
commit
db497ac867
13 changed files with 1078 additions and 148 deletions
|
@ -4,18 +4,22 @@ use crate::{
|
|||
FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool,
|
||||
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool,
|
||||
};
|
||||
use acp_thread::ModelSelector;
|
||||
use acp_thread::AgentModelSelector;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashSet, IndexMap};
|
||||
use fs::Fs;
|
||||
use futures::{StreamExt, future};
|
||||
use gpui::{
|
||||
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
|
||||
};
|
||||
use language_model::{LanguageModel, LanguageModelRegistry};
|
||||
use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
|
||||
use project::{Project, ProjectItem, ProjectPath, Worktree};
|
||||
use prompt_store::{
|
||||
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
|
||||
};
|
||||
use settings::update_settings_file;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
@ -48,6 +52,104 @@ struct Session {
|
|||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
pub struct LanguageModels {
|
||||
/// Access language model by ID
|
||||
models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
|
||||
/// Cached list for returning language model information
|
||||
model_list: acp_thread::AgentModelList,
|
||||
refresh_models_rx: watch::Receiver<()>,
|
||||
refresh_models_tx: watch::Sender<()>,
|
||||
}
|
||||
|
||||
impl LanguageModels {
|
||||
fn new(cx: &App) -> Self {
|
||||
let (refresh_models_tx, refresh_models_rx) = watch::channel(());
|
||||
let mut this = Self {
|
||||
models: HashMap::default(),
|
||||
model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
|
||||
refresh_models_rx,
|
||||
refresh_models_tx,
|
||||
};
|
||||
this.refresh_list(cx);
|
||||
this
|
||||
}
|
||||
|
||||
fn refresh_list(&mut self, cx: &App) {
|
||||
let providers = LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.providers()
|
||||
.into_iter()
|
||||
.filter(|provider| provider.is_authenticated(cx))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut language_model_list = IndexMap::default();
|
||||
let mut recommended_models = HashSet::default();
|
||||
|
||||
let mut recommended = Vec::new();
|
||||
for provider in &providers {
|
||||
for model in provider.recommended_models(cx) {
|
||||
recommended_models.insert(model.id());
|
||||
recommended.push(Self::map_language_model_to_info(&model, &provider));
|
||||
}
|
||||
}
|
||||
if !recommended.is_empty() {
|
||||
language_model_list.insert(
|
||||
acp_thread::AgentModelGroupName("Recommended".into()),
|
||||
recommended,
|
||||
);
|
||||
}
|
||||
|
||||
let mut models = HashMap::default();
|
||||
for provider in providers {
|
||||
let mut provider_models = Vec::new();
|
||||
for model in provider.provided_models(cx) {
|
||||
let model_info = Self::map_language_model_to_info(&model, &provider);
|
||||
let model_id = model_info.id.clone();
|
||||
if !recommended_models.contains(&model.id()) {
|
||||
provider_models.push(model_info);
|
||||
}
|
||||
models.insert(model_id, model);
|
||||
}
|
||||
if !provider_models.is_empty() {
|
||||
language_model_list.insert(
|
||||
acp_thread::AgentModelGroupName(provider.name().0.clone()),
|
||||
provider_models,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
self.models = models;
|
||||
self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
|
||||
self.refresh_models_tx.send(()).ok();
|
||||
}
|
||||
|
||||
fn watch(&self) -> watch::Receiver<()> {
|
||||
self.refresh_models_rx.clone()
|
||||
}
|
||||
|
||||
pub fn model_from_id(
|
||||
&self,
|
||||
model_id: &acp_thread::AgentModelId,
|
||||
) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.models.get(model_id).cloned()
|
||||
}
|
||||
|
||||
fn map_language_model_to_info(
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
provider: &Arc<dyn LanguageModelProvider>,
|
||||
) -> acp_thread::AgentModelInfo {
|
||||
acp_thread::AgentModelInfo {
|
||||
id: Self::model_id(model),
|
||||
name: model.name().0,
|
||||
icon: Some(provider.icon()),
|
||||
}
|
||||
}
|
||||
|
||||
fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
|
||||
acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NativeAgent {
|
||||
/// Session ID -> Session mapping
|
||||
sessions: HashMap<acp::SessionId, Session>,
|
||||
|
@ -58,8 +160,11 @@ pub struct NativeAgent {
|
|||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
/// Shared templates for all threads
|
||||
templates: Arc<Templates>,
|
||||
/// Cached model information
|
||||
models: LanguageModels,
|
||||
project: Entity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
fs: Arc<dyn Fs>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
|
@ -68,6 +173,7 @@ impl NativeAgent {
|
|||
project: Entity<Project>,
|
||||
templates: Arc<Templates>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<NativeAgent>> {
|
||||
log::info!("Creating new NativeAgent");
|
||||
|
@ -77,7 +183,13 @@ impl NativeAgent {
|
|||
.await;
|
||||
|
||||
cx.new(|cx| {
|
||||
let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
|
||||
let mut subscriptions = vec![
|
||||
cx.subscribe(&project, Self::handle_project_event),
|
||||
cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
Self::handle_models_updated_event,
|
||||
),
|
||||
];
|
||||
if let Some(prompt_store) = prompt_store.as_ref() {
|
||||
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
|
||||
}
|
||||
|
@ -95,13 +207,19 @@ impl NativeAgent {
|
|||
ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
|
||||
}),
|
||||
templates,
|
||||
models: LanguageModels::new(cx),
|
||||
project,
|
||||
prompt_store,
|
||||
fs,
|
||||
_subscriptions: subscriptions,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn models(&self) -> &LanguageModels {
|
||||
&self.models
|
||||
}
|
||||
|
||||
async fn maintain_project_context(
|
||||
this: WeakEntity<Self>,
|
||||
mut needs_refresh: watch::Receiver<()>,
|
||||
|
@ -297,75 +415,104 @@ impl NativeAgent {
|
|||
) {
|
||||
self.project_context_needs_refresh.send(()).ok();
|
||||
}
|
||||
|
||||
fn handle_models_updated_event(
|
||||
&mut self,
|
||||
_registry: Entity<LanguageModelRegistry>,
|
||||
_event: &language_model::Event,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.models.refresh_list(cx);
|
||||
for session in self.sessions.values_mut() {
|
||||
session.thread.update(cx, |thread, _| {
|
||||
let model_id = LanguageModels::model_id(&thread.selected_model);
|
||||
if let Some(model) = self.models.model_from_id(&model_id) {
|
||||
thread.selected_model = model.clone();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper struct that implements the AgentConnection trait
|
||||
#[derive(Clone)]
|
||||
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
|
||||
|
||||
impl ModelSelector for NativeAgentConnection {
|
||||
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
|
||||
impl AgentModelSelector for NativeAgentConnection {
|
||||
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
|
||||
log::debug!("NativeAgentConnection::list_models called");
|
||||
cx.spawn(async move |cx| {
|
||||
cx.update(|cx| {
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
let models = registry.available_models(cx).collect::<Vec<_>>();
|
||||
log::info!("Found {} available models", models.len());
|
||||
if models.is_empty() {
|
||||
Err(anyhow::anyhow!("No models available"))
|
||||
} else {
|
||||
Ok(models)
|
||||
}
|
||||
})?
|
||||
let list = self.0.read(cx).models.model_list.clone();
|
||||
Task::ready(if list.is_empty() {
|
||||
Err(anyhow::anyhow!("No models available"))
|
||||
} else {
|
||||
Ok(list)
|
||||
})
|
||||
}
|
||||
|
||||
fn select_model(
|
||||
&self,
|
||||
session_id: acp::SessionId,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut AsyncApp,
|
||||
model_id: acp_thread::AgentModelId,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<()>> {
|
||||
log::info!(
|
||||
"Setting model for session {}: {:?}",
|
||||
session_id,
|
||||
model.name()
|
||||
);
|
||||
let agent = self.0.clone();
|
||||
log::info!("Setting model for session {}: {}", session_id, model_id);
|
||||
let Some(thread) = self
|
||||
.0
|
||||
.read(cx)
|
||||
.sessions
|
||||
.get(&session_id)
|
||||
.map(|session| session.thread.clone())
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Session not found")));
|
||||
};
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
agent.update(cx, |agent, cx| {
|
||||
if let Some(session) = agent.sessions.get(&session_id) {
|
||||
session.thread.update(cx, |thread, _cx| {
|
||||
thread.selected_model = model;
|
||||
});
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("Session not found"))
|
||||
}
|
||||
})?
|
||||
})
|
||||
let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
|
||||
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
|
||||
};
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.selected_model = model.clone();
|
||||
});
|
||||
|
||||
update_settings_file::<AgentSettings>(
|
||||
self.0.read(cx).fs.clone(),
|
||||
cx,
|
||||
move |settings, _cx| {
|
||||
settings.set_model(model);
|
||||
},
|
||||
);
|
||||
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn selected_model(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Arc<dyn LanguageModel>>> {
|
||||
let agent = self.0.clone();
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp_thread::AgentModelInfo>> {
|
||||
let session_id = session_id.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let thread = agent
|
||||
.read_with(cx, |agent, _| {
|
||||
agent
|
||||
.sessions
|
||||
.get(&session_id)
|
||||
.map(|session| session.thread.clone())
|
||||
})?
|
||||
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
|
||||
let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
||||
Ok(selected)
|
||||
})
|
||||
|
||||
let Some(thread) = self
|
||||
.0
|
||||
.read(cx)
|
||||
.sessions
|
||||
.get(&session_id)
|
||||
.map(|session| session.thread.clone())
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Session not found")));
|
||||
};
|
||||
let model = thread.read(cx).selected_model.clone();
|
||||
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Provider not found")));
|
||||
};
|
||||
Task::ready(Ok(LanguageModels::map_language_model_to_info(
|
||||
&model, &provider,
|
||||
)))
|
||||
}
|
||||
|
||||
fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
|
||||
self.0.read(cx).models.watch()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -413,13 +560,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
|
||||
let default_model = registry
|
||||
.default_model()
|
||||
.map(|configured| {
|
||||
log::info!(
|
||||
"Using configured default model: {:?} from provider: {:?}",
|
||||
configured.model.name(),
|
||||
configured.provider.name()
|
||||
);
|
||||
configured.model
|
||||
.and_then(|default_model| {
|
||||
agent
|
||||
.models
|
||||
.model_from_id(&LanguageModels::model_id(&default_model.model))
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
log::warn!("No default model configured in settings");
|
||||
|
@ -487,8 +631,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
|
||||
Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
|
||||
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
|
||||
Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
|
||||
}
|
||||
|
||||
fn prompt(
|
||||
|
@ -629,6 +773,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
|
||||
use fs::FakeFs;
|
||||
use gpui::TestAppContext;
|
||||
use serde_json::json;
|
||||
|
@ -646,9 +791,15 @@ mod tests {
|
|||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
let agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
agent.read_with(cx, |agent, _| {
|
||||
assert_eq!(agent.project_context.borrow().worktrees, vec![])
|
||||
});
|
||||
|
@ -689,13 +840,131 @@ mod tests {
|
|||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_listing_models(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/", json!({ "a": {} })).await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
let connection = NativeAgentConnection(
|
||||
NativeAgent::new(
|
||||
project.clone(),
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
|
||||
|
||||
let acp_thread::AgentModelList::Grouped(models) = models else {
|
||||
panic!("Unexpected model group");
|
||||
};
|
||||
assert_eq!(
|
||||
models,
|
||||
IndexMap::from_iter([(
|
||||
AgentModelGroupName("Fake".into()),
|
||||
vec![AgentModelInfo {
|
||||
id: AgentModelId("fake/fake".into()),
|
||||
name: "Fake".into(),
|
||||
icon: Some(ui::IconName::ZedAssistant),
|
||||
}]
|
||||
)])
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.create_dir(paths::settings_file().parent().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
fs.insert_file(
|
||||
paths::settings_file(),
|
||||
json!({
|
||||
"agent": {
|
||||
"default_model": {
|
||||
"provider": "foo",
|
||||
"model": "bar"
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.into_bytes(),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
|
||||
// Create the agent and connection
|
||||
let agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let connection = NativeAgentConnection(agent.clone());
|
||||
|
||||
// Create a thread/session
|
||||
let acp_thread = cx
|
||||
.update(|cx| {
|
||||
Rc::new(connection.clone()).new_thread(
|
||||
project.clone(),
|
||||
Path::new("/a"),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
|
||||
|
||||
// Select a model
|
||||
let model_id = AgentModelId("fake/fake".into());
|
||||
cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify the thread has the selected model
|
||||
agent.read_with(cx, |agent, _| {
|
||||
let session = agent.sessions.get(&session_id).unwrap();
|
||||
session.thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.selected_model.id().0, "fake");
|
||||
});
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
// Verify settings file was updated
|
||||
let settings_content = fs.load(paths::settings_file()).await.unwrap();
|
||||
let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
|
||||
|
||||
// Check that the agent settings contain the selected model
|
||||
assert_eq!(
|
||||
settings_json["agent"]["default_model"]["model"],
|
||||
json!("fake")
|
||||
);
|
||||
assert_eq!(
|
||||
settings_json["agent"]["default_model"]["provider"],
|
||||
json!("fake")
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
env_logger::try_init().ok();
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
agent_settings::init(cx);
|
||||
language::init(cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::{path::Path, rc::Rc, sync::Arc};
|
||||
|
||||
use agent_servers::AgentServer;
|
||||
use anyhow::Result;
|
||||
use fs::Fs;
|
||||
use gpui::{App, Entity, Task};
|
||||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
|
@ -10,7 +10,15 @@ use prompt_store::PromptStore;
|
|||
use crate::{NativeAgent, NativeAgentConnection, templates::Templates};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NativeAgentServer;
|
||||
pub struct NativeAgentServer {
|
||||
fs: Arc<dyn Fs>,
|
||||
}
|
||||
|
||||
impl NativeAgentServer {
|
||||
pub fn new(fs: Arc<dyn Fs>) -> Self {
|
||||
Self { fs }
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentServer for NativeAgentServer {
|
||||
fn name(&self) -> &'static str {
|
||||
|
@ -41,6 +49,7 @@ impl AgentServer for NativeAgentServer {
|
|||
_root_dir
|
||||
);
|
||||
let project = project.clone();
|
||||
let fs = self.fs.clone();
|
||||
let prompt_store = PromptStore::global(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
log::debug!("Creating templates for native agent");
|
||||
|
@ -48,7 +57,7 @@ impl AgentServer for NativeAgentServer {
|
|||
let prompt_store = prompt_store.await?;
|
||||
|
||||
log::debug!("Creating native agent entity");
|
||||
let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?;
|
||||
let agent = NativeAgent::new(project, templates, Some(prompt_store), fs, cx).await?;
|
||||
|
||||
// Create the connection wrapper
|
||||
let connection = NativeAgentConnection(agent);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use super::*;
|
||||
use crate::MessageContent;
|
||||
use acp_thread::AgentConnection;
|
||||
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList};
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use agent_settings::AgentProfileId;
|
||||
|
@ -686,13 +686,19 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
|||
// Create a project for new_thread
|
||||
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
|
||||
fake_fs.insert_tree(path!("/test"), json!({})).await;
|
||||
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
|
||||
let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
|
||||
let cwd = Path::new("/test");
|
||||
|
||||
// Create agent and connection
|
||||
let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
let agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
templates.clone(),
|
||||
None,
|
||||
fake_fs.clone(),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let connection = NativeAgentConnection(agent.clone());
|
||||
|
||||
// Test model_selector returns Some
|
||||
|
@ -705,22 +711,22 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
|||
|
||||
// Test list_models
|
||||
let listed_models = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.list_models(&mut async_cx)
|
||||
})
|
||||
.update(|cx| selector.list_models(cx))
|
||||
.await
|
||||
.expect("list_models should succeed");
|
||||
let AgentModelList::Grouped(listed_models) = listed_models else {
|
||||
panic!("Unexpected model list type");
|
||||
};
|
||||
assert!(!listed_models.is_empty(), "should have at least one model");
|
||||
assert_eq!(listed_models[0].id().0, "fake");
|
||||
assert_eq!(
|
||||
listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
|
||||
"fake/fake"
|
||||
);
|
||||
|
||||
// Create a thread using new_thread
|
||||
let connection_rc = Rc::new(connection.clone());
|
||||
let acp_thread = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
connection_rc.new_thread(project, cwd, &mut async_cx)
|
||||
})
|
||||
.update(|cx| connection_rc.new_thread(project, cwd, &mut cx.to_async()))
|
||||
.await
|
||||
.expect("new_thread should succeed");
|
||||
|
||||
|
@ -729,12 +735,12 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
|||
|
||||
// Test selected_model returns the default
|
||||
let model = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.selected_model(&session_id, &mut async_cx)
|
||||
})
|
||||
.update(|cx| selector.selected_model(&session_id, cx))
|
||||
.await
|
||||
.expect("selected_model should succeed");
|
||||
let model = cx
|
||||
.update(|cx| agent.read(cx).models().model_from_id(&model.id))
|
||||
.unwrap();
|
||||
let model = model.as_fake();
|
||||
assert_eq!(model.id().0, "fake", "should return default model");
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue