Allow customization of the model used for tool calling (#15479)

We also eliminate the `completion` crate and moved its logic into
`LanguageModelRegistry`.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-30 16:18:53 +02:00 committed by GitHub
parent 1bfea9d443
commit 99bc90a372
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 478 additions and 691 deletions

View file

@ -1,6 +1,6 @@
use crate::{
humanize_token_count, prompts::generate_terminal_assistant_prompt, AssistantPanel,
AssistantPanelEvent, LanguageModelCompletionProvider, ModelSelector,
AssistantPanelEvent, ModelSelector,
};
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
@ -16,7 +16,9 @@ use gpui::{
Subscription, Task, TextStyle, UpdateGlobal, View, WeakView,
};
use language::Buffer;
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use settings::Settings;
use std::{
cmp,
@ -556,7 +558,7 @@ impl Render for PromptEditor {
Tooltip::with_meta(
format!(
"Using {}",
LanguageModelCompletionProvider::read_global(cx)
LanguageModelRegistry::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()),
@ -700,6 +702,9 @@ impl PromptEditor {
fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
let assist_id = self.id;
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return;
};
self.pending_token_count = cx.spawn(|this, mut cx| async move {
cx.background_executor().timer(Duration::from_secs(1)).await;
let request =
@ -707,11 +712,7 @@ impl PromptEditor {
inline_assistant.request_for_inline_assist(assist_id, cx)
})??;
let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.await?;
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
cx.notify();
@ -840,7 +841,7 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let token_count = self.token_count?;
let max_token_count = model.max_token_count();
@ -982,19 +983,16 @@ impl Codegen {
}
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
self.status = CodegenStatus::Pending;
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return;
};
let telemetry = self.telemetry.clone();
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
let response =
LanguageModelCompletionProvider::read_global(cx).stream_completion(prompt, cx);
self.status = CodegenStatus::Pending;
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
self.generation = cx.spawn(|this, mut cx| async move {
let response = response.await;
let model_telemetry_id = model.telemetry_id();
let response = model.stream_completion(prompt, &cx).await;
let generate = async {
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);