assistant2: Add helper methods to Thread for dealing with tool use (#26310)

This PR adds two new helper methods to the `Thread` for dealing with
tool use:

- `use_pending_tools` - This uses all of the tools that are pending
- The reason we aren't calling this directly in `stream_completion` is
that we still might need to have a way for users to confirm that they
want tools to be run, which would need to happen at the UI layer in the
`ActiveThread`.
- `send_tool_results_to_model` - This encapsulates inserting a new user
message that contains the tool results and sending them up to the model.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-03-07 18:16:45 -05:00 committed by GitHub
parent 18f3f8097f
commit 921c24e274
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 61 additions and 52 deletions

View file

@ -1,17 +1,15 @@
use std::sync::Arc; use std::sync::Arc;
use assistant_tool::ToolWorkingSet;
use collections::HashMap; use collections::HashMap;
use editor::{Editor, MultiBuffer}; use editor::{Editor, MultiBuffer};
use gpui::{ use gpui::{
list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty, list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
Task, TextStyleRefinement, UnderlineStyle, WeakEntity, Task, TextStyleRefinement, UnderlineStyle,
}; };
use language::{Buffer, LanguageRegistry}; use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role}; use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::{Markdown, MarkdownStyle}; use markdown::{Markdown, MarkdownStyle};
use project::Project;
use settings::Settings as _; use settings::Settings as _;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::{prelude::*, Disclosure, KeyBinding}; use ui::{prelude::*, Disclosure, KeyBinding};
@ -23,9 +21,7 @@ use crate::tool_use::{ToolUse, ToolUseStatus};
use crate::ui::ContextPill; use crate::ui::ContextPill;
pub struct ActiveThread { pub struct ActiveThread {
project: WeakEntity<Project>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
tools: Arc<ToolWorkingSet>,
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
thread: Entity<Thread>, thread: Entity<Thread>,
save_thread_task: Option<Task<()>>, save_thread_task: Option<Task<()>>,
@ -46,9 +42,7 @@ impl ActiveThread {
pub fn new( pub fn new(
thread: Entity<Thread>, thread: Entity<Thread>,
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
project: WeakEntity<Project>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
tools: Arc<ToolWorkingSet>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
@ -58,9 +52,7 @@ impl ActiveThread {
]; ];
let mut this = Self { let mut this = Self {
project,
language_registry, language_registry,
tools,
thread_store, thread_store,
thread: thread.clone(), thread: thread.clone(),
save_thread_task: None, save_thread_task: None,
@ -300,24 +292,9 @@ impl ActiveThread {
cx.notify(); cx.notify();
} }
ThreadEvent::UsePendingTools => { ThreadEvent::UsePendingTools => {
let pending_tool_uses = self self.thread.update(cx, |thread, cx| {
.thread thread.use_pending_tools(cx);
.read(cx) });
.pending_tool_uses()
.into_iter()
.filter(|tool_use| tool_use.status.is_idle())
.cloned()
.collect::<Vec<_>>();
for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(tool_use.input, self.project.clone(), cx);
self.thread.update(cx, |thread, cx| {
thread.insert_tool_output(tool_use.id.clone(), task, cx);
});
}
}
} }
ThreadEvent::ToolFinished { .. } => { ThreadEvent::ToolFinished { .. } => {
let all_tools_finished = self let all_tools_finished = self
@ -330,16 +307,7 @@ impl ActiveThread {
let model_registry = LanguageModelRegistry::read_global(cx); let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() { if let Some(model) = model_registry.active_model() {
self.thread.update(cx, |thread, cx| { self.thread.update(cx, |thread, cx| {
// Insert a user message to contain the tool results. thread.send_tool_results_to_model(model, cx);
thread.insert_user_message(
// TODO: Sending up a user message without any content results in the model sending back
// responses that also don't have any content. We currently don't handle this case well,
// so for now we provide some text to keep the model on track.
"Here are the tool results.",
Vec::new(),
cx,
);
thread.send_to_model(model, RequestKind::Chat, true, cx);
}); });
} }
} }

View file

@ -92,7 +92,6 @@ pub struct AssistantPanel {
context_editor: Option<Entity<ContextEditor>>, context_editor: Option<Entity<ContextEditor>>,
configuration: Option<Entity<AssistantConfiguration>>, configuration: Option<Entity<AssistantConfiguration>>,
configuration_subscription: Option<Subscription>, configuration_subscription: Option<Subscription>,
tools: Arc<ToolWorkingSet>,
local_timezone: UtcOffset, local_timezone: UtcOffset,
active_view: ActiveView, active_view: ActiveView,
history_store: Entity<HistoryStore>, history_store: Entity<HistoryStore>,
@ -133,7 +132,7 @@ impl AssistantPanel {
log::info!("[assistant2-debug] finished initializing ContextStore"); log::info!("[assistant2-debug] finished initializing ContextStore");
workspace.update_in(&mut cx, |workspace, window, cx| { workspace.update_in(&mut cx, |workspace, window, cx| {
cx.new(|cx| Self::new(workspace, thread_store, context_store, tools, window, cx)) cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx))
}) })
}) })
} }
@ -142,7 +141,6 @@ impl AssistantPanel {
workspace: &Workspace, workspace: &Workspace,
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
context_store: Entity<assistant_context_editor::ContextStore>, context_store: Entity<assistant_context_editor::ContextStore>,
tools: Arc<ToolWorkingSet>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
@ -179,9 +177,7 @@ impl AssistantPanel {
ActiveThread::new( ActiveThread::new(
thread.clone(), thread.clone(),
thread_store.clone(), thread_store.clone(),
project.downgrade(),
language_registry, language_registry,
tools.clone(),
window, window,
cx, cx,
) )
@ -191,7 +187,6 @@ impl AssistantPanel {
context_editor: None, context_editor: None,
configuration: None, configuration: None,
configuration_subscription: None, configuration_subscription: None,
tools,
local_timezone: UtcOffset::from_whole_seconds( local_timezone: UtcOffset::from_whole_seconds(
chrono::Local::now().offset().local_minus_utc(), chrono::Local::now().offset().local_minus_utc(),
) )
@ -246,9 +241,7 @@ impl AssistantPanel {
ActiveThread::new( ActiveThread::new(
thread.clone(), thread.clone(),
self.thread_store.clone(), self.thread_store.clone(),
self.project.downgrade(),
self.language_registry.clone(), self.language_registry.clone(),
self.tools.clone(),
window, window,
cx, cx,
) )
@ -381,9 +374,7 @@ impl AssistantPanel {
ActiveThread::new( ActiveThread::new(
thread.clone(), thread.clone(),
this.thread_store.clone(), this.thread_store.clone(),
this.project.downgrade(),
this.language_registry.clone(), this.language_registry.clone(),
this.tools.clone(),
window, window,
cx, cx,
) )

View file

@ -5,13 +5,14 @@ use assistant_tool::ToolWorkingSet;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet}; use collections::{BTreeMap, HashMap, HashSet};
use futures::StreamExt as _; use futures::StreamExt as _;
use gpui::{App, Context, EventEmitter, SharedString, Task}; use gpui::{App, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
Role, StopReason, Role, StopReason,
}; };
use project::Project;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use util::{post_inc, TryFutureExt as _}; use util::{post_inc, TryFutureExt as _};
use uuid::Uuid; use uuid::Uuid;
@ -71,12 +72,17 @@ pub struct Thread {
context_by_message: HashMap<MessageId, Vec<ContextId>>, context_by_message: HashMap<MessageId, Vec<ContextId>>,
completion_count: usize, completion_count: usize,
pending_completions: Vec<PendingCompletion>, pending_completions: Vec<PendingCompletion>,
project: WeakEntity<Project>,
tools: Arc<ToolWorkingSet>, tools: Arc<ToolWorkingSet>,
tool_use: ToolUseState, tool_use: ToolUseState,
} }
impl Thread { impl Thread {
pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut Context<Self>) -> Self { pub fn new(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
_cx: &mut Context<Self>,
) -> Self {
Self { Self {
id: ThreadId::new(), id: ThreadId::new(),
updated_at: Utc::now(), updated_at: Utc::now(),
@ -88,6 +94,7 @@ impl Thread {
context_by_message: HashMap::default(), context_by_message: HashMap::default(),
completion_count: 0, completion_count: 0,
pending_completions: Vec::new(), pending_completions: Vec::new(),
project: project.downgrade(),
tools, tools,
tool_use: ToolUseState::new(), tool_use: ToolUseState::new(),
} }
@ -96,6 +103,7 @@ impl Thread {
pub fn from_saved( pub fn from_saved(
id: ThreadId, id: ThreadId,
saved: SavedThread, saved: SavedThread,
project: Entity<Project>,
tools: Arc<ToolWorkingSet>, tools: Arc<ToolWorkingSet>,
_cx: &mut Context<Self>, _cx: &mut Context<Self>,
) -> Self { ) -> Self {
@ -127,6 +135,7 @@ impl Thread {
context_by_message: HashMap::default(), context_by_message: HashMap::default(),
completion_count: 0, completion_count: 0,
pending_completions: Vec::new(), pending_completions: Vec::new(),
project: project.downgrade(),
tools, tools,
tool_use, tool_use,
} }
@ -550,6 +559,23 @@ impl Thread {
}); });
} }
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
let pending_tool_uses = self
.pending_tool_uses()
.into_iter()
.filter(|tool_use| tool_use.status.is_idle())
.cloned()
.collect::<Vec<_>>();
for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(tool_use.input, self.project.clone(), cx);
self.insert_tool_output(tool_use.id.clone(), task, cx);
}
}
}
pub fn insert_tool_output( pub fn insert_tool_output(
&mut self, &mut self,
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
@ -576,6 +602,23 @@ impl Thread {
.run_pending_tool(tool_use_id, insert_output_task); .run_pending_tool(tool_use_id, insert_output_task);
} }
pub fn send_tool_results_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
cx: &mut Context<Self>,
) {
// Insert a user message to contain the tool results.
self.insert_user_message(
// TODO: Sending up a user message without any content results in the model sending back
// responses that also don't have any content. We currently don't handle this case well,
// so for now we provide some text to keep the model on track.
"Here are the tool results.",
Vec::new(),
cx,
);
self.send_to_model(model, RequestKind::Chat, true, cx);
}
/// Cancels the last pending completion, if there are any pending. /// Cancels the last pending completion, if there are any pending.
/// ///
/// Returns whether a completion was canceled. /// Returns whether a completion was canceled.

View file

@ -26,7 +26,6 @@ pub fn init(cx: &mut App) {
} }
pub struct ThreadStore { pub struct ThreadStore {
#[allow(unused)]
project: Entity<Project>, project: Entity<Project>,
tools: Arc<ToolWorkingSet>, tools: Arc<ToolWorkingSet>,
context_server_manager: Entity<ContextServerManager>, context_server_manager: Entity<ContextServerManager>,
@ -78,7 +77,7 @@ impl ThreadStore {
} }
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> { pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
cx.new(|cx| Thread::new(self.tools.clone(), cx)) cx.new(|cx| Thread::new(self.project.clone(), self.tools.clone(), cx))
} }
pub fn open_thread( pub fn open_thread(
@ -96,7 +95,15 @@ impl ThreadStore {
.ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?; .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
cx.new(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx)) cx.new(|cx| {
Thread::from_saved(
id.clone(),
thread,
this.project.clone(),
this.tools.clone(),
cx,
)
})
}) })
}) })
} }