Allow customization of the model used for tool calling (#15479)

We also eliminate the `completion` crate and moved its logic into
`LanguageModelRegistry`.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-30 16:18:53 +02:00 committed by GitHub
parent 1bfea9d443
commit 99bc90a372
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 478 additions and 691 deletions

View file

@ -15,7 +15,6 @@ use assistant_settings::AssistantSettings;
use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
use completion::LanguageModelCompletionProvider;
pub use context::*;
pub use context_store::*;
use fs::Fs;
@ -192,7 +191,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
context_store::init(&client);
prompt_library::init(cx);
init_completion_provider(cx);
init_language_model_settings(cx);
assistant_slash_command::init(cx);
register_slash_commands(cx);
assistant_panel::init(cx);
@ -217,8 +216,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
.detach();
}
fn init_completion_provider(cx: &mut AppContext) {
completion::init(cx);
fn init_language_model_settings(cx: &mut AppContext) {
update_active_language_model_from_settings(cx);
cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
@ -233,20 +231,9 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) {
let settings = AssistantSettings::get_global(cx);
let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
let model_id = LanguageModelId::from(settings.default_model.model.clone());
let Some(provider) = LanguageModelRegistry::global(cx)
.read(cx)
.provider(&provider_name)
else {
return;
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| model.id() == model_id).cloned() {
LanguageModelCompletionProvider::global(cx).update(cx, |completion_provider, cx| {
completion_provider.set_active_model(model, cx);
});
}
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.select_active_model(&provider_name, &model_id, cx);
});
}
fn register_slash_commands(cx: &mut AppContext) {

View file

@ -19,7 +19,6 @@ use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
use client::proto;
use collections::{BTreeSet, HashMap, HashSet};
use completion::LanguageModelCompletionProvider;
use editor::{
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
display_map::{
@ -43,7 +42,7 @@ use language::{
language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point,
ToOffset,
};
use language_model::{LanguageModelProviderId, Role};
use language_model::{LanguageModelProviderId, LanguageModelRegistry, Role};
use multi_buffer::MultiBufferRow;
use picker::{Picker, PickerDelegate};
use project::{Project, ProjectLspAdapterDelegate};
@ -392,9 +391,9 @@ impl AssistantPanel {
cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event),
cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event),
cx.subscribe(&context_store, Self::handle_context_store_event),
cx.observe(
&LanguageModelCompletionProvider::global(cx),
|this, _, cx| {
cx.subscribe(
&LanguageModelRegistry::global(cx),
|this, _, _: &language_model::ActiveModelChanged, cx| {
this.completion_provider_changed(cx);
},
),
@ -560,7 +559,7 @@ impl AssistantPanel {
})
}
let Some(new_provider_id) = LanguageModelCompletionProvider::read_global(cx)
let Some(new_provider_id) = LanguageModelRegistry::read_global(cx)
.active_provider()
.map(|p| p.id())
else {
@ -599,7 +598,7 @@ impl AssistantPanel {
}
fn authentication_prompt(cx: &mut WindowContext) -> Option<AnyView> {
if let Some(provider) = LanguageModelCompletionProvider::read_global(cx).active_provider() {
if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() {
if !provider.is_authenticated(cx) {
return Some(provider.authentication_prompt(cx));
}
@ -904,9 +903,9 @@ impl AssistantPanel {
}
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
LanguageModelCompletionProvider::read_global(cx)
.reset_credentials(cx)
.detach_and_log_err(cx);
if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() {
provider.reset_credentials(cx).detach_and_log_err(cx);
}
}
fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext<Self>) {
@ -1041,11 +1040,18 @@ impl AssistantPanel {
}
fn is_authenticated(&mut self, cx: &mut ViewContext<Self>) -> bool {
LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx)
LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(false, |provider| provider.is_authenticated(cx))
}
fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
LanguageModelCompletionProvider::read_global(cx).authenticate(cx)
LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(
Task::ready(Err(anyhow!("no active language model provider"))),
|provider| provider.authenticate(cx),
)
}
fn render_signed_in(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
@ -2707,7 +2713,7 @@ impl ContextEditorToolbarItem {
}
fn render_remaining_tokens(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let context = &self
.active_context_editor
.as_ref()?
@ -2779,7 +2785,7 @@ impl Render for ContextEditorToolbarItem {
.whitespace_nowrap()
.child(
Label::new(
LanguageModelCompletionProvider::read_global(cx)
LanguageModelRegistry::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()),

View file

@ -52,7 +52,7 @@ pub struct AssistantSettings {
pub dock: AssistantDockPosition,
pub default_width: Pixels,
pub default_height: Pixels,
pub default_model: AssistantDefaultModel,
pub default_model: LanguageModelSelection,
pub using_outdated_settings_version: bool,
}
@ -198,25 +198,25 @@ impl AssistantSettingsContent {
.clone()
.and_then(|provider| match provider {
AssistantProviderContentV1::ZedDotDev { default_model } => {
default_model.map(|model| AssistantDefaultModel {
default_model.map(|model| LanguageModelSelection {
provider: "zed.dev".to_string(),
model: model.id().to_string(),
})
}
AssistantProviderContentV1::OpenAi { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel {
default_model.map(|model| LanguageModelSelection {
provider: "openai".to_string(),
model: model.id().to_string(),
})
}
AssistantProviderContentV1::Anthropic { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel {
default_model.map(|model| LanguageModelSelection {
provider: "anthropic".to_string(),
model: model.id().to_string(),
})
}
AssistantProviderContentV1::Ollama { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel {
default_model.map(|model| LanguageModelSelection {
provider: "ollama".to_string(),
model: model.id().to_string(),
})
@ -231,7 +231,7 @@ impl AssistantSettingsContent {
dock: settings.dock,
default_width: settings.default_width,
default_height: settings.default_height,
default_model: Some(AssistantDefaultModel {
default_model: Some(LanguageModelSelection {
provider: "openai".to_string(),
model: settings
.default_open_ai_model
@ -325,7 +325,7 @@ impl AssistantSettingsContent {
_ => {}
},
VersionedAssistantSettingsContent::V2(settings) => {
settings.default_model = Some(AssistantDefaultModel { provider, model });
settings.default_model = Some(LanguageModelSelection { provider, model });
}
},
AssistantSettingsContent::Legacy(settings) => {
@ -382,11 +382,11 @@ pub struct AssistantSettingsContentV2 {
/// Default: 320
default_height: Option<f32>,
/// The default model to use when creating new contexts.
default_model: Option<AssistantDefaultModel>,
default_model: Option<LanguageModelSelection>,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct AssistantDefaultModel {
pub struct LanguageModelSelection {
#[schemars(schema_with = "providers_schema")]
pub provider: String,
pub model: String,
@ -407,7 +407,7 @@ fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema:
.into()
}
impl Default for AssistantDefaultModel {
impl Default for LanguageModelSelection {
fn default() -> Self {
Self {
provider: "openai".to_string(),
@ -542,7 +542,7 @@ mod tests {
assert!(!AssistantSettings::get_global(cx).using_outdated_settings_version);
assert_eq!(
AssistantSettings::get_global(cx).default_model,
AssistantDefaultModel {
LanguageModelSelection {
provider: "openai".into(),
model: "gpt-4o".into(),
}
@ -555,7 +555,7 @@ mod tests {
|settings, _| {
*settings = AssistantSettingsContent::Versioned(
VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 {
default_model: Some(AssistantDefaultModel {
default_model: Some(LanguageModelSelection {
provider: "test-provider".into(),
model: "gpt-99".into(),
}),

View file

@ -1,6 +1,6 @@
use crate::{
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
LanguageModelCompletionProvider, MessageId, MessageStatus,
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion, MessageId,
MessageStatus,
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
@ -18,7 +18,10 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
use language::{
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool,
Role,
};
use open_ai::Model as OpenAiModel;
use paths::contexts_dir;
use project::Project;
@ -1180,17 +1183,16 @@ impl Context {
pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
let request = self.to_completion_request(cx);
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return;
};
self.pending_token_count = cx.spawn(|this, mut cx| {
async move {
cx.background_executor()
.timer(Duration::from_millis(200))
.await;
let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.await?;
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
cx.notify()
@ -1368,6 +1370,10 @@ impl Context {
}
}
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return Task::ready(Err(anyhow!("no active model")).log_err());
};
let mut request = self.to_completion_request(cx);
let edit_step_range = edit_step.source_range.clone();
let step_text = self
@ -1388,12 +1394,7 @@ impl Context {
content: prompt,
});
let tool_use = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx)
.use_tool::<EditTool>(request, cx)
})?
.await?;
let tool_use = model.use_tool::<EditTool>(request, &cx).await?;
this.update(&mut cx, |this, cx| {
let step_index = this
@ -1568,6 +1569,8 @@ impl Context {
}
pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
message
.start
@ -1575,14 +1578,12 @@ impl Context {
.then_some(message.id)
})?;
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
if !provider.is_authenticated(cx) {
log::info!("completion provider has no credentials");
return None;
}
let request = self.to_completion_request(cx);
let stream =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap();
@ -1594,6 +1595,7 @@ impl Context {
let task = cx.spawn({
|this, mut cx| async move {
let stream = model.stream_completion(request, &cx);
let assistant_message_id = assistant_message.id;
let mut response_latency = None;
let stream_completion = async {
@ -1662,14 +1664,10 @@ impl Context {
});
if let Some(telemetry) = this.telemetry.as_ref() {
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
telemetry.report_assistant_event(
Some(this.id.0.clone()),
AssistantKind::Panel,
model_telemetry_id,
model.telemetry_id(),
response_latency,
error_message,
);
@ -1935,8 +1933,15 @@ impl Context {
}
pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
return;
};
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return;
};
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
if !provider.is_authenticated(cx) {
return;
}
@ -1953,10 +1958,9 @@ impl Context {
temperature: 1.0,
};
let stream =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
self.pending_summary = cx.spawn(|this, mut cx| {
async move {
let stream = model.stream_completion(request, &cx);
let mut messages = stream.await?;
let mut replaced = !replace_old;
@ -2490,7 +2494,6 @@ mod tests {
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
cx.set_global(settings_store);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -2623,7 +2626,6 @@ mod tests {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -2717,7 +2719,6 @@ mod tests {
fn test_messages_for_offsets(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
cx.set_global(settings_store);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -2803,7 +2804,6 @@ mod tests {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(Project::init_settings);
cx.update(assistant_panel::init);
let fs = FakeFs::new(cx.background_executor.clone());
@ -2930,7 +2930,6 @@ mod tests {
cx.set_global(settings_store);
let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
let fake_model = fake_provider.test_model();
cx.update(assistant_panel::init);
@ -3032,7 +3031,6 @@ mod tests {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(assistant_panel::init);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
@ -3109,7 +3107,6 @@ mod tests {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(assistant_panel::init);
let slash_commands = cx.update(SlashCommandRegistry::default_global);

View file

@ -1,6 +1,6 @@
use crate::{
humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent,
Hunk, LanguageModelCompletionProvider, ModelSelector, StreamingDiff,
Hunk, ModelSelector, StreamingDiff,
};
use anyhow::{anyhow, Context as _, Result};
use client::telemetry::Telemetry;
@ -27,7 +27,9 @@ use gpui::{
WindowContext,
};
use language::{Buffer, IndentKind, Point, Selection, TransactionId};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use rope::Rope;
@ -1328,7 +1330,7 @@ impl Render for PromptEditor {
Tooltip::with_meta(
format!(
"Using {}",
LanguageModelCompletionProvider::read_global(cx)
LanguageModelRegistry::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()),
@ -1662,7 +1664,7 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let token_count = self.token_count?;
let max_token_count = model.max_token_count();
@ -2013,8 +2015,12 @@ impl Codegen {
assistant_panel_context: Option<LanguageModelRequest>,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
model.count_tokens(request, cx)
} else {
future::ready(Err(anyhow!("no active model"))).boxed()
}
}
pub fn start(
@ -2024,6 +2030,10 @@ impl Codegen {
assistant_panel_context: Option<LanguageModelRequest>,
cx: &mut ModelContext<Self>,
) -> Result<()> {
let model = LanguageModelRegistry::read_global(cx)
.active_model()
.context("no active model")?;
self.undo(cx);
// Handle initial insertion
@ -2053,10 +2063,7 @@ impl Codegen {
None
};
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model_telemetry_id()
.context("no active model")?;
let telemetry_id = model.telemetry_id();
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
.trim()
.to_lowercase()
@ -2067,10 +2074,10 @@ impl Codegen {
let request =
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
let chunks =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
async move { Ok(chunks.await?.boxed()) }.boxed_local()
};
self.handle_stream(model_telemetry_id, edit_range, chunks, cx);
self.handle_stream(telemetry_id, edit_range, chunks, cx);
Ok(())
}
@ -2657,7 +2664,6 @@ mod tests {
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(language_settings::init);
let text = indoc! {"
@ -2789,7 +2795,6 @@ mod tests {
mut rng: StdRng,
) {
cx.update(LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
@ -2853,7 +2858,6 @@ mod tests {
#[gpui::test(iterations = 10)]
async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
cx.update(LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);

View file

@ -1,6 +1,6 @@
use std::sync::Arc;
use crate::{assistant_settings::AssistantSettings, LanguageModelCompletionProvider};
use crate::assistant_settings::AssistantSettings;
use fs::Fs;
use gpui::SharedString;
use language_model::LanguageModelRegistry;
@ -81,13 +81,13 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
}
},
{
let provider = provider.id();
let provider = provider.clone();
move |cx| {
LanguageModelCompletionProvider::global(cx).update(
LanguageModelRegistry::global(cx).update(
cx,
|completion_provider, cx| {
completion_provider
.set_active_provider(provider.clone(), cx)
.set_active_provider(Some(provider.clone()), cx);
},
);
}
@ -95,12 +95,12 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
);
}
let selected_model = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.id());
let selected_provider = LanguageModelCompletionProvider::read_global(cx)
let selected_provider = LanguageModelRegistry::read_global(cx)
.active_provider()
.map(|m| m.id());
let selected_model = LanguageModelRegistry::read_global(cx)
.active_model()
.map(|m| m.id());
for available_model in available_models {
menu = menu.custom_entry(

View file

@ -1,6 +1,5 @@
use crate::{
slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant,
LanguageModelCompletionProvider,
};
use anyhow::{anyhow, Result};
use assets::Assets;
@ -19,7 +18,9 @@ use gpui::{
};
use heed::{types::SerdeBincode, Database, RoTxn};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use parking_lot::RwLock;
use picker::{Picker, PickerDelegate};
use rope::Rope;
@ -636,7 +637,10 @@ impl PromptLibrary {
};
let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor;
let provider = LanguageModelCompletionProvider::read_global(cx);
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
return;
};
let initial_prompt = action.prompt.clone();
if provider.is_authenticated(cx) {
InlineAssistant::update_global(cx, |assistant, cx| {
@ -725,6 +729,9 @@ impl PromptLibrary {
}
fn count_tokens(&mut self, prompt_id: PromptId, cx: &mut ViewContext<Self>) {
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return;
};
if let Some(prompt) = self.prompt_editors.get_mut(&prompt_id) {
let editor = &prompt.body_editor.read(cx);
let buffer = &editor.buffer().read(cx).as_singleton().unwrap().read(cx);
@ -736,7 +743,7 @@ impl PromptLibrary {
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(
model.count_tokens(
LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::System,
@ -804,7 +811,7 @@ impl PromptLibrary {
let prompt_metadata = self.store.metadata(prompt_id)?;
let prompt_editor = &self.prompt_editors[&prompt_id];
let focus_handle = prompt_editor.body_editor.focus_handle(cx);
let current_model = LanguageModelCompletionProvider::read_global(cx).active_model();
let model = LanguageModelRegistry::read_global(cx).active_model();
let settings = ThemeSettings::get_global(cx);
Some(
@ -914,7 +921,7 @@ impl PromptLibrary {
None,
format!(
"Model: {}",
current_model
model
.as_ref()
.map(|model| model
.name()

View file

@ -1,6 +1,6 @@
use crate::{
humanize_token_count, prompts::generate_terminal_assistant_prompt, AssistantPanel,
AssistantPanelEvent, LanguageModelCompletionProvider, ModelSelector,
AssistantPanelEvent, ModelSelector,
};
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
@ -16,7 +16,9 @@ use gpui::{
Subscription, Task, TextStyle, UpdateGlobal, View, WeakView,
};
use language::Buffer;
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use settings::Settings;
use std::{
cmp,
@ -556,7 +558,7 @@ impl Render for PromptEditor {
Tooltip::with_meta(
format!(
"Using {}",
LanguageModelCompletionProvider::read_global(cx)
LanguageModelRegistry::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()),
@ -700,6 +702,9 @@ impl PromptEditor {
fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
let assist_id = self.id;
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return;
};
self.pending_token_count = cx.spawn(|this, mut cx| async move {
cx.background_executor().timer(Duration::from_secs(1)).await;
let request =
@ -707,11 +712,7 @@ impl PromptEditor {
inline_assistant.request_for_inline_assist(assist_id, cx)
})??;
let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.await?;
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
cx.notify();
@ -840,7 +841,7 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let token_count = self.token_count?;
let max_token_count = model.max_token_count();
@ -982,19 +983,16 @@ impl Codegen {
}
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
self.status = CodegenStatus::Pending;
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
return;
};
let telemetry = self.telemetry.clone();
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
let response =
LanguageModelCompletionProvider::read_global(cx).stream_completion(prompt, cx);
self.status = CodegenStatus::Pending;
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
self.generation = cx.spawn(|this, mut cx| async move {
let response = response.await;
let model_telemetry_id = model.telemetry_id();
let response = model.stream_completion(prompt, &cx).await;
let generate = async {
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);