ai: Separate model settings for each feature (#28088)
Closes: https://github.com/zed-industries/zed/issues/20582 Allows users to select a specific model for each AI-powered feature: - Agent panel - Inline assistant - Thread summarization - Commit message generation If unspecified for a given feature, it will use the `default_model` setting. Release Notes: - Added support for configuring a specific model for each AI-powered feature --------- Co-authored-by: Danilo Leal <daniloleal09@gmail.com> Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
This commit is contained in:
parent
cf0d1e4229
commit
43cb925a59
27 changed files with 670 additions and 381 deletions
|
@ -1272,7 +1272,7 @@ impl AssistantContext {
|
|||
// Assume it will be a Chat request, even though that takes fewer tokens (and risks going over the limit),
|
||||
// because otherwise you see in the UI that your empty message has a bunch of tokens already used.
|
||||
let request = self.to_completion_request(RequestType::Chat, cx);
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
|
||||
return;
|
||||
};
|
||||
let debounce = self.token_count.is_some();
|
||||
|
@ -1284,10 +1284,12 @@ impl AssistantContext {
|
|||
.await;
|
||||
}
|
||||
|
||||
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
|
||||
let token_count = cx
|
||||
.update(|cx| model.model.count_tokens(request, cx))?
|
||||
.await?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
this.start_cache_warming(&model, cx);
|
||||
this.start_cache_warming(&model.model, cx);
|
||||
cx.notify()
|
||||
})
|
||||
}
|
||||
|
@ -2304,14 +2306,16 @@ impl AssistantContext {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Option<MessageAnchor> {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let provider = model_registry.active_provider()?;
|
||||
let model = model_registry.active_model()?;
|
||||
let model = model_registry.default_model()?;
|
||||
let last_message_id = self.get_last_valid_message_id(cx)?;
|
||||
|
||||
if !provider.is_authenticated(cx) {
|
||||
if !model.provider.is_authenticated(cx) {
|
||||
log::info!("completion provider has no credentials");
|
||||
return None;
|
||||
}
|
||||
|
||||
let model = model.model;
|
||||
|
||||
// Compute which messages to cache, including the last one.
|
||||
self.mark_cache_anchors(&model.cache_configuration(), false, cx);
|
||||
|
||||
|
@ -2940,15 +2944,12 @@ impl AssistantContext {
|
|||
}
|
||||
|
||||
pub fn summarize(&mut self, replace_old: bool, cx: &mut Context<Self>) {
|
||||
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
|
||||
return;
|
||||
};
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
|
||||
return;
|
||||
};
|
||||
|
||||
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
|
||||
if !provider.is_authenticated(cx) {
|
||||
if !model.provider.is_authenticated(cx) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2964,7 +2965,7 @@ impl AssistantContext {
|
|||
|
||||
self.pending_summary = cx.spawn(async move |this, cx| {
|
||||
async move {
|
||||
let stream = model.stream_completion_text(request, &cx);
|
||||
let stream = model.model.stream_completion_text(request, &cx);
|
||||
let mut messages = stream.await?;
|
||||
|
||||
let mut replaced = !replace_old;
|
||||
|
|
|
@ -384,7 +384,9 @@ impl ContextEditor {
|
|||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider();
|
||||
let provider = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.map(|default| default.provider);
|
||||
if provider
|
||||
.as_ref()
|
||||
.map_or(false, |provider| provider.must_accept_terms(cx))
|
||||
|
@ -2395,13 +2397,13 @@ impl ContextEditor {
|
|||
None => (ButtonStyle::Filled, None),
|
||||
};
|
||||
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider();
|
||||
let model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
|
||||
let has_configuration_error = configuration_error(cx).is_some();
|
||||
let needs_to_accept_terms = self.show_accept_terms
|
||||
&& provider
|
||||
&& model
|
||||
.as_ref()
|
||||
.map_or(false, |provider| provider.must_accept_terms(cx));
|
||||
.map_or(false, |model| model.provider.must_accept_terms(cx));
|
||||
let disabled = has_configuration_error || needs_to_accept_terms;
|
||||
|
||||
ButtonLike::new("send_button")
|
||||
|
@ -2454,7 +2456,9 @@ impl ContextEditor {
|
|||
None => (ButtonStyle::Filled, None),
|
||||
};
|
||||
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider();
|
||||
let provider = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.map(|default| default.provider);
|
||||
|
||||
let has_configuration_error = configuration_error(cx).is_some();
|
||||
let needs_to_accept_terms = self.show_accept_terms
|
||||
|
@ -2500,7 +2504,9 @@ impl ContextEditor {
|
|||
}
|
||||
|
||||
fn render_language_model_selector(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let active_model = LanguageModelRegistry::read_global(cx).active_model();
|
||||
let active_model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.map(|default| default.model);
|
||||
let focus_handle = self.editor().focus_handle(cx).clone();
|
||||
let model_name = match active_model {
|
||||
Some(model) => model.name().0,
|
||||
|
@ -3020,7 +3026,9 @@ impl EventEmitter<SearchEvent> for ContextEditor {}
|
|||
|
||||
impl Render for ContextEditor {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider();
|
||||
let provider = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.map(|default| default.provider);
|
||||
let accept_terms = if self.show_accept_terms {
|
||||
provider.as_ref().and_then(|provider| {
|
||||
provider.render_accept_terms(LanguageModelProviderTosView::PromptEditorPopup, cx)
|
||||
|
@ -3616,7 +3624,9 @@ enum TokenState {
|
|||
fn token_state(context: &Entity<AssistantContext>, cx: &App) -> Option<TokenState> {
|
||||
const WARNING_TOKEN_THRESHOLD: f32 = 0.8;
|
||||
|
||||
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()?
|
||||
.model;
|
||||
let token_count = context.read(cx).token_count()?;
|
||||
let max_token_count = model.max_token_count();
|
||||
|
||||
|
@ -3669,16 +3679,16 @@ pub enum ConfigurationError {
|
|||
}
|
||||
|
||||
fn configuration_error(cx: &App) -> Option<ConfigurationError> {
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider();
|
||||
let is_authenticated = provider
|
||||
let model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
let is_authenticated = model
|
||||
.as_ref()
|
||||
.map_or(false, |provider| provider.is_authenticated(cx));
|
||||
.map_or(false, |model| model.provider.is_authenticated(cx));
|
||||
|
||||
if provider.is_some() && is_authenticated {
|
||||
if model.is_some() && is_authenticated {
|
||||
return None;
|
||||
}
|
||||
|
||||
if provider.is_none() {
|
||||
if model.is_none() {
|
||||
return Some(ConfigurationError::NoProvider);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue