assistant2: Restore tool uses when loading saved threads (#25942)

This PR makes it so tool uses are restored when loading saved threads in
Assistant 2.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-03-03 13:32:26 -05:00 committed by GitHub
parent 6635462f7b
commit b2add8c803
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 120 additions and 8 deletions

View file

@ -8,8 +8,9 @@ use futures::StreamExt as _;
use gpui::{App, Context, EventEmitter, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolUseId,
MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, Role, StopReason,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
Role, StopReason,
};
use serde::{Deserialize, Serialize};
use util::{post_inc, TryFutureExt as _};
@ -88,7 +89,7 @@ impl Thread {
completion_count: 0,
pending_completions: Vec::new(),
tools,
tool_use: ToolUseState::default(),
tool_use: ToolUseState::new(),
}
}
@ -99,6 +100,7 @@ impl Thread {
_cx: &mut Context<Self>,
) -> Self {
let next_message_id = MessageId(saved.messages.len());
let tool_use = ToolUseState::from_saved_messages(&saved.messages);
Self {
id,
@ -120,7 +122,7 @@ impl Thread {
completion_count: 0,
pending_completions: Vec::new(),
tools,
tool_use: ToolUseState::default(),
tool_use,
}
}
@ -189,6 +191,10 @@ impl Thread {
self.tool_use.tool_uses_for_message(id)
}
pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
self.tool_use.tool_results_for_message(id)
}
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_use.message_has_tool_results(message_id)
}

View file

@ -14,7 +14,7 @@ use gpui::{
};
use heed::types::{SerdeBincode, SerdeJson};
use heed::Database;
use language_model::Role;
use language_model::{LanguageModelToolUseId, Role};
use project::Project;
use serde::{Deserialize, Serialize};
use util::ResultExt as _;
@ -113,6 +113,24 @@ impl ThreadStore {
id: message.id,
role: message.role,
text: message.text.clone(),
tool_uses: thread
.tool_uses_for_message(message.id)
.into_iter()
.map(|tool_use| SavedToolUse {
id: tool_use.id,
name: tool_use.name,
input: tool_use.input,
})
.collect(),
tool_results: thread
.tool_results_for_message(message.id)
.into_iter()
.map(|tool_result| SavedToolResult {
tool_use_id: tool_result.tool_use_id.clone(),
is_error: tool_result.is_error,
content: tool_result.content.clone(),
})
.collect(),
})
.collect(),
};
@ -239,11 +257,29 @@ pub struct SavedThread {
pub messages: Vec<SavedMessage>,
}
#[derive(Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct SavedMessage {
pub id: MessageId,
pub role: Role,
pub text: String,
#[serde(default)]
pub tool_uses: Vec<SavedToolUse>,
#[serde(default)]
pub tool_results: Vec<SavedToolResult>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SavedToolUse {
pub id: LanguageModelToolUseId,
pub name: SharedString,
pub input: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SavedToolResult {
pub tool_use_id: LanguageModelToolUseId,
pub is_error: bool,
pub content: Arc<str>,
}
struct GlobalThreadsDatabase(

View file

@ -7,10 +7,11 @@ use futures::FutureExt as _;
use gpui::{SharedString, Task};
use language_model::{
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
LanguageModelToolUseId, MessageContent,
LanguageModelToolUseId, MessageContent, Role,
};
use crate::thread::MessageId;
use crate::thread_store::SavedMessage;
#[derive(Debug)]
pub struct ToolUse {
@ -28,7 +29,6 @@ pub enum ToolUseStatus {
Error(SharedString),
}
#[derive(Default)]
pub struct ToolUseState {
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
@ -37,6 +37,65 @@ pub struct ToolUseState {
}
impl ToolUseState {
pub fn new() -> Self {
Self {
tool_uses_by_assistant_message: HashMap::default(),
tool_uses_by_user_message: HashMap::default(),
tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
}
}
pub fn from_saved_messages(messages: &[SavedMessage]) -> Self {
let mut this = Self::new();
for message in messages {
match message.role {
Role::Assistant => {
if !message.tool_uses.is_empty() {
this.tool_uses_by_assistant_message.insert(
message.id,
message
.tool_uses
.iter()
.map(|tool_use| LanguageModelToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
input: tool_use.input.clone(),
})
.collect(),
);
}
}
Role::User => {
if !message.tool_results.is_empty() {
let tool_uses_by_user_message = this
.tool_uses_by_user_message
.entry(message.id)
.or_default();
for tool_result in &message.tool_results {
let tool_use_id = tool_result.tool_use_id.clone();
tool_uses_by_user_message.push(tool_use_id.clone());
this.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id,
is_error: tool_result.is_error,
content: tool_result.content.clone(),
},
);
}
}
}
Role::System => {}
}
}
this
}
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
self.pending_tool_uses_by_id.values().collect()
}
@ -84,6 +143,17 @@ impl ToolUseState {
tool_uses
}
pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
let empty = Vec::new();
self.tool_uses_by_user_message
.get(&message_id)
.unwrap_or(&empty)
.iter()
.filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
.collect()
}
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_uses_by_user_message
.get(&message_id)