assistant2: Add helper methods to Thread for dealing with tool use (#26310)

This PR adds two new helper methods to the `Thread` for dealing with
tool use:

- `use_pending_tools` - This uses all of the tools that are pending
- The reason we aren't calling this directly in `stream_completion` is
that we still might need to have a way for users to confirm that they
want tools to be run, which would need to happen at the UI layer in the
`ActiveThread`.
- `send_tool_results_to_model` - This encapsulates inserting a new user
message that contains the tool results and sending them up to the model.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-03-07 18:16:45 -05:00 committed by GitHub
parent 18f3f8097f
commit 921c24e274
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 61 additions and 52 deletions

View file

@ -1,17 +1,15 @@
use std::sync::Arc;
use assistant_tool::ToolWorkingSet;
use collections::HashMap;
use editor::{Editor, MultiBuffer};
use gpui::{
list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
Task, TextStyleRefinement, UnderlineStyle, WeakEntity,
Task, TextStyleRefinement, UnderlineStyle,
};
use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::{Markdown, MarkdownStyle};
use project::Project;
use settings::Settings as _;
use theme::ThemeSettings;
use ui::{prelude::*, Disclosure, KeyBinding};
@ -23,9 +21,7 @@ use crate::tool_use::{ToolUse, ToolUseStatus};
use crate::ui::ContextPill;
pub struct ActiveThread {
project: WeakEntity<Project>,
language_registry: Arc<LanguageRegistry>,
tools: Arc<ToolWorkingSet>,
thread_store: Entity<ThreadStore>,
thread: Entity<Thread>,
save_thread_task: Option<Task<()>>,
@ -46,9 +42,7 @@ impl ActiveThread {
pub fn new(
thread: Entity<Thread>,
thread_store: Entity<ThreadStore>,
project: WeakEntity<Project>,
language_registry: Arc<LanguageRegistry>,
tools: Arc<ToolWorkingSet>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@ -58,9 +52,7 @@ impl ActiveThread {
];
let mut this = Self {
project,
language_registry,
tools,
thread_store,
thread: thread.clone(),
save_thread_task: None,
@ -300,24 +292,9 @@ impl ActiveThread {
cx.notify();
}
ThreadEvent::UsePendingTools => {
let pending_tool_uses = self
.thread
.read(cx)
.pending_tool_uses()
.into_iter()
.filter(|tool_use| tool_use.status.is_idle())
.cloned()
.collect::<Vec<_>>();
for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(tool_use.input, self.project.clone(), cx);
self.thread.update(cx, |thread, cx| {
thread.insert_tool_output(tool_use.id.clone(), task, cx);
});
}
}
self.thread.update(cx, |thread, cx| {
thread.use_pending_tools(cx);
});
}
ThreadEvent::ToolFinished { .. } => {
let all_tools_finished = self
@ -330,16 +307,7 @@ impl ActiveThread {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() {
self.thread.update(cx, |thread, cx| {
// Insert a user message to contain the tool results.
thread.insert_user_message(
// TODO: Sending up a user message without any content results in the model sending back
// responses that also don't have any content. We currently don't handle this case well,
// so for now we provide some text to keep the model on track.
"Here are the tool results.",
Vec::new(),
cx,
);
thread.send_to_model(model, RequestKind::Chat, true, cx);
thread.send_tool_results_to_model(model, cx);
});
}
}

View file

@ -92,7 +92,6 @@ pub struct AssistantPanel {
context_editor: Option<Entity<ContextEditor>>,
configuration: Option<Entity<AssistantConfiguration>>,
configuration_subscription: Option<Subscription>,
tools: Arc<ToolWorkingSet>,
local_timezone: UtcOffset,
active_view: ActiveView,
history_store: Entity<HistoryStore>,
@ -133,7 +132,7 @@ impl AssistantPanel {
log::info!("[assistant2-debug] finished initializing ContextStore");
workspace.update_in(&mut cx, |workspace, window, cx| {
cx.new(|cx| Self::new(workspace, thread_store, context_store, tools, window, cx))
cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx))
})
})
}
@ -142,7 +141,6 @@ impl AssistantPanel {
workspace: &Workspace,
thread_store: Entity<ThreadStore>,
context_store: Entity<assistant_context_editor::ContextStore>,
tools: Arc<ToolWorkingSet>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@ -179,9 +177,7 @@ impl AssistantPanel {
ActiveThread::new(
thread.clone(),
thread_store.clone(),
project.downgrade(),
language_registry,
tools.clone(),
window,
cx,
)
@ -191,7 +187,6 @@ impl AssistantPanel {
context_editor: None,
configuration: None,
configuration_subscription: None,
tools,
local_timezone: UtcOffset::from_whole_seconds(
chrono::Local::now().offset().local_minus_utc(),
)
@ -246,9 +241,7 @@ impl AssistantPanel {
ActiveThread::new(
thread.clone(),
self.thread_store.clone(),
self.project.downgrade(),
self.language_registry.clone(),
self.tools.clone(),
window,
cx,
)
@ -381,9 +374,7 @@ impl AssistantPanel {
ActiveThread::new(
thread.clone(),
this.thread_store.clone(),
this.project.downgrade(),
this.language_registry.clone(),
this.tools.clone(),
window,
cx,
)

View file

@ -5,13 +5,14 @@ use assistant_tool::ToolWorkingSet;
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet};
use futures::StreamExt as _;
use gpui::{App, Context, EventEmitter, SharedString, Task};
use gpui::{App, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
Role, StopReason,
};
use project::Project;
use serde::{Deserialize, Serialize};
use util::{post_inc, TryFutureExt as _};
use uuid::Uuid;
@ -71,12 +72,17 @@ pub struct Thread {
context_by_message: HashMap<MessageId, Vec<ContextId>>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
project: WeakEntity<Project>,
tools: Arc<ToolWorkingSet>,
tool_use: ToolUseState,
}
impl Thread {
pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut Context<Self>) -> Self {
pub fn new(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
_cx: &mut Context<Self>,
) -> Self {
Self {
id: ThreadId::new(),
updated_at: Utc::now(),
@ -88,6 +94,7 @@ impl Thread {
context_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
project: project.downgrade(),
tools,
tool_use: ToolUseState::new(),
}
@ -96,6 +103,7 @@ impl Thread {
pub fn from_saved(
id: ThreadId,
saved: SavedThread,
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
_cx: &mut Context<Self>,
) -> Self {
@ -127,6 +135,7 @@ impl Thread {
context_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
project: project.downgrade(),
tools,
tool_use,
}
@ -550,6 +559,23 @@ impl Thread {
});
}
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
let pending_tool_uses = self
.pending_tool_uses()
.into_iter()
.filter(|tool_use| tool_use.status.is_idle())
.cloned()
.collect::<Vec<_>>();
for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(tool_use.input, self.project.clone(), cx);
self.insert_tool_output(tool_use.id.clone(), task, cx);
}
}
}
pub fn insert_tool_output(
&mut self,
tool_use_id: LanguageModelToolUseId,
@ -576,6 +602,23 @@ impl Thread {
.run_pending_tool(tool_use_id, insert_output_task);
}
pub fn send_tool_results_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
cx: &mut Context<Self>,
) {
// Insert a user message to contain the tool results.
self.insert_user_message(
// TODO: Sending up a user message without any content results in the model sending back
// responses that also don't have any content. We currently don't handle this case well,
// so for now we provide some text to keep the model on track.
"Here are the tool results.",
Vec::new(),
cx,
);
self.send_to_model(model, RequestKind::Chat, true, cx);
}
/// Cancels the last pending completion, if there are any pending.
///
/// Returns whether a completion was canceled.

View file

@ -26,7 +26,6 @@ pub fn init(cx: &mut App) {
}
pub struct ThreadStore {
#[allow(unused)]
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
context_server_manager: Entity<ContextServerManager>,
@ -78,7 +77,7 @@ impl ThreadStore {
}
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
cx.new(|cx| Thread::new(self.tools.clone(), cx))
cx.new(|cx| Thread::new(self.project.clone(), self.tools.clone(), cx))
}
pub fn open_thread(
@ -96,7 +95,15 @@ impl ThreadStore {
.ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
this.update(&mut cx, |this, cx| {
cx.new(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx))
cx.new(|cx| {
Thread::from_saved(
id.clone(),
thread,
this.project.clone(),
this.tools.clone(),
cx,
)
})
})
})
}