Allow using a custom model when using zed.dev (#14933)

Release Notes:

- N/A
This commit is contained in:
Antonio Scandurra 2024-07-22 12:25:53 +02:00 committed by GitHub
parent a334c69e05
commit 0155435142
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 114 additions and 110 deletions

View file

@ -2,100 +2,40 @@ use crate::LanguageModelRequest;
pub use anthropic::Model as AnthropicModel;
pub use ollama::Model as OllamaModel;
pub use open_ai::Model as OpenAiModel;
use schemars::{
schema::{InstanceType, Metadata, Schema, SchemaObject},
JsonSchema,
};
use serde::{
de::{self, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
use std::fmt;
use strum::{EnumIter, IntoEnumIterator};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use strum::EnumIter;
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
pub enum CloudModel {
#[serde(rename = "gpt-3.5-turbo")]
Gpt3Point5Turbo,
#[serde(rename = "gpt-4")]
Gpt4,
#[serde(rename = "gpt-4-turbo-preview")]
Gpt4Turbo,
#[serde(rename = "gpt-4o")]
#[default]
Gpt4Omni,
#[serde(rename = "gpt-4o-mini")]
Gpt4OmniMini,
#[serde(rename = "claude-3-5-sonnet")]
Claude3_5Sonnet,
#[serde(rename = "claude-3-opus")]
Claude3Opus,
#[serde(rename = "claude-3-sonnet")]
Claude3Sonnet,
#[serde(rename = "claude-3-haiku")]
Claude3Haiku,
#[serde(rename = "gemini-1.5-pro")]
Gemini15Pro,
#[serde(rename = "gemini-1.5-flash")]
Gemini15Flash,
Custom(String),
}
impl Serialize for CloudModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.id())
}
}
impl<'de> Deserialize<'de> for CloudModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ZedDotDevModelVisitor;
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
type Value = CloudModel;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
let model = CloudModel::iter()
.find(|model| model.id() == value)
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
Ok(model)
}
}
deserializer.deserialize_str(ZedDotDevModelVisitor)
}
}
impl JsonSchema for CloudModel {
fn schema_name() -> String {
"ZedDotDevModel".to_owned()
}
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = CloudModel::iter()
.filter_map(|model| {
let id = model.id();
if id.is_empty() {
None
} else {
Some(id.to_string())
}
})
.collect::<Vec<_>>();
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()),
default: Some(CloudModel::default().id().into()),
examples: variants.into_iter().map(Into::into).collect(),
..Default::default()
})),
..Default::default()
})
}
#[serde(rename = "custom")]
Custom {
name: String,
max_tokens: Option<usize>,
},
}
impl CloudModel {
@ -112,7 +52,7 @@ impl CloudModel {
Self::Claude3Haiku => "claude-3-haiku",
Self::Gemini15Pro => "gemini-1.5-pro",
Self::Gemini15Flash => "gemini-1.5-flash",
Self::Custom(id) => id,
Self::Custom { name, .. } => name,
}
}
@ -129,7 +69,7 @@ impl CloudModel {
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Gemini15Pro => "Gemini 1.5 Pro",
Self::Gemini15Flash => "Gemini 1.5 Flash",
Self::Custom(id) => id.as_str(),
Self::Custom { name, .. } => name,
}
}
@ -145,14 +85,20 @@ impl CloudModel {
| Self::Claude3Haiku => 200000,
Self::Gemini15Pro => 128000,
Self::Gemini15Flash => 32000,
Self::Custom(_) => 4096, // TODO: Make this configurable
Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
}
}
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
request.preprocess_anthropic()
Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku
| Self::Claude3_5Sonnet => {
request.preprocess_anthropic();
}
Self::Custom { name, .. } if name.starts_with("anthropic/") => {
request.preprocess_anthropic();
}
_ => {}
}