assistant: Fix issues when configuring different providers (#15072)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
ba6c36f370
commit
af4b9805c9
16 changed files with 225 additions and 148 deletions
|
@ -1,15 +1,15 @@
|
|||
use super::open_ai::count_open_ai_tokens;
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use client::Client;
|
||||
use collections::HashMap;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::prelude::*;
|
||||
|
||||
|
@ -17,6 +17,7 @@ use crate::LanguageModelProvider;
|
|||
|
||||
use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
|
||||
|
||||
pub const PROVIDER_ID: &str = "zed.dev";
|
||||
pub const PROVIDER_NAME: &str = "zed.dev";
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
|
@ -33,7 +34,6 @@ pub struct CloudLanguageModelProvider {
|
|||
struct State {
|
||||
client: Arc<Client>,
|
||||
status: client::Status,
|
||||
settings: ZedDotDevSettings,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
|
@ -52,9 +52,7 @@ impl CloudLanguageModelProvider {
|
|||
let state = cx.new_model(|cx| State {
|
||||
client: client.clone(),
|
||||
status,
|
||||
settings: ZedDotDevSettings::default(),
|
||||
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
this.settings = AllLanguageModelSettings::get_global(cx).zed_dot_dev.clone();
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
|
@ -90,12 +88,16 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
|
|||
}
|
||||
|
||||
impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let mut models = HashMap::default();
|
||||
let mut models = BTreeMap::default();
|
||||
|
||||
// Add base models from CloudModel::iter()
|
||||
for model in CloudModel::iter() {
|
||||
|
@ -105,7 +107,10 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in &self.state.read(cx).settings.available_models {
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
.zed_dot_dev
|
||||
.available_models
|
||||
{
|
||||
models.insert(model.id().to_string(), model.clone());
|
||||
}
|
||||
|
||||
|
@ -156,6 +161,10 @@ impl LanguageModel for CloudLanguageModel {
|
|||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
@ -187,6 +196,9 @@ impl LanguageModel for CloudLanguageModel {
|
|||
| CloudModel::Claude3Opus
|
||||
| CloudModel::Claude3Sonnet
|
||||
| CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
|
||||
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
|
||||
count_anthropic_tokens(request, cx)
|
||||
}
|
||||
_ => {
|
||||
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
||||
model: self.model.id().to_string(),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue