assistant2: Add support for using tools (#21190)
This PR adds rudimentary support for using tools to `assistant2`. There are currently no visual affordances for tool use. This is gated behind the `assistant-tool-use` feature flag. <img width="1079" alt="Screenshot 2024-11-25 at 7 21 31 PM" src="https://github.com/user-attachments/assets/64d6ca29-c592-4474-8e9d-c344f855bc63"> Release Notes: - N/A
This commit is contained in:
parent
3901d46101
commit
f059b6a24b
8 changed files with 263 additions and 37 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -455,6 +455,8 @@ name = "assistant2"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"collections",
|
||||
"command_palette_hooks",
|
||||
"editor",
|
||||
"feature_flags",
|
||||
|
@ -463,6 +465,7 @@ dependencies = [
|
|||
"language_model",
|
||||
"language_model_selector",
|
||||
"proto",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"theme",
|
||||
|
|
|
@ -15,7 +15,7 @@ use assistant_tool::ToolWorkingSet;
|
|||
use client::{self, proto, telemetry::Telemetry};
|
||||
use clock::ReplicaId;
|
||||
use collections::{HashMap, HashSet};
|
||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
|
||||
use feature_flags::{FeatureFlagAppExt, ToolUseFeatureFlag};
|
||||
use fs::{Fs, RemoveOptions};
|
||||
use futures::{future::Shared, FutureExt, StreamExt};
|
||||
use gpui::{
|
||||
|
@ -3201,16 +3201,6 @@ pub enum PendingSlashCommandStatus {
|
|||
Error(String),
|
||||
}
|
||||
|
||||
pub(crate) struct ToolUseFeatureFlag;
|
||||
|
||||
impl FeatureFlag for ToolUseFeatureFlag {
|
||||
const NAME: &'static str = "assistant-tool-use";
|
||||
|
||||
fn enabled_for_staff() -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PendingToolUse {
|
||||
pub id: Arc<str>,
|
||||
|
|
|
@ -14,6 +14,8 @@ doctest = false
|
|||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
|
@ -23,6 +25,7 @@ language_model.workspace = true
|
|||
language_model_selector.workspace = true
|
||||
proto.workspace = true
|
||||
settings.workspace = true
|
||||
serde_json.workspace = true
|
||||
smol.workspace = true
|
||||
theme.workspace = true
|
||||
ui.workspace = true
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use gpui::{
|
||||
prelude::*, px, Action, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
|
||||
FocusableView, Model, Pixels, Subscription, Task, View, ViewContext, WeakView, WindowContext,
|
||||
|
@ -10,7 +13,7 @@ use workspace::dock::{DockPosition, Panel, PanelEvent};
|
|||
use workspace::Workspace;
|
||||
|
||||
use crate::message_editor::MessageEditor;
|
||||
use crate::thread::Thread;
|
||||
use crate::thread::{Thread, ThreadEvent};
|
||||
use crate::{NewThread, ToggleFocus, ToggleModelSelector};
|
||||
|
||||
pub fn init(cx: &mut AppContext) {
|
||||
|
@ -25,8 +28,10 @@ pub fn init(cx: &mut AppContext) {
|
|||
}
|
||||
|
||||
pub struct AssistantPanel {
|
||||
workspace: WeakView<Workspace>,
|
||||
thread: Model<Thread>,
|
||||
message_editor: View<MessageEditor>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
|
@ -36,26 +41,36 @@ impl AssistantPanel {
|
|||
cx: AsyncWindowContext,
|
||||
) -> Task<Result<View<Self>>> {
|
||||
cx.spawn(|mut cx| async move {
|
||||
let tools = Arc::new(ToolWorkingSet::default());
|
||||
workspace.update(&mut cx, |workspace, cx| {
|
||||
cx.new_view(|cx| Self::new(workspace, cx))
|
||||
cx.new_view(|cx| Self::new(workspace, tools, cx))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn new(_workspace: &Workspace, cx: &mut ViewContext<Self>) -> Self {
|
||||
let thread = cx.new_model(Thread::new);
|
||||
let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())];
|
||||
fn new(workspace: &Workspace, tools: Arc<ToolWorkingSet>, cx: &mut ViewContext<Self>) -> Self {
|
||||
let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx));
|
||||
let subscriptions = vec![
|
||||
cx.observe(&thread, |_, _, cx| cx.notify()),
|
||||
cx.subscribe(&thread, Self::handle_thread_event),
|
||||
];
|
||||
|
||||
Self {
|
||||
workspace: workspace.weak_handle(),
|
||||
thread: thread.clone(),
|
||||
message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
|
||||
tools,
|
||||
_subscriptions: subscriptions,
|
||||
}
|
||||
}
|
||||
|
||||
fn new_thread(&mut self, cx: &mut ViewContext<Self>) {
|
||||
let thread = cx.new_model(Thread::new);
|
||||
let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())];
|
||||
let tools = self.thread.read(cx).tools().clone();
|
||||
let thread = cx.new_model(|cx| Thread::new(tools, cx));
|
||||
let subscriptions = vec![
|
||||
cx.observe(&thread, |_, _, cx| cx.notify()),
|
||||
cx.subscribe(&thread, Self::handle_thread_event),
|
||||
];
|
||||
|
||||
self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx));
|
||||
self.thread = thread;
|
||||
|
@ -63,6 +78,38 @@ impl AssistantPanel {
|
|||
|
||||
self.message_editor.focus_handle(cx).focus(cx);
|
||||
}
|
||||
|
||||
fn handle_thread_event(
|
||||
&mut self,
|
||||
_: Model<Thread>,
|
||||
event: &ThreadEvent,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) {
|
||||
match event {
|
||||
ThreadEvent::StreamedCompletion => {}
|
||||
ThreadEvent::UsePendingTools => {
|
||||
let pending_tool_uses = self
|
||||
.thread
|
||||
.read(cx)
|
||||
.pending_tool_uses()
|
||||
.into_iter()
|
||||
.filter(|tool_use| tool_use.status.is_idle())
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for tool_use in pending_tool_uses {
|
||||
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
|
||||
let task = tool.run(tool_use.input, self.workspace.clone(), cx);
|
||||
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.insert_tool_output(tool_use.id.clone(), task, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
ThreadEvent::ToolFinished { .. } => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FocusableView for AssistantPanel {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use feature_flags::{FeatureFlagAppExt, ToolUseFeatureFlag};
|
||||
use gpui::{AppContext, FocusableView, Model, TextStyle, View};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use language_model::{LanguageModelRegistry, LanguageModelRequestTool};
|
||||
use settings::Settings;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding};
|
||||
|
@ -55,7 +56,21 @@ impl MessageEditor {
|
|||
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message(user_message);
|
||||
let request = thread.to_completion_request(request_kind, cx);
|
||||
let mut request = thread.to_completion_request(request_kind, cx);
|
||||
|
||||
if cx.has_flag::<ToolUseFeatureFlag>() {
|
||||
request.tools = thread
|
||||
.tools()
|
||||
.tools(cx)
|
||||
.into_iter()
|
||||
.map(|tool| LanguageModelRequestTool {
|
||||
name: tool.name(),
|
||||
description: tool.description(),
|
||||
input_schema: tool.input_schema(),
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
thread.stream_completion(request, model, cx)
|
||||
});
|
||||
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt as _;
|
||||
use anyhow::Result;
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use collections::HashMap;
|
||||
use futures::future::Shared;
|
||||
use futures::{FutureExt as _, StreamExt as _};
|
||||
use gpui::{AppContext, EventEmitter, ModelContext, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
MessageContent, Role, StopReason,
|
||||
LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
|
||||
};
|
||||
use util::{post_inc, ResultExt as _};
|
||||
use util::post_inc;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum RequestKind {
|
||||
|
@ -14,14 +18,12 @@ pub enum RequestKind {
|
|||
}
|
||||
|
||||
/// A message in a [`Thread`].
|
||||
#[derive(Debug)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
struct PendingCompletion {
|
||||
id: usize,
|
||||
_task: Task<()>,
|
||||
pub tool_uses: Vec<LanguageModelToolUse>,
|
||||
pub tool_results: Vec<LanguageModelToolResult>,
|
||||
}
|
||||
|
||||
/// A thread of conversation with the LLM.
|
||||
|
@ -29,14 +31,20 @@ pub struct Thread {
|
|||
messages: Vec<Message>,
|
||||
completion_count: usize,
|
||||
pending_completions: Vec<PendingCompletion>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
|
||||
completed_tool_uses_by_id: HashMap<Arc<str>, String>,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(_cx: &mut ModelContext<Self>) -> Self {
|
||||
pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
|
||||
Self {
|
||||
tools,
|
||||
messages: Vec::new(),
|
||||
completion_count: 0,
|
||||
pending_completions: Vec::new(),
|
||||
pending_tool_uses_by_id: HashMap::default(),
|
||||
completed_tool_uses_by_id: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -44,11 +52,31 @@ impl Thread {
|
|||
self.messages.iter()
|
||||
}
|
||||
|
||||
pub fn tools(&self) -> &Arc<ToolWorkingSet> {
|
||||
&self.tools
|
||||
}
|
||||
|
||||
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
|
||||
self.pending_tool_uses_by_id.values().collect()
|
||||
}
|
||||
|
||||
pub fn insert_user_message(&mut self, text: impl Into<String>) {
|
||||
self.messages.push(Message {
|
||||
let mut message = Message {
|
||||
role: Role::User,
|
||||
text: text.into(),
|
||||
});
|
||||
tool_uses: Vec::new(),
|
||||
tool_results: Vec::new(),
|
||||
};
|
||||
|
||||
for (tool_use_id, tool_output) in self.completed_tool_uses_by_id.drain() {
|
||||
message.tool_results.push(LanguageModelToolResult {
|
||||
tool_use_id: tool_use_id.to_string(),
|
||||
content: tool_output,
|
||||
is_error: false,
|
||||
});
|
||||
}
|
||||
|
||||
self.messages.push(message);
|
||||
}
|
||||
|
||||
pub fn to_completion_request(
|
||||
|
@ -70,9 +98,23 @@ impl Thread {
|
|||
cache: false,
|
||||
};
|
||||
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text(message.text.clone()));
|
||||
for tool_result in &message.tool_results {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolResult(tool_result.clone()));
|
||||
}
|
||||
|
||||
if !message.text.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text(message.text.clone()));
|
||||
}
|
||||
|
||||
for tool_use in &message.tool_uses {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||
}
|
||||
|
||||
request.messages.push(request_message);
|
||||
}
|
||||
|
@ -103,6 +145,8 @@ impl Thread {
|
|||
thread.messages.push(Message {
|
||||
role: Role::Assistant,
|
||||
text: String::new(),
|
||||
tool_uses: Vec::new(),
|
||||
tool_results: Vec::new(),
|
||||
});
|
||||
}
|
||||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
|
@ -115,7 +159,24 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
}
|
||||
LanguageModelCompletionEvent::ToolUse(_tool_use) => {}
|
||||
LanguageModelCompletionEvent::ToolUse(tool_use) => {
|
||||
if let Some(last_message) = thread.messages.last_mut() {
|
||||
if last_message.role == Role::Assistant {
|
||||
last_message.tool_uses.push(tool_use.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let tool_use_id: Arc<str> = tool_use.id.into();
|
||||
thread.pending_tool_uses_by_id.insert(
|
||||
tool_use_id.clone(),
|
||||
PendingToolUse {
|
||||
id: tool_use_id,
|
||||
name: tool_use.name,
|
||||
input: tool_use.input,
|
||||
status: PendingToolUseStatus::Idle,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
cx.emit(ThreadEvent::StreamedCompletion);
|
||||
|
@ -135,7 +196,35 @@ impl Thread {
|
|||
};
|
||||
|
||||
let result = stream_completion.await;
|
||||
let _ = result.log_err();
|
||||
|
||||
thread
|
||||
.update(&mut cx, |_thread, cx| {
|
||||
let error_message = if let Some(error) = result.as_ref().err() {
|
||||
let error_message = error
|
||||
.chain()
|
||||
.map(|err| err.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
Some(error_message)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(error_message) = error_message {
|
||||
eprintln!("Completion failed: {error_message:?}");
|
||||
}
|
||||
|
||||
if let Ok(stop_reason) = result {
|
||||
match stop_reason {
|
||||
StopReason::ToolUse => {
|
||||
cx.emit(ThreadEvent::UsePendingTools);
|
||||
}
|
||||
StopReason::EndTurn => {}
|
||||
StopReason::MaxTokens => {}
|
||||
}
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
|
||||
self.pending_completions.push(PendingCompletion {
|
||||
|
@ -143,11 +232,80 @@ impl Thread {
|
|||
_task: task,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn insert_tool_output(
|
||||
&mut self,
|
||||
tool_use_id: Arc<str>,
|
||||
output: Task<Result<String>>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
let insert_output_task = cx.spawn(|thread, mut cx| {
|
||||
let tool_use_id = tool_use_id.clone();
|
||||
async move {
|
||||
let output = output.await;
|
||||
thread
|
||||
.update(&mut cx, |thread, cx| match output {
|
||||
Ok(output) => {
|
||||
thread
|
||||
.completed_tool_uses_by_id
|
||||
.insert(tool_use_id.clone(), output);
|
||||
|
||||
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
|
||||
}
|
||||
Err(err) => {
|
||||
if let Some(tool_use) =
|
||||
thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
|
||||
{
|
||||
tool_use.status = PendingToolUseStatus::Error(err.to_string());
|
||||
}
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
|
||||
tool_use.status = PendingToolUseStatus::Running {
|
||||
_task: insert_output_task.shared(),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ThreadEvent {
|
||||
StreamedCompletion,
|
||||
UsePendingTools,
|
||||
ToolFinished {
|
||||
#[allow(unused)]
|
||||
tool_use_id: Arc<str>,
|
||||
},
|
||||
}
|
||||
|
||||
impl EventEmitter<ThreadEvent> for Thread {}
|
||||
|
||||
struct PendingCompletion {
|
||||
id: usize,
|
||||
_task: Task<()>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PendingToolUse {
|
||||
pub id: Arc<str>,
|
||||
pub name: String,
|
||||
pub input: serde_json::Value,
|
||||
pub status: PendingToolUseStatus,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PendingToolUseStatus {
|
||||
Idle,
|
||||
Running { _task: Shared<Task<()>> },
|
||||
Error(#[allow(unused)] String),
|
||||
}
|
||||
|
||||
impl PendingToolUseStatus {
|
||||
pub fn is_idle(&self) -> bool {
|
||||
matches!(self, PendingToolUseStatus::Idle)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ impl Tool for NowTool {
|
|||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"Returns the current datetime in RFC 3339 format.".into()
|
||||
"Returns the current datetime in RFC 3339 format. Only use this tool when the user specifically asks for it or the current task would benefit from knowing the current datetime.".into()
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
|
|
|
@ -49,6 +49,16 @@ impl FeatureFlag for Assistant2FeatureFlag {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct ToolUseFeatureFlag;
|
||||
|
||||
impl FeatureFlag for ToolUseFeatureFlag {
|
||||
const NAME: &'static str = "assistant-tool-use";
|
||||
|
||||
fn enabled_for_staff() -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Remoting {}
|
||||
impl FeatureFlag for Remoting {
|
||||
const NAME: &'static str = "remoting";
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue