Add language_models crate to house language model providers (#20945)

This PR adds a new `language_models` crate to house the various language
model providers.

By extracting the provider definitions out of `language_model`, we're
able to remove `language_model`'s dependency on `editor`, which improves
incremental compilation when changing `editor`.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-11-20 18:49:34 -05:00 committed by GitHub
parent 335b112abd
commit cbba44900d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 265 additions and 199 deletions

View file

@ -1,76 +1,17 @@
use crate::provider::cloud::RefreshLlmTokenListener;
use crate::{
provider::{
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
},
LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState,
};
use client::{Client, UserStore};
use collections::BTreeMap;
use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
use std::sync::Arc;
use ui::Context;
pub fn init(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) {
let registry = cx.new_model(|cx| {
let mut registry = LanguageModelRegistry::default();
register_language_model_providers(&mut registry, user_store, client, cx);
registry
});
pub fn init(cx: &mut AppContext) {
let registry = cx.new_model(|_cx| LanguageModelRegistry::default());
cx.set_global(GlobalLanguageModelRegistry(registry));
}
fn register_language_model_providers(
registry: &mut LanguageModelRegistry,
user_store: Model<UserStore>,
client: Arc<Client>,
cx: &mut ModelContext<LanguageModelRegistry>,
) {
use feature_flags::FeatureFlagAppExt;
RefreshLlmTokenListener::register(client.clone(), cx);
registry.register_provider(
AnthropicLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
OpenAiLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
OllamaLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
GoogleLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
let user_store = user_store.clone();
let client = client.clone();
LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
if enabled {
registry.register_provider(
CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
cx,
);
} else {
registry.unregister_provider(
LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()),
cx,
);
}
});
})
.detach();
}
struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
impl Global for GlobalLanguageModelRegistry {}
@ -106,8 +47,8 @@ impl LanguageModelRegistry {
}
#[cfg(any(test, feature = "test-support"))]
pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
let fake_provider = crate::provider::fake::FakeLanguageModelProvider;
pub fn test(cx: &mut AppContext) -> crate::fake_provider::FakeLanguageModelProvider {
let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
let registry = cx.new_model(|cx| {
let mut registry = Self::default();
registry.register_provider(fake_provider.clone(), cx);
@ -148,7 +89,7 @@ impl LanguageModelRegistry {
}
pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
let zed_provider_id = LanguageModelProviderId(crate::provider::cloud::PROVIDER_ID.into());
let zed_provider_id = LanguageModelProviderId("zed.dev".into());
let mut providers = Vec::with_capacity(self.providers.len());
if let Some(provider) = self.providers.get(&zed_provider_id) {
providers.push(provider.clone());
@ -269,7 +210,7 @@ impl LanguageModelRegistry {
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::fake::FakeLanguageModelProvider;
use crate::fake_provider::FakeLanguageModelProvider;
#[gpui::test]
fn test_register_providers(cx: &mut AppContext) {
@ -281,10 +222,10 @@ mod tests {
let providers = registry.read(cx).providers();
assert_eq!(providers.len(), 1);
assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
registry.update(cx, |registry, cx| {
registry.unregister_provider(crate::provider::fake::provider_id(), cx);
registry.unregister_provider(crate::fake_provider::provider_id(), cx);
});
let providers = registry.read(cx).providers();