assistant: Overhaul provider infrastructure (#14929)

<img width="624" alt="image"
src="https://github.com/user-attachments/assets/f492b0bd-14c3-49e2-b2ff-dc78e52b0815">

- [x] Correctly set custom model token count
- [x] How to count tokens for Gemini models?
- [x] Feature flag zed.dev provider
- [x] Figure out how to configure custom models
- [ ] Update docs

Release Notes:

- Added support for quickly switching between multiple language model
providers in the assistant panel

---------

Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2024-07-23 19:48:41 +02:00 committed by GitHub
parent 17ef9a367f
commit d0f52e90e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
55 changed files with 2757 additions and 2023 deletions

View file

@ -1,7 +1,7 @@
use crate::{
assistant_settings::AssistantSettings, humanize_token_count,
prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent,
CompletionProvider,
LanguageModelCompletionProvider,
};
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
@ -17,7 +17,9 @@ use gpui::{
Subscription, Task, TextStyle, UpdateGlobal, View, WeakView,
};
use language::Buffer;
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use settings::{update_settings_file, Settings};
use std::{
cmp,
@ -215,8 +217,6 @@ impl TerminalInlineAssistant {
) -> Result<LanguageModelRequest> {
let assist = self.assists.get(&assist_id).context("invalid assist")?;
let model = CompletionProvider::global(cx).model();
let shell = std::env::var("SHELL").ok();
let working_directory = assist
.terminal
@ -268,7 +268,6 @@ impl TerminalInlineAssistant {
});
Ok(LanguageModelRequest {
model,
messages,
stop: Vec::new(),
temperature: 1.0,
@ -559,24 +558,39 @@ impl Render for PromptEditor {
PopoverMenu::new("model-switcher")
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models() {
for available_model in
LanguageModelRegistry::read_global(cx).available_models(cx)
{
menu = menu.custom_entry(
{
let model = model.clone();
let model_name = available_model.name().0.clone();
let provider =
available_model.provider_name().0.clone();
move |_| {
Label::new(model.display_name())
.into_any_element()
h_flex()
.w_full()
.justify_between()
.child(Label::new(model_name.clone()))
.child(
div().ml_4().child(
Label::new(provider.clone())
.color(Color::Muted),
),
)
.into_any()
}
},
{
let fs = fs.clone();
let model = model.clone();
let model = available_model.clone();
move |cx| {
let model = model.clone();
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings| settings.set_model(model),
move |settings, _| {
settings.set_model(model)
},
);
}
},
@ -595,9 +609,10 @@ impl Render for PromptEditor {
Tooltip::with_meta(
format!(
"Using {}",
CompletionProvider::global(cx)
.model()
.display_name()
LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into())
),
None,
"Change Model",
@ -748,7 +763,9 @@ impl PromptEditor {
})??;
let token_count = cx
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
@ -878,7 +895,7 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = CompletionProvider::global(cx).model();
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
let token_count = self.token_count?;
let max_token_count = model.max_token_count();
@ -1023,8 +1040,12 @@ impl Codegen {
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
let telemetry = self.telemetry.clone();
let model_telemetry_id = prompt.model.telemetry_id();
let response = CompletionProvider::global(cx).stream_completion(prompt, cx);
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
let response =
LanguageModelCompletionProvider::read_global(cx).stream_completion(prompt, cx);
self.generation = cx.spawn(|this, mut cx| async move {
let response = response.await;