Allow codebase search to be turned on or off within the composer for assistant2 (#11315)
   Release Notes: - N/A
This commit is contained in:
parent
43ad470e58
commit
1915a756a0
8 changed files with 209 additions and 138 deletions
|
@ -23,7 +23,7 @@ use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use settings::Settings;
|
use settings::Settings;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use ui::Composer;
|
use ui::{Composer, ProjectIndexButton};
|
||||||
use util::{paths::EMBEDDINGS_DIR, ResultExt};
|
use util::{paths::EMBEDDINGS_DIR, ResultExt};
|
||||||
use workspace::{
|
use workspace::{
|
||||||
dock::{DockPosition, Panel, PanelEvent},
|
dock::{DockPosition, Panel, PanelEvent},
|
||||||
|
@ -228,6 +228,7 @@ pub struct AssistantChat {
|
||||||
list_state: ListState,
|
list_state: ListState,
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
composer_editor: View<Editor>,
|
composer_editor: View<Editor>,
|
||||||
|
project_index_button: Option<View<ProjectIndexButton>>,
|
||||||
user_store: Model<UserStore>,
|
user_store: Model<UserStore>,
|
||||||
next_message_id: MessageId,
|
next_message_id: MessageId,
|
||||||
collapsed_messages: HashMap<MessageId, bool>,
|
collapsed_messages: HashMap<MessageId, bool>,
|
||||||
|
@ -263,6 +264,10 @@ impl AssistantChat {
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let project_index_button = project_index.clone().map(|project_index| {
|
||||||
|
cx.new_view(|cx| ProjectIndexButton::new(project_index, tool_registry.clone(), cx))
|
||||||
|
});
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
messages: Vec::new(),
|
messages: Vec::new(),
|
||||||
|
@ -275,6 +280,7 @@ impl AssistantChat {
|
||||||
list_state,
|
list_state,
|
||||||
user_store,
|
user_store,
|
||||||
language_registry,
|
language_registry,
|
||||||
|
project_index_button,
|
||||||
project_index,
|
project_index,
|
||||||
next_message_id: MessageId(0),
|
next_message_id: MessageId(0),
|
||||||
editing_message: None,
|
editing_message: None,
|
||||||
|
@ -397,7 +403,7 @@ impl AssistantChat {
|
||||||
{
|
{
|
||||||
this.tool_registry.definitions()
|
this.tool_registry.definitions()
|
||||||
} else {
|
} else {
|
||||||
&[]
|
Vec::new()
|
||||||
};
|
};
|
||||||
call_count += 1;
|
call_count += 1;
|
||||||
|
|
||||||
|
@ -590,7 +596,7 @@ impl AssistantChat {
|
||||||
element.child(Composer::new(
|
element.child(Composer::new(
|
||||||
body.clone(),
|
body.clone(),
|
||||||
self.user_store.read(cx).current_user(),
|
self.user_store.read(cx).current_user(),
|
||||||
self.tool_registry.clone(),
|
self.project_index_button.clone(),
|
||||||
crate::ui::ModelSelector::new(
|
crate::ui::ModelSelector::new(
|
||||||
cx.view().downgrade(),
|
cx.view().downgrade(),
|
||||||
self.model.clone(),
|
self.model.clone(),
|
||||||
|
@ -768,7 +774,7 @@ impl Render for AssistantChat {
|
||||||
.child(Composer::new(
|
.child(Composer::new(
|
||||||
self.composer_editor.clone(),
|
self.composer_editor.clone(),
|
||||||
self.user_store.read(cx).current_user(),
|
self.user_store.read(cx).current_user(),
|
||||||
self.tool_registry.clone(),
|
self.project_index_button.clone(),
|
||||||
crate::ui::ModelSelector::new(cx.view().downgrade(), self.model.clone())
|
crate::ui::ModelSelector::new(cx.view().downgrade(), self.model.clone())
|
||||||
.into_any_element(),
|
.into_any_element(),
|
||||||
))
|
))
|
||||||
|
|
|
@ -33,7 +33,7 @@ impl CompletionProvider {
|
||||||
messages: Vec<CompletionMessage>,
|
messages: Vec<CompletionMessage>,
|
||||||
stop: Vec<String>,
|
stop: Vec<String>,
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
tools: &[ToolFunctionDefinition],
|
tools: Vec<ToolFunctionDefinition>,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
|
||||||
{
|
{
|
||||||
self.0.complete(model, messages, stop, temperature, tools)
|
self.0.complete(model, messages, stop, temperature, tools)
|
||||||
|
@ -51,7 +51,7 @@ pub trait CompletionProviderBackend: 'static {
|
||||||
messages: Vec<CompletionMessage>,
|
messages: Vec<CompletionMessage>,
|
||||||
stop: Vec<String>,
|
stop: Vec<String>,
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
tools: &[ToolFunctionDefinition],
|
tools: Vec<ToolFunctionDefinition>,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>;
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ impl CompletionProviderBackend for CloudCompletionProvider {
|
||||||
messages: Vec<CompletionMessage>,
|
messages: Vec<CompletionMessage>,
|
||||||
stop: Vec<String>,
|
stop: Vec<String>,
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
tools: &[ToolFunctionDefinition],
|
tools: Vec<ToolFunctionDefinition>,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
|
||||||
{
|
{
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
|
|
|
@ -1,14 +1,17 @@
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use assistant_tooling::LanguageModelTool;
|
use assistant_tooling::{
|
||||||
use gpui::{percentage, prelude::*, Animation, AnimationExt, AnyView, Model, Task, Transformation};
|
// assistant_tool_button::{AssistantToolButton, ToolStatus},
|
||||||
|
LanguageModelTool,
|
||||||
|
};
|
||||||
|
use gpui::{prelude::*, Model, Task};
|
||||||
use project::Fs;
|
use project::Fs;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use semantic_index::{ProjectIndex, Status};
|
use semantic_index::{ProjectIndex, Status};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::{sync::Arc, time::Duration};
|
use std::sync::Arc;
|
||||||
use ui::{
|
use ui::{
|
||||||
div, prelude::*, ButtonLike, CollapsibleContainer, Color, Icon, IconName, Indicator, Label,
|
div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString,
|
||||||
SharedString, Tooltip, WindowContext,
|
WindowContext,
|
||||||
};
|
};
|
||||||
use util::ResultExt as _;
|
use util::ResultExt as _;
|
||||||
|
|
||||||
|
@ -199,13 +202,6 @@ impl LanguageModelTool for ProjectIndexTool {
|
||||||
cx.new_view(|_cx| ProjectIndexView { input, output })
|
cx.new_view(|_cx| ProjectIndexView { input, output })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn status_view(&self, cx: &mut WindowContext) -> Option<AnyView> {
|
|
||||||
Some(
|
|
||||||
cx.new_view(|cx| ProjectIndexStatusView::new(self.project_index.clone(), cx))
|
|
||||||
.into(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
|
fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
|
||||||
match &output {
|
match &output {
|
||||||
Ok(output) => {
|
Ok(output) => {
|
||||||
|
@ -236,82 +232,3 @@ impl LanguageModelTool for ProjectIndexTool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ProjectIndexStatusView {
|
|
||||||
project_index: Model<ProjectIndex>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProjectIndexStatusView {
|
|
||||||
pub fn new(project_index: Model<ProjectIndex>, cx: &mut ViewContext<Self>) -> Self {
|
|
||||||
cx.subscribe(&project_index, |_this, _, _status: &Status, cx| {
|
|
||||||
cx.notify();
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
Self { project_index }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Render for ProjectIndexStatusView {
|
|
||||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
|
||||||
let status = self.project_index.read(cx).status();
|
|
||||||
|
|
||||||
let is_enabled = match status {
|
|
||||||
Status::Idle => true,
|
|
||||||
_ => false,
|
|
||||||
};
|
|
||||||
|
|
||||||
let icon = match status {
|
|
||||||
Status::Idle => Icon::new(IconName::Code)
|
|
||||||
.size(IconSize::XSmall)
|
|
||||||
.color(Color::Default),
|
|
||||||
Status::Loading => Icon::new(IconName::Code)
|
|
||||||
.size(IconSize::XSmall)
|
|
||||||
.color(Color::Muted),
|
|
||||||
Status::Scanning { .. } => Icon::new(IconName::Code)
|
|
||||||
.size(IconSize::XSmall)
|
|
||||||
.color(Color::Muted),
|
|
||||||
};
|
|
||||||
|
|
||||||
let indicator = match status {
|
|
||||||
Status::Idle => Some(Indicator::dot().color(Color::Success)),
|
|
||||||
Status::Scanning { .. } => Some(Indicator::dot().color(Color::Warning)),
|
|
||||||
Status::Loading => Some(Indicator::icon(
|
|
||||||
Icon::new(IconName::Spinner)
|
|
||||||
.color(Color::Accent)
|
|
||||||
.with_animation(
|
|
||||||
"arrow-circle",
|
|
||||||
Animation::new(Duration::from_secs(2)).repeat(),
|
|
||||||
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
|
|
||||||
),
|
|
||||||
)),
|
|
||||||
};
|
|
||||||
|
|
||||||
ButtonLike::new("project-index")
|
|
||||||
.disabled(!is_enabled)
|
|
||||||
.child(
|
|
||||||
ui::IconWithIndicator::new(icon, indicator)
|
|
||||||
.indicator_border_color(Some(gpui::transparent_black())),
|
|
||||||
)
|
|
||||||
.tooltip({
|
|
||||||
move |cx| {
|
|
||||||
let (tooltip, meta) = match status {
|
|
||||||
Status::Idle => (
|
|
||||||
"Project index ready".to_string(),
|
|
||||||
Some("Click to disable".to_string()),
|
|
||||||
),
|
|
||||||
Status::Loading => ("Project index loading...".to_string(), None),
|
|
||||||
Status::Scanning { remaining_count } => (
|
|
||||||
"Project index scanning...".to_string(),
|
|
||||||
Some(format!("{} remaining...", remaining_count)),
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(meta) = meta {
|
|
||||||
Tooltip::with_meta(tooltip, None, meta, cx)
|
|
||||||
} else {
|
|
||||||
Tooltip::text(tooltip, cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
mod chat_message;
|
mod chat_message;
|
||||||
mod chat_notice;
|
mod chat_notice;
|
||||||
mod composer;
|
mod composer;
|
||||||
|
mod project_index_button;
|
||||||
|
|
||||||
#[cfg(feature = "stories")]
|
#[cfg(feature = "stories")]
|
||||||
mod stories;
|
mod stories;
|
||||||
|
@ -8,6 +9,7 @@ mod stories;
|
||||||
pub use chat_message::*;
|
pub use chat_message::*;
|
||||||
pub use chat_notice::*;
|
pub use chat_notice::*;
|
||||||
pub use composer::*;
|
pub use composer::*;
|
||||||
|
pub use project_index_button::*;
|
||||||
|
|
||||||
#[cfg(feature = "stories")]
|
#[cfg(feature = "stories")]
|
||||||
pub use stories::*;
|
pub use stories::*;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use assistant_tooling::ToolRegistry;
|
use crate::{ui::ProjectIndexButton, AssistantChat, CompletionProvider};
|
||||||
use client::User;
|
use client::User;
|
||||||
use editor::{Editor, EditorElement, EditorStyle};
|
use editor::{Editor, EditorElement, EditorStyle};
|
||||||
use gpui::{AnyElement, FontStyle, FontWeight, TextStyle, View, WeakView, WhiteSpace};
|
use gpui::{AnyElement, FontStyle, FontWeight, TextStyle, View, WeakView, WhiteSpace};
|
||||||
|
@ -7,13 +7,11 @@ use std::sync::Arc;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::{popover_menu, prelude::*, Avatar, ButtonLike, ContextMenu, Tooltip};
|
use ui::{popover_menu, prelude::*, Avatar, ButtonLike, ContextMenu, Tooltip};
|
||||||
|
|
||||||
use crate::{AssistantChat, CompletionProvider};
|
|
||||||
|
|
||||||
#[derive(IntoElement)]
|
#[derive(IntoElement)]
|
||||||
pub struct Composer {
|
pub struct Composer {
|
||||||
editor: View<Editor>,
|
editor: View<Editor>,
|
||||||
player: Option<Arc<User>>,
|
player: Option<Arc<User>>,
|
||||||
tool_registry: Arc<ToolRegistry>,
|
project_index_button: Option<View<ProjectIndexButton>>,
|
||||||
model_selector: AnyElement,
|
model_selector: AnyElement,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,20 +19,28 @@ impl Composer {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
editor: View<Editor>,
|
editor: View<Editor>,
|
||||||
player: Option<Arc<User>>,
|
player: Option<Arc<User>>,
|
||||||
tool_registry: Arc<ToolRegistry>,
|
project_index_button: Option<View<ProjectIndexButton>>,
|
||||||
model_selector: AnyElement,
|
model_selector: AnyElement,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
editor,
|
editor,
|
||||||
player,
|
player,
|
||||||
tool_registry,
|
project_index_button,
|
||||||
model_selector,
|
model_selector,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn render_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement {
|
||||||
|
h_flex().children(
|
||||||
|
self.project_index_button
|
||||||
|
.clone()
|
||||||
|
.map(|view| view.into_any_element()),
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RenderOnce for Composer {
|
impl RenderOnce for Composer {
|
||||||
fn render(self, cx: &mut WindowContext) -> impl IntoElement {
|
fn render(mut self, cx: &mut WindowContext) -> impl IntoElement {
|
||||||
let mut player_avatar = div().size(rems_from_px(20.)).into_any_element();
|
let mut player_avatar = div().size(rems_from_px(20.)).into_any_element();
|
||||||
if let Some(player) = self.player.clone() {
|
if let Some(player) = self.player.clone() {
|
||||||
player_avatar = Avatar::new(player.avatar_uri.clone())
|
player_avatar = Avatar::new(player.avatar_uri.clone())
|
||||||
|
@ -95,9 +101,7 @@ impl RenderOnce for Composer {
|
||||||
.gap_2()
|
.gap_2()
|
||||||
.justify_between()
|
.justify_between()
|
||||||
.w_full()
|
.w_full()
|
||||||
.child(h_flex().gap_1().children(
|
.child(h_flex().gap_1().child(self.render_tools(cx)))
|
||||||
self.tool_registry.status_views().iter().cloned(),
|
|
||||||
))
|
|
||||||
.child(h_flex().gap_1().child(self.model_selector)),
|
.child(h_flex().gap_1().child(self.model_selector)),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
|
109
crates/assistant2/src/ui/project_index_button.rs
Normal file
109
crates/assistant2/src/ui/project_index_button.rs
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
use assistant_tooling::ToolRegistry;
|
||||||
|
use gpui::{percentage, prelude::*, Animation, AnimationExt, Model, Transformation};
|
||||||
|
use semantic_index::{ProjectIndex, Status};
|
||||||
|
use std::{sync::Arc, time::Duration};
|
||||||
|
use ui::{prelude::*, ButtonLike, Color, Icon, IconName, Indicator, Tooltip};
|
||||||
|
|
||||||
|
use crate::tools::ProjectIndexTool;
|
||||||
|
|
||||||
|
pub struct ProjectIndexButton {
|
||||||
|
project_index: Model<ProjectIndex>,
|
||||||
|
tool_registry: Arc<ToolRegistry>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProjectIndexButton {
|
||||||
|
pub fn new(
|
||||||
|
project_index: Model<ProjectIndex>,
|
||||||
|
tool_registry: Arc<ToolRegistry>,
|
||||||
|
cx: &mut ViewContext<Self>,
|
||||||
|
) -> Self {
|
||||||
|
cx.subscribe(&project_index, |_this, _, _status: &Status, cx| {
|
||||||
|
cx.notify();
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
Self {
|
||||||
|
project_index,
|
||||||
|
tool_registry,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_enabled(&mut self, enabled: bool) {
|
||||||
|
self.tool_registry
|
||||||
|
.set_tool_enabled::<ProjectIndexTool>(enabled);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Render for ProjectIndexButton {
|
||||||
|
// Expanded information on ToolView
|
||||||
|
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||||
|
let status = self.project_index.read(cx).status();
|
||||||
|
let is_enabled = self.tool_registry.is_tool_enabled::<ProjectIndexTool>();
|
||||||
|
|
||||||
|
let icon = if is_enabled {
|
||||||
|
match status {
|
||||||
|
Status::Idle => Icon::new(IconName::Code)
|
||||||
|
.size(IconSize::XSmall)
|
||||||
|
.color(Color::Default),
|
||||||
|
Status::Loading => Icon::new(IconName::Code)
|
||||||
|
.size(IconSize::XSmall)
|
||||||
|
.color(Color::Muted),
|
||||||
|
Status::Scanning { .. } => Icon::new(IconName::Code)
|
||||||
|
.size(IconSize::XSmall)
|
||||||
|
.color(Color::Muted),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Icon::new(IconName::Code)
|
||||||
|
.size(IconSize::XSmall)
|
||||||
|
.color(Color::Disabled)
|
||||||
|
};
|
||||||
|
|
||||||
|
let indicator = if is_enabled {
|
||||||
|
match status {
|
||||||
|
Status::Idle => Some(Indicator::dot().color(Color::Success)),
|
||||||
|
Status::Scanning { .. } => Some(Indicator::dot().color(Color::Warning)),
|
||||||
|
Status::Loading => Some(Indicator::icon(
|
||||||
|
Icon::new(IconName::Spinner)
|
||||||
|
.color(Color::Accent)
|
||||||
|
.with_animation(
|
||||||
|
"arrow-circle",
|
||||||
|
Animation::new(Duration::from_secs(2)).repeat(),
|
||||||
|
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
ButtonLike::new("project-index")
|
||||||
|
.child(
|
||||||
|
ui::IconWithIndicator::new(icon, indicator)
|
||||||
|
.indicator_border_color(Some(gpui::transparent_black())),
|
||||||
|
)
|
||||||
|
.tooltip({
|
||||||
|
move |cx| {
|
||||||
|
let (tooltip, meta) = match status {
|
||||||
|
Status::Idle => (
|
||||||
|
"Project index ready".to_string(),
|
||||||
|
Some("Click to disable".to_string()),
|
||||||
|
),
|
||||||
|
Status::Loading => ("Project index loading...".to_string(), None),
|
||||||
|
Status::Scanning { remaining_count } => (
|
||||||
|
"Project index scanning...".to_string(),
|
||||||
|
Some(format!("{} remaining...", remaining_count)),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(meta) = meta {
|
||||||
|
Tooltip::with_meta(tooltip, None, meta, cx)
|
||||||
|
} else {
|
||||||
|
Tooltip::text(tooltip, cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.on_click(cx.listener(move |this, _, cx| {
|
||||||
|
this.set_enabled(!is_enabled);
|
||||||
|
cx.notify();
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,48 +1,86 @@
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use gpui::{AnyView, Task, WindowContext};
|
use gpui::{Task, WindowContext};
|
||||||
use std::collections::HashMap;
|
use std::{
|
||||||
|
any::TypeId,
|
||||||
|
collections::HashMap,
|
||||||
|
sync::atomic::{AtomicBool, Ordering::SeqCst},
|
||||||
|
};
|
||||||
|
|
||||||
use crate::tool::{
|
use crate::tool::{
|
||||||
LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
|
LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Internal Tool representation for the registry
|
||||||
|
pub struct Tool {
|
||||||
|
enabled: AtomicBool,
|
||||||
|
type_id: TypeId,
|
||||||
|
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
||||||
|
definition: ToolFunctionDefinition,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tool {
|
||||||
|
fn new(
|
||||||
|
type_id: TypeId,
|
||||||
|
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
||||||
|
definition: ToolFunctionDefinition,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
enabled: AtomicBool::new(true),
|
||||||
|
type_id,
|
||||||
|
call,
|
||||||
|
definition,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ToolRegistry {
|
pub struct ToolRegistry {
|
||||||
tools: HashMap<
|
tools: HashMap<String, Tool>,
|
||||||
String,
|
|
||||||
Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
|
||||||
>,
|
|
||||||
definitions: Vec<ToolFunctionDefinition>,
|
|
||||||
status_views: Vec<AnyView>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolRegistry {
|
impl ToolRegistry {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
tools: HashMap::new(),
|
tools: HashMap::new(),
|
||||||
definitions: Vec::new(),
|
|
||||||
status_views: Vec::new(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn definitions(&self) -> &[ToolFunctionDefinition] {
|
pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
|
||||||
&self.definitions
|
for tool in self.tools.values() {
|
||||||
|
if tool.type_id == TypeId::of::<T>() {
|
||||||
|
tool.enabled.store(is_enabled, SeqCst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
|
||||||
|
for tool in self.tools.values() {
|
||||||
|
if tool.type_id == TypeId::of::<T>() {
|
||||||
|
return tool.enabled.load(SeqCst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
|
||||||
|
self.tools
|
||||||
|
.values()
|
||||||
|
.filter(|tool| tool.enabled.load(SeqCst))
|
||||||
|
.map(|tool| tool.definition.clone())
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn register<T: 'static + LanguageModelTool>(
|
pub fn register<T: 'static + LanguageModelTool>(
|
||||||
&mut self,
|
&mut self,
|
||||||
tool: T,
|
tool: T,
|
||||||
cx: &mut WindowContext,
|
_cx: &mut WindowContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
self.definitions.push(tool.definition());
|
let definition = tool.definition();
|
||||||
|
|
||||||
if let Some(tool_view) = tool.status_view(cx) {
|
|
||||||
self.status_views.push(tool_view);
|
|
||||||
}
|
|
||||||
|
|
||||||
let name = tool.name();
|
let name = tool.name();
|
||||||
let previous = self.tools.insert(
|
|
||||||
name.clone(),
|
let registered_tool = Tool::new(
|
||||||
// registry.call(tool_call, cx)
|
TypeId::of::<T>(),
|
||||||
Box::new(
|
Box::new(
|
||||||
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
|
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
|
||||||
let name = tool_call.name.clone();
|
let name = tool_call.name.clone();
|
||||||
|
@ -77,8 +115,11 @@ impl ToolRegistry {
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
definition,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let previous = self.tools.insert(name.clone(), registered_tool);
|
||||||
|
|
||||||
if previous.is_some() {
|
if previous.is_some() {
|
||||||
return Err(anyhow!("already registered a tool with name {}", name));
|
return Err(anyhow!("already registered a tool with name {}", name));
|
||||||
}
|
}
|
||||||
|
@ -109,11 +150,7 @@ impl ToolRegistry {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
tool(tool_call, cx)
|
(tool.call)(tool_call, cx)
|
||||||
}
|
|
||||||
|
|
||||||
pub fn status_views(&self) -> &[AnyView] {
|
|
||||||
&self.status_views
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -104,8 +104,4 @@ pub trait LanguageModelTool {
|
||||||
output: Result<Self::Output>,
|
output: Result<Self::Output>,
|
||||||
cx: &mut WindowContext,
|
cx: &mut WindowContext,
|
||||||
) -> View<Self::View>;
|
) -> View<Self::View>;
|
||||||
|
|
||||||
fn status_view(&self, _cx: &mut WindowContext) -> Option<AnyView> {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue