From 3d8625f25c67c1af4dfd1516691472c33c7971d0 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Tue, 7 Jan 2025 12:21:39 -0500 Subject: [PATCH] assistant2: Store deduped context on the `Thread` (#22781) This PR is a small refactoring in advance of some other changes. Previously we were storing the whole `Context` associated with each message. However, it's likely that multiple messages may end up using the same context. We now store the deduped context in a separate collection and refer to it from each message by its `ContextId`. Release Notes: - N/A --- crates/assistant2/src/thread.rs | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index f31a2b04c6..0816e541b4 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize}; use util::{post_inc, TryFutureExt as _}; use uuid::Uuid; -use crate::context::{attach_context_to_message, Context}; +use crate::context::{attach_context_to_message, Context, ContextId}; #[derive(Debug, Clone, Copy)] pub enum RequestKind { @@ -64,7 +64,8 @@ pub struct Thread { pending_summary: Task>, messages: Vec, next_message_id: MessageId, - context_by_message: HashMap>, + context: HashMap, + context_by_message: HashMap>, completion_count: usize, pending_completions: Vec, tools: Arc, @@ -82,6 +83,7 @@ impl Thread { pending_summary: Task::ready(None), messages: Vec::new(), next_message_id: MessageId(0), + context: HashMap::default(), context_by_message: HashMap::default(), completion_count: 0, pending_completions: Vec::new(), @@ -129,8 +131,15 @@ impl Thread { &self.tools } - pub fn context_for_message(&self, id: MessageId) -> Option<&Vec> { - self.context_by_message.get(&id) + pub fn context_for_message(&self, id: MessageId) -> Option> { + let context = self.context_by_message.get(&id)?; + Some( + context + .into_iter() + .filter_map(|context_id| self.context.get(&context_id)) + .cloned() + .collect::>(), + ) } pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> { @@ -144,7 +153,10 @@ impl Thread { cx: &mut ModelContext, ) { let message_id = self.insert_message(Role::User, text, cx); - self.context_by_message.insert(message_id, context); + let context_ids = context.iter().map(|context| context.id).collect::>(); + self.context + .extend(context.into_iter().map(|context| (context.id, context))); + self.context_by_message.insert(message_id, context_ids); } pub fn insert_message(