Allow customization of the model used for tool calling (#15479)

We also eliminate the `completion` crate and moved its logic into
`LanguageModelRegistry`.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-30 16:18:53 +02:00 committed by GitHub
parent 1bfea9d443
commit 99bc90a372
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 478 additions and 691 deletions

View file

@ -1,6 +1,6 @@
use crate::{
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
LanguageModelCompletionProvider, MessageId, MessageStatus,
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion, MessageId,
MessageStatus,
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
@ -18,7 +18,10 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
use language::{
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool,
Role,
};
use open_ai::Model as OpenAiModel;
use paths::contexts_dir;
use project::Project;
@ -1180,17 +1183,16 @@ impl Context {
pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
let request = self.to_completion_request(cx);
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return;
};
self.pending_token_count = cx.spawn(|this, mut cx| {
async move {
cx.background_executor()
.timer(Duration::from_millis(200))
.await;
let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.await?;
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
cx.notify()
@ -1368,6 +1370,10 @@ impl Context {
}
}
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return Task::ready(Err(anyhow!("no active model")).log_err());
};
let mut request = self.to_completion_request(cx);
let edit_step_range = edit_step.source_range.clone();
let step_text = self
@ -1388,12 +1394,7 @@ impl Context {
content: prompt,
});
let tool_use = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx)
.use_tool::<EditTool>(request, cx)
})?
.await?;
let tool_use = model.use_tool::<EditTool>(request, &cx).await?;
this.update(&mut cx, |this, cx| {
let step_index = this
@ -1568,6 +1569,8 @@ impl Context {
}
pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
message
.start
@ -1575,14 +1578,12 @@ impl Context {
.then_some(message.id)
})?;
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
if !provider.is_authenticated(cx) {
log::info!("completion provider has no credentials");
return None;
}
let request = self.to_completion_request(cx);
let stream =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap();
@ -1594,6 +1595,7 @@ impl Context {
let task = cx.spawn({
|this, mut cx| async move {
let stream = model.stream_completion(request, &cx);
let assistant_message_id = assistant_message.id;
let mut response_latency = None;
let stream_completion = async {
@ -1662,14 +1664,10 @@ impl Context {
});
if let Some(telemetry) = this.telemetry.as_ref() {
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
telemetry.report_assistant_event(
Some(this.id.0.clone()),
AssistantKind::Panel,
model_telemetry_id,
model.telemetry_id(),
response_latency,
error_message,
);
@ -1935,8 +1933,15 @@ impl Context {
}
pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
return;
};
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return;
};
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
if !provider.is_authenticated(cx) {
return;
}
@ -1953,10 +1958,9 @@ impl Context {
temperature: 1.0,
};
let stream =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
self.pending_summary = cx.spawn(|this, mut cx| {
async move {
let stream = model.stream_completion(request, &cx);
let mut messages = stream.await?;
let mut replaced = !replace_old;
@ -2490,7 +2494,6 @@ mod tests {
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
cx.set_global(settings_store);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -2623,7 +2626,6 @@ mod tests {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -2717,7 +2719,6 @@ mod tests {
fn test_messages_for_offsets(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
cx.set_global(settings_store);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -2803,7 +2804,6 @@ mod tests {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(Project::init_settings);
cx.update(assistant_panel::init);
let fs = FakeFs::new(cx.background_executor.clone());
@ -2930,7 +2930,6 @@ mod tests {
cx.set_global(settings_store);
let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
let fake_model = fake_provider.test_model();
cx.update(assistant_panel::init);
@ -3032,7 +3031,6 @@ mod tests {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(assistant_panel::init);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
@ -3109,7 +3107,6 @@ mod tests {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(assistant_panel::init);
let slash_commands = cx.update(SlashCommandRegistry::default_global);