Allow customization of the model used for tool calling (#15479)
We also eliminate the `completion` crate and moved its logic into `LanguageModelRegistry`. Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
1bfea9d443
commit
99bc90a372
32 changed files with 478 additions and 691 deletions
|
|
@ -4,11 +4,12 @@ use crate::{
|
|||
copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
|
||||
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
|
||||
},
|
||||
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
|
||||
LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderState,
|
||||
};
|
||||
use client::Client;
|
||||
use collections::BTreeMap;
|
||||
use gpui::{AppContext, Global, Model, ModelContext};
|
||||
use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
|
||||
use std::sync::Arc;
|
||||
use ui::Context;
|
||||
|
||||
|
|
@ -70,9 +71,19 @@ impl Global for GlobalLanguageModelRegistry {}
|
|||
|
||||
#[derive(Default)]
|
||||
pub struct LanguageModelRegistry {
|
||||
active_model: Option<ActiveModel>,
|
||||
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
|
||||
}
|
||||
|
||||
pub struct ActiveModel {
|
||||
provider: Arc<dyn LanguageModelProvider>,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
}
|
||||
|
||||
pub struct ActiveModelChanged;
|
||||
|
||||
impl EventEmitter<ActiveModelChanged> for LanguageModelRegistry {}
|
||||
|
||||
impl LanguageModelRegistry {
|
||||
pub fn global(cx: &AppContext) -> Model<Self> {
|
||||
cx.global::<GlobalLanguageModelRegistry>().0.clone()
|
||||
|
|
@ -88,6 +99,8 @@ impl LanguageModelRegistry {
|
|||
let registry = cx.new_model(|cx| {
|
||||
let mut registry = Self::default();
|
||||
registry.register_provider(fake_provider.clone(), cx);
|
||||
let model = fake_provider.provided_models(cx)[0].clone();
|
||||
registry.set_active_model(Some(model), cx);
|
||||
registry
|
||||
});
|
||||
cx.set_global(GlobalLanguageModelRegistry(registry));
|
||||
|
|
@ -136,6 +149,64 @@ impl LanguageModelRegistry {
|
|||
) -> Option<Arc<dyn LanguageModelProvider>> {
|
||||
self.providers.get(name).cloned()
|
||||
}
|
||||
|
||||
pub fn select_active_model(
|
||||
&mut self,
|
||||
provider: &LanguageModelProviderId,
|
||||
model_id: &LanguageModelId,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
let Some(provider) = self.provider(&provider) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let models = provider.provided_models(cx);
|
||||
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
|
||||
self.set_active_model(Some(model), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_active_provider(
|
||||
&mut self,
|
||||
provider: Option<Arc<dyn LanguageModelProvider>>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
self.active_model = provider.map(|provider| ActiveModel {
|
||||
provider,
|
||||
model: None,
|
||||
});
|
||||
cx.emit(ActiveModelChanged);
|
||||
}
|
||||
|
||||
pub fn set_active_model(
|
||||
&mut self,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
if let Some(model) = model {
|
||||
let provider_id = model.provider_id();
|
||||
if let Some(provider) = self.providers.get(&provider_id).cloned() {
|
||||
self.active_model = Some(ActiveModel {
|
||||
provider,
|
||||
model: Some(model),
|
||||
});
|
||||
cx.emit(ActiveModelChanged);
|
||||
} else {
|
||||
log::warn!("Active model's provider not found in registry");
|
||||
}
|
||||
} else {
|
||||
self.active_model = None;
|
||||
cx.emit(ActiveModelChanged);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
|
||||
Some(self.active_model.as_ref()?.provider.clone())
|
||||
}
|
||||
|
||||
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.active_model.as_ref()?.model.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue