bedrock: Add support for tool use, cross-region inference, and Claude 3.7 Thinking (#28137)

Closes #27223
Merges: #27996, #26734, #27949 

Release Notes:

- AWS Bedrock: Added advanced authentication strategies with:
  - Short lived credentials with Session Tokens 
  - AWS Named Profile
  - EC2 Identity, Pod Identity, Web Identity
- AWS Bedrock: Added Claude 3.7 Thinking support.
- AWS Bedrock: Adding Cross Region Inference for all combinations of
regions and model availability.
- Agent Beta: Added support for AWS Bedrock.

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Shardul Vaidya 2025-04-05 11:16:26 -04:00 committed by GitHub
parent ea0f5144c9
commit 525755c28e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 1042 additions and 318 deletions

74
Cargo.lock generated
View file

@ -2141,9 +2141,9 @@ dependencies = [
[[package]] [[package]]
name = "blake3" name = "blake3"
version = "1.8.0" version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34a796731680be7931955498a16a10b2270c7762963d5d570fdbfe02dcbf314f" checksum = "389a099b34312839e16420d499a9cad9650541715937ffbdd40d36f49e77eeb3"
dependencies = [ dependencies = [
"arrayref", "arrayref",
"arrayvec", "arrayvec",
@ -2455,9 +2455,9 @@ dependencies = [
[[package]] [[package]]
name = "cap-fs-ext" name = "cap-fs-ext"
version = "3.4.2" version = "3.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f78efdd7378980d79c0f36b519e51191742d2c9f91ffa5e228fba9f3806d2e1" checksum = "f6323b9baffb4d6d9c65bfef3129db62b1391f7fb953fe67c0d7cb0804feb77b"
dependencies = [ dependencies = [
"cap-primitives", "cap-primitives",
"cap-std", "cap-std",
@ -2467,9 +2467,9 @@ dependencies = [
[[package]] [[package]]
name = "cap-net-ext" name = "cap-net-ext"
version = "3.4.2" version = "3.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ac68674a6042af2bcee1adad9f6abd432642cf03444ce3a5b36c3f39f23baf8" checksum = "66022e5e076ea27041a05ebd4349022e2133e6f4977974dffd54ceb7b982e871"
dependencies = [ dependencies = [
"cap-primitives", "cap-primitives",
"cap-std", "cap-std",
@ -2479,9 +2479,9 @@ dependencies = [
[[package]] [[package]]
name = "cap-primitives" name = "cap-primitives"
version = "3.4.2" version = "3.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fc15faeed2223d8b8e8cc1857f5861935a06d06713c4ac106b722ae9ce3c369" checksum = "50ad0183a9659850877cefe8f5b87d564b2dd1fe78b18945813687f29c0a6878"
dependencies = [ dependencies = [
"ambient-authority", "ambient-authority",
"fs-set-times", "fs-set-times",
@ -2496,9 +2496,9 @@ dependencies = [
[[package]] [[package]]
name = "cap-rand" name = "cap-rand"
version = "3.4.2" version = "3.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dea13372b49df066d1ae654e5c6e41799c1efd9f6b36794b921e877ea4037977" checksum = "ab78a9f6301e70c0fe5df7328adbcb9228277fdb7bab36f312fc072f505e38c2"
dependencies = [ dependencies = [
"ambient-authority", "ambient-authority",
"rand 0.8.5", "rand 0.8.5",
@ -2506,9 +2506,9 @@ dependencies = [
[[package]] [[package]]
name = "cap-std" name = "cap-std"
version = "3.4.2" version = "3.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3dbd3e8e8d093d6ccb4b512264869e1281cdb032f7940bd50b2894f96f25609" checksum = "1c41814365b796ed12688026cb90a1e03236a84ccf009628f9c43c8aa3af250a"
dependencies = [ dependencies = [
"cap-primitives", "cap-primitives",
"io-extras", "io-extras",
@ -2518,9 +2518,9 @@ dependencies = [
[[package]] [[package]]
name = "cap-time-ext" name = "cap-time-ext"
version = "3.4.2" version = "3.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd736b20fc033f564a1995fb82fc349146de43aabba19c7368b4cb17d8f9ea53" checksum = "eb57b71bb69b97c638ec38b477e9b33fa1c1cff0e437dd55d15c117cf17ed5dc"
dependencies = [ dependencies = [
"ambient-authority", "ambient-authority",
"cap-primitives", "cap-primitives",
@ -2598,9 +2598,9 @@ dependencies = [
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.2.17" version = "1.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fcb57c740ae1daf453ae85f16e37396f672b039e00d9d866e07ddb24e328e3a" checksum = "525046617d8376e3db1deffb079e91cef90a89fc3ca5c185bbf8c9ecdd15cd5c"
dependencies = [ dependencies = [
"jobserver", "jobserver",
"libc", "libc",
@ -3926,9 +3926,9 @@ checksum = "4f211af61d8efdd104f96e57adf5e426ba1bc3ed7a4ead616e15e5881fd79c4d"
[[package]] [[package]]
name = "ctrlc" name = "ctrlc"
version = "3.4.5" version = "3.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90eeab0aa92f3f9b4e87f258c72b139c207d251f9cbc1080a0086b86a8870dd3" checksum = "697b5419f348fd5ae2478e8018cb016c00a5881c7f46c717de98ffd135a5651c"
dependencies = [ dependencies = [
"nix", "nix",
"windows-sys 0.59.0", "windows-sys 0.59.0",
@ -4805,9 +4805,9 @@ dependencies = [
[[package]] [[package]]
name = "errno" name = "errno"
version = "0.3.10" version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.59.0", "windows-sys 0.59.0",
@ -5252,9 +5252,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]] [[package]]
name = "flate2" name = "flate2"
version = "1.1.0" version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc" checksum = "7ced92e76e966ca2fd84c8f7aa01a4aea65b0eb6648d72f7c8f3e2764a67fece"
dependencies = [ dependencies = [
"crc32fast", "crc32fast",
"miniz_oxide", "miniz_oxide",
@ -7891,9 +7891,9 @@ checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa"
[[package]] [[package]]
name = "libmimalloc-sys" name = "libmimalloc-sys"
version = "0.1.40" version = "0.1.41"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07d0e07885d6a754b9c7993f2625187ad694ee985d60f23355ff0e7077261502" checksum = "6b20daca3a4ac14dbdc753c5e90fc7b490a48a9131daed3c9a9ced7b2defd37b"
dependencies = [ dependencies = [
"cc", "cc",
"libc", "libc",
@ -8562,9 +8562,9 @@ dependencies = [
[[package]] [[package]]
name = "mimalloc" name = "mimalloc"
version = "0.1.44" version = "0.1.45"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99585191385958383e13f6b822e6b6d8d9cf928e7d286ceb092da92b43c87bc1" checksum = "03cb1f88093fe50061ca1195d336ffec131347c7b833db31f9ab62a2d1b7925f"
dependencies = [ dependencies = [
"libmimalloc-sys", "libmimalloc-sys",
] ]
@ -8593,9 +8593,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]] [[package]]
name = "miniz_oxide" name = "miniz_oxide"
version = "0.8.5" version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" checksum = "ff70ce3e48ae43fa075863cef62e8b43b71a4f2382229920e0df362592919430"
dependencies = [ dependencies = [
"adler2", "adler2",
"simd-adler32", "simd-adler32",
@ -9474,9 +9474,9 @@ dependencies = [
[[package]] [[package]]
name = "openssl" name = "openssl"
version = "0.10.71" version = "0.10.72"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd" checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da"
dependencies = [ dependencies = [
"bitflags 2.9.0", "bitflags 2.9.0",
"cfg-if", "cfg-if",
@ -9515,9 +9515,9 @@ dependencies = [
[[package]] [[package]]
name = "openssl-sys" name = "openssl-sys"
version = "0.9.106" version = "0.9.107"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd" checksum = "8288979acd84749c744a9014b4382d42b8f7b2592847b5afb2ed29e5d16ede07"
dependencies = [ dependencies = [
"cc", "cc",
"libc", "libc",
@ -12149,7 +12149,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
dependencies = [ dependencies = [
"bitflags 2.9.0", "bitflags 2.9.0",
"errno 0.3.10", "errno 0.3.11",
"itoa", "itoa",
"libc", "libc",
"linux-raw-sys 0.4.15", "linux-raw-sys 0.4.15",
@ -12164,7 +12164,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d97817398dd4bb2e6da002002db259209759911da105da92bec29ccb12cf58bf" checksum = "d97817398dd4bb2e6da002002db259209759911da105da92bec29ccb12cf58bf"
dependencies = [ dependencies = [
"bitflags 2.9.0", "bitflags 2.9.0",
"errno 0.3.10", "errno 0.3.11",
"libc", "libc",
"linux-raw-sys 0.9.3", "linux-raw-sys 0.9.3",
"windows-sys 0.59.0", "windows-sys 0.59.0",
@ -12176,7 +12176,7 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a25c3aad9fc1424eb82c88087789a7d938e1829724f3e4043163baf0d13cfc12" checksum = "a25c3aad9fc1424eb82c88087789a7d938e1829724f3e4043163baf0d13cfc12"
dependencies = [ dependencies = [
"errno 0.3.10", "errno 0.3.11",
"libc", "libc",
"rustix 0.38.44", "rustix 0.38.44",
] ]
@ -13798,9 +13798,9 @@ dependencies = [
[[package]] [[package]]
name = "swash" name = "swash"
version = "0.2.1" version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13d5bbc2aa266907ed8ee977c9c9e16363cc2b001266104e13397b57f1d15f71" checksum = "fae9a562c7b46107d9c78cd78b75bbe1e991c16734c0aee8ff0ee711fb8b620a"
dependencies = [ dependencies = [
"skrifa", "skrifa",
"yazi", "yazi",

View file

@ -399,11 +399,11 @@ async-trait = "0.1"
async-tungstenite = "0.28" async-tungstenite = "0.28"
async-watch = "0.3.1" async-watch = "0.3.1"
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] } async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
aws-config = { version = "1.5.16", features = ["behavior-version-latest"] } aws-config = { version = "1.6.1", features = ["behavior-version-latest"] }
aws-credential-types = { version = "1.2.1", features = ["hardcoded-credentials"] } aws-credential-types = { version = "1.2.2", features = ["hardcoded-credentials"] }
aws-sdk-bedrockruntime = { version = "1.73.0", features = ["behavior-version-latest"] } aws-sdk-bedrockruntime = { version = "1.80.0", features = ["behavior-version-latest"] }
aws-smithy-runtime-api = { version = "1.7.3", features = ["http-1x", "client"] } aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] }
aws-smithy-types = { version = "1.2.13", features = ["http-body-1-x"] } aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] }
base64 = "0.22" base64 = "0.22"
bitflags = "2.6.0" bitflags = "2.6.0"
blade-graphics = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" } blade-graphics = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" }

View file

@ -1,21 +1,25 @@
mod models; mod models;
use std::collections::HashMap;
use std::pin::Pin; use std::pin::Pin;
use anyhow::{Context, Error, Result, anyhow}; use anyhow::{Error, Result, anyhow};
use aws_sdk_bedrockruntime as bedrock; use aws_sdk_bedrockruntime as bedrock;
pub use aws_sdk_bedrockruntime as bedrock_client; pub use aws_sdk_bedrockruntime as bedrock_client;
pub use aws_sdk_bedrockruntime::types::{ pub use aws_sdk_bedrockruntime::types::{
ContentBlock as BedrockInnerContent, SpecificToolChoice as BedrockSpecificTool, AutoToolChoice as BedrockAutoToolChoice, ContentBlock as BedrockInnerContent,
ToolChoice as BedrockToolChoice, ToolInputSchema as BedrockToolInputSchema, Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolConfiguration as BedrockToolConfig,
ToolSpecification as BedrockTool, ToolInputSchema as BedrockToolInputSchema, ToolSpecification as BedrockToolSpec,
}; };
use aws_smithy_types::{Document, Number as AwsNumber}; use aws_smithy_types::{Document, Number as AwsNumber};
pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest; pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
pub use bedrock::types::{ pub use bedrock::types::{
ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole, ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse, ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
Message as BedrockMessage, ResponseStream as BedrockResponseStream, ImageBlock as BedrockImageBlock, Message as BedrockMessage,
ResponseStream as BedrockResponseStream, ToolResultBlock as BedrockToolResultBlock,
ToolResultContentBlock as BedrockToolResultContentBlock,
ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
}; };
use futures::stream::{self, BoxStream, Stream}; use futures::stream::{self, BoxStream, Stream};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -24,25 +28,6 @@ use thiserror::Error;
pub use crate::models::*; pub use crate::models::*;
pub async fn complete(
client: &bedrock::Client,
request: Request,
) -> Result<BedrockResponse, BedrockError> {
let response = bedrock::Client::converse(client)
.model_id(request.model.clone())
.set_messages(request.messages.into())
.send()
.await
.context("failed to send request to Bedrock");
match response {
Ok(output) => output
.output
.ok_or_else(|| BedrockError::Other(anyhow!("no output"))),
Err(err) => Err(BedrockError::Other(err)),
}
}
pub async fn stream_completion( pub async fn stream_completion(
client: bedrock::Client, client: bedrock::Client,
request: Request, request: Request,
@ -50,11 +35,32 @@ pub async fn stream_completion(
) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> { ) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
handle handle
.spawn(async move { .spawn(async move {
let response = bedrock::Client::converse_stream(&client) let mut response = bedrock::Client::converse_stream(&client)
.model_id(request.model.clone()) .model_id(request.model.clone())
.set_messages(request.messages.into()) .set_messages(request.messages.into());
.send()
.await; if let Some(Thinking::Enabled {
budget_tokens: Some(budget_tokens),
}) = request.thinking
{
response =
response.additional_model_request_fields(Document::Object(HashMap::from([(
"thinking".to_string(),
Document::from(HashMap::from([
("type".to_string(), Document::String("enabled".to_string())),
(
"budget_tokens".to_string(),
Document::Number(AwsNumber::PosInt(budget_tokens)),
),
])),
)])));
}
if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() {
response = response.set_tool_config(request.tools);
}
let response = response.send().await;
match response { match response {
Ok(output) => { Ok(output) => {
@ -65,7 +71,7 @@ pub async fn stream_completion(
>, >,
> = Box::pin(stream::unfold(output.stream, |mut stream| async move { > = Box::pin(stream::unfold(output.stream, |mut stream| async move {
match stream.recv().await { match stream.recv().await {
Ok(Some(output)) => Some((Ok(output), stream)), Ok(Some(output)) => Some(({ Ok(output) }, stream)),
Ok(None) => None, Ok(None) => None,
Err(err) => { Err(err) => {
Some(( Some((
@ -135,13 +141,18 @@ pub fn value_to_aws_document(value: &Value) -> Document {
} }
} }
#[derive(Debug, Serialize, Deserialize)]
pub enum Thinking {
Enabled { budget_tokens: Option<u64> },
}
#[derive(Debug)] #[derive(Debug)]
pub struct Request { pub struct Request {
pub model: String, pub model: String,
pub max_tokens: u32, pub max_tokens: u32,
pub messages: Vec<BedrockMessage>, pub messages: Vec<BedrockMessage>,
pub tools: Vec<BedrockTool>, pub tools: Option<BedrockToolConfig>,
pub tool_choice: Option<BedrockToolChoice>, pub thinking: Option<Thinking>,
pub system: Option<String>, pub system: Option<String>,
pub metadata: Option<Metadata>, pub metadata: Option<Metadata>,
pub stop_sequences: Vec<String>, pub stop_sequences: Vec<String>,

View file

@ -2,21 +2,38 @@ use anyhow::anyhow;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use strum::EnumIter; use strum::EnumIter;
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub enum BedrockModelMode {
#[default]
Default,
Thinking {
budget_tokens: Option<u64>,
},
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model { pub enum Model {
// Anthropic models (already included) // Anthropic models (already included)
#[default] #[default]
#[serde(rename = "claude-3-5-sonnet-v2", alias = "claude-3-5-sonnet-latest")] #[serde(rename = "claude-3-5-sonnet-v2", alias = "claude-3-5-sonnet-latest")]
Claude3_5Sonnet, Claude3_5SonnetV2,
#[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")] #[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")]
Claude3_7Sonnet, Claude3_7Sonnet,
#[serde(
rename = "claude-3-7-sonnet-thinking",
alias = "claude-3-7-sonnet-thinking-latest"
)]
Claude3_7SonnetThinking,
#[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")] #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
Claude3Opus, Claude3Opus,
#[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")] #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
Claude3Sonnet, Claude3Sonnet,
#[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")] #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
Claude3_5Haiku, Claude3_5Haiku,
Claude3_5Sonnet,
Claude3Haiku,
// Amazon Nova Models // Amazon Nova Models
AmazonNovaLite, AmazonNovaLite,
AmazonNovaMicro, AmazonNovaMicro,
@ -69,7 +86,7 @@ pub enum Model {
impl Model { impl Model {
pub fn from_id(id: &str) -> anyhow::Result<Self> { pub fn from_id(id: &str) -> anyhow::Result<Self> {
if id.starts_with("claude-3-5-sonnet-v2") { if id.starts_with("claude-3-5-sonnet-v2") {
Ok(Self::Claude3_5Sonnet) Ok(Self::Claude3_5SonnetV2)
} else if id.starts_with("claude-3-opus") { } else if id.starts_with("claude-3-opus") {
Ok(Self::Claude3Opus) Ok(Self::Claude3Opus)
} else if id.starts_with("claude-3-sonnet") { } else if id.starts_with("claude-3-sonnet") {
@ -78,6 +95,8 @@ impl Model {
Ok(Self::Claude3_5Haiku) Ok(Self::Claude3_5Haiku)
} else if id.starts_with("claude-3-7-sonnet") { } else if id.starts_with("claude-3-7-sonnet") {
Ok(Self::Claude3_7Sonnet) Ok(Self::Claude3_7Sonnet)
} else if id.starts_with("claude-3-7-sonnet-thinking") {
Ok(Self::Claude3_7SonnetThinking)
} else { } else {
Err(anyhow!("invalid model id")) Err(anyhow!("invalid model id"))
} }
@ -85,14 +104,18 @@ impl Model {
pub fn id(&self) -> &str { pub fn id(&self) -> &str {
match self { match self {
Model::Claude3_5Sonnet => "us.anthropic.claude-3-5-sonnet-20241022-v2:0", Model::Claude3_5SonnetV2 => "anthropic.claude-3-5-sonnet-20241022-v2:0",
Model::Claude3Opus => "us.anthropic.claude-3-opus-20240229-v1:0", Model::Claude3_5Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0",
Model::Claude3Sonnet => "us.anthropic.claude-3-sonnet-20240229-v1:0", Model::Claude3Opus => "anthropic.claude-3-opus-20240229-v1:0",
Model::Claude3_5Haiku => "us.anthropic.claude-3-5-haiku-20241022-v1:0", Model::Claude3Sonnet => "anthropic.claude-3-sonnet-20240229-v1:0",
Model::Claude3_7Sonnet => "us.anthropic.claude-3-7-sonnet-20250219-v1:0", Model::Claude3Haiku => "anthropic.claude-3-haiku-20240307-v1:0",
Model::AmazonNovaLite => "us.amazon.nova-lite-v1:0", Model::Claude3_5Haiku => "anthropic.claude-3-5-haiku-20241022-v1:0",
Model::AmazonNovaMicro => "us.amazon.nova-micro-v1:0", Model::Claude3_7Sonnet | Model::Claude3_7SonnetThinking => {
Model::AmazonNovaPro => "us.amazon.nova-pro-v1:0", "anthropic.claude-3-7-sonnet-20250219-v1:0"
}
Model::AmazonNovaLite => "amazon.nova-lite-v1:0",
Model::AmazonNovaMicro => "amazon.nova-micro-v1:0",
Model::AmazonNovaPro => "amazon.nova-pro-v1:0",
Model::DeepSeekR1 => "us.deepseek.r1-v1:0", Model::DeepSeekR1 => "us.deepseek.r1-v1:0",
Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct", Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct",
Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct", Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct",
@ -128,11 +151,14 @@ impl Model {
pub fn display_name(&self) -> &str { pub fn display_name(&self) -> &str {
match self { match self {
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet v2", Self::Claude3_5SonnetV2 => "Claude 3.5 Sonnet v2",
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3Opus => "Claude 3 Opus", Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet", Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Claude3_5Haiku => "Claude 3.5 Haiku", Self::Claude3_5Haiku => "Claude 3.5 Haiku",
Self::Claude3_7Sonnet => "Claude 3.7 Sonnet", Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
Self::AmazonNovaLite => "Amazon Nova Lite", Self::AmazonNovaLite => "Amazon Nova Lite",
Self::AmazonNovaMicro => "Amazon Nova Micro", Self::AmazonNovaMicro => "Amazon Nova Micro",
Self::AmazonNovaPro => "Amazon Nova Pro", Self::AmazonNovaPro => "Amazon Nova Pro",
@ -173,7 +199,7 @@ impl Model {
pub fn max_token_count(&self) -> usize { pub fn max_token_count(&self) -> usize {
match self { match self {
Self::Claude3_5Sonnet Self::Claude3_5SonnetV2
| Self::Claude3Opus | Self::Claude3Opus
| Self::Claude3Sonnet | Self::Claude3Sonnet
| Self::Claude3_5Haiku | Self::Claude3_5Haiku
@ -186,7 +212,8 @@ impl Model {
pub fn max_output_tokens(&self) -> u32 { pub fn max_output_tokens(&self) -> u32 {
match self { match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096, Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096,
Self::Claude3_5Sonnet => 8_192, Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => 128_000,
Self::Claude3_5SonnetV2 => 8_192,
Self::Custom { Self::Custom {
max_output_tokens, .. max_output_tokens, ..
} => max_output_tokens.unwrap_or(4_096), } => max_output_tokens.unwrap_or(4_096),
@ -196,7 +223,7 @@ impl Model {
pub fn default_temperature(&self) -> f32 { pub fn default_temperature(&self) -> f32 {
match self { match self {
Self::Claude3_5Sonnet Self::Claude3_5SonnetV2
| Self::Claude3Opus | Self::Claude3Opus
| Self::Claude3Sonnet | Self::Claude3Sonnet
| Self::Claude3_5Haiku | Self::Claude3_5Haiku
@ -208,4 +235,253 @@ impl Model {
_ => 1.0, _ => 1.0,
} }
} }
pub fn supports_tool_use(&self) -> bool {
match self {
// Anthropic Claude 3 models (all support tool use)
Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3_5Sonnet
| Self::Claude3_5SonnetV2
| Self::Claude3_7Sonnet
| Self::Claude3_7SonnetThinking
| Self::Claude3_5Haiku => true,
// Amazon Nova models (all support tool use)
Self::AmazonNovaPro | Self::AmazonNovaLite | Self::AmazonNovaMicro => true,
// AI21 Jamba 1.5 models support tool use
Self::AI21Jamba15LargeV1 | Self::AI21Jamba15MiniV1 => true,
// Cohere Command R models support tool use
Self::CohereCommandRV1 | Self::CohereCommandRPlusV1 => true,
// All other models don't support tool use
// Including Meta Llama 3.2, AI21 Jurassic, and others
_ => false,
}
}
pub fn mode(&self) -> BedrockModelMode {
match self {
Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
budget_tokens: Some(4096),
},
_ => BedrockModelMode::Default,
}
}
pub fn cross_region_inference_id(&self, region: &str) -> Result<String, anyhow::Error> {
let region_group = if region.starts_with("us-gov-") {
"us-gov"
} else if region.starts_with("us-") {
"us"
} else if region.starts_with("eu-") {
"eu"
} else if region.starts_with("ap-") || region == "me-central-1" || region == "me-south-1" {
"apac"
} else if region.starts_with("ca-") || region.starts_with("sa-") {
// Canada and South America regions - default to US profiles
"us"
} else {
// Unknown region
return Err(anyhow!("Unsupported Region"));
};
let model_id = self.id();
match (self, region_group) {
// Custom models can't have CRI IDs
(Model::Custom { .. }, _) => Ok(self.id().into()),
// Models with US Gov only
(Model::Claude3_5Sonnet, "us-gov") | (Model::Claude3Haiku, "us-gov") => {
Ok(format!("{}.{}", region_group, model_id))
}
// Models available only in US
(Model::Claude3Opus, "us")
| (Model::Claude3_7Sonnet, "us")
| (Model::Claude3_7SonnetThinking, "us") => {
Ok(format!("{}.{}", region_group, model_id))
}
// Models available in US, EU, and APAC
(Model::Claude3_5SonnetV2, "us")
| (Model::Claude3_5SonnetV2, "apac")
| (Model::Claude3_5Sonnet, _)
| (Model::Claude3Haiku, _)
| (Model::Claude3Sonnet, _)
| (Model::AmazonNovaLite, _)
| (Model::AmazonNovaMicro, _)
| (Model::AmazonNovaPro, _) => Ok(format!("{}.{}", region_group, model_id)),
// Models with limited EU availability
(Model::MetaLlama321BInstructV1, "us")
| (Model::MetaLlama321BInstructV1, "eu")
| (Model::MetaLlama323BInstructV1, "us")
| (Model::MetaLlama323BInstructV1, "eu") => {
Ok(format!("{}.{}", region_group, model_id))
}
// US-only models (all remaining Meta models)
(Model::MetaLlama38BInstructV1, "us")
| (Model::MetaLlama370BInstructV1, "us")
| (Model::MetaLlama318BInstructV1, "us")
| (Model::MetaLlama318BInstructV1_128k, "us")
| (Model::MetaLlama3170BInstructV1, "us")
| (Model::MetaLlama3170BInstructV1_128k, "us")
| (Model::MetaLlama3211BInstructV1, "us")
| (Model::MetaLlama3290BInstructV1, "us") => {
Ok(format!("{}.{}", region_group, model_id))
}
// Any other combination is not supported
_ => Ok(self.id().into()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_us_region_inference_ids() -> anyhow::Result<()> {
// Test US regions
assert_eq!(
Model::Claude3_5SonnetV2.cross_region_inference_id("us-east-1")?,
"us.anthropic.claude-3-5-sonnet-20241022-v2:0"
);
assert_eq!(
Model::Claude3_5SonnetV2.cross_region_inference_id("us-west-2")?,
"us.anthropic.claude-3-5-sonnet-20241022-v2:0"
);
assert_eq!(
Model::AmazonNovaPro.cross_region_inference_id("us-east-2")?,
"us.amazon.nova-pro-v1:0"
);
Ok(())
}
#[test]
fn test_eu_region_inference_ids() -> anyhow::Result<()> {
// Test European regions
assert_eq!(
Model::Claude3Sonnet.cross_region_inference_id("eu-west-1")?,
"eu.anthropic.claude-3-sonnet-20240229-v1:0"
);
assert_eq!(
Model::AmazonNovaMicro.cross_region_inference_id("eu-north-1")?,
"eu.amazon.nova-micro-v1:0"
);
Ok(())
}
#[test]
fn test_apac_region_inference_ids() -> anyhow::Result<()> {
// Test Asia-Pacific regions
assert_eq!(
Model::Claude3_5SonnetV2.cross_region_inference_id("ap-northeast-1")?,
"apac.anthropic.claude-3-5-sonnet-20241022-v2:0"
);
assert_eq!(
Model::AmazonNovaLite.cross_region_inference_id("ap-south-1")?,
"apac.amazon.nova-lite-v1:0"
);
Ok(())
}
#[test]
fn test_gov_region_inference_ids() -> anyhow::Result<()> {
// Test Government regions
assert_eq!(
Model::Claude3_5Sonnet.cross_region_inference_id("us-gov-east-1")?,
"us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0"
);
assert_eq!(
Model::Claude3Haiku.cross_region_inference_id("us-gov-west-1")?,
"us-gov.anthropic.claude-3-haiku-20240307-v1:0"
);
Ok(())
}
#[test]
fn test_meta_models_inference_ids() -> anyhow::Result<()> {
// Test Meta models
assert_eq!(
Model::MetaLlama370BInstructV1.cross_region_inference_id("us-east-1")?,
"us.meta.llama3-70b-instruct-v1:0"
);
assert_eq!(
Model::MetaLlama321BInstructV1.cross_region_inference_id("eu-west-1")?,
"eu.meta.llama3-2-1b-instruct-v1:0"
);
Ok(())
}
#[test]
fn test_mistral_models_inference_ids() -> anyhow::Result<()> {
// Mistral models don't follow the regional prefix pattern,
// so they should return their original IDs
assert_eq!(
Model::MistralMistralLarge2402V1.cross_region_inference_id("us-east-1")?,
"mistral.mistral-large-2402-v1:0"
);
assert_eq!(
Model::MistralMixtral8x7BInstructV0.cross_region_inference_id("eu-west-1")?,
"mistral.mixtral-8x7b-instruct-v0:1"
);
Ok(())
}
#[test]
fn test_ai21_models_inference_ids() -> anyhow::Result<()> {
// AI21 models don't follow the regional prefix pattern,
// so they should return their original IDs
assert_eq!(
Model::AI21J2UltraV1.cross_region_inference_id("us-east-1")?,
"ai21.j2-ultra-v1"
);
assert_eq!(
Model::AI21JambaInstructV1.cross_region_inference_id("eu-west-1")?,
"ai21.jamba-instruct-v1:0"
);
Ok(())
}
#[test]
fn test_cohere_models_inference_ids() -> anyhow::Result<()> {
// Cohere models don't follow the regional prefix pattern,
// so they should return their original IDs
assert_eq!(
Model::CohereCommandRV1.cross_region_inference_id("us-east-1")?,
"cohere.command-r-v1:0"
);
assert_eq!(
Model::CohereCommandTextV14_4k.cross_region_inference_id("ap-southeast-1")?,
"cohere.command-text-v14:7:4k"
);
Ok(())
}
#[test]
fn test_custom_model_inference_ids() -> anyhow::Result<()> {
// Test custom models
let custom_model = Model::Custom {
name: "custom.my-model-v1:0".to_string(),
max_tokens: 100000,
display_name: Some("My Custom Model".to_string()),
max_output_tokens: Some(8192),
default_temperature: Some(0.7),
};
// Custom model should return its name unchanged
assert_eq!(
custom_model.cross_region_inference_id("us-east-1")?,
"custom.my-model-v1:0"
);
Ok(())
}
} }

File diff suppressed because it is too large Load diff

View file

@ -72,6 +72,7 @@ pub struct AllLanguageModelSettings {
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct AllLanguageModelSettingsContent { pub struct AllLanguageModelSettingsContent {
pub anthropic: Option<AnthropicSettingsContent>, pub anthropic: Option<AnthropicSettingsContent>,
pub bedrock: Option<AmazonBedrockSettingsContent>,
pub ollama: Option<OllamaSettingsContent>, pub ollama: Option<OllamaSettingsContent>,
pub lmstudio: Option<LmStudioSettingsContent>, pub lmstudio: Option<LmStudioSettingsContent>,
pub openai: Option<OpenAiSettingsContent>, pub openai: Option<OpenAiSettingsContent>,
@ -160,6 +161,15 @@ pub struct AnthropicSettingsContentV1 {
pub available_models: Option<Vec<provider::anthropic::AvailableModel>>, pub available_models: Option<Vec<provider::anthropic::AvailableModel>>,
} }
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct AmazonBedrockSettingsContent {
available_models: Option<Vec<provider::bedrock::AvailableModel>>,
endpoint_url: Option<String>,
region: Option<String>,
profile: Option<String>,
authentication_method: Option<provider::bedrock::BedrockAuthMethod>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OllamaSettingsContent { pub struct OllamaSettingsContent {
pub api_url: Option<String>, pub api_url: Option<String>,
@ -297,6 +307,25 @@ impl settings::Settings for AllLanguageModelSettings {
anthropic.as_ref().and_then(|s| s.available_models.clone()), anthropic.as_ref().and_then(|s| s.available_models.clone()),
); );
// Bedrock
let bedrock = value.bedrock.clone();
merge(
&mut settings.bedrock.profile_name,
bedrock.as_ref().map(|s| s.profile.clone()),
);
merge(
&mut settings.bedrock.authentication_method,
bedrock.as_ref().map(|s| s.authentication_method.clone()),
);
merge(
&mut settings.bedrock.region,
bedrock.as_ref().map(|s| s.region.clone()),
);
merge(
&mut settings.bedrock.endpoint,
bedrock.as_ref().map(|s| s.endpoint_url.clone()),
);
// Ollama // Ollama
let ollama = value.ollama.clone(); let ollama = value.ollama.clone();