assistant: Fix issues when configuring different providers (#15072)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Bennet Bo Fenner 2024-07-24 11:21:31 +02:00 committed by GitHub
parent ba6c36f370
commit af4b9805c9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 225 additions and 148 deletions

View file

@ -1,15 +1,15 @@
use super::open_ai::count_open_ai_tokens;
use crate::{
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
LanguageModelName, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
};
use anyhow::Result;
use client::Client;
use collections::HashMap;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use std::{collections::BTreeMap, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
@ -17,6 +17,7 @@ use crate::LanguageModelProvider;
use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
pub const PROVIDER_ID: &str = "zed.dev";
pub const PROVIDER_NAME: &str = "zed.dev";
#[derive(Default, Clone, Debug, PartialEq)]
@ -33,7 +34,6 @@ pub struct CloudLanguageModelProvider {
struct State {
client: Arc<Client>,
status: client::Status,
settings: ZedDotDevSettings,
_subscription: Subscription,
}
@ -52,9 +52,7 @@ impl CloudLanguageModelProvider {
let state = cx.new_model(|cx| State {
client: client.clone(),
status,
settings: ZedDotDevSettings::default(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.settings = AllLanguageModelSettings::get_global(cx).zed_dot_dev.clone();
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
});
@ -90,12 +88,16 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
}
impl LanguageModelProvider for CloudLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = HashMap::default();
let mut models = BTreeMap::default();
// Add base models from CloudModel::iter()
for model in CloudModel::iter() {
@ -105,7 +107,10 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
}
// Override with available models from settings
for model in &self.state.read(cx).settings.available_models {
for model in &AllLanguageModelSettings::get_global(cx)
.zed_dot_dev
.available_models
{
models.insert(model.id().to_string(), model.clone());
}
@ -156,6 +161,10 @@ impl LanguageModel for CloudLanguageModel {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
@ -187,6 +196,9 @@ impl LanguageModel for CloudLanguageModel {
| CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
count_anthropic_tokens(request, cx)
}
_ => {
let request = self.client.request(proto::CountTokensWithLanguageModel {
model: self.model.id().to_string(),