assistant: Fix issues when configuring different providers (#15072)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Bennet Bo Fenner 2024-07-24 11:21:31 +02:00 committed by GitHub
parent ba6c36f370
commit af4b9805c9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 225 additions and 148 deletions

View file

@ -25,6 +25,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
pub trait LanguageModel: Send + Sync {
fn id(&self) -> LanguageModelId;
fn name(&self) -> LanguageModelName;
fn provider_id(&self) -> LanguageModelProviderId;
fn provider_name(&self) -> LanguageModelProviderName;
fn telemetry_id(&self) -> String;
@ -44,8 +45,10 @@ pub trait LanguageModel: Send + Sync {
}
pub trait LanguageModelProvider: 'static {
fn id(&self) -> LanguageModelProviderId;
fn name(&self) -> LanguageModelProviderName;
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
fn is_authenticated(&self, cx: &AppContext) -> bool;
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
@ -62,6 +65,9 @@ pub struct LanguageModelId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub struct LanguageModelName(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub struct LanguageModelProviderId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub struct LanguageModelProviderName(pub SharedString);
@ -77,6 +83,12 @@ impl From<String> for LanguageModelName {
}
}
impl From<String> for LanguageModelProviderId {
fn from(value: String) -> Self {
Self(SharedString::from(value))
}
}
impl From<String> for LanguageModelProviderName {
fn from(value: String) -> Self {
Self(SharedString::from(value))

View file

@ -1,6 +1,5 @@
use anthropic::{stream_completion, Request, RequestMessage};
use anyhow::{anyhow, Result};
use collections::HashMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{
@ -9,7 +8,7 @@ use gpui::{
};
use http_client::HttpClient;
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use std::{collections::BTreeMap, sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
@ -17,11 +16,12 @@ use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelRequestMessage, Role,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
const PROVIDER_NAME: &str = "anthropic";
const PROVIDER_ID: &str = "anthropic";
const PROVIDER_NAME: &str = "Anthropic";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct AnthropicSettings {
@ -37,7 +37,6 @@ pub struct AnthropicLanguageModelProvider {
struct State {
api_key: Option<String>,
settings: AnthropicSettings,
_subscription: Subscription,
}
@ -45,9 +44,7 @@ impl AnthropicLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let state = cx.new_model(|cx| State {
api_key: None,
settings: AnthropicSettings::default(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.settings = AllLanguageModelSettings::get_global(cx).anthropic.clone();
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
});
@ -64,12 +61,16 @@ impl LanguageModelProviderState for AnthropicLanguageModelProvider {
}
impl LanguageModelProvider for AnthropicLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = HashMap::default();
let mut models = BTreeMap::default();
// Add base models from anthropic::Model::iter()
for model in anthropic::Model::iter() {
@ -79,7 +80,11 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
}
// Override with available models from settings
for model in &self.state.read(cx).settings.available_models {
for model in AllLanguageModelSettings::get_global(cx)
.anthropic
.available_models
.iter()
{
models.insert(model.id().to_string(), model.clone());
}
@ -104,7 +109,10 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
let api_url = self.state.read(cx).settings.api_url.clone();
let api_url = AllLanguageModelSettings::get_global(cx)
.anthropic
.api_url
.clone();
let state = self.state.clone();
cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
@ -132,7 +140,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let state = self.state.clone();
let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
let delete_credentials =
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
state.update(&mut cx, |this, cx| {
@ -221,6 +230,10 @@ impl LanguageModel for AnthropicModel {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
@ -249,11 +262,13 @@ impl LanguageModel for AnthropicModel {
let request = self.to_anthropic_request(request);
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(
state.api_key.clone(),
state.settings.api_url.clone(),
state.settings.low_speed_timeout,
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
@ -365,7 +380,10 @@ impl AuthenticationPrompt {
}
let write_credentials = cx.write_credentials(
&self.state.read(cx).settings.api_url,
AllLanguageModelSettings::get_global(cx)
.anthropic
.api_url
.as_str(),
"Bearer",
api_key.as_bytes(),
);

View file

@ -1,15 +1,15 @@
use super::open_ai::count_open_ai_tokens;
use crate::{
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
LanguageModelName, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
};
use anyhow::Result;
use client::Client;
use collections::HashMap;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use std::{collections::BTreeMap, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
@ -17,6 +17,7 @@ use crate::LanguageModelProvider;
use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
pub const PROVIDER_ID: &str = "zed.dev";
pub const PROVIDER_NAME: &str = "zed.dev";
#[derive(Default, Clone, Debug, PartialEq)]
@ -33,7 +34,6 @@ pub struct CloudLanguageModelProvider {
struct State {
client: Arc<Client>,
status: client::Status,
settings: ZedDotDevSettings,
_subscription: Subscription,
}
@ -52,9 +52,7 @@ impl CloudLanguageModelProvider {
let state = cx.new_model(|cx| State {
client: client.clone(),
status,
settings: ZedDotDevSettings::default(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.settings = AllLanguageModelSettings::get_global(cx).zed_dot_dev.clone();
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
});
@ -90,12 +88,16 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
}
impl LanguageModelProvider for CloudLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = HashMap::default();
let mut models = BTreeMap::default();
// Add base models from CloudModel::iter()
for model in CloudModel::iter() {
@ -105,7 +107,10 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
}
// Override with available models from settings
for model in &self.state.read(cx).settings.available_models {
for model in &AllLanguageModelSettings::get_global(cx)
.zed_dot_dev
.available_models
{
models.insert(model.id().to_string(), model.clone());
}
@ -156,6 +161,10 @@ impl LanguageModel for CloudLanguageModel {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
@ -187,6 +196,9 @@ impl LanguageModel for CloudLanguageModel {
| CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
count_anthropic_tokens(request, cx)
}
_ => {
let request = self.client.request(proto::CountTokensWithLanguageModel {
model: self.model.id().to_string(),

View file

@ -5,7 +5,8 @@ use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, St
use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest,
};
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
use http_client::Result;
@ -19,8 +20,12 @@ pub fn language_model_name() -> LanguageModelName {
LanguageModelName::from("Fake".to_string())
}
pub fn provider_id() -> LanguageModelProviderId {
LanguageModelProviderId::from("fake".to_string())
}
pub fn provider_name() -> LanguageModelProviderName {
LanguageModelProviderName::from("fake".to_string())
LanguageModelProviderName::from("Fake".to_string())
}
#[derive(Clone, Default)]
@ -35,6 +40,10 @@ impl LanguageModelProviderState for FakeLanguageModelProvider {
}
impl LanguageModelProvider for FakeLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
provider_id()
}
fn name(&self) -> LanguageModelProviderName {
provider_name()
}
@ -125,6 +134,10 @@ impl LanguageModel for FakeLanguageModel {
language_model_name()
}
fn provider_id(&self) -> LanguageModelProviderId {
provider_id()
}
fn provider_name(&self) -> LanguageModelProviderName {
provider_name()
}

View file

@ -2,21 +2,24 @@ use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
use http_client::HttpClient;
use ollama::{get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest};
use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
};
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, ElevationIndex};
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, Role,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
};
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
const PROVIDER_NAME: &str = "ollama";
const PROVIDER_ID: &str = "ollama";
const PROVIDER_NAME: &str = "Ollama";
#[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings {
@ -32,14 +35,14 @@ pub struct OllamaLanguageModelProvider {
struct State {
http_client: Arc<dyn HttpClient>,
available_models: Vec<ollama::Model>,
settings: OllamaSettings,
_subscription: Subscription,
}
impl State {
fn fetch_models(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = self.settings.api_url.clone();
let api_url = settings.api_url.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|this, mut cx| async move {
@ -66,23 +69,25 @@ impl State {
impl OllamaLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
Self {
let this = Self {
http_client: http_client.clone(),
state: cx.new_model(|cx| State {
http_client,
available_models: Default::default(),
settings: OllamaSettings::default(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
this.fetch_models(cx).detach_and_log_err(cx);
cx.notify();
}),
}),
}
};
this.fetch_models(cx).detach_and_log_err(cx);
this
}
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = self.state.read(cx).settings.api_url.clone();
let api_url = settings.api_url.clone();
let state = self.state.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
@ -117,6 +122,10 @@ impl LanguageModelProviderState for OllamaLanguageModelProvider {
}
impl LanguageModelProvider for OllamaLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
@ -131,12 +140,20 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
id: LanguageModelId::from(model.name.clone()),
model: model.clone(),
http_client: self.http_client.clone(),
state: self.state.clone(),
}) as Arc<dyn LanguageModel>
})
.collect()
}
fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let id = model.id().0.to_string();
cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
.detach_and_log_err(cx);
}
fn is_authenticated(&self, cx: &AppContext) -> bool {
!self.state.read(cx).available_models.is_empty()
}
@ -167,7 +184,6 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
pub struct OllamaLanguageModel {
id: LanguageModelId,
model: ollama::Model,
state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>,
}
@ -211,6 +227,14 @@ impl LanguageModel for OllamaLanguageModel {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
@ -219,10 +243,6 @@ impl LanguageModel for OllamaLanguageModel {
format!("ollama/{}", self.model.id())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn count_tokens(
&self,
request: LanguageModelRequest,
@ -248,11 +268,9 @@ impl LanguageModel for OllamaLanguageModel {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
let Ok((api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
(
state.settings.api_url.clone(),
state.settings.low_speed_timeout,
)
let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
(settings.api_url.clone(), settings.low_speed_timeout)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
use collections::HashMap;
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
use gpui::{
@ -17,11 +17,12 @@ use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, Role,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
};
const PROVIDER_NAME: &str = "openai";
const PROVIDER_ID: &str = "openai";
const PROVIDER_NAME: &str = "OpenAI";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiSettings {
@ -37,7 +38,6 @@ pub struct OpenAiLanguageModelProvider {
struct State {
api_key: Option<String>,
settings: OpenAiSettings,
_subscription: Subscription,
}
@ -45,9 +45,7 @@ impl OpenAiLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let state = cx.new_model(|cx| State {
api_key: None,
settings: OpenAiSettings::default(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.settings = AllLanguageModelSettings::get_global(cx).open_ai.clone();
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
cx.notify();
}),
});
@ -65,12 +63,16 @@ impl LanguageModelProviderState for OpenAiLanguageModelProvider {
}
impl LanguageModelProvider for OpenAiLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = HashMap::default();
let mut models = BTreeMap::default();
// Add base models from open_ai::Model::iter()
for model in open_ai::Model::iter() {
@ -80,7 +82,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
}
// Override with available models from settings
for model in &self.state.read(cx).settings.available_models {
for model in &AllLanguageModelSettings::get_global(cx)
.openai
.available_models
{
models.insert(model.id().to_string(), model.clone());
}
@ -105,7 +110,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
let api_url = self.state.read(cx).settings.api_url.clone();
let api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
let state = self.state.clone();
cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
@ -131,7 +139,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
let settings = &AllLanguageModelSettings::get_global(cx).openai;
let delete_credentials = cx.delete_credentials(&settings.api_url);
let state = self.state.clone();
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
@ -188,6 +197,10 @@ impl LanguageModel for OpenAiLanguageModel {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
@ -216,11 +229,12 @@ impl LanguageModel for OpenAiLanguageModel {
let request = self.to_open_ai_request(request);
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).openai;
(
state.api_key.clone(),
state.settings.api_url.clone(),
state.settings.low_speed_timeout,
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
@ -307,11 +321,9 @@ impl AuthenticationPrompt {
return;
}
let write_credentials = cx.write_credentials(
&self.state.read(cx).settings.api_url,
"Bearer",
api_key.as_bytes(),
);
let settings = &AllLanguageModelSettings::get_global(cx).openai;
let write_credentials =
cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
let state = self.state.clone();
cx.spawn(|_, mut cx| async move {
write_credentials.await?;

View file

@ -9,7 +9,7 @@ use crate::{
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
},
LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
};
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
@ -48,7 +48,7 @@ fn register_language_model_providers(
registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
} else {
registry.unregister_provider(
&LanguageModelProviderName::from(
&LanguageModelProviderId::from(
crate::provider::cloud::PROVIDER_NAME.to_string(),
),
cx,
@ -65,7 +65,7 @@ impl Global for GlobalLanguageModelRegistry {}
#[derive(Default)]
pub struct LanguageModelRegistry {
providers: HashMap<LanguageModelProviderName, Arc<dyn LanguageModelProvider>>,
providers: HashMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
}
impl LanguageModelRegistry {
@ -94,7 +94,7 @@ impl LanguageModelRegistry {
provider: T,
cx: &mut ModelContext<Self>,
) {
let name = provider.name();
let name = provider.id();
if let Some(subscription) = provider.subscribe(cx) {
subscription.detach();
@ -106,7 +106,7 @@ impl LanguageModelRegistry {
pub fn unregister_provider(
&mut self,
name: &LanguageModelProviderName,
name: &LanguageModelProviderId,
cx: &mut ModelContext<Self>,
) {
if self.providers.remove(name).is_some() {
@ -116,7 +116,7 @@ impl LanguageModelRegistry {
pub fn providers(
&self,
) -> impl Iterator<Item = (&LanguageModelProviderName, &Arc<dyn LanguageModelProvider>)> {
) -> impl Iterator<Item = (&LanguageModelProviderId, &Arc<dyn LanguageModelProvider>)> {
self.providers.iter()
}
@ -130,7 +130,7 @@ impl LanguageModelRegistry {
pub fn available_models_grouped_by_provider(
&self,
cx: &AppContext,
) -> HashMap<LanguageModelProviderName, Vec<Arc<dyn LanguageModel>>> {
) -> HashMap<LanguageModelProviderId, Vec<Arc<dyn LanguageModel>>> {
self.providers
.iter()
.map(|(name, provider)| (name.clone(), provider.provided_models(cx)))
@ -139,7 +139,7 @@ impl LanguageModelRegistry {
pub fn provider(
&self,
name: &LanguageModelProviderName,
name: &LanguageModelProviderId,
) -> Option<Arc<dyn LanguageModelProvider>> {
self.providers.get(name).cloned()
}
@ -160,10 +160,10 @@ mod tests {
let providers = registry.read(cx).providers().collect::<Vec<_>>();
assert_eq!(providers.len(), 1);
assert_eq!(providers[0].0, &crate::provider::fake::provider_name());
assert_eq!(providers[0].0, &crate::provider::fake::provider_id());
registry.update(cx, |registry, cx| {
registry.unregister_provider(&crate::provider::fake::provider_name(), cx);
registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
});
let providers = registry.read(cx).providers().collect::<Vec<_>>();

View file

@ -21,9 +21,9 @@ pub fn init(cx: &mut AppContext) {
#[derive(Default)]
pub struct AllLanguageModelSettings {
pub open_ai: OpenAiSettings,
pub anthropic: AnthropicSettings,
pub ollama: OllamaSettings,
pub openai: OpenAiSettings,
pub zed_dot_dev: ZedDotDevSettings,
}
@ -31,7 +31,7 @@ pub struct AllLanguageModelSettings {
pub struct AllLanguageModelSettingsContent {
pub anthropic: Option<AnthropicSettingsContent>,
pub ollama: Option<OllamaSettingsContent>,
pub open_ai: Option<OpenAiSettingsContent>,
pub openai: Option<OpenAiSettingsContent>,
#[serde(rename = "zed.dev")]
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
}
@ -110,21 +110,21 @@ impl settings::Settings for AllLanguageModelSettings {
}
merge(
&mut settings.open_ai.api_url,
value.open_ai.as_ref().and_then(|s| s.api_url.clone()),
&mut settings.openai.api_url,
value.openai.as_ref().and_then(|s| s.api_url.clone()),
);
if let Some(low_speed_timeout_in_seconds) = value
.open_ai
.openai
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
{
settings.open_ai.low_speed_timeout =
settings.openai.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
merge(
&mut settings.open_ai.available_models,
&mut settings.openai.available_models,
value
.open_ai
.openai
.as_ref()
.and_then(|s| s.available_models.clone()),
);