diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index f31a2b04c6..0816e541b4 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize}; use util::{post_inc, TryFutureExt as _}; use uuid::Uuid; -use crate::context::{attach_context_to_message, Context}; +use crate::context::{attach_context_to_message, Context, ContextId}; #[derive(Debug, Clone, Copy)] pub enum RequestKind { @@ -64,7 +64,8 @@ pub struct Thread { pending_summary: Task>, messages: Vec, next_message_id: MessageId, - context_by_message: HashMap>, + context: HashMap, + context_by_message: HashMap>, completion_count: usize, pending_completions: Vec, tools: Arc, @@ -82,6 +83,7 @@ impl Thread { pending_summary: Task::ready(None), messages: Vec::new(), next_message_id: MessageId(0), + context: HashMap::default(), context_by_message: HashMap::default(), completion_count: 0, pending_completions: Vec::new(), @@ -129,8 +131,15 @@ impl Thread { &self.tools } - pub fn context_for_message(&self, id: MessageId) -> Option<&Vec> { - self.context_by_message.get(&id) + pub fn context_for_message(&self, id: MessageId) -> Option> { + let context = self.context_by_message.get(&id)?; + Some( + context + .into_iter() + .filter_map(|context_id| self.context.get(&context_id)) + .cloned() + .collect::>(), + ) } pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> { @@ -144,7 +153,10 @@ impl Thread { cx: &mut ModelContext, ) { let message_id = self.insert_message(Role::User, text, cx); - self.context_by_message.insert(message_id, context); + let context_ids = context.iter().map(|context| context.id).collect::>(); + self.context + .extend(context.into_iter().map(|context| (context.id, context))); + self.context_by_message.insert(message_id, context_ids); } pub fn insert_message(