Allow customization of the model used for tool calling (#15479)

We also eliminate the `completion` crate and moved its logic into
`LanguageModelRegistry`.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-30 16:18:53 +02:00 committed by GitHub
parent 1bfea9d443
commit 99bc90a372
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 478 additions and 691 deletions

View file

@ -4,11 +4,12 @@ use crate::{
copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
},
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState,
};
use client::Client;
use collections::BTreeMap;
use gpui::{AppContext, Global, Model, ModelContext};
use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
use std::sync::Arc;
use ui::Context;
@ -70,9 +71,19 @@ impl Global for GlobalLanguageModelRegistry {}
#[derive(Default)]
pub struct LanguageModelRegistry {
active_model: Option<ActiveModel>,
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
}
pub struct ActiveModel {
provider: Arc<dyn LanguageModelProvider>,
model: Option<Arc<dyn LanguageModel>>,
}
pub struct ActiveModelChanged;
impl EventEmitter<ActiveModelChanged> for LanguageModelRegistry {}
impl LanguageModelRegistry {
pub fn global(cx: &AppContext) -> Model<Self> {
cx.global::<GlobalLanguageModelRegistry>().0.clone()
@ -88,6 +99,8 @@ impl LanguageModelRegistry {
let registry = cx.new_model(|cx| {
let mut registry = Self::default();
registry.register_provider(fake_provider.clone(), cx);
let model = fake_provider.provided_models(cx)[0].clone();
registry.set_active_model(Some(model), cx);
registry
});
cx.set_global(GlobalLanguageModelRegistry(registry));
@ -136,6 +149,64 @@ impl LanguageModelRegistry {
) -> Option<Arc<dyn LanguageModelProvider>> {
self.providers.get(name).cloned()
}
pub fn select_active_model(
&mut self,
provider: &LanguageModelProviderId,
model_id: &LanguageModelId,
cx: &mut ModelContext<Self>,
) {
let Some(provider) = self.provider(&provider) else {
return;
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_active_model(Some(model), cx);
}
}
pub fn set_active_provider(
&mut self,
provider: Option<Arc<dyn LanguageModelProvider>>,
cx: &mut ModelContext<Self>,
) {
self.active_model = provider.map(|provider| ActiveModel {
provider,
model: None,
});
cx.emit(ActiveModelChanged);
}
pub fn set_active_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
cx: &mut ModelContext<Self>,
) {
if let Some(model) = model {
let provider_id = model.provider_id();
if let Some(provider) = self.providers.get(&provider_id).cloned() {
self.active_model = Some(ActiveModel {
provider,
model: Some(model),
});
cx.emit(ActiveModelChanged);
} else {
log::warn!("Active model's provider not found in registry");
}
} else {
self.active_model = None;
cx.emit(ActiveModelChanged);
}
}
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
Some(self.active_model.as_ref()?.provider.clone())
}
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
self.active_model.as_ref()?.model.clone()
}
}
#[cfg(test)]