Add language_models
crate to house language model providers (#20945)
This PR adds a new `language_models` crate to house the various language model providers. By extracting the provider definitions out of `language_model`, we're able to remove `language_model`'s dependency on `editor`, which improves incremental compilation when changing `editor`. Release Notes: - N/A
This commit is contained in:
parent
335b112abd
commit
cbba44900d
27 changed files with 265 additions and 199 deletions
|
@ -1,76 +1,17 @@
|
|||
use crate::provider::cloud::RefreshLlmTokenListener;
|
||||
use crate::{
|
||||
provider::{
|
||||
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
|
||||
copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
|
||||
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
|
||||
},
|
||||
LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderState,
|
||||
};
|
||||
use client::{Client, UserStore};
|
||||
use collections::BTreeMap;
|
||||
use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
|
||||
use std::sync::Arc;
|
||||
use ui::Context;
|
||||
|
||||
pub fn init(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) {
|
||||
let registry = cx.new_model(|cx| {
|
||||
let mut registry = LanguageModelRegistry::default();
|
||||
register_language_model_providers(&mut registry, user_store, client, cx);
|
||||
registry
|
||||
});
|
||||
pub fn init(cx: &mut AppContext) {
|
||||
let registry = cx.new_model(|_cx| LanguageModelRegistry::default());
|
||||
cx.set_global(GlobalLanguageModelRegistry(registry));
|
||||
}
|
||||
|
||||
fn register_language_model_providers(
|
||||
registry: &mut LanguageModelRegistry,
|
||||
user_store: Model<UserStore>,
|
||||
client: Arc<Client>,
|
||||
cx: &mut ModelContext<LanguageModelRegistry>,
|
||||
) {
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
|
||||
RefreshLlmTokenListener::register(client.clone(), cx);
|
||||
|
||||
registry.register_provider(
|
||||
AnthropicLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
OpenAiLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
OllamaLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
GoogleLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
|
||||
|
||||
cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
|
||||
let user_store = user_store.clone();
|
||||
let client = client.clone();
|
||||
LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
|
||||
if enabled {
|
||||
registry.register_provider(
|
||||
CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
|
||||
cx,
|
||||
);
|
||||
} else {
|
||||
registry.unregister_provider(
|
||||
LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
});
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
|
||||
|
||||
impl Global for GlobalLanguageModelRegistry {}
|
||||
|
@ -106,8 +47,8 @@ impl LanguageModelRegistry {
|
|||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
|
||||
let fake_provider = crate::provider::fake::FakeLanguageModelProvider;
|
||||
pub fn test(cx: &mut AppContext) -> crate::fake_provider::FakeLanguageModelProvider {
|
||||
let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
|
||||
let registry = cx.new_model(|cx| {
|
||||
let mut registry = Self::default();
|
||||
registry.register_provider(fake_provider.clone(), cx);
|
||||
|
@ -148,7 +89,7 @@ impl LanguageModelRegistry {
|
|||
}
|
||||
|
||||
pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
|
||||
let zed_provider_id = LanguageModelProviderId(crate::provider::cloud::PROVIDER_ID.into());
|
||||
let zed_provider_id = LanguageModelProviderId("zed.dev".into());
|
||||
let mut providers = Vec::with_capacity(self.providers.len());
|
||||
if let Some(provider) = self.providers.get(&zed_provider_id) {
|
||||
providers.push(provider.clone());
|
||||
|
@ -269,7 +210,7 @@ impl LanguageModelRegistry {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::fake::FakeLanguageModelProvider;
|
||||
use crate::fake_provider::FakeLanguageModelProvider;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_register_providers(cx: &mut AppContext) {
|
||||
|
@ -281,10 +222,10 @@ mod tests {
|
|||
|
||||
let providers = registry.read(cx).providers();
|
||||
assert_eq!(providers.len(), 1);
|
||||
assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
|
||||
assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
|
||||
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.unregister_provider(crate::provider::fake::provider_id(), cx);
|
||||
registry.unregister_provider(crate::fake_provider::provider_id(), cx);
|
||||
});
|
||||
|
||||
let providers = registry.read(cx).providers();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue