diff --git a/crates/assistant2/src/assistant2.rs b/crates/assistant2/src/assistant2.rs index 74ab6d3049..cee26a0830 100644 --- a/crates/assistant2/src/assistant2.rs +++ b/crates/assistant2/src/assistant2.rs @@ -23,7 +23,7 @@ use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex}; use serde::Deserialize; use settings::Settings; use std::sync::Arc; -use ui::Composer; +use ui::{Composer, ProjectIndexButton}; use util::{paths::EMBEDDINGS_DIR, ResultExt}; use workspace::{ dock::{DockPosition, Panel, PanelEvent}, @@ -228,6 +228,7 @@ pub struct AssistantChat { list_state: ListState, language_registry: Arc, composer_editor: View, + project_index_button: Option>, user_store: Model, next_message_id: MessageId, collapsed_messages: HashMap, @@ -263,6 +264,10 @@ impl AssistantChat { }, ); + let project_index_button = project_index.clone().map(|project_index| { + cx.new_view(|cx| ProjectIndexButton::new(project_index, tool_registry.clone(), cx)) + }); + Self { model, messages: Vec::new(), @@ -275,6 +280,7 @@ impl AssistantChat { list_state, user_store, language_registry, + project_index_button, project_index, next_message_id: MessageId(0), editing_message: None, @@ -397,7 +403,7 @@ impl AssistantChat { { this.tool_registry.definitions() } else { - &[] + Vec::new() }; call_count += 1; @@ -590,7 +596,7 @@ impl AssistantChat { element.child(Composer::new( body.clone(), self.user_store.read(cx).current_user(), - self.tool_registry.clone(), + self.project_index_button.clone(), crate::ui::ModelSelector::new( cx.view().downgrade(), self.model.clone(), @@ -768,7 +774,7 @@ impl Render for AssistantChat { .child(Composer::new( self.composer_editor.clone(), self.user_store.read(cx).current_user(), - self.tool_registry.clone(), + self.project_index_button.clone(), crate::ui::ModelSelector::new(cx.view().downgrade(), self.model.clone()) .into_any_element(), )) diff --git a/crates/assistant2/src/completion_provider.rs b/crates/assistant2/src/completion_provider.rs index 7351573119..07f2f75010 100644 --- a/crates/assistant2/src/completion_provider.rs +++ b/crates/assistant2/src/completion_provider.rs @@ -33,7 +33,7 @@ impl CompletionProvider { messages: Vec, stop: Vec, temperature: f32, - tools: &[ToolFunctionDefinition], + tools: Vec, ) -> BoxFuture<'static, Result>>> { self.0.complete(model, messages, stop, temperature, tools) @@ -51,7 +51,7 @@ pub trait CompletionProviderBackend: 'static { messages: Vec, stop: Vec, temperature: f32, - tools: &[ToolFunctionDefinition], + tools: Vec, ) -> BoxFuture<'static, Result>>>; } @@ -80,7 +80,7 @@ impl CompletionProviderBackend for CloudCompletionProvider { messages: Vec, stop: Vec, temperature: f32, - tools: &[ToolFunctionDefinition], + tools: Vec, ) -> BoxFuture<'static, Result>>> { let client = self.client.clone(); diff --git a/crates/assistant2/src/tools/project_index.rs b/crates/assistant2/src/tools/project_index.rs index 072ff87c1f..b6b8258964 100644 --- a/crates/assistant2/src/tools/project_index.rs +++ b/crates/assistant2/src/tools/project_index.rs @@ -1,14 +1,17 @@ use anyhow::Result; -use assistant_tooling::LanguageModelTool; -use gpui::{percentage, prelude::*, Animation, AnimationExt, AnyView, Model, Task, Transformation}; +use assistant_tooling::{ + // assistant_tool_button::{AssistantToolButton, ToolStatus}, + LanguageModelTool, +}; +use gpui::{prelude::*, Model, Task}; use project::Fs; use schemars::JsonSchema; use semantic_index::{ProjectIndex, Status}; use serde::Deserialize; -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; use ui::{ - div, prelude::*, ButtonLike, CollapsibleContainer, Color, Icon, IconName, Indicator, Label, - SharedString, Tooltip, WindowContext, + div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString, + WindowContext, }; use util::ResultExt as _; @@ -199,13 +202,6 @@ impl LanguageModelTool for ProjectIndexTool { cx.new_view(|_cx| ProjectIndexView { input, output }) } - fn status_view(&self, cx: &mut WindowContext) -> Option { - Some( - cx.new_view(|cx| ProjectIndexStatusView::new(self.project_index.clone(), cx)) - .into(), - ) - } - fn format(_input: &Self::Input, output: &Result) -> String { match &output { Ok(output) => { @@ -236,82 +232,3 @@ impl LanguageModelTool for ProjectIndexTool { } } } - -struct ProjectIndexStatusView { - project_index: Model, -} - -impl ProjectIndexStatusView { - pub fn new(project_index: Model, cx: &mut ViewContext) -> Self { - cx.subscribe(&project_index, |_this, _, _status: &Status, cx| { - cx.notify(); - }) - .detach(); - Self { project_index } - } -} - -impl Render for ProjectIndexStatusView { - fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - let status = self.project_index.read(cx).status(); - - let is_enabled = match status { - Status::Idle => true, - _ => false, - }; - - let icon = match status { - Status::Idle => Icon::new(IconName::Code) - .size(IconSize::XSmall) - .color(Color::Default), - Status::Loading => Icon::new(IconName::Code) - .size(IconSize::XSmall) - .color(Color::Muted), - Status::Scanning { .. } => Icon::new(IconName::Code) - .size(IconSize::XSmall) - .color(Color::Muted), - }; - - let indicator = match status { - Status::Idle => Some(Indicator::dot().color(Color::Success)), - Status::Scanning { .. } => Some(Indicator::dot().color(Color::Warning)), - Status::Loading => Some(Indicator::icon( - Icon::new(IconName::Spinner) - .color(Color::Accent) - .with_animation( - "arrow-circle", - Animation::new(Duration::from_secs(2)).repeat(), - |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), - ), - )), - }; - - ButtonLike::new("project-index") - .disabled(!is_enabled) - .child( - ui::IconWithIndicator::new(icon, indicator) - .indicator_border_color(Some(gpui::transparent_black())), - ) - .tooltip({ - move |cx| { - let (tooltip, meta) = match status { - Status::Idle => ( - "Project index ready".to_string(), - Some("Click to disable".to_string()), - ), - Status::Loading => ("Project index loading...".to_string(), None), - Status::Scanning { remaining_count } => ( - "Project index scanning...".to_string(), - Some(format!("{} remaining...", remaining_count)), - ), - }; - - if let Some(meta) = meta { - Tooltip::with_meta(tooltip, None, meta, cx) - } else { - Tooltip::text(tooltip, cx) - } - } - }) - } -} diff --git a/crates/assistant2/src/ui.rs b/crates/assistant2/src/ui.rs index 25b367949e..4dffffef91 100644 --- a/crates/assistant2/src/ui.rs +++ b/crates/assistant2/src/ui.rs @@ -1,6 +1,7 @@ mod chat_message; mod chat_notice; mod composer; +mod project_index_button; #[cfg(feature = "stories")] mod stories; @@ -8,6 +9,7 @@ mod stories; pub use chat_message::*; pub use chat_notice::*; pub use composer::*; +pub use project_index_button::*; #[cfg(feature = "stories")] pub use stories::*; diff --git a/crates/assistant2/src/ui/composer.rs b/crates/assistant2/src/ui/composer.rs index 105b3a242a..b094b3e8ce 100644 --- a/crates/assistant2/src/ui/composer.rs +++ b/crates/assistant2/src/ui/composer.rs @@ -1,4 +1,4 @@ -use assistant_tooling::ToolRegistry; +use crate::{ui::ProjectIndexButton, AssistantChat, CompletionProvider}; use client::User; use editor::{Editor, EditorElement, EditorStyle}; use gpui::{AnyElement, FontStyle, FontWeight, TextStyle, View, WeakView, WhiteSpace}; @@ -7,13 +7,11 @@ use std::sync::Arc; use theme::ThemeSettings; use ui::{popover_menu, prelude::*, Avatar, ButtonLike, ContextMenu, Tooltip}; -use crate::{AssistantChat, CompletionProvider}; - #[derive(IntoElement)] pub struct Composer { editor: View, player: Option>, - tool_registry: Arc, + project_index_button: Option>, model_selector: AnyElement, } @@ -21,20 +19,28 @@ impl Composer { pub fn new( editor: View, player: Option>, - tool_registry: Arc, + project_index_button: Option>, model_selector: AnyElement, ) -> Self { Self { editor, player, - tool_registry, + project_index_button, model_selector, } } + + fn render_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement { + h_flex().children( + self.project_index_button + .clone() + .map(|view| view.into_any_element()), + ) + } } impl RenderOnce for Composer { - fn render(self, cx: &mut WindowContext) -> impl IntoElement { + fn render(mut self, cx: &mut WindowContext) -> impl IntoElement { let mut player_avatar = div().size(rems_from_px(20.)).into_any_element(); if let Some(player) = self.player.clone() { player_avatar = Avatar::new(player.avatar_uri.clone()) @@ -95,9 +101,7 @@ impl RenderOnce for Composer { .gap_2() .justify_between() .w_full() - .child(h_flex().gap_1().children( - self.tool_registry.status_views().iter().cloned(), - )) + .child(h_flex().gap_1().child(self.render_tools(cx))) .child(h_flex().gap_1().child(self.model_selector)), ), ), diff --git a/crates/assistant2/src/ui/project_index_button.rs b/crates/assistant2/src/ui/project_index_button.rs new file mode 100644 index 0000000000..a34a3639d8 --- /dev/null +++ b/crates/assistant2/src/ui/project_index_button.rs @@ -0,0 +1,109 @@ +use assistant_tooling::ToolRegistry; +use gpui::{percentage, prelude::*, Animation, AnimationExt, Model, Transformation}; +use semantic_index::{ProjectIndex, Status}; +use std::{sync::Arc, time::Duration}; +use ui::{prelude::*, ButtonLike, Color, Icon, IconName, Indicator, Tooltip}; + +use crate::tools::ProjectIndexTool; + +pub struct ProjectIndexButton { + project_index: Model, + tool_registry: Arc, +} + +impl ProjectIndexButton { + pub fn new( + project_index: Model, + tool_registry: Arc, + cx: &mut ViewContext, + ) -> Self { + cx.subscribe(&project_index, |_this, _, _status: &Status, cx| { + cx.notify(); + }) + .detach(); + Self { + project_index, + tool_registry, + } + } + + pub fn set_enabled(&mut self, enabled: bool) { + self.tool_registry + .set_tool_enabled::(enabled); + } +} + +impl Render for ProjectIndexButton { + // Expanded information on ToolView + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + let status = self.project_index.read(cx).status(); + let is_enabled = self.tool_registry.is_tool_enabled::(); + + let icon = if is_enabled { + match status { + Status::Idle => Icon::new(IconName::Code) + .size(IconSize::XSmall) + .color(Color::Default), + Status::Loading => Icon::new(IconName::Code) + .size(IconSize::XSmall) + .color(Color::Muted), + Status::Scanning { .. } => Icon::new(IconName::Code) + .size(IconSize::XSmall) + .color(Color::Muted), + } + } else { + Icon::new(IconName::Code) + .size(IconSize::XSmall) + .color(Color::Disabled) + }; + + let indicator = if is_enabled { + match status { + Status::Idle => Some(Indicator::dot().color(Color::Success)), + Status::Scanning { .. } => Some(Indicator::dot().color(Color::Warning)), + Status::Loading => Some(Indicator::icon( + Icon::new(IconName::Spinner) + .color(Color::Accent) + .with_animation( + "arrow-circle", + Animation::new(Duration::from_secs(2)).repeat(), + |icon, delta| icon.transform(Transformation::rotate(percentage(delta))), + ), + )), + } + } else { + None + }; + + ButtonLike::new("project-index") + .child( + ui::IconWithIndicator::new(icon, indicator) + .indicator_border_color(Some(gpui::transparent_black())), + ) + .tooltip({ + move |cx| { + let (tooltip, meta) = match status { + Status::Idle => ( + "Project index ready".to_string(), + Some("Click to disable".to_string()), + ), + Status::Loading => ("Project index loading...".to_string(), None), + Status::Scanning { remaining_count } => ( + "Project index scanning...".to_string(), + Some(format!("{} remaining...", remaining_count)), + ), + }; + + if let Some(meta) = meta { + Tooltip::with_meta(tooltip, None, meta, cx) + } else { + Tooltip::text(tooltip, cx) + } + } + }) + .on_click(cx.listener(move |this, _, cx| { + this.set_enabled(!is_enabled); + cx.notify(); + })) + } +} diff --git a/crates/assistant_tooling/src/registry.rs b/crates/assistant_tooling/src/registry.rs index da9fe94b9e..b0e7dc4f2e 100644 --- a/crates/assistant_tooling/src/registry.rs +++ b/crates/assistant_tooling/src/registry.rs @@ -1,48 +1,86 @@ use anyhow::{anyhow, Result}; -use gpui::{AnyView, Task, WindowContext}; -use std::collections::HashMap; +use gpui::{Task, WindowContext}; +use std::{ + any::TypeId, + collections::HashMap, + sync::atomic::{AtomicBool, Ordering::SeqCst}, +}; use crate::tool::{ LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition, }; +// Internal Tool representation for the registry +pub struct Tool { + enabled: AtomicBool, + type_id: TypeId, + call: Box Task>>, + definition: ToolFunctionDefinition, +} + +impl Tool { + fn new( + type_id: TypeId, + call: Box Task>>, + definition: ToolFunctionDefinition, + ) -> Self { + Self { + enabled: AtomicBool::new(true), + type_id, + call, + definition, + } + } +} + pub struct ToolRegistry { - tools: HashMap< - String, - Box Task>>, - >, - definitions: Vec, - status_views: Vec, + tools: HashMap, } impl ToolRegistry { pub fn new() -> Self { Self { tools: HashMap::new(), - definitions: Vec::new(), - status_views: Vec::new(), } } - pub fn definitions(&self) -> &[ToolFunctionDefinition] { - &self.definitions + pub fn set_tool_enabled(&self, is_enabled: bool) { + for tool in self.tools.values() { + if tool.type_id == TypeId::of::() { + tool.enabled.store(is_enabled, SeqCst); + return; + } + } + } + + pub fn is_tool_enabled(&self) -> bool { + for tool in self.tools.values() { + if tool.type_id == TypeId::of::() { + return tool.enabled.load(SeqCst); + } + } + false + } + + pub fn definitions(&self) -> Vec { + self.tools + .values() + .filter(|tool| tool.enabled.load(SeqCst)) + .map(|tool| tool.definition.clone()) + .collect() } pub fn register( &mut self, tool: T, - cx: &mut WindowContext, + _cx: &mut WindowContext, ) -> Result<()> { - self.definitions.push(tool.definition()); - - if let Some(tool_view) = tool.status_view(cx) { - self.status_views.push(tool_view); - } + let definition = tool.definition(); let name = tool.name(); - let previous = self.tools.insert( - name.clone(), - // registry.call(tool_call, cx) + + let registered_tool = Tool::new( + TypeId::of::(), Box::new( move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| { let name = tool_call.name.clone(); @@ -77,8 +115,11 @@ impl ToolRegistry { }) }, ), + definition, ); + let previous = self.tools.insert(name.clone(), registered_tool); + if previous.is_some() { return Err(anyhow!("already registered a tool with name {}", name)); } @@ -109,11 +150,7 @@ impl ToolRegistry { } }; - tool(tool_call, cx) - } - - pub fn status_views(&self) -> &[AnyView] { - &self.status_views + (tool.call)(tool_call, cx) } } diff --git a/crates/assistant_tooling/src/tool.rs b/crates/assistant_tooling/src/tool.rs index 82536b9e8a..256d2eef8a 100644 --- a/crates/assistant_tooling/src/tool.rs +++ b/crates/assistant_tooling/src/tool.rs @@ -104,8 +104,4 @@ pub trait LanguageModelTool { output: Result, cx: &mut WindowContext, ) -> View; - - fn status_view(&self, _cx: &mut WindowContext) -> Option { - None - } }