Fix issues with Claude in Assistant2 (#12619)

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Mikayla Maki 2024-06-03 16:30:09 -07:00 committed by GitHub
parent afc0650a49
commit 3cd6719b30
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 276 additions and 213 deletions

View file

@ -1,14 +1,14 @@
mod anthropic;
mod cloud;
#[cfg(test)]
mod fake;
mod open_ai;
mod zed;
pub use anthropic::*;
pub use cloud::*;
#[cfg(test)]
pub use fake::*;
pub use open_ai::*;
pub use zed::*;
use crate::{
assistant_settings::{AssistantProvider, AssistantSettings},
@ -25,8 +25,8 @@ use std::time::Duration;
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let mut settings_version = 0;
let provider = match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => CompletionProvider::ZedDotDev(
ZedDotDevCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
),
AssistantProvider::OpenAi {
model,
@ -87,14 +87,11 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
settings_version,
);
}
(
CompletionProvider::ZedDotDev(provider),
AssistantProvider::ZedDotDev { model },
) => {
(CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
provider.update(model.clone(), settings_version);
}
(_, AssistantProvider::ZedDotDev { model }) => {
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
*provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
model.clone(),
client.clone(),
settings_version,
@ -142,7 +139,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
pub enum CompletionProvider {
OpenAi(OpenAiCompletionProvider),
Anthropic(AnthropicCompletionProvider),
ZedDotDev(ZedDotDevCompletionProvider),
Cloud(CloudCompletionProvider),
#[cfg(test)]
Fake(FakeCompletionProvider),
}
@ -164,9 +161,9 @@ impl CompletionProvider {
.available_models()
.map(LanguageModel::Anthropic)
.collect(),
CompletionProvider::ZedDotDev(provider) => provider
CompletionProvider::Cloud(provider) => provider
.available_models()
.map(LanguageModel::ZedDotDev)
.map(LanguageModel::Cloud)
.collect(),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
@ -177,7 +174,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => provider.settings_version(),
CompletionProvider::Anthropic(provider) => provider.settings_version(),
CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
CompletionProvider::Cloud(provider) => provider.settings_version(),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
@ -187,7 +184,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
CompletionProvider::Cloud(provider) => provider.is_authenticated(),
#[cfg(test)]
CompletionProvider::Fake(_) => true,
}
@ -197,7 +194,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
CompletionProvider::Cloud(provider) => provider.authenticate(cx),
#[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())),
}
@ -207,7 +204,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
@ -217,7 +214,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
CompletionProvider::Cloud(_) => Task::ready(Ok(())),
#[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())),
}
@ -227,7 +224,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
CompletionProvider::ZedDotDev(provider) => LanguageModel::ZedDotDev(provider.model()),
CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
#[cfg(test)]
CompletionProvider::Fake(_) => LanguageModel::default(),
}
@ -241,7 +238,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
#[cfg(test)]
CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
}
@ -254,7 +251,7 @@ impl CompletionProvider {
match self {
CompletionProvider::OpenAi(provider) => provider.complete(request),
CompletionProvider::Anthropic(provider) => provider.complete(request),
CompletionProvider::ZedDotDev(provider) => provider.complete(request),
CompletionProvider::Cloud(provider) => provider.complete(request),
#[cfg(test)]
CompletionProvider::Fake(provider) => provider.complete(),
}