assistant: Overhaul provider infrastructure (#14929)
<img width="624" alt="image" src="https://github.com/user-attachments/assets/f492b0bd-14c3-49e2-b2ff-dc78e52b0815"> - [x] Correctly set custom model token count - [x] How to count tokens for Gemini models? - [x] Feature flag zed.dev provider - [x] Figure out how to configure custom models - [ ] Update docs Release Notes: - Added support for quickly switching between multiple language model providers in the assistant panel --------- Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
parent
17ef9a367f
commit
d0f52e90e6
55 changed files with 2757 additions and 2023 deletions
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, MessageId,
|
||||
MessageStatus,
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider,
|
||||
MessageId, MessageStatus,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_slash_command::{
|
||||
|
@ -1124,7 +1124,9 @@ impl Context {
|
|||
.await;
|
||||
|
||||
let token_count = cx
|
||||
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
|
@ -1308,7 +1310,9 @@ impl Context {
|
|||
});
|
||||
|
||||
let raw_output = cx
|
||||
.update(|cx| CompletionProvider::global(cx).complete(request, cx))?
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).complete(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let operations = Self::parse_edit_operations(&raw_output);
|
||||
|
@ -1612,13 +1616,14 @@ impl Context {
|
|||
.then_some(message.id)
|
||||
})?;
|
||||
|
||||
if !CompletionProvider::global(cx).is_authenticated() {
|
||||
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
|
||||
log::info!("completion provider has no credentials");
|
||||
return None;
|
||||
}
|
||||
|
||||
let request = self.to_completion_request(cx);
|
||||
let stream = CompletionProvider::global(cx).stream_completion(request, cx);
|
||||
let stream =
|
||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
||||
let assistant_message = self
|
||||
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
||||
.unwrap();
|
||||
|
@ -1698,11 +1703,14 @@ impl Context {
|
|||
});
|
||||
|
||||
if let Some(telemetry) = this.telemetry.as_ref() {
|
||||
let model = CompletionProvider::global(cx).model();
|
||||
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model()
|
||||
.map(|m| m.telemetry_id())
|
||||
.unwrap_or_default();
|
||||
telemetry.report_assistant_event(
|
||||
Some(this.id.0.clone()),
|
||||
AssistantKind::Panel,
|
||||
model.telemetry_id(),
|
||||
model_telemetry_id,
|
||||
response_latency,
|
||||
error_message,
|
||||
);
|
||||
|
@ -1727,7 +1735,6 @@ impl Context {
|
|||
.map(|message| message.to_request_message(self.buffer.read(cx)));
|
||||
|
||||
LanguageModelRequest {
|
||||
model: CompletionProvider::global(cx).model(),
|
||||
messages: messages.collect(),
|
||||
stop: vec![],
|
||||
temperature: 1.0,
|
||||
|
@ -1970,7 +1977,7 @@ impl Context {
|
|||
|
||||
pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
|
||||
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
|
||||
if !CompletionProvider::global(cx).is_authenticated() {
|
||||
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1982,13 +1989,13 @@ impl Context {
|
|||
content: "Summarize the context into a short title without punctuation.".into(),
|
||||
}));
|
||||
let request = LanguageModelRequest {
|
||||
model: CompletionProvider::global(cx).model(),
|
||||
messages: messages.collect(),
|
||||
stop: vec![],
|
||||
temperature: 1.0,
|
||||
};
|
||||
|
||||
let stream = CompletionProvider::global(cx).stream_completion(request, cx);
|
||||
let stream =
|
||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
||||
self.pending_summary = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let mut messages = stream.await?;
|
||||
|
@ -2504,7 +2511,6 @@ mod tests {
|
|||
MessageId,
|
||||
};
|
||||
use assistant_slash_command::{ArgumentCompletion, SlashCommand};
|
||||
use completion::FakeCompletionProvider;
|
||||
use fs::FakeFs;
|
||||
use gpui::{AppContext, TestAppContext, WeakView};
|
||||
use indoc::indoc;
|
||||
|
@ -2524,7 +2530,8 @@ mod tests {
|
|||
#[gpui::test]
|
||||
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
FakeCompletionProvider::setup_test(cx);
|
||||
language_model::LanguageModelRegistry::test(cx);
|
||||
completion::LanguageModelCompletionProvider::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
assistant_panel::init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
|
@ -2656,7 +2663,8 @@ mod tests {
|
|||
fn test_message_splitting(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
FakeCompletionProvider::setup_test(cx);
|
||||
language_model::LanguageModelRegistry::test(cx);
|
||||
completion::LanguageModelCompletionProvider::test(cx);
|
||||
assistant_panel::init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
|
||||
|
@ -2749,7 +2757,8 @@ mod tests {
|
|||
#[gpui::test]
|
||||
fn test_messages_for_offsets(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
FakeCompletionProvider::setup_test(cx);
|
||||
language_model::LanguageModelRegistry::test(cx);
|
||||
completion::LanguageModelCompletionProvider::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
assistant_panel::init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
|
@ -2834,7 +2843,8 @@ mod tests {
|
|||
async fn test_slash_commands(cx: &mut TestAppContext) {
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(FakeCompletionProvider::setup_test);
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
cx.update(Project::init_settings);
|
||||
cx.update(assistant_panel::init);
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
|
@ -2959,7 +2969,11 @@ mod tests {
|
|||
cx.update(prompt_library::init);
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
let fake_provider = cx.update(FakeCompletionProvider::setup_test);
|
||||
|
||||
let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
|
||||
let fake_model = fake_provider.test_model();
|
||||
cx.update(assistant_panel::init);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||
|
||||
|
@ -3025,8 +3039,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Simulate the LLM completion
|
||||
fake_provider.send_last_completion_chunk(llm_response.to_string());
|
||||
fake_provider.finish_last_completion();
|
||||
fake_model.send_last_completion_chunk(llm_response.to_string());
|
||||
fake_model.finish_last_completion();
|
||||
|
||||
// Wait for the completion to be processed
|
||||
cx.run_until_parked();
|
||||
|
@ -3107,7 +3121,8 @@ mod tests {
|
|||
async fn test_serialization(cx: &mut TestAppContext) {
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(FakeCompletionProvider::setup_test);
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
cx.update(assistant_panel::init);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||
let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
|
||||
|
@ -3183,7 +3198,9 @@ mod tests {
|
|||
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(FakeCompletionProvider::setup_test);
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
|
||||
cx.update(assistant_panel::init);
|
||||
let slash_commands = cx.update(SlashCommandRegistry::default_global);
|
||||
slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue