bedrock: Add prompt caching support (#33194)
Closes https://github.com/zed-industries/zed/issues/33221 Bedrock has similar to anthropic caching api, if we want to cache messages up to a certain point, we should add a special block into that message. Additionally, we can cache tools definition by adding cache point block after tools spec. See: [Bedrock User Guide: Prompt Caching](https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html#prompt-caching-models) Release Notes: - bedrock: Added prompt caching support --------- Co-authored-by: Oleksiy Syvokon <oleksiy@zed.dev>
This commit is contained in:
parent
59aeede50d
commit
0905255fd1
2 changed files with 101 additions and 10 deletions
|
@ -11,6 +11,13 @@ pub enum BedrockModelMode {
|
|||
},
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub struct BedrockModelCacheConfiguration {
|
||||
pub max_cache_anchors: usize,
|
||||
pub min_total_token: u64,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||
pub enum Model {
|
||||
|
@ -104,6 +111,7 @@ pub enum Model {
|
|||
display_name: Option<String>,
|
||||
max_output_tokens: Option<u64>,
|
||||
default_temperature: Option<f32>,
|
||||
cache_configuration: Option<BedrockModelCacheConfiguration>,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -401,6 +409,56 @@ impl Model {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn supports_caching(&self) -> bool {
|
||||
match self {
|
||||
// Only Claude models on Bedrock support caching
|
||||
// Nova models support only text caching
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html#prompt-caching-models
|
||||
Self::Claude3_5Haiku
|
||||
| Self::Claude3_7Sonnet
|
||||
| Self::Claude3_7SonnetThinking
|
||||
| Self::ClaudeSonnet4
|
||||
| Self::ClaudeSonnet4Thinking
|
||||
| Self::ClaudeOpus4
|
||||
| Self::ClaudeOpus4Thinking => true,
|
||||
|
||||
// Custom models - check if they have cache configuration
|
||||
Self::Custom {
|
||||
cache_configuration,
|
||||
..
|
||||
} => cache_configuration.is_some(),
|
||||
|
||||
// All other models don't support caching
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cache_configuration(&self) -> Option<BedrockModelCacheConfiguration> {
|
||||
match self {
|
||||
Self::Claude3_7Sonnet
|
||||
| Self::Claude3_7SonnetThinking
|
||||
| Self::ClaudeSonnet4
|
||||
| Self::ClaudeSonnet4Thinking
|
||||
| Self::ClaudeOpus4
|
||||
| Self::ClaudeOpus4Thinking => Some(BedrockModelCacheConfiguration {
|
||||
max_cache_anchors: 4,
|
||||
min_total_token: 1024,
|
||||
}),
|
||||
|
||||
Self::Claude3_5Haiku => Some(BedrockModelCacheConfiguration {
|
||||
max_cache_anchors: 4,
|
||||
min_total_token: 2048,
|
||||
}),
|
||||
|
||||
Self::Custom {
|
||||
cache_configuration,
|
||||
..
|
||||
} => cache_configuration.clone(),
|
||||
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mode(&self) -> BedrockModelMode {
|
||||
match self {
|
||||
Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
|
||||
|
@ -660,6 +718,7 @@ mod tests {
|
|||
display_name: Some("My Custom Model".to_string()),
|
||||
max_output_tokens: Some(8192),
|
||||
default_temperature: Some(0.7),
|
||||
cache_configuration: None,
|
||||
};
|
||||
|
||||
// Custom model should return its name unchanged
|
||||
|
|
|
@ -11,8 +11,8 @@ use aws_http_client::AwsHttpClient;
|
|||
use bedrock::bedrock_client::Client as BedrockClient;
|
||||
use bedrock::bedrock_client::config::timeout::TimeoutConfig;
|
||||
use bedrock::bedrock_client::types::{
|
||||
ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, ReasoningContentBlockDelta,
|
||||
StopReason,
|
||||
CachePointBlock, CachePointType, ContentBlockDelta, ContentBlockStart, ConverseStreamOutput,
|
||||
ReasoningContentBlockDelta, StopReason,
|
||||
};
|
||||
use bedrock::{
|
||||
BedrockAnyToolChoice, BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent,
|
||||
|
@ -48,7 +48,7 @@ use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
|
|||
use theme::ThemeSettings;
|
||||
use tokio::runtime::Handle;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
use util::{ResultExt, default};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
|
||||
|
@ -329,6 +329,12 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
|
|||
max_tokens: model.max_tokens,
|
||||
max_output_tokens: model.max_output_tokens,
|
||||
default_temperature: model.default_temperature,
|
||||
cache_configuration: model.cache_configuration.as_ref().map(|config| {
|
||||
bedrock::BedrockModelCacheConfiguration {
|
||||
max_cache_anchors: config.max_cache_anchors,
|
||||
min_total_token: config.min_total_token,
|
||||
}
|
||||
}),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -558,6 +564,7 @@ impl LanguageModel for BedrockModel {
|
|||
self.model.default_temperature(),
|
||||
self.model.max_output_tokens(),
|
||||
self.model.mode(),
|
||||
self.model.supports_caching(),
|
||||
) {
|
||||
Ok(request) => request,
|
||||
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
|
||||
|
@ -581,7 +588,13 @@ impl LanguageModel for BedrockModel {
|
|||
}
|
||||
|
||||
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
||||
None
|
||||
self.model
|
||||
.cache_configuration()
|
||||
.map(|config| LanguageModelCacheConfiguration {
|
||||
max_cache_anchors: config.max_cache_anchors,
|
||||
should_speculate: false,
|
||||
min_total_token: config.min_total_token,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -608,6 +621,7 @@ pub fn into_bedrock(
|
|||
default_temperature: f32,
|
||||
max_output_tokens: u64,
|
||||
mode: BedrockModelMode,
|
||||
supports_caching: bool,
|
||||
) -> Result<bedrock::Request> {
|
||||
let mut new_messages: Vec<BedrockMessage> = Vec::new();
|
||||
let mut system_message = String::new();
|
||||
|
@ -619,7 +633,7 @@ pub fn into_bedrock(
|
|||
|
||||
match message.role {
|
||||
Role::User | Role::Assistant => {
|
||||
let bedrock_message_content: Vec<BedrockInnerContent> = message
|
||||
let mut bedrock_message_content: Vec<BedrockInnerContent> = message
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(|content| match content {
|
||||
|
@ -703,6 +717,14 @@ pub fn into_bedrock(
|
|||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
if message.cache && supports_caching {
|
||||
bedrock_message_content.push(BedrockInnerContent::CachePoint(
|
||||
CachePointBlock::builder()
|
||||
.r#type(CachePointType::Default)
|
||||
.build()
|
||||
.context("failed to build cache point block")?,
|
||||
));
|
||||
}
|
||||
let bedrock_role = match message.role {
|
||||
Role::User => bedrock::BedrockRole::User,
|
||||
Role::Assistant => bedrock::BedrockRole::Assistant,
|
||||
|
@ -731,7 +753,7 @@ pub fn into_bedrock(
|
|||
}
|
||||
}
|
||||
|
||||
let tool_spec: Vec<BedrockTool> = request
|
||||
let mut tool_spec: Vec<BedrockTool> = request
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|tool| {
|
||||
|
@ -748,6 +770,15 @@ pub fn into_bedrock(
|
|||
})
|
||||
.collect();
|
||||
|
||||
if !tool_spec.is_empty() && supports_caching {
|
||||
tool_spec.push(BedrockTool::CachePoint(
|
||||
CachePointBlock::builder()
|
||||
.r#type(CachePointType::Default)
|
||||
.build()
|
||||
.context("failed to build cache point block")?,
|
||||
));
|
||||
}
|
||||
|
||||
let tool_choice = match request.tool_choice {
|
||||
Some(LanguageModelToolChoice::Auto) | None => {
|
||||
BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
|
||||
|
@ -990,10 +1021,11 @@ pub fn map_to_language_model_completion_events(
|
|||
LanguageModelCompletionEvent::UsageUpdate(
|
||||
TokenUsage {
|
||||
input_tokens: metadata.input_tokens as u64,
|
||||
output_tokens: metadata.output_tokens
|
||||
as u64,
|
||||
cache_creation_input_tokens: default(),
|
||||
cache_read_input_tokens: default(),
|
||||
output_tokens: metadata.output_tokens as u64,
|
||||
cache_creation_input_tokens:
|
||||
metadata.cache_write_input_tokens.unwrap_or_default() as u64,
|
||||
cache_read_input_tokens:
|
||||
metadata.cache_read_input_tokens.unwrap_or_default() as u64,
|
||||
},
|
||||
);
|
||||
return Some((Some(Ok(completion_event)), state));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue