assistant2: Add support for using tools (#21190)

This PR adds rudimentary support for using tools to `assistant2`. There
are currently no visual affordances for tool use.

This is gated behind the `assistant-tool-use` feature flag.

<img width="1079" alt="Screenshot 2024-11-25 at 7 21 31 PM"
src="https://github.com/user-attachments/assets/64d6ca29-c592-4474-8e9d-c344f855bc63">

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-11-25 19:44:34 -05:00 committed by GitHub
parent 3901d46101
commit f059b6a24b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 263 additions and 37 deletions

3
Cargo.lock generated
View file

@ -455,6 +455,8 @@ name = "assistant2"
version = "0.1.0"
dependencies = [
"anyhow",
"assistant_tool",
"collections",
"command_palette_hooks",
"editor",
"feature_flags",
@ -463,6 +465,7 @@ dependencies = [
"language_model",
"language_model_selector",
"proto",
"serde_json",
"settings",
"smol",
"theme",

View file

@ -15,7 +15,7 @@ use assistant_tool::ToolWorkingSet;
use client::{self, proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::{HashMap, HashSet};
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
use feature_flags::{FeatureFlagAppExt, ToolUseFeatureFlag};
use fs::{Fs, RemoveOptions};
use futures::{future::Shared, FutureExt, StreamExt};
use gpui::{
@ -3201,16 +3201,6 @@ pub enum PendingSlashCommandStatus {
Error(String),
}
pub(crate) struct ToolUseFeatureFlag;
impl FeatureFlag for ToolUseFeatureFlag {
const NAME: &'static str = "assistant-tool-use";
fn enabled_for_staff() -> bool {
false
}
}
#[derive(Debug, Clone)]
pub struct PendingToolUse {
pub id: Arc<str>,

View file

@ -14,6 +14,8 @@ doctest = false
[dependencies]
anyhow.workspace = true
assistant_tool.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
editor.workspace = true
feature_flags.workspace = true
@ -23,6 +25,7 @@ language_model.workspace = true
language_model_selector.workspace = true
proto.workspace = true
settings.workspace = true
serde_json.workspace = true
smol.workspace = true
theme.workspace = true
ui.workspace = true

View file

@ -1,4 +1,7 @@
use std::sync::Arc;
use anyhow::Result;
use assistant_tool::ToolWorkingSet;
use gpui::{
prelude::*, px, Action, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
FocusableView, Model, Pixels, Subscription, Task, View, ViewContext, WeakView, WindowContext,
@ -10,7 +13,7 @@ use workspace::dock::{DockPosition, Panel, PanelEvent};
use workspace::Workspace;
use crate::message_editor::MessageEditor;
use crate::thread::Thread;
use crate::thread::{Thread, ThreadEvent};
use crate::{NewThread, ToggleFocus, ToggleModelSelector};
pub fn init(cx: &mut AppContext) {
@ -25,8 +28,10 @@ pub fn init(cx: &mut AppContext) {
}
pub struct AssistantPanel {
workspace: WeakView<Workspace>,
thread: Model<Thread>,
message_editor: View<MessageEditor>,
tools: Arc<ToolWorkingSet>,
_subscriptions: Vec<Subscription>,
}
@ -36,26 +41,36 @@ impl AssistantPanel {
cx: AsyncWindowContext,
) -> Task<Result<View<Self>>> {
cx.spawn(|mut cx| async move {
let tools = Arc::new(ToolWorkingSet::default());
workspace.update(&mut cx, |workspace, cx| {
cx.new_view(|cx| Self::new(workspace, cx))
cx.new_view(|cx| Self::new(workspace, tools, cx))
})
})
}
fn new(_workspace: &Workspace, cx: &mut ViewContext<Self>) -> Self {
let thread = cx.new_model(Thread::new);
let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())];
fn new(workspace: &Workspace, tools: Arc<ToolWorkingSet>, cx: &mut ViewContext<Self>) -> Self {
let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx));
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe(&thread, Self::handle_thread_event),
];
Self {
workspace: workspace.weak_handle(),
thread: thread.clone(),
message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
tools,
_subscriptions: subscriptions,
}
}
fn new_thread(&mut self, cx: &mut ViewContext<Self>) {
let thread = cx.new_model(Thread::new);
let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())];
let tools = self.thread.read(cx).tools().clone();
let thread = cx.new_model(|cx| Thread::new(tools, cx));
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe(&thread, Self::handle_thread_event),
];
self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx));
self.thread = thread;
@ -63,6 +78,38 @@ impl AssistantPanel {
self.message_editor.focus_handle(cx).focus(cx);
}
fn handle_thread_event(
&mut self,
_: Model<Thread>,
event: &ThreadEvent,
cx: &mut ViewContext<Self>,
) {
match event {
ThreadEvent::StreamedCompletion => {}
ThreadEvent::UsePendingTools => {
let pending_tool_uses = self
.thread
.read(cx)
.pending_tool_uses()
.into_iter()
.filter(|tool_use| tool_use.status.is_idle())
.cloned()
.collect::<Vec<_>>();
for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(tool_use.input, self.workspace.clone(), cx);
self.thread.update(cx, |thread, cx| {
thread.insert_tool_output(tool_use.id.clone(), task, cx);
});
}
}
}
ThreadEvent::ToolFinished { .. } => {}
}
}
}
impl FocusableView for AssistantPanel {

View file

@ -1,6 +1,7 @@
use editor::{Editor, EditorElement, EditorStyle};
use feature_flags::{FeatureFlagAppExt, ToolUseFeatureFlag};
use gpui::{AppContext, FocusableView, Model, TextStyle, View};
use language_model::LanguageModelRegistry;
use language_model::{LanguageModelRegistry, LanguageModelRequestTool};
use settings::Settings;
use theme::ThemeSettings;
use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding};
@ -55,7 +56,21 @@ impl MessageEditor {
self.thread.update(cx, |thread, cx| {
thread.insert_user_message(user_message);
let request = thread.to_completion_request(request_kind, cx);
let mut request = thread.to_completion_request(request_kind, cx);
if cx.has_flag::<ToolUseFeatureFlag>() {
request.tools = thread
.tools()
.tools(cx)
.into_iter()
.map(|tool| LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema: tool.input_schema(),
})
.collect();
}
thread.stream_completion(request, model, cx)
});

View file

@ -1,12 +1,16 @@
use std::sync::Arc;
use futures::StreamExt as _;
use anyhow::Result;
use assistant_tool::ToolWorkingSet;
use collections::HashMap;
use futures::future::Shared;
use futures::{FutureExt as _, StreamExt as _};
use gpui::{AppContext, EventEmitter, ModelContext, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
MessageContent, Role, StopReason,
LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
};
use util::{post_inc, ResultExt as _};
use util::post_inc;
#[derive(Debug, Clone, Copy)]
pub enum RequestKind {
@ -14,14 +18,12 @@ pub enum RequestKind {
}
/// A message in a [`Thread`].
#[derive(Debug)]
pub struct Message {
pub role: Role,
pub text: String,
}
struct PendingCompletion {
id: usize,
_task: Task<()>,
pub tool_uses: Vec<LanguageModelToolUse>,
pub tool_results: Vec<LanguageModelToolResult>,
}
/// A thread of conversation with the LLM.
@ -29,14 +31,20 @@ pub struct Thread {
messages: Vec<Message>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
tools: Arc<ToolWorkingSet>,
pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
completed_tool_uses_by_id: HashMap<Arc<str>, String>,
}
impl Thread {
pub fn new(_cx: &mut ModelContext<Self>) -> Self {
pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
Self {
tools,
messages: Vec::new(),
completion_count: 0,
pending_completions: Vec::new(),
pending_tool_uses_by_id: HashMap::default(),
completed_tool_uses_by_id: HashMap::default(),
}
}
@ -44,11 +52,31 @@ impl Thread {
self.messages.iter()
}
pub fn tools(&self) -> &Arc<ToolWorkingSet> {
&self.tools
}
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
self.pending_tool_uses_by_id.values().collect()
}
pub fn insert_user_message(&mut self, text: impl Into<String>) {
self.messages.push(Message {
let mut message = Message {
role: Role::User,
text: text.into(),
});
tool_uses: Vec::new(),
tool_results: Vec::new(),
};
for (tool_use_id, tool_output) in self.completed_tool_uses_by_id.drain() {
message.tool_results.push(LanguageModelToolResult {
tool_use_id: tool_use_id.to_string(),
content: tool_output,
is_error: false,
});
}
self.messages.push(message);
}
pub fn to_completion_request(
@ -70,9 +98,23 @@ impl Thread {
cache: false,
};
request_message
.content
.push(MessageContent::Text(message.text.clone()));
for tool_result in &message.tool_results {
request_message
.content
.push(MessageContent::ToolResult(tool_result.clone()));
}
if !message.text.is_empty() {
request_message
.content
.push(MessageContent::Text(message.text.clone()));
}
for tool_use in &message.tool_uses {
request_message
.content
.push(MessageContent::ToolUse(tool_use.clone()));
}
request.messages.push(request_message);
}
@ -103,6 +145,8 @@ impl Thread {
thread.messages.push(Message {
role: Role::Assistant,
text: String::new(),
tool_uses: Vec::new(),
tool_results: Vec::new(),
});
}
LanguageModelCompletionEvent::Stop(reason) => {
@ -115,7 +159,24 @@ impl Thread {
}
}
}
LanguageModelCompletionEvent::ToolUse(_tool_use) => {}
LanguageModelCompletionEvent::ToolUse(tool_use) => {
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant {
last_message.tool_uses.push(tool_use.clone());
}
}
let tool_use_id: Arc<str> = tool_use.id.into();
thread.pending_tool_uses_by_id.insert(
tool_use_id.clone(),
PendingToolUse {
id: tool_use_id,
name: tool_use.name,
input: tool_use.input,
status: PendingToolUseStatus::Idle,
},
);
}
}
cx.emit(ThreadEvent::StreamedCompletion);
@ -135,7 +196,35 @@ impl Thread {
};
let result = stream_completion.await;
let _ = result.log_err();
thread
.update(&mut cx, |_thread, cx| {
let error_message = if let Some(error) = result.as_ref().err() {
let error_message = error
.chain()
.map(|err| err.to_string())
.collect::<Vec<_>>()
.join("\n");
Some(error_message)
} else {
None
};
if let Some(error_message) = error_message {
eprintln!("Completion failed: {error_message:?}");
}
if let Ok(stop_reason) = result {
match stop_reason {
StopReason::ToolUse => {
cx.emit(ThreadEvent::UsePendingTools);
}
StopReason::EndTurn => {}
StopReason::MaxTokens => {}
}
}
})
.ok();
});
self.pending_completions.push(PendingCompletion {
@ -143,11 +232,80 @@ impl Thread {
_task: task,
});
}
pub fn insert_tool_output(
&mut self,
tool_use_id: Arc<str>,
output: Task<Result<String>>,
cx: &mut ModelContext<Self>,
) {
let insert_output_task = cx.spawn(|thread, mut cx| {
let tool_use_id = tool_use_id.clone();
async move {
let output = output.await;
thread
.update(&mut cx, |thread, cx| match output {
Ok(output) => {
thread
.completed_tool_uses_by_id
.insert(tool_use_id.clone(), output);
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
}
Err(err) => {
if let Some(tool_use) =
thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
{
tool_use.status = PendingToolUseStatus::Error(err.to_string());
}
}
})
.ok();
}
});
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
tool_use.status = PendingToolUseStatus::Running {
_task: insert_output_task.shared(),
};
}
}
}
#[derive(Debug, Clone)]
pub enum ThreadEvent {
StreamedCompletion,
UsePendingTools,
ToolFinished {
#[allow(unused)]
tool_use_id: Arc<str>,
},
}
impl EventEmitter<ThreadEvent> for Thread {}
struct PendingCompletion {
id: usize,
_task: Task<()>,
}
#[derive(Debug, Clone)]
pub struct PendingToolUse {
pub id: Arc<str>,
pub name: String,
pub input: serde_json::Value,
pub status: PendingToolUseStatus,
}
#[derive(Debug, Clone)]
pub enum PendingToolUseStatus {
Idle,
Running { _task: Shared<Task<()>> },
Error(#[allow(unused)] String),
}
impl PendingToolUseStatus {
pub fn is_idle(&self) -> bool {
matches!(self, PendingToolUseStatus::Idle)
}
}

View file

@ -30,7 +30,7 @@ impl Tool for NowTool {
}
fn description(&self) -> String {
"Returns the current datetime in RFC 3339 format.".into()
"Returns the current datetime in RFC 3339 format. Only use this tool when the user specifically asks for it or the current task would benefit from knowing the current datetime.".into()
}
fn input_schema(&self) -> serde_json::Value {

View file

@ -49,6 +49,16 @@ impl FeatureFlag for Assistant2FeatureFlag {
}
}
pub struct ToolUseFeatureFlag;
impl FeatureFlag for ToolUseFeatureFlag {
const NAME: &'static str = "assistant-tool-use";
fn enabled_for_staff() -> bool {
false
}
}
pub struct Remoting {}
impl FeatureFlag for Remoting {
const NAME: &'static str = "remoting";