assistant2: Add support for using tools (#21190)

This PR adds rudimentary support for using tools to `assistant2`. There
are currently no visual affordances for tool use.

This is gated behind the `assistant-tool-use` feature flag.

<img width="1079" alt="Screenshot 2024-11-25 at 7 21 31 PM"
src="https://github.com/user-attachments/assets/64d6ca29-c592-4474-8e9d-c344f855bc63">

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-11-25 19:44:34 -05:00 committed by GitHub
parent 3901d46101
commit f059b6a24b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 263 additions and 37 deletions

3
Cargo.lock generated
View file

@ -455,6 +455,8 @@ name = "assistant2"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"assistant_tool",
"collections",
"command_palette_hooks", "command_palette_hooks",
"editor", "editor",
"feature_flags", "feature_flags",
@ -463,6 +465,7 @@ dependencies = [
"language_model", "language_model",
"language_model_selector", "language_model_selector",
"proto", "proto",
"serde_json",
"settings", "settings",
"smol", "smol",
"theme", "theme",

View file

@ -15,7 +15,7 @@ use assistant_tool::ToolWorkingSet;
use client::{self, proto, telemetry::Telemetry}; use client::{self, proto, telemetry::Telemetry};
use clock::ReplicaId; use clock::ReplicaId;
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use feature_flags::{FeatureFlag, FeatureFlagAppExt}; use feature_flags::{FeatureFlagAppExt, ToolUseFeatureFlag};
use fs::{Fs, RemoveOptions}; use fs::{Fs, RemoveOptions};
use futures::{future::Shared, FutureExt, StreamExt}; use futures::{future::Shared, FutureExt, StreamExt};
use gpui::{ use gpui::{
@ -3201,16 +3201,6 @@ pub enum PendingSlashCommandStatus {
Error(String), Error(String),
} }
pub(crate) struct ToolUseFeatureFlag;
impl FeatureFlag for ToolUseFeatureFlag {
const NAME: &'static str = "assistant-tool-use";
fn enabled_for_staff() -> bool {
false
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct PendingToolUse { pub struct PendingToolUse {
pub id: Arc<str>, pub id: Arc<str>,

View file

@ -14,6 +14,8 @@ doctest = false
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
assistant_tool.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true command_palette_hooks.workspace = true
editor.workspace = true editor.workspace = true
feature_flags.workspace = true feature_flags.workspace = true
@ -23,6 +25,7 @@ language_model.workspace = true
language_model_selector.workspace = true language_model_selector.workspace = true
proto.workspace = true proto.workspace = true
settings.workspace = true settings.workspace = true
serde_json.workspace = true
smol.workspace = true smol.workspace = true
theme.workspace = true theme.workspace = true
ui.workspace = true ui.workspace = true

View file

@ -1,4 +1,7 @@
use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use assistant_tool::ToolWorkingSet;
use gpui::{ use gpui::{
prelude::*, px, Action, AppContext, AsyncWindowContext, EventEmitter, FocusHandle, prelude::*, px, Action, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
FocusableView, Model, Pixels, Subscription, Task, View, ViewContext, WeakView, WindowContext, FocusableView, Model, Pixels, Subscription, Task, View, ViewContext, WeakView, WindowContext,
@ -10,7 +13,7 @@ use workspace::dock::{DockPosition, Panel, PanelEvent};
use workspace::Workspace; use workspace::Workspace;
use crate::message_editor::MessageEditor; use crate::message_editor::MessageEditor;
use crate::thread::Thread; use crate::thread::{Thread, ThreadEvent};
use crate::{NewThread, ToggleFocus, ToggleModelSelector}; use crate::{NewThread, ToggleFocus, ToggleModelSelector};
pub fn init(cx: &mut AppContext) { pub fn init(cx: &mut AppContext) {
@ -25,8 +28,10 @@ pub fn init(cx: &mut AppContext) {
} }
pub struct AssistantPanel { pub struct AssistantPanel {
workspace: WeakView<Workspace>,
thread: Model<Thread>, thread: Model<Thread>,
message_editor: View<MessageEditor>, message_editor: View<MessageEditor>,
tools: Arc<ToolWorkingSet>,
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
} }
@ -36,26 +41,36 @@ impl AssistantPanel {
cx: AsyncWindowContext, cx: AsyncWindowContext,
) -> Task<Result<View<Self>>> { ) -> Task<Result<View<Self>>> {
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
let tools = Arc::new(ToolWorkingSet::default());
workspace.update(&mut cx, |workspace, cx| { workspace.update(&mut cx, |workspace, cx| {
cx.new_view(|cx| Self::new(workspace, cx)) cx.new_view(|cx| Self::new(workspace, tools, cx))
}) })
}) })
} }
fn new(_workspace: &Workspace, cx: &mut ViewContext<Self>) -> Self { fn new(workspace: &Workspace, tools: Arc<ToolWorkingSet>, cx: &mut ViewContext<Self>) -> Self {
let thread = cx.new_model(Thread::new); let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx));
let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())]; let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe(&thread, Self::handle_thread_event),
];
Self { Self {
workspace: workspace.weak_handle(),
thread: thread.clone(), thread: thread.clone(),
message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)), message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
tools,
_subscriptions: subscriptions, _subscriptions: subscriptions,
} }
} }
fn new_thread(&mut self, cx: &mut ViewContext<Self>) { fn new_thread(&mut self, cx: &mut ViewContext<Self>) {
let thread = cx.new_model(Thread::new); let tools = self.thread.read(cx).tools().clone();
let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())]; let thread = cx.new_model(|cx| Thread::new(tools, cx));
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe(&thread, Self::handle_thread_event),
];
self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx)); self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx));
self.thread = thread; self.thread = thread;
@ -63,6 +78,38 @@ impl AssistantPanel {
self.message_editor.focus_handle(cx).focus(cx); self.message_editor.focus_handle(cx).focus(cx);
} }
fn handle_thread_event(
&mut self,
_: Model<Thread>,
event: &ThreadEvent,
cx: &mut ViewContext<Self>,
) {
match event {
ThreadEvent::StreamedCompletion => {}
ThreadEvent::UsePendingTools => {
let pending_tool_uses = self
.thread
.read(cx)
.pending_tool_uses()
.into_iter()
.filter(|tool_use| tool_use.status.is_idle())
.cloned()
.collect::<Vec<_>>();
for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(tool_use.input, self.workspace.clone(), cx);
self.thread.update(cx, |thread, cx| {
thread.insert_tool_output(tool_use.id.clone(), task, cx);
});
}
}
}
ThreadEvent::ToolFinished { .. } => {}
}
}
} }
impl FocusableView for AssistantPanel { impl FocusableView for AssistantPanel {

View file

@ -1,6 +1,7 @@
use editor::{Editor, EditorElement, EditorStyle}; use editor::{Editor, EditorElement, EditorStyle};
use feature_flags::{FeatureFlagAppExt, ToolUseFeatureFlag};
use gpui::{AppContext, FocusableView, Model, TextStyle, View}; use gpui::{AppContext, FocusableView, Model, TextStyle, View};
use language_model::LanguageModelRegistry; use language_model::{LanguageModelRegistry, LanguageModelRequestTool};
use settings::Settings; use settings::Settings;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding}; use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding};
@ -55,7 +56,21 @@ impl MessageEditor {
self.thread.update(cx, |thread, cx| { self.thread.update(cx, |thread, cx| {
thread.insert_user_message(user_message); thread.insert_user_message(user_message);
let request = thread.to_completion_request(request_kind, cx); let mut request = thread.to_completion_request(request_kind, cx);
if cx.has_flag::<ToolUseFeatureFlag>() {
request.tools = thread
.tools()
.tools(cx)
.into_iter()
.map(|tool| LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema: tool.input_schema(),
})
.collect();
}
thread.stream_completion(request, model, cx) thread.stream_completion(request, model, cx)
}); });

View file

@ -1,12 +1,16 @@
use std::sync::Arc; use std::sync::Arc;
use futures::StreamExt as _; use anyhow::Result;
use assistant_tool::ToolWorkingSet;
use collections::HashMap;
use futures::future::Shared;
use futures::{FutureExt as _, StreamExt as _};
use gpui::{AppContext, EventEmitter, ModelContext, Task}; use gpui::{AppContext, EventEmitter, ModelContext, Task};
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
MessageContent, Role, StopReason, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
}; };
use util::{post_inc, ResultExt as _}; use util::post_inc;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub enum RequestKind { pub enum RequestKind {
@ -14,14 +18,12 @@ pub enum RequestKind {
} }
/// A message in a [`Thread`]. /// A message in a [`Thread`].
#[derive(Debug)]
pub struct Message { pub struct Message {
pub role: Role, pub role: Role,
pub text: String, pub text: String,
} pub tool_uses: Vec<LanguageModelToolUse>,
pub tool_results: Vec<LanguageModelToolResult>,
struct PendingCompletion {
id: usize,
_task: Task<()>,
} }
/// A thread of conversation with the LLM. /// A thread of conversation with the LLM.
@ -29,14 +31,20 @@ pub struct Thread {
messages: Vec<Message>, messages: Vec<Message>,
completion_count: usize, completion_count: usize,
pending_completions: Vec<PendingCompletion>, pending_completions: Vec<PendingCompletion>,
tools: Arc<ToolWorkingSet>,
pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
completed_tool_uses_by_id: HashMap<Arc<str>, String>,
} }
impl Thread { impl Thread {
pub fn new(_cx: &mut ModelContext<Self>) -> Self { pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
Self { Self {
tools,
messages: Vec::new(), messages: Vec::new(),
completion_count: 0, completion_count: 0,
pending_completions: Vec::new(), pending_completions: Vec::new(),
pending_tool_uses_by_id: HashMap::default(),
completed_tool_uses_by_id: HashMap::default(),
} }
} }
@ -44,13 +52,33 @@ impl Thread {
self.messages.iter() self.messages.iter()
} }
pub fn tools(&self) -> &Arc<ToolWorkingSet> {
&self.tools
}
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
self.pending_tool_uses_by_id.values().collect()
}
pub fn insert_user_message(&mut self, text: impl Into<String>) { pub fn insert_user_message(&mut self, text: impl Into<String>) {
self.messages.push(Message { let mut message = Message {
role: Role::User, role: Role::User,
text: text.into(), text: text.into(),
tool_uses: Vec::new(),
tool_results: Vec::new(),
};
for (tool_use_id, tool_output) in self.completed_tool_uses_by_id.drain() {
message.tool_results.push(LanguageModelToolResult {
tool_use_id: tool_use_id.to_string(),
content: tool_output,
is_error: false,
}); });
} }
self.messages.push(message);
}
pub fn to_completion_request( pub fn to_completion_request(
&self, &self,
_request_kind: RequestKind, _request_kind: RequestKind,
@ -70,9 +98,23 @@ impl Thread {
cache: false, cache: false,
}; };
for tool_result in &message.tool_results {
request_message
.content
.push(MessageContent::ToolResult(tool_result.clone()));
}
if !message.text.is_empty() {
request_message request_message
.content .content
.push(MessageContent::Text(message.text.clone())); .push(MessageContent::Text(message.text.clone()));
}
for tool_use in &message.tool_uses {
request_message
.content
.push(MessageContent::ToolUse(tool_use.clone()));
}
request.messages.push(request_message); request.messages.push(request_message);
} }
@ -103,6 +145,8 @@ impl Thread {
thread.messages.push(Message { thread.messages.push(Message {
role: Role::Assistant, role: Role::Assistant,
text: String::new(), text: String::new(),
tool_uses: Vec::new(),
tool_results: Vec::new(),
}); });
} }
LanguageModelCompletionEvent::Stop(reason) => { LanguageModelCompletionEvent::Stop(reason) => {
@ -115,7 +159,24 @@ impl Thread {
} }
} }
} }
LanguageModelCompletionEvent::ToolUse(_tool_use) => {} LanguageModelCompletionEvent::ToolUse(tool_use) => {
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant {
last_message.tool_uses.push(tool_use.clone());
}
}
let tool_use_id: Arc<str> = tool_use.id.into();
thread.pending_tool_uses_by_id.insert(
tool_use_id.clone(),
PendingToolUse {
id: tool_use_id,
name: tool_use.name,
input: tool_use.input,
status: PendingToolUseStatus::Idle,
},
);
}
} }
cx.emit(ThreadEvent::StreamedCompletion); cx.emit(ThreadEvent::StreamedCompletion);
@ -135,7 +196,35 @@ impl Thread {
}; };
let result = stream_completion.await; let result = stream_completion.await;
let _ = result.log_err();
thread
.update(&mut cx, |_thread, cx| {
let error_message = if let Some(error) = result.as_ref().err() {
let error_message = error
.chain()
.map(|err| err.to_string())
.collect::<Vec<_>>()
.join("\n");
Some(error_message)
} else {
None
};
if let Some(error_message) = error_message {
eprintln!("Completion failed: {error_message:?}");
}
if let Ok(stop_reason) = result {
match stop_reason {
StopReason::ToolUse => {
cx.emit(ThreadEvent::UsePendingTools);
}
StopReason::EndTurn => {}
StopReason::MaxTokens => {}
}
}
})
.ok();
}); });
self.pending_completions.push(PendingCompletion { self.pending_completions.push(PendingCompletion {
@ -143,11 +232,80 @@ impl Thread {
_task: task, _task: task,
}); });
} }
pub fn insert_tool_output(
&mut self,
tool_use_id: Arc<str>,
output: Task<Result<String>>,
cx: &mut ModelContext<Self>,
) {
let insert_output_task = cx.spawn(|thread, mut cx| {
let tool_use_id = tool_use_id.clone();
async move {
let output = output.await;
thread
.update(&mut cx, |thread, cx| match output {
Ok(output) => {
thread
.completed_tool_uses_by_id
.insert(tool_use_id.clone(), output);
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
}
Err(err) => {
if let Some(tool_use) =
thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
{
tool_use.status = PendingToolUseStatus::Error(err.to_string());
}
}
})
.ok();
}
});
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
tool_use.status = PendingToolUseStatus::Running {
_task: insert_output_task.shared(),
};
}
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum ThreadEvent { pub enum ThreadEvent {
StreamedCompletion, StreamedCompletion,
UsePendingTools,
ToolFinished {
#[allow(unused)]
tool_use_id: Arc<str>,
},
} }
impl EventEmitter<ThreadEvent> for Thread {} impl EventEmitter<ThreadEvent> for Thread {}
struct PendingCompletion {
id: usize,
_task: Task<()>,
}
#[derive(Debug, Clone)]
pub struct PendingToolUse {
pub id: Arc<str>,
pub name: String,
pub input: serde_json::Value,
pub status: PendingToolUseStatus,
}
#[derive(Debug, Clone)]
pub enum PendingToolUseStatus {
Idle,
Running { _task: Shared<Task<()>> },
Error(#[allow(unused)] String),
}
impl PendingToolUseStatus {
pub fn is_idle(&self) -> bool {
matches!(self, PendingToolUseStatus::Idle)
}
}

View file

@ -30,7 +30,7 @@ impl Tool for NowTool {
} }
fn description(&self) -> String { fn description(&self) -> String {
"Returns the current datetime in RFC 3339 format.".into() "Returns the current datetime in RFC 3339 format. Only use this tool when the user specifically asks for it or the current task would benefit from knowing the current datetime.".into()
} }
fn input_schema(&self) -> serde_json::Value { fn input_schema(&self) -> serde_json::Value {

View file

@ -49,6 +49,16 @@ impl FeatureFlag for Assistant2FeatureFlag {
} }
} }
pub struct ToolUseFeatureFlag;
impl FeatureFlag for ToolUseFeatureFlag {
const NAME: &'static str = "assistant-tool-use";
fn enabled_for_staff() -> bool {
false
}
}
pub struct Remoting {} pub struct Remoting {}
impl FeatureFlag for Remoting { impl FeatureFlag for Remoting {
const NAME: &'static str = "remoting"; const NAME: &'static str = "remoting";