assistant2: Factor out tool use into its own module (#25819)
This PR factors out the concerns related to tool use out of `Thread` and into their own module. Release Notes: - N/A
This commit is contained in:
parent
b445e4ce24
commit
fc52b43159
4 changed files with 258 additions and 197 deletions
|
@ -15,10 +15,9 @@ use theme::ThemeSettings;
|
||||||
use ui::{prelude::*, Disclosure};
|
use ui::{prelude::*, Disclosure};
|
||||||
use workspace::Workspace;
|
use workspace::Workspace;
|
||||||
|
|
||||||
use crate::thread::{
|
use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
|
||||||
MessageId, RequestKind, Thread, ThreadError, ThreadEvent, ToolUse, ToolUseStatus,
|
|
||||||
};
|
|
||||||
use crate::thread_store::ThreadStore;
|
use crate::thread_store::ThreadStore;
|
||||||
|
use crate::tool_use::{ToolUse, ToolUseStatus};
|
||||||
use crate::ui::ContextPill;
|
use crate::ui::ContextPill;
|
||||||
|
|
||||||
pub struct ActiveThread {
|
pub struct ActiveThread {
|
||||||
|
|
|
@ -16,6 +16,7 @@ mod terminal_inline_assistant;
|
||||||
mod thread;
|
mod thread;
|
||||||
mod thread_history;
|
mod thread_history;
|
||||||
mod thread_store;
|
mod thread_store;
|
||||||
|
mod tool_use;
|
||||||
mod ui;
|
mod ui;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
|
@ -4,14 +4,12 @@ use anyhow::Result;
|
||||||
use assistant_tool::ToolWorkingSet;
|
use assistant_tool::ToolWorkingSet;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use collections::{BTreeMap, HashMap, HashSet};
|
use collections::{BTreeMap, HashMap, HashSet};
|
||||||
use futures::future::Shared;
|
use futures::StreamExt as _;
|
||||||
use futures::{FutureExt as _, StreamExt as _};
|
|
||||||
use gpui::{App, Context, EventEmitter, SharedString, Task};
|
use gpui::{App, Context, EventEmitter, SharedString, Task};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
|
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
|
||||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolUseId,
|
||||||
LanguageModelToolUse, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, Role, StopReason,
|
||||||
PaymentRequiredError, Role, StopReason,
|
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use util::{post_inc, TryFutureExt as _};
|
use util::{post_inc, TryFutureExt as _};
|
||||||
|
@ -19,6 +17,7 @@ use uuid::Uuid;
|
||||||
|
|
||||||
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
|
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
|
||||||
use crate::thread_store::SavedThread;
|
use crate::thread_store::SavedThread;
|
||||||
|
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub enum RequestKind {
|
pub enum RequestKind {
|
||||||
|
@ -43,7 +42,7 @@ impl std::fmt::Display for ThreadId {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
|
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||||
pub struct MessageId(usize);
|
pub struct MessageId(pub(crate) usize);
|
||||||
|
|
||||||
impl MessageId {
|
impl MessageId {
|
||||||
fn post_inc(&mut self) -> Self {
|
fn post_inc(&mut self) -> Self {
|
||||||
|
@ -59,22 +58,6 @@ pub struct Message {
|
||||||
pub text: String,
|
pub text: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ToolUse {
|
|
||||||
pub id: LanguageModelToolUseId,
|
|
||||||
pub name: SharedString,
|
|
||||||
pub status: ToolUseStatus,
|
|
||||||
pub input: serde_json::Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum ToolUseStatus {
|
|
||||||
Pending,
|
|
||||||
Running,
|
|
||||||
Finished(SharedString),
|
|
||||||
Error(SharedString),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A thread of conversation with the LLM.
|
/// A thread of conversation with the LLM.
|
||||||
pub struct Thread {
|
pub struct Thread {
|
||||||
id: ThreadId,
|
id: ThreadId,
|
||||||
|
@ -88,10 +71,7 @@ pub struct Thread {
|
||||||
completion_count: usize,
|
completion_count: usize,
|
||||||
pending_completions: Vec<PendingCompletion>,
|
pending_completions: Vec<PendingCompletion>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
tools: Arc<ToolWorkingSet>,
|
||||||
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
|
tool_use: ToolUseState,
|
||||||
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
|
|
||||||
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
|
||||||
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Thread {
|
impl Thread {
|
||||||
|
@ -108,10 +88,7 @@ impl Thread {
|
||||||
completion_count: 0,
|
completion_count: 0,
|
||||||
pending_completions: Vec::new(),
|
pending_completions: Vec::new(),
|
||||||
tools,
|
tools,
|
||||||
tool_uses_by_assistant_message: HashMap::default(),
|
tool_use: ToolUseState::default(),
|
||||||
tool_uses_by_user_message: HashMap::default(),
|
|
||||||
tool_results: HashMap::default(),
|
|
||||||
pending_tool_uses_by_id: HashMap::default(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,10 +120,7 @@ impl Thread {
|
||||||
completion_count: 0,
|
completion_count: 0,
|
||||||
pending_completions: Vec::new(),
|
pending_completions: Vec::new(),
|
||||||
tools,
|
tools,
|
||||||
tool_uses_by_assistant_message: HashMap::default(),
|
tool_use: ToolUseState::default(),
|
||||||
tool_uses_by_user_message: HashMap::default(),
|
|
||||||
tool_results: HashMap::default(),
|
|
||||||
pending_tool_uses_by_id: HashMap::default(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,56 +182,15 @@ impl Thread {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
|
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
|
||||||
self.pending_tool_uses_by_id.values().collect()
|
self.tool_use.pending_tool_uses()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
|
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
|
||||||
let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
|
self.tool_use.tool_uses_for_message(id)
|
||||||
return Vec::new();
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut tool_uses = Vec::new();
|
|
||||||
|
|
||||||
for tool_use in tool_uses_for_message.iter() {
|
|
||||||
let tool_result = self.tool_results.get(&tool_use.id);
|
|
||||||
|
|
||||||
let status = (|| {
|
|
||||||
if let Some(tool_result) = tool_result {
|
|
||||||
return if tool_result.is_error {
|
|
||||||
ToolUseStatus::Error(tool_result.content.clone().into())
|
|
||||||
} else {
|
|
||||||
ToolUseStatus::Finished(tool_result.content.clone().into())
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
|
|
||||||
return match pending_tool_use.status {
|
|
||||||
PendingToolUseStatus::Idle => ToolUseStatus::Pending,
|
|
||||||
PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
|
|
||||||
PendingToolUseStatus::Error(ref err) => {
|
|
||||||
ToolUseStatus::Error(err.clone().into())
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
ToolUseStatus::Pending
|
|
||||||
})();
|
|
||||||
|
|
||||||
tool_uses.push(ToolUse {
|
|
||||||
id: tool_use.id.clone(),
|
|
||||||
name: tool_use.name.clone().into(),
|
|
||||||
input: tool_use.input.clone(),
|
|
||||||
status,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
tool_uses
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||||
self.tool_uses_by_user_message
|
self.tool_use.message_has_tool_results(message_id)
|
||||||
.get(&message_id)
|
|
||||||
.map_or(false, |results| !results.is_empty())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn insert_user_message(
|
pub fn insert_user_message(
|
||||||
|
@ -360,20 +293,13 @@ impl Thread {
|
||||||
content: Vec::new(),
|
content: Vec::new(),
|
||||||
cache: false,
|
cache: false,
|
||||||
};
|
};
|
||||||
if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message.id) {
|
match request_kind {
|
||||||
match request_kind {
|
RequestKind::Chat => {
|
||||||
RequestKind::Chat => {
|
self.tool_use
|
||||||
for tool_use_id in tool_uses {
|
.attach_tool_results(message.id, &mut request_message);
|
||||||
if let Some(tool_result) = self.tool_results.get(tool_use_id) {
|
}
|
||||||
request_message
|
RequestKind::Summarize => {
|
||||||
.content
|
// We don't care about tool use during summarization.
|
||||||
.push(MessageContent::ToolResult(tool_result.clone()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
RequestKind::Summarize => {
|
|
||||||
// We don't care about tool use during summarization.
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -383,18 +309,13 @@ impl Thread {
|
||||||
.push(MessageContent::Text(message.text.clone()));
|
.push(MessageContent::Text(message.text.clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message.id) {
|
match request_kind {
|
||||||
match request_kind {
|
RequestKind::Chat => {
|
||||||
RequestKind::Chat => {
|
self.tool_use
|
||||||
for tool_use in tool_uses {
|
.attach_tool_uses(message.id, &mut request_message);
|
||||||
request_message
|
}
|
||||||
.content
|
RequestKind::Summarize => {
|
||||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
// We don't care about tool use during summarization.
|
||||||
}
|
|
||||||
}
|
|
||||||
RequestKind::Summarize => {
|
|
||||||
// We don't care about tool use during summarization.
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -470,32 +391,8 @@ impl Thread {
|
||||||
.rfind(|message| message.role == Role::Assistant)
|
.rfind(|message| message.role == Role::Assistant)
|
||||||
{
|
{
|
||||||
thread
|
thread
|
||||||
.tool_uses_by_assistant_message
|
.tool_use
|
||||||
.entry(last_assistant_message.id)
|
.request_tool_use(last_assistant_message.id, tool_use);
|
||||||
.or_default()
|
|
||||||
.push(tool_use.clone());
|
|
||||||
|
|
||||||
// The tool use is being requested by the
|
|
||||||
// Assistant, so we want to attach the tool
|
|
||||||
// results to the next user message.
|
|
||||||
let next_user_message_id =
|
|
||||||
MessageId(last_assistant_message.id.0 + 1);
|
|
||||||
thread
|
|
||||||
.tool_uses_by_user_message
|
|
||||||
.entry(next_user_message_id)
|
|
||||||
.or_default()
|
|
||||||
.push(tool_use.id.clone());
|
|
||||||
|
|
||||||
thread.pending_tool_uses_by_id.insert(
|
|
||||||
tool_use.id.clone(),
|
|
||||||
PendingToolUse {
|
|
||||||
assistant_message_id: last_assistant_message.id,
|
|
||||||
id: tool_use.id,
|
|
||||||
name: tool_use.name,
|
|
||||||
input: tool_use.input,
|
|
||||||
status: PendingToolUseStatus::Idle,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -624,49 +521,19 @@ impl Thread {
|
||||||
async move {
|
async move {
|
||||||
let output = output.await;
|
let output = output.await;
|
||||||
thread
|
thread
|
||||||
.update(&mut cx, |thread, cx| match output {
|
.update(&mut cx, |thread, cx| {
|
||||||
Ok(output) => {
|
thread
|
||||||
thread.tool_results.insert(
|
.tool_use
|
||||||
tool_use_id.clone(),
|
.insert_tool_output(tool_use_id.clone(), output);
|
||||||
LanguageModelToolResult {
|
|
||||||
tool_use_id: tool_use_id.clone(),
|
|
||||||
content: output.into(),
|
|
||||||
is_error: false,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
thread.pending_tool_uses_by_id.remove(&tool_use_id);
|
|
||||||
|
|
||||||
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
|
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
thread.tool_results.insert(
|
|
||||||
tool_use_id.clone(),
|
|
||||||
LanguageModelToolResult {
|
|
||||||
tool_use_id: tool_use_id.clone(),
|
|
||||||
content: err.to_string().into(),
|
|
||||||
is_error: true,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(tool_use) =
|
|
||||||
thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
|
|
||||||
{
|
|
||||||
tool_use.status =
|
|
||||||
PendingToolUseStatus::Error(err.to_string().into());
|
|
||||||
}
|
|
||||||
|
|
||||||
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
|
self.tool_use
|
||||||
tool_use.status = PendingToolUseStatus::Running {
|
.run_pending_tool(tool_use_id, insert_output_task);
|
||||||
_task: insert_output_task.shared(),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Cancels the last pending completion, if there are any pending.
|
/// Cancels the last pending completion, if there are any pending.
|
||||||
|
@ -708,30 +575,3 @@ struct PendingCompletion {
|
||||||
id: usize,
|
id: usize,
|
||||||
_task: Task<()>,
|
_task: Task<()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct PendingToolUse {
|
|
||||||
pub id: LanguageModelToolUseId,
|
|
||||||
/// The ID of the Assistant message in which the tool use was requested.
|
|
||||||
pub assistant_message_id: MessageId,
|
|
||||||
pub name: Arc<str>,
|
|
||||||
pub input: serde_json::Value,
|
|
||||||
pub status: PendingToolUseStatus,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum PendingToolUseStatus {
|
|
||||||
Idle,
|
|
||||||
Running { _task: Shared<Task<()>> },
|
|
||||||
Error(#[allow(unused)] Arc<str>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PendingToolUseStatus {
|
|
||||||
pub fn is_idle(&self) -> bool {
|
|
||||||
matches!(self, PendingToolUseStatus::Idle)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_error(&self) -> bool {
|
|
||||||
matches!(self, PendingToolUseStatus::Error(_))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
221
crates/assistant2/src/tool_use.rs
Normal file
221
crates/assistant2/src/tool_use.rs
Normal file
|
@ -0,0 +1,221 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use collections::HashMap;
|
||||||
|
use futures::future::Shared;
|
||||||
|
use futures::FutureExt as _;
|
||||||
|
use gpui::{SharedString, Task};
|
||||||
|
use language_model::{
|
||||||
|
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
|
||||||
|
LanguageModelToolUseId, MessageContent,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::thread::MessageId;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ToolUse {
|
||||||
|
pub id: LanguageModelToolUseId,
|
||||||
|
pub name: SharedString,
|
||||||
|
pub status: ToolUseStatus,
|
||||||
|
pub input: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum ToolUseStatus {
|
||||||
|
Pending,
|
||||||
|
Running,
|
||||||
|
Finished(SharedString),
|
||||||
|
Error(SharedString),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct ToolUseState {
|
||||||
|
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
|
||||||
|
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
|
||||||
|
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
||||||
|
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolUseState {
|
||||||
|
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
|
||||||
|
self.pending_tool_uses_by_id.values().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
|
||||||
|
let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
|
||||||
|
return Vec::new();
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut tool_uses = Vec::new();
|
||||||
|
|
||||||
|
for tool_use in tool_uses_for_message.iter() {
|
||||||
|
let tool_result = self.tool_results.get(&tool_use.id);
|
||||||
|
|
||||||
|
let status = (|| {
|
||||||
|
if let Some(tool_result) = tool_result {
|
||||||
|
return if tool_result.is_error {
|
||||||
|
ToolUseStatus::Error(tool_result.content.clone().into())
|
||||||
|
} else {
|
||||||
|
ToolUseStatus::Finished(tool_result.content.clone().into())
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
|
||||||
|
return match pending_tool_use.status {
|
||||||
|
PendingToolUseStatus::Idle => ToolUseStatus::Pending,
|
||||||
|
PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
|
||||||
|
PendingToolUseStatus::Error(ref err) => {
|
||||||
|
ToolUseStatus::Error(err.clone().into())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
ToolUseStatus::Pending
|
||||||
|
})();
|
||||||
|
|
||||||
|
tool_uses.push(ToolUse {
|
||||||
|
id: tool_use.id.clone(),
|
||||||
|
name: tool_use.name.clone().into(),
|
||||||
|
input: tool_use.input.clone(),
|
||||||
|
status,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_uses
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||||
|
self.tool_uses_by_user_message
|
||||||
|
.get(&message_id)
|
||||||
|
.map_or(false, |results| !results.is_empty())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn request_tool_use(
|
||||||
|
&mut self,
|
||||||
|
assistant_message_id: MessageId,
|
||||||
|
tool_use: LanguageModelToolUse,
|
||||||
|
) {
|
||||||
|
self.tool_uses_by_assistant_message
|
||||||
|
.entry(assistant_message_id)
|
||||||
|
.or_default()
|
||||||
|
.push(tool_use.clone());
|
||||||
|
|
||||||
|
// The tool use is being requested by the Assistant, so we want to
|
||||||
|
// attach the tool results to the next user message.
|
||||||
|
let next_user_message_id = MessageId(assistant_message_id.0 + 1);
|
||||||
|
self.tool_uses_by_user_message
|
||||||
|
.entry(next_user_message_id)
|
||||||
|
.or_default()
|
||||||
|
.push(tool_use.id.clone());
|
||||||
|
|
||||||
|
self.pending_tool_uses_by_id.insert(
|
||||||
|
tool_use.id.clone(),
|
||||||
|
PendingToolUse {
|
||||||
|
assistant_message_id,
|
||||||
|
id: tool_use.id,
|
||||||
|
name: tool_use.name,
|
||||||
|
input: tool_use.input,
|
||||||
|
status: PendingToolUseStatus::Idle,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run_pending_tool(&mut self, tool_use_id: LanguageModelToolUseId, task: Task<()>) {
|
||||||
|
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
|
||||||
|
tool_use.status = PendingToolUseStatus::Running {
|
||||||
|
_task: task.shared(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert_tool_output(
|
||||||
|
&mut self,
|
||||||
|
tool_use_id: LanguageModelToolUseId,
|
||||||
|
output: Result<String>,
|
||||||
|
) {
|
||||||
|
match output {
|
||||||
|
Ok(output) => {
|
||||||
|
self.tool_results.insert(
|
||||||
|
tool_use_id.clone(),
|
||||||
|
LanguageModelToolResult {
|
||||||
|
tool_use_id: tool_use_id.clone(),
|
||||||
|
content: output.into(),
|
||||||
|
is_error: false,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
self.pending_tool_uses_by_id.remove(&tool_use_id);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
self.tool_results.insert(
|
||||||
|
tool_use_id.clone(),
|
||||||
|
LanguageModelToolResult {
|
||||||
|
tool_use_id: tool_use_id.clone(),
|
||||||
|
content: err.to_string().into(),
|
||||||
|
is_error: true,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
|
||||||
|
tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn attach_tool_uses(
|
||||||
|
&self,
|
||||||
|
message_id: MessageId,
|
||||||
|
request_message: &mut LanguageModelRequestMessage,
|
||||||
|
) {
|
||||||
|
if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
|
||||||
|
for tool_use in tool_uses {
|
||||||
|
request_message
|
||||||
|
.content
|
||||||
|
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn attach_tool_results(
|
||||||
|
&self,
|
||||||
|
message_id: MessageId,
|
||||||
|
request_message: &mut LanguageModelRequestMessage,
|
||||||
|
) {
|
||||||
|
if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
|
||||||
|
for tool_use_id in tool_uses {
|
||||||
|
if let Some(tool_result) = self.tool_results.get(tool_use_id) {
|
||||||
|
request_message
|
||||||
|
.content
|
||||||
|
.push(MessageContent::ToolResult(tool_result.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PendingToolUse {
|
||||||
|
pub id: LanguageModelToolUseId,
|
||||||
|
/// The ID of the Assistant message in which the tool use was requested.
|
||||||
|
pub assistant_message_id: MessageId,
|
||||||
|
pub name: Arc<str>,
|
||||||
|
pub input: serde_json::Value,
|
||||||
|
pub status: PendingToolUseStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum PendingToolUseStatus {
|
||||||
|
Idle,
|
||||||
|
Running { _task: Shared<Task<()>> },
|
||||||
|
Error(#[allow(unused)] Arc<str>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PendingToolUseStatus {
|
||||||
|
pub fn is_idle(&self) -> bool {
|
||||||
|
matches!(self, PendingToolUseStatus::Idle)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_error(&self) -> bool {
|
||||||
|
matches!(self, PendingToolUseStatus::Error(_))
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue