Add aws_http_client
and bedrock
crates (#25490)
This PR adds new `aws_http_client` and `bedrock` crates for supporting AWS Bedrock. Pulling out of https://github.com/zed-industries/zed/pull/21092 to make it easier to land. Release Notes: - N/A --------- Co-authored-by: Shardul Vaidya <cam.v737@gmail.com> Co-authored-by: Anthony Eid <hello@anthonyeid.me>
This commit is contained in:
parent
8a3fb890b0
commit
cdd07fdf29
9 changed files with 595 additions and 0 deletions
51
Cargo.lock
generated
51
Cargo.lock
generated
|
@ -1269,6 +1269,30 @@ dependencies = [
|
|||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aws-sdk-bedrockruntime"
|
||||
version = "1.74.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6938541d1948a543bca23303fec4cff9c36bf0e63b8fa3ae1b337bcb9d5b81af"
|
||||
dependencies = [
|
||||
"aws-credential-types",
|
||||
"aws-runtime",
|
||||
"aws-smithy-async",
|
||||
"aws-smithy-eventstream",
|
||||
"aws-smithy-http",
|
||||
"aws-smithy-json",
|
||||
"aws-smithy-runtime",
|
||||
"aws-smithy-runtime-api",
|
||||
"aws-smithy-types",
|
||||
"aws-types",
|
||||
"bytes 1.10.0",
|
||||
"fastrand 2.3.0",
|
||||
"http 0.2.12",
|
||||
"once_cell",
|
||||
"regex-lite",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aws-sdk-kinesis"
|
||||
version = "1.61.0"
|
||||
|
@ -1598,6 +1622,17 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aws_http_client"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"aws-smithy-runtime-api",
|
||||
"aws-smithy-types",
|
||||
"futures 0.3.31",
|
||||
"http_client",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.6.20"
|
||||
|
@ -1727,6 +1762,22 @@ version = "1.6.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
|
||||
|
||||
[[package]]
|
||||
name = "bedrock"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"aws-sdk-bedrockruntime",
|
||||
"aws-smithy-types",
|
||||
"futures 0.3.31",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bigdecimal"
|
||||
version = "0.4.7"
|
||||
|
|
|
@ -15,6 +15,8 @@ members = [
|
|||
"crates/audio",
|
||||
"crates/auto_update",
|
||||
"crates/auto_update_ui",
|
||||
"crates/aws_http_client",
|
||||
"crates/bedrock",
|
||||
"crates/breadcrumbs",
|
||||
"crates/buffer_diff",
|
||||
"crates/call",
|
||||
|
@ -218,6 +220,8 @@ assistant_tools = { path = "crates/assistant_tools" }
|
|||
audio = { path = "crates/audio" }
|
||||
auto_update = { path = "crates/auto_update" }
|
||||
auto_update_ui = { path = "crates/auto_update_ui" }
|
||||
aws_http_client = { path = "crates/aws_http_client" }
|
||||
bedrock = { path = "crates/bedrock" }
|
||||
breadcrumbs = { path = "crates/breadcrumbs" }
|
||||
call = { path = "crates/call" }
|
||||
channel = { path = "crates/channel" }
|
||||
|
@ -382,6 +386,11 @@ async-trait = "0.1"
|
|||
async-tungstenite = "0.28"
|
||||
async-watch = "0.3.1"
|
||||
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
|
||||
aws-config = { version = "1.5.16", features = ["behavior-version-latest"] }
|
||||
aws-credential-types = { version = "1.2.1", features = ["hardcoded-credentials"] }
|
||||
aws-sdk-bedrockruntime = { version = "1.73.0", features = ["behavior-version-latest"] }
|
||||
aws-smithy-runtime-api = { version = "1.7.3", features = ["http-1x", "client"] }
|
||||
aws-smithy-types = { version = "1.2.13", features = ["http-body-1-x"] }
|
||||
base64 = "0.22"
|
||||
bitflags = "2.6.0"
|
||||
blade-graphics = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" }
|
||||
|
|
22
crates/aws_http_client/Cargo.toml
Normal file
22
crates/aws_http_client/Cargo.toml
Normal file
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
name = "aws_http_client"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/aws_http_client.rs"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
[dependencies]
|
||||
aws-smithy-runtime-api.workspace = true
|
||||
aws-smithy-types.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
|
1
crates/aws_http_client/LICENSE-GPL
Symbolic link
1
crates/aws_http_client/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
118
crates/aws_http_client/src/aws_http_client.rs
Normal file
118
crates/aws_http_client/src/aws_http_client.rs
Normal file
|
@ -0,0 +1,118 @@
|
|||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use aws_smithy_runtime_api::client::http::{
|
||||
HttpClient as AwsClient, HttpConnector as AwsConnector,
|
||||
HttpConnectorFuture as AwsConnectorFuture, HttpConnectorFuture, HttpConnectorSettings,
|
||||
SharedHttpConnector,
|
||||
};
|
||||
use aws_smithy_runtime_api::client::orchestrator::{HttpRequest as AwsHttpRequest, HttpResponse};
|
||||
use aws_smithy_runtime_api::client::result::ConnectorError;
|
||||
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
|
||||
use aws_smithy_runtime_api::http::StatusCode;
|
||||
use aws_smithy_types::body::SdkBody;
|
||||
use futures::AsyncReadExt;
|
||||
use http_client::{AsyncBody, Inner};
|
||||
use http_client::{HttpClient, Request};
|
||||
use tokio::runtime::Handle;
|
||||
|
||||
struct AwsHttpConnector {
|
||||
client: Arc<dyn HttpClient>,
|
||||
handle: Handle,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AwsHttpConnector {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("AwsHttpConnector").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl AwsConnector for AwsHttpConnector {
|
||||
fn call(&self, request: AwsHttpRequest) -> AwsConnectorFuture {
|
||||
let req = match request.try_into_http1x() {
|
||||
Ok(req) => req,
|
||||
Err(err) => {
|
||||
return HttpConnectorFuture::ready(Err(ConnectorError::other(err.into(), None)))
|
||||
}
|
||||
};
|
||||
|
||||
let (parts, body) = req.into_parts();
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.send(Request::from_parts(parts, convert_to_async_body(body)));
|
||||
|
||||
let handle = self.handle.clone();
|
||||
|
||||
HttpConnectorFuture::new(async move {
|
||||
let response = match response.await {
|
||||
Ok(response) => response,
|
||||
Err(err) => return Err(ConnectorError::other(err.into(), None)),
|
||||
};
|
||||
let (parts, body) = response.into_parts();
|
||||
let body = convert_to_sdk_body(body, handle).await;
|
||||
|
||||
Ok(HttpResponse::new(
|
||||
StatusCode::try_from(parts.status.as_u16()).unwrap(),
|
||||
body,
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AwsHttpClient {
|
||||
client: Arc<dyn HttpClient>,
|
||||
handler: Handle,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AwsHttpClient {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("AwsHttpClient").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl AwsHttpClient {
|
||||
pub fn new(client: Arc<dyn HttpClient>, handle: Handle) -> Self {
|
||||
Self {
|
||||
client,
|
||||
handler: handle,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AwsClient for AwsHttpClient {
|
||||
fn http_connector(
|
||||
&self,
|
||||
_settings: &HttpConnectorSettings,
|
||||
_components: &RuntimeComponents,
|
||||
) -> SharedHttpConnector {
|
||||
SharedHttpConnector::new(AwsHttpConnector {
|
||||
client: self.client.clone(),
|
||||
handle: self.handler.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn convert_to_sdk_body(body: AsyncBody, handle: Handle) -> SdkBody {
|
||||
match body.0 {
|
||||
Inner::Empty => SdkBody::empty(),
|
||||
Inner::Bytes(bytes) => SdkBody::from(bytes.into_inner()),
|
||||
Inner::AsyncReader(mut reader) => {
|
||||
let buffer = handle.spawn(async move {
|
||||
let mut buffer = Vec::new();
|
||||
let _ = reader.read_to_end(&mut buffer).await;
|
||||
buffer
|
||||
});
|
||||
|
||||
SdkBody::from(buffer.await.unwrap_or_default())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert_to_async_body(body: SdkBody) -> AsyncBody {
|
||||
match body.bytes() {
|
||||
Some(bytes) => AsyncBody::from((*bytes).to_vec()),
|
||||
None => AsyncBody::empty(),
|
||||
}
|
||||
}
|
28
crates/bedrock/Cargo.toml
Normal file
28
crates/bedrock/Cargo.toml
Normal file
|
@ -0,0 +1,28 @@
|
|||
[package]
|
||||
name = "bedrock"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/bedrock.rs"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
schemars = ["dep:schemars"]
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
aws-sdk-bedrockruntime = { workspace = true, features = ["behavior-version-latest"] }
|
||||
aws-smithy-types = {workspace = true}
|
||||
futures.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
thiserror.workspace = true
|
||||
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
|
1
crates/bedrock/LICENSE-GPL
Symbolic link
1
crates/bedrock/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
166
crates/bedrock/src/bedrock.rs
Normal file
166
crates/bedrock/src/bedrock.rs
Normal file
|
@ -0,0 +1,166 @@
|
|||
mod models;
|
||||
|
||||
use std::pin::Pin;
|
||||
|
||||
use anyhow::{anyhow, Context, Error, Result};
|
||||
use aws_sdk_bedrockruntime as bedrock;
|
||||
pub use aws_sdk_bedrockruntime as bedrock_client;
|
||||
pub use aws_sdk_bedrockruntime::types::{
|
||||
ContentBlock as BedrockInnerContent, SpecificToolChoice as BedrockSpecificTool,
|
||||
ToolChoice as BedrockToolChoice, ToolInputSchema as BedrockToolInputSchema,
|
||||
ToolSpecification as BedrockTool,
|
||||
};
|
||||
use aws_smithy_types::{Document, Number as AwsNumber};
|
||||
pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
|
||||
pub use bedrock::types::{
|
||||
ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
|
||||
ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
|
||||
Message as BedrockMessage, ResponseStream as BedrockResponseStream,
|
||||
};
|
||||
use futures::stream::{self, BoxStream, Stream};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Number, Value};
|
||||
use thiserror::Error;
|
||||
|
||||
pub use crate::models::*;
|
||||
|
||||
pub async fn complete(
|
||||
client: &bedrock::Client,
|
||||
request: Request,
|
||||
) -> Result<BedrockResponse, BedrockError> {
|
||||
let response = bedrock::Client::converse(client)
|
||||
.model_id(request.model.clone())
|
||||
.set_messages(request.messages.into())
|
||||
.send()
|
||||
.await
|
||||
.context("failed to send request to Bedrock");
|
||||
|
||||
match response {
|
||||
Ok(output) => output
|
||||
.output
|
||||
.ok_or_else(|| BedrockError::Other(anyhow!("no output"))),
|
||||
Err(err) => Err(BedrockError::Other(err)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
client: bedrock::Client,
|
||||
request: Request,
|
||||
handle: tokio::runtime::Handle,
|
||||
) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
|
||||
handle
|
||||
.spawn(async move {
|
||||
let response = bedrock::Client::converse_stream(&client)
|
||||
.model_id(request.model.clone())
|
||||
.set_messages(request.messages.into())
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match response {
|
||||
Ok(output) => {
|
||||
let stream: Pin<
|
||||
Box<
|
||||
dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
|
||||
+ Send,
|
||||
>,
|
||||
> = Box::pin(stream::unfold(output.stream, |mut stream| async move {
|
||||
match stream.recv().await {
|
||||
Ok(Some(output)) => Some((Ok(output), stream)),
|
||||
Ok(None) => None,
|
||||
Err(err) => {
|
||||
Some((
|
||||
// TODO: Figure out how we can capture Throttling Exceptions
|
||||
Err(BedrockError::ClientError(anyhow!(
|
||||
"{:?}",
|
||||
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
|
||||
))),
|
||||
stream,
|
||||
))
|
||||
}
|
||||
}
|
||||
}));
|
||||
Ok(stream)
|
||||
}
|
||||
Err(err) => Err(anyhow!(
|
||||
"{:?}",
|
||||
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
|
||||
)),
|
||||
}
|
||||
})
|
||||
.await
|
||||
.map_err(|err| anyhow!("failed to spawn task: {err:?}"))?
|
||||
}
|
||||
|
||||
pub fn aws_document_to_value(document: &Document) -> Value {
|
||||
match document {
|
||||
Document::Null => Value::Null,
|
||||
Document::Bool(value) => Value::Bool(*value),
|
||||
Document::Number(value) => match *value {
|
||||
AwsNumber::PosInt(value) => Value::Number(Number::from(value)),
|
||||
AwsNumber::NegInt(value) => Value::Number(Number::from(value)),
|
||||
AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()),
|
||||
},
|
||||
Document::String(value) => Value::String(value.clone()),
|
||||
Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()),
|
||||
Document::Object(map) => Value::Object(
|
||||
map.iter()
|
||||
.map(|(key, value)| (key.clone(), aws_document_to_value(value)))
|
||||
.collect(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn value_to_aws_document(value: &Value) -> Document {
|
||||
match value {
|
||||
Value::Null => Document::Null,
|
||||
Value::Bool(value) => Document::Bool(*value),
|
||||
Value::Number(value) => {
|
||||
if let Some(value) = value.as_u64() {
|
||||
Document::Number(AwsNumber::PosInt(value))
|
||||
} else if let Some(value) = value.as_i64() {
|
||||
Document::Number(AwsNumber::NegInt(value))
|
||||
} else if let Some(value) = value.as_f64() {
|
||||
Document::Number(AwsNumber::Float(value))
|
||||
} else {
|
||||
Document::Null
|
||||
}
|
||||
}
|
||||
Value::String(value) => Document::String(value.clone()),
|
||||
Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()),
|
||||
Value::Object(map) => Document::Object(
|
||||
map.iter()
|
||||
.map(|(key, value)| (key.clone(), value_to_aws_document(value)))
|
||||
.collect(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Request {
|
||||
pub model: String,
|
||||
pub max_tokens: u32,
|
||||
pub messages: Vec<BedrockMessage>,
|
||||
pub tools: Vec<BedrockTool>,
|
||||
pub tool_choice: Option<BedrockToolChoice>,
|
||||
pub system: Option<String>,
|
||||
pub metadata: Option<Metadata>,
|
||||
pub stop_sequences: Vec<String>,
|
||||
pub temperature: Option<f32>,
|
||||
pub top_k: Option<u32>,
|
||||
pub top_p: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Metadata {
|
||||
pub user_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum BedrockError {
|
||||
#[error("client error: {0}")]
|
||||
ClientError(anyhow::Error),
|
||||
#[error("extension error: {0}")]
|
||||
ExtensionError(anyhow::Error),
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
199
crates/bedrock/src/models.rs
Normal file
199
crates/bedrock/src/models.rs
Normal file
|
@ -0,0 +1,199 @@
|
|||
use anyhow::anyhow;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum::EnumIter;
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||
pub enum Model {
|
||||
// Anthropic models (already included)
|
||||
#[default]
|
||||
#[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")]
|
||||
Claude3_5Sonnet,
|
||||
#[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
|
||||
Claude3Opus,
|
||||
#[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
|
||||
Claude3Sonnet,
|
||||
#[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
|
||||
Claude3_5Haiku,
|
||||
// Amazon Nova Models
|
||||
AmazonNovaLite,
|
||||
AmazonNovaMicro,
|
||||
AmazonNovaPro,
|
||||
// AI21 models
|
||||
AI21J2GrandeInstruct,
|
||||
AI21J2JumboInstruct,
|
||||
AI21J2Mid,
|
||||
AI21J2MidV1,
|
||||
AI21J2Ultra,
|
||||
AI21J2UltraV1_8k,
|
||||
AI21J2UltraV1,
|
||||
AI21JambaInstructV1,
|
||||
AI21Jamba15LargeV1,
|
||||
AI21Jamba15MiniV1,
|
||||
// Cohere models
|
||||
CohereCommandTextV14_4k,
|
||||
CohereCommandRV1,
|
||||
CohereCommandRPlusV1,
|
||||
CohereCommandLightTextV14_4k,
|
||||
// Meta models
|
||||
MetaLlama38BInstructV1,
|
||||
MetaLlama370BInstructV1,
|
||||
MetaLlama318BInstructV1_128k,
|
||||
MetaLlama318BInstructV1,
|
||||
MetaLlama3170BInstructV1_128k,
|
||||
MetaLlama3170BInstructV1,
|
||||
MetaLlama3211BInstructV1,
|
||||
MetaLlama3290BInstructV1,
|
||||
MetaLlama321BInstructV1,
|
||||
MetaLlama323BInstructV1,
|
||||
// Mistral models
|
||||
MistralMistral7BInstructV0,
|
||||
MistralMixtral8x7BInstructV0,
|
||||
MistralMistralLarge2402V1,
|
||||
MistralMistralSmall2402V1,
|
||||
#[serde(rename = "custom")]
|
||||
Custom {
|
||||
name: String,
|
||||
max_tokens: usize,
|
||||
/// The name displayed in the UI, such as in the assistant panel model dropdown menu.
|
||||
display_name: Option<String>,
|
||||
max_output_tokens: Option<u32>,
|
||||
default_temperature: Option<f32>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn from_id(id: &str) -> anyhow::Result<Self> {
|
||||
if id.starts_with("claude-3-5-sonnet") {
|
||||
Ok(Self::Claude3_5Sonnet)
|
||||
} else if id.starts_with("claude-3-opus") {
|
||||
Ok(Self::Claude3Opus)
|
||||
} else if id.starts_with("claude-3-sonnet") {
|
||||
Ok(Self::Claude3Sonnet)
|
||||
} else if id.starts_with("claude-3-5-haiku") {
|
||||
Ok(Self::Claude3_5Haiku)
|
||||
} else {
|
||||
Err(anyhow!("invalid model id"))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
Model::Claude3_5Sonnet => "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
Model::Claude3Opus => "us.anthropic.claude-3-opus-20240229-v1:0",
|
||||
Model::Claude3Sonnet => "us.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
Model::Claude3_5Haiku => "us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
Model::AmazonNovaLite => "us.amazon.nova-lite-v1:0",
|
||||
Model::AmazonNovaMicro => "us.amazon.nova-micro-v1:0",
|
||||
Model::AmazonNovaPro => "us.amazon.nova-pro-v1:0",
|
||||
Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct",
|
||||
Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct",
|
||||
Model::AI21J2Mid => "ai21.j2-mid",
|
||||
Model::AI21J2MidV1 => "ai21.j2-mid-v1",
|
||||
Model::AI21J2Ultra => "ai21.j2-ultra",
|
||||
Model::AI21J2UltraV1_8k => "ai21.j2-ultra-v1:0:8k",
|
||||
Model::AI21J2UltraV1 => "ai21.j2-ultra-v1",
|
||||
Model::AI21JambaInstructV1 => "ai21.jamba-instruct-v1:0",
|
||||
Model::AI21Jamba15LargeV1 => "ai21.jamba-1-5-large-v1:0",
|
||||
Model::AI21Jamba15MiniV1 => "ai21.jamba-1-5-mini-v1:0",
|
||||
Model::CohereCommandTextV14_4k => "cohere.command-text-v14:7:4k",
|
||||
Model::CohereCommandRV1 => "cohere.command-r-v1:0",
|
||||
Model::CohereCommandRPlusV1 => "cohere.command-r-plus-v1:0",
|
||||
Model::CohereCommandLightTextV14_4k => "cohere.command-light-text-v14:7:4k",
|
||||
Model::MetaLlama38BInstructV1 => "meta.llama3-8b-instruct-v1:0",
|
||||
Model::MetaLlama370BInstructV1 => "meta.llama3-70b-instruct-v1:0",
|
||||
Model::MetaLlama318BInstructV1_128k => "meta.llama3-1-8b-instruct-v1:0:128k",
|
||||
Model::MetaLlama318BInstructV1 => "meta.llama3-1-8b-instruct-v1:0",
|
||||
Model::MetaLlama3170BInstructV1_128k => "meta.llama3-1-70b-instruct-v1:0:128k",
|
||||
Model::MetaLlama3170BInstructV1 => "meta.llama3-1-70b-instruct-v1:0",
|
||||
Model::MetaLlama3211BInstructV1 => "meta.llama3-2-11b-instruct-v1:0",
|
||||
Model::MetaLlama3290BInstructV1 => "meta.llama3-2-90b-instruct-v1:0",
|
||||
Model::MetaLlama321BInstructV1 => "meta.llama3-2-1b-instruct-v1:0",
|
||||
Model::MetaLlama323BInstructV1 => "meta.llama3-2-3b-instruct-v1:0",
|
||||
Model::MistralMistral7BInstructV0 => "mistral.mistral-7b-instruct-v0:2",
|
||||
Model::MistralMixtral8x7BInstructV0 => "mistral.mixtral-8x7b-instruct-v0:1",
|
||||
Model::MistralMistralLarge2402V1 => "mistral.mistral-large-2402-v1:0",
|
||||
Model::MistralMistralSmall2402V1 => "mistral.mistral-small-2402-v1:0",
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
|
||||
Self::Claude3Opus => "Claude 3 Opus",
|
||||
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||
Self::Claude3_5Haiku => "Claude 3.5 Haiku",
|
||||
Self::AmazonNovaLite => "Amazon Nova Lite",
|
||||
Self::AmazonNovaMicro => "Amazon Nova Micro",
|
||||
Self::AmazonNovaPro => "Amazon Nova Pro",
|
||||
Self::AI21J2GrandeInstruct => "AI21 Jurassic2 Grande Instruct",
|
||||
Self::AI21J2JumboInstruct => "AI21 Jurassic2 Jumbo Instruct",
|
||||
Self::AI21J2Mid => "AI21 Jurassic2 Mid",
|
||||
Self::AI21J2MidV1 => "AI21 Jurassic2 Mid V1",
|
||||
Self::AI21J2Ultra => "AI21 Jurassic2 Ultra",
|
||||
Self::AI21J2UltraV1_8k => "AI21 Jurassic2 Ultra V1 8K",
|
||||
Self::AI21J2UltraV1 => "AI21 Jurassic2 Ultra V1",
|
||||
Self::AI21JambaInstructV1 => "AI21 Jamba Instruct",
|
||||
Self::AI21Jamba15LargeV1 => "AI21 Jamba 1.5 Large",
|
||||
Self::AI21Jamba15MiniV1 => "AI21 Jamba 1.5 Mini",
|
||||
Self::CohereCommandTextV14_4k => "Cohere Command Text V14 4K",
|
||||
Self::CohereCommandRV1 => "Cohere Command R V1",
|
||||
Self::CohereCommandRPlusV1 => "Cohere Command R Plus V1",
|
||||
Self::CohereCommandLightTextV14_4k => "Cohere Command Light Text V14 4K",
|
||||
Self::MetaLlama38BInstructV1 => "Meta Llama 3 8B Instruct V1",
|
||||
Self::MetaLlama370BInstructV1 => "Meta Llama 3 70B Instruct V1",
|
||||
Self::MetaLlama318BInstructV1_128k => "Meta Llama 3 1.8B Instruct V1 128K",
|
||||
Self::MetaLlama318BInstructV1 => "Meta Llama 3 1.8B Instruct V1",
|
||||
Self::MetaLlama3170BInstructV1_128k => "Meta Llama 3 1 70B Instruct V1 128K",
|
||||
Self::MetaLlama3170BInstructV1 => "Meta Llama 3 1 70B Instruct V1",
|
||||
Self::MetaLlama3211BInstructV1 => "Meta Llama 3 2 11B Instruct V1",
|
||||
Self::MetaLlama3290BInstructV1 => "Meta Llama 3 2 90B Instruct V1",
|
||||
Self::MetaLlama321BInstructV1 => "Meta Llama 3 2 1B Instruct V1",
|
||||
Self::MetaLlama323BInstructV1 => "Meta Llama 3 2 3B Instruct V1",
|
||||
Self::MistralMistral7BInstructV0 => "Mistral 7B Instruct V0",
|
||||
Self::MistralMixtral8x7BInstructV0 => "Mistral Mixtral 8x7B Instruct V0",
|
||||
Self::MistralMistralLarge2402V1 => "Mistral Large 2402 V1",
|
||||
Self::MistralMistralSmall2402V1 => "Mistral Small 2402 V1",
|
||||
Self::Custom {
|
||||
display_name, name, ..
|
||||
} => display_name.as_deref().unwrap_or(name),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
Self::Claude3_5Sonnet
|
||||
| Self::Claude3Opus
|
||||
| Self::Claude3Sonnet
|
||||
| Self::Claude3_5Haiku => 200_000,
|
||||
Self::Custom { max_tokens, .. } => *max_tokens,
|
||||
_ => 200_000,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_output_tokens(&self) -> u32 {
|
||||
match self {
|
||||
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096,
|
||||
Self::Claude3_5Sonnet => 8_192,
|
||||
Self::Custom {
|
||||
max_output_tokens, ..
|
||||
} => max_output_tokens.unwrap_or(4_096),
|
||||
_ => 4_096,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_temperature(&self) -> f32 {
|
||||
match self {
|
||||
Self::Claude3_5Sonnet
|
||||
| Self::Claude3Opus
|
||||
| Self::Claude3Sonnet
|
||||
| Self::Claude3_5Haiku => 1.0,
|
||||
Self::Custom {
|
||||
default_temperature,
|
||||
..
|
||||
} => default_temperature.unwrap_or(1.0),
|
||||
_ => 1.0,
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue