bedrock: Add prompt caching support (#33194)

Closes https://github.com/zed-industries/zed/issues/33221

Bedrock has similar to anthropic caching api, if we want to cache
messages up to a certain point, we should add a special block into that
message.

Additionally, we can cache tools definition by adding cache point block
after tools spec.

See: [Bedrock User Guide: Prompt
Caching](https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html#prompt-caching-models)

Release Notes:

- bedrock: Added prompt caching support

---------

Co-authored-by: Oleksiy Syvokon <oleksiy@zed.dev>
This commit is contained in:
Vladimir Kuznichenkov 2025-06-25 17:15:13 +03:00 committed by GitHub
parent 59aeede50d
commit 0905255fd1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 101 additions and 10 deletions

View file

@ -11,6 +11,13 @@ pub enum BedrockModelMode {
}, },
} }
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct BedrockModelCacheConfiguration {
pub max_cache_anchors: usize,
pub min_total_token: u64,
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model { pub enum Model {
@ -104,6 +111,7 @@ pub enum Model {
display_name: Option<String>, display_name: Option<String>,
max_output_tokens: Option<u64>, max_output_tokens: Option<u64>,
default_temperature: Option<f32>, default_temperature: Option<f32>,
cache_configuration: Option<BedrockModelCacheConfiguration>,
}, },
} }
@ -401,6 +409,56 @@ impl Model {
} }
} }
pub fn supports_caching(&self) -> bool {
match self {
// Only Claude models on Bedrock support caching
// Nova models support only text caching
// https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html#prompt-caching-models
Self::Claude3_5Haiku
| Self::Claude3_7Sonnet
| Self::Claude3_7SonnetThinking
| Self::ClaudeSonnet4
| Self::ClaudeSonnet4Thinking
| Self::ClaudeOpus4
| Self::ClaudeOpus4Thinking => true,
// Custom models - check if they have cache configuration
Self::Custom {
cache_configuration,
..
} => cache_configuration.is_some(),
// All other models don't support caching
_ => false,
}
}
pub fn cache_configuration(&self) -> Option<BedrockModelCacheConfiguration> {
match self {
Self::Claude3_7Sonnet
| Self::Claude3_7SonnetThinking
| Self::ClaudeSonnet4
| Self::ClaudeSonnet4Thinking
| Self::ClaudeOpus4
| Self::ClaudeOpus4Thinking => Some(BedrockModelCacheConfiguration {
max_cache_anchors: 4,
min_total_token: 1024,
}),
Self::Claude3_5Haiku => Some(BedrockModelCacheConfiguration {
max_cache_anchors: 4,
min_total_token: 2048,
}),
Self::Custom {
cache_configuration,
..
} => cache_configuration.clone(),
_ => None,
}
}
pub fn mode(&self) -> BedrockModelMode { pub fn mode(&self) -> BedrockModelMode {
match self { match self {
Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking { Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
@ -660,6 +718,7 @@ mod tests {
display_name: Some("My Custom Model".to_string()), display_name: Some("My Custom Model".to_string()),
max_output_tokens: Some(8192), max_output_tokens: Some(8192),
default_temperature: Some(0.7), default_temperature: Some(0.7),
cache_configuration: None,
}; };
// Custom model should return its name unchanged // Custom model should return its name unchanged

View file

@ -11,8 +11,8 @@ use aws_http_client::AwsHttpClient;
use bedrock::bedrock_client::Client as BedrockClient; use bedrock::bedrock_client::Client as BedrockClient;
use bedrock::bedrock_client::config::timeout::TimeoutConfig; use bedrock::bedrock_client::config::timeout::TimeoutConfig;
use bedrock::bedrock_client::types::{ use bedrock::bedrock_client::types::{
ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, ReasoningContentBlockDelta, CachePointBlock, CachePointType, ContentBlockDelta, ContentBlockStart, ConverseStreamOutput,
StopReason, ReasoningContentBlockDelta, StopReason,
}; };
use bedrock::{ use bedrock::{
BedrockAnyToolChoice, BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent, BedrockAnyToolChoice, BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent,
@ -48,7 +48,7 @@ use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
use theme::ThemeSettings; use theme::ThemeSettings;
use tokio::runtime::Handle; use tokio::runtime::Handle;
use ui::{Icon, IconName, List, Tooltip, prelude::*}; use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::{ResultExt, default}; use util::ResultExt;
use crate::AllLanguageModelSettings; use crate::AllLanguageModelSettings;
@ -329,6 +329,12 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
max_tokens: model.max_tokens, max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens, max_output_tokens: model.max_output_tokens,
default_temperature: model.default_temperature, default_temperature: model.default_temperature,
cache_configuration: model.cache_configuration.as_ref().map(|config| {
bedrock::BedrockModelCacheConfiguration {
max_cache_anchors: config.max_cache_anchors,
min_total_token: config.min_total_token,
}
}),
}, },
); );
} }
@ -558,6 +564,7 @@ impl LanguageModel for BedrockModel {
self.model.default_temperature(), self.model.default_temperature(),
self.model.max_output_tokens(), self.model.max_output_tokens(),
self.model.mode(), self.model.mode(),
self.model.supports_caching(),
) { ) {
Ok(request) => request, Ok(request) => request,
Err(err) => return futures::future::ready(Err(err.into())).boxed(), Err(err) => return futures::future::ready(Err(err.into())).boxed(),
@ -581,7 +588,13 @@ impl LanguageModel for BedrockModel {
} }
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> { fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
None self.model
.cache_configuration()
.map(|config| LanguageModelCacheConfiguration {
max_cache_anchors: config.max_cache_anchors,
should_speculate: false,
min_total_token: config.min_total_token,
})
} }
} }
@ -608,6 +621,7 @@ pub fn into_bedrock(
default_temperature: f32, default_temperature: f32,
max_output_tokens: u64, max_output_tokens: u64,
mode: BedrockModelMode, mode: BedrockModelMode,
supports_caching: bool,
) -> Result<bedrock::Request> { ) -> Result<bedrock::Request> {
let mut new_messages: Vec<BedrockMessage> = Vec::new(); let mut new_messages: Vec<BedrockMessage> = Vec::new();
let mut system_message = String::new(); let mut system_message = String::new();
@ -619,7 +633,7 @@ pub fn into_bedrock(
match message.role { match message.role {
Role::User | Role::Assistant => { Role::User | Role::Assistant => {
let bedrock_message_content: Vec<BedrockInnerContent> = message let mut bedrock_message_content: Vec<BedrockInnerContent> = message
.content .content
.into_iter() .into_iter()
.filter_map(|content| match content { .filter_map(|content| match content {
@ -703,6 +717,14 @@ pub fn into_bedrock(
_ => None, _ => None,
}) })
.collect(); .collect();
if message.cache && supports_caching {
bedrock_message_content.push(BedrockInnerContent::CachePoint(
CachePointBlock::builder()
.r#type(CachePointType::Default)
.build()
.context("failed to build cache point block")?,
));
}
let bedrock_role = match message.role { let bedrock_role = match message.role {
Role::User => bedrock::BedrockRole::User, Role::User => bedrock::BedrockRole::User,
Role::Assistant => bedrock::BedrockRole::Assistant, Role::Assistant => bedrock::BedrockRole::Assistant,
@ -731,7 +753,7 @@ pub fn into_bedrock(
} }
} }
let tool_spec: Vec<BedrockTool> = request let mut tool_spec: Vec<BedrockTool> = request
.tools .tools
.iter() .iter()
.filter_map(|tool| { .filter_map(|tool| {
@ -748,6 +770,15 @@ pub fn into_bedrock(
}) })
.collect(); .collect();
if !tool_spec.is_empty() && supports_caching {
tool_spec.push(BedrockTool::CachePoint(
CachePointBlock::builder()
.r#type(CachePointType::Default)
.build()
.context("failed to build cache point block")?,
));
}
let tool_choice = match request.tool_choice { let tool_choice = match request.tool_choice {
Some(LanguageModelToolChoice::Auto) | None => { Some(LanguageModelToolChoice::Auto) | None => {
BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build()) BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
@ -990,10 +1021,11 @@ pub fn map_to_language_model_completion_events(
LanguageModelCompletionEvent::UsageUpdate( LanguageModelCompletionEvent::UsageUpdate(
TokenUsage { TokenUsage {
input_tokens: metadata.input_tokens as u64, input_tokens: metadata.input_tokens as u64,
output_tokens: metadata.output_tokens output_tokens: metadata.output_tokens as u64,
as u64, cache_creation_input_tokens:
cache_creation_input_tokens: default(), metadata.cache_write_input_tokens.unwrap_or_default() as u64,
cache_read_input_tokens: default(), cache_read_input_tokens:
metadata.cache_read_input_tokens.unwrap_or_default() as u64,
}, },
); );
return Some((Some(Ok(completion_event)), state)); return Some((Some(Ok(completion_event)), state));