Simplify logic & add UI affordances to show model cache status (#16395)
Release Notes: - Adds UI affordances to the assistant panel to show which messages have been cached - Migrate cache invalidation to be based on `has_edits_since_in_range` to be smarter and more selective about when to invalidate the cache and when to fetch. <img width="310" alt="Screenshot 2024-08-16 at 11 19 23 PM" src="https://github.com/user-attachments/assets/4ee2d111-2f55-4b0e-b944-50c4f78afc42"> <img width="580" alt="Screenshot 2024-08-18 at 10 05 16 PM" src="https://github.com/user-attachments/assets/17630a60-7b78-421c-ae39-425246638a12"> I had originally added the lightening bolt on every message and only added the tooltip warning about editing prior messages on the first anchor, but thought it looked too busy, so I settled on just annotating the last anchor.
This commit is contained in:
parent
971db5c6f6
commit
0042c24d3c
5 changed files with 365 additions and 77 deletions
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
assistant_panel, prompt_library, slash_command::file_command, workflow::tool, Context,
|
||||
ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
|
||||
assistant_panel, prompt_library, slash_command::file_command, workflow::tool, CacheStatus,
|
||||
Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use assistant_slash_command::{
|
||||
|
@ -12,7 +12,7 @@ use fs::{FakeFs, Fs as _};
|
|||
use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
|
||||
use indoc::indoc;
|
||||
use language::{Buffer, LanguageRegistry, LspAdapterDelegate};
|
||||
use language_model::{LanguageModelRegistry, Role};
|
||||
use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
|
||||
use parking_lot::Mutex;
|
||||
use project::Project;
|
||||
use rand::prelude::*;
|
||||
|
@ -33,6 +33,8 @@ use unindent::Unindent;
|
|||
use util::{test::marked_text_ranges, RandomCharIter};
|
||||
use workspace::Workspace;
|
||||
|
||||
use super::MessageCacheMetadata;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
|
@ -1002,6 +1004,159 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std
|
|||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_mark_cache_anchors(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
assistant_panel::init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let context =
|
||||
cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
|
||||
let buffer = context.read(cx).buffer.clone();
|
||||
|
||||
// Create a test cache configuration
|
||||
let cache_configuration = &Some(LanguageModelCacheConfiguration {
|
||||
max_cache_anchors: 3,
|
||||
should_speculate: true,
|
||||
min_total_token: 10,
|
||||
});
|
||||
|
||||
let message_1 = context.read(cx).message_anchors[0].clone();
|
||||
|
||||
context.update(cx, |context, cx| {
|
||||
context.mark_cache_anchors(cache_configuration, false, cx)
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
messages_cache(&context, cx)
|
||||
.iter()
|
||||
.filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
|
||||
.count(),
|
||||
0,
|
||||
"Empty messages should not have any cache anchors."
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
|
||||
let message_2 = context
|
||||
.update(cx, |context, cx| {
|
||||
context.insert_message_after(message_1.id, Role::User, MessageStatus::Pending, cx)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbbbbbb")], None, cx));
|
||||
let message_3 = context
|
||||
.update(cx, |context, cx| {
|
||||
context.insert_message_after(message_2.id, Role::User, MessageStatus::Pending, cx)
|
||||
})
|
||||
.unwrap();
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(12..12, "cccccc")], None, cx));
|
||||
|
||||
context.update(cx, |context, cx| {
|
||||
context.mark_cache_anchors(cache_configuration, false, cx)
|
||||
});
|
||||
assert_eq!(buffer.read(cx).text(), "aaa\nbbbbbbb\ncccccc");
|
||||
assert_eq!(
|
||||
messages_cache(&context, cx)
|
||||
.iter()
|
||||
.filter(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
|
||||
.count(),
|
||||
0,
|
||||
"Messages should not be marked for cache before going over the token minimum."
|
||||
);
|
||||
context.update(cx, |context, _| {
|
||||
context.token_count = Some(20);
|
||||
});
|
||||
|
||||
context.update(cx, |context, cx| {
|
||||
context.mark_cache_anchors(cache_configuration, true, cx)
|
||||
});
|
||||
assert_eq!(
|
||||
messages_cache(&context, cx)
|
||||
.iter()
|
||||
.map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
|
||||
.collect::<Vec<bool>>(),
|
||||
vec![true, true, false],
|
||||
"Last message should not be an anchor on speculative request."
|
||||
);
|
||||
|
||||
context
|
||||
.update(cx, |context, cx| {
|
||||
context.insert_message_after(message_3.id, Role::Assistant, MessageStatus::Pending, cx)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
context.update(cx, |context, cx| {
|
||||
context.mark_cache_anchors(cache_configuration, false, cx)
|
||||
});
|
||||
assert_eq!(
|
||||
messages_cache(&context, cx)
|
||||
.iter()
|
||||
.map(|(_, cache)| cache.as_ref().map_or(false, |cache| cache.is_anchor))
|
||||
.collect::<Vec<bool>>(),
|
||||
vec![false, true, true, false],
|
||||
"Most recent message should also be cached if not a speculative request."
|
||||
);
|
||||
context.update(cx, |context, cx| {
|
||||
context.update_cache_status_for_completion(cx)
|
||||
});
|
||||
assert_eq!(
|
||||
messages_cache(&context, cx)
|
||||
.iter()
|
||||
.map(|(_, cache)| cache
|
||||
.as_ref()
|
||||
.map_or(None, |cache| Some(cache.status.clone())))
|
||||
.collect::<Vec<Option<CacheStatus>>>(),
|
||||
vec![
|
||||
Some(CacheStatus::Cached),
|
||||
Some(CacheStatus::Cached),
|
||||
Some(CacheStatus::Cached),
|
||||
None
|
||||
],
|
||||
"All user messages prior to anchor should be marked as cached."
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(14..14, "d")], None, cx));
|
||||
context.update(cx, |context, cx| {
|
||||
context.mark_cache_anchors(cache_configuration, false, cx)
|
||||
});
|
||||
assert_eq!(
|
||||
messages_cache(&context, cx)
|
||||
.iter()
|
||||
.map(|(_, cache)| cache
|
||||
.as_ref()
|
||||
.map_or(None, |cache| Some(cache.status.clone())))
|
||||
.collect::<Vec<Option<CacheStatus>>>(),
|
||||
vec![
|
||||
Some(CacheStatus::Cached),
|
||||
Some(CacheStatus::Cached),
|
||||
Some(CacheStatus::Pending),
|
||||
None
|
||||
],
|
||||
"Modifying a message should invalidate it's cache but leave previous messages."
|
||||
);
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "e")], None, cx));
|
||||
context.update(cx, |context, cx| {
|
||||
context.mark_cache_anchors(cache_configuration, false, cx)
|
||||
});
|
||||
assert_eq!(
|
||||
messages_cache(&context, cx)
|
||||
.iter()
|
||||
.map(|(_, cache)| cache
|
||||
.as_ref()
|
||||
.map_or(None, |cache| Some(cache.status.clone())))
|
||||
.collect::<Vec<Option<CacheStatus>>>(),
|
||||
vec![
|
||||
Some(CacheStatus::Pending),
|
||||
Some(CacheStatus::Pending),
|
||||
Some(CacheStatus::Pending),
|
||||
None
|
||||
],
|
||||
"Modifying a message should invalidate all future messages."
|
||||
);
|
||||
}
|
||||
|
||||
fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
|
||||
context
|
||||
.read(cx)
|
||||
|
@ -1010,6 +1165,17 @@ fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role,
|
|||
.collect()
|
||||
}
|
||||
|
||||
fn messages_cache(
|
||||
context: &Model<Context>,
|
||||
cx: &AppContext,
|
||||
) -> Vec<(MessageId, Option<MessageCacheMetadata>)> {
|
||||
context
|
||||
.read(cx)
|
||||
.messages(cx)
|
||||
.map(|message| (message.id, message.cache.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct FakeSlashCommand(String);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue