Fix issues with Claude in Assistant2 (#12619)

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Mikayla Maki 2024-06-03 16:30:09 -07:00 committed by GitHub
parent afc0650a49
commit 3cd6719b30
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 276 additions and 213 deletions

View file

@ -14,10 +14,10 @@ use serde::{
use settings::{Settings, SettingsSources};
use strum::{EnumIter, IntoEnumIterator};
use crate::LanguageModel;
use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
pub enum ZedDotDevModel {
pub enum CloudModel {
Gpt3Point5Turbo,
Gpt4,
Gpt4Turbo,
@ -29,7 +29,7 @@ pub enum ZedDotDevModel {
Custom(String),
}
impl Serialize for ZedDotDevModel {
impl Serialize for CloudModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
@ -38,7 +38,7 @@ impl Serialize for ZedDotDevModel {
}
}
impl<'de> Deserialize<'de> for ZedDotDevModel {
impl<'de> Deserialize<'de> for CloudModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
@ -46,7 +46,7 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
struct ZedDotDevModelVisitor;
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
type Value = ZedDotDevModel;
type Value = CloudModel;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
@ -56,9 +56,9 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
where
E: de::Error,
{
let model = ZedDotDevModel::iter()
let model = CloudModel::iter()
.find(|model| model.id() == value)
.unwrap_or_else(|| ZedDotDevModel::Custom(value.to_string()));
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
Ok(model)
}
}
@ -67,13 +67,13 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
}
}
impl JsonSchema for ZedDotDevModel {
impl JsonSchema for CloudModel {
fn schema_name() -> String {
"ZedDotDevModel".to_owned()
}
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = ZedDotDevModel::iter()
let variants = CloudModel::iter()
.filter_map(|model| {
let id = model.id();
if id.is_empty() {
@ -88,7 +88,7 @@ impl JsonSchema for ZedDotDevModel {
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()),
default: Some(ZedDotDevModel::default().id().into()),
default: Some(CloudModel::default().id().into()),
examples: variants.into_iter().map(Into::into).collect(),
..Default::default()
})),
@ -97,7 +97,7 @@ impl JsonSchema for ZedDotDevModel {
}
}
impl ZedDotDevModel {
impl CloudModel {
pub fn id(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
@ -133,6 +133,15 @@ impl ZedDotDevModel {
Self::Custom(_) => 4096, // TODO: Make this configurable
}
}
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
preprocess_anthropic_request(request)
}
_ => {}
}
}
}
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
@ -147,7 +156,7 @@ pub enum AssistantDockPosition {
#[derive(Debug, PartialEq)]
pub enum AssistantProvider {
ZedDotDev {
model: ZedDotDevModel,
model: CloudModel,
},
OpenAi {
model: OpenAiModel,
@ -175,9 +184,7 @@ impl Default for AssistantProvider {
#[serde(tag = "name", rename_all = "snake_case")]
pub enum AssistantProviderContent {
#[serde(rename = "zed.dev")]
ZedDotDev {
default_model: Option<ZedDotDevModel>,
},
ZedDotDev { default_model: Option<CloudModel> },
#[serde(rename = "openai")]
OpenAi {
default_model: Option<OpenAiModel>,
@ -281,7 +288,7 @@ impl AssistantSettingsContent {
Some(AssistantProviderContent::ZedDotDev {
default_model: model,
}) => {
if let LanguageModel::ZedDotDev(new_model) = new_model {
if let LanguageModel::Cloud(new_model) = new_model {
*model = Some(new_model);
}
}
@ -302,7 +309,7 @@ impl AssistantSettingsContent {
}
}
provider => match new_model {
LanguageModel::ZedDotDev(model) => {
LanguageModel::Cloud(model) => {
*provider = Some(AssistantProviderContent::ZedDotDev {
default_model: Some(model),
})
@ -613,7 +620,7 @@ mod tests {
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::ZedDotDev {
model: ZedDotDevModel::Custom("custom".into())
model: CloudModel::Custom("custom".into())
}
);
}