assistant: Overhaul provider infrastructure (#14929)

<img width="624" alt="image"
src="https://github.com/user-attachments/assets/f492b0bd-14c3-49e2-b2ff-dc78e52b0815">

- [x] Correctly set custom model token count
- [x] How to count tokens for Gemini models?
- [x] Feature flag zed.dev provider
- [x] Figure out how to configure custom models
- [ ] Update docs

Release Notes:

- Added support for quickly switching between multiple language model
providers in the assistant panel

---------

Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2024-07-23 19:48:41 +02:00 committed by GitHub
parent 17ef9a367f
commit d0f52e90e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
55 changed files with 2757 additions and 2023 deletions

View file

@ -1,6 +1,6 @@
use crate::{
prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, MessageId,
MessageStatus,
prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider,
MessageId, MessageStatus,
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
@ -1124,7 +1124,9 @@ impl Context {
.await;
let token_count = cx
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.await?;
this.update(&mut cx, |this, cx| {
@ -1308,7 +1310,9 @@ impl Context {
});
let raw_output = cx
.update(|cx| CompletionProvider::global(cx).complete(request, cx))?
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).complete(request, cx)
})?
.await?;
let operations = Self::parse_edit_operations(&raw_output);
@ -1612,13 +1616,14 @@ impl Context {
.then_some(message.id)
})?;
if !CompletionProvider::global(cx).is_authenticated() {
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
log::info!("completion provider has no credentials");
return None;
}
let request = self.to_completion_request(cx);
let stream = CompletionProvider::global(cx).stream_completion(request, cx);
let stream =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap();
@ -1698,11 +1703,14 @@ impl Context {
});
if let Some(telemetry) = this.telemetry.as_ref() {
let model = CompletionProvider::global(cx).model();
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
telemetry.report_assistant_event(
Some(this.id.0.clone()),
AssistantKind::Panel,
model.telemetry_id(),
model_telemetry_id,
response_latency,
error_message,
);
@ -1727,7 +1735,6 @@ impl Context {
.map(|message| message.to_request_message(self.buffer.read(cx)));
LanguageModelRequest {
model: CompletionProvider::global(cx).model(),
messages: messages.collect(),
stop: vec![],
temperature: 1.0,
@ -1970,7 +1977,7 @@ impl Context {
pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
if !CompletionProvider::global(cx).is_authenticated() {
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
return;
}
@ -1982,13 +1989,13 @@ impl Context {
content: "Summarize the context into a short title without punctuation.".into(),
}));
let request = LanguageModelRequest {
model: CompletionProvider::global(cx).model(),
messages: messages.collect(),
stop: vec![],
temperature: 1.0,
};
let stream = CompletionProvider::global(cx).stream_completion(request, cx);
let stream =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
self.pending_summary = cx.spawn(|this, mut cx| {
async move {
let mut messages = stream.await?;
@ -2504,7 +2511,6 @@ mod tests {
MessageId,
};
use assistant_slash_command::{ArgumentCompletion, SlashCommand};
use completion::FakeCompletionProvider;
use fs::FakeFs;
use gpui::{AppContext, TestAppContext, WeakView};
use indoc::indoc;
@ -2524,7 +2530,8 @@ mod tests {
#[gpui::test]
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
FakeCompletionProvider::setup_test(cx);
language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
cx.set_global(settings_store);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -2656,7 +2663,8 @@ mod tests {
fn test_message_splitting(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
FakeCompletionProvider::setup_test(cx);
language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -2749,7 +2757,8 @@ mod tests {
#[gpui::test]
fn test_messages_for_offsets(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
FakeCompletionProvider::setup_test(cx);
language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
cx.set_global(settings_store);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -2834,7 +2843,8 @@ mod tests {
async fn test_slash_commands(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(FakeCompletionProvider::setup_test);
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(Project::init_settings);
cx.update(assistant_panel::init);
let fs = FakeFs::new(cx.background_executor.clone());
@ -2959,7 +2969,11 @@ mod tests {
cx.update(prompt_library::init);
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
let fake_provider = cx.update(FakeCompletionProvider::setup_test);
let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
let fake_model = fake_provider.test_model();
cx.update(assistant_panel::init);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
@ -3025,8 +3039,8 @@ mod tests {
});
// Simulate the LLM completion
fake_provider.send_last_completion_chunk(llm_response.to_string());
fake_provider.finish_last_completion();
fake_model.send_last_completion_chunk(llm_response.to_string());
fake_model.finish_last_completion();
// Wait for the completion to be processed
cx.run_until_parked();
@ -3107,7 +3121,8 @@ mod tests {
async fn test_serialization(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(FakeCompletionProvider::setup_test);
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(assistant_panel::init);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
@ -3183,7 +3198,9 @@ mod tests {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(FakeCompletionProvider::setup_test);
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(assistant_panel::init);
let slash_commands = cx.update(SlashCommandRegistry::default_global);
slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);