assistant2: Persist scripting tool uses in saved threads (#26404)
This PR makes it so the scripting tool uses are persisted to and restored from saved threads. Release Notes: - N/A
This commit is contained in:
parent
6cfc4dc857
commit
082cc6184c
3 changed files with 52 additions and 21 deletions
|
@ -117,8 +117,10 @@ impl Thread {
|
||||||
.map(|message| message.id.0 + 1)
|
.map(|message| message.id.0 + 1)
|
||||||
.unwrap_or(0),
|
.unwrap_or(0),
|
||||||
);
|
);
|
||||||
let tool_use = ToolUseState::from_saved_messages(&saved.messages);
|
let tool_use =
|
||||||
let scripting_tool_use = ToolUseState::new();
|
ToolUseState::from_saved_messages(&saved.messages, |name| name != ScriptingTool::NAME);
|
||||||
|
let scripting_tool_use =
|
||||||
|
ToolUseState::from_saved_messages(&saved.messages, |name| name == ScriptingTool::NAME);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
id,
|
id,
|
||||||
|
|
|
@ -116,28 +116,35 @@ impl ThreadStore {
|
||||||
updated_at: thread.updated_at(),
|
updated_at: thread.updated_at(),
|
||||||
messages: thread
|
messages: thread
|
||||||
.messages()
|
.messages()
|
||||||
.map(|message| SavedMessage {
|
.map(|message| {
|
||||||
id: message.id,
|
let all_tool_uses = thread
|
||||||
role: message.role,
|
|
||||||
text: message.text.clone(),
|
|
||||||
tool_uses: thread
|
|
||||||
.tool_uses_for_message(message.id)
|
.tool_uses_for_message(message.id)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
.chain(thread.scripting_tool_uses_for_message(message.id))
|
||||||
.map(|tool_use| SavedToolUse {
|
.map(|tool_use| SavedToolUse {
|
||||||
id: tool_use.id,
|
id: tool_use.id,
|
||||||
name: tool_use.name,
|
name: tool_use.name,
|
||||||
input: tool_use.input,
|
input: tool_use.input,
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect();
|
||||||
tool_results: thread
|
let all_tool_results = thread
|
||||||
.tool_results_for_message(message.id)
|
.tool_results_for_message(message.id)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
.chain(thread.scripting_tool_results_for_message(message.id))
|
||||||
.map(|tool_result| SavedToolResult {
|
.map(|tool_result| SavedToolResult {
|
||||||
tool_use_id: tool_result.tool_use_id.clone(),
|
tool_use_id: tool_result.tool_use_id.clone(),
|
||||||
is_error: tool_result.is_error,
|
is_error: tool_result.is_error,
|
||||||
content: tool_result.content.clone(),
|
content: tool_result.content.clone(),
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect();
|
||||||
|
|
||||||
|
SavedMessage {
|
||||||
|
id: message.id,
|
||||||
|
role: message.role,
|
||||||
|
text: message.text.clone(),
|
||||||
|
tool_uses: all_tool_uses,
|
||||||
|
tool_results: all_tool_results,
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
};
|
};
|
||||||
|
|
|
@ -46,25 +46,39 @@ impl ToolUseState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_saved_messages(messages: &[SavedMessage]) -> Self {
|
/// Constructs a [`ToolUseState`] from the given list of [`SavedMessage`]s.
|
||||||
|
///
|
||||||
|
/// Accepts a function to filter the tools that should be used to populate the state.
|
||||||
|
pub fn from_saved_messages(
|
||||||
|
messages: &[SavedMessage],
|
||||||
|
mut filter_by_tool_name: impl FnMut(&str) -> bool,
|
||||||
|
) -> Self {
|
||||||
let mut this = Self::new();
|
let mut this = Self::new();
|
||||||
|
let mut tool_names_by_id = HashMap::default();
|
||||||
|
|
||||||
for message in messages {
|
for message in messages {
|
||||||
match message.role {
|
match message.role {
|
||||||
Role::Assistant => {
|
Role::Assistant => {
|
||||||
if !message.tool_uses.is_empty() {
|
if !message.tool_uses.is_empty() {
|
||||||
this.tool_uses_by_assistant_message.insert(
|
let tool_uses = message
|
||||||
message.id,
|
.tool_uses
|
||||||
message
|
.iter()
|
||||||
.tool_uses
|
.filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
|
||||||
|
.map(|tool_use| LanguageModelToolUse {
|
||||||
|
id: tool_use.id.clone(),
|
||||||
|
name: tool_use.name.clone().into(),
|
||||||
|
input: tool_use.input.clone(),
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
tool_names_by_id.extend(
|
||||||
|
tool_uses
|
||||||
.iter()
|
.iter()
|
||||||
.map(|tool_use| LanguageModelToolUse {
|
.map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
|
||||||
id: tool_use.id.clone(),
|
|
||||||
name: tool_use.name.clone().into(),
|
|
||||||
input: tool_use.input.clone(),
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
this.tool_uses_by_assistant_message
|
||||||
|
.insert(message.id, tool_uses);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Role::User => {
|
Role::User => {
|
||||||
|
@ -76,6 +90,14 @@ impl ToolUseState {
|
||||||
|
|
||||||
for tool_result in &message.tool_results {
|
for tool_result in &message.tool_results {
|
||||||
let tool_use_id = tool_result.tool_use_id.clone();
|
let tool_use_id = tool_result.tool_use_id.clone();
|
||||||
|
let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
|
||||||
|
log::warn!("no tool name found for tool use: {tool_use_id:?}");
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
if !(filter_by_tool_name)(tool_use.as_ref()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
tool_uses_by_user_message.push(tool_use_id.clone());
|
tool_uses_by_user_message.push(tool_use_id.clone());
|
||||||
this.tool_results.insert(
|
this.tool_results.insert(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue