Allow customization of the model used for tool calling (#15479)
We also eliminate the `completion` crate and moved its logic into `LanguageModelRegistry`. Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
1bfea9d443
commit
99bc90a372
32 changed files with 478 additions and 691 deletions
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
|
||||
LanguageModelCompletionProvider, MessageId, MessageStatus,
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion, MessageId,
|
||||
MessageStatus,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_slash_command::{
|
||||
|
@ -18,7 +18,10 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
|
|||
use language::{
|
||||
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
|
||||
};
|
||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool,
|
||||
Role,
|
||||
};
|
||||
use open_ai::Model as OpenAiModel;
|
||||
use paths::contexts_dir;
|
||||
use project::Project;
|
||||
|
@ -1180,17 +1183,16 @@ impl Context {
|
|||
|
||||
pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
|
||||
let request = self.to_completion_request(cx);
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
return;
|
||||
};
|
||||
self.pending_token_count = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
cx.background_executor()
|
||||
.timer(Duration::from_millis(200))
|
||||
.await;
|
||||
|
||||
let token_count = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify()
|
||||
|
@ -1368,6 +1370,10 @@ impl Context {
|
|||
}
|
||||
}
|
||||
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
return Task::ready(Err(anyhow!("no active model")).log_err());
|
||||
};
|
||||
|
||||
let mut request = self.to_completion_request(cx);
|
||||
let edit_step_range = edit_step.source_range.clone();
|
||||
let step_text = self
|
||||
|
@ -1388,12 +1394,7 @@ impl Context {
|
|||
content: prompt,
|
||||
});
|
||||
|
||||
let tool_use = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
.use_tool::<EditTool>(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
let tool_use = model.use_tool::<EditTool>(request, &cx).await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let step_index = this
|
||||
|
@ -1568,6 +1569,8 @@ impl Context {
|
|||
}
|
||||
|
||||
pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
|
||||
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||
let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
|
||||
message
|
||||
.start
|
||||
|
@ -1575,14 +1578,12 @@ impl Context {
|
|||
.then_some(message.id)
|
||||
})?;
|
||||
|
||||
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
|
||||
if !provider.is_authenticated(cx) {
|
||||
log::info!("completion provider has no credentials");
|
||||
return None;
|
||||
}
|
||||
|
||||
let request = self.to_completion_request(cx);
|
||||
let stream =
|
||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
||||
let assistant_message = self
|
||||
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
||||
.unwrap();
|
||||
|
@ -1594,6 +1595,7 @@ impl Context {
|
|||
|
||||
let task = cx.spawn({
|
||||
|this, mut cx| async move {
|
||||
let stream = model.stream_completion(request, &cx);
|
||||
let assistant_message_id = assistant_message.id;
|
||||
let mut response_latency = None;
|
||||
let stream_completion = async {
|
||||
|
@ -1662,14 +1664,10 @@ impl Context {
|
|||
});
|
||||
|
||||
if let Some(telemetry) = this.telemetry.as_ref() {
|
||||
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model()
|
||||
.map(|m| m.telemetry_id())
|
||||
.unwrap_or_default();
|
||||
telemetry.report_assistant_event(
|
||||
Some(this.id.0.clone()),
|
||||
AssistantKind::Panel,
|
||||
model_telemetry_id,
|
||||
model.telemetry_id(),
|
||||
response_latency,
|
||||
error_message,
|
||||
);
|
||||
|
@ -1935,8 +1933,15 @@ impl Context {
|
|||
}
|
||||
|
||||
pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
|
||||
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
|
||||
return;
|
||||
};
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
return;
|
||||
};
|
||||
|
||||
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
|
||||
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
|
||||
if !provider.is_authenticated(cx) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1953,10 +1958,9 @@ impl Context {
|
|||
temperature: 1.0,
|
||||
};
|
||||
|
||||
let stream =
|
||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
||||
self.pending_summary = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let stream = model.stream_completion(request, &cx);
|
||||
let mut messages = stream.await?;
|
||||
|
||||
let mut replaced = !replace_old;
|
||||
|
@ -2490,7 +2494,6 @@ mod tests {
|
|||
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
language_model::LanguageModelRegistry::test(cx);
|
||||
completion::LanguageModelCompletionProvider::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
assistant_panel::init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
|
@ -2623,7 +2626,6 @@ mod tests {
|
|||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language_model::LanguageModelRegistry::test(cx);
|
||||
completion::LanguageModelCompletionProvider::test(cx);
|
||||
assistant_panel::init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
|
||||
|
@ -2717,7 +2719,6 @@ mod tests {
|
|||
fn test_messages_for_offsets(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
language_model::LanguageModelRegistry::test(cx);
|
||||
completion::LanguageModelCompletionProvider::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
assistant_panel::init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
|
@ -2803,7 +2804,6 @@ mod tests {
|
|||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
cx.update(Project::init_settings);
|
||||
cx.update(assistant_panel::init);
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
|
@ -2930,7 +2930,6 @@ mod tests {
|
|||
cx.set_global(settings_store);
|
||||
|
||||
let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
|
||||
let fake_model = fake_provider.test_model();
|
||||
cx.update(assistant_panel::init);
|
||||
|
@ -3032,7 +3031,6 @@ mod tests {
|
|||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
cx.update(assistant_panel::init);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||
let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
|
||||
|
@ -3109,7 +3107,6 @@ mod tests {
|
|||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
|
||||
cx.update(assistant_panel::init);
|
||||
let slash_commands = cx.update(SlashCommandRegistry::default_global);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue