Implement Anthropic prompt caching (#16274)

Release Notes:

- Adds support for Prompt Caching in Anthropic. For models that support
it this can dramatically lower cost while improving performance.
This commit is contained in:
Roy Williams 2024-08-15 23:21:06 -04:00 committed by GitHub
parent 09b6e3f2a6
commit 46fb917e02
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 338 additions and 70 deletions

View file

@ -21,8 +21,8 @@ use gpui::{
use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
use language_model::{
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role,
LanguageModel, LanguageModelCacheConfiguration, LanguageModelImage, LanguageModelRegistry,
LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use open_ai::Model as OpenAiModel;
use paths::{context_images_dir, contexts_dir};
@ -30,7 +30,7 @@ use project::Project;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use std::{
cmp::Ordering,
cmp::{max, Ordering},
collections::hash_map,
fmt::Debug,
iter, mem,
@ -107,6 +107,8 @@ impl ContextOperation {
message.status.context("invalid status")?,
),
timestamp: id.0,
should_cache: false,
is_cache_anchor: false,
},
version: language::proto::deserialize_version(&insert.version),
})
@ -121,6 +123,8 @@ impl ContextOperation {
timestamp: language::proto::deserialize_timestamp(
update.timestamp.context("invalid timestamp")?,
),
should_cache: false,
is_cache_anchor: false,
},
version: language::proto::deserialize_version(&update.version),
}),
@ -313,6 +317,8 @@ pub struct MessageMetadata {
pub role: Role,
pub status: MessageStatus,
timestamp: clock::Lamport,
should_cache: bool,
is_cache_anchor: bool,
}
#[derive(Clone, Debug)]
@ -338,6 +344,7 @@ pub struct Message {
pub anchor: language::Anchor,
pub role: Role,
pub status: MessageStatus,
pub cache: bool,
}
impl Message {
@ -373,6 +380,7 @@ impl Message {
LanguageModelRequestMessage {
role: self.role,
content,
cache: self.cache,
}
}
}
@ -421,6 +429,7 @@ pub struct Context {
token_count: Option<usize>,
pending_token_count: Task<Option<()>>,
pending_save: Task<Result<()>>,
pending_cache_warming_task: Task<Option<()>>,
path: Option<PathBuf>,
_subscriptions: Vec<Subscription>,
telemetry: Option<Arc<Telemetry>>,
@ -498,6 +507,7 @@ impl Context {
pending_completions: Default::default(),
token_count: None,
pending_token_count: Task::ready(None),
pending_cache_warming_task: Task::ready(None),
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: None,
@ -524,6 +534,8 @@ impl Context {
role: Role::User,
status: MessageStatus::Done,
timestamp: first_message_id.0,
should_cache: false,
is_cache_anchor: false,
},
);
this.message_anchors.push(message);
@ -948,6 +960,7 @@ impl Context {
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
this.start_cache_warming(&model, cx);
cx.notify()
})
}
@ -955,6 +968,121 @@ impl Context {
});
}
pub fn mark_longest_messages_for_cache(
&mut self,
cache_configuration: &Option<LanguageModelCacheConfiguration>,
speculative: bool,
cx: &mut ModelContext<Self>,
) -> bool {
let cache_configuration =
cache_configuration
.as_ref()
.unwrap_or(&LanguageModelCacheConfiguration {
max_cache_anchors: 0,
should_speculate: false,
min_total_token: 0,
});
let messages: Vec<Message> = self
.messages_from_anchors(
self.message_anchors.iter().take(if speculative {
self.message_anchors.len().saturating_sub(1)
} else {
self.message_anchors.len()
}),
cx,
)
.filter(|message| message.offset_range.len() >= 5_000)
.collect();
let mut sorted_messages = messages.clone();
sorted_messages.sort_by(|a, b| b.offset_range.len().cmp(&a.offset_range.len()));
if cache_configuration.max_cache_anchors == 0 && cache_configuration.should_speculate {
// Some models support caching, but don't support anchors. In that case we want to
// mark the largest message as needing to be cached, but we will not mark it as an
// anchor.
sorted_messages.truncate(1);
} else {
// Save 1 anchor for the inline assistant.
sorted_messages.truncate(max(cache_configuration.max_cache_anchors, 1) - 1);
}
let longest_message_ids: HashSet<MessageId> = sorted_messages
.into_iter()
.map(|message| message.id)
.collect();
let cache_deltas: HashSet<MessageId> = self
.messages_metadata
.iter()
.filter_map(|(id, metadata)| {
let should_cache = longest_message_ids.contains(id);
let should_be_anchor = should_cache && cache_configuration.max_cache_anchors > 0;
if metadata.should_cache != should_cache
|| metadata.is_cache_anchor != should_be_anchor
{
Some(*id)
} else {
None
}
})
.collect();
let mut newly_cached_item = false;
for id in cache_deltas {
newly_cached_item = newly_cached_item || longest_message_ids.contains(&id);
self.update_metadata(id, cx, |metadata| {
metadata.should_cache = longest_message_ids.contains(&id);
metadata.is_cache_anchor =
metadata.should_cache && (cache_configuration.max_cache_anchors > 0);
});
}
newly_cached_item
}
fn start_cache_warming(&mut self, model: &Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
let cache_configuration = model.cache_configuration();
if !self.mark_longest_messages_for_cache(&cache_configuration, true, cx) {
return;
}
if let Some(cache_configuration) = cache_configuration {
if !cache_configuration.should_speculate {
return;
}
}
let request = {
let mut req = self.to_completion_request(cx);
// Skip the last message because it's likely to change and
// therefore would be a waste to cache.
req.messages.pop();
req.messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec!["Respond only with OK, nothing else.".into()],
cache: false,
});
req
};
let model = Arc::clone(model);
self.pending_cache_warming_task = cx.spawn(|_, cx| {
async move {
match model.stream_completion(request, &cx).await {
Ok(mut stream) => {
stream.next().await;
log::info!("Cache warming completed successfully");
}
Err(e) => {
log::warn!("Cache warming failed: {}", e);
}
};
anyhow::Ok(())
}
.log_err()
});
}
pub fn reparse_slash_commands(&mut self, cx: &mut ModelContext<Self>) {
let buffer = self.buffer.read(cx);
let mut row_ranges = self
@ -1352,20 +1480,26 @@ impl Context {
self.count_remaining_tokens(cx);
}
pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
fn get_last_valid_message_id(&self, cx: &ModelContext<Self>) -> Option<MessageId> {
self.message_anchors.iter().rev().find_map(|message| {
message
.start
.is_valid(self.buffer.read(cx))
.then_some(message.id)
})?;
})
}
pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let last_message_id = self.get_last_valid_message_id(cx)?;
if !provider.is_authenticated(cx) {
log::info!("completion provider has no credentials");
return None;
}
// Compute which messages to cache, including the last one.
self.mark_longest_messages_for_cache(&model.cache_configuration(), false, cx);
let request = self.to_completion_request(cx);
let assistant_message = self
@ -1580,6 +1714,8 @@ impl Context {
role,
status,
timestamp: anchor.id.0,
should_cache: false,
is_cache_anchor: false,
};
self.insert_message(anchor.clone(), metadata.clone(), cx);
self.push_op(
@ -1696,6 +1832,8 @@ impl Context {
role,
status: MessageStatus::Done,
timestamp: suffix.id.0,
should_cache: false,
is_cache_anchor: false,
};
self.insert_message(suffix.clone(), suffix_metadata.clone(), cx);
self.push_op(
@ -1745,6 +1883,8 @@ impl Context {
role,
status: MessageStatus::Done,
timestamp: selection.id.0,
should_cache: false,
is_cache_anchor: false,
};
self.insert_message(selection.clone(), selection_metadata.clone(), cx);
self.push_op(
@ -1811,6 +1951,7 @@ impl Context {
content: vec![
"Summarize the context into a short title without punctuation.".into(),
],
cache: false,
}));
let request = LanguageModelRequest {
messages: messages.collect(),
@ -1910,14 +2051,22 @@ impl Context {
result
}
pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
fn messages_from_anchors<'a>(
&'a self,
message_anchors: impl Iterator<Item = &'a MessageAnchor> + 'a,
cx: &'a AppContext,
) -> impl 'a + Iterator<Item = Message> {
let buffer = self.buffer.read(cx);
let messages = self.message_anchors.iter().enumerate();
let messages = message_anchors.enumerate();
let images = self.image_anchors.iter();
Self::messages_from_iters(buffer, &self.messages_metadata, messages, images)
}
pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
self.messages_from_anchors(self.message_anchors.iter(), cx)
}
pub fn messages_from_iters<'a>(
buffer: &'a Buffer,
metadata: &'a HashMap<MessageId, MessageMetadata>,
@ -1969,6 +2118,7 @@ impl Context {
anchor: message_anchor.start,
role: metadata.role,
status: metadata.status.clone(),
cache: metadata.is_cache_anchor,
image_offsets,
});
}
@ -2215,6 +2365,8 @@ impl SavedContext {
role: message.metadata.role,
status: message.metadata.status,
timestamp: message.metadata.timestamp,
should_cache: false,
is_cache_anchor: false,
},
version: version.clone(),
});
@ -2231,6 +2383,8 @@ impl SavedContext {
role: metadata.role,
status: metadata.status,
timestamp,
should_cache: false,
is_cache_anchor: false,
},
version: version.clone(),
});
@ -2325,6 +2479,8 @@ impl SavedContextV0_3_0 {
role: metadata.role,
status: metadata.status.clone(),
timestamp,
should_cache: false,
is_cache_anchor: false,
},
image_offsets: Vec::new(),
})