Simplify logic & add UI affordances to show model cache status (#16395)

Release Notes:

- Adds UI affordances to the assistant panel to show which messages have
been cached
- Migrate cache invalidation to be based on `has_edits_since_in_range`
to be smarter and more selective about when to invalidate the cache and
when to fetch.

<img width="310" alt="Screenshot 2024-08-16 at 11 19 23 PM"
src="https://github.com/user-attachments/assets/4ee2d111-2f55-4b0e-b944-50c4f78afc42">

<img width="580" alt="Screenshot 2024-08-18 at 10 05 16 PM"
src="https://github.com/user-attachments/assets/17630a60-7b78-421c-ae39-425246638a12">


I had originally added the lightening bolt on every message and only
added the tooltip warning about editing prior messages on the first
anchor, but thought it looked too busy, so I settled on just annotating
the last anchor.
This commit is contained in:
Roy Williams 2024-08-19 15:06:14 -04:00 committed by GitHub
parent 971db5c6f6
commit 0042c24d3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 365 additions and 77 deletions

View file

@ -40,6 +40,7 @@ use std::{
time::{Duration, Instant},
};
use telemetry_events::AssistantKind;
use text::BufferSnapshot;
use util::{post_inc, ResultExt, TryFutureExt};
use uuid::Uuid;
@ -107,8 +108,7 @@ impl ContextOperation {
message.status.context("invalid status")?,
),
timestamp: id.0,
should_cache: false,
is_cache_anchor: false,
cache: None,
},
version: language::proto::deserialize_version(&insert.version),
})
@ -123,8 +123,7 @@ impl ContextOperation {
timestamp: language::proto::deserialize_timestamp(
update.timestamp.context("invalid timestamp")?,
),
should_cache: false,
is_cache_anchor: false,
cache: None,
},
version: language::proto::deserialize_version(&update.version),
}),
@ -313,13 +312,43 @@ pub struct MessageAnchor {
pub start: language::Anchor,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum CacheStatus {
Pending,
Cached,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MessageCacheMetadata {
pub is_anchor: bool,
pub is_final_anchor: bool,
pub status: CacheStatus,
pub cached_at: clock::Global,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct MessageMetadata {
pub role: Role,
pub status: MessageStatus,
timestamp: clock::Lamport,
should_cache: bool,
is_cache_anchor: bool,
#[serde(skip)]
pub cache: Option<MessageCacheMetadata>,
}
impl MessageMetadata {
pub fn is_cache_valid(&self, buffer: &BufferSnapshot, range: &Range<usize>) -> bool {
let result = match &self.cache {
Some(MessageCacheMetadata { cached_at, .. }) => !buffer.has_edits_since_in_range(
&cached_at,
Range {
start: buffer.anchor_at(range.start, Bias::Right),
end: buffer.anchor_at(range.end, Bias::Left),
},
),
_ => false,
};
result
}
}
#[derive(Clone, Debug)]
@ -345,7 +374,7 @@ pub struct Message {
pub anchor: language::Anchor,
pub role: Role,
pub status: MessageStatus,
pub cache: bool,
pub cache: Option<MessageCacheMetadata>,
}
impl Message {
@ -381,7 +410,7 @@ impl Message {
Some(LanguageModelRequestMessage {
role: self.role,
content,
cache: self.cache,
cache: self.cache.as_ref().map_or(false, |cache| cache.is_anchor),
})
}
@ -544,8 +573,7 @@ impl Context {
role: Role::User,
status: MessageStatus::Done,
timestamp: first_message_id.0,
should_cache: false,
is_cache_anchor: false,
cache: None,
},
);
this.message_anchors.push(message);
@ -979,7 +1007,7 @@ impl Context {
});
}
pub fn mark_longest_messages_for_cache(
pub fn mark_cache_anchors(
&mut self,
cache_configuration: &Option<LanguageModelCacheConfiguration>,
speculative: bool,
@ -994,66 +1022,104 @@ impl Context {
min_total_token: 0,
});
let messages: Vec<Message> = self
.messages_from_anchors(
self.message_anchors.iter().take(if speculative {
self.message_anchors.len().saturating_sub(1)
} else {
self.message_anchors.len()
}),
cx,
)
.filter(|message| message.offset_range.len() >= 5_000)
.collect();
let messages: Vec<Message> = self.messages(cx).collect();
let mut sorted_messages = messages.clone();
sorted_messages.sort_by(|a, b| b.offset_range.len().cmp(&a.offset_range.len()));
if cache_configuration.max_cache_anchors == 0 && cache_configuration.should_speculate {
// Some models support caching, but don't support anchors. In that case we want to
// mark the largest message as needing to be cached, but we will not mark it as an
// anchor.
sorted_messages.truncate(1);
} else {
// Save 1 anchor for the inline assistant.
sorted_messages.truncate(max(cache_configuration.max_cache_anchors, 1) - 1);
if speculative {
// Avoid caching the last message if this is a speculative cache fetch as
// it's likely to change.
sorted_messages.pop();
}
sorted_messages.retain(|m| m.role == Role::User);
sorted_messages.sort_by(|a, b| b.offset_range.len().cmp(&a.offset_range.len()));
let longest_message_ids: HashSet<MessageId> = sorted_messages
let cache_anchors = if self.token_count.unwrap_or(0) < cache_configuration.min_total_token {
// If we have't hit the minimum threshold to enable caching, don't cache anything.
0
} else {
// Save 1 anchor for the inline assistant to use.
max(cache_configuration.max_cache_anchors, 1) - 1
};
sorted_messages.truncate(cache_anchors);
let anchors: HashSet<MessageId> = sorted_messages
.into_iter()
.map(|message| message.id)
.collect();
let cache_deltas: HashSet<MessageId> = self
.messages_metadata
let buffer = self.buffer.read(cx).snapshot();
let invalidated_caches: HashSet<MessageId> = messages
.iter()
.filter_map(|(id, metadata)| {
let should_cache = longest_message_ids.contains(id);
let should_be_anchor = should_cache && cache_configuration.max_cache_anchors > 0;
if metadata.should_cache != should_cache
|| metadata.is_cache_anchor != should_be_anchor
{
Some(*id)
} else {
None
}
.scan(false, |encountered_invalid, message| {
let message_id = message.id;
let is_invalid = self
.messages_metadata
.get(&message_id)
.map_or(true, |metadata| {
!metadata.is_cache_valid(&buffer, &message.offset_range)
|| *encountered_invalid
});
*encountered_invalid |= is_invalid;
Some(if is_invalid { Some(message_id) } else { None })
})
.flatten()
.collect();
let mut newly_cached_item = false;
for id in cache_deltas {
newly_cached_item = newly_cached_item || longest_message_ids.contains(&id);
self.update_metadata(id, cx, |metadata| {
metadata.should_cache = longest_message_ids.contains(&id);
metadata.is_cache_anchor =
metadata.should_cache && (cache_configuration.max_cache_anchors > 0);
let last_anchor = messages.iter().rev().find_map(|message| {
if anchors.contains(&message.id) {
Some(message.id)
} else {
None
}
});
let mut new_anchor_needs_caching = false;
let current_version = &buffer.version;
// If we have no anchors, mark all messages as not being cached.
let mut hit_last_anchor = last_anchor.is_none();
for message in messages.iter() {
if hit_last_anchor {
self.update_metadata(message.id, cx, |metadata| metadata.cache = None);
continue;
}
if let Some(last_anchor) = last_anchor {
if message.id == last_anchor {
hit_last_anchor = true;
}
}
new_anchor_needs_caching = new_anchor_needs_caching
|| (invalidated_caches.contains(&message.id) && anchors.contains(&message.id));
self.update_metadata(message.id, cx, |metadata| {
let cache_status = if invalidated_caches.contains(&message.id) {
CacheStatus::Pending
} else {
metadata
.cache
.as_ref()
.map_or(CacheStatus::Pending, |cm| cm.status.clone())
};
metadata.cache = Some(MessageCacheMetadata {
is_anchor: anchors.contains(&message.id),
is_final_anchor: hit_last_anchor,
status: cache_status,
cached_at: current_version.clone(),
});
});
}
newly_cached_item
new_anchor_needs_caching
}
fn start_cache_warming(&mut self, model: &Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
let cache_configuration = model.cache_configuration();
if !self.mark_longest_messages_for_cache(&cache_configuration, true, cx) {
if !self.mark_cache_anchors(&cache_configuration, true, cx) {
return;
}
if !self.pending_completions.is_empty() {
return;
}
if let Some(cache_configuration) = cache_configuration {
@ -1076,7 +1142,7 @@ impl Context {
};
let model = Arc::clone(model);
self.pending_cache_warming_task = cx.spawn(|_, cx| {
self.pending_cache_warming_task = cx.spawn(|this, mut cx| {
async move {
match model.stream_completion(request, &cx).await {
Ok(mut stream) => {
@ -1087,13 +1153,41 @@ impl Context {
log::warn!("Cache warming failed: {}", e);
}
};
this.update(&mut cx, |this, cx| {
this.update_cache_status_for_completion(cx);
})
.ok();
anyhow::Ok(())
}
.log_err()
});
}
pub fn update_cache_status_for_completion(&mut self, cx: &mut ModelContext<Self>) {
let cached_message_ids: Vec<MessageId> = self
.messages_metadata
.iter()
.filter_map(|(message_id, metadata)| {
metadata.cache.as_ref().and_then(|cache| {
if cache.status == CacheStatus::Pending {
Some(*message_id)
} else {
None
}
})
})
.collect();
for message_id in cached_message_ids {
self.update_metadata(message_id, cx, |metadata| {
if let Some(cache) = &mut metadata.cache {
cache.status = CacheStatus::Cached;
}
});
}
cx.notify();
}
pub fn reparse_slash_commands(&mut self, cx: &mut ModelContext<Self>) {
let buffer = self.buffer.read(cx);
let mut row_ranges = self
@ -1531,7 +1625,7 @@ impl Context {
return None;
}
// Compute which messages to cache, including the last one.
self.mark_longest_messages_for_cache(&model.cache_configuration(), false, cx);
self.mark_cache_anchors(&model.cache_configuration(), false, cx);
let request = self.to_completion_request(cx);
let assistant_message = self
@ -1596,6 +1690,7 @@ impl Context {
this.pending_completions
.retain(|completion| completion.id != pending_completion_id);
this.summarize(false, cx);
this.update_cache_status_for_completion(cx);
})?;
anyhow::Ok(())
@ -1746,8 +1841,7 @@ impl Context {
role,
status,
timestamp: anchor.id.0,
should_cache: false,
is_cache_anchor: false,
cache: None,
};
self.insert_message(anchor.clone(), metadata.clone(), cx);
self.push_op(
@ -1864,8 +1958,7 @@ impl Context {
role,
status: MessageStatus::Done,
timestamp: suffix.id.0,
should_cache: false,
is_cache_anchor: false,
cache: None,
};
self.insert_message(suffix.clone(), suffix_metadata.clone(), cx);
self.push_op(
@ -1915,8 +2008,7 @@ impl Context {
role,
status: MessageStatus::Done,
timestamp: selection.id.0,
should_cache: false,
is_cache_anchor: false,
cache: None,
};
self.insert_message(selection.clone(), selection_metadata.clone(), cx);
self.push_op(
@ -2150,7 +2242,7 @@ impl Context {
anchor: message_anchor.start,
role: metadata.role,
status: metadata.status.clone(),
cache: metadata.is_cache_anchor,
cache: metadata.cache.clone(),
image_offsets,
});
}
@ -2397,8 +2489,7 @@ impl SavedContext {
role: message.metadata.role,
status: message.metadata.status,
timestamp: message.metadata.timestamp,
should_cache: false,
is_cache_anchor: false,
cache: None,
},
version: version.clone(),
});
@ -2415,8 +2506,7 @@ impl SavedContext {
role: metadata.role,
status: metadata.status,
timestamp,
should_cache: false,
is_cache_anchor: false,
cache: None,
},
version: version.clone(),
});
@ -2511,8 +2601,7 @@ impl SavedContextV0_3_0 {
role: metadata.role,
status: metadata.status.clone(),
timestamp,
should_cache: false,
is_cache_anchor: false,
cache: None,
},
image_offsets: Vec::new(),
})