bedrock: Add support for tool use, cross-region inference, and Claude 3.7 Thinking (#28137)
Closes #27223 Merges: #27996, #26734, #27949 Release Notes: - AWS Bedrock: Added advanced authentication strategies with: - Short lived credentials with Session Tokens - AWS Named Profile - EC2 Identity, Pod Identity, Web Identity - AWS Bedrock: Added Claude 3.7 Thinking support. - AWS Bedrock: Adding Cross Region Inference for all combinations of regions and model availability. - Agent Beta: Added support for AWS Bedrock. --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
ea0f5144c9
commit
525755c28e
6 changed files with 1042 additions and 318 deletions
74
Cargo.lock
generated
74
Cargo.lock
generated
|
@ -2141,9 +2141,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "blake3"
|
||||
version = "1.8.0"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34a796731680be7931955498a16a10b2270c7762963d5d570fdbfe02dcbf314f"
|
||||
checksum = "389a099b34312839e16420d499a9cad9650541715937ffbdd40d36f49e77eeb3"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"arrayvec",
|
||||
|
@ -2455,9 +2455,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cap-fs-ext"
|
||||
version = "3.4.2"
|
||||
version = "3.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f78efdd7378980d79c0f36b519e51191742d2c9f91ffa5e228fba9f3806d2e1"
|
||||
checksum = "f6323b9baffb4d6d9c65bfef3129db62b1391f7fb953fe67c0d7cb0804feb77b"
|
||||
dependencies = [
|
||||
"cap-primitives",
|
||||
"cap-std",
|
||||
|
@ -2467,9 +2467,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cap-net-ext"
|
||||
version = "3.4.2"
|
||||
version = "3.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ac68674a6042af2bcee1adad9f6abd432642cf03444ce3a5b36c3f39f23baf8"
|
||||
checksum = "66022e5e076ea27041a05ebd4349022e2133e6f4977974dffd54ceb7b982e871"
|
||||
dependencies = [
|
||||
"cap-primitives",
|
||||
"cap-std",
|
||||
|
@ -2479,9 +2479,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cap-primitives"
|
||||
version = "3.4.2"
|
||||
version = "3.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8fc15faeed2223d8b8e8cc1857f5861935a06d06713c4ac106b722ae9ce3c369"
|
||||
checksum = "50ad0183a9659850877cefe8f5b87d564b2dd1fe78b18945813687f29c0a6878"
|
||||
dependencies = [
|
||||
"ambient-authority",
|
||||
"fs-set-times",
|
||||
|
@ -2496,9 +2496,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cap-rand"
|
||||
version = "3.4.2"
|
||||
version = "3.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dea13372b49df066d1ae654e5c6e41799c1efd9f6b36794b921e877ea4037977"
|
||||
checksum = "ab78a9f6301e70c0fe5df7328adbcb9228277fdb7bab36f312fc072f505e38c2"
|
||||
dependencies = [
|
||||
"ambient-authority",
|
||||
"rand 0.8.5",
|
||||
|
@ -2506,9 +2506,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cap-std"
|
||||
version = "3.4.2"
|
||||
version = "3.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3dbd3e8e8d093d6ccb4b512264869e1281cdb032f7940bd50b2894f96f25609"
|
||||
checksum = "1c41814365b796ed12688026cb90a1e03236a84ccf009628f9c43c8aa3af250a"
|
||||
dependencies = [
|
||||
"cap-primitives",
|
||||
"io-extras",
|
||||
|
@ -2518,9 +2518,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cap-time-ext"
|
||||
version = "3.4.2"
|
||||
version = "3.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd736b20fc033f564a1995fb82fc349146de43aabba19c7368b4cb17d8f9ea53"
|
||||
checksum = "eb57b71bb69b97c638ec38b477e9b33fa1c1cff0e437dd55d15c117cf17ed5dc"
|
||||
dependencies = [
|
||||
"ambient-authority",
|
||||
"cap-primitives",
|
||||
|
@ -2598,9 +2598,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.17"
|
||||
version = "1.2.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fcb57c740ae1daf453ae85f16e37396f672b039e00d9d866e07ddb24e328e3a"
|
||||
checksum = "525046617d8376e3db1deffb079e91cef90a89fc3ca5c185bbf8c9ecdd15cd5c"
|
||||
dependencies = [
|
||||
"jobserver",
|
||||
"libc",
|
||||
|
@ -3926,9 +3926,9 @@ checksum = "4f211af61d8efdd104f96e57adf5e426ba1bc3ed7a4ead616e15e5881fd79c4d"
|
|||
|
||||
[[package]]
|
||||
name = "ctrlc"
|
||||
version = "3.4.5"
|
||||
version = "3.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "90eeab0aa92f3f9b4e87f258c72b139c207d251f9cbc1080a0086b86a8870dd3"
|
||||
checksum = "697b5419f348fd5ae2478e8018cb016c00a5881c7f46c717de98ffd135a5651c"
|
||||
dependencies = [
|
||||
"nix",
|
||||
"windows-sys 0.59.0",
|
||||
|
@ -4805,9 +4805,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.10"
|
||||
version = "0.3.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d"
|
||||
checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.59.0",
|
||||
|
@ -5252,9 +5252,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
|
|||
|
||||
[[package]]
|
||||
name = "flate2"
|
||||
version = "1.1.0"
|
||||
version = "1.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc"
|
||||
checksum = "7ced92e76e966ca2fd84c8f7aa01a4aea65b0eb6648d72f7c8f3e2764a67fece"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"miniz_oxide",
|
||||
|
@ -7891,9 +7891,9 @@ checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa"
|
|||
|
||||
[[package]]
|
||||
name = "libmimalloc-sys"
|
||||
version = "0.1.40"
|
||||
version = "0.1.41"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07d0e07885d6a754b9c7993f2625187ad694ee985d60f23355ff0e7077261502"
|
||||
checksum = "6b20daca3a4ac14dbdc753c5e90fc7b490a48a9131daed3c9a9ced7b2defd37b"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
|
@ -8562,9 +8562,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "mimalloc"
|
||||
version = "0.1.44"
|
||||
version = "0.1.45"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "99585191385958383e13f6b822e6b6d8d9cf928e7d286ceb092da92b43c87bc1"
|
||||
checksum = "03cb1f88093fe50061ca1195d336ffec131347c7b833db31f9ab62a2d1b7925f"
|
||||
dependencies = [
|
||||
"libmimalloc-sys",
|
||||
]
|
||||
|
@ -8593,9 +8593,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
|
|||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.8.5"
|
||||
version = "0.8.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5"
|
||||
checksum = "ff70ce3e48ae43fa075863cef62e8b43b71a4f2382229920e0df362592919430"
|
||||
dependencies = [
|
||||
"adler2",
|
||||
"simd-adler32",
|
||||
|
@ -9474,9 +9474,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.71"
|
||||
version = "0.10.72"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd"
|
||||
checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"cfg-if",
|
||||
|
@ -9515,9 +9515,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "openssl-sys"
|
||||
version = "0.9.106"
|
||||
version = "0.9.107"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd"
|
||||
checksum = "8288979acd84749c744a9014b4382d42b8f7b2592847b5afb2ed29e5d16ede07"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
|
@ -12149,7 +12149,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"errno 0.3.10",
|
||||
"errno 0.3.11",
|
||||
"itoa",
|
||||
"libc",
|
||||
"linux-raw-sys 0.4.15",
|
||||
|
@ -12164,7 +12164,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "d97817398dd4bb2e6da002002db259209759911da105da92bec29ccb12cf58bf"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"errno 0.3.10",
|
||||
"errno 0.3.11",
|
||||
"libc",
|
||||
"linux-raw-sys 0.9.3",
|
||||
"windows-sys 0.59.0",
|
||||
|
@ -12176,7 +12176,7 @@ version = "0.1.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a25c3aad9fc1424eb82c88087789a7d938e1829724f3e4043163baf0d13cfc12"
|
||||
dependencies = [
|
||||
"errno 0.3.10",
|
||||
"errno 0.3.11",
|
||||
"libc",
|
||||
"rustix 0.38.44",
|
||||
]
|
||||
|
@ -13798,9 +13798,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "swash"
|
||||
version = "0.2.1"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "13d5bbc2aa266907ed8ee977c9c9e16363cc2b001266104e13397b57f1d15f71"
|
||||
checksum = "fae9a562c7b46107d9c78cd78b75bbe1e991c16734c0aee8ff0ee711fb8b620a"
|
||||
dependencies = [
|
||||
"skrifa",
|
||||
"yazi",
|
||||
|
|
10
Cargo.toml
10
Cargo.toml
|
@ -399,11 +399,11 @@ async-trait = "0.1"
|
|||
async-tungstenite = "0.28"
|
||||
async-watch = "0.3.1"
|
||||
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
|
||||
aws-config = { version = "1.5.16", features = ["behavior-version-latest"] }
|
||||
aws-credential-types = { version = "1.2.1", features = ["hardcoded-credentials"] }
|
||||
aws-sdk-bedrockruntime = { version = "1.73.0", features = ["behavior-version-latest"] }
|
||||
aws-smithy-runtime-api = { version = "1.7.3", features = ["http-1x", "client"] }
|
||||
aws-smithy-types = { version = "1.2.13", features = ["http-body-1-x"] }
|
||||
aws-config = { version = "1.6.1", features = ["behavior-version-latest"] }
|
||||
aws-credential-types = { version = "1.2.2", features = ["hardcoded-credentials"] }
|
||||
aws-sdk-bedrockruntime = { version = "1.80.0", features = ["behavior-version-latest"] }
|
||||
aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] }
|
||||
aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] }
|
||||
base64 = "0.22"
|
||||
bitflags = "2.6.0"
|
||||
blade-graphics = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" }
|
||||
|
|
|
@ -1,21 +1,25 @@
|
|||
mod models;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
|
||||
use anyhow::{Context, Error, Result, anyhow};
|
||||
use anyhow::{Error, Result, anyhow};
|
||||
use aws_sdk_bedrockruntime as bedrock;
|
||||
pub use aws_sdk_bedrockruntime as bedrock_client;
|
||||
pub use aws_sdk_bedrockruntime::types::{
|
||||
ContentBlock as BedrockInnerContent, SpecificToolChoice as BedrockSpecificTool,
|
||||
ToolChoice as BedrockToolChoice, ToolInputSchema as BedrockToolInputSchema,
|
||||
ToolSpecification as BedrockTool,
|
||||
AutoToolChoice as BedrockAutoToolChoice, ContentBlock as BedrockInnerContent,
|
||||
Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolConfiguration as BedrockToolConfig,
|
||||
ToolInputSchema as BedrockToolInputSchema, ToolSpecification as BedrockToolSpec,
|
||||
};
|
||||
use aws_smithy_types::{Document, Number as AwsNumber};
|
||||
pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
|
||||
pub use bedrock::types::{
|
||||
ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
|
||||
ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
|
||||
Message as BedrockMessage, ResponseStream as BedrockResponseStream,
|
||||
ImageBlock as BedrockImageBlock, Message as BedrockMessage,
|
||||
ResponseStream as BedrockResponseStream, ToolResultBlock as BedrockToolResultBlock,
|
||||
ToolResultContentBlock as BedrockToolResultContentBlock,
|
||||
ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
|
||||
};
|
||||
use futures::stream::{self, BoxStream, Stream};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -24,25 +28,6 @@ use thiserror::Error;
|
|||
|
||||
pub use crate::models::*;
|
||||
|
||||
pub async fn complete(
|
||||
client: &bedrock::Client,
|
||||
request: Request,
|
||||
) -> Result<BedrockResponse, BedrockError> {
|
||||
let response = bedrock::Client::converse(client)
|
||||
.model_id(request.model.clone())
|
||||
.set_messages(request.messages.into())
|
||||
.send()
|
||||
.await
|
||||
.context("failed to send request to Bedrock");
|
||||
|
||||
match response {
|
||||
Ok(output) => output
|
||||
.output
|
||||
.ok_or_else(|| BedrockError::Other(anyhow!("no output"))),
|
||||
Err(err) => Err(BedrockError::Other(err)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
client: bedrock::Client,
|
||||
request: Request,
|
||||
|
@ -50,11 +35,32 @@ pub async fn stream_completion(
|
|||
) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
|
||||
handle
|
||||
.spawn(async move {
|
||||
let response = bedrock::Client::converse_stream(&client)
|
||||
let mut response = bedrock::Client::converse_stream(&client)
|
||||
.model_id(request.model.clone())
|
||||
.set_messages(request.messages.into())
|
||||
.send()
|
||||
.await;
|
||||
.set_messages(request.messages.into());
|
||||
|
||||
if let Some(Thinking::Enabled {
|
||||
budget_tokens: Some(budget_tokens),
|
||||
}) = request.thinking
|
||||
{
|
||||
response =
|
||||
response.additional_model_request_fields(Document::Object(HashMap::from([(
|
||||
"thinking".to_string(),
|
||||
Document::from(HashMap::from([
|
||||
("type".to_string(), Document::String("enabled".to_string())),
|
||||
(
|
||||
"budget_tokens".to_string(),
|
||||
Document::Number(AwsNumber::PosInt(budget_tokens)),
|
||||
),
|
||||
])),
|
||||
)])));
|
||||
}
|
||||
|
||||
if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() {
|
||||
response = response.set_tool_config(request.tools);
|
||||
}
|
||||
|
||||
let response = response.send().await;
|
||||
|
||||
match response {
|
||||
Ok(output) => {
|
||||
|
@ -65,7 +71,7 @@ pub async fn stream_completion(
|
|||
>,
|
||||
> = Box::pin(stream::unfold(output.stream, |mut stream| async move {
|
||||
match stream.recv().await {
|
||||
Ok(Some(output)) => Some((Ok(output), stream)),
|
||||
Ok(Some(output)) => Some(({ Ok(output) }, stream)),
|
||||
Ok(None) => None,
|
||||
Err(err) => {
|
||||
Some((
|
||||
|
@ -135,13 +141,18 @@ pub fn value_to_aws_document(value: &Value) -> Document {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub enum Thinking {
|
||||
Enabled { budget_tokens: Option<u64> },
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Request {
|
||||
pub model: String,
|
||||
pub max_tokens: u32,
|
||||
pub messages: Vec<BedrockMessage>,
|
||||
pub tools: Vec<BedrockTool>,
|
||||
pub tool_choice: Option<BedrockToolChoice>,
|
||||
pub tools: Option<BedrockToolConfig>,
|
||||
pub thinking: Option<Thinking>,
|
||||
pub system: Option<String>,
|
||||
pub metadata: Option<Metadata>,
|
||||
pub stop_sequences: Vec<String>,
|
||||
|
|
|
@ -2,21 +2,38 @@ use anyhow::anyhow;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use strum::EnumIter;
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub enum BedrockModelMode {
|
||||
#[default]
|
||||
Default,
|
||||
Thinking {
|
||||
budget_tokens: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||
pub enum Model {
|
||||
// Anthropic models (already included)
|
||||
#[default]
|
||||
#[serde(rename = "claude-3-5-sonnet-v2", alias = "claude-3-5-sonnet-latest")]
|
||||
Claude3_5Sonnet,
|
||||
Claude3_5SonnetV2,
|
||||
#[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")]
|
||||
Claude3_7Sonnet,
|
||||
#[serde(
|
||||
rename = "claude-3-7-sonnet-thinking",
|
||||
alias = "claude-3-7-sonnet-thinking-latest"
|
||||
)]
|
||||
Claude3_7SonnetThinking,
|
||||
#[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
|
||||
Claude3Opus,
|
||||
#[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
|
||||
Claude3Sonnet,
|
||||
#[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
|
||||
Claude3_5Haiku,
|
||||
Claude3_5Sonnet,
|
||||
Claude3Haiku,
|
||||
// Amazon Nova Models
|
||||
AmazonNovaLite,
|
||||
AmazonNovaMicro,
|
||||
|
@ -69,7 +86,7 @@ pub enum Model {
|
|||
impl Model {
|
||||
pub fn from_id(id: &str) -> anyhow::Result<Self> {
|
||||
if id.starts_with("claude-3-5-sonnet-v2") {
|
||||
Ok(Self::Claude3_5Sonnet)
|
||||
Ok(Self::Claude3_5SonnetV2)
|
||||
} else if id.starts_with("claude-3-opus") {
|
||||
Ok(Self::Claude3Opus)
|
||||
} else if id.starts_with("claude-3-sonnet") {
|
||||
|
@ -78,6 +95,8 @@ impl Model {
|
|||
Ok(Self::Claude3_5Haiku)
|
||||
} else if id.starts_with("claude-3-7-sonnet") {
|
||||
Ok(Self::Claude3_7Sonnet)
|
||||
} else if id.starts_with("claude-3-7-sonnet-thinking") {
|
||||
Ok(Self::Claude3_7SonnetThinking)
|
||||
} else {
|
||||
Err(anyhow!("invalid model id"))
|
||||
}
|
||||
|
@ -85,14 +104,18 @@ impl Model {
|
|||
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
Model::Claude3_5Sonnet => "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
Model::Claude3Opus => "us.anthropic.claude-3-opus-20240229-v1:0",
|
||||
Model::Claude3Sonnet => "us.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
Model::Claude3_5Haiku => "us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
Model::Claude3_7Sonnet => "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
Model::AmazonNovaLite => "us.amazon.nova-lite-v1:0",
|
||||
Model::AmazonNovaMicro => "us.amazon.nova-micro-v1:0",
|
||||
Model::AmazonNovaPro => "us.amazon.nova-pro-v1:0",
|
||||
Model::Claude3_5SonnetV2 => "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
Model::Claude3_5Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
Model::Claude3Opus => "anthropic.claude-3-opus-20240229-v1:0",
|
||||
Model::Claude3Sonnet => "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
Model::Claude3Haiku => "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
Model::Claude3_5Haiku => "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
Model::Claude3_7Sonnet | Model::Claude3_7SonnetThinking => {
|
||||
"anthropic.claude-3-7-sonnet-20250219-v1:0"
|
||||
}
|
||||
Model::AmazonNovaLite => "amazon.nova-lite-v1:0",
|
||||
Model::AmazonNovaMicro => "amazon.nova-micro-v1:0",
|
||||
Model::AmazonNovaPro => "amazon.nova-pro-v1:0",
|
||||
Model::DeepSeekR1 => "us.deepseek.r1-v1:0",
|
||||
Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct",
|
||||
Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct",
|
||||
|
@ -128,11 +151,14 @@ impl Model {
|
|||
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet v2",
|
||||
Self::Claude3_5SonnetV2 => "Claude 3.5 Sonnet v2",
|
||||
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
|
||||
Self::Claude3Opus => "Claude 3 Opus",
|
||||
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||
Self::Claude3_5Haiku => "Claude 3.5 Haiku",
|
||||
Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
|
||||
Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
|
||||
Self::AmazonNovaLite => "Amazon Nova Lite",
|
||||
Self::AmazonNovaMicro => "Amazon Nova Micro",
|
||||
Self::AmazonNovaPro => "Amazon Nova Pro",
|
||||
|
@ -173,7 +199,7 @@ impl Model {
|
|||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
Self::Claude3_5Sonnet
|
||||
Self::Claude3_5SonnetV2
|
||||
| Self::Claude3Opus
|
||||
| Self::Claude3Sonnet
|
||||
| Self::Claude3_5Haiku
|
||||
|
@ -186,7 +212,8 @@ impl Model {
|
|||
pub fn max_output_tokens(&self) -> u32 {
|
||||
match self {
|
||||
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096,
|
||||
Self::Claude3_5Sonnet => 8_192,
|
||||
Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => 128_000,
|
||||
Self::Claude3_5SonnetV2 => 8_192,
|
||||
Self::Custom {
|
||||
max_output_tokens, ..
|
||||
} => max_output_tokens.unwrap_or(4_096),
|
||||
|
@ -196,7 +223,7 @@ impl Model {
|
|||
|
||||
pub fn default_temperature(&self) -> f32 {
|
||||
match self {
|
||||
Self::Claude3_5Sonnet
|
||||
Self::Claude3_5SonnetV2
|
||||
| Self::Claude3Opus
|
||||
| Self::Claude3Sonnet
|
||||
| Self::Claude3_5Haiku
|
||||
|
@ -208,4 +235,253 @@ impl Model {
|
|||
_ => 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supports_tool_use(&self) -> bool {
|
||||
match self {
|
||||
// Anthropic Claude 3 models (all support tool use)
|
||||
Self::Claude3Opus
|
||||
| Self::Claude3Sonnet
|
||||
| Self::Claude3_5Sonnet
|
||||
| Self::Claude3_5SonnetV2
|
||||
| Self::Claude3_7Sonnet
|
||||
| Self::Claude3_7SonnetThinking
|
||||
| Self::Claude3_5Haiku => true,
|
||||
|
||||
// Amazon Nova models (all support tool use)
|
||||
Self::AmazonNovaPro | Self::AmazonNovaLite | Self::AmazonNovaMicro => true,
|
||||
|
||||
// AI21 Jamba 1.5 models support tool use
|
||||
Self::AI21Jamba15LargeV1 | Self::AI21Jamba15MiniV1 => true,
|
||||
|
||||
// Cohere Command R models support tool use
|
||||
Self::CohereCommandRV1 | Self::CohereCommandRPlusV1 => true,
|
||||
|
||||
// All other models don't support tool use
|
||||
// Including Meta Llama 3.2, AI21 Jurassic, and others
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mode(&self) -> BedrockModelMode {
|
||||
match self {
|
||||
Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
|
||||
budget_tokens: Some(4096),
|
||||
},
|
||||
_ => BedrockModelMode::Default,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cross_region_inference_id(&self, region: &str) -> Result<String, anyhow::Error> {
|
||||
let region_group = if region.starts_with("us-gov-") {
|
||||
"us-gov"
|
||||
} else if region.starts_with("us-") {
|
||||
"us"
|
||||
} else if region.starts_with("eu-") {
|
||||
"eu"
|
||||
} else if region.starts_with("ap-") || region == "me-central-1" || region == "me-south-1" {
|
||||
"apac"
|
||||
} else if region.starts_with("ca-") || region.starts_with("sa-") {
|
||||
// Canada and South America regions - default to US profiles
|
||||
"us"
|
||||
} else {
|
||||
// Unknown region
|
||||
return Err(anyhow!("Unsupported Region"));
|
||||
};
|
||||
|
||||
let model_id = self.id();
|
||||
|
||||
match (self, region_group) {
|
||||
// Custom models can't have CRI IDs
|
||||
(Model::Custom { .. }, _) => Ok(self.id().into()),
|
||||
|
||||
// Models with US Gov only
|
||||
(Model::Claude3_5Sonnet, "us-gov") | (Model::Claude3Haiku, "us-gov") => {
|
||||
Ok(format!("{}.{}", region_group, model_id))
|
||||
}
|
||||
|
||||
// Models available only in US
|
||||
(Model::Claude3Opus, "us")
|
||||
| (Model::Claude3_7Sonnet, "us")
|
||||
| (Model::Claude3_7SonnetThinking, "us") => {
|
||||
Ok(format!("{}.{}", region_group, model_id))
|
||||
}
|
||||
|
||||
// Models available in US, EU, and APAC
|
||||
(Model::Claude3_5SonnetV2, "us")
|
||||
| (Model::Claude3_5SonnetV2, "apac")
|
||||
| (Model::Claude3_5Sonnet, _)
|
||||
| (Model::Claude3Haiku, _)
|
||||
| (Model::Claude3Sonnet, _)
|
||||
| (Model::AmazonNovaLite, _)
|
||||
| (Model::AmazonNovaMicro, _)
|
||||
| (Model::AmazonNovaPro, _) => Ok(format!("{}.{}", region_group, model_id)),
|
||||
|
||||
// Models with limited EU availability
|
||||
(Model::MetaLlama321BInstructV1, "us")
|
||||
| (Model::MetaLlama321BInstructV1, "eu")
|
||||
| (Model::MetaLlama323BInstructV1, "us")
|
||||
| (Model::MetaLlama323BInstructV1, "eu") => {
|
||||
Ok(format!("{}.{}", region_group, model_id))
|
||||
}
|
||||
|
||||
// US-only models (all remaining Meta models)
|
||||
(Model::MetaLlama38BInstructV1, "us")
|
||||
| (Model::MetaLlama370BInstructV1, "us")
|
||||
| (Model::MetaLlama318BInstructV1, "us")
|
||||
| (Model::MetaLlama318BInstructV1_128k, "us")
|
||||
| (Model::MetaLlama3170BInstructV1, "us")
|
||||
| (Model::MetaLlama3170BInstructV1_128k, "us")
|
||||
| (Model::MetaLlama3211BInstructV1, "us")
|
||||
| (Model::MetaLlama3290BInstructV1, "us") => {
|
||||
Ok(format!("{}.{}", region_group, model_id))
|
||||
}
|
||||
|
||||
// Any other combination is not supported
|
||||
_ => Ok(self.id().into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_us_region_inference_ids() -> anyhow::Result<()> {
|
||||
// Test US regions
|
||||
assert_eq!(
|
||||
Model::Claude3_5SonnetV2.cross_region_inference_id("us-east-1")?,
|
||||
"us.anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::Claude3_5SonnetV2.cross_region_inference_id("us-west-2")?,
|
||||
"us.anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::AmazonNovaPro.cross_region_inference_id("us-east-2")?,
|
||||
"us.amazon.nova-pro-v1:0"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_eu_region_inference_ids() -> anyhow::Result<()> {
|
||||
// Test European regions
|
||||
assert_eq!(
|
||||
Model::Claude3Sonnet.cross_region_inference_id("eu-west-1")?,
|
||||
"eu.anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::AmazonNovaMicro.cross_region_inference_id("eu-north-1")?,
|
||||
"eu.amazon.nova-micro-v1:0"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apac_region_inference_ids() -> anyhow::Result<()> {
|
||||
// Test Asia-Pacific regions
|
||||
assert_eq!(
|
||||
Model::Claude3_5SonnetV2.cross_region_inference_id("ap-northeast-1")?,
|
||||
"apac.anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::AmazonNovaLite.cross_region_inference_id("ap-south-1")?,
|
||||
"apac.amazon.nova-lite-v1:0"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gov_region_inference_ids() -> anyhow::Result<()> {
|
||||
// Test Government regions
|
||||
assert_eq!(
|
||||
Model::Claude3_5Sonnet.cross_region_inference_id("us-gov-east-1")?,
|
||||
"us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::Claude3Haiku.cross_region_inference_id("us-gov-west-1")?,
|
||||
"us-gov.anthropic.claude-3-haiku-20240307-v1:0"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_meta_models_inference_ids() -> anyhow::Result<()> {
|
||||
// Test Meta models
|
||||
assert_eq!(
|
||||
Model::MetaLlama370BInstructV1.cross_region_inference_id("us-east-1")?,
|
||||
"us.meta.llama3-70b-instruct-v1:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::MetaLlama321BInstructV1.cross_region_inference_id("eu-west-1")?,
|
||||
"eu.meta.llama3-2-1b-instruct-v1:0"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mistral_models_inference_ids() -> anyhow::Result<()> {
|
||||
// Mistral models don't follow the regional prefix pattern,
|
||||
// so they should return their original IDs
|
||||
assert_eq!(
|
||||
Model::MistralMistralLarge2402V1.cross_region_inference_id("us-east-1")?,
|
||||
"mistral.mistral-large-2402-v1:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::MistralMixtral8x7BInstructV0.cross_region_inference_id("eu-west-1")?,
|
||||
"mistral.mixtral-8x7b-instruct-v0:1"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ai21_models_inference_ids() -> anyhow::Result<()> {
|
||||
// AI21 models don't follow the regional prefix pattern,
|
||||
// so they should return their original IDs
|
||||
assert_eq!(
|
||||
Model::AI21J2UltraV1.cross_region_inference_id("us-east-1")?,
|
||||
"ai21.j2-ultra-v1"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::AI21JambaInstructV1.cross_region_inference_id("eu-west-1")?,
|
||||
"ai21.jamba-instruct-v1:0"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cohere_models_inference_ids() -> anyhow::Result<()> {
|
||||
// Cohere models don't follow the regional prefix pattern,
|
||||
// so they should return their original IDs
|
||||
assert_eq!(
|
||||
Model::CohereCommandRV1.cross_region_inference_id("us-east-1")?,
|
||||
"cohere.command-r-v1:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::CohereCommandTextV14_4k.cross_region_inference_id("ap-southeast-1")?,
|
||||
"cohere.command-text-v14:7:4k"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_model_inference_ids() -> anyhow::Result<()> {
|
||||
// Test custom models
|
||||
let custom_model = Model::Custom {
|
||||
name: "custom.my-model-v1:0".to_string(),
|
||||
max_tokens: 100000,
|
||||
display_name: Some("My Custom Model".to_string()),
|
||||
max_output_tokens: Some(8192),
|
||||
default_temperature: Some(0.7),
|
||||
};
|
||||
|
||||
// Custom model should return its name unchanged
|
||||
assert_eq!(
|
||||
custom_model.cross_region_inference_id("us-east-1")?,
|
||||
"custom.my-model-v1:0"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -72,6 +72,7 @@ pub struct AllLanguageModelSettings {
|
|||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct AllLanguageModelSettingsContent {
|
||||
pub anthropic: Option<AnthropicSettingsContent>,
|
||||
pub bedrock: Option<AmazonBedrockSettingsContent>,
|
||||
pub ollama: Option<OllamaSettingsContent>,
|
||||
pub lmstudio: Option<LmStudioSettingsContent>,
|
||||
pub openai: Option<OpenAiSettingsContent>,
|
||||
|
@ -160,6 +161,15 @@ pub struct AnthropicSettingsContentV1 {
|
|||
pub available_models: Option<Vec<provider::anthropic::AvailableModel>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct AmazonBedrockSettingsContent {
|
||||
available_models: Option<Vec<provider::bedrock::AvailableModel>>,
|
||||
endpoint_url: Option<String>,
|
||||
region: Option<String>,
|
||||
profile: Option<String>,
|
||||
authentication_method: Option<provider::bedrock::BedrockAuthMethod>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct OllamaSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
|
@ -297,6 +307,25 @@ impl settings::Settings for AllLanguageModelSettings {
|
|||
anthropic.as_ref().and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
// Bedrock
|
||||
let bedrock = value.bedrock.clone();
|
||||
merge(
|
||||
&mut settings.bedrock.profile_name,
|
||||
bedrock.as_ref().map(|s| s.profile.clone()),
|
||||
);
|
||||
merge(
|
||||
&mut settings.bedrock.authentication_method,
|
||||
bedrock.as_ref().map(|s| s.authentication_method.clone()),
|
||||
);
|
||||
merge(
|
||||
&mut settings.bedrock.region,
|
||||
bedrock.as_ref().map(|s| s.region.clone()),
|
||||
);
|
||||
merge(
|
||||
&mut settings.bedrock.endpoint,
|
||||
bedrock.as_ref().map(|s| s.endpoint_url.clone()),
|
||||
);
|
||||
|
||||
// Ollama
|
||||
let ollama = value.ollama.clone();
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue