assistant_tool: Reduce locking in ToolWorkingSet (#26605)

This PR updates the `ToolWorkingSet` to reduce the amount of locking we
need to do.

A number of the methods have had corresponding versions moved to the
`ToolWorkingSetState` so that we can take out the lock once and do a
number of operations without needing to continually acquire and release
the lock.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-03-12 17:26:26 -04:00 committed by GitHub
parent edeed7b619
commit e60e8f3a0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -46,56 +46,104 @@ impl ToolWorkingSet {
} }
pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> { pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let mut tools = ToolRegistry::global(cx).tools(); self.state.lock().tools(cx)
tools.extend( }
self.state
.lock()
.context_server_tools_by_id
.values()
.cloned(),
);
tools pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
self.state.lock().tools_by_source(cx)
} }
pub fn are_all_tools_enabled(&self) -> bool { pub fn are_all_tools_enabled(&self) -> bool {
let state = self.state.lock(); let state = self.state.lock();
state.disabled_tools_by_source.is_empty() && !state.is_scripting_tool_disabled state.disabled_tools_by_source.is_empty() && !state.is_scripting_tool_disabled
} }
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
self.state.lock().enabled_tools(cx)
}
pub fn enable_all_tools(&self) { pub fn enable_all_tools(&self) {
let mut state = self.state.lock(); let mut state = self.state.lock();
state.disabled_tools_by_source.clear(); state.disabled_tools_by_source.clear();
state.is_scripting_tool_disabled = false; state.enable_scripting_tool();
} }
pub fn disable_all_tools(&self, cx: &App) { pub fn disable_all_tools(&self, cx: &App) {
let tools = self.tools_by_source(cx); let mut state = self.state.lock();
state.disable_all_tools(cx);
for (source, tools) in tools {
let tool_names = tools
.into_iter()
.map(|tool| tool.name().into())
.collect::<Vec<_>>();
self.disable(source, &tool_names);
} }
self.disable_scripting_tool(); pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
let mut state = self.state.lock();
let tool_id = state.next_tool_id;
state.next_tool_id.0 += 1;
state
.context_server_tools_by_id
.insert(tool_id, tool.clone());
state.tools_changed();
tool_id
} }
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> { pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
let all_tools = self.tools(cx); self.state.lock().is_enabled(source, name)
all_tools
.into_iter()
.filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
.collect()
} }
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> { pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.state.lock().is_disabled(source, name)
}
pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
let mut state = self.state.lock();
state.enable(source, tools_to_enable);
}
pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
let mut state = self.state.lock();
state.disable(source, tools_to_disable);
}
pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
let mut state = self.state.lock();
state
.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
state.tools_changed();
}
pub fn is_scripting_tool_enabled(&self) -> bool {
let state = self.state.lock();
!state.is_scripting_tool_disabled
}
pub fn enable_scripting_tool(&self) {
let mut state = self.state.lock();
state.enable_scripting_tool();
}
pub fn disable_scripting_tool(&self) {
let mut state = self.state.lock();
state.disable_scripting_tool();
}
}
impl WorkingSetState {
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.map(|tool| (tool.name(), tool.clone())),
);
}
fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let mut tools = ToolRegistry::global(cx).tools();
tools.extend(self.context_server_tools_by_id.values().cloned());
tools
}
fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
let mut tools_by_source = IndexMap::default(); let mut tools_by_source = IndexMap::default();
for tool in self.tools(cx) { for tool in self.tools(cx) {
@ -114,78 +162,59 @@ impl ToolWorkingSet {
tools_by_source tools_by_source
} }
pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId { fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let mut state = self.state.lock(); let all_tools = self.tools(cx);
let tool_id = state.next_tool_id;
state.next_tool_id.0 += 1; all_tools
state .into_iter()
.context_server_tools_by_id .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
.insert(tool_id, tool.clone()); .collect()
state.tools_changed();
tool_id
} }
pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool { fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
!self.is_disabled(source, name) !self.is_disabled(source, name)
} }
pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool { fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
let state = self.state.lock(); self.disabled_tools_by_source
state
.disabled_tools_by_source
.get(source) .get(source)
.map_or(false, |disabled_tools| disabled_tools.contains(name)) .map_or(false, |disabled_tools| disabled_tools.contains(name))
} }
pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) { fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
let mut state = self.state.lock(); self.disabled_tools_by_source
state
.disabled_tools_by_source
.entry(source) .entry(source)
.or_default() .or_default()
.retain(|name| !tools_to_enable.contains(name)); .retain(|name| !tools_to_enable.contains(name));
} }
pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) { fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
let mut state = self.state.lock(); self.disabled_tools_by_source
state
.disabled_tools_by_source
.entry(source) .entry(source)
.or_default() .or_default()
.extend(tools_to_disable.into_iter().cloned()); .extend(tools_to_disable.into_iter().cloned());
} }
pub fn remove(&self, tool_ids_to_remove: &[ToolId]) { fn disable_all_tools(&mut self, cx: &App) {
let mut state = self.state.lock(); let tools = self.tools_by_source(cx);
state
.context_server_tools_by_id for (source, tools) in tools {
.retain(|id, _| !tool_ids_to_remove.contains(id)); let tool_names = tools
state.tools_changed(); .into_iter()
.map(|tool| tool.name().into())
.collect::<Vec<_>>();
self.disable(source, &tool_names);
} }
pub fn is_scripting_tool_enabled(&self) -> bool { self.disable_scripting_tool();
let state = self.state.lock();
!state.is_scripting_tool_disabled
} }
pub fn enable_scripting_tool(&self) { fn enable_scripting_tool(&mut self) {
let mut state = self.state.lock(); self.is_scripting_tool_disabled = false;
state.is_scripting_tool_disabled = false;
} }
pub fn disable_scripting_tool(&self) { fn disable_scripting_tool(&mut self) {
let mut state = self.state.lock(); self.is_scripting_tool_disabled = true;
state.is_scripting_tool_disabled = true;
}
}
impl WorkingSetState {
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.map(|tool| (tool.name(), tool.clone())),
);
} }
} }