Add aws_http_client and bedrock crates (#25490)

This PR adds new `aws_http_client` and `bedrock` crates for supporting
AWS Bedrock.

Pulling out of https://github.com/zed-industries/zed/pull/21092 to make
it easier to land.

Release Notes:

- N/A

---------

Co-authored-by: Shardul Vaidya <cam.v737@gmail.com>
Co-authored-by: Anthony Eid <hello@anthonyeid.me>
This commit is contained in:
Marshall Bowers 2025-02-24 15:28:20 -05:00 committed by GitHub
parent 8a3fb890b0
commit cdd07fdf29
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 595 additions and 0 deletions

51
Cargo.lock generated
View file

@ -1269,6 +1269,30 @@ dependencies = [
"uuid",
]
[[package]]
name = "aws-sdk-bedrockruntime"
version = "1.74.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6938541d1948a543bca23303fec4cff9c36bf0e63b8fa3ae1b337bcb9d5b81af"
dependencies = [
"aws-credential-types",
"aws-runtime",
"aws-smithy-async",
"aws-smithy-eventstream",
"aws-smithy-http",
"aws-smithy-json",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-types",
"bytes 1.10.0",
"fastrand 2.3.0",
"http 0.2.12",
"once_cell",
"regex-lite",
"tracing",
]
[[package]]
name = "aws-sdk-kinesis"
version = "1.61.0"
@ -1598,6 +1622,17 @@ dependencies = [
"tracing",
]
[[package]]
name = "aws_http_client"
version = "0.1.0"
dependencies = [
"aws-smithy-runtime-api",
"aws-smithy-types",
"futures 0.3.31",
"http_client",
"tokio",
]
[[package]]
name = "axum"
version = "0.6.20"
@ -1727,6 +1762,22 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "bedrock"
version = "0.1.0"
dependencies = [
"anyhow",
"aws-sdk-bedrockruntime",
"aws-smithy-types",
"futures 0.3.31",
"schemars",
"serde",
"serde_json",
"strum",
"thiserror 1.0.69",
"tokio",
]
[[package]]
name = "bigdecimal"
version = "0.4.7"

View file

@ -15,6 +15,8 @@ members = [
"crates/audio",
"crates/auto_update",
"crates/auto_update_ui",
"crates/aws_http_client",
"crates/bedrock",
"crates/breadcrumbs",
"crates/buffer_diff",
"crates/call",
@ -218,6 +220,8 @@ assistant_tools = { path = "crates/assistant_tools" }
audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" }
auto_update_ui = { path = "crates/auto_update_ui" }
aws_http_client = { path = "crates/aws_http_client" }
bedrock = { path = "crates/bedrock" }
breadcrumbs = { path = "crates/breadcrumbs" }
call = { path = "crates/call" }
channel = { path = "crates/channel" }
@ -382,6 +386,11 @@ async-trait = "0.1"
async-tungstenite = "0.28"
async-watch = "0.3.1"
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
aws-config = { version = "1.5.16", features = ["behavior-version-latest"] }
aws-credential-types = { version = "1.2.1", features = ["hardcoded-credentials"] }
aws-sdk-bedrockruntime = { version = "1.73.0", features = ["behavior-version-latest"] }
aws-smithy-runtime-api = { version = "1.7.3", features = ["http-1x", "client"] }
aws-smithy-types = { version = "1.2.13", features = ["http-body-1-x"] }
base64 = "0.22"
bitflags = "2.6.0"
blade-graphics = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" }

View file

@ -0,0 +1,22 @@
[package]
name = "aws_http_client"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/aws_http_client.rs"
[features]
default = []
[dependencies]
aws-smithy-runtime-api.workspace = true
aws-smithy-types.workspace = true
futures.workspace = true
http_client.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }

View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -0,0 +1,118 @@
use std::fmt;
use std::sync::Arc;
use aws_smithy_runtime_api::client::http::{
HttpClient as AwsClient, HttpConnector as AwsConnector,
HttpConnectorFuture as AwsConnectorFuture, HttpConnectorFuture, HttpConnectorSettings,
SharedHttpConnector,
};
use aws_smithy_runtime_api::client::orchestrator::{HttpRequest as AwsHttpRequest, HttpResponse};
use aws_smithy_runtime_api::client::result::ConnectorError;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_runtime_api::http::StatusCode;
use aws_smithy_types::body::SdkBody;
use futures::AsyncReadExt;
use http_client::{AsyncBody, Inner};
use http_client::{HttpClient, Request};
use tokio::runtime::Handle;
struct AwsHttpConnector {
client: Arc<dyn HttpClient>,
handle: Handle,
}
impl std::fmt::Debug for AwsHttpConnector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AwsHttpConnector").finish()
}
}
impl AwsConnector for AwsHttpConnector {
fn call(&self, request: AwsHttpRequest) -> AwsConnectorFuture {
let req = match request.try_into_http1x() {
Ok(req) => req,
Err(err) => {
return HttpConnectorFuture::ready(Err(ConnectorError::other(err.into(), None)))
}
};
let (parts, body) = req.into_parts();
let response = self
.client
.send(Request::from_parts(parts, convert_to_async_body(body)));
let handle = self.handle.clone();
HttpConnectorFuture::new(async move {
let response = match response.await {
Ok(response) => response,
Err(err) => return Err(ConnectorError::other(err.into(), None)),
};
let (parts, body) = response.into_parts();
let body = convert_to_sdk_body(body, handle).await;
Ok(HttpResponse::new(
StatusCode::try_from(parts.status.as_u16()).unwrap(),
body,
))
})
}
}
#[derive(Clone)]
pub struct AwsHttpClient {
client: Arc<dyn HttpClient>,
handler: Handle,
}
impl std::fmt::Debug for AwsHttpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AwsHttpClient").finish()
}
}
impl AwsHttpClient {
pub fn new(client: Arc<dyn HttpClient>, handle: Handle) -> Self {
Self {
client,
handler: handle,
}
}
}
impl AwsClient for AwsHttpClient {
fn http_connector(
&self,
_settings: &HttpConnectorSettings,
_components: &RuntimeComponents,
) -> SharedHttpConnector {
SharedHttpConnector::new(AwsHttpConnector {
client: self.client.clone(),
handle: self.handler.clone(),
})
}
}
pub async fn convert_to_sdk_body(body: AsyncBody, handle: Handle) -> SdkBody {
match body.0 {
Inner::Empty => SdkBody::empty(),
Inner::Bytes(bytes) => SdkBody::from(bytes.into_inner()),
Inner::AsyncReader(mut reader) => {
let buffer = handle.spawn(async move {
let mut buffer = Vec::new();
let _ = reader.read_to_end(&mut buffer).await;
buffer
});
SdkBody::from(buffer.await.unwrap_or_default())
}
}
}
pub fn convert_to_async_body(body: SdkBody) -> AsyncBody {
match body.bytes() {
Some(bytes) => AsyncBody::from((*bytes).to_vec()),
None => AsyncBody::empty(),
}
}

28
crates/bedrock/Cargo.toml Normal file
View file

@ -0,0 +1,28 @@
[package]
name = "bedrock"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/bedrock.rs"
[features]
default = []
schemars = ["dep:schemars"]
[dependencies]
anyhow.workspace = true
aws-sdk-bedrockruntime = { workspace = true, features = ["behavior-version-latest"] }
aws-smithy-types = {workspace = true}
futures.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }

1
crates/bedrock/LICENSE-GPL Symbolic link
View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -0,0 +1,166 @@
mod models;
use std::pin::Pin;
use anyhow::{anyhow, Context, Error, Result};
use aws_sdk_bedrockruntime as bedrock;
pub use aws_sdk_bedrockruntime as bedrock_client;
pub use aws_sdk_bedrockruntime::types::{
ContentBlock as BedrockInnerContent, SpecificToolChoice as BedrockSpecificTool,
ToolChoice as BedrockToolChoice, ToolInputSchema as BedrockToolInputSchema,
ToolSpecification as BedrockTool,
};
use aws_smithy_types::{Document, Number as AwsNumber};
pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
pub use bedrock::types::{
ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
Message as BedrockMessage, ResponseStream as BedrockResponseStream,
};
use futures::stream::{self, BoxStream, Stream};
use serde::{Deserialize, Serialize};
use serde_json::{Number, Value};
use thiserror::Error;
pub use crate::models::*;
pub async fn complete(
client: &bedrock::Client,
request: Request,
) -> Result<BedrockResponse, BedrockError> {
let response = bedrock::Client::converse(client)
.model_id(request.model.clone())
.set_messages(request.messages.into())
.send()
.await
.context("failed to send request to Bedrock");
match response {
Ok(output) => output
.output
.ok_or_else(|| BedrockError::Other(anyhow!("no output"))),
Err(err) => Err(BedrockError::Other(err)),
}
}
pub async fn stream_completion(
client: bedrock::Client,
request: Request,
handle: tokio::runtime::Handle,
) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
handle
.spawn(async move {
let response = bedrock::Client::converse_stream(&client)
.model_id(request.model.clone())
.set_messages(request.messages.into())
.send()
.await;
match response {
Ok(output) => {
let stream: Pin<
Box<
dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
+ Send,
>,
> = Box::pin(stream::unfold(output.stream, |mut stream| async move {
match stream.recv().await {
Ok(Some(output)) => Some((Ok(output), stream)),
Ok(None) => None,
Err(err) => {
Some((
// TODO: Figure out how we can capture Throttling Exceptions
Err(BedrockError::ClientError(anyhow!(
"{:?}",
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
))),
stream,
))
}
}
}));
Ok(stream)
}
Err(err) => Err(anyhow!(
"{:?}",
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
)),
}
})
.await
.map_err(|err| anyhow!("failed to spawn task: {err:?}"))?
}
pub fn aws_document_to_value(document: &Document) -> Value {
match document {
Document::Null => Value::Null,
Document::Bool(value) => Value::Bool(*value),
Document::Number(value) => match *value {
AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
},
Document::String(value) => Value::String(value.clone()),
Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
Document::Object(map) => Value::Object(
map.iter()
.map(|(key, value)| (key.clone(), aws_document_to_value(value)))
.collect(),
),
}
}
pub fn value_to_aws_document(value: &Value) -> Document {
match value {
Value::Null => Document::Null,
Value::Bool(value) => Document::Bool(*value),
Value::Number(value) => {
if let Some(value) = value.as_u64() {
Document::Number(AwsNumber::PosInt(value))
} else if let Some(value) = value.as_i64() {
Document::Number(AwsNumber::NegInt(value))
} else if let Some(value) = value.as_f64() {
Document::Number(AwsNumber::Float(value))
} else {
Document::Null
}
}
Value::String(value) => Document::String(value.clone()),
Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
Value::Object(map) => Document::Object(
map.iter()
.map(|(key, value)| (key.clone(), value_to_aws_document(value)))
.collect(),
),
}
}
#[derive(Debug)]
pub struct Request {
pub model: String,
pub max_tokens: u32,
pub messages: Vec<BedrockMessage>,
pub tools: Vec<BedrockTool>,
pub tool_choice: Option<BedrockToolChoice>,
pub system: Option<String>,
pub metadata: Option<Metadata>,
pub stop_sequences: Vec<String>,
pub temperature: Option<f32>,
pub top_k: Option<u32>,
pub top_p: Option<f32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Metadata {
pub user_id: Option<String>,
}
#[derive(Error, Debug)]
pub enum BedrockError {
#[error("client error: {0}")]
ClientError(anyhow::Error),
#[error("extension error: {0}")]
ExtensionError(anyhow::Error),
#[error(transparent)]
Other(#[from] anyhow::Error),
}

View file

@ -0,0 +1,199 @@
use anyhow::anyhow;
use serde::{Deserialize, Serialize};
use strum::EnumIter;
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model {
// Anthropic models (already included)
#[default]
#[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")]
Claude3_5Sonnet,
#[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
Claude3Opus,
#[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
Claude3Sonnet,
#[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
Claude3_5Haiku,
// Amazon Nova Models
AmazonNovaLite,
AmazonNovaMicro,
AmazonNovaPro,
// AI21 models
AI21J2GrandeInstruct,
AI21J2JumboInstruct,
AI21J2Mid,
AI21J2MidV1,
AI21J2Ultra,
AI21J2UltraV1_8k,
AI21J2UltraV1,
AI21JambaInstructV1,
AI21Jamba15LargeV1,
AI21Jamba15MiniV1,
// Cohere models
CohereCommandTextV14_4k,
CohereCommandRV1,
CohereCommandRPlusV1,
CohereCommandLightTextV14_4k,
// Meta models
MetaLlama38BInstructV1,
MetaLlama370BInstructV1,
MetaLlama318BInstructV1_128k,
MetaLlama318BInstructV1,
MetaLlama3170BInstructV1_128k,
MetaLlama3170BInstructV1,
MetaLlama3211BInstructV1,
MetaLlama3290BInstructV1,
MetaLlama321BInstructV1,
MetaLlama323BInstructV1,
// Mistral models
MistralMistral7BInstructV0,
MistralMixtral8x7BInstructV0,
MistralMistralLarge2402V1,
MistralMistralSmall2402V1,
#[serde(rename = "custom")]
Custom {
name: String,
max_tokens: usize,
/// The name displayed in the UI, such as in the assistant panel model dropdown menu.
display_name: Option<String>,
max_output_tokens: Option<u32>,
default_temperature: Option<f32>,
},
}
impl Model {
pub fn from_id(id: &str) -> anyhow::Result<Self> {
if id.starts_with("claude-3-5-sonnet") {
Ok(Self::Claude3_5Sonnet)
} else if id.starts_with("claude-3-opus") {
Ok(Self::Claude3Opus)
} else if id.starts_with("claude-3-sonnet") {
Ok(Self::Claude3Sonnet)
} else if id.starts_with("claude-3-5-haiku") {
Ok(Self::Claude3_5Haiku)
} else {
Err(anyhow!("invalid model id"))
}
}
pub fn id(&self) -> &str {
match self {
Model::Claude3_5Sonnet => "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
Model::Claude3Opus => "us.anthropic.claude-3-opus-20240229-v1:0",
Model::Claude3Sonnet => "us.anthropic.claude-3-sonnet-20240229-v1:0",
Model::Claude3_5Haiku => "us.anthropic.claude-3-5-haiku-20241022-v1:0",
Model::AmazonNovaLite => "us.amazon.nova-lite-v1:0",
Model::AmazonNovaMicro => "us.amazon.nova-micro-v1:0",
Model::AmazonNovaPro => "us.amazon.nova-pro-v1:0",
Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct",
Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct",
Model::AI21J2Mid => "ai21.j2-mid",
Model::AI21J2MidV1 => "ai21.j2-mid-v1",
Model::AI21J2Ultra => "ai21.j2-ultra",
Model::AI21J2UltraV1_8k => "ai21.j2-ultra-v1:0:8k",
Model::AI21J2UltraV1 => "ai21.j2-ultra-v1",
Model::AI21JambaInstructV1 => "ai21.jamba-instruct-v1:0",
Model::AI21Jamba15LargeV1 => "ai21.jamba-1-5-large-v1:0",
Model::AI21Jamba15MiniV1 => "ai21.jamba-1-5-mini-v1:0",
Model::CohereCommandTextV14_4k => "cohere.command-text-v14:7:4k",
Model::CohereCommandRV1 => "cohere.command-r-v1:0",
Model::CohereCommandRPlusV1 => "cohere.command-r-plus-v1:0",
Model::CohereCommandLightTextV14_4k => "cohere.command-light-text-v14:7:4k",
Model::MetaLlama38BInstructV1 => "meta.llama3-8b-instruct-v1:0",
Model::MetaLlama370BInstructV1 => "meta.llama3-70b-instruct-v1:0",
Model::MetaLlama318BInstructV1_128k => "meta.llama3-1-8b-instruct-v1:0:128k",
Model::MetaLlama318BInstructV1 => "meta.llama3-1-8b-instruct-v1:0",
Model::MetaLlama3170BInstructV1_128k => "meta.llama3-1-70b-instruct-v1:0:128k",
Model::MetaLlama3170BInstructV1 => "meta.llama3-1-70b-instruct-v1:0",
Model::MetaLlama3211BInstructV1 => "meta.llama3-2-11b-instruct-v1:0",
Model::MetaLlama3290BInstructV1 => "meta.llama3-2-90b-instruct-v1:0",
Model::MetaLlama321BInstructV1 => "meta.llama3-2-1b-instruct-v1:0",
Model::MetaLlama323BInstructV1 => "meta.llama3-2-3b-instruct-v1:0",
Model::MistralMistral7BInstructV0 => "mistral.mistral-7b-instruct-v0:2",
Model::MistralMixtral8x7BInstructV0 => "mistral.mixtral-8x7b-instruct-v0:1",
Model::MistralMistralLarge2402V1 => "mistral.mistral-large-2402-v1:0",
Model::MistralMistralSmall2402V1 => "mistral.mistral-small-2402-v1:0",
Self::Custom { name, .. } => name,
}
}
pub fn display_name(&self) -> &str {
match self {
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3_5Haiku => "Claude 3.5 Haiku",
Self::AmazonNovaLite => "Amazon Nova Lite",
Self::AmazonNovaMicro => "Amazon Nova Micro",
Self::AmazonNovaPro => "Amazon Nova Pro",
Self::AI21J2GrandeInstruct => "AI21 Jurassic2 Grande Instruct",
Self::AI21J2JumboInstruct => "AI21 Jurassic2 Jumbo Instruct",
Self::AI21J2Mid => "AI21 Jurassic2 Mid",
Self::AI21J2MidV1 => "AI21 Jurassic2 Mid V1",
Self::AI21J2Ultra => "AI21 Jurassic2 Ultra",
Self::AI21J2UltraV1_8k => "AI21 Jurassic2 Ultra V1 8K",
Self::AI21J2UltraV1 => "AI21 Jurassic2 Ultra V1",
Self::AI21JambaInstructV1 => "AI21 Jamba Instruct",
Self::AI21Jamba15LargeV1 => "AI21 Jamba 1.5 Large",
Self::AI21Jamba15MiniV1 => "AI21 Jamba 1.5 Mini",
Self::CohereCommandTextV14_4k => "Cohere Command Text V14 4K",
Self::CohereCommandRV1 => "Cohere Command R V1",
Self::CohereCommandRPlusV1 => "Cohere Command R Plus V1",
Self::CohereCommandLightTextV14_4k => "Cohere Command Light Text V14 4K",
Self::MetaLlama38BInstructV1 => "Meta Llama 3 8B Instruct V1",
Self::MetaLlama370BInstructV1 => "Meta Llama 3 70B Instruct V1",
Self::MetaLlama318BInstructV1_128k => "Meta Llama 3 1.8B Instruct V1 128K",
Self::MetaLlama318BInstructV1 => "Meta Llama 3 1.8B Instruct V1",
Self::MetaLlama3170BInstructV1_128k => "Meta Llama 3 1 70B Instruct V1 128K",
Self::MetaLlama3170BInstructV1 => "Meta Llama 3 1 70B Instruct V1",
Self::MetaLlama3211BInstructV1 => "Meta Llama 3 2 11B Instruct V1",
Self::MetaLlama3290BInstructV1 => "Meta Llama 3 2 90B Instruct V1",
Self::MetaLlama321BInstructV1 => "Meta Llama 3 2 1B Instruct V1",
Self::MetaLlama323BInstructV1 => "Meta Llama 3 2 3B Instruct V1",
Self::MistralMistral7BInstructV0 => "Mistral 7B Instruct V0",
Self::MistralMixtral8x7BInstructV0 => "Mistral Mixtral 8x7B Instruct V0",
Self::MistralMistralLarge2402V1 => "Mistral Large 2402 V1",
Self::MistralMistralSmall2402V1 => "Mistral Small 2402 V1",
Self::Custom {
display_name, name, ..
} => display_name.as_deref().unwrap_or(name),
}
}
pub fn max_token_count(&self) -> usize {
match self {
Self::Claude3_5Sonnet
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3_5Haiku => 200_000,
Self::Custom { max_tokens, .. } => *max_tokens,
_ => 200_000,
}
}
pub fn max_output_tokens(&self) -> u32 {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096,
Self::Claude3_5Sonnet => 8_192,
Self::Custom {
max_output_tokens, ..
} => max_output_tokens.unwrap_or(4_096),
_ => 4_096,
}
}
pub fn default_temperature(&self) -> f32 {
match self {
Self::Claude3_5Sonnet
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3_5Haiku => 1.0,
Self::Custom {
default_temperature,
..
} => default_temperature.unwrap_or(1.0),
_ => 1.0,
}
}
}