Kyle Kelley 2024-05-02 13:26:46 -07:00 committed by GitHub
parent 43ad470e58
commit 1915a756a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 209 additions and 138 deletions

View file

@ -1,48 +1,86 @@
use anyhow::{anyhow, Result};
use gpui::{AnyView, Task, WindowContext};
use std::collections::HashMap;
use gpui::{Task, WindowContext};
use std::{
any::TypeId,
collections::HashMap,
sync::atomic::{AtomicBool, Ordering::SeqCst},
};
use crate::tool::{
LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
};
// Internal Tool representation for the registry
pub struct Tool {
enabled: AtomicBool,
type_id: TypeId,
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
definition: ToolFunctionDefinition,
}
impl Tool {
fn new(
type_id: TypeId,
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
definition: ToolFunctionDefinition,
) -> Self {
Self {
enabled: AtomicBool::new(true),
type_id,
call,
definition,
}
}
}
pub struct ToolRegistry {
tools: HashMap<
String,
Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
>,
definitions: Vec<ToolFunctionDefinition>,
status_views: Vec<AnyView>,
tools: HashMap<String, Tool>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
definitions: Vec::new(),
status_views: Vec::new(),
}
}
pub fn definitions(&self) -> &[ToolFunctionDefinition] {
&self.definitions
pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
for tool in self.tools.values() {
if tool.type_id == TypeId::of::<T>() {
tool.enabled.store(is_enabled, SeqCst);
return;
}
}
}
pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
for tool in self.tools.values() {
if tool.type_id == TypeId::of::<T>() {
return tool.enabled.load(SeqCst);
}
}
false
}
pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
self.tools
.values()
.filter(|tool| tool.enabled.load(SeqCst))
.map(|tool| tool.definition.clone())
.collect()
}
pub fn register<T: 'static + LanguageModelTool>(
&mut self,
tool: T,
cx: &mut WindowContext,
_cx: &mut WindowContext,
) -> Result<()> {
self.definitions.push(tool.definition());
if let Some(tool_view) = tool.status_view(cx) {
self.status_views.push(tool_view);
}
let definition = tool.definition();
let name = tool.name();
let previous = self.tools.insert(
name.clone(),
// registry.call(tool_call, cx)
let registered_tool = Tool::new(
TypeId::of::<T>(),
Box::new(
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
let name = tool_call.name.clone();
@ -77,8 +115,11 @@ impl ToolRegistry {
})
},
),
definition,
);
let previous = self.tools.insert(name.clone(), registered_tool);
if previous.is_some() {
return Err(anyhow!("already registered a tool with name {}", name));
}
@ -109,11 +150,7 @@ impl ToolRegistry {
}
};
tool(tool_call, cx)
}
pub fn status_views(&self) -> &[AnyView] {
&self.status_views
(tool.call)(tool_call, cx)
}
}

View file

@ -104,8 +104,4 @@ pub trait LanguageModelTool {
output: Result<Self::Output>,
cx: &mut WindowContext,
) -> View<Self::View>;
fn status_view(&self, _cx: &mut WindowContext) -> Option<AnyView> {
None
}
}