From 0905255fd14df57b4dd250ea31bdda368e62d6ce Mon Sep 17 00:00:00 2001 From: Vladimir Kuznichenkov <5330267+kuzaxak@users.noreply.github.com> Date: Wed, 25 Jun 2025 17:15:13 +0300 Subject: [PATCH] bedrock: Add prompt caching support (#33194) Closes https://github.com/zed-industries/zed/issues/33221 Bedrock has similar to anthropic caching api, if we want to cache messages up to a certain point, we should add a special block into that message. Additionally, we can cache tools definition by adding cache point block after tools spec. See: [Bedrock User Guide: Prompt Caching](https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html#prompt-caching-models) Release Notes: - bedrock: Added prompt caching support --------- Co-authored-by: Oleksiy Syvokon --- crates/bedrock/src/models.rs | 59 +++++++++++++++++++ .../language_models/src/provider/bedrock.rs | 52 ++++++++++++---- 2 files changed, 101 insertions(+), 10 deletions(-) diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs index 272ac0e52c..b6eeafa2d6 100644 --- a/crates/bedrock/src/models.rs +++ b/crates/bedrock/src/models.rs @@ -11,6 +11,13 @@ pub enum BedrockModelMode { }, } +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct BedrockModelCacheConfiguration { + pub max_cache_anchors: usize, + pub min_total_token: u64, +} + #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] pub enum Model { @@ -104,6 +111,7 @@ pub enum Model { display_name: Option, max_output_tokens: Option, default_temperature: Option, + cache_configuration: Option, }, } @@ -401,6 +409,56 @@ impl Model { } } + pub fn supports_caching(&self) -> bool { + match self { + // Only Claude models on Bedrock support caching + // Nova models support only text caching + // https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html#prompt-caching-models + Self::Claude3_5Haiku + | Self::Claude3_7Sonnet + | Self::Claude3_7SonnetThinking + | Self::ClaudeSonnet4 + | Self::ClaudeSonnet4Thinking + | Self::ClaudeOpus4 + | Self::ClaudeOpus4Thinking => true, + + // Custom models - check if they have cache configuration + Self::Custom { + cache_configuration, + .. + } => cache_configuration.is_some(), + + // All other models don't support caching + _ => false, + } + } + + pub fn cache_configuration(&self) -> Option { + match self { + Self::Claude3_7Sonnet + | Self::Claude3_7SonnetThinking + | Self::ClaudeSonnet4 + | Self::ClaudeSonnet4Thinking + | Self::ClaudeOpus4 + | Self::ClaudeOpus4Thinking => Some(BedrockModelCacheConfiguration { + max_cache_anchors: 4, + min_total_token: 1024, + }), + + Self::Claude3_5Haiku => Some(BedrockModelCacheConfiguration { + max_cache_anchors: 4, + min_total_token: 2048, + }), + + Self::Custom { + cache_configuration, + .. + } => cache_configuration.clone(), + + _ => None, + } + } + pub fn mode(&self) -> BedrockModelMode { match self { Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking { @@ -660,6 +718,7 @@ mod tests { display_name: Some("My Custom Model".to_string()), max_output_tokens: Some(8192), default_temperature: Some(0.7), + cache_configuration: None, }; // Custom model should return its name unchanged diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 2b2527f1ac..a55fc5bc11 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -11,8 +11,8 @@ use aws_http_client::AwsHttpClient; use bedrock::bedrock_client::Client as BedrockClient; use bedrock::bedrock_client::config::timeout::TimeoutConfig; use bedrock::bedrock_client::types::{ - ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, ReasoningContentBlockDelta, - StopReason, + CachePointBlock, CachePointType, ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, + ReasoningContentBlockDelta, StopReason, }; use bedrock::{ BedrockAnyToolChoice, BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent, @@ -48,7 +48,7 @@ use strum::{EnumIter, IntoEnumIterator, IntoStaticStr}; use theme::ThemeSettings; use tokio::runtime::Handle; use ui::{Icon, IconName, List, Tooltip, prelude::*}; -use util::{ResultExt, default}; +use util::ResultExt; use crate::AllLanguageModelSettings; @@ -329,6 +329,12 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { max_tokens: model.max_tokens, max_output_tokens: model.max_output_tokens, default_temperature: model.default_temperature, + cache_configuration: model.cache_configuration.as_ref().map(|config| { + bedrock::BedrockModelCacheConfiguration { + max_cache_anchors: config.max_cache_anchors, + min_total_token: config.min_total_token, + } + }), }, ); } @@ -558,6 +564,7 @@ impl LanguageModel for BedrockModel { self.model.default_temperature(), self.model.max_output_tokens(), self.model.mode(), + self.model.supports_caching(), ) { Ok(request) => request, Err(err) => return futures::future::ready(Err(err.into())).boxed(), @@ -581,7 +588,13 @@ impl LanguageModel for BedrockModel { } fn cache_configuration(&self) -> Option { - None + self.model + .cache_configuration() + .map(|config| LanguageModelCacheConfiguration { + max_cache_anchors: config.max_cache_anchors, + should_speculate: false, + min_total_token: config.min_total_token, + }) } } @@ -608,6 +621,7 @@ pub fn into_bedrock( default_temperature: f32, max_output_tokens: u64, mode: BedrockModelMode, + supports_caching: bool, ) -> Result { let mut new_messages: Vec = Vec::new(); let mut system_message = String::new(); @@ -619,7 +633,7 @@ pub fn into_bedrock( match message.role { Role::User | Role::Assistant => { - let bedrock_message_content: Vec = message + let mut bedrock_message_content: Vec = message .content .into_iter() .filter_map(|content| match content { @@ -703,6 +717,14 @@ pub fn into_bedrock( _ => None, }) .collect(); + if message.cache && supports_caching { + bedrock_message_content.push(BedrockInnerContent::CachePoint( + CachePointBlock::builder() + .r#type(CachePointType::Default) + .build() + .context("failed to build cache point block")?, + )); + } let bedrock_role = match message.role { Role::User => bedrock::BedrockRole::User, Role::Assistant => bedrock::BedrockRole::Assistant, @@ -731,7 +753,7 @@ pub fn into_bedrock( } } - let tool_spec: Vec = request + let mut tool_spec: Vec = request .tools .iter() .filter_map(|tool| { @@ -748,6 +770,15 @@ pub fn into_bedrock( }) .collect(); + if !tool_spec.is_empty() && supports_caching { + tool_spec.push(BedrockTool::CachePoint( + CachePointBlock::builder() + .r#type(CachePointType::Default) + .build() + .context("failed to build cache point block")?, + )); + } + let tool_choice = match request.tool_choice { Some(LanguageModelToolChoice::Auto) | None => { BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build()) @@ -990,10 +1021,11 @@ pub fn map_to_language_model_completion_events( LanguageModelCompletionEvent::UsageUpdate( TokenUsage { input_tokens: metadata.input_tokens as u64, - output_tokens: metadata.output_tokens - as u64, - cache_creation_input_tokens: default(), - cache_read_input_tokens: default(), + output_tokens: metadata.output_tokens as u64, + cache_creation_input_tokens: + metadata.cache_write_input_tokens.unwrap_or_default() as u64, + cache_read_input_tokens: + metadata.cache_read_input_tokens.unwrap_or_default() as u64, }, ); return Some((Some(Ok(completion_event)), state));