assistant2: Store deduped context on the Thread (#22781)

This PR is a small refactoring in advance of some other changes.

Previously we were storing the whole `Context` associated with each
message. However, it's likely that multiple messages may end up using
the same context.

We now store the deduped context in a separate collection and refer to
it from each message by its `ContextId`.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-01-07 12:21:39 -05:00 committed by GitHub
parent f53a17b044
commit 3d8625f25c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
use util::{post_inc, TryFutureExt as _};
use uuid::Uuid;
use crate::context::{attach_context_to_message, Context};
use crate::context::{attach_context_to_message, Context, ContextId};
#[derive(Debug, Clone, Copy)]
pub enum RequestKind {
@ -64,7 +64,8 @@ pub struct Thread {
pending_summary: Task<Option<()>>,
messages: Vec<Message>,
next_message_id: MessageId,
context_by_message: HashMap<MessageId, Vec<Context>>,
context: HashMap<ContextId, Context>,
context_by_message: HashMap<MessageId, Vec<ContextId>>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
tools: Arc<ToolWorkingSet>,
@ -82,6 +83,7 @@ impl Thread {
pending_summary: Task::ready(None),
messages: Vec::new(),
next_message_id: MessageId(0),
context: HashMap::default(),
context_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
@ -129,8 +131,15 @@ impl Thread {
&self.tools
}
pub fn context_for_message(&self, id: MessageId) -> Option<&Vec<Context>> {
self.context_by_message.get(&id)
pub fn context_for_message(&self, id: MessageId) -> Option<Vec<Context>> {
let context = self.context_by_message.get(&id)?;
Some(
context
.into_iter()
.filter_map(|context_id| self.context.get(&context_id))
.cloned()
.collect::<Vec<_>>(),
)
}
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
@ -144,7 +153,10 @@ impl Thread {
cx: &mut ModelContext<Self>,
) {
let message_id = self.insert_message(Role::User, text, cx);
self.context_by_message.insert(message_id, context);
let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
self.context
.extend(context.into_iter().map(|context| (context.id, context)));
self.context_by_message.insert(message_id, context_ids);
}
pub fn insert_message(