diff --git a/Cargo.lock b/Cargo.lock index 3a1fbe5e17..52a3417047 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2141,9 +2141,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34a796731680be7931955498a16a10b2270c7762963d5d570fdbfe02dcbf314f" +checksum = "389a099b34312839e16420d499a9cad9650541715937ffbdd40d36f49e77eeb3" dependencies = [ "arrayref", "arrayvec", @@ -2455,9 +2455,9 @@ dependencies = [ [[package]] name = "cap-fs-ext" -version = "3.4.2" +version = "3.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f78efdd7378980d79c0f36b519e51191742d2c9f91ffa5e228fba9f3806d2e1" +checksum = "f6323b9baffb4d6d9c65bfef3129db62b1391f7fb953fe67c0d7cb0804feb77b" dependencies = [ "cap-primitives", "cap-std", @@ -2467,9 +2467,9 @@ dependencies = [ [[package]] name = "cap-net-ext" -version = "3.4.2" +version = "3.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ac68674a6042af2bcee1adad9f6abd432642cf03444ce3a5b36c3f39f23baf8" +checksum = "66022e5e076ea27041a05ebd4349022e2133e6f4977974dffd54ceb7b982e871" dependencies = [ "cap-primitives", "cap-std", @@ -2479,9 +2479,9 @@ dependencies = [ [[package]] name = "cap-primitives" -version = "3.4.2" +version = "3.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fc15faeed2223d8b8e8cc1857f5861935a06d06713c4ac106b722ae9ce3c369" +checksum = "50ad0183a9659850877cefe8f5b87d564b2dd1fe78b18945813687f29c0a6878" dependencies = [ "ambient-authority", "fs-set-times", @@ -2496,9 +2496,9 @@ dependencies = [ [[package]] name = "cap-rand" -version = "3.4.2" +version = "3.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dea13372b49df066d1ae654e5c6e41799c1efd9f6b36794b921e877ea4037977" +checksum = "ab78a9f6301e70c0fe5df7328adbcb9228277fdb7bab36f312fc072f505e38c2" dependencies = [ "ambient-authority", "rand 0.8.5", @@ -2506,9 +2506,9 @@ dependencies = [ [[package]] name = "cap-std" -version = "3.4.2" +version = "3.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3dbd3e8e8d093d6ccb4b512264869e1281cdb032f7940bd50b2894f96f25609" +checksum = "1c41814365b796ed12688026cb90a1e03236a84ccf009628f9c43c8aa3af250a" dependencies = [ "cap-primitives", "io-extras", @@ -2518,9 +2518,9 @@ dependencies = [ [[package]] name = "cap-time-ext" -version = "3.4.2" +version = "3.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd736b20fc033f564a1995fb82fc349146de43aabba19c7368b4cb17d8f9ea53" +checksum = "eb57b71bb69b97c638ec38b477e9b33fa1c1cff0e437dd55d15c117cf17ed5dc" dependencies = [ "ambient-authority", "cap-primitives", @@ -2598,9 +2598,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.17" +version = "1.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fcb57c740ae1daf453ae85f16e37396f672b039e00d9d866e07ddb24e328e3a" +checksum = "525046617d8376e3db1deffb079e91cef90a89fc3ca5c185bbf8c9ecdd15cd5c" dependencies = [ "jobserver", "libc", @@ -3926,9 +3926,9 @@ checksum = "4f211af61d8efdd104f96e57adf5e426ba1bc3ed7a4ead616e15e5881fd79c4d" [[package]] name = "ctrlc" -version = "3.4.5" +version = "3.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90eeab0aa92f3f9b4e87f258c72b139c207d251f9cbc1080a0086b86a8870dd3" +checksum = "697b5419f348fd5ae2478e8018cb016c00a5881c7f46c717de98ffd135a5651c" dependencies = [ "nix", "windows-sys 0.59.0", @@ -4805,9 +4805,9 @@ dependencies = [ [[package]] name = "errno" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e" dependencies = [ "libc", "windows-sys 0.59.0", @@ -5252,9 +5252,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flate2" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc" +checksum = "7ced92e76e966ca2fd84c8f7aa01a4aea65b0eb6648d72f7c8f3e2764a67fece" dependencies = [ "crc32fast", "miniz_oxide", @@ -7891,9 +7891,9 @@ checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" [[package]] name = "libmimalloc-sys" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07d0e07885d6a754b9c7993f2625187ad694ee985d60f23355ff0e7077261502" +checksum = "6b20daca3a4ac14dbdc753c5e90fc7b490a48a9131daed3c9a9ced7b2defd37b" dependencies = [ "cc", "libc", @@ -8562,9 +8562,9 @@ dependencies = [ [[package]] name = "mimalloc" -version = "0.1.44" +version = "0.1.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99585191385958383e13f6b822e6b6d8d9cf928e7d286ceb092da92b43c87bc1" +checksum = "03cb1f88093fe50061ca1195d336ffec131347c7b833db31f9ab62a2d1b7925f" dependencies = [ "libmimalloc-sys", ] @@ -8593,9 +8593,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.5" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" +checksum = "ff70ce3e48ae43fa075863cef62e8b43b71a4f2382229920e0df362592919430" dependencies = [ "adler2", "simd-adler32", @@ -9474,9 +9474,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.71" +version = "0.10.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd" +checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da" dependencies = [ "bitflags 2.9.0", "cfg-if", @@ -9515,9 +9515,9 @@ dependencies = [ [[package]] name = "openssl-sys" -version = "0.9.106" +version = "0.9.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd" +checksum = "8288979acd84749c744a9014b4382d42b8f7b2592847b5afb2ed29e5d16ede07" dependencies = [ "cc", "libc", @@ -12149,7 +12149,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ "bitflags 2.9.0", - "errno 0.3.10", + "errno 0.3.11", "itoa", "libc", "linux-raw-sys 0.4.15", @@ -12164,7 +12164,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d97817398dd4bb2e6da002002db259209759911da105da92bec29ccb12cf58bf" dependencies = [ "bitflags 2.9.0", - "errno 0.3.10", + "errno 0.3.11", "libc", "linux-raw-sys 0.9.3", "windows-sys 0.59.0", @@ -12176,7 +12176,7 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a25c3aad9fc1424eb82c88087789a7d938e1829724f3e4043163baf0d13cfc12" dependencies = [ - "errno 0.3.10", + "errno 0.3.11", "libc", "rustix 0.38.44", ] @@ -13798,9 +13798,9 @@ dependencies = [ [[package]] name = "swash" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13d5bbc2aa266907ed8ee977c9c9e16363cc2b001266104e13397b57f1d15f71" +checksum = "fae9a562c7b46107d9c78cd78b75bbe1e991c16734c0aee8ff0ee711fb8b620a" dependencies = [ "skrifa", "yazi", diff --git a/Cargo.toml b/Cargo.toml index bd4b9bbd2f..8316b1e865 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -399,11 +399,11 @@ async-trait = "0.1" async-tungstenite = "0.28" async-watch = "0.3.1" async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] } -aws-config = { version = "1.5.16", features = ["behavior-version-latest"] } -aws-credential-types = { version = "1.2.1", features = ["hardcoded-credentials"] } -aws-sdk-bedrockruntime = { version = "1.73.0", features = ["behavior-version-latest"] } -aws-smithy-runtime-api = { version = "1.7.3", features = ["http-1x", "client"] } -aws-smithy-types = { version = "1.2.13", features = ["http-body-1-x"] } +aws-config = { version = "1.6.1", features = ["behavior-version-latest"] } +aws-credential-types = { version = "1.2.2", features = ["hardcoded-credentials"] } +aws-sdk-bedrockruntime = { version = "1.80.0", features = ["behavior-version-latest"] } +aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] } +aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] } base64 = "0.22" bitflags = "2.6.0" blade-graphics = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" } diff --git a/crates/bedrock/src/bedrock.rs b/crates/bedrock/src/bedrock.rs index df6701e474..5b34388a9a 100644 --- a/crates/bedrock/src/bedrock.rs +++ b/crates/bedrock/src/bedrock.rs @@ -1,21 +1,25 @@ mod models; +use std::collections::HashMap; use std::pin::Pin; -use anyhow::{Context, Error, Result, anyhow}; +use anyhow::{Error, Result, anyhow}; use aws_sdk_bedrockruntime as bedrock; pub use aws_sdk_bedrockruntime as bedrock_client; pub use aws_sdk_bedrockruntime::types::{ - ContentBlock as BedrockInnerContent, SpecificToolChoice as BedrockSpecificTool, - ToolChoice as BedrockToolChoice, ToolInputSchema as BedrockToolInputSchema, - ToolSpecification as BedrockTool, + AutoToolChoice as BedrockAutoToolChoice, ContentBlock as BedrockInnerContent, + Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolConfiguration as BedrockToolConfig, + ToolInputSchema as BedrockToolInputSchema, ToolSpecification as BedrockToolSpec, }; use aws_smithy_types::{Document, Number as AwsNumber}; pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest; pub use bedrock::types::{ ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole, ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse, - Message as BedrockMessage, ResponseStream as BedrockResponseStream, + ImageBlock as BedrockImageBlock, Message as BedrockMessage, + ResponseStream as BedrockResponseStream, ToolResultBlock as BedrockToolResultBlock, + ToolResultContentBlock as BedrockToolResultContentBlock, + ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock, }; use futures::stream::{self, BoxStream, Stream}; use serde::{Deserialize, Serialize}; @@ -24,25 +28,6 @@ use thiserror::Error; pub use crate::models::*; -pub async fn complete( - client: &bedrock::Client, - request: Request, -) -> Result { - let response = bedrock::Client::converse(client) - .model_id(request.model.clone()) - .set_messages(request.messages.into()) - .send() - .await - .context("failed to send request to Bedrock"); - - match response { - Ok(output) => output - .output - .ok_or_else(|| BedrockError::Other(anyhow!("no output"))), - Err(err) => Err(BedrockError::Other(err)), - } -} - pub async fn stream_completion( client: bedrock::Client, request: Request, @@ -50,11 +35,32 @@ pub async fn stream_completion( ) -> Result>, Error> { handle .spawn(async move { - let response = bedrock::Client::converse_stream(&client) + let mut response = bedrock::Client::converse_stream(&client) .model_id(request.model.clone()) - .set_messages(request.messages.into()) - .send() - .await; + .set_messages(request.messages.into()); + + if let Some(Thinking::Enabled { + budget_tokens: Some(budget_tokens), + }) = request.thinking + { + response = + response.additional_model_request_fields(Document::Object(HashMap::from([( + "thinking".to_string(), + Document::from(HashMap::from([ + ("type".to_string(), Document::String("enabled".to_string())), + ( + "budget_tokens".to_string(), + Document::Number(AwsNumber::PosInt(budget_tokens)), + ), + ])), + )]))); + } + + if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() { + response = response.set_tool_config(request.tools); + } + + let response = response.send().await; match response { Ok(output) => { @@ -65,7 +71,7 @@ pub async fn stream_completion( >, > = Box::pin(stream::unfold(output.stream, |mut stream| async move { match stream.recv().await { - Ok(Some(output)) => Some((Ok(output), stream)), + Ok(Some(output)) => Some(({ Ok(output) }, stream)), Ok(None) => None, Err(err) => { Some(( @@ -135,13 +141,18 @@ pub fn value_to_aws_document(value: &Value) -> Document { } } +#[derive(Debug, Serialize, Deserialize)] +pub enum Thinking { + Enabled { budget_tokens: Option }, +} + #[derive(Debug)] pub struct Request { pub model: String, pub max_tokens: u32, pub messages: Vec, - pub tools: Vec, - pub tool_choice: Option, + pub tools: Option, + pub thinking: Option, pub system: Option, pub metadata: Option, pub stop_sequences: Vec, diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs index 26c6d35dd2..052e5c2ca1 100644 --- a/crates/bedrock/src/models.rs +++ b/crates/bedrock/src/models.rs @@ -2,21 +2,38 @@ use anyhow::anyhow; use serde::{Deserialize, Serialize}; use strum::EnumIter; +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub enum BedrockModelMode { + #[default] + Default, + Thinking { + budget_tokens: Option, + }, +} + #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] pub enum Model { // Anthropic models (already included) #[default] #[serde(rename = "claude-3-5-sonnet-v2", alias = "claude-3-5-sonnet-latest")] - Claude3_5Sonnet, + Claude3_5SonnetV2, #[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")] Claude3_7Sonnet, + #[serde( + rename = "claude-3-7-sonnet-thinking", + alias = "claude-3-7-sonnet-thinking-latest" + )] + Claude3_7SonnetThinking, #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")] Claude3Opus, #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")] Claude3Sonnet, #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")] Claude3_5Haiku, + Claude3_5Sonnet, + Claude3Haiku, // Amazon Nova Models AmazonNovaLite, AmazonNovaMicro, @@ -69,7 +86,7 @@ pub enum Model { impl Model { pub fn from_id(id: &str) -> anyhow::Result { if id.starts_with("claude-3-5-sonnet-v2") { - Ok(Self::Claude3_5Sonnet) + Ok(Self::Claude3_5SonnetV2) } else if id.starts_with("claude-3-opus") { Ok(Self::Claude3Opus) } else if id.starts_with("claude-3-sonnet") { @@ -78,6 +95,8 @@ impl Model { Ok(Self::Claude3_5Haiku) } else if id.starts_with("claude-3-7-sonnet") { Ok(Self::Claude3_7Sonnet) + } else if id.starts_with("claude-3-7-sonnet-thinking") { + Ok(Self::Claude3_7SonnetThinking) } else { Err(anyhow!("invalid model id")) } @@ -85,14 +104,18 @@ impl Model { pub fn id(&self) -> &str { match self { - Model::Claude3_5Sonnet => "us.anthropic.claude-3-5-sonnet-20241022-v2:0", - Model::Claude3Opus => "us.anthropic.claude-3-opus-20240229-v1:0", - Model::Claude3Sonnet => "us.anthropic.claude-3-sonnet-20240229-v1:0", - Model::Claude3_5Haiku => "us.anthropic.claude-3-5-haiku-20241022-v1:0", - Model::Claude3_7Sonnet => "us.anthropic.claude-3-7-sonnet-20250219-v1:0", - Model::AmazonNovaLite => "us.amazon.nova-lite-v1:0", - Model::AmazonNovaMicro => "us.amazon.nova-micro-v1:0", - Model::AmazonNovaPro => "us.amazon.nova-pro-v1:0", + Model::Claude3_5SonnetV2 => "anthropic.claude-3-5-sonnet-20241022-v2:0", + Model::Claude3_5Sonnet => "anthropic.claude-3-5-sonnet-20240620-v1:0", + Model::Claude3Opus => "anthropic.claude-3-opus-20240229-v1:0", + Model::Claude3Sonnet => "anthropic.claude-3-sonnet-20240229-v1:0", + Model::Claude3Haiku => "anthropic.claude-3-haiku-20240307-v1:0", + Model::Claude3_5Haiku => "anthropic.claude-3-5-haiku-20241022-v1:0", + Model::Claude3_7Sonnet | Model::Claude3_7SonnetThinking => { + "anthropic.claude-3-7-sonnet-20250219-v1:0" + } + Model::AmazonNovaLite => "amazon.nova-lite-v1:0", + Model::AmazonNovaMicro => "amazon.nova-micro-v1:0", + Model::AmazonNovaPro => "amazon.nova-pro-v1:0", Model::DeepSeekR1 => "us.deepseek.r1-v1:0", Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct", Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct", @@ -128,11 +151,14 @@ impl Model { pub fn display_name(&self) -> &str { match self { - Self::Claude3_5Sonnet => "Claude 3.5 Sonnet v2", + Self::Claude3_5SonnetV2 => "Claude 3.5 Sonnet v2", + Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", Self::Claude3Opus => "Claude 3 Opus", Self::Claude3Sonnet => "Claude 3 Sonnet", + Self::Claude3Haiku => "Claude 3 Haiku", Self::Claude3_5Haiku => "Claude 3.5 Haiku", Self::Claude3_7Sonnet => "Claude 3.7 Sonnet", + Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking", Self::AmazonNovaLite => "Amazon Nova Lite", Self::AmazonNovaMicro => "Amazon Nova Micro", Self::AmazonNovaPro => "Amazon Nova Pro", @@ -173,7 +199,7 @@ impl Model { pub fn max_token_count(&self) -> usize { match self { - Self::Claude3_5Sonnet + Self::Claude3_5SonnetV2 | Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku @@ -186,7 +212,8 @@ impl Model { pub fn max_output_tokens(&self) -> u32 { match self { Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096, - Self::Claude3_5Sonnet => 8_192, + Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => 128_000, + Self::Claude3_5SonnetV2 => 8_192, Self::Custom { max_output_tokens, .. } => max_output_tokens.unwrap_or(4_096), @@ -196,7 +223,7 @@ impl Model { pub fn default_temperature(&self) -> f32 { match self { - Self::Claude3_5Sonnet + Self::Claude3_5SonnetV2 | Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku @@ -208,4 +235,253 @@ impl Model { _ => 1.0, } } + + pub fn supports_tool_use(&self) -> bool { + match self { + // Anthropic Claude 3 models (all support tool use) + Self::Claude3Opus + | Self::Claude3Sonnet + | Self::Claude3_5Sonnet + | Self::Claude3_5SonnetV2 + | Self::Claude3_7Sonnet + | Self::Claude3_7SonnetThinking + | Self::Claude3_5Haiku => true, + + // Amazon Nova models (all support tool use) + Self::AmazonNovaPro | Self::AmazonNovaLite | Self::AmazonNovaMicro => true, + + // AI21 Jamba 1.5 models support tool use + Self::AI21Jamba15LargeV1 | Self::AI21Jamba15MiniV1 => true, + + // Cohere Command R models support tool use + Self::CohereCommandRV1 | Self::CohereCommandRPlusV1 => true, + + // All other models don't support tool use + // Including Meta Llama 3.2, AI21 Jurassic, and others + _ => false, + } + } + + pub fn mode(&self) -> BedrockModelMode { + match self { + Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking { + budget_tokens: Some(4096), + }, + _ => BedrockModelMode::Default, + } + } + + pub fn cross_region_inference_id(&self, region: &str) -> Result { + let region_group = if region.starts_with("us-gov-") { + "us-gov" + } else if region.starts_with("us-") { + "us" + } else if region.starts_with("eu-") { + "eu" + } else if region.starts_with("ap-") || region == "me-central-1" || region == "me-south-1" { + "apac" + } else if region.starts_with("ca-") || region.starts_with("sa-") { + // Canada and South America regions - default to US profiles + "us" + } else { + // Unknown region + return Err(anyhow!("Unsupported Region")); + }; + + let model_id = self.id(); + + match (self, region_group) { + // Custom models can't have CRI IDs + (Model::Custom { .. }, _) => Ok(self.id().into()), + + // Models with US Gov only + (Model::Claude3_5Sonnet, "us-gov") | (Model::Claude3Haiku, "us-gov") => { + Ok(format!("{}.{}", region_group, model_id)) + } + + // Models available only in US + (Model::Claude3Opus, "us") + | (Model::Claude3_7Sonnet, "us") + | (Model::Claude3_7SonnetThinking, "us") => { + Ok(format!("{}.{}", region_group, model_id)) + } + + // Models available in US, EU, and APAC + (Model::Claude3_5SonnetV2, "us") + | (Model::Claude3_5SonnetV2, "apac") + | (Model::Claude3_5Sonnet, _) + | (Model::Claude3Haiku, _) + | (Model::Claude3Sonnet, _) + | (Model::AmazonNovaLite, _) + | (Model::AmazonNovaMicro, _) + | (Model::AmazonNovaPro, _) => Ok(format!("{}.{}", region_group, model_id)), + + // Models with limited EU availability + (Model::MetaLlama321BInstructV1, "us") + | (Model::MetaLlama321BInstructV1, "eu") + | (Model::MetaLlama323BInstructV1, "us") + | (Model::MetaLlama323BInstructV1, "eu") => { + Ok(format!("{}.{}", region_group, model_id)) + } + + // US-only models (all remaining Meta models) + (Model::MetaLlama38BInstructV1, "us") + | (Model::MetaLlama370BInstructV1, "us") + | (Model::MetaLlama318BInstructV1, "us") + | (Model::MetaLlama318BInstructV1_128k, "us") + | (Model::MetaLlama3170BInstructV1, "us") + | (Model::MetaLlama3170BInstructV1_128k, "us") + | (Model::MetaLlama3211BInstructV1, "us") + | (Model::MetaLlama3290BInstructV1, "us") => { + Ok(format!("{}.{}", region_group, model_id)) + } + + // Any other combination is not supported + _ => Ok(self.id().into()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_us_region_inference_ids() -> anyhow::Result<()> { + // Test US regions + assert_eq!( + Model::Claude3_5SonnetV2.cross_region_inference_id("us-east-1")?, + "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + ); + assert_eq!( + Model::Claude3_5SonnetV2.cross_region_inference_id("us-west-2")?, + "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + ); + assert_eq!( + Model::AmazonNovaPro.cross_region_inference_id("us-east-2")?, + "us.amazon.nova-pro-v1:0" + ); + Ok(()) + } + + #[test] + fn test_eu_region_inference_ids() -> anyhow::Result<()> { + // Test European regions + assert_eq!( + Model::Claude3Sonnet.cross_region_inference_id("eu-west-1")?, + "eu.anthropic.claude-3-sonnet-20240229-v1:0" + ); + assert_eq!( + Model::AmazonNovaMicro.cross_region_inference_id("eu-north-1")?, + "eu.amazon.nova-micro-v1:0" + ); + Ok(()) + } + + #[test] + fn test_apac_region_inference_ids() -> anyhow::Result<()> { + // Test Asia-Pacific regions + assert_eq!( + Model::Claude3_5SonnetV2.cross_region_inference_id("ap-northeast-1")?, + "apac.anthropic.claude-3-5-sonnet-20241022-v2:0" + ); + assert_eq!( + Model::AmazonNovaLite.cross_region_inference_id("ap-south-1")?, + "apac.amazon.nova-lite-v1:0" + ); + Ok(()) + } + + #[test] + fn test_gov_region_inference_ids() -> anyhow::Result<()> { + // Test Government regions + assert_eq!( + Model::Claude3_5Sonnet.cross_region_inference_id("us-gov-east-1")?, + "us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0" + ); + assert_eq!( + Model::Claude3Haiku.cross_region_inference_id("us-gov-west-1")?, + "us-gov.anthropic.claude-3-haiku-20240307-v1:0" + ); + Ok(()) + } + + #[test] + fn test_meta_models_inference_ids() -> anyhow::Result<()> { + // Test Meta models + assert_eq!( + Model::MetaLlama370BInstructV1.cross_region_inference_id("us-east-1")?, + "us.meta.llama3-70b-instruct-v1:0" + ); + assert_eq!( + Model::MetaLlama321BInstructV1.cross_region_inference_id("eu-west-1")?, + "eu.meta.llama3-2-1b-instruct-v1:0" + ); + Ok(()) + } + + #[test] + fn test_mistral_models_inference_ids() -> anyhow::Result<()> { + // Mistral models don't follow the regional prefix pattern, + // so they should return their original IDs + assert_eq!( + Model::MistralMistralLarge2402V1.cross_region_inference_id("us-east-1")?, + "mistral.mistral-large-2402-v1:0" + ); + assert_eq!( + Model::MistralMixtral8x7BInstructV0.cross_region_inference_id("eu-west-1")?, + "mistral.mixtral-8x7b-instruct-v0:1" + ); + Ok(()) + } + + #[test] + fn test_ai21_models_inference_ids() -> anyhow::Result<()> { + // AI21 models don't follow the regional prefix pattern, + // so they should return their original IDs + assert_eq!( + Model::AI21J2UltraV1.cross_region_inference_id("us-east-1")?, + "ai21.j2-ultra-v1" + ); + assert_eq!( + Model::AI21JambaInstructV1.cross_region_inference_id("eu-west-1")?, + "ai21.jamba-instruct-v1:0" + ); + Ok(()) + } + + #[test] + fn test_cohere_models_inference_ids() -> anyhow::Result<()> { + // Cohere models don't follow the regional prefix pattern, + // so they should return their original IDs + assert_eq!( + Model::CohereCommandRV1.cross_region_inference_id("us-east-1")?, + "cohere.command-r-v1:0" + ); + assert_eq!( + Model::CohereCommandTextV14_4k.cross_region_inference_id("ap-southeast-1")?, + "cohere.command-text-v14:7:4k" + ); + Ok(()) + } + + #[test] + fn test_custom_model_inference_ids() -> anyhow::Result<()> { + // Test custom models + let custom_model = Model::Custom { + name: "custom.my-model-v1:0".to_string(), + max_tokens: 100000, + display_name: Some("My Custom Model".to_string()), + max_output_tokens: Some(8192), + default_temperature: Some(0.7), + }; + + // Custom model should return its name unchanged + assert_eq!( + custom_model.cross_region_inference_id("us-east-1")?, + "custom.my-model-v1:0" + ); + + Ok(()) + } } diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 14bd6b436d..aa603534e0 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -4,19 +4,29 @@ use std::sync::Arc; use crate::ui::InstructionListItem; use anyhow::{Context as _, Result, anyhow}; -use aws_config::Region; use aws_config::stalled_stream_protection::StalledStreamProtectionConfig; +use aws_config::{BehaviorVersion, Region}; use aws_credential_types::Credentials; use aws_http_client::AwsHttpClient; -use bedrock::bedrock_client::types::{ContentBlockDelta, ContentBlockStart, ConverseStreamOutput}; -use bedrock::bedrock_client::{self, Config}; -use bedrock::{BedrockError, BedrockInnerContent, BedrockMessage, BedrockStreamingResponse, Model}; +use bedrock::bedrock_client::Client as BedrockClient; +use bedrock::bedrock_client::config::timeout::TimeoutConfig; +use bedrock::bedrock_client::types::{ + ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, ReasoningContentBlockDelta, + StopReason, +}; +use bedrock::{ + BedrockAutoToolChoice, BedrockError, BedrockInnerContent, BedrockMessage, BedrockModelMode, + BedrockStreamingResponse, BedrockTool, BedrockToolChoice, BedrockToolConfig, + BedrockToolInputSchema, BedrockToolResultBlock, BedrockToolResultContentBlock, + BedrockToolResultStatus, BedrockToolSpec, BedrockToolUseBlock, Model, value_to_aws_document, +}; use collections::{BTreeMap, HashMap}; use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{ - AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, + AnyView, App, AsyncApp, Context, Entity, FontStyle, FontWeight, Subscription, Task, TextStyle, + WhiteSpace, }; use gpui_tokio::Tokio; use http_client::HttpClient; @@ -24,17 +34,18 @@ use language_model::{ AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role, + LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role, TokenUsage, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::Value; use settings::{Settings, SettingsStore}; -use strum::IntoEnumIterator; +use smol::lock::OnceCell; +use strum::{EnumIter, IntoEnumIterator, IntoStaticStr}; use theme::ThemeSettings; use tokio::runtime::Handle; use ui::{Icon, IconName, List, Tooltip, prelude::*}; -use util::{ResultExt, maybe}; +use util::{ResultExt, default}; use crate::AllLanguageModelSettings; @@ -43,15 +54,33 @@ const PROVIDER_NAME: &str = "Amazon Bedrock"; #[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)] pub struct BedrockCredentials { - pub region: String, pub access_key_id: String, pub secret_access_key: String, + pub session_token: Option, + pub region: String, } #[derive(Default, Clone, Debug, PartialEq)] pub struct AmazonBedrockSettings { - pub session_token: Option, pub available_models: Vec, + pub region: Option, + pub endpoint: Option, + pub profile_name: Option, + pub role_arn: Option, + pub authentication_method: Option, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, EnumIter, IntoStaticStr, JsonSchema)] +pub enum BedrockAuthMethod { + #[serde(rename = "named_profile")] + NamedProfile, + #[serde(rename = "static_credentials")] + StaticCredentials, + #[serde(rename = "sso")] + SingleSignOn, + /// IMDSv2, PodIdentity, env vars, etc. + #[serde(rename = "default")] + Automatic, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] @@ -62,6 +91,36 @@ pub struct AvailableModel { pub cache_configuration: Option, pub max_output_tokens: Option, pub default_temperature: Option, + pub mode: Option, +} + +#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ModelMode { + #[default] + Default, + Thinking { + /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`. + budget_tokens: Option, + }, +} + +impl From for BedrockModelMode { + fn from(value: ModelMode) -> Self { + match value { + ModelMode::Default => BedrockModelMode::Default, + ModelMode::Thinking { budget_tokens } => BedrockModelMode::Thinking { budget_tokens }, + } + } +} + +impl From for ModelMode { + fn from(value: BedrockModelMode) -> Self { + match value { + BedrockModelMode::Default => ModelMode::Default, + BedrockModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens }, + } + } } /// The URL of the base AWS service. @@ -73,11 +132,15 @@ const AMAZON_AWS_URL: &str = "https://amazonaws.com"; // These environment variables all use a `ZED_` prefix because we don't want to overwrite the user's AWS credentials. const ZED_BEDROCK_ACCESS_KEY_ID_VAR: &str = "ZED_ACCESS_KEY_ID"; const ZED_BEDROCK_SECRET_ACCESS_KEY_VAR: &str = "ZED_SECRET_ACCESS_KEY"; +const ZED_BEDROCK_SESSION_TOKEN_VAR: &str = "ZED_SESSION_TOKEN"; +const ZED_AWS_PROFILE_VAR: &str = "ZED_AWS_PROFILE"; const ZED_BEDROCK_REGION_VAR: &str = "ZED_AWS_REGION"; const ZED_AWS_CREDENTIALS_VAR: &str = "ZED_AWS_CREDENTIALS"; +const ZED_AWS_ENDPOINT_VAR: &str = "ZED_AWS_ENDPOINT"; pub struct State { credentials: Option, + settings: Option, credentials_from_env: bool, _subscription: Subscription, } @@ -93,6 +156,7 @@ impl State { this.update(cx, |this, cx| { this.credentials = None; this.credentials_from_env = false; + this.settings = None; cx.notify(); }) }) @@ -120,12 +184,47 @@ impl State { }) } - fn is_authenticated(&self) -> bool { - self.credentials.is_some() + fn is_authenticated(&self) -> Option { + match self + .settings + .as_ref() + .and_then(|s| s.authentication_method.as_ref()) + { + Some(BedrockAuthMethod::StaticCredentials) => Some(String::from( + "You are authenticated using Static Credentials.", + )), + Some(BedrockAuthMethod::NamedProfile) | Some(BedrockAuthMethod::SingleSignOn) => { + match self.settings.as_ref() { + None => Some(String::from( + "You are authenticated using a Named Profile, but no profile is set.", + )), + Some(settings) => match settings.clone().profile_name { + None => Some(String::from( + "You are authenticated using a Named Profile, but no profile is set.", + )), + Some(profile_name) => Some(format!( + "You are authenticated using a Named Profile: {profile_name}", + )), + }, + } + } + Some(BedrockAuthMethod::Automatic) => Some(String::from( + "You are authenticated using Automatic Credentials.", + )), + None => { + if self.credentials.is_some() { + Some(String::from( + "You are authenticated using Static Credentials.", + )) + } else { + None + } + } + } } fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { + if self.is_authenticated().is_some() { return Task::ready(Ok(())); } @@ -170,6 +269,7 @@ impl BedrockLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { let state = cx.new(|cx| State { credentials: None, + settings: Some(AllLanguageModelSettings::get_global(cx).bedrock.clone()), credentials_from_env: false, _subscription: cx.observe_global::(|_, cx| { cx.notify(); @@ -209,6 +309,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { http_client: self.http_client.clone(), handler: self.handler.clone(), state: self.state.clone(), + client: OnceCell::new(), request_limiter: RateLimiter::new(4), })) } @@ -249,6 +350,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { http_client: self.http_client.clone(), handler: self.handler.clone(), state: self.state.clone(), + client: OnceCell::new(), request_limiter: RateLimiter::new(4), }) as Arc }) @@ -256,7 +358,7 @@ impl LanguageModelProvider for BedrockLanguageModelProvider { } fn is_authenticated(&self, cx: &App) -> bool { - self.state.read(cx).is_authenticated() + self.state.read(cx).is_authenticated().is_some() } fn authenticate(&self, cx: &mut App) -> Task> { @@ -287,11 +389,94 @@ struct BedrockModel { model: Model, http_client: AwsHttpClient, handler: tokio::runtime::Handle, + client: OnceCell, state: gpui::Entity, request_limiter: RateLimiter, } impl BedrockModel { + fn get_or_init_client(&self, cx: &AsyncApp) -> Result<&BedrockClient, anyhow::Error> { + self.client + .get_or_try_init_blocking(|| { + let Ok((auth_method, credentials, endpoint, region, settings)) = + cx.read_entity(&self.state, |state, _cx| { + let auth_method = state + .settings + .as_ref() + .and_then(|s| s.authentication_method.clone()) + .unwrap_or(BedrockAuthMethod::Automatic); + + let endpoint = state.settings.as_ref().and_then(|s| s.endpoint.clone()); + + let region = state + .settings + .as_ref() + .and_then(|s| s.region.clone()) + .unwrap_or(String::from("us-east-1")); + + ( + auth_method, + state.credentials.clone(), + endpoint, + region, + state.settings.clone(), + ) + }) + else { + return Err(anyhow!("App state dropped")); + }; + + let mut config_builder = aws_config::defaults(BehaviorVersion::latest()) + .stalled_stream_protection(StalledStreamProtectionConfig::disabled()) + .http_client(self.http_client.clone()) + .region(Region::new(region)) + .timeout_config(TimeoutConfig::disabled()); + + if let Some(endpoint_url) = endpoint { + if !endpoint_url.is_empty() { + config_builder = config_builder.endpoint_url(endpoint_url); + } + } + + match auth_method { + BedrockAuthMethod::StaticCredentials => { + if let Some(creds) = credentials { + let aws_creds = Credentials::new( + creds.access_key_id, + creds.secret_access_key, + creds.session_token, + None, + "zed-bedrock-provider", + ); + config_builder = config_builder.credentials_provider(aws_creds); + } + } + BedrockAuthMethod::NamedProfile | BedrockAuthMethod::SingleSignOn => { + // Currently NamedProfile and SSO behave the same way but only the instructions change + // Until we support BearerAuth through SSO, this will not change. + let profile_name = settings + .and_then(|s| s.profile_name) + .unwrap_or_else(|| "default".to_string()); + + if !profile_name.is_empty() { + config_builder = config_builder.profile_name(profile_name); + } + } + BedrockAuthMethod::Automatic => { + // Use default credential provider chain + } + } + + let config = self.handler.block_on(config_builder.load()); + Ok(BedrockClient::new(&config)) + }) + .map_err(|err| anyhow!("Failed to initialize Bedrock client: {err}"))?; + + self.client + .get() + .ok_or_else(|| anyhow!("Bedrock client not initialized")) + } + fn stream_completion( &self, request: bedrock::Request, @@ -299,37 +484,10 @@ impl BedrockModel { ) -> Result< BoxFuture<'static, BoxStream<'static, Result>>, > { - let Ok(Ok((access_key_id, secret_access_key, region))) = - cx.read_entity(&self.state, |state, _cx| { - if let Some(credentials) = &state.credentials { - Ok(( - credentials.access_key_id.clone(), - credentials.secret_access_key.clone(), - credentials.region.clone(), - )) - } else { - return Err(anyhow!("Failed to read credentials")); - } - }) - else { - return Err(anyhow!("App state dropped")); - }; - - let runtime_client = bedrock_client::Client::from_conf( - Config::builder() - .stalled_stream_protection(StalledStreamProtectionConfig::disabled()) - .credentials_provider(Credentials::new( - access_key_id, - secret_access_key, - None, - None, - "Keychain", - )) - .region(Region::new(region)) - .http_client(self.http_client.clone()) - .build(), - ); - + let runtime_client = self + .get_or_init_client(cx) + .cloned() + .context("Bedrock client not initialized")?; let owned_handle = self.handler.clone(); Ok(async move { @@ -360,7 +518,7 @@ impl LanguageModel for BedrockModel { } fn supports_tools(&self) -> bool { - true + self.model.supports_tool_use() } fn telemetry_id(&self) -> String { @@ -388,12 +546,36 @@ impl LanguageModel for BedrockModel { request: LanguageModelRequest, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { - let request = into_bedrock( + let Ok(region) = cx.read_entity(&self.state, |state, _cx| { + // Get region - from credentials or directly from settings + let region = state + .credentials + .as_ref() + .map(|s| s.region.clone()) + .unwrap_or(String::from("us-east-1")); + + region + }) else { + return async move { Err(anyhow!("App State Dropped")) }.boxed(); + }; + + let model_id = match self.model.cross_region_inference_id(®ion) { + Ok(s) => s, + Err(e) => { + return async move { Err(e) }.boxed(); + } + }; + + let request = match into_bedrock( request, - self.model.id().into(), + model_id, self.model.default_temperature(), self.model.max_output_tokens(), - ); + self.model.mode(), + ) { + Ok(request) => request, + Err(err) => return futures::future::ready(Err(err)).boxed(), + }; let owned_handle = self.handler.clone(); @@ -418,7 +600,8 @@ pub fn into_bedrock( model: String, default_temperature: f32, max_output_tokens: u32, -) -> bedrock::Request { + mode: BedrockModelMode, +) -> Result { let mut new_messages: Vec = Vec::new(); let mut system_message = String::new(); @@ -440,6 +623,32 @@ pub fn into_bedrock( None } } + MessageContent::ToolUse(tool_use) => BedrockToolUseBlock::builder() + .name(tool_use.name.to_string()) + .tool_use_id(tool_use.id.to_string()) + .input(value_to_aws_document(&tool_use.input)) + .build() + .context("failed to build Bedrock tool use block") + .log_err() + .map(BedrockInnerContent::ToolUse), + MessageContent::ToolResult(tool_result) => { + BedrockToolResultBlock::builder() + .tool_use_id(tool_result.tool_use_id.to_string()) + .content(BedrockToolResultContentBlock::Text( + tool_result.content.to_string(), + )) + .status({ + if tool_result.is_error { + BedrockToolResultStatus::Error + } else { + BedrockToolResultStatus::Success + } + }) + .build() + .context("failed to build Bedrock tool result block") + .log_err() + .map(BedrockInnerContent::ToolResult) + } _ => None, }) .collect(); @@ -459,7 +668,7 @@ pub fn into_bedrock( .role(bedrock_role) .set_content(Some(bedrock_message_content)) .build() - .expect("failed to build Bedrock message"), + .context("failed to build Bedrock message")?, ); } Role::System => { @@ -471,19 +680,47 @@ pub fn into_bedrock( } } - bedrock::Request { + let tool_spec: Vec = request + .tools + .iter() + .filter_map(|tool| { + Some(BedrockTool::ToolSpec( + BedrockToolSpec::builder() + .name(tool.name.clone()) + .description(tool.description.clone()) + .input_schema(BedrockToolInputSchema::Json(value_to_aws_document( + &tool.input_schema, + ))) + .build() + .log_err()?, + )) + }) + .collect(); + + let tool_config: BedrockToolConfig = BedrockToolConfig::builder() + .set_tools(Some(tool_spec)) + .tool_choice(BedrockToolChoice::Auto( + BedrockAutoToolChoice::builder().build(), + )) + .build()?; + + Ok(bedrock::Request { model, messages: new_messages, max_tokens: max_output_tokens, system: Some(system_message), - tools: vec![], - tool_choice: None, + tools: Some(tool_config), + thinking: if let BedrockModelMode::Thinking { budget_tokens } = mode { + Some(bedrock::Thinking::Enabled { budget_tokens }) + } else { + None + }, metadata: None, stop_sequences: Vec::new(), temperature: request.temperature.or(Some(default_temperature)), top_k: None, top_p: None, - } + }) } // TODO: just call the ConverseOutput.usage() method: @@ -571,48 +808,72 @@ pub fn map_to_language_model_completion_events( match event { Ok(event) => match event { ConverseStreamOutput::ContentBlockDelta(cb_delta) => { - if let Some(ContentBlockDelta::Text(text_out)) = - cb_delta.delta - { - return Some(( - Some(Ok(LanguageModelCompletionEvent::Text( - text_out, - ))), - state, - )); - } else if let Some(ContentBlockDelta::ToolUse(text_out)) = - cb_delta.delta - { - if let Some(tool_use) = state - .tool_uses_by_index - .get_mut(&cb_delta.content_block_index) - { - tool_use.input_json.push_str(text_out.input()); - return Some((None, state)); - }; + match cb_delta.delta { + Some(ContentBlockDelta::Text(text_out)) => { + let completion_event = + LanguageModelCompletionEvent::Text(text_out); + return Some((Some(Ok(completion_event)), state)); + } - return Some((None, state)); - } else if cb_delta.delta.is_none() { - return Some((None, state)); + Some(ContentBlockDelta::ToolUse(text_out)) => { + if let Some(tool_use) = state + .tool_uses_by_index + .get_mut(&cb_delta.content_block_index) + { + tool_use.input_json.push_str(text_out.input()); + } + } + + Some(ContentBlockDelta::ReasoningContent(thinking)) => { + match thinking { + ReasoningContentBlockDelta::RedactedContent( + redacted, + ) => { + let thinking_event = + LanguageModelCompletionEvent::Thinking( + String::from_utf8( + redacted.into_inner(), + ) + .unwrap_or("REDACTED".to_string()), + ); + + return Some(( + Some(Ok(thinking_event)), + state, + )); + } + ReasoningContentBlockDelta::Signature(_sig) => { + } + ReasoningContentBlockDelta::Text(thoughts) => { + let thinking_event = + LanguageModelCompletionEvent::Thinking( + thoughts.to_string(), + ); + + return Some(( + Some(Ok(thinking_event)), + state, + )); + } + _ => {} + } + } + _ => {} } } ConverseStreamOutput::ContentBlockStart(cb_start) => { - if let Some(start) = cb_start.start { - match start { - ContentBlockStart::ToolUse(text_out) => { - let tool_use = RawToolUse { - id: text_out.tool_use_id, - name: text_out.name, - input_json: String::new(), - }; + if let Some(ContentBlockStart::ToolUse(text_out)) = + cb_start.start + { + let tool_use = RawToolUse { + id: text_out.tool_use_id, + name: text_out.name, + input_json: String::new(), + }; - state.tool_uses_by_index.insert( - cb_start.content_block_index, - tool_use, - ); - } - _ => {} - } + state + .tool_uses_by_index + .insert(cb_start.content_block_index, tool_use); } } ConverseStreamOutput::ContentBlockStop(cb_stop) => { @@ -620,30 +881,85 @@ pub fn map_to_language_model_completion_events( .tool_uses_by_index .remove(&cb_stop.content_block_index) { + let tool_use_event = LanguageModelToolUse { + id: tool_use.id.into(), + name: tool_use.name.into(), + input: if tool_use.input_json.is_empty() { + Value::Null + } else { + serde_json::Value::from_str( + &tool_use.input_json, + ) + .map_err(|err| anyhow!(err)) + .unwrap() + }, + }; + return Some(( - Some(maybe!({ - Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_use.id.into(), - name: tool_use.name.into(), - input: if tool_use.input_json.is_empty() - { - Value::Null - } else { - serde_json::Value::from_str( - &tool_use.input_json, - ) - .map_err(|err| anyhow!(err))? - }, - }, - )) - })), + Some(Ok(LanguageModelCompletionEvent::ToolUse( + tool_use_event, + ))), state, )); } } + + ConverseStreamOutput::Metadata(cb_meta) => { + if let Some(metadata) = cb_meta.usage { + let completion_event = + LanguageModelCompletionEvent::UsageUpdate( + TokenUsage { + input_tokens: metadata.input_tokens as u32, + output_tokens: metadata.output_tokens + as u32, + cache_creation_input_tokens: default(), + cache_read_input_tokens: default(), + }, + ); + return Some((Some(Ok(completion_event)), state)); + } + } + ConverseStreamOutput::MessageStop(message_stop) => { + let reason = match message_stop.stop_reason { + StopReason::ContentFiltered => { + LanguageModelCompletionEvent::Stop( + language_model::StopReason::EndTurn, + ) + } + StopReason::EndTurn => { + LanguageModelCompletionEvent::Stop( + language_model::StopReason::EndTurn, + ) + } + StopReason::GuardrailIntervened => { + LanguageModelCompletionEvent::Stop( + language_model::StopReason::EndTurn, + ) + } + StopReason::MaxTokens => { + LanguageModelCompletionEvent::Stop( + language_model::StopReason::EndTurn, + ) + } + StopReason::StopSequence => { + LanguageModelCompletionEvent::Stop( + language_model::StopReason::EndTurn, + ) + } + StopReason::ToolUse => { + LanguageModelCompletionEvent::Stop( + language_model::StopReason::ToolUse, + ) + } + _ => LanguageModelCompletionEvent::Stop( + language_model::StopReason::EndTurn, + ), + }; + return Some((Some(Ok(reason)), state)); + } _ => {} }, + Err(err) => return Some((Some(Err(anyhow!(err))), state)), } } @@ -661,6 +977,7 @@ pub fn map_to_language_model_completion_events( struct ConfigurationView { access_key_id_editor: Entity, secret_access_key_editor: Entity, + session_token_editor: Entity, region_editor: Entity, state: gpui::Entity, load_credentials_task: Option>, @@ -670,6 +987,7 @@ impl ConfigurationView { const PLACEHOLDER_ACCESS_KEY_ID_TEXT: &'static str = "XXXXXXXXXXXXXXXX"; const PLACEHOLDER_SECRET_ACCESS_KEY_TEXT: &'static str = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"; + const PLACEHOLDER_SESSION_TOKEN_TEXT: &'static str = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"; const PLACEHOLDER_REGION: &'static str = "us-east-1"; fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { @@ -707,6 +1025,11 @@ impl ConfigurationView { editor.set_placeholder_text(Self::PLACEHOLDER_SECRET_ACCESS_KEY_TEXT, cx); editor }), + session_token_editor: cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_placeholder_text(Self::PLACEHOLDER_SESSION_TOKEN_TEXT, cx); + editor + }), region_editor: cx.new(|cx| { let mut editor = Editor::single_line(window, cx); editor.set_placeholder_text(Self::PLACEHOLDER_REGION, cx); @@ -737,6 +1060,18 @@ impl ConfigurationView { .to_string() .trim() .to_string(); + let session_token = self + .session_token_editor + .read(cx) + .text(cx) + .to_string() + .trim() + .to_string(); + let session_token = if session_token.is_empty() { + None + } else { + Some(session_token) + }; let region = self .region_editor .read(cx) @@ -744,15 +1079,21 @@ impl ConfigurationView { .to_string() .trim() .to_string(); + let region = if region.is_empty() { + "us-east-1".to_string() + } else { + region + }; let state = self.state.clone(); cx.spawn(async move |_, cx| { state .update(cx, |state, cx| { let credentials: BedrockCredentials = BedrockCredentials { + region: region.clone(), access_key_id: access_key_id.clone(), secret_access_key: secret_access_key.clone(), - region: region.clone(), + session_token: session_token.clone(), }; state.set_credentials(credentials, cx) @@ -767,6 +1108,8 @@ impl ConfigurationView { .update(cx, |editor, cx| editor.set_text("", window, cx)); self.secret_access_key_editor .update(cx, |editor, cx| editor.set_text("", window, cx)); + self.session_token_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); self.region_editor .update(cx, |editor, cx| editor.set_text("", window, cx)); @@ -800,7 +1143,102 @@ impl ConfigurationView { } } - fn render_aa_id_editor(&self, cx: &mut Context) -> impl IntoElement { + fn make_input_styles(&self, cx: &Context) -> Div { + let bg_color = cx.theme().colors().editor_background; + let border_color = cx.theme().colors().border_variant; + + h_flex() + .w_full() + .px_2() + .py_1() + .bg(bg_color) + .border_1() + .border_color(border_color) + .rounded_sm() + } + + fn should_render_editor(&self, cx: &mut Context) -> Option { + self.state.read(cx).is_authenticated() + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let env_var_set = self.state.read(cx).credentials_from_env; + let creds_type = self.should_render_editor(cx).is_some(); + + if self.load_credentials_task.is_some() { + return div().child(Label::new("Loading credentials...")).into_any(); + } + + if let Some(auth) = self.should_render_editor(cx) { + return h_flex() + .size_full() + .justify_between() + .child( + h_flex() + .gap_1() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(Label::new(if env_var_set { + format!("Access Key ID is set in {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, Secret Key is set in {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, Region is set in {ZED_BEDROCK_REGION_VAR} environment variables.") + } else { + auth.clone() + })), + ) + .child( + Button::new("reset-key", "Reset key") + .icon(Some(IconName::Trash)) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .disabled(env_var_set || creds_type) + .when(env_var_set, |this| { + this.tooltip(Tooltip::text(format!("To reset your credentials, unset the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables."))) + }) + .when(creds_type, |this| { + this.tooltip(Tooltip::text("You cannot reset credentials as they're being derived, check Zed settings to understand how")) + }) + .on_click(cx.listener(|this, _, window, cx| this.reset_credentials(window, cx))), + ) + .into_any(); + } + + v_flex() + .size_full() + .on_action(cx.listener(ConfigurationView::save_credentials)) + .child(Label::new("To use Zed's assistant with Bedrock, you can set a custom authentication strategy through the settings.json, or use static credentials.")) + .child(Label::new("Though to access models on AWS first, you will have to: ")) + .child( + List::new() + .child( + InstructionListItem::new( + "Grant permissions to the strategy you plan to use according to this documentation: ", + Some("Prerequisites"), + Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html"), + ) + ) + .child( + InstructionListItem::new( + "Select the models you would like access to: ", + Some("Bedrock Model Catalog"), + Some("https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess"), + ) + ) + ) + .child(self.render_static_credentials_ui(cx)) + .child(self.render_common_fields(cx)) + .child( + Label::new( + format!("You can also assign the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR} AND {ZED_BEDROCK_REGION_VAR} environment variables and restart Zed.\n Optionally, if your environment uses AWS CLI profiles, you can set {ZED_AWS_PROFILE_VAR}; if it requires a custom endpoint, you can set {ZED_AWS_ENDPOINT_VAR}; and if it requires a Session Token, you can set {ZED_BEDROCK_SESSION_TOKEN_VAR}."), + ) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .into_any() + } +} + +impl ConfigurationView { + fn render_access_key_id_editor(&self, cx: &mut Context) -> impl IntoElement { let text_style = self.make_text_style(cx); EditorElement::new( @@ -814,7 +1252,7 @@ impl ConfigurationView { ) } - fn render_sk_editor(&self, cx: &mut Context) -> impl IntoElement { + fn render_secret_key_editor(&self, cx: &mut Context) -> impl IntoElement { let text_style = self.make_text_style(cx); EditorElement::new( @@ -828,6 +1266,20 @@ impl ConfigurationView { ) } + fn render_session_token_editor(&self, cx: &mut Context) -> impl IntoElement { + let text_style = self.make_text_style(cx); + + EditorElement::new( + &self.session_token_editor, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } + fn render_region_editor(&self, cx: &mut Context) -> impl IntoElement { let text_style = self.make_text_style(cx); @@ -842,124 +1294,80 @@ impl ConfigurationView { ) } - fn should_render_editor(&self, cx: &mut Context) -> bool { - !self.state.read(cx).is_authenticated() - } -} - -impl Render for ConfigurationView { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).credentials_from_env; - let bg_color = cx.theme().colors().editor_background; - let border_color = cx.theme().colors().border_variant; - let input_base_styles = || { - h_flex() - .w_full() - .px_2() - .py_1() - .bg(bg_color) - .border_1() - .border_color(border_color) - .rounded_sm() - }; - - if self.load_credentials_task.is_some() { - div().child(Label::new("Loading credentials...")).into_any() - } else if self.should_render_editor(cx) { - v_flex() - .size_full() - .on_action(cx.listener(ConfigurationView::save_credentials)) - .child(Label::new("To use Zed's assistant with Bedrock, you need to add the Access Key ID, Secret Access Key and AWS Region. Follow these steps:")) - .child( - List::new() - .child( - InstructionListItem::new( - "Start by", - Some("creating a user and security credentials"), - Some("https://us-east-1.console.aws.amazon.com/iam/home") - ) - ) - .child( - InstructionListItem::new( - "Grant that user permissions according to this documentation:", - Some("Prerequisites"), - Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html") - ) - ) - .child( - InstructionListItem::new( - "Select the models you would like access to:", - Some("Bedrock Model Catalog"), - Some("https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess") - ) - ) - .child( - InstructionListItem::text_only("Fill the fields below and hit enter to start using the assistant") - ) - ) - .child( - v_flex() - .my_2() - .gap_1p5() - .child( - v_flex() - .gap_0p5() - .child(Label::new("Access Key ID").size(LabelSize::Small)) - .child( - input_base_styles().child(self.render_aa_id_editor(cx)) - ) - ) - .child( - v_flex() - .gap_0p5() - .child(Label::new("Secret Access Key").size(LabelSize::Small)) - .child( - input_base_styles().child(self.render_sk_editor(cx)) - ) - ) - .child( - v_flex() - .gap_0p5() - .child(Label::new("Region").size(LabelSize::Small)) - .child( - input_base_styles().child(self.render_region_editor(cx)) - ) - ) - ) - .child( - Label::new( - format!("You can also assign the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables and restart Zed."), - ) - .size(LabelSize::Small) - .color(Color::Muted), - ) - .into_any() - } else { - h_flex() - .size_full() - .justify_between() - .child( - h_flex() - .gap_1() - .child(Icon::new(IconName::Check).color(Color::Success)) - .child(Label::new(if env_var_set { - format!("Access Key ID is set in {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, Secret Key is set in {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, Region is set in {ZED_BEDROCK_REGION_VAR} environment variables.") - } else { - "Credentials configured.".to_string() - })), - ) - .child( - Button::new("reset-key", "Reset key") - .icon(Some(IconName::Trash)) - .icon_size(IconSize::Small) - .icon_position(IconPosition::Start) - .disabled(env_var_set) - .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your credentials, unset the {ZED_BEDROCK_ACCESS_KEY_ID_VAR}, {ZED_BEDROCK_SECRET_ACCESS_KEY_VAR}, and {ZED_BEDROCK_REGION_VAR} environment variables."))) - }) - .on_click(cx.listener(|this, _, window, cx| this.reset_credentials(window, cx))), - ) - .into_any() - } + fn render_static_credentials_ui(&self, cx: &mut Context) -> AnyElement { + v_flex() + .my_2() + .gap_1p5() + .child( + Label::new("Static Keys") + .size(LabelSize::Default) + .weight(FontWeight::BOLD), + ) + .child( + Label::new( + "This method uses your AWS access key ID and secret access key directly.", + ) + .size(LabelSize::Small), + ) + .child( + List::new() + .child(InstructionListItem::new( + "Create an IAM user in the AWS console with programmatic access", + Some("IAM Console"), + Some("https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users"), + )) + .child(InstructionListItem::new( + "Attach the necessary Bedrock permissions to this ", + Some("user"), + Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html"), + )) + .child(InstructionListItem::text_only( + "Copy the access key ID and secret access key when provided", + )) + .child(InstructionListItem::text_only( + "Enter these credentials below", + )), + ) + .child( + v_flex() + .gap_0p5() + .child(Label::new("Access Key ID").size(LabelSize::Small)) + .child( + self.make_input_styles(cx) + .child(self.render_access_key_id_editor(cx)), + ), + ) + .child( + v_flex() + .gap_0p5() + .child(Label::new("Secret Access Key").size(LabelSize::Small)) + .child(self.make_input_styles(cx).child(self.render_secret_key_editor(cx))), + ) + .child( + v_flex() + .gap_0p5() + .child(Label::new("Session Token (Optional)").size(LabelSize::Small)) + .child( + self.make_input_styles(cx) + .child(self.render_session_token_editor(cx)), + ), + ) + .into_any_element() + } + + fn render_common_fields(&self, cx: &mut Context) -> AnyElement { + v_flex() + .my_2() + .gap_1p5() + .child( + v_flex() + .gap_0p5() + .child(Label::new("Region").size(LabelSize::Small)) + .child( + self.make_input_styles(cx) + .child(self.render_region_editor(cx)), + ), + ) + .into_any_element() } } diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index 68833d8d63..9ac058f3c9 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -72,6 +72,7 @@ pub struct AllLanguageModelSettings { #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct AllLanguageModelSettingsContent { pub anthropic: Option, + pub bedrock: Option, pub ollama: Option, pub lmstudio: Option, pub openai: Option, @@ -160,6 +161,15 @@ pub struct AnthropicSettingsContentV1 { pub available_models: Option>, } +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct AmazonBedrockSettingsContent { + available_models: Option>, + endpoint_url: Option, + region: Option, + profile: Option, + authentication_method: Option, +} + #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct OllamaSettingsContent { pub api_url: Option, @@ -297,6 +307,25 @@ impl settings::Settings for AllLanguageModelSettings { anthropic.as_ref().and_then(|s| s.available_models.clone()), ); + // Bedrock + let bedrock = value.bedrock.clone(); + merge( + &mut settings.bedrock.profile_name, + bedrock.as_ref().map(|s| s.profile.clone()), + ); + merge( + &mut settings.bedrock.authentication_method, + bedrock.as_ref().map(|s| s.authentication_method.clone()), + ); + merge( + &mut settings.bedrock.region, + bedrock.as_ref().map(|s| s.region.clone()), + ); + merge( + &mut settings.bedrock.endpoint, + bedrock.as_ref().map(|s| s.endpoint_url.clone()), + ); + // Ollama let ollama = value.ollama.clone();