Allow using a custom model when using zed.dev (#14933)
Release Notes: - N/A
This commit is contained in:
parent
a334c69e05
commit
0155435142
6 changed files with 114 additions and 110 deletions
|
@ -20,6 +20,12 @@ pub enum Model {
|
|||
Claude3Sonnet,
|
||||
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
|
||||
Claude3Haiku,
|
||||
#[serde(rename = "custom")]
|
||||
Custom {
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
max_tokens: Option<usize>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Model {
|
||||
|
@ -33,30 +39,41 @@ impl Model {
|
|||
} else if id.starts_with("claude-3-haiku") {
|
||||
Ok(Self::Claude3Haiku)
|
||||
} else {
|
||||
Err(anyhow!("Invalid model id: {}", id))
|
||||
Ok(Self::Custom {
|
||||
name: id.to_string(),
|
||||
max_tokens: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &'static str {
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
|
||||
Model::Claude3Opus => "claude-3-opus-20240229",
|
||||
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
|
||||
Model::Claude3Haiku => "claude-3-opus-20240307",
|
||||
Model::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
|
||||
Self::Claude3Opus => "Claude 3 Opus",
|
||||
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
200_000
|
||||
match self {
|
||||
Self::Claude3_5Sonnet
|
||||
| Self::Claude3Opus
|
||||
| Self::Claude3Sonnet
|
||||
| Self::Claude3Haiku => 200_000,
|
||||
Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -90,6 +107,7 @@ impl From<Role> for String {
|
|||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Request {
|
||||
#[serde(serialize_with = "serialize_request_model")]
|
||||
pub model: Model,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
|
@ -97,6 +115,13 @@ pub struct Request {
|
|||
pub max_tokens: u32,
|
||||
}
|
||||
|
||||
fn serialize_request_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&model.id())
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue