assistant: Overhaul provider infrastructure (#14929)

<img width="624" alt="image"
src="https://github.com/user-attachments/assets/f492b0bd-14c3-49e2-b2ff-dc78e52b0815">

- [x] Correctly set custom model token count
- [x] How to count tokens for Gemini models?
- [x] Feature flag zed.dev provider
- [x] Figure out how to configure custom models
- [ ] Update docs

Release Notes:

- Added support for quickly switching between multiple language model
providers in the assistant panel

---------

Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2024-07-23 19:48:41 +02:00 committed by GitHub
parent 17ef9a367f
commit d0f52e90e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
55 changed files with 2757 additions and 2023 deletions

View file

@ -1,6 +1,6 @@
use crate::{
assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt,
AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, StreamingDiff,
AssistantPanel, AssistantPanelEvent, Hunk, LanguageModelCompletionProvider, StreamingDiff,
};
use anyhow::{anyhow, Context as _, Result};
use client::telemetry::Telemetry;
@ -27,7 +27,9 @@ use gpui::{
WindowContext,
};
use language::{Buffer, Point, Selection, TransactionId};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use rope::Rope;
@ -844,7 +846,10 @@ impl InlineAssistant {
}
let codegen = assist.codegen.clone();
let telemetry_id = CompletionProvider::global(cx).model().telemetry_id();
let telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
if user_prompt.trim().to_lowercase() == "delete" {
async { Ok(stream::empty().boxed()) }.boxed_local()
@ -854,7 +859,10 @@ impl InlineAssistant {
async move {
let request = request.await?;
let chunks = cx
.update(|cx| CompletionProvider::global(cx).stream_completion(request, cx))?
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx)
.stream_completion(request, cx)
})?
.await?;
Ok(chunks.boxed())
}
@ -871,8 +879,8 @@ impl InlineAssistant {
cx: &mut WindowContext,
) -> Task<Result<LanguageModelRequest>> {
cx.spawn(|mut cx| async move {
let (user_prompt, context_request, project_name, buffer, range, model) = cx
.read_global(|this: &InlineAssistant, cx: &WindowContext| {
let (user_prompt, context_request, project_name, buffer, range) =
cx.read_global(|this: &InlineAssistant, cx: &WindowContext| {
let assist = this.assists.get(&assist_id).context("invalid assist")?;
let decorations = assist.decorations.as_ref().context("invalid assist")?;
let editor = assist.editor.upgrade().context("invalid assist")?;
@ -906,15 +914,7 @@ impl InlineAssistant {
});
let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
let range = assist.codegen.read(cx).range.clone();
let model = CompletionProvider::global(cx).model();
anyhow::Ok((
user_prompt,
context_request,
project_name,
buffer,
range,
model,
))
anyhow::Ok((user_prompt, context_request, project_name, buffer, range))
})??;
let language = buffer.language_at(range.start);
@ -973,7 +973,6 @@ impl InlineAssistant {
});
Ok(LanguageModelRequest {
model,
messages,
stop: vec!["|END|>".to_string()],
temperature,
@ -1432,24 +1431,39 @@ impl Render for PromptEditor {
PopoverMenu::new("model-switcher")
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models() {
for available_model in
LanguageModelRegistry::read_global(cx).available_models(cx)
{
menu = menu.custom_entry(
{
let model = model.clone();
let model_name = available_model.name().0.clone();
let provider =
available_model.provider_name().0.clone();
move |_| {
Label::new(model.display_name())
.into_any_element()
h_flex()
.w_full()
.justify_between()
.child(Label::new(model_name.clone()))
.child(
div().ml_4().child(
Label::new(provider.clone())
.color(Color::Muted),
),
)
.into_any()
}
},
{
let fs = fs.clone();
let model = model.clone();
let model = available_model.clone();
move |cx| {
let model = model.clone();
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings| settings.set_model(model),
move |settings, _| {
settings.set_model(model)
},
);
}
},
@ -1468,9 +1482,10 @@ impl Render for PromptEditor {
Tooltip::with_meta(
format!(
"Using {}",
CompletionProvider::global(cx)
.model()
.display_name()
LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()),
),
None,
"Change Model",
@ -1668,7 +1683,9 @@ impl PromptEditor {
.await?;
let token_count = cx
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
@ -1796,7 +1813,7 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = CompletionProvider::global(cx).model();
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
let token_count = self.token_count?;
let max_token_count = model.max_token_count();
@ -2601,7 +2618,6 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
#[cfg(test)]
mod tests {
use super::*;
use completion::FakeCompletionProvider;
use futures::stream::{self};
use gpui::{Context, TestAppContext};
use indoc::indoc;
@ -2622,7 +2638,8 @@ mod tests {
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
cx.set_global(cx.update(SettingsStore::test));
cx.update(|cx| FakeCompletionProvider::setup_test(cx));
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(language_settings::init);
let text = indoc! {"
@ -2749,7 +2766,8 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
cx.update(|cx| FakeCompletionProvider::setup_test(cx));
cx.update(LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);