From b1a6e2427f37b71b11dbe0da91ef459cac109b27 Mon Sep 17 00:00:00 2001 From: Roy Williams Date: Fri, 3 Jan 2025 18:46:32 -0500 Subject: [PATCH] anthropic: Allow specifying additional beta headers for custom models (#20551) Release Notes: - Added the ability to specify additional beta headers for custom Anthropic models. --------- Co-authored-by: David Soria Parra <167242713+dsp-ant@users.noreply.github.com> Co-authored-by: Marshall Bowers --- crates/anthropic/src/anthropic.rs | 29 +++++++++++++++---- .../language_models/src/provider/anthropic.rs | 3 ++ crates/language_models/src/provider/cloud.rs | 4 +++ crates/language_models/src/settings.rs | 2 ++ 4 files changed, 33 insertions(+), 5 deletions(-) diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 29e42c03ba..03f60d5a86 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -50,6 +50,8 @@ pub enum Model { cache_configuration: Option, max_output_tokens: Option, default_temperature: Option, + #[serde(default)] + extra_beta_headers: Vec, }, } @@ -146,6 +148,24 @@ impl Model { } } + pub fn beta_headers(&self) -> String { + let mut headers = vec!["prompt-caching-2024-07-31".to_string()]; + + if let Self::Custom { + extra_beta_headers, .. + } = self + { + headers.extend( + extra_beta_headers + .iter() + .filter(|header| !header.trim().is_empty()) + .cloned(), + ); + } + + headers.join(",") + } + pub fn tool_model_id(&self) -> &str { if let Self::Custom { tool_override: Some(tool_override), @@ -166,11 +186,12 @@ pub async fn complete( request: Request, ) -> Result { let uri = format!("{api_url}/v1/messages"); + let model = Model::from_id(&request.model)?; let request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) .header("Anthropic-Version", "2023-06-01") - .header("Anthropic-Beta", "prompt-caching-2024-07-31") + .header("Anthropic-Beta", model.beta_headers()) .header("X-Api-Key", api_key) .header("Content-Type", "application/json"); @@ -281,14 +302,12 @@ pub async fn stream_completion_with_rate_limit_info( stream: true, }; let uri = format!("{api_url}/v1/messages"); + let model = Model::from_id(&request.base.model)?; let request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) .header("Anthropic-Version", "2023-06-01") - .header( - "Anthropic-Beta", - "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15", - ) + .header("Anthropic-Beta", model.beta_headers()) .header("X-Api-Key", api_key) .header("Content-Type", "application/json"); let serialized_request = diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index e882bb900d..1404a3428e 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -52,6 +52,8 @@ pub struct AvailableModel { pub cache_configuration: Option, pub max_output_tokens: Option, pub default_temperature: Option, + #[serde(default)] + pub extra_beta_headers: Vec, } pub struct AnthropicLanguageModelProvider { @@ -202,6 +204,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { }), max_output_tokens: model.max_output_tokens, default_temperature: model.default_temperature, + extra_beta_headers: model.extra_beta_headers.clone(), }, ); } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 6d76b733b7..4621236785 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -94,6 +94,9 @@ pub struct AvailableModel { pub cache_configuration: Option, /// The default temperature to use for this model. pub default_temperature: Option, + /// Any extra beta headers to provide when using the model. + #[serde(default)] + pub extra_beta_headers: Vec, } struct GlobalRefreshLlmTokenListener(Model); @@ -323,6 +326,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { }), default_temperature: model.default_temperature, max_output_tokens: model.max_output_tokens, + extra_beta_headers: model.extra_beta_headers.clone(), }), AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { name: model.name.clone(), diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index f6602427cb..c8ec9f7369 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -97,6 +97,7 @@ impl AnthropicSettingsContent { cache_configuration, max_output_tokens, default_temperature, + extra_beta_headers, } => Some(provider::anthropic::AvailableModel { name, display_name, @@ -111,6 +112,7 @@ impl AnthropicSettingsContent { ), max_output_tokens, default_temperature, + extra_beta_headers, }), _ => None, })