bedrock: Add prompt caching support (#33194)

Closes https://github.com/zed-industries/zed/issues/33221

Bedrock has similar to anthropic caching api, if we want to cache
messages up to a certain point, we should add a special block into that
message.

Additionally, we can cache tools definition by adding cache point block
after tools spec.

See: [Bedrock User Guide: Prompt
Caching](https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html#prompt-caching-models)

Release Notes:

- bedrock: Added prompt caching support

---------

Co-authored-by: Oleksiy Syvokon <oleksiy@zed.dev>
This commit is contained in:
Vladimir Kuznichenkov 2025-06-25 17:15:13 +03:00 committed by GitHub
parent 59aeede50d
commit 0905255fd1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 101 additions and 10 deletions

View file

@ -11,8 +11,8 @@ use aws_http_client::AwsHttpClient;
use bedrock::bedrock_client::Client as BedrockClient;
use bedrock::bedrock_client::config::timeout::TimeoutConfig;
use bedrock::bedrock_client::types::{
ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, ReasoningContentBlockDelta,
StopReason,
CachePointBlock, CachePointType, ContentBlockDelta, ContentBlockStart, ConverseStreamOutput,
ReasoningContentBlockDelta, StopReason,
};
use bedrock::{
BedrockAnyToolChoice, BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent,
@ -48,7 +48,7 @@ use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
use theme::ThemeSettings;
use tokio::runtime::Handle;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::{ResultExt, default};
use util::ResultExt;
use crate::AllLanguageModelSettings;
@ -329,6 +329,12 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
default_temperature: model.default_temperature,
cache_configuration: model.cache_configuration.as_ref().map(|config| {
bedrock::BedrockModelCacheConfiguration {
max_cache_anchors: config.max_cache_anchors,
min_total_token: config.min_total_token,
}
}),
},
);
}
@ -558,6 +564,7 @@ impl LanguageModel for BedrockModel {
self.model.default_temperature(),
self.model.max_output_tokens(),
self.model.mode(),
self.model.supports_caching(),
) {
Ok(request) => request,
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
@ -581,7 +588,13 @@ impl LanguageModel for BedrockModel {
}
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
None
self.model
.cache_configuration()
.map(|config| LanguageModelCacheConfiguration {
max_cache_anchors: config.max_cache_anchors,
should_speculate: false,
min_total_token: config.min_total_token,
})
}
}
@ -608,6 +621,7 @@ pub fn into_bedrock(
default_temperature: f32,
max_output_tokens: u64,
mode: BedrockModelMode,
supports_caching: bool,
) -> Result<bedrock::Request> {
let mut new_messages: Vec<BedrockMessage> = Vec::new();
let mut system_message = String::new();
@ -619,7 +633,7 @@ pub fn into_bedrock(
match message.role {
Role::User | Role::Assistant => {
let bedrock_message_content: Vec<BedrockInnerContent> = message
let mut bedrock_message_content: Vec<BedrockInnerContent> = message
.content
.into_iter()
.filter_map(|content| match content {
@ -703,6 +717,14 @@ pub fn into_bedrock(
_ => None,
})
.collect();
if message.cache && supports_caching {
bedrock_message_content.push(BedrockInnerContent::CachePoint(
CachePointBlock::builder()
.r#type(CachePointType::Default)
.build()
.context("failed to build cache point block")?,
));
}
let bedrock_role = match message.role {
Role::User => bedrock::BedrockRole::User,
Role::Assistant => bedrock::BedrockRole::Assistant,
@ -731,7 +753,7 @@ pub fn into_bedrock(
}
}
let tool_spec: Vec<BedrockTool> = request
let mut tool_spec: Vec<BedrockTool> = request
.tools
.iter()
.filter_map(|tool| {
@ -748,6 +770,15 @@ pub fn into_bedrock(
})
.collect();
if !tool_spec.is_empty() && supports_caching {
tool_spec.push(BedrockTool::CachePoint(
CachePointBlock::builder()
.r#type(CachePointType::Default)
.build()
.context("failed to build cache point block")?,
));
}
let tool_choice = match request.tool_choice {
Some(LanguageModelToolChoice::Auto) | None => {
BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
@ -990,10 +1021,11 @@ pub fn map_to_language_model_completion_events(
LanguageModelCompletionEvent::UsageUpdate(
TokenUsage {
input_tokens: metadata.input_tokens as u64,
output_tokens: metadata.output_tokens
as u64,
cache_creation_input_tokens: default(),
cache_read_input_tokens: default(),
output_tokens: metadata.output_tokens as u64,
cache_creation_input_tokens:
metadata.cache_write_input_tokens.unwrap_or_default() as u64,
cache_read_input_tokens:
metadata.cache_read_input_tokens.unwrap_or_default() as u64,
},
);
return Some((Some(Ok(completion_event)), state));