Allow using a custom model when using zed.dev (#14933)
Release Notes: - N/A
This commit is contained in:
parent
a334c69e05
commit
0155435142
6 changed files with 114 additions and 110 deletions
|
@ -2,100 +2,40 @@ use crate::LanguageModelRequest;
|
|||
pub use anthropic::Model as AnthropicModel;
|
||||
pub use ollama::Model as OllamaModel;
|
||||
pub use open_ai::Model as OpenAiModel;
|
||||
use schemars::{
|
||||
schema::{InstanceType, Metadata, Schema, SchemaObject},
|
||||
JsonSchema,
|
||||
};
|
||||
use serde::{
|
||||
de::{self, Visitor},
|
||||
Deserialize, Deserializer, Serialize, Serializer,
|
||||
};
|
||||
use std::fmt;
|
||||
use strum::{EnumIter, IntoEnumIterator};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum::EnumIter;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
|
||||
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
|
||||
pub enum CloudModel {
|
||||
#[serde(rename = "gpt-3.5-turbo")]
|
||||
Gpt3Point5Turbo,
|
||||
#[serde(rename = "gpt-4")]
|
||||
Gpt4,
|
||||
#[serde(rename = "gpt-4-turbo-preview")]
|
||||
Gpt4Turbo,
|
||||
#[serde(rename = "gpt-4o")]
|
||||
#[default]
|
||||
Gpt4Omni,
|
||||
#[serde(rename = "gpt-4o-mini")]
|
||||
Gpt4OmniMini,
|
||||
#[serde(rename = "claude-3-5-sonnet")]
|
||||
Claude3_5Sonnet,
|
||||
#[serde(rename = "claude-3-opus")]
|
||||
Claude3Opus,
|
||||
#[serde(rename = "claude-3-sonnet")]
|
||||
Claude3Sonnet,
|
||||
#[serde(rename = "claude-3-haiku")]
|
||||
Claude3Haiku,
|
||||
#[serde(rename = "gemini-1.5-pro")]
|
||||
Gemini15Pro,
|
||||
#[serde(rename = "gemini-1.5-flash")]
|
||||
Gemini15Flash,
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl Serialize for CloudModel {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(self.id())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for CloudModel {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct ZedDotDevModelVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
|
||||
type Value = CloudModel;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
let model = CloudModel::iter()
|
||||
.find(|model| model.id() == value)
|
||||
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
|
||||
Ok(model)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_str(ZedDotDevModelVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonSchema for CloudModel {
|
||||
fn schema_name() -> String {
|
||||
"ZedDotDevModel".to_owned()
|
||||
}
|
||||
|
||||
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
|
||||
let variants = CloudModel::iter()
|
||||
.filter_map(|model| {
|
||||
let id = model.id();
|
||||
if id.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(id.to_string())
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
Schema::Object(SchemaObject {
|
||||
instance_type: Some(InstanceType::String.into()),
|
||||
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
|
||||
metadata: Some(Box::new(Metadata {
|
||||
title: Some("ZedDotDevModel".to_owned()),
|
||||
default: Some(CloudModel::default().id().into()),
|
||||
examples: variants.into_iter().map(Into::into).collect(),
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
#[serde(rename = "custom")]
|
||||
Custom {
|
||||
name: String,
|
||||
max_tokens: Option<usize>,
|
||||
},
|
||||
}
|
||||
|
||||
impl CloudModel {
|
||||
|
@ -112,7 +52,7 @@ impl CloudModel {
|
|||
Self::Claude3Haiku => "claude-3-haiku",
|
||||
Self::Gemini15Pro => "gemini-1.5-pro",
|
||||
Self::Gemini15Flash => "gemini-1.5-flash",
|
||||
Self::Custom(id) => id,
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -129,7 +69,7 @@ impl CloudModel {
|
|||
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||
Self::Gemini15Pro => "Gemini 1.5 Pro",
|
||||
Self::Gemini15Flash => "Gemini 1.5 Flash",
|
||||
Self::Custom(id) => id.as_str(),
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -145,14 +85,20 @@ impl CloudModel {
|
|||
| Self::Claude3Haiku => 200000,
|
||||
Self::Gemini15Pro => 128000,
|
||||
Self::Gemini15Flash => 32000,
|
||||
Self::Custom(_) => 4096, // TODO: Make this configurable
|
||||
Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
|
||||
match self {
|
||||
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
|
||||
request.preprocess_anthropic()
|
||||
Self::Claude3Opus
|
||||
| Self::Claude3Sonnet
|
||||
| Self::Claude3Haiku
|
||||
| Self::Claude3_5Sonnet => {
|
||||
request.preprocess_anthropic();
|
||||
}
|
||||
Self::Custom { name, .. } if name.starts_with("anthropic/") => {
|
||||
request.preprocess_anthropic();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue