assistant: Add basic tool invocation (#17368)
This PR adds the initial groundwork for invoking tools in response to tool uses from the model. Tool uses are run when the model responds with a `stop_reason` of `tool_use`. Currently the tool results are just inserted as text into the user message. We'll want to include these as `tool_result` content on the message, but Claude seems to understand it regardless. Release Notes: - N/A
This commit is contained in:
parent
7fb94c4c4d
commit
01525f17fa
3 changed files with 92 additions and 10 deletions
|
@ -20,6 +20,7 @@ use crate::{
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
|
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
|
||||||
|
use assistant_tool::ToolRegistry;
|
||||||
use client::{proto, Client, Status};
|
use client::{proto, Client, Status};
|
||||||
use collections::{BTreeSet, HashMap, HashSet};
|
use collections::{BTreeSet, HashMap, HashSet};
|
||||||
use editor::{
|
use editor::{
|
||||||
|
@ -2091,6 +2092,27 @@ impl ContextEditor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ContextEvent::UsePendingTools => {
|
||||||
|
let pending_tool_uses = self
|
||||||
|
.context
|
||||||
|
.read(cx)
|
||||||
|
.pending_tool_uses()
|
||||||
|
.into_iter()
|
||||||
|
.filter(|tool_use| tool_use.status.is_idle())
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
for tool_use in pending_tool_uses {
|
||||||
|
let tool_registry = ToolRegistry::global(cx);
|
||||||
|
if let Some(tool) = tool_registry.tool(&tool_use.name) {
|
||||||
|
let task = tool.run(tool_use.input, self.workspace.clone(), cx);
|
||||||
|
|
||||||
|
self.context.update(cx, |context, cx| {
|
||||||
|
context.insert_tool_output(tool_use.id.clone(), task, cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
ContextEvent::Operation(_) => {}
|
ContextEvent::Operation(_) => {}
|
||||||
ContextEvent::ShowAssistError(error_message) => {
|
ContextEvent::ShowAssistError(error_message) => {
|
||||||
self.error_message = Some(error_message.clone());
|
self.error_message = Some(error_message.clone());
|
||||||
|
|
|
@ -29,7 +29,7 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
|
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
|
||||||
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||||
LanguageModelRequestTool, MessageContent, Role,
|
LanguageModelRequestTool, MessageContent, Role, StopReason,
|
||||||
};
|
};
|
||||||
use open_ai::Model as OpenAiModel;
|
use open_ai::Model as OpenAiModel;
|
||||||
use paths::{context_images_dir, contexts_dir};
|
use paths::{context_images_dir, contexts_dir};
|
||||||
|
@ -306,6 +306,7 @@ pub enum ContextEvent {
|
||||||
run_commands_in_output: bool,
|
run_commands_in_output: bool,
|
||||||
expand_result: bool,
|
expand_result: bool,
|
||||||
},
|
},
|
||||||
|
UsePendingTools,
|
||||||
Operation(ContextOperation),
|
Operation(ContextOperation),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -416,6 +417,7 @@ impl Message {
|
||||||
|
|
||||||
range_start = *image_offset;
|
range_start = *image_offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
if range_start != self.offset_range.end {
|
if range_start != self.offset_range.end {
|
||||||
if let Some(text) =
|
if let Some(text) =
|
||||||
Self::collect_text_content(buffer, range_start..self.offset_range.end)
|
Self::collect_text_content(buffer, range_start..self.offset_range.end)
|
||||||
|
@ -492,7 +494,7 @@ pub struct Context {
|
||||||
edits_since_last_parse: language::Subscription,
|
edits_since_last_parse: language::Subscription,
|
||||||
finished_slash_commands: HashSet<SlashCommandId>,
|
finished_slash_commands: HashSet<SlashCommandId>,
|
||||||
slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
|
slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
|
||||||
pending_tool_uses_by_id: HashMap<String, PendingToolUse>,
|
pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
|
||||||
message_anchors: Vec<MessageAnchor>,
|
message_anchors: Vec<MessageAnchor>,
|
||||||
images: HashMap<u64, (Arc<RenderImage>, Shared<Task<Option<LanguageModelImage>>>)>,
|
images: HashMap<u64, (Arc<RenderImage>, Shared<Task<Option<LanguageModelImage>>>)>,
|
||||||
image_anchors: Vec<ImageAnchor>,
|
image_anchors: Vec<ImageAnchor>,
|
||||||
|
@ -1012,7 +1014,7 @@ impl Context {
|
||||||
self.pending_tool_uses_by_id.values().collect()
|
self.pending_tool_uses_by_id.values().collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_tool_use_by_id(&self, id: &String) -> Option<&PendingToolUse> {
|
pub fn get_tool_use_by_id(&self, id: &Arc<str>) -> Option<&PendingToolUse> {
|
||||||
self.pending_tool_uses_by_id.get(id)
|
self.pending_tool_uses_by_id.get(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1919,6 +1921,45 @@ impl Context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn insert_tool_output(
|
||||||
|
&mut self,
|
||||||
|
tool_id: Arc<str>,
|
||||||
|
output: Task<Result<String>>,
|
||||||
|
cx: &mut ModelContext<Self>,
|
||||||
|
) {
|
||||||
|
let insert_output_task = cx.spawn(|this, mut cx| {
|
||||||
|
let tool_id = tool_id.clone();
|
||||||
|
async move {
|
||||||
|
let output = output.await;
|
||||||
|
this.update(&mut cx, |this, cx| match output {
|
||||||
|
Ok(mut output) => {
|
||||||
|
if !output.ends_with('\n') {
|
||||||
|
output.push('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
this.buffer.update(cx, |buffer, cx| {
|
||||||
|
let buffer_end = buffer.len().to_offset(buffer);
|
||||||
|
|
||||||
|
buffer.edit([(buffer_end..buffer_end, output)], None, cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
if let Some(tool_use) = this.pending_tool_uses_by_id.get_mut(&tool_id) {
|
||||||
|
tool_use.status = PendingToolUseStatus::Error(err.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_id) {
|
||||||
|
tool_use.status = PendingToolUseStatus::Running {
|
||||||
|
_task: insert_output_task.shared(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
|
pub fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
|
||||||
self.count_remaining_tokens(cx);
|
self.count_remaining_tokens(cx);
|
||||||
}
|
}
|
||||||
|
@ -1990,7 +2031,7 @@ impl Context {
|
||||||
.message_anchors
|
.message_anchors
|
||||||
.iter()
|
.iter()
|
||||||
.position(|message| message.id == assistant_message_id)?;
|
.position(|message| message.id == assistant_message_id)?;
|
||||||
this.buffer.update(cx, |buffer, cx| {
|
let event_to_emit = this.buffer.update(cx, |buffer, cx| {
|
||||||
let message_old_end_offset = this.message_anchors[message_ix + 1..]
|
let message_old_end_offset = this.message_anchors[message_ix + 1..]
|
||||||
.iter()
|
.iter()
|
||||||
.find(|message| message.start.is_valid(buffer))
|
.find(|message| message.start.is_valid(buffer))
|
||||||
|
@ -2000,9 +2041,11 @@ impl Context {
|
||||||
|
|
||||||
match event {
|
match event {
|
||||||
LanguageModelCompletionEvent::Stop(reason) => match reason {
|
LanguageModelCompletionEvent::Stop(reason) => match reason {
|
||||||
language_model::StopReason::ToolUse => {}
|
StopReason::ToolUse => {
|
||||||
language_model::StopReason::EndTurn => {}
|
return Some(ContextEvent::UsePendingTools);
|
||||||
language_model::StopReason::MaxTokens => {}
|
}
|
||||||
|
StopReason::EndTurn => {}
|
||||||
|
StopReason::MaxTokens => {}
|
||||||
},
|
},
|
||||||
LanguageModelCompletionEvent::Text(chunk) => {
|
LanguageModelCompletionEvent::Text(chunk) => {
|
||||||
buffer.edit(
|
buffer.edit(
|
||||||
|
@ -2041,10 +2084,11 @@ impl Context {
|
||||||
let source_range = buffer.anchor_after(start_ix)
|
let source_range = buffer.anchor_after(start_ix)
|
||||||
..buffer.anchor_after(end_ix);
|
..buffer.anchor_after(end_ix);
|
||||||
|
|
||||||
|
let tool_use_id: Arc<str> = tool_use.id.into();
|
||||||
this.pending_tool_uses_by_id.insert(
|
this.pending_tool_uses_by_id.insert(
|
||||||
tool_use.id.clone(),
|
tool_use_id.clone(),
|
||||||
PendingToolUse {
|
PendingToolUse {
|
||||||
id: tool_use.id,
|
id: tool_use_id,
|
||||||
name: tool_use.name,
|
name: tool_use.name,
|
||||||
input: tool_use.input,
|
input: tool_use.input,
|
||||||
status: PendingToolUseStatus::Idle,
|
status: PendingToolUseStatus::Idle,
|
||||||
|
@ -2053,9 +2097,14 @@ impl Context {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
None
|
||||||
});
|
});
|
||||||
|
|
||||||
cx.emit(ContextEvent::StreamedCompletion);
|
cx.emit(ContextEvent::StreamedCompletion);
|
||||||
|
if let Some(event) = event_to_emit {
|
||||||
|
cx.emit(event);
|
||||||
|
}
|
||||||
|
|
||||||
Some(())
|
Some(())
|
||||||
})?;
|
})?;
|
||||||
|
@ -2821,7 +2870,7 @@ impl FeatureFlag for ToolUseFeatureFlag {
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct PendingToolUse {
|
pub struct PendingToolUse {
|
||||||
pub id: String,
|
pub id: Arc<str>,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub input: serde_json::Value,
|
pub input: serde_json::Value,
|
||||||
pub status: PendingToolUseStatus,
|
pub status: PendingToolUseStatus,
|
||||||
|
@ -2835,6 +2884,12 @@ pub enum PendingToolUseStatus {
|
||||||
Error(String),
|
Error(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl PendingToolUseStatus {
|
||||||
|
pub fn is_idle(&self) -> bool {
|
||||||
|
matches!(self, PendingToolUseStatus::Idle)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
pub struct SavedMessage {
|
pub struct SavedMessage {
|
||||||
pub id: MessageId,
|
pub id: MessageId,
|
||||||
|
|
|
@ -66,4 +66,9 @@ impl ToolRegistry {
|
||||||
pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
|
pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
|
||||||
self.state.read().tools.values().cloned().collect()
|
self.state.read().tools.values().cloned().collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the [`Tool`] with the given name.
|
||||||
|
pub fn tool(&self, name: &str) -> Option<Arc<dyn Tool>> {
|
||||||
|
self.state.read().tools.get(name).cloned()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue