assistant2: Persist scripting tool uses in saved threads (#26404)

This PR makes it so the scripting tool uses are persisted to and
restored from saved threads.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-03-10 17:42:23 -04:00 committed by GitHub
parent 6cfc4dc857
commit 082cc6184c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 52 additions and 21 deletions

View file

@ -117,8 +117,10 @@ impl Thread {
.map(|message| message.id.0 + 1) .map(|message| message.id.0 + 1)
.unwrap_or(0), .unwrap_or(0),
); );
let tool_use = ToolUseState::from_saved_messages(&saved.messages); let tool_use =
let scripting_tool_use = ToolUseState::new(); ToolUseState::from_saved_messages(&saved.messages, |name| name != ScriptingTool::NAME);
let scripting_tool_use =
ToolUseState::from_saved_messages(&saved.messages, |name| name == ScriptingTool::NAME);
Self { Self {
id, id,

View file

@ -116,28 +116,35 @@ impl ThreadStore {
updated_at: thread.updated_at(), updated_at: thread.updated_at(),
messages: thread messages: thread
.messages() .messages()
.map(|message| SavedMessage { .map(|message| {
id: message.id, let all_tool_uses = thread
role: message.role,
text: message.text.clone(),
tool_uses: thread
.tool_uses_for_message(message.id) .tool_uses_for_message(message.id)
.into_iter() .into_iter()
.chain(thread.scripting_tool_uses_for_message(message.id))
.map(|tool_use| SavedToolUse { .map(|tool_use| SavedToolUse {
id: tool_use.id, id: tool_use.id,
name: tool_use.name, name: tool_use.name,
input: tool_use.input, input: tool_use.input,
}) })
.collect(), .collect();
tool_results: thread let all_tool_results = thread
.tool_results_for_message(message.id) .tool_results_for_message(message.id)
.into_iter() .into_iter()
.chain(thread.scripting_tool_results_for_message(message.id))
.map(|tool_result| SavedToolResult { .map(|tool_result| SavedToolResult {
tool_use_id: tool_result.tool_use_id.clone(), tool_use_id: tool_result.tool_use_id.clone(),
is_error: tool_result.is_error, is_error: tool_result.is_error,
content: tool_result.content.clone(), content: tool_result.content.clone(),
}) })
.collect(), .collect();
SavedMessage {
id: message.id,
role: message.role,
text: message.text.clone(),
tool_uses: all_tool_uses,
tool_results: all_tool_results,
}
}) })
.collect(), .collect(),
}; };

View file

@ -46,25 +46,39 @@ impl ToolUseState {
} }
} }
pub fn from_saved_messages(messages: &[SavedMessage]) -> Self { /// Constructs a [`ToolUseState`] from the given list of [`SavedMessage`]s.
///
/// Accepts a function to filter the tools that should be used to populate the state.
pub fn from_saved_messages(
messages: &[SavedMessage],
mut filter_by_tool_name: impl FnMut(&str) -> bool,
) -> Self {
let mut this = Self::new(); let mut this = Self::new();
let mut tool_names_by_id = HashMap::default();
for message in messages { for message in messages {
match message.role { match message.role {
Role::Assistant => { Role::Assistant => {
if !message.tool_uses.is_empty() { if !message.tool_uses.is_empty() {
this.tool_uses_by_assistant_message.insert( let tool_uses = message
message.id, .tool_uses
message .iter()
.tool_uses .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
.map(|tool_use| LanguageModelToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
input: tool_use.input.clone(),
})
.collect::<Vec<_>>();
tool_names_by_id.extend(
tool_uses
.iter() .iter()
.map(|tool_use| LanguageModelToolUse { .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
input: tool_use.input.clone(),
})
.collect(),
); );
this.tool_uses_by_assistant_message
.insert(message.id, tool_uses);
} }
} }
Role::User => { Role::User => {
@ -76,6 +90,14 @@ impl ToolUseState {
for tool_result in &message.tool_results { for tool_result in &message.tool_results {
let tool_use_id = tool_result.tool_use_id.clone(); let tool_use_id = tool_result.tool_use_id.clone();
let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
log::warn!("no tool name found for tool use: {tool_use_id:?}");
continue;
};
if !(filter_by_tool_name)(tool_use.as_ref()) {
continue;
}
tool_uses_by_user_message.push(tool_use_id.clone()); tool_uses_by_user_message.push(tool_use_id.clone());
this.tool_results.insert( this.tool_results.insert(