agent: Make ToolWorkingSet an Entity (#28757)

Motivation is to emit events when enabled tools change, want to use this
in #28755

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-04-15 06:42:31 -06:00 committed by GitHub
parent 7e1b419243
commit e26f0a331f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 183 additions and 202 deletions

View file

@ -894,6 +894,7 @@ mod tests {
use super::*; use super::*;
use crate::{ThreadStore, thread_store}; use crate::{ThreadStore, thread_store};
use assistant_settings::AssistantSettings; use assistant_settings::AssistantSettings;
use assistant_tool::ToolWorkingSet;
use context_server::ContextServerSettings; use context_server::ContextServerSettings;
use editor::EditorSettings; use editor::EditorSettings;
use gpui::TestAppContext; use gpui::TestAppContext;
@ -937,7 +938,7 @@ mod tests {
.update(|cx| { .update(|cx| {
ThreadStore::load( ThreadStore::load(
project.clone(), project.clone(),
Arc::default(), cx.new(|_| ToolWorkingSet::default()),
Arc::new(PromptBuilder::new(None).unwrap()), Arc::new(PromptBuilder::new(None).unwrap()),
cx, cx,
) )

View file

@ -29,7 +29,7 @@ pub struct AssistantConfiguration {
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>, configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
context_server_manager: Entity<ContextServerManager>, context_server_manager: Entity<ContextServerManager>,
expanded_context_server_tools: HashMap<Arc<str>, bool>, expanded_context_server_tools: HashMap<Arc<str>, bool>,
tools: Arc<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
_registry_subscription: Subscription, _registry_subscription: Subscription,
} }
@ -37,7 +37,7 @@ impl AssistantConfiguration {
pub fn new( pub fn new(
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
context_server_manager: Entity<ContextServerManager>, context_server_manager: Entity<ContextServerManager>,
tools: Arc<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
@ -226,7 +226,7 @@ impl AssistantConfiguration {
fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement { fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let context_servers = self.context_server_manager.read(cx).all_servers().clone(); let context_servers = self.context_server_manager.read(cx).all_servers().clone();
let tools_by_source = self.tools.tools_by_source(cx); let tools_by_source = self.tools.read(cx).tools_by_source(cx);
let empty = Vec::new(); let empty = Vec::new();
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly."; const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";

View file

@ -84,7 +84,7 @@ pub struct NewProfileMode {
pub struct ManageProfilesModal { pub struct ManageProfilesModal {
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
tools: Arc<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>, thread_store: WeakEntity<ThreadStore>,
focus_handle: FocusHandle, focus_handle: FocusHandle,
mode: Mode, mode: Mode,
@ -117,7 +117,7 @@ impl ManageProfilesModal {
pub fn new( pub fn new(
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
tools: Arc<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>, thread_store: WeakEntity<ThreadStore>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,

View file

@ -60,7 +60,7 @@ pub struct ToolPickerDelegate {
impl ToolPickerDelegate { impl ToolPickerDelegate {
pub fn new( pub fn new(
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
tool_set: Arc<ToolWorkingSet>, tool_set: Entity<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>, thread_store: WeakEntity<ThreadStore>,
profile_id: AgentProfileId, profile_id: AgentProfileId,
profile: AgentProfile, profile: AgentProfile,
@ -68,7 +68,7 @@ impl ToolPickerDelegate {
) -> Self { ) -> Self {
let mut tool_entries = Vec::new(); let mut tool_entries = Vec::new();
for (source, tools) in tool_set.tools_by_source(cx) { for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
tool_entries.extend(tools.into_iter().map(|tool| ToolEntry { tool_entries.extend(tools.into_iter().map(|tool| ToolEntry {
name: tool.name().into(), name: tool.name().into(),
source: source.clone(), source: source.clone(),
@ -192,7 +192,7 @@ impl PickerDelegate for ToolPickerDelegate {
if active_profile_id == &self.profile_id { if active_profile_id == &self.profile_id {
self.thread_store self.thread_store
.update(cx, |this, cx| { .update(cx, |this, cx| {
this.load_profile(&self.profile, cx); this.load_profile(self.profile.clone(), cx);
}) })
.log_err(); .log_err();
} }

View file

@ -203,7 +203,7 @@ impl AssistantPanel {
cx: AsyncWindowContext, cx: AsyncWindowContext,
) -> Task<Result<Entity<Self>>> { ) -> Task<Result<Entity<Self>>> {
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let tools = Arc::new(ToolWorkingSet::default()); let tools = cx.new(|_| ToolWorkingSet::default())?;
let thread_store = workspace let thread_store = workspace
.update(cx, |workspace, cx| { .update(cx, |workspace, cx| {
let project = workspace.project().clone(); let project = workspace.project().clone();

View file

@ -86,7 +86,7 @@ impl ProfileSelector {
thread_store thread_store
.update(cx, |this, cx| { .update(cx, |this, cx| {
this.load_profile_by_id(&profile_id, cx); this.load_profile_by_id(profile_id.clone(), cx);
}) })
.log_err(); .log_err();
} }

View file

@ -254,7 +254,7 @@ pub struct Thread {
pending_completions: Vec<PendingCompletion>, pending_completions: Vec<PendingCompletion>,
project: Entity<Project>, project: Entity<Project>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
tools: Arc<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
tool_use: ToolUseState, tool_use: ToolUseState,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
last_restore_checkpoint: Option<LastRestoreCheckpoint>, last_restore_checkpoint: Option<LastRestoreCheckpoint>,
@ -278,7 +278,7 @@ pub struct ExceededWindowError {
impl Thread { impl Thread {
pub fn new( pub fn new(
project: Entity<Project>, project: Entity<Project>,
tools: Arc<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
system_prompt: SharedProjectContext, system_prompt: SharedProjectContext,
cx: &mut Context<Self>, cx: &mut Context<Self>,
@ -322,7 +322,7 @@ impl Thread {
id: ThreadId, id: ThreadId,
serialized: SerializedThread, serialized: SerializedThread,
project: Entity<Project>, project: Entity<Project>,
tools: Arc<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
project_context: SharedProjectContext, project_context: SharedProjectContext,
cx: &mut Context<Self>, cx: &mut Context<Self>,
@ -458,7 +458,7 @@ impl Thread {
!self.pending_completions.is_empty() || !self.all_tools_finished() !self.pending_completions.is_empty() || !self.all_tools_finished()
} }
pub fn tools(&self) -> &Arc<ToolWorkingSet> { pub fn tools(&self) -> &Entity<ToolWorkingSet> {
&self.tools &self.tools
} }
@ -846,6 +846,7 @@ impl Thread {
let mut tools = Vec::new(); let mut tools = Vec::new();
tools.extend( tools.extend(
self.tools() self.tools()
.read(cx)
.enabled_tools(cx) .enabled_tools(cx)
.into_iter() .into_iter()
.filter_map(|tool| { .filter_map(|tool| {
@ -1354,7 +1355,7 @@ impl Thread {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
for tool_use in pending_tool_uses.iter() { for tool_use in pending_tool_uses.iter() {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) { if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
if tool.needs_confirmation(&tool_use.input, cx) if tool.needs_confirmation(&tool_use.input, cx)
&& !AssistantSettings::get_global(cx).always_allow_tool_actions && !AssistantSettings::get_global(cx).always_allow_tool_actions
{ {
@ -1406,7 +1407,7 @@ impl Thread {
) -> Task<()> { ) -> Task<()> {
let tool_name: Arc<str> = tool.name().into(); let tool_name: Arc<str> = tool.name().into();
let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) { let run_tool = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))) Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
} else { } else {
tool.run( tool.run(
@ -1521,6 +1522,7 @@ impl Thread {
let enabled_tool_names: Vec<String> = self let enabled_tool_names: Vec<String> = self
.tools() .tools()
.read(cx)
.enabled_tools(cx) .enabled_tools(cx)
.iter() .iter()
.map(|tool| tool.name().to_string()) .map(|tool| tool.name().to_string())
@ -2341,7 +2343,7 @@ fn main() {{
.update(|_, cx| { .update(|_, cx| {
ThreadStore::load( ThreadStore::load(
project.clone(), project.clone(),
Arc::default(), cx.new(|_| ToolWorkingSet::default()),
Arc::new(PromptBuilder::new(None).unwrap()), Arc::new(PromptBuilder::new(None).unwrap()),
cx, cx,
) )

View file

@ -56,7 +56,7 @@ impl SharedProjectContext {
pub struct ThreadStore { pub struct ThreadStore {
project: Entity<Project>, project: Entity<Project>,
tools: Arc<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
context_server_manager: Entity<ContextServerManager>, context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>, context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
@ -74,7 +74,7 @@ impl EventEmitter<RulesLoadingError> for ThreadStore {}
impl ThreadStore { impl ThreadStore {
pub fn load( pub fn load(
project: Entity<Project>, project: Entity<Project>,
tools: Arc<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
cx: &mut App, cx: &mut App,
) -> Task<Entity<Self>> { ) -> Task<Entity<Self>> {
@ -88,7 +88,7 @@ impl ThreadStore {
fn new( fn new(
project: Entity<Project>, project: Entity<Project>,
tools: Arc<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
@ -248,7 +248,7 @@ impl ThreadStore {
self.context_server_manager.clone() self.context_server_manager.clone()
} }
pub fn tools(&self) -> Arc<ToolWorkingSet> { pub fn tools(&self) -> Entity<ToolWorkingSet> {
self.tools.clone() self.tools.clone()
} }
@ -355,52 +355,60 @@ impl ThreadStore {
}) })
} }
fn load_default_profile(&self, cx: &Context<Self>) { fn load_default_profile(&self, cx: &mut Context<Self>) {
let assistant_settings = AssistantSettings::get_global(cx); let assistant_settings = AssistantSettings::get_global(cx);
self.load_profile_by_id(&assistant_settings.default_profile, cx); self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
} }
pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context<Self>) { pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
let assistant_settings = AssistantSettings::get_global(cx); let assistant_settings = AssistantSettings::get_global(cx);
if let Some(profile) = assistant_settings.profiles.get(profile_id) { if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
self.load_profile(profile, cx); self.load_profile(profile.clone(), cx);
} }
} }
pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) { pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
self.tools.disable_all_tools(); self.tools.update(cx, |tools, cx| {
self.tools.enable( tools.disable_all_tools(cx);
ToolSource::Native, tools.enable(
&profile ToolSource::Native,
.tools &profile
.iter() .tools
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) .iter()
.collect::<Vec<_>>(), .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
); .collect::<Vec<_>>(),
cx,
);
});
if profile.enable_all_context_servers { if profile.enable_all_context_servers {
for context_server in self.context_server_manager.read(cx).all_servers() { for context_server in self.context_server_manager.read(cx).all_servers() {
self.tools.enable_source( self.tools.update(cx, |tools, cx| {
ToolSource::ContextServer { tools.enable_source(
id: context_server.id().into(), ToolSource::ContextServer {
}, id: context_server.id().into(),
cx, },
); cx,
);
});
} }
} else { } else {
for (context_server_id, preset) in &profile.context_servers { for (context_server_id, preset) in &profile.context_servers {
self.tools.enable( self.tools.update(cx, |tools, cx| {
ToolSource::ContextServer { tools.enable(
id: context_server_id.clone().into(), ToolSource::ContextServer {
}, id: context_server_id.clone().into(),
&preset },
.tools &preset
.iter() .tools
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) .iter()
.collect::<Vec<_>>(), .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
) .collect::<Vec<_>>(),
cx,
)
})
} }
} }
} }
@ -434,29 +442,36 @@ impl ThreadStore {
if protocol.capable(context_server::protocol::ServerCapability::Tools) { if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() { if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tools let tool_ids = tool_working_set
.tools .update(cx, |tool_working_set, _| {
.into_iter() tools
.map(|tool| { .tools
log::info!( .into_iter()
"registering context server tool: {:?}", .map(|tool| {
tool.name log::info!(
); "registering context server tool: {:?}",
tool_working_set.insert(Arc::new( tool.name
ContextServerTool::new( );
context_server_manager.clone(), tool_working_set.insert(Arc::new(
server.id(), ContextServerTool::new(
tool, context_server_manager.clone(),
), server.id(),
)) tool,
),
))
})
.collect::<Vec<_>>()
}) })
.collect::<Vec<_>>(); .log_err();
this.update(cx, |this, cx| { if let Some(tool_ids) = tool_ids {
this.context_server_tool_ids.insert(server_id, tool_ids); this.update(cx, |this, cx| {
this.load_default_profile(cx); this.context_server_tool_ids
}) .insert(server_id, tool_ids);
.log_err(); this.load_default_profile(cx);
})
.log_err();
}
} }
} }
} }
@ -466,7 +481,9 @@ impl ThreadStore {
} }
context_server::manager::Event::ServerStopped { server_id } => { context_server::manager::Event::ServerStopped { server_id } => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.remove(&tool_ids); tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
});
self.load_default_profile(cx); self.load_default_profile(cx);
} }
} }

View file

@ -5,7 +5,7 @@ use assistant_tool::{Tool, ToolWorkingSet};
use collections::HashMap; use collections::HashMap;
use futures::FutureExt as _; use futures::FutureExt as _;
use futures::future::Shared; use futures::future::Shared;
use gpui::{App, SharedString, Task}; use gpui::{App, Entity, SharedString, Task};
use language_model::{ use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
@ -49,7 +49,7 @@ impl ToolUseStatus {
} }
pub struct ToolUseState { pub struct ToolUseState {
tools: Arc<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>, tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>, tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>, tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
@ -59,7 +59,7 @@ pub struct ToolUseState {
pub const USING_TOOL_MARKER: &str = "<using_tool>"; pub const USING_TOOL_MARKER: &str = "<using_tool>";
impl ToolUseState { impl ToolUseState {
pub fn new(tools: Arc<ToolWorkingSet>) -> Self { pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
Self { Self {
tools, tools,
tool_uses_by_assistant_message: HashMap::default(), tool_uses_by_assistant_message: HashMap::default(),
@ -73,7 +73,7 @@ impl ToolUseState {
/// ///
/// Accepts a function to filter the tools that should be used to populate the state. /// Accepts a function to filter the tools that should be used to populate the state.
pub fn from_serialized_messages( pub fn from_serialized_messages(
tools: Arc<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
messages: &[SerializedMessage], messages: &[SerializedMessage],
mut filter_by_tool_name: impl FnMut(&str) -> bool, mut filter_by_tool_name: impl FnMut(&str) -> bool,
) -> Self { ) -> Self {
@ -199,12 +199,12 @@ impl ToolUseState {
} }
})(); })();
let (icon, needs_confirmation) = if let Some(tool) = self.tools.tool(&tool_use.name, cx) let (icon, needs_confirmation) =
{ if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx)) (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
} else { } else {
(IconName::Cog, false) (IconName::Cog, false)
}; };
tool_uses.push(ToolUse { tool_uses.push(ToolUse {
id: tool_use.id.clone(), id: tool_use.id.clone(),
@ -226,7 +226,7 @@ impl ToolUseState {
input: &serde_json::Value, input: &serde_json::Value,
cx: &App, cx: &App,
) -> SharedString { ) -> SharedString {
if let Some(tool) = self.tools.tool(tool_name, cx) { if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
tool.ui_text(input).into() tool.ui_text(input).into()
} else { } else {
format!("Unknown tool {tool_name:?}").into() format!("Unknown tool {tool_name:?}").into()

View file

@ -1,8 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use collections::{HashMap, HashSet, IndexMap}; use collections::{HashMap, HashSet, IndexMap};
use gpui::App; use gpui::{App, Context, EventEmitter};
use parking_lot::Mutex;
use crate::{Tool, ToolRegistry, ToolSource}; use crate::{Tool, ToolRegistry, ToolSource};
@ -12,11 +11,6 @@ pub struct ToolId(usize);
/// A working set of tools for use in one instance of the Assistant Panel. /// A working set of tools for use in one instance of the Assistant Panel.
#[derive(Default)] #[derive(Default)]
pub struct ToolWorkingSet { pub struct ToolWorkingSet {
state: Mutex<WorkingSetState>,
}
#[derive(Default)]
struct WorkingSetState {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>, context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>, context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
enabled_sources: HashSet<ToolSource>, enabled_sources: HashSet<ToolSource>,
@ -24,99 +18,27 @@ struct WorkingSetState {
next_tool_id: ToolId, next_tool_id: ToolId,
} }
pub enum ToolWorkingSetEvent {
EnabledToolsChanged,
}
impl EventEmitter<ToolWorkingSetEvent> for ToolWorkingSet {}
impl ToolWorkingSet { impl ToolWorkingSet {
pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> { pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
self.state self.context_server_tools_by_name
.lock()
.context_server_tools_by_name
.get(name) .get(name)
.cloned() .cloned()
.or_else(|| ToolRegistry::global(cx).tool(name)) .or_else(|| ToolRegistry::global(cx).tool(name))
} }
pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> { pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
self.state.lock().tools(cx)
}
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
self.state.lock().tools_by_source(cx)
}
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
self.state.lock().enabled_tools(cx)
}
pub fn disable_all_tools(&self) {
let mut state = self.state.lock();
state.disable_all_tools();
}
pub fn enable_source(&self, source: ToolSource, cx: &App) {
let mut state = self.state.lock();
state.enable_source(source, cx);
}
pub fn disable_source(&self, source: &ToolSource) {
let mut state = self.state.lock();
state.disable_source(source);
}
pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
let mut state = self.state.lock();
let tool_id = state.next_tool_id;
state.next_tool_id.0 += 1;
state
.context_server_tools_by_id
.insert(tool_id, tool.clone());
state.tools_changed();
tool_id
}
pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.state.lock().is_enabled(source, name)
}
pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.state.lock().is_disabled(source, name)
}
pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
let mut state = self.state.lock();
state.enable(source, tools_to_enable);
}
pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
let mut state = self.state.lock();
state.disable(source, tools_to_disable);
}
pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
let mut state = self.state.lock();
state
.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
state.tools_changed();
}
}
impl WorkingSetState {
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.map(|tool| (tool.name(), tool.clone())),
);
}
fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let mut tools = ToolRegistry::global(cx).tools(); let mut tools = ToolRegistry::global(cx).tools();
tools.extend(self.context_server_tools_by_id.values().cloned()); tools.extend(self.context_server_tools_by_id.values().cloned());
tools tools
} }
fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> { pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
let mut tools_by_source = IndexMap::default(); let mut tools_by_source = IndexMap::default();
for tool in self.tools(cx) { for tool in self.tools(cx) {
@ -135,7 +57,7 @@ impl WorkingSetState {
tools_by_source tools_by_source
} }
fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> { pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let all_tools = self.tools(cx); let all_tools = self.tools(cx);
all_tools all_tools
@ -144,31 +66,12 @@ impl WorkingSetState {
.collect() .collect()
} }
fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool { pub fn disable_all_tools(&mut self, cx: &mut Context<Self>) {
self.enabled_tools_by_source self.enabled_tools_by_source.clear();
.get(source) cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
.map_or(false, |enabled_tools| enabled_tools.contains(name))
} }
fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool { pub fn enable_source(&mut self, source: ToolSource, cx: &mut Context<Self>) {
!self.is_enabled(source, name)
}
fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.extend(tools_to_enable.into_iter().cloned());
}
fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.retain(|name| !tools_to_disable.contains(name));
}
fn enable_source(&mut self, source: ToolSource, cx: &App) {
self.enabled_sources.insert(source.clone()); self.enabled_sources.insert(source.clone());
let tools_by_source = self.tools_by_source(cx); let tools_by_source = self.tools_by_source(cx);
@ -181,14 +84,72 @@ impl WorkingSetState {
.collect::<HashSet<_>>(), .collect::<HashSet<_>>(),
); );
} }
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
} }
fn disable_source(&mut self, source: &ToolSource) { pub fn disable_source(&mut self, source: &ToolSource, cx: &mut Context<Self>) {
self.enabled_sources.remove(source); self.enabled_sources.remove(source);
self.enabled_tools_by_source.remove(source); self.enabled_tools_by_source.remove(source);
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
} }
fn disable_all_tools(&mut self) { pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
self.enabled_tools_by_source.clear(); let tool_id = self.next_tool_id;
self.next_tool_id.0 += 1;
self.context_server_tools_by_id
.insert(tool_id, tool.clone());
self.tools_changed();
tool_id
}
pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.enabled_tools_by_source
.get(source)
.map_or(false, |enabled_tools| enabled_tools.contains(name))
}
pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
!self.is_enabled(source, name)
}
pub fn enable(
&mut self,
source: ToolSource,
tools_to_enable: &[Arc<str>],
cx: &mut Context<Self>,
) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.extend(tools_to_enable.into_iter().cloned());
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
pub fn disable(
&mut self,
source: ToolSource,
tools_to_disable: &[Arc<str>],
cx: &mut Context<Self>,
) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.retain(|name| !tools_to_disable.contains(name));
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
self.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
self.tools_changed();
}
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.map(|tool| (tool.name(), tool.clone())),
);
} }
} }

View file

@ -6,7 +6,7 @@ use collections::HashMap;
use dap::DapRegistry; use dap::DapRegistry;
use futures::channel::mpsc; use futures::channel::mpsc;
use futures::{FutureExt, StreamExt as _, select_biased}; use futures::{FutureExt, StreamExt as _, select_biased};
use gpui::{App, AsyncApp, Entity, Task}; use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
use handlebars::Handlebars; use handlebars::Handlebars;
use language::{DiagnosticSeverity, OffsetRangeExt}; use language::{DiagnosticSeverity, OffsetRangeExt};
use language_model::{ use language_model::{
@ -181,7 +181,7 @@ impl Example {
project.create_worktree(&worktree_path, true, cx) project.create_worktree(&worktree_path, true, cx)
}); });
let tools = Arc::new(ToolWorkingSet::default()); let tools = cx.new(|_| ToolWorkingSet::default());
let thread_store = let thread_store =
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx); ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
let this = self.clone(); let this = self.clone();