agent: Make ToolWorkingSet an Entity (#28757)

Motivation is to emit events when enabled tools change, want to use this
in #28755

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-04-15 06:42:31 -06:00 committed by GitHub
parent 7e1b419243
commit e26f0a331f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 183 additions and 202 deletions

View file

@ -56,7 +56,7 @@ impl SharedProjectContext {
pub struct ThreadStore {
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
@ -74,7 +74,7 @@ impl EventEmitter<RulesLoadingError> for ThreadStore {}
impl ThreadStore {
pub fn load(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut App,
) -> Task<Entity<Self>> {
@ -88,7 +88,7 @@ impl ThreadStore {
fn new(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
) -> Self {
@ -248,7 +248,7 @@ impl ThreadStore {
self.context_server_manager.clone()
}
pub fn tools(&self) -> Arc<ToolWorkingSet> {
pub fn tools(&self) -> Entity<ToolWorkingSet> {
self.tools.clone()
}
@ -355,52 +355,60 @@ impl ThreadStore {
})
}
fn load_default_profile(&self, cx: &Context<Self>) {
fn load_default_profile(&self, cx: &mut Context<Self>) {
let assistant_settings = AssistantSettings::get_global(cx);
self.load_profile_by_id(&assistant_settings.default_profile, cx);
self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
}
pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context<Self>) {
pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
let assistant_settings = AssistantSettings::get_global(cx);
if let Some(profile) = assistant_settings.profiles.get(profile_id) {
self.load_profile(profile, cx);
if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
self.load_profile(profile.clone(), cx);
}
}
pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
self.tools.disable_all_tools();
self.tools.enable(
ToolSource::Native,
&profile
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
);
pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
self.tools.update(cx, |tools, cx| {
tools.disable_all_tools(cx);
tools.enable(
ToolSource::Native,
&profile
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
cx,
);
});
if profile.enable_all_context_servers {
for context_server in self.context_server_manager.read(cx).all_servers() {
self.tools.enable_source(
ToolSource::ContextServer {
id: context_server.id().into(),
},
cx,
);
self.tools.update(cx, |tools, cx| {
tools.enable_source(
ToolSource::ContextServer {
id: context_server.id().into(),
},
cx,
);
});
}
} else {
for (context_server_id, preset) in &profile.context_servers {
self.tools.enable(
ToolSource::ContextServer {
id: context_server_id.clone().into(),
},
&preset
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
)
self.tools.update(cx, |tools, cx| {
tools.enable(
ToolSource::ContextServer {
id: context_server_id.clone().into(),
},
&preset
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
cx,
)
})
}
}
}
@ -434,29 +442,36 @@ impl ThreadStore {
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
let tool_ids = tool_working_set
.update(cx, |tool_working_set, _| {
tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
.log_err();
this.update(cx, |this, cx| {
this.context_server_tool_ids.insert(server_id, tool_ids);
this.load_default_profile(cx);
})
.log_err();
if let Some(tool_ids) = tool_ids {
this.update(cx, |this, cx| {
this.context_server_tool_ids
.insert(server_id, tool_ids);
this.load_default_profile(cx);
})
.log_err();
}
}
}
}
@ -466,7 +481,9 @@ impl ThreadStore {
}
context_server::manager::Event::ServerStopped { server_id } => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.remove(&tool_ids);
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
});
self.load_default_profile(cx);
}
}