Restructure agent context (#29233)
Simplifies the data structures involved in agent context by removing caching and limiting the use of ContextId: * `AssistantContext` enum is now like an ID / handle to context that does not need to be updated. `ContextId` still exists but is only used for generating unique `ElementId`. * `ContextStore` has a `IndexMap<ContextSetEntry>`. Only need to keep a `HashSet<ThreadId>` consistent with it. `ContextSetEntry` is a newtype wrapper around `AssistantContext` which implements eq / hash on a subset of fields. * Thread `Message` directly stores its context. Fixes the following bugs: * If a context entry is removed from the strip and added again, it was reincluded in the next message. * Clicking file context in the thread that has been removed from the context strip didn't jump to the file. * Refresh of directory context didn't reflect added / removed files. * Deleted directories would remain in the message editor context strip. * Token counting requests didn't include image context. * File, directory, and symbol context deduplication relied on `ProjectPath` for identity, and so didn't handle renames. * Symbol context line numbers didn't update when shifted Known bugs (not fixed): * Deleting a directory causes it to disappear from messages in threads. Fixing this in a nice way is tricky. One easy fix is to store the original path and show that on deletion. It's weird that deletion would cause the name to "revert", though. Another possibility would be to snapshot context metadata on add (ala `AddedContext`), and keep that around despite deletion. Release Notes: - N/A
This commit is contained in:
parent
d492939bed
commit
17ecf94f6f
31 changed files with 1893 additions and 2061 deletions
|
@ -1,6 +1,6 @@
|
|||
use crate::context::attach_context_to_message;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::context::ContextLoadResult;
|
||||
use crate::inline_prompt_editor::CodegenStatus;
|
||||
use crate::{context::load_context, context_store::ContextStore};
|
||||
use anyhow::Result;
|
||||
use client::telemetry::Telemetry;
|
||||
use collections::HashSet;
|
||||
|
@ -8,7 +8,7 @@ use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset
|
|||
use futures::{
|
||||
SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
|
||||
};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
|
||||
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
|
@ -16,7 +16,9 @@ use language_model::{
|
|||
};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use parking_lot::Mutex;
|
||||
use project::Project;
|
||||
use prompt_store::PromptBuilder;
|
||||
use prompt_store::PromptStore;
|
||||
use rope::Rope;
|
||||
use smol::future::FutureExt;
|
||||
use std::{
|
||||
|
@ -41,6 +43,8 @@ pub struct BufferCodegen {
|
|||
range: Range<Anchor>,
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
context_store: Entity<ContextStore>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Arc<Telemetry>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
pub is_insertion: bool,
|
||||
|
@ -52,6 +56,8 @@ impl BufferCodegen {
|
|||
range: Range<Anchor>,
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
context_store: Entity<ContextStore>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Arc<Telemetry>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
cx: &mut Context<Self>,
|
||||
|
@ -62,6 +68,8 @@ impl BufferCodegen {
|
|||
range.clone(),
|
||||
false,
|
||||
Some(context_store.clone()),
|
||||
project.clone(),
|
||||
prompt_store.clone(),
|
||||
Some(telemetry.clone()),
|
||||
builder.clone(),
|
||||
cx,
|
||||
|
@ -77,6 +85,8 @@ impl BufferCodegen {
|
|||
range,
|
||||
initial_transaction_id,
|
||||
context_store,
|
||||
project,
|
||||
prompt_store,
|
||||
telemetry,
|
||||
builder,
|
||||
};
|
||||
|
@ -155,6 +165,8 @@ impl BufferCodegen {
|
|||
self.range.clone(),
|
||||
false,
|
||||
Some(self.context_store.clone()),
|
||||
self.project.clone(),
|
||||
self.prompt_store.clone(),
|
||||
Some(self.telemetry.clone()),
|
||||
self.builder.clone(),
|
||||
cx,
|
||||
|
@ -231,13 +243,14 @@ pub struct CodegenAlternative {
|
|||
generation: Task<()>,
|
||||
diff: Diff,
|
||||
context_store: Option<Entity<ContextStore>>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
_subscription: gpui::Subscription,
|
||||
builder: Arc<PromptBuilder>,
|
||||
active: bool,
|
||||
edits: Vec<(Range<Anchor>, String)>,
|
||||
line_operations: Vec<LineOperation>,
|
||||
request: Option<LanguageModelRequest>,
|
||||
elapsed_time: Option<f64>,
|
||||
completion: Option<String>,
|
||||
pub message_id: Option<String>,
|
||||
|
@ -251,6 +264,8 @@ impl CodegenAlternative {
|
|||
range: Range<Anchor>,
|
||||
active: bool,
|
||||
context_store: Option<Entity<ContextStore>>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
cx: &mut Context<Self>,
|
||||
|
@ -292,6 +307,8 @@ impl CodegenAlternative {
|
|||
generation: Task::ready(()),
|
||||
diff: Diff::default(),
|
||||
context_store,
|
||||
project,
|
||||
prompt_store,
|
||||
telemetry,
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
builder,
|
||||
|
@ -299,7 +316,6 @@ impl CodegenAlternative {
|
|||
edits: Vec::new(),
|
||||
line_operations: Vec::new(),
|
||||
range,
|
||||
request: None,
|
||||
elapsed_time: None,
|
||||
completion: None,
|
||||
}
|
||||
|
@ -368,16 +384,18 @@ impl CodegenAlternative {
|
|||
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
|
||||
} else {
|
||||
let request = self.build_request(user_prompt, cx)?;
|
||||
self.request = Some(request.clone());
|
||||
|
||||
cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await)
|
||||
cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await)
|
||||
.boxed_local()
|
||||
};
|
||||
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_request(&self, user_prompt: String, cx: &mut App) -> Result<LanguageModelRequest> {
|
||||
fn build_request(
|
||||
&self,
|
||||
user_prompt: String,
|
||||
cx: &mut App,
|
||||
) -> Result<Task<LanguageModelRequest>> {
|
||||
let buffer = self.buffer.read(cx).snapshot(cx);
|
||||
let language = buffer.language_at(self.range.start);
|
||||
let language_name = if let Some(language) = language.as_ref() {
|
||||
|
@ -410,30 +428,44 @@ impl CodegenAlternative {
|
|||
.generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
let context_task = self.context_store.as_ref().map(|context_store| {
|
||||
if let Some(project) = self.project.upgrade() {
|
||||
let context = context_store
|
||||
.read(cx)
|
||||
.context()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
load_context(context, &project, &self.prompt_store, cx)
|
||||
} else {
|
||||
Task::ready(ContextLoadResult::default())
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(context_store) = &self.context_store {
|
||||
attach_context_to_message(
|
||||
&mut request_message,
|
||||
context_store.read(cx).context().iter(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
Ok(cx.spawn(async move |_cx| {
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
|
||||
request_message.content.push(prompt.into());
|
||||
if let Some(context_task) = context_task {
|
||||
context_task
|
||||
.await
|
||||
.loaded_context
|
||||
.add_to_request_message(&mut request_message);
|
||||
}
|
||||
|
||||
Ok(LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
messages: vec![request_message],
|
||||
})
|
||||
request_message.content.push(prompt.into());
|
||||
|
||||
LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
messages: vec![request_message],
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn handle_stream(
|
||||
|
@ -1038,6 +1070,7 @@ impl Diff {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use fs::FakeFs;
|
||||
use futures::{
|
||||
Stream,
|
||||
stream::{self},
|
||||
|
@ -1080,12 +1113,16 @@ mod tests {
|
|||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
|
@ -1144,12 +1181,16 @@ mod tests {
|
|||
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
|
@ -1211,12 +1252,16 @@ mod tests {
|
|||
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
|
@ -1278,12 +1323,16 @@ mod tests {
|
|||
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
|
@ -1333,12 +1382,16 @@ mod tests {
|
|||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
false,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue