bedrock: Add prompt caching support (#33194)
Closes https://github.com/zed-industries/zed/issues/33221 Bedrock has similar to anthropic caching api, if we want to cache messages up to a certain point, we should add a special block into that message. Additionally, we can cache tools definition by adding cache point block after tools spec. See: [Bedrock User Guide: Prompt Caching](https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html#prompt-caching-models) Release Notes: - bedrock: Added prompt caching support --------- Co-authored-by: Oleksiy Syvokon <oleksiy@zed.dev>
This commit is contained in:
parent
59aeede50d
commit
0905255fd1
2 changed files with 101 additions and 10 deletions
|
@ -11,6 +11,13 @@ pub enum BedrockModelMode {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||||
|
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||||
|
pub struct BedrockModelCacheConfiguration {
|
||||||
|
pub max_cache_anchors: usize,
|
||||||
|
pub min_total_token: u64,
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||||
pub enum Model {
|
pub enum Model {
|
||||||
|
@ -104,6 +111,7 @@ pub enum Model {
|
||||||
display_name: Option<String>,
|
display_name: Option<String>,
|
||||||
max_output_tokens: Option<u64>,
|
max_output_tokens: Option<u64>,
|
||||||
default_temperature: Option<f32>,
|
default_temperature: Option<f32>,
|
||||||
|
cache_configuration: Option<BedrockModelCacheConfiguration>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -401,6 +409,56 @@ impl Model {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn supports_caching(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
// Only Claude models on Bedrock support caching
|
||||||
|
// Nova models support only text caching
|
||||||
|
// https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html#prompt-caching-models
|
||||||
|
Self::Claude3_5Haiku
|
||||||
|
| Self::Claude3_7Sonnet
|
||||||
|
| Self::Claude3_7SonnetThinking
|
||||||
|
| Self::ClaudeSonnet4
|
||||||
|
| Self::ClaudeSonnet4Thinking
|
||||||
|
| Self::ClaudeOpus4
|
||||||
|
| Self::ClaudeOpus4Thinking => true,
|
||||||
|
|
||||||
|
// Custom models - check if they have cache configuration
|
||||||
|
Self::Custom {
|
||||||
|
cache_configuration,
|
||||||
|
..
|
||||||
|
} => cache_configuration.is_some(),
|
||||||
|
|
||||||
|
// All other models don't support caching
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn cache_configuration(&self) -> Option<BedrockModelCacheConfiguration> {
|
||||||
|
match self {
|
||||||
|
Self::Claude3_7Sonnet
|
||||||
|
| Self::Claude3_7SonnetThinking
|
||||||
|
| Self::ClaudeSonnet4
|
||||||
|
| Self::ClaudeSonnet4Thinking
|
||||||
|
| Self::ClaudeOpus4
|
||||||
|
| Self::ClaudeOpus4Thinking => Some(BedrockModelCacheConfiguration {
|
||||||
|
max_cache_anchors: 4,
|
||||||
|
min_total_token: 1024,
|
||||||
|
}),
|
||||||
|
|
||||||
|
Self::Claude3_5Haiku => Some(BedrockModelCacheConfiguration {
|
||||||
|
max_cache_anchors: 4,
|
||||||
|
min_total_token: 2048,
|
||||||
|
}),
|
||||||
|
|
||||||
|
Self::Custom {
|
||||||
|
cache_configuration,
|
||||||
|
..
|
||||||
|
} => cache_configuration.clone(),
|
||||||
|
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn mode(&self) -> BedrockModelMode {
|
pub fn mode(&self) -> BedrockModelMode {
|
||||||
match self {
|
match self {
|
||||||
Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
|
Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
|
||||||
|
@ -660,6 +718,7 @@ mod tests {
|
||||||
display_name: Some("My Custom Model".to_string()),
|
display_name: Some("My Custom Model".to_string()),
|
||||||
max_output_tokens: Some(8192),
|
max_output_tokens: Some(8192),
|
||||||
default_temperature: Some(0.7),
|
default_temperature: Some(0.7),
|
||||||
|
cache_configuration: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Custom model should return its name unchanged
|
// Custom model should return its name unchanged
|
||||||
|
|
|
@ -11,8 +11,8 @@ use aws_http_client::AwsHttpClient;
|
||||||
use bedrock::bedrock_client::Client as BedrockClient;
|
use bedrock::bedrock_client::Client as BedrockClient;
|
||||||
use bedrock::bedrock_client::config::timeout::TimeoutConfig;
|
use bedrock::bedrock_client::config::timeout::TimeoutConfig;
|
||||||
use bedrock::bedrock_client::types::{
|
use bedrock::bedrock_client::types::{
|
||||||
ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, ReasoningContentBlockDelta,
|
CachePointBlock, CachePointType, ContentBlockDelta, ContentBlockStart, ConverseStreamOutput,
|
||||||
StopReason,
|
ReasoningContentBlockDelta, StopReason,
|
||||||
};
|
};
|
||||||
use bedrock::{
|
use bedrock::{
|
||||||
BedrockAnyToolChoice, BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent,
|
BedrockAnyToolChoice, BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent,
|
||||||
|
@ -48,7 +48,7 @@ use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use tokio::runtime::Handle;
|
use tokio::runtime::Handle;
|
||||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||||
use util::{ResultExt, default};
|
use util::ResultExt;
|
||||||
|
|
||||||
use crate::AllLanguageModelSettings;
|
use crate::AllLanguageModelSettings;
|
||||||
|
|
||||||
|
@ -329,6 +329,12 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
|
||||||
max_tokens: model.max_tokens,
|
max_tokens: model.max_tokens,
|
||||||
max_output_tokens: model.max_output_tokens,
|
max_output_tokens: model.max_output_tokens,
|
||||||
default_temperature: model.default_temperature,
|
default_temperature: model.default_temperature,
|
||||||
|
cache_configuration: model.cache_configuration.as_ref().map(|config| {
|
||||||
|
bedrock::BedrockModelCacheConfiguration {
|
||||||
|
max_cache_anchors: config.max_cache_anchors,
|
||||||
|
min_total_token: config.min_total_token,
|
||||||
|
}
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -558,6 +564,7 @@ impl LanguageModel for BedrockModel {
|
||||||
self.model.default_temperature(),
|
self.model.default_temperature(),
|
||||||
self.model.max_output_tokens(),
|
self.model.max_output_tokens(),
|
||||||
self.model.mode(),
|
self.model.mode(),
|
||||||
|
self.model.supports_caching(),
|
||||||
) {
|
) {
|
||||||
Ok(request) => request,
|
Ok(request) => request,
|
||||||
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
|
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
|
||||||
|
@ -581,7 +588,13 @@ impl LanguageModel for BedrockModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
||||||
None
|
self.model
|
||||||
|
.cache_configuration()
|
||||||
|
.map(|config| LanguageModelCacheConfiguration {
|
||||||
|
max_cache_anchors: config.max_cache_anchors,
|
||||||
|
should_speculate: false,
|
||||||
|
min_total_token: config.min_total_token,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -608,6 +621,7 @@ pub fn into_bedrock(
|
||||||
default_temperature: f32,
|
default_temperature: f32,
|
||||||
max_output_tokens: u64,
|
max_output_tokens: u64,
|
||||||
mode: BedrockModelMode,
|
mode: BedrockModelMode,
|
||||||
|
supports_caching: bool,
|
||||||
) -> Result<bedrock::Request> {
|
) -> Result<bedrock::Request> {
|
||||||
let mut new_messages: Vec<BedrockMessage> = Vec::new();
|
let mut new_messages: Vec<BedrockMessage> = Vec::new();
|
||||||
let mut system_message = String::new();
|
let mut system_message = String::new();
|
||||||
|
@ -619,7 +633,7 @@ pub fn into_bedrock(
|
||||||
|
|
||||||
match message.role {
|
match message.role {
|
||||||
Role::User | Role::Assistant => {
|
Role::User | Role::Assistant => {
|
||||||
let bedrock_message_content: Vec<BedrockInnerContent> = message
|
let mut bedrock_message_content: Vec<BedrockInnerContent> = message
|
||||||
.content
|
.content
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|content| match content {
|
.filter_map(|content| match content {
|
||||||
|
@ -703,6 +717,14 @@ pub fn into_bedrock(
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
if message.cache && supports_caching {
|
||||||
|
bedrock_message_content.push(BedrockInnerContent::CachePoint(
|
||||||
|
CachePointBlock::builder()
|
||||||
|
.r#type(CachePointType::Default)
|
||||||
|
.build()
|
||||||
|
.context("failed to build cache point block")?,
|
||||||
|
));
|
||||||
|
}
|
||||||
let bedrock_role = match message.role {
|
let bedrock_role = match message.role {
|
||||||
Role::User => bedrock::BedrockRole::User,
|
Role::User => bedrock::BedrockRole::User,
|
||||||
Role::Assistant => bedrock::BedrockRole::Assistant,
|
Role::Assistant => bedrock::BedrockRole::Assistant,
|
||||||
|
@ -731,7 +753,7 @@ pub fn into_bedrock(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let tool_spec: Vec<BedrockTool> = request
|
let mut tool_spec: Vec<BedrockTool> = request
|
||||||
.tools
|
.tools
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|tool| {
|
.filter_map(|tool| {
|
||||||
|
@ -748,6 +770,15 @@ pub fn into_bedrock(
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
if !tool_spec.is_empty() && supports_caching {
|
||||||
|
tool_spec.push(BedrockTool::CachePoint(
|
||||||
|
CachePointBlock::builder()
|
||||||
|
.r#type(CachePointType::Default)
|
||||||
|
.build()
|
||||||
|
.context("failed to build cache point block")?,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
let tool_choice = match request.tool_choice {
|
let tool_choice = match request.tool_choice {
|
||||||
Some(LanguageModelToolChoice::Auto) | None => {
|
Some(LanguageModelToolChoice::Auto) | None => {
|
||||||
BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
|
BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
|
||||||
|
@ -990,10 +1021,11 @@ pub fn map_to_language_model_completion_events(
|
||||||
LanguageModelCompletionEvent::UsageUpdate(
|
LanguageModelCompletionEvent::UsageUpdate(
|
||||||
TokenUsage {
|
TokenUsage {
|
||||||
input_tokens: metadata.input_tokens as u64,
|
input_tokens: metadata.input_tokens as u64,
|
||||||
output_tokens: metadata.output_tokens
|
output_tokens: metadata.output_tokens as u64,
|
||||||
as u64,
|
cache_creation_input_tokens:
|
||||||
cache_creation_input_tokens: default(),
|
metadata.cache_write_input_tokens.unwrap_or_default() as u64,
|
||||||
cache_read_input_tokens: default(),
|
cache_read_input_tokens:
|
||||||
|
metadata.cache_read_input_tokens.unwrap_or_default() as u64,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
return Some((Some(Ok(completion_event)), state));
|
return Some((Some(Ok(completion_event)), state));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue