Simplify language model registry + only emit change events on change (#29086)
* Now only does default fallback logic in the registry * Only emits change events when there is actually a change Release Notes: - N/A
This commit is contained in:
parent
98ceffe026
commit
d88b06a5dc
5 changed files with 119 additions and 194 deletions
|
@ -27,7 +27,6 @@ gpui.workspace = true
|
|||
http_client.workspace = true
|
||||
icons.workspace = true
|
||||
image.workspace = true
|
||||
log.workspace = true
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
parking_lot.workspace = true
|
||||
proto.workspace = true
|
||||
|
|
|
@ -25,12 +25,23 @@ pub struct LanguageModelRegistry {
|
|||
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
|
||||
}
|
||||
|
||||
pub struct SelectedModel {
|
||||
pub provider: LanguageModelProviderId,
|
||||
pub model: LanguageModelId,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ConfiguredModel {
|
||||
pub provider: Arc<dyn LanguageModelProvider>,
|
||||
pub model: Arc<dyn LanguageModel>,
|
||||
}
|
||||
|
||||
impl ConfiguredModel {
|
||||
pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
|
||||
self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
|
||||
}
|
||||
}
|
||||
|
||||
pub enum Event {
|
||||
DefaultModelChanged,
|
||||
InlineAssistantModelChanged,
|
||||
|
@ -59,7 +70,11 @@ impl LanguageModelRegistry {
|
|||
let mut registry = Self::default();
|
||||
registry.register_provider(fake_provider.clone(), cx);
|
||||
let model = fake_provider.provided_models(cx)[0].clone();
|
||||
registry.set_default_model(Some(model), cx);
|
||||
let configured_model = ConfiguredModel {
|
||||
provider: Arc::new(fake_provider.clone()),
|
||||
model,
|
||||
};
|
||||
registry.set_default_model(Some(configured_model), cx);
|
||||
registry
|
||||
});
|
||||
cx.set_global(GlobalLanguageModelRegistry(registry));
|
||||
|
@ -119,144 +134,114 @@ impl LanguageModelRegistry {
|
|||
self.providers.get(id).cloned()
|
||||
}
|
||||
|
||||
pub fn select_default_model(
|
||||
&mut self,
|
||||
provider: &LanguageModelProviderId,
|
||||
model_id: &LanguageModelId,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(provider) = self.provider(provider) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let models = provider.provided_models(cx);
|
||||
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
|
||||
self.set_default_model(Some(model), cx);
|
||||
}
|
||||
pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
|
||||
let configured_model = model.and_then(|model| self.select_model(model, cx));
|
||||
self.set_default_model(configured_model, cx);
|
||||
}
|
||||
|
||||
pub fn select_inline_assistant_model(
|
||||
&mut self,
|
||||
provider: &LanguageModelProviderId,
|
||||
model_id: &LanguageModelId,
|
||||
model: Option<&SelectedModel>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(provider) = self.provider(provider) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let models = provider.provided_models(cx);
|
||||
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
|
||||
self.set_inline_assistant_model(Some(model), cx);
|
||||
}
|
||||
let configured_model = model.and_then(|model| self.select_model(model, cx));
|
||||
self.set_inline_assistant_model(configured_model, cx);
|
||||
}
|
||||
|
||||
pub fn select_commit_message_model(
|
||||
&mut self,
|
||||
provider: &LanguageModelProviderId,
|
||||
model_id: &LanguageModelId,
|
||||
model: Option<&SelectedModel>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(provider) = self.provider(provider) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let models = provider.provided_models(cx);
|
||||
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
|
||||
self.set_commit_message_model(Some(model), cx);
|
||||
}
|
||||
let configured_model = model.and_then(|model| self.select_model(model, cx));
|
||||
self.set_commit_message_model(configured_model, cx);
|
||||
}
|
||||
|
||||
pub fn select_thread_summary_model(
|
||||
&mut self,
|
||||
provider: &LanguageModelProviderId,
|
||||
model_id: &LanguageModelId,
|
||||
model: Option<&SelectedModel>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(provider) = self.provider(provider) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let models = provider.provided_models(cx);
|
||||
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
|
||||
self.set_thread_summary_model(Some(model), cx);
|
||||
}
|
||||
let configured_model = model.and_then(|model| self.select_model(model, cx));
|
||||
self.set_thread_summary_model(configured_model, cx);
|
||||
}
|
||||
|
||||
pub fn set_default_model(
|
||||
/// Selects and sets the inline alternatives for language models based on
|
||||
/// provider name and id.
|
||||
pub fn select_inline_alternative_models(
|
||||
&mut self,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
alternatives: impl IntoIterator<Item = SelectedModel>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(model) = model {
|
||||
let provider_id = model.provider_id();
|
||||
if let Some(provider) = self.providers.get(&provider_id).cloned() {
|
||||
self.default_model = Some(ConfiguredModel { provider, model });
|
||||
cx.emit(Event::DefaultModelChanged);
|
||||
} else {
|
||||
log::warn!("Active model's provider not found in registry");
|
||||
}
|
||||
} else {
|
||||
self.default_model = None;
|
||||
cx.emit(Event::DefaultModelChanged);
|
||||
self.inline_alternatives = alternatives
|
||||
.into_iter()
|
||||
.flat_map(|alternative| {
|
||||
self.select_model(&alternative, cx)
|
||||
.map(|configured_model| configured_model.model)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
}
|
||||
|
||||
fn select_model(
|
||||
&mut self,
|
||||
selected_model: &SelectedModel,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<ConfiguredModel> {
|
||||
let provider = self.provider(&selected_model.provider)?;
|
||||
let model = provider
|
||||
.provided_models(cx)
|
||||
.iter()
|
||||
.find(|model| model.id() == selected_model.model)?
|
||||
.clone();
|
||||
Some(ConfiguredModel { provider, model })
|
||||
}
|
||||
|
||||
pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
|
||||
match (self.default_model.as_ref(), model.as_ref()) {
|
||||
(Some(old), Some(new)) if old.is_same_as(new) => {}
|
||||
(None, None) => {}
|
||||
_ => cx.emit(Event::DefaultModelChanged),
|
||||
}
|
||||
self.default_model = model;
|
||||
}
|
||||
|
||||
pub fn set_inline_assistant_model(
|
||||
&mut self,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
model: Option<ConfiguredModel>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(model) = model {
|
||||
let provider_id = model.provider_id();
|
||||
if let Some(provider) = self.providers.get(&provider_id).cloned() {
|
||||
self.inline_assistant_model = Some(ConfiguredModel { provider, model });
|
||||
cx.emit(Event::InlineAssistantModelChanged);
|
||||
} else {
|
||||
log::warn!("Inline assistant model's provider not found in registry");
|
||||
}
|
||||
} else {
|
||||
self.inline_assistant_model = None;
|
||||
cx.emit(Event::InlineAssistantModelChanged);
|
||||
match (self.inline_assistant_model.as_ref(), model.as_ref()) {
|
||||
(Some(old), Some(new)) if old.is_same_as(new) => {}
|
||||
(None, None) => {}
|
||||
_ => cx.emit(Event::InlineAssistantModelChanged),
|
||||
}
|
||||
self.inline_assistant_model = model;
|
||||
}
|
||||
|
||||
pub fn set_commit_message_model(
|
||||
&mut self,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
model: Option<ConfiguredModel>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(model) = model {
|
||||
let provider_id = model.provider_id();
|
||||
if let Some(provider) = self.providers.get(&provider_id).cloned() {
|
||||
self.commit_message_model = Some(ConfiguredModel { provider, model });
|
||||
cx.emit(Event::CommitMessageModelChanged);
|
||||
} else {
|
||||
log::warn!("Commit message model's provider not found in registry");
|
||||
}
|
||||
} else {
|
||||
self.commit_message_model = None;
|
||||
cx.emit(Event::CommitMessageModelChanged);
|
||||
match (self.commit_message_model.as_ref(), model.as_ref()) {
|
||||
(Some(old), Some(new)) if old.is_same_as(new) => {}
|
||||
(None, None) => {}
|
||||
_ => cx.emit(Event::CommitMessageModelChanged),
|
||||
}
|
||||
self.commit_message_model = model;
|
||||
}
|
||||
|
||||
pub fn set_thread_summary_model(
|
||||
&mut self,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
model: Option<ConfiguredModel>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(model) = model {
|
||||
let provider_id = model.provider_id();
|
||||
if let Some(provider) = self.providers.get(&provider_id).cloned() {
|
||||
self.thread_summary_model = Some(ConfiguredModel { provider, model });
|
||||
cx.emit(Event::ThreadSummaryModelChanged);
|
||||
} else {
|
||||
log::warn!("Thread summary model's provider not found in registry");
|
||||
}
|
||||
} else {
|
||||
self.thread_summary_model = None;
|
||||
cx.emit(Event::ThreadSummaryModelChanged);
|
||||
match (self.thread_summary_model.as_ref(), model.as_ref()) {
|
||||
(Some(old), Some(new)) if old.is_same_as(new) => {}
|
||||
(None, None) => {}
|
||||
_ => cx.emit(Event::ThreadSummaryModelChanged),
|
||||
}
|
||||
self.thread_summary_model = model;
|
||||
}
|
||||
|
||||
pub fn default_model(&self) -> Option<ConfiguredModel> {
|
||||
|
@ -286,30 +271,6 @@ impl LanguageModelRegistry {
|
|||
.or_else(|| self.default_model())
|
||||
}
|
||||
|
||||
/// Selects and sets the inline alternatives for language models based on
|
||||
/// provider name and id.
|
||||
pub fn select_inline_alternative_models(
|
||||
&mut self,
|
||||
alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let mut selected_alternatives = Vec::new();
|
||||
|
||||
for (provider_id, model_id) in alternatives {
|
||||
if let Some(provider) = self.providers.get(&provider_id) {
|
||||
if let Some(model) = provider
|
||||
.provided_models(cx)
|
||||
.iter()
|
||||
.find(|m| m.id() == model_id)
|
||||
{
|
||||
selected_alternatives.push(model.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.inline_alternatives = selected_alternatives;
|
||||
}
|
||||
|
||||
/// The models to use for inline assists. Returns the union of the active
|
||||
/// model and all inline alternatives. When there are multiple models, the
|
||||
/// user will be able to cycle through results.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue