ZIm/crates/language_model/src/registry.rs
Marshall Bowers 937186da12
gpui: Don't export named Context from prelude (#21869)
This PR updates the `gpui::prelude` to not export the `Context` trait
named.

This prevents some naming clashes in downstream consumers.

Release Notes:

- N/A
2024-12-11 13:21:40 -05:00

233 lines
7.4 KiB
Rust

use crate::{
LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState,
};
use collections::BTreeMap;
use gpui::{prelude::*, AppContext, EventEmitter, Global, Model, ModelContext};
use std::sync::Arc;
pub fn init(cx: &mut AppContext) {
let registry = cx.new_model(|_cx| LanguageModelRegistry::default());
cx.set_global(GlobalLanguageModelRegistry(registry));
}
struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
impl Global for GlobalLanguageModelRegistry {}
#[derive(Default)]
pub struct LanguageModelRegistry {
active_model: Option<ActiveModel>,
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
}
pub struct ActiveModel {
provider: Arc<dyn LanguageModelProvider>,
model: Option<Arc<dyn LanguageModel>>,
}
pub enum Event {
ActiveModelChanged,
ProviderStateChanged,
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
}
impl EventEmitter<Event> for LanguageModelRegistry {}
impl LanguageModelRegistry {
pub fn global(cx: &AppContext) -> Model<Self> {
cx.global::<GlobalLanguageModelRegistry>().0.clone()
}
pub fn read_global(cx: &AppContext) -> &Self {
cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
}
#[cfg(any(test, feature = "test-support"))]
pub fn test(cx: &mut AppContext) -> crate::fake_provider::FakeLanguageModelProvider {
let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
let registry = cx.new_model(|cx| {
let mut registry = Self::default();
registry.register_provider(fake_provider.clone(), cx);
let model = fake_provider.provided_models(cx)[0].clone();
registry.set_active_model(Some(model), cx);
registry
});
cx.set_global(GlobalLanguageModelRegistry(registry));
fake_provider
}
pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
&mut self,
provider: T,
cx: &mut ModelContext<Self>,
) {
let id = provider.id();
let subscription = provider.subscribe(cx, |_, cx| {
cx.emit(Event::ProviderStateChanged);
});
if let Some(subscription) = subscription {
subscription.detach();
}
self.providers.insert(id.clone(), Arc::new(provider));
cx.emit(Event::AddedProvider(id));
}
pub fn unregister_provider(
&mut self,
id: LanguageModelProviderId,
cx: &mut ModelContext<Self>,
) {
if self.providers.remove(&id).is_some() {
cx.emit(Event::RemovedProvider(id));
}
}
pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
let zed_provider_id = LanguageModelProviderId("zed.dev".into());
let mut providers = Vec::with_capacity(self.providers.len());
if let Some(provider) = self.providers.get(&zed_provider_id) {
providers.push(provider.clone());
}
providers.extend(self.providers.values().filter_map(|p| {
if p.id() != zed_provider_id {
Some(p.clone())
} else {
None
}
}));
providers
}
pub fn available_models<'a>(
&'a self,
cx: &'a AppContext,
) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
self.providers
.values()
.flat_map(|provider| provider.provided_models(cx))
}
pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
self.providers.get(id).cloned()
}
pub fn select_active_model(
&mut self,
provider: &LanguageModelProviderId,
model_id: &LanguageModelId,
cx: &mut ModelContext<Self>,
) {
let Some(provider) = self.provider(provider) else {
return;
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_active_model(Some(model), cx);
}
}
pub fn set_active_provider(
&mut self,
provider: Option<Arc<dyn LanguageModelProvider>>,
cx: &mut ModelContext<Self>,
) {
self.active_model = provider.map(|provider| ActiveModel {
provider,
model: None,
});
cx.emit(Event::ActiveModelChanged);
}
pub fn set_active_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
cx: &mut ModelContext<Self>,
) {
if let Some(model) = model {
let provider_id = model.provider_id();
if let Some(provider) = self.providers.get(&provider_id).cloned() {
self.active_model = Some(ActiveModel {
provider,
model: Some(model),
});
cx.emit(Event::ActiveModelChanged);
} else {
log::warn!("Active model's provider not found in registry");
}
} else {
self.active_model = None;
cx.emit(Event::ActiveModelChanged);
}
}
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
Some(self.active_model.as_ref()?.provider.clone())
}
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
self.active_model.as_ref()?.model.clone()
}
/// Selects and sets the inline alternatives for language models based on
/// provider name and id.
pub fn select_inline_alternative_models(
&mut self,
alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
cx: &mut ModelContext<Self>,
) {
let mut selected_alternatives = Vec::new();
for (provider_id, model_id) in alternatives {
if let Some(provider) = self.providers.get(&provider_id) {
if let Some(model) = provider
.provided_models(cx)
.iter()
.find(|m| m.id() == model_id)
{
selected_alternatives.push(model.clone());
}
}
}
self.inline_alternatives = selected_alternatives;
}
/// The models to use for inline assists. Returns the union of the active
/// model and all inline alternatives. When there are multiple models, the
/// user will be able to cycle through results.
pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
&self.inline_alternatives
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fake_provider::FakeLanguageModelProvider;
#[gpui::test]
fn test_register_providers(cx: &mut AppContext) {
let registry = cx.new_model(|_| LanguageModelRegistry::default());
registry.update(cx, |registry, cx| {
registry.register_provider(FakeLanguageModelProvider, cx);
});
let providers = registry.read(cx).providers();
assert_eq!(providers.len(), 1);
assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
registry.update(cx, |registry, cx| {
registry.unregister_provider(crate::fake_provider::provider_id(), cx);
});
let providers = registry.read(cx).providers();
assert!(providers.is_empty());
}
}