assistant2: Make context persistent in the thread (#22789)

This PR makes it so the context is persistent in the thread, rather than
having to reattach it for each message.

This PR intentionally does not make an attempt to refresh the attached
context if it changes. That will come in a follow-up.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-01-07 14:16:30 -05:00 committed by GitHub
parent 76a8b55f77
commit fffa40f973
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 26 additions and 12 deletions

View file

@ -36,12 +36,6 @@ impl ContextStore {
&self.context &self.context
} }
pub fn drain(&mut self) -> Vec<Context> {
let context = self.context.drain(..).collect();
self.clear();
context
}
pub fn clear(&mut self) { pub fn clear(&mut self) {
self.context.clear(); self.context.clear();
self.files.clear(); self.files.clear();

View file

@ -142,7 +142,9 @@ impl MessageEditor {
editor.clear(cx); editor.clear(cx);
text text
}); });
let context = self.context_store.update(cx, |this, _cx| this.drain()); let context = self
.context_store
.update(cx, |this, _cx| this.context().clone());
self.thread.update(cx, |thread, cx| { self.thread.update(cx, |thread, cx| {
thread.insert_user_message(user_message, context, cx); thread.insert_user_message(user_message, context, cx);

View file

@ -3,7 +3,7 @@ use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use assistant_tool::ToolWorkingSet; use assistant_tool::ToolWorkingSet;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::HashMap; use collections::{HashMap, HashSet};
use futures::future::Shared; use futures::future::Shared;
use futures::{FutureExt as _, StreamExt as _}; use futures::{FutureExt as _, StreamExt as _};
use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task}; use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task};
@ -209,7 +209,13 @@ impl Thread {
temperature: None, temperature: None,
}; };
let mut referenced_context_ids = HashSet::default();
for message in &self.messages { for message in &self.messages {
if let Some(context_ids) = self.context_by_message.get(&message.id) {
referenced_context_ids.extend(context_ids);
}
let mut request_message = LanguageModelRequestMessage { let mut request_message = LanguageModelRequestMessage {
role: message.role, role: message.role,
content: Vec::new(), content: Vec::new(),
@ -224,10 +230,6 @@ impl Thread {
} }
} }
if let Some(context) = self.context_for_message(message.id) {
attach_context_to_message(&mut request_message, context.clone());
}
if !message.text.is_empty() { if !message.text.is_empty() {
request_message request_message
.content .content
@ -245,6 +247,22 @@ impl Thread {
request.messages.push(request_message); request.messages.push(request_message);
} }
if !referenced_context_ids.is_empty() {
let mut context_message = LanguageModelRequestMessage {
role: Role::User,
content: Vec::new(),
cache: false,
};
let referenced_context = referenced_context_ids
.into_iter()
.filter_map(|context_id| self.context.get(context_id))
.cloned();
attach_context_to_message(&mut context_message, referenced_context);
request.messages.push(context_message);
}
request request
} }