From 689e4aef2f18b3ad255f64d56257ce0eb47d14b1 Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Wed, 8 May 2024 10:24:51 -0700 Subject: [PATCH] Render messages as early as possible to show progress (#11569) This shows "Researching..." as placeholder text as early as possible so that the user can see the model is working on reading/researching/etc. This also adds on an `Option` to the `render_running` function so that tools can hopefully render based on partially completed JSON (still to come). Release Notes: - N/A --- Cargo.lock | 1 + crates/assistant2/src/assistant2.rs | 7 +++- crates/assistant2/src/tools/project_index.rs | 13 +++++-- crates/assistant_tooling/Cargo.toml | 1 + .../src/assistant_tooling.rs | 3 +- crates/assistant_tooling/src/tool_registry.rs | 36 +++++++++++++------ 6 files changed, 46 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 55a3763ad6..1663793bc9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -423,6 +423,7 @@ dependencies = [ "serde_json", "settings", "sum_tree", + "ui", "unindent", "util", ] diff --git a/crates/assistant2/src/assistant2.rs b/crates/assistant2/src/assistant2.rs index 66be8f4da6..79ae7c9c43 100644 --- a/crates/assistant2/src/assistant2.rs +++ b/crates/assistant2/src/assistant2.rs @@ -16,7 +16,8 @@ use crate::{ use ::ui::{div, prelude::*, Color, Tooltip, ViewContext}; use anyhow::{Context, Result}; use assistant_tooling::{ - AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment, + tool_running_placeholder, AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, + UserAttachment, }; use client::{proto, Client, UserStore}; use collections::HashMap; @@ -864,6 +865,10 @@ impl AssistantChat { } } + if message_elements.is_empty() { + message_elements.push(tool_running_placeholder()); + } + div() .when(is_first, |this| this.pt(padding)) .child( diff --git a/crates/assistant2/src/tools/project_index.rs b/crates/assistant2/src/tools/project_index.rs index c67c9216c1..32a56fe9ed 100644 --- a/crates/assistant2/src/tools/project_index.rs +++ b/crates/assistant2/src/tools/project_index.rs @@ -6,6 +6,7 @@ use project::ProjectPath; use schemars::JsonSchema; use semantic_index::{ProjectIndex, Status}; use serde::Deserialize; +use serde_json::Value; use std::{fmt::Write as _, ops::Range}; use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext}; @@ -202,8 +203,14 @@ impl LanguageModelTool for ProjectIndexTool { cx.new_view(|_cx| ProjectIndexView::new(input, output)) } - fn render_running(_: &mut WindowContext) -> impl IntoElement { - CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false) - .start_slot("Searching code base") + fn render_running(arguments: &Option, _: &mut WindowContext) -> impl IntoElement { + let text: String = arguments + .as_ref() + .and_then(|arguments| arguments.get("query")) + .and_then(|query| query.as_str()) + .map(|query| format!("Searching for: {}", query)) + .unwrap_or_else(|| "Preparing search...".to_string()); + + CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false).start_slot(text) } } diff --git a/crates/assistant_tooling/Cargo.toml b/crates/assistant_tooling/Cargo.toml index a69d1729d3..c7290f9c98 100644 --- a/crates/assistant_tooling/Cargo.toml +++ b/crates/assistant_tooling/Cargo.toml @@ -21,6 +21,7 @@ schemars.workspace = true serde.workspace = true serde_json.workspace = true sum_tree.workspace = true +ui.workspace = true util.workspace = true [dev-dependencies] diff --git a/crates/assistant_tooling/src/assistant_tooling.rs b/crates/assistant_tooling/src/assistant_tooling.rs index 6e5903c1f4..39dabf0830 100644 --- a/crates/assistant_tooling/src/assistant_tooling.rs +++ b/crates/assistant_tooling/src/assistant_tooling.rs @@ -5,5 +5,6 @@ mod tool_registry; pub use attachment_registry::{AttachmentRegistry, LanguageModelAttachment, UserAttachment}; pub use project_context::ProjectContext; pub use tool_registry::{ - LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition, ToolOutput, ToolRegistry, + tool_running_placeholder, LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition, + ToolOutput, ToolRegistry, }; diff --git a/crates/assistant_tooling/src/tool_registry.rs b/crates/assistant_tooling/src/tool_registry.rs index 5e1da303f9..d32f756e5f 100644 --- a/crates/assistant_tooling/src/tool_registry.rs +++ b/crates/assistant_tooling/src/tool_registry.rs @@ -4,6 +4,7 @@ use gpui::{ }; use schemars::{schema::RootSchema, schema_for, JsonSchema}; use serde::Deserialize; +use serde_json::Value; use std::{ any::TypeId, collections::HashMap, @@ -78,17 +79,22 @@ pub trait LanguageModelTool { /// Executes the tool with the given input. fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task>; + /// A view of the output of running the tool, for displaying to the user. fn output_view( input: Self::Input, output: Result, cx: &mut WindowContext, ) -> View; - fn render_running(_cx: &mut WindowContext) -> impl IntoElement { - div() + fn render_running(_arguments: &Option, _cx: &mut WindowContext) -> impl IntoElement { + tool_running_placeholder() } } +pub fn tool_running_placeholder() -> AnyElement { + ui::Label::new("Researching...").into_any_element() +} + pub trait ToolOutput: Sized { fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String; } @@ -97,7 +103,7 @@ struct RegisteredTool { enabled: AtomicBool, type_id: TypeId, call: Box Task>>, - render_running: fn(&mut WindowContext) -> gpui::AnyElement, + render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement, definition: ToolFunctionDefinition, } @@ -144,11 +150,15 @@ impl ToolRegistry { .p_2() .child(result.into_any_element(&tool_call.name)) .into_any_element(), - None => self - .registered_tools - .get(&tool_call.name) - .map(|tool| (tool.render_running)(cx)) - .unwrap_or_else(|| div().into_any_element()), + None => { + let tool = self.registered_tools.get(&tool_call.name); + + if let Some(tool) = tool { + (tool.render_running)(&tool_call, cx) + } else { + tool_running_placeholder() + } + } } } @@ -205,8 +215,14 @@ impl ToolRegistry { return Ok(()); - fn render_running(cx: &mut WindowContext) -> AnyElement { - T::render_running(cx).into_any_element() + fn render_running( + tool_call: &ToolFunctionCall, + cx: &mut WindowContext, + ) -> AnyElement { + // Attempt to parse the string arguments that are JSON as a JSON value + let maybe_arguments = serde_json::to_value(tool_call.arguments.clone()).ok(); + + T::render_running(&maybe_arguments, cx).into_any_element() } fn generate(