From cdd07fdf29ab03ffd64a0c9ed1404e2addd80020 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Mon, 24 Feb 2025 15:28:20 -0500 Subject: [PATCH] Add `aws_http_client` and `bedrock` crates (#25490) This PR adds new `aws_http_client` and `bedrock` crates for supporting AWS Bedrock. Pulling out of https://github.com/zed-industries/zed/pull/21092 to make it easier to land. Release Notes: - N/A --------- Co-authored-by: Shardul Vaidya Co-authored-by: Anthony Eid --- Cargo.lock | 51 +++++ Cargo.toml | 9 + crates/aws_http_client/Cargo.toml | 22 ++ crates/aws_http_client/LICENSE-GPL | 1 + crates/aws_http_client/src/aws_http_client.rs | 118 +++++++++++ crates/bedrock/Cargo.toml | 28 +++ crates/bedrock/LICENSE-GPL | 1 + crates/bedrock/src/bedrock.rs | 166 +++++++++++++++ crates/bedrock/src/models.rs | 199 ++++++++++++++++++ 9 files changed, 595 insertions(+) create mode 100644 crates/aws_http_client/Cargo.toml create mode 120000 crates/aws_http_client/LICENSE-GPL create mode 100644 crates/aws_http_client/src/aws_http_client.rs create mode 100644 crates/bedrock/Cargo.toml create mode 120000 crates/bedrock/LICENSE-GPL create mode 100644 crates/bedrock/src/bedrock.rs create mode 100644 crates/bedrock/src/models.rs diff --git a/Cargo.lock b/Cargo.lock index 76a3fdcc4d..9cade7daf8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1269,6 +1269,30 @@ dependencies = [ "uuid", ] +[[package]] +name = "aws-sdk-bedrockruntime" +version = "1.74.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6938541d1948a543bca23303fec4cff9c36bf0e63b8fa3ae1b337bcb9d5b81af" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-eventstream", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes 1.10.0", + "fastrand 2.3.0", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", +] + [[package]] name = "aws-sdk-kinesis" version = "1.61.0" @@ -1598,6 +1622,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "aws_http_client" +version = "0.1.0" +dependencies = [ + "aws-smithy-runtime-api", + "aws-smithy-types", + "futures 0.3.31", + "http_client", + "tokio", +] + [[package]] name = "axum" version = "0.6.20" @@ -1727,6 +1762,22 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "bedrock" +version = "0.1.0" +dependencies = [ + "anyhow", + "aws-sdk-bedrockruntime", + "aws-smithy-types", + "futures 0.3.31", + "schemars", + "serde", + "serde_json", + "strum", + "thiserror 1.0.69", + "tokio", +] + [[package]] name = "bigdecimal" version = "0.4.7" diff --git a/Cargo.toml b/Cargo.toml index 448f653a51..f4e90eae46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,8 @@ members = [ "crates/audio", "crates/auto_update", "crates/auto_update_ui", + "crates/aws_http_client", + "crates/bedrock", "crates/breadcrumbs", "crates/buffer_diff", "crates/call", @@ -218,6 +220,8 @@ assistant_tools = { path = "crates/assistant_tools" } audio = { path = "crates/audio" } auto_update = { path = "crates/auto_update" } auto_update_ui = { path = "crates/auto_update_ui" } +aws_http_client = { path = "crates/aws_http_client" } +bedrock = { path = "crates/bedrock" } breadcrumbs = { path = "crates/breadcrumbs" } call = { path = "crates/call" } channel = { path = "crates/channel" } @@ -382,6 +386,11 @@ async-trait = "0.1" async-tungstenite = "0.28" async-watch = "0.3.1" async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] } +aws-config = { version = "1.5.16", features = ["behavior-version-latest"] } +aws-credential-types = { version = "1.2.1", features = ["hardcoded-credentials"] } +aws-sdk-bedrockruntime = { version = "1.73.0", features = ["behavior-version-latest"] } +aws-smithy-runtime-api = { version = "1.7.3", features = ["http-1x", "client"] } +aws-smithy-types = { version = "1.2.13", features = ["http-body-1-x"] } base64 = "0.22" bitflags = "2.6.0" blade-graphics = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" } diff --git a/crates/aws_http_client/Cargo.toml b/crates/aws_http_client/Cargo.toml new file mode 100644 index 0000000000..8715fe1b56 --- /dev/null +++ b/crates/aws_http_client/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "aws_http_client" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/aws_http_client.rs" + +[features] +default = [] + +[dependencies] +aws-smithy-runtime-api.workspace = true +aws-smithy-types.workspace = true +futures.workspace = true +http_client.workspace = true +tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } diff --git a/crates/aws_http_client/LICENSE-GPL b/crates/aws_http_client/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/aws_http_client/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/aws_http_client/src/aws_http_client.rs b/crates/aws_http_client/src/aws_http_client.rs new file mode 100644 index 0000000000..f992806581 --- /dev/null +++ b/crates/aws_http_client/src/aws_http_client.rs @@ -0,0 +1,118 @@ +use std::fmt; +use std::sync::Arc; + +use aws_smithy_runtime_api::client::http::{ + HttpClient as AwsClient, HttpConnector as AwsConnector, + HttpConnectorFuture as AwsConnectorFuture, HttpConnectorFuture, HttpConnectorSettings, + SharedHttpConnector, +}; +use aws_smithy_runtime_api::client::orchestrator::{HttpRequest as AwsHttpRequest, HttpResponse}; +use aws_smithy_runtime_api::client::result::ConnectorError; +use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; +use aws_smithy_runtime_api::http::StatusCode; +use aws_smithy_types::body::SdkBody; +use futures::AsyncReadExt; +use http_client::{AsyncBody, Inner}; +use http_client::{HttpClient, Request}; +use tokio::runtime::Handle; + +struct AwsHttpConnector { + client: Arc, + handle: Handle, +} + +impl std::fmt::Debug for AwsHttpConnector { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AwsHttpConnector").finish() + } +} + +impl AwsConnector for AwsHttpConnector { + fn call(&self, request: AwsHttpRequest) -> AwsConnectorFuture { + let req = match request.try_into_http1x() { + Ok(req) => req, + Err(err) => { + return HttpConnectorFuture::ready(Err(ConnectorError::other(err.into(), None))) + } + }; + + let (parts, body) = req.into_parts(); + + let response = self + .client + .send(Request::from_parts(parts, convert_to_async_body(body))); + + let handle = self.handle.clone(); + + HttpConnectorFuture::new(async move { + let response = match response.await { + Ok(response) => response, + Err(err) => return Err(ConnectorError::other(err.into(), None)), + }; + let (parts, body) = response.into_parts(); + let body = convert_to_sdk_body(body, handle).await; + + Ok(HttpResponse::new( + StatusCode::try_from(parts.status.as_u16()).unwrap(), + body, + )) + }) + } +} + +#[derive(Clone)] +pub struct AwsHttpClient { + client: Arc, + handler: Handle, +} + +impl std::fmt::Debug for AwsHttpClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AwsHttpClient").finish() + } +} + +impl AwsHttpClient { + pub fn new(client: Arc, handle: Handle) -> Self { + Self { + client, + handler: handle, + } + } +} + +impl AwsClient for AwsHttpClient { + fn http_connector( + &self, + _settings: &HttpConnectorSettings, + _components: &RuntimeComponents, + ) -> SharedHttpConnector { + SharedHttpConnector::new(AwsHttpConnector { + client: self.client.clone(), + handle: self.handler.clone(), + }) + } +} + +pub async fn convert_to_sdk_body(body: AsyncBody, handle: Handle) -> SdkBody { + match body.0 { + Inner::Empty => SdkBody::empty(), + Inner::Bytes(bytes) => SdkBody::from(bytes.into_inner()), + Inner::AsyncReader(mut reader) => { + let buffer = handle.spawn(async move { + let mut buffer = Vec::new(); + let _ = reader.read_to_end(&mut buffer).await; + buffer + }); + + SdkBody::from(buffer.await.unwrap_or_default()) + } + } +} + +pub fn convert_to_async_body(body: SdkBody) -> AsyncBody { + match body.bytes() { + Some(bytes) => AsyncBody::from((*bytes).to_vec()), + None => AsyncBody::empty(), + } +} diff --git a/crates/bedrock/Cargo.toml b/crates/bedrock/Cargo.toml new file mode 100644 index 0000000000..e99f7e2cf0 --- /dev/null +++ b/crates/bedrock/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "bedrock" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/bedrock.rs" + +[features] +default = [] +schemars = ["dep:schemars"] + +[dependencies] +anyhow.workspace = true +aws-sdk-bedrockruntime = { workspace = true, features = ["behavior-version-latest"] } +aws-smithy-types = {workspace = true} +futures.workspace = true +schemars = { workspace = true, optional = true } +serde.workspace = true +serde_json.workspace = true +strum.workspace = true +thiserror.workspace = true +tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } diff --git a/crates/bedrock/LICENSE-GPL b/crates/bedrock/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/bedrock/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/bedrock/src/bedrock.rs b/crates/bedrock/src/bedrock.rs new file mode 100644 index 0000000000..fa17bc0383 --- /dev/null +++ b/crates/bedrock/src/bedrock.rs @@ -0,0 +1,166 @@ +mod models; + +use std::pin::Pin; + +use anyhow::{anyhow, Context, Error, Result}; +use aws_sdk_bedrockruntime as bedrock; +pub use aws_sdk_bedrockruntime as bedrock_client; +pub use aws_sdk_bedrockruntime::types::{ + ContentBlock as BedrockInnerContent, SpecificToolChoice as BedrockSpecificTool, + ToolChoice as BedrockToolChoice, ToolInputSchema as BedrockToolInputSchema, + ToolSpecification as BedrockTool, +}; +use aws_smithy_types::{Document, Number as AwsNumber}; +pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest; +pub use bedrock::types::{ + ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole, + ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse, + Message as BedrockMessage, ResponseStream as BedrockResponseStream, +}; +use futures::stream::{self, BoxStream, Stream}; +use serde::{Deserialize, Serialize}; +use serde_json::{Number, Value}; +use thiserror::Error; + +pub use crate::models::*; + +pub async fn complete( + client: &bedrock::Client, + request: Request, +) -> Result { + let response = bedrock::Client::converse(client) + .model_id(request.model.clone()) + .set_messages(request.messages.into()) + .send() + .await + .context("failed to send request to Bedrock"); + + match response { + Ok(output) => output + .output + .ok_or_else(|| BedrockError::Other(anyhow!("no output"))), + Err(err) => Err(BedrockError::Other(err)), + } +} + +pub async fn stream_completion( + client: bedrock::Client, + request: Request, + handle: tokio::runtime::Handle, +) -> Result>, Error> { + handle + .spawn(async move { + let response = bedrock::Client::converse_stream(&client) + .model_id(request.model.clone()) + .set_messages(request.messages.into()) + .send() + .await; + + match response { + Ok(output) => { + let stream: Pin< + Box< + dyn Stream> + + Send, + >, + > = Box::pin(stream::unfold(output.stream, |mut stream| async move { + match stream.recv().await { + Ok(Some(output)) => Some((Ok(output), stream)), + Ok(None) => None, + Err(err) => { + Some(( + // TODO: Figure out how we can capture Throttling Exceptions + Err(BedrockError::ClientError(anyhow!( + "{:?}", + aws_sdk_bedrockruntime::error::DisplayErrorContext(err) + ))), + stream, + )) + } + } + })); + Ok(stream) + } + Err(err) => Err(anyhow!( + "{:?}", + aws_sdk_bedrockruntime::error::DisplayErrorContext(err) + )), + } + }) + .await + .map_err(|err| anyhow!("failed to spawn task: {err:?}"))? +} + +pub fn aws_document_to_value(document: &Document) -> Value { + match document { + Document::Null => Value::Null, + Document::Bool(value) => Value::Bool(*value), + Document::Number(value) => match *value { + AwsNumber::PosInt(value) => Value::Number(Number::from(value)), + AwsNumber::NegInt(value) => Value::Number(Number::from(value)), + AwsNumber::Float(value) => Value::Number(Number::from_f64(value).unwrap()), + }, + Document::String(value) => Value::String(value.clone()), + Document::Array(array) => Value::Array(array.iter().map(aws_document_to_value).collect()), + Document::Object(map) => Value::Object( + map.iter() + .map(|(key, value)| (key.clone(), aws_document_to_value(value))) + .collect(), + ), + } +} + +pub fn value_to_aws_document(value: &Value) -> Document { + match value { + Value::Null => Document::Null, + Value::Bool(value) => Document::Bool(*value), + Value::Number(value) => { + if let Some(value) = value.as_u64() { + Document::Number(AwsNumber::PosInt(value)) + } else if let Some(value) = value.as_i64() { + Document::Number(AwsNumber::NegInt(value)) + } else if let Some(value) = value.as_f64() { + Document::Number(AwsNumber::Float(value)) + } else { + Document::Null + } + } + Value::String(value) => Document::String(value.clone()), + Value::Array(array) => Document::Array(array.iter().map(value_to_aws_document).collect()), + Value::Object(map) => Document::Object( + map.iter() + .map(|(key, value)| (key.clone(), value_to_aws_document(value))) + .collect(), + ), + } +} + +#[derive(Debug)] +pub struct Request { + pub model: String, + pub max_tokens: u32, + pub messages: Vec, + pub tools: Vec, + pub tool_choice: Option, + pub system: Option, + pub metadata: Option, + pub stop_sequences: Vec, + pub temperature: Option, + pub top_k: Option, + pub top_p: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Metadata { + pub user_id: Option, +} + +#[derive(Error, Debug)] +pub enum BedrockError { + #[error("client error: {0}")] + ClientError(anyhow::Error), + #[error("extension error: {0}")] + ExtensionError(anyhow::Error), + #[error(transparent)] + Other(#[from] anyhow::Error), +} diff --git a/crates/bedrock/src/models.rs b/crates/bedrock/src/models.rs new file mode 100644 index 0000000000..a8d0614e5d --- /dev/null +++ b/crates/bedrock/src/models.rs @@ -0,0 +1,199 @@ +use anyhow::anyhow; +use serde::{Deserialize, Serialize}; +use strum::EnumIter; + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] +pub enum Model { + // Anthropic models (already included) + #[default] + #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")] + Claude3_5Sonnet, + #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")] + Claude3Opus, + #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")] + Claude3Sonnet, + #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")] + Claude3_5Haiku, + // Amazon Nova Models + AmazonNovaLite, + AmazonNovaMicro, + AmazonNovaPro, + // AI21 models + AI21J2GrandeInstruct, + AI21J2JumboInstruct, + AI21J2Mid, + AI21J2MidV1, + AI21J2Ultra, + AI21J2UltraV1_8k, + AI21J2UltraV1, + AI21JambaInstructV1, + AI21Jamba15LargeV1, + AI21Jamba15MiniV1, + // Cohere models + CohereCommandTextV14_4k, + CohereCommandRV1, + CohereCommandRPlusV1, + CohereCommandLightTextV14_4k, + // Meta models + MetaLlama38BInstructV1, + MetaLlama370BInstructV1, + MetaLlama318BInstructV1_128k, + MetaLlama318BInstructV1, + MetaLlama3170BInstructV1_128k, + MetaLlama3170BInstructV1, + MetaLlama3211BInstructV1, + MetaLlama3290BInstructV1, + MetaLlama321BInstructV1, + MetaLlama323BInstructV1, + // Mistral models + MistralMistral7BInstructV0, + MistralMixtral8x7BInstructV0, + MistralMistralLarge2402V1, + MistralMistralSmall2402V1, + #[serde(rename = "custom")] + Custom { + name: String, + max_tokens: usize, + /// The name displayed in the UI, such as in the assistant panel model dropdown menu. + display_name: Option, + max_output_tokens: Option, + default_temperature: Option, + }, +} + +impl Model { + pub fn from_id(id: &str) -> anyhow::Result { + if id.starts_with("claude-3-5-sonnet") { + Ok(Self::Claude3_5Sonnet) + } else if id.starts_with("claude-3-opus") { + Ok(Self::Claude3Opus) + } else if id.starts_with("claude-3-sonnet") { + Ok(Self::Claude3Sonnet) + } else if id.starts_with("claude-3-5-haiku") { + Ok(Self::Claude3_5Haiku) + } else { + Err(anyhow!("invalid model id")) + } + } + + pub fn id(&self) -> &str { + match self { + Model::Claude3_5Sonnet => "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + Model::Claude3Opus => "us.anthropic.claude-3-opus-20240229-v1:0", + Model::Claude3Sonnet => "us.anthropic.claude-3-sonnet-20240229-v1:0", + Model::Claude3_5Haiku => "us.anthropic.claude-3-5-haiku-20241022-v1:0", + Model::AmazonNovaLite => "us.amazon.nova-lite-v1:0", + Model::AmazonNovaMicro => "us.amazon.nova-micro-v1:0", + Model::AmazonNovaPro => "us.amazon.nova-pro-v1:0", + Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct", + Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct", + Model::AI21J2Mid => "ai21.j2-mid", + Model::AI21J2MidV1 => "ai21.j2-mid-v1", + Model::AI21J2Ultra => "ai21.j2-ultra", + Model::AI21J2UltraV1_8k => "ai21.j2-ultra-v1:0:8k", + Model::AI21J2UltraV1 => "ai21.j2-ultra-v1", + Model::AI21JambaInstructV1 => "ai21.jamba-instruct-v1:0", + Model::AI21Jamba15LargeV1 => "ai21.jamba-1-5-large-v1:0", + Model::AI21Jamba15MiniV1 => "ai21.jamba-1-5-mini-v1:0", + Model::CohereCommandTextV14_4k => "cohere.command-text-v14:7:4k", + Model::CohereCommandRV1 => "cohere.command-r-v1:0", + Model::CohereCommandRPlusV1 => "cohere.command-r-plus-v1:0", + Model::CohereCommandLightTextV14_4k => "cohere.command-light-text-v14:7:4k", + Model::MetaLlama38BInstructV1 => "meta.llama3-8b-instruct-v1:0", + Model::MetaLlama370BInstructV1 => "meta.llama3-70b-instruct-v1:0", + Model::MetaLlama318BInstructV1_128k => "meta.llama3-1-8b-instruct-v1:0:128k", + Model::MetaLlama318BInstructV1 => "meta.llama3-1-8b-instruct-v1:0", + Model::MetaLlama3170BInstructV1_128k => "meta.llama3-1-70b-instruct-v1:0:128k", + Model::MetaLlama3170BInstructV1 => "meta.llama3-1-70b-instruct-v1:0", + Model::MetaLlama3211BInstructV1 => "meta.llama3-2-11b-instruct-v1:0", + Model::MetaLlama3290BInstructV1 => "meta.llama3-2-90b-instruct-v1:0", + Model::MetaLlama321BInstructV1 => "meta.llama3-2-1b-instruct-v1:0", + Model::MetaLlama323BInstructV1 => "meta.llama3-2-3b-instruct-v1:0", + Model::MistralMistral7BInstructV0 => "mistral.mistral-7b-instruct-v0:2", + Model::MistralMixtral8x7BInstructV0 => "mistral.mixtral-8x7b-instruct-v0:1", + Model::MistralMistralLarge2402V1 => "mistral.mistral-large-2402-v1:0", + Model::MistralMistralSmall2402V1 => "mistral.mistral-small-2402-v1:0", + Self::Custom { name, .. } => name, + } + } + + pub fn display_name(&self) -> &str { + match self { + Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", + Self::Claude3Opus => "Claude 3 Opus", + Self::Claude3Sonnet => "Claude 3 Sonnet", + Self::Claude3_5Haiku => "Claude 3.5 Haiku", + Self::AmazonNovaLite => "Amazon Nova Lite", + Self::AmazonNovaMicro => "Amazon Nova Micro", + Self::AmazonNovaPro => "Amazon Nova Pro", + Self::AI21J2GrandeInstruct => "AI21 Jurassic2 Grande Instruct", + Self::AI21J2JumboInstruct => "AI21 Jurassic2 Jumbo Instruct", + Self::AI21J2Mid => "AI21 Jurassic2 Mid", + Self::AI21J2MidV1 => "AI21 Jurassic2 Mid V1", + Self::AI21J2Ultra => "AI21 Jurassic2 Ultra", + Self::AI21J2UltraV1_8k => "AI21 Jurassic2 Ultra V1 8K", + Self::AI21J2UltraV1 => "AI21 Jurassic2 Ultra V1", + Self::AI21JambaInstructV1 => "AI21 Jamba Instruct", + Self::AI21Jamba15LargeV1 => "AI21 Jamba 1.5 Large", + Self::AI21Jamba15MiniV1 => "AI21 Jamba 1.5 Mini", + Self::CohereCommandTextV14_4k => "Cohere Command Text V14 4K", + Self::CohereCommandRV1 => "Cohere Command R V1", + Self::CohereCommandRPlusV1 => "Cohere Command R Plus V1", + Self::CohereCommandLightTextV14_4k => "Cohere Command Light Text V14 4K", + Self::MetaLlama38BInstructV1 => "Meta Llama 3 8B Instruct V1", + Self::MetaLlama370BInstructV1 => "Meta Llama 3 70B Instruct V1", + Self::MetaLlama318BInstructV1_128k => "Meta Llama 3 1.8B Instruct V1 128K", + Self::MetaLlama318BInstructV1 => "Meta Llama 3 1.8B Instruct V1", + Self::MetaLlama3170BInstructV1_128k => "Meta Llama 3 1 70B Instruct V1 128K", + Self::MetaLlama3170BInstructV1 => "Meta Llama 3 1 70B Instruct V1", + Self::MetaLlama3211BInstructV1 => "Meta Llama 3 2 11B Instruct V1", + Self::MetaLlama3290BInstructV1 => "Meta Llama 3 2 90B Instruct V1", + Self::MetaLlama321BInstructV1 => "Meta Llama 3 2 1B Instruct V1", + Self::MetaLlama323BInstructV1 => "Meta Llama 3 2 3B Instruct V1", + Self::MistralMistral7BInstructV0 => "Mistral 7B Instruct V0", + Self::MistralMixtral8x7BInstructV0 => "Mistral Mixtral 8x7B Instruct V0", + Self::MistralMistralLarge2402V1 => "Mistral Large 2402 V1", + Self::MistralMistralSmall2402V1 => "Mistral Small 2402 V1", + Self::Custom { + display_name, name, .. + } => display_name.as_deref().unwrap_or(name), + } + } + + pub fn max_token_count(&self) -> usize { + match self { + Self::Claude3_5Sonnet + | Self::Claude3Opus + | Self::Claude3Sonnet + | Self::Claude3_5Haiku => 200_000, + Self::Custom { max_tokens, .. } => *max_tokens, + _ => 200_000, + } + } + + pub fn max_output_tokens(&self) -> u32 { + match self { + Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096, + Self::Claude3_5Sonnet => 8_192, + Self::Custom { + max_output_tokens, .. + } => max_output_tokens.unwrap_or(4_096), + _ => 4_096, + } + } + + pub fn default_temperature(&self) -> f32 { + match self { + Self::Claude3_5Sonnet + | Self::Claude3Opus + | Self::Claude3Sonnet + | Self::Claude3_5Haiku => 1.0, + Self::Custom { + default_temperature, + .. + } => default_temperature.unwrap_or(1.0), + _ => 1.0, + } + } +}