assistant2: Add helper methods to Thread
for dealing with tool use (#26310)
This PR adds two new helper methods to the `Thread` for dealing with tool use: - `use_pending_tools` - This uses all of the tools that are pending - The reason we aren't calling this directly in `stream_completion` is that we still might need to have a way for users to confirm that they want tools to be run, which would need to happen at the UI layer in the `ActiveThread`. - `send_tool_results_to_model` - This encapsulates inserting a new user message that contains the tool results and sending them up to the model. Release Notes: - N/A
This commit is contained in:
parent
18f3f8097f
commit
921c24e274
4 changed files with 61 additions and 52 deletions
|
@ -1,17 +1,15 @@
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use assistant_tool::ToolWorkingSet;
|
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use editor::{Editor, MultiBuffer};
|
use editor::{Editor, MultiBuffer};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
|
list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
|
||||||
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
|
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
|
||||||
Task, TextStyleRefinement, UnderlineStyle, WeakEntity,
|
Task, TextStyleRefinement, UnderlineStyle,
|
||||||
};
|
};
|
||||||
use language::{Buffer, LanguageRegistry};
|
use language::{Buffer, LanguageRegistry};
|
||||||
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
|
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
|
||||||
use markdown::{Markdown, MarkdownStyle};
|
use markdown::{Markdown, MarkdownStyle};
|
||||||
use project::Project;
|
|
||||||
use settings::Settings as _;
|
use settings::Settings as _;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::{prelude::*, Disclosure, KeyBinding};
|
use ui::{prelude::*, Disclosure, KeyBinding};
|
||||||
|
@ -23,9 +21,7 @@ use crate::tool_use::{ToolUse, ToolUseStatus};
|
||||||
use crate::ui::ContextPill;
|
use crate::ui::ContextPill;
|
||||||
|
|
||||||
pub struct ActiveThread {
|
pub struct ActiveThread {
|
||||||
project: WeakEntity<Project>,
|
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
|
||||||
thread_store: Entity<ThreadStore>,
|
thread_store: Entity<ThreadStore>,
|
||||||
thread: Entity<Thread>,
|
thread: Entity<Thread>,
|
||||||
save_thread_task: Option<Task<()>>,
|
save_thread_task: Option<Task<()>>,
|
||||||
|
@ -46,9 +42,7 @@ impl ActiveThread {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
thread: Entity<Thread>,
|
thread: Entity<Thread>,
|
||||||
thread_store: Entity<ThreadStore>,
|
thread_store: Entity<ThreadStore>,
|
||||||
project: WeakEntity<Project>,
|
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
@ -58,9 +52,7 @@ impl ActiveThread {
|
||||||
];
|
];
|
||||||
|
|
||||||
let mut this = Self {
|
let mut this = Self {
|
||||||
project,
|
|
||||||
language_registry,
|
language_registry,
|
||||||
tools,
|
|
||||||
thread_store,
|
thread_store,
|
||||||
thread: thread.clone(),
|
thread: thread.clone(),
|
||||||
save_thread_task: None,
|
save_thread_task: None,
|
||||||
|
@ -300,25 +292,10 @@ impl ActiveThread {
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
ThreadEvent::UsePendingTools => {
|
ThreadEvent::UsePendingTools => {
|
||||||
let pending_tool_uses = self
|
|
||||||
.thread
|
|
||||||
.read(cx)
|
|
||||||
.pending_tool_uses()
|
|
||||||
.into_iter()
|
|
||||||
.filter(|tool_use| tool_use.status.is_idle())
|
|
||||||
.cloned()
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
for tool_use in pending_tool_uses {
|
|
||||||
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
|
|
||||||
let task = tool.run(tool_use.input, self.project.clone(), cx);
|
|
||||||
|
|
||||||
self.thread.update(cx, |thread, cx| {
|
self.thread.update(cx, |thread, cx| {
|
||||||
thread.insert_tool_output(tool_use.id.clone(), task, cx);
|
thread.use_pending_tools(cx);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
ThreadEvent::ToolFinished { .. } => {
|
ThreadEvent::ToolFinished { .. } => {
|
||||||
let all_tools_finished = self
|
let all_tools_finished = self
|
||||||
.thread
|
.thread
|
||||||
|
@ -330,16 +307,7 @@ impl ActiveThread {
|
||||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||||
if let Some(model) = model_registry.active_model() {
|
if let Some(model) = model_registry.active_model() {
|
||||||
self.thread.update(cx, |thread, cx| {
|
self.thread.update(cx, |thread, cx| {
|
||||||
// Insert a user message to contain the tool results.
|
thread.send_tool_results_to_model(model, cx);
|
||||||
thread.insert_user_message(
|
|
||||||
// TODO: Sending up a user message without any content results in the model sending back
|
|
||||||
// responses that also don't have any content. We currently don't handle this case well,
|
|
||||||
// so for now we provide some text to keep the model on track.
|
|
||||||
"Here are the tool results.",
|
|
||||||
Vec::new(),
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
thread.send_to_model(model, RequestKind::Chat, true, cx);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,7 +92,6 @@ pub struct AssistantPanel {
|
||||||
context_editor: Option<Entity<ContextEditor>>,
|
context_editor: Option<Entity<ContextEditor>>,
|
||||||
configuration: Option<Entity<AssistantConfiguration>>,
|
configuration: Option<Entity<AssistantConfiguration>>,
|
||||||
configuration_subscription: Option<Subscription>,
|
configuration_subscription: Option<Subscription>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
|
||||||
local_timezone: UtcOffset,
|
local_timezone: UtcOffset,
|
||||||
active_view: ActiveView,
|
active_view: ActiveView,
|
||||||
history_store: Entity<HistoryStore>,
|
history_store: Entity<HistoryStore>,
|
||||||
|
@ -133,7 +132,7 @@ impl AssistantPanel {
|
||||||
log::info!("[assistant2-debug] finished initializing ContextStore");
|
log::info!("[assistant2-debug] finished initializing ContextStore");
|
||||||
|
|
||||||
workspace.update_in(&mut cx, |workspace, window, cx| {
|
workspace.update_in(&mut cx, |workspace, window, cx| {
|
||||||
cx.new(|cx| Self::new(workspace, thread_store, context_store, tools, window, cx))
|
cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -142,7 +141,6 @@ impl AssistantPanel {
|
||||||
workspace: &Workspace,
|
workspace: &Workspace,
|
||||||
thread_store: Entity<ThreadStore>,
|
thread_store: Entity<ThreadStore>,
|
||||||
context_store: Entity<assistant_context_editor::ContextStore>,
|
context_store: Entity<assistant_context_editor::ContextStore>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
@ -179,9 +177,7 @@ impl AssistantPanel {
|
||||||
ActiveThread::new(
|
ActiveThread::new(
|
||||||
thread.clone(),
|
thread.clone(),
|
||||||
thread_store.clone(),
|
thread_store.clone(),
|
||||||
project.downgrade(),
|
|
||||||
language_registry,
|
language_registry,
|
||||||
tools.clone(),
|
|
||||||
window,
|
window,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
@ -191,7 +187,6 @@ impl AssistantPanel {
|
||||||
context_editor: None,
|
context_editor: None,
|
||||||
configuration: None,
|
configuration: None,
|
||||||
configuration_subscription: None,
|
configuration_subscription: None,
|
||||||
tools,
|
|
||||||
local_timezone: UtcOffset::from_whole_seconds(
|
local_timezone: UtcOffset::from_whole_seconds(
|
||||||
chrono::Local::now().offset().local_minus_utc(),
|
chrono::Local::now().offset().local_minus_utc(),
|
||||||
)
|
)
|
||||||
|
@ -246,9 +241,7 @@ impl AssistantPanel {
|
||||||
ActiveThread::new(
|
ActiveThread::new(
|
||||||
thread.clone(),
|
thread.clone(),
|
||||||
self.thread_store.clone(),
|
self.thread_store.clone(),
|
||||||
self.project.downgrade(),
|
|
||||||
self.language_registry.clone(),
|
self.language_registry.clone(),
|
||||||
self.tools.clone(),
|
|
||||||
window,
|
window,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
@ -381,9 +374,7 @@ impl AssistantPanel {
|
||||||
ActiveThread::new(
|
ActiveThread::new(
|
||||||
thread.clone(),
|
thread.clone(),
|
||||||
this.thread_store.clone(),
|
this.thread_store.clone(),
|
||||||
this.project.downgrade(),
|
|
||||||
this.language_registry.clone(),
|
this.language_registry.clone(),
|
||||||
this.tools.clone(),
|
|
||||||
window,
|
window,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,13 +5,14 @@ use assistant_tool::ToolWorkingSet;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use collections::{BTreeMap, HashMap, HashSet};
|
use collections::{BTreeMap, HashMap, HashSet};
|
||||||
use futures::StreamExt as _;
|
use futures::StreamExt as _;
|
||||||
use gpui::{App, Context, EventEmitter, SharedString, Task};
|
use gpui::{App, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
|
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
|
||||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
|
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
|
||||||
Role, StopReason,
|
Role, StopReason,
|
||||||
};
|
};
|
||||||
|
use project::Project;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use util::{post_inc, TryFutureExt as _};
|
use util::{post_inc, TryFutureExt as _};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
@ -71,12 +72,17 @@ pub struct Thread {
|
||||||
context_by_message: HashMap<MessageId, Vec<ContextId>>,
|
context_by_message: HashMap<MessageId, Vec<ContextId>>,
|
||||||
completion_count: usize,
|
completion_count: usize,
|
||||||
pending_completions: Vec<PendingCompletion>,
|
pending_completions: Vec<PendingCompletion>,
|
||||||
|
project: WeakEntity<Project>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
tools: Arc<ToolWorkingSet>,
|
||||||
tool_use: ToolUseState,
|
tool_use: ToolUseState,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Thread {
|
impl Thread {
|
||||||
pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut Context<Self>) -> Self {
|
pub fn new(
|
||||||
|
project: Entity<Project>,
|
||||||
|
tools: Arc<ToolWorkingSet>,
|
||||||
|
_cx: &mut Context<Self>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: ThreadId::new(),
|
id: ThreadId::new(),
|
||||||
updated_at: Utc::now(),
|
updated_at: Utc::now(),
|
||||||
|
@ -88,6 +94,7 @@ impl Thread {
|
||||||
context_by_message: HashMap::default(),
|
context_by_message: HashMap::default(),
|
||||||
completion_count: 0,
|
completion_count: 0,
|
||||||
pending_completions: Vec::new(),
|
pending_completions: Vec::new(),
|
||||||
|
project: project.downgrade(),
|
||||||
tools,
|
tools,
|
||||||
tool_use: ToolUseState::new(),
|
tool_use: ToolUseState::new(),
|
||||||
}
|
}
|
||||||
|
@ -96,6 +103,7 @@ impl Thread {
|
||||||
pub fn from_saved(
|
pub fn from_saved(
|
||||||
id: ThreadId,
|
id: ThreadId,
|
||||||
saved: SavedThread,
|
saved: SavedThread,
|
||||||
|
project: Entity<Project>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
tools: Arc<ToolWorkingSet>,
|
||||||
_cx: &mut Context<Self>,
|
_cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
@ -127,6 +135,7 @@ impl Thread {
|
||||||
context_by_message: HashMap::default(),
|
context_by_message: HashMap::default(),
|
||||||
completion_count: 0,
|
completion_count: 0,
|
||||||
pending_completions: Vec::new(),
|
pending_completions: Vec::new(),
|
||||||
|
project: project.downgrade(),
|
||||||
tools,
|
tools,
|
||||||
tool_use,
|
tool_use,
|
||||||
}
|
}
|
||||||
|
@ -550,6 +559,23 @@ impl Thread {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
|
||||||
|
let pending_tool_uses = self
|
||||||
|
.pending_tool_uses()
|
||||||
|
.into_iter()
|
||||||
|
.filter(|tool_use| tool_use.status.is_idle())
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
for tool_use in pending_tool_uses {
|
||||||
|
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
|
||||||
|
let task = tool.run(tool_use.input, self.project.clone(), cx);
|
||||||
|
|
||||||
|
self.insert_tool_output(tool_use.id.clone(), task, cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn insert_tool_output(
|
pub fn insert_tool_output(
|
||||||
&mut self,
|
&mut self,
|
||||||
tool_use_id: LanguageModelToolUseId,
|
tool_use_id: LanguageModelToolUseId,
|
||||||
|
@ -576,6 +602,23 @@ impl Thread {
|
||||||
.run_pending_tool(tool_use_id, insert_output_task);
|
.run_pending_tool(tool_use_id, insert_output_task);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn send_tool_results_to_model(
|
||||||
|
&mut self,
|
||||||
|
model: Arc<dyn LanguageModel>,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
// Insert a user message to contain the tool results.
|
||||||
|
self.insert_user_message(
|
||||||
|
// TODO: Sending up a user message without any content results in the model sending back
|
||||||
|
// responses that also don't have any content. We currently don't handle this case well,
|
||||||
|
// so for now we provide some text to keep the model on track.
|
||||||
|
"Here are the tool results.",
|
||||||
|
Vec::new(),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
self.send_to_model(model, RequestKind::Chat, true, cx);
|
||||||
|
}
|
||||||
|
|
||||||
/// Cancels the last pending completion, if there are any pending.
|
/// Cancels the last pending completion, if there are any pending.
|
||||||
///
|
///
|
||||||
/// Returns whether a completion was canceled.
|
/// Returns whether a completion was canceled.
|
||||||
|
|
|
@ -26,7 +26,6 @@ pub fn init(cx: &mut App) {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ThreadStore {
|
pub struct ThreadStore {
|
||||||
#[allow(unused)]
|
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
tools: Arc<ToolWorkingSet>,
|
||||||
context_server_manager: Entity<ContextServerManager>,
|
context_server_manager: Entity<ContextServerManager>,
|
||||||
|
@ -78,7 +77,7 @@ impl ThreadStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
|
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
|
||||||
cx.new(|cx| Thread::new(self.tools.clone(), cx))
|
cx.new(|cx| Thread::new(self.project.clone(), self.tools.clone(), cx))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn open_thread(
|
pub fn open_thread(
|
||||||
|
@ -96,7 +95,15 @@ impl ThreadStore {
|
||||||
.ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
|
.ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
|
||||||
|
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
cx.new(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx))
|
cx.new(|cx| {
|
||||||
|
Thread::from_saved(
|
||||||
|
id.clone(),
|
||||||
|
thread,
|
||||||
|
this.project.clone(),
|
||||||
|
this.tools.clone(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue