Allow customization of the model used for tool calling (#15479)
We also eliminate the `completion` crate and moved its logic into `LanguageModelRegistry`. Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
1bfea9d443
commit
99bc90a372
32 changed files with 478 additions and 691 deletions
|
@ -15,7 +15,6 @@ use assistant_settings::AssistantSettings;
|
|||
use assistant_slash_command::SlashCommandRegistry;
|
||||
use client::{proto, Client};
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
use completion::LanguageModelCompletionProvider;
|
||||
pub use context::*;
|
||||
pub use context_store::*;
|
||||
use fs::Fs;
|
||||
|
@ -192,7 +191,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
|
|||
|
||||
context_store::init(&client);
|
||||
prompt_library::init(cx);
|
||||
init_completion_provider(cx);
|
||||
init_language_model_settings(cx);
|
||||
assistant_slash_command::init(cx);
|
||||
register_slash_commands(cx);
|
||||
assistant_panel::init(cx);
|
||||
|
@ -217,8 +216,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
|
|||
.detach();
|
||||
}
|
||||
|
||||
fn init_completion_provider(cx: &mut AppContext) {
|
||||
completion::init(cx);
|
||||
fn init_language_model_settings(cx: &mut AppContext) {
|
||||
update_active_language_model_from_settings(cx);
|
||||
|
||||
cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
|
||||
|
@ -233,20 +231,9 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) {
|
|||
let settings = AssistantSettings::get_global(cx);
|
||||
let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
|
||||
let model_id = LanguageModelId::from(settings.default_model.model.clone());
|
||||
|
||||
let Some(provider) = LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.provider(&provider_name)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let models = provider.provided_models(cx);
|
||||
if let Some(model) = models.iter().find(|model| model.id() == model_id).cloned() {
|
||||
LanguageModelCompletionProvider::global(cx).update(cx, |completion_provider, cx| {
|
||||
completion_provider.set_active_model(model, cx);
|
||||
});
|
||||
}
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.select_active_model(&provider_name, &model_id, cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn register_slash_commands(cx: &mut AppContext) {
|
||||
|
|
|
@ -19,7 +19,6 @@ use anyhow::{anyhow, Result};
|
|||
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
|
||||
use client::proto;
|
||||
use collections::{BTreeSet, HashMap, HashSet};
|
||||
use completion::LanguageModelCompletionProvider;
|
||||
use editor::{
|
||||
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
|
||||
display_map::{
|
||||
|
@ -43,7 +42,7 @@ use language::{
|
|||
language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point,
|
||||
ToOffset,
|
||||
};
|
||||
use language_model::{LanguageModelProviderId, Role};
|
||||
use language_model::{LanguageModelProviderId, LanguageModelRegistry, Role};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use project::{Project, ProjectLspAdapterDelegate};
|
||||
|
@ -392,9 +391,9 @@ impl AssistantPanel {
|
|||
cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event),
|
||||
cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event),
|
||||
cx.subscribe(&context_store, Self::handle_context_store_event),
|
||||
cx.observe(
|
||||
&LanguageModelCompletionProvider::global(cx),
|
||||
|this, _, cx| {
|
||||
cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
|this, _, _: &language_model::ActiveModelChanged, cx| {
|
||||
this.completion_provider_changed(cx);
|
||||
},
|
||||
),
|
||||
|
@ -560,7 +559,7 @@ impl AssistantPanel {
|
|||
})
|
||||
}
|
||||
|
||||
let Some(new_provider_id) = LanguageModelCompletionProvider::read_global(cx)
|
||||
let Some(new_provider_id) = LanguageModelRegistry::read_global(cx)
|
||||
.active_provider()
|
||||
.map(|p| p.id())
|
||||
else {
|
||||
|
@ -599,7 +598,7 @@ impl AssistantPanel {
|
|||
}
|
||||
|
||||
fn authentication_prompt(cx: &mut WindowContext) -> Option<AnyView> {
|
||||
if let Some(provider) = LanguageModelCompletionProvider::read_global(cx).active_provider() {
|
||||
if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() {
|
||||
if !provider.is_authenticated(cx) {
|
||||
return Some(provider.authentication_prompt(cx));
|
||||
}
|
||||
|
@ -904,9 +903,9 @@ impl AssistantPanel {
|
|||
}
|
||||
|
||||
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
.reset_credentials(cx)
|
||||
.detach_and_log_err(cx);
|
||||
if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() {
|
||||
provider.reset_credentials(cx).detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
|
||||
fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext<Self>) {
|
||||
|
@ -1041,11 +1040,18 @@ impl AssistantPanel {
|
|||
}
|
||||
|
||||
fn is_authenticated(&mut self, cx: &mut ViewContext<Self>) -> bool {
|
||||
LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx)
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.active_provider()
|
||||
.map_or(false, |provider| provider.is_authenticated(cx))
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
|
||||
LanguageModelCompletionProvider::read_global(cx).authenticate(cx)
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.active_provider()
|
||||
.map_or(
|
||||
Task::ready(Err(anyhow!("no active language model provider"))),
|
||||
|provider| provider.authenticate(cx),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_signed_in(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
|
@ -2707,7 +2713,7 @@ impl ContextEditorToolbarItem {
|
|||
}
|
||||
|
||||
fn render_remaining_tokens(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
||||
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
|
||||
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||
let context = &self
|
||||
.active_context_editor
|
||||
.as_ref()?
|
||||
|
@ -2779,7 +2785,7 @@ impl Render for ContextEditorToolbarItem {
|
|||
.whitespace_nowrap()
|
||||
.child(
|
||||
Label::new(
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.active_model()
|
||||
.map(|model| model.name().0)
|
||||
.unwrap_or_else(|| "No model selected".into()),
|
||||
|
|
|
@ -52,7 +52,7 @@ pub struct AssistantSettings {
|
|||
pub dock: AssistantDockPosition,
|
||||
pub default_width: Pixels,
|
||||
pub default_height: Pixels,
|
||||
pub default_model: AssistantDefaultModel,
|
||||
pub default_model: LanguageModelSelection,
|
||||
pub using_outdated_settings_version: bool,
|
||||
}
|
||||
|
||||
|
@ -198,25 +198,25 @@ impl AssistantSettingsContent {
|
|||
.clone()
|
||||
.and_then(|provider| match provider {
|
||||
AssistantProviderContentV1::ZedDotDev { default_model } => {
|
||||
default_model.map(|model| AssistantDefaultModel {
|
||||
default_model.map(|model| LanguageModelSelection {
|
||||
provider: "zed.dev".to_string(),
|
||||
model: model.id().to_string(),
|
||||
})
|
||||
}
|
||||
AssistantProviderContentV1::OpenAi { default_model, .. } => {
|
||||
default_model.map(|model| AssistantDefaultModel {
|
||||
default_model.map(|model| LanguageModelSelection {
|
||||
provider: "openai".to_string(),
|
||||
model: model.id().to_string(),
|
||||
})
|
||||
}
|
||||
AssistantProviderContentV1::Anthropic { default_model, .. } => {
|
||||
default_model.map(|model| AssistantDefaultModel {
|
||||
default_model.map(|model| LanguageModelSelection {
|
||||
provider: "anthropic".to_string(),
|
||||
model: model.id().to_string(),
|
||||
})
|
||||
}
|
||||
AssistantProviderContentV1::Ollama { default_model, .. } => {
|
||||
default_model.map(|model| AssistantDefaultModel {
|
||||
default_model.map(|model| LanguageModelSelection {
|
||||
provider: "ollama".to_string(),
|
||||
model: model.id().to_string(),
|
||||
})
|
||||
|
@ -231,7 +231,7 @@ impl AssistantSettingsContent {
|
|||
dock: settings.dock,
|
||||
default_width: settings.default_width,
|
||||
default_height: settings.default_height,
|
||||
default_model: Some(AssistantDefaultModel {
|
||||
default_model: Some(LanguageModelSelection {
|
||||
provider: "openai".to_string(),
|
||||
model: settings
|
||||
.default_open_ai_model
|
||||
|
@ -325,7 +325,7 @@ impl AssistantSettingsContent {
|
|||
_ => {}
|
||||
},
|
||||
VersionedAssistantSettingsContent::V2(settings) => {
|
||||
settings.default_model = Some(AssistantDefaultModel { provider, model });
|
||||
settings.default_model = Some(LanguageModelSelection { provider, model });
|
||||
}
|
||||
},
|
||||
AssistantSettingsContent::Legacy(settings) => {
|
||||
|
@ -382,11 +382,11 @@ pub struct AssistantSettingsContentV2 {
|
|||
/// Default: 320
|
||||
default_height: Option<f32>,
|
||||
/// The default model to use when creating new contexts.
|
||||
default_model: Option<AssistantDefaultModel>,
|
||||
default_model: Option<LanguageModelSelection>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
pub struct AssistantDefaultModel {
|
||||
pub struct LanguageModelSelection {
|
||||
#[schemars(schema_with = "providers_schema")]
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
|
@ -407,7 +407,7 @@ fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema:
|
|||
.into()
|
||||
}
|
||||
|
||||
impl Default for AssistantDefaultModel {
|
||||
impl Default for LanguageModelSelection {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
provider: "openai".to_string(),
|
||||
|
@ -542,7 +542,7 @@ mod tests {
|
|||
assert!(!AssistantSettings::get_global(cx).using_outdated_settings_version);
|
||||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).default_model,
|
||||
AssistantDefaultModel {
|
||||
LanguageModelSelection {
|
||||
provider: "openai".into(),
|
||||
model: "gpt-4o".into(),
|
||||
}
|
||||
|
@ -555,7 +555,7 @@ mod tests {
|
|||
|settings, _| {
|
||||
*settings = AssistantSettingsContent::Versioned(
|
||||
VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 {
|
||||
default_model: Some(AssistantDefaultModel {
|
||||
default_model: Some(LanguageModelSelection {
|
||||
provider: "test-provider".into(),
|
||||
model: "gpt-99".into(),
|
||||
}),
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
|
||||
LanguageModelCompletionProvider, MessageId, MessageStatus,
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion, MessageId,
|
||||
MessageStatus,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_slash_command::{
|
||||
|
@ -18,7 +18,10 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
|
|||
use language::{
|
||||
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
|
||||
};
|
||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool,
|
||||
Role,
|
||||
};
|
||||
use open_ai::Model as OpenAiModel;
|
||||
use paths::contexts_dir;
|
||||
use project::Project;
|
||||
|
@ -1180,17 +1183,16 @@ impl Context {
|
|||
|
||||
pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
|
||||
let request = self.to_completion_request(cx);
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
return;
|
||||
};
|
||||
self.pending_token_count = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
cx.background_executor()
|
||||
.timer(Duration::from_millis(200))
|
||||
.await;
|
||||
|
||||
let token_count = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify()
|
||||
|
@ -1368,6 +1370,10 @@ impl Context {
|
|||
}
|
||||
}
|
||||
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
return Task::ready(Err(anyhow!("no active model")).log_err());
|
||||
};
|
||||
|
||||
let mut request = self.to_completion_request(cx);
|
||||
let edit_step_range = edit_step.source_range.clone();
|
||||
let step_text = self
|
||||
|
@ -1388,12 +1394,7 @@ impl Context {
|
|||
content: prompt,
|
||||
});
|
||||
|
||||
let tool_use = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
.use_tool::<EditTool>(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
let tool_use = model.use_tool::<EditTool>(request, &cx).await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let step_index = this
|
||||
|
@ -1568,6 +1569,8 @@ impl Context {
|
|||
}
|
||||
|
||||
pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
|
||||
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||
let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
|
||||
message
|
||||
.start
|
||||
|
@ -1575,14 +1578,12 @@ impl Context {
|
|||
.then_some(message.id)
|
||||
})?;
|
||||
|
||||
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
|
||||
if !provider.is_authenticated(cx) {
|
||||
log::info!("completion provider has no credentials");
|
||||
return None;
|
||||
}
|
||||
|
||||
let request = self.to_completion_request(cx);
|
||||
let stream =
|
||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
||||
let assistant_message = self
|
||||
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
||||
.unwrap();
|
||||
|
@ -1594,6 +1595,7 @@ impl Context {
|
|||
|
||||
let task = cx.spawn({
|
||||
|this, mut cx| async move {
|
||||
let stream = model.stream_completion(request, &cx);
|
||||
let assistant_message_id = assistant_message.id;
|
||||
let mut response_latency = None;
|
||||
let stream_completion = async {
|
||||
|
@ -1662,14 +1664,10 @@ impl Context {
|
|||
});
|
||||
|
||||
if let Some(telemetry) = this.telemetry.as_ref() {
|
||||
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model()
|
||||
.map(|m| m.telemetry_id())
|
||||
.unwrap_or_default();
|
||||
telemetry.report_assistant_event(
|
||||
Some(this.id.0.clone()),
|
||||
AssistantKind::Panel,
|
||||
model_telemetry_id,
|
||||
model.telemetry_id(),
|
||||
response_latency,
|
||||
error_message,
|
||||
);
|
||||
|
@ -1935,8 +1933,15 @@ impl Context {
|
|||
}
|
||||
|
||||
pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
|
||||
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
|
||||
return;
|
||||
};
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
return;
|
||||
};
|
||||
|
||||
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
|
||||
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
|
||||
if !provider.is_authenticated(cx) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1953,10 +1958,9 @@ impl Context {
|
|||
temperature: 1.0,
|
||||
};
|
||||
|
||||
let stream =
|
||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
||||
self.pending_summary = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let stream = model.stream_completion(request, &cx);
|
||||
let mut messages = stream.await?;
|
||||
|
||||
let mut replaced = !replace_old;
|
||||
|
@ -2490,7 +2494,6 @@ mod tests {
|
|||
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
language_model::LanguageModelRegistry::test(cx);
|
||||
completion::LanguageModelCompletionProvider::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
assistant_panel::init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
|
@ -2623,7 +2626,6 @@ mod tests {
|
|||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language_model::LanguageModelRegistry::test(cx);
|
||||
completion::LanguageModelCompletionProvider::test(cx);
|
||||
assistant_panel::init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
|
||||
|
@ -2717,7 +2719,6 @@ mod tests {
|
|||
fn test_messages_for_offsets(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
language_model::LanguageModelRegistry::test(cx);
|
||||
completion::LanguageModelCompletionProvider::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
assistant_panel::init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
|
@ -2803,7 +2804,6 @@ mod tests {
|
|||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
cx.update(Project::init_settings);
|
||||
cx.update(assistant_panel::init);
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
|
@ -2930,7 +2930,6 @@ mod tests {
|
|||
cx.set_global(settings_store);
|
||||
|
||||
let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
|
||||
let fake_model = fake_provider.test_model();
|
||||
cx.update(assistant_panel::init);
|
||||
|
@ -3032,7 +3031,6 @@ mod tests {
|
|||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
cx.update(assistant_panel::init);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||
let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
|
||||
|
@ -3109,7 +3107,6 @@ mod tests {
|
|||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
|
||||
cx.update(assistant_panel::init);
|
||||
let slash_commands = cx.update(SlashCommandRegistry::default_global);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent,
|
||||
Hunk, LanguageModelCompletionProvider, ModelSelector, StreamingDiff,
|
||||
Hunk, ModelSelector, StreamingDiff,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use client::telemetry::Telemetry;
|
||||
|
@ -27,7 +27,9 @@ use gpui::{
|
|||
WindowContext,
|
||||
};
|
||||
use language::{Buffer, IndentKind, Point, Selection, TransactionId};
|
||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use parking_lot::Mutex;
|
||||
use rope::Rope;
|
||||
|
@ -1328,7 +1330,7 @@ impl Render for PromptEditor {
|
|||
Tooltip::with_meta(
|
||||
format!(
|
||||
"Using {}",
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.active_model()
|
||||
.map(|model| model.name().0)
|
||||
.unwrap_or_else(|| "No model selected".into()),
|
||||
|
@ -1662,7 +1664,7 @@ impl PromptEditor {
|
|||
}
|
||||
|
||||
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
||||
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
|
||||
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||
let token_count = self.token_count?;
|
||||
let max_token_count = model.max_token_count();
|
||||
|
||||
|
@ -2013,8 +2015,12 @@ impl Codegen {
|
|||
assistant_panel_context: Option<LanguageModelRequest>,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
|
||||
let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
|
||||
model.count_tokens(request, cx)
|
||||
} else {
|
||||
future::ready(Err(anyhow!("no active model"))).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(
|
||||
|
@ -2024,6 +2030,10 @@ impl Codegen {
|
|||
assistant_panel_context: Option<LanguageModelRequest>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Result<()> {
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.active_model()
|
||||
.context("no active model")?;
|
||||
|
||||
self.undo(cx);
|
||||
|
||||
// Handle initial insertion
|
||||
|
@ -2053,10 +2063,7 @@ impl Codegen {
|
|||
None
|
||||
};
|
||||
|
||||
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model_telemetry_id()
|
||||
.context("no active model")?;
|
||||
|
||||
let telemetry_id = model.telemetry_id();
|
||||
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
|
||||
.trim()
|
||||
.to_lowercase()
|
||||
|
@ -2067,10 +2074,10 @@ impl Codegen {
|
|||
let request =
|
||||
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
|
||||
let chunks =
|
||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
||||
cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
|
||||
async move { Ok(chunks.await?.boxed()) }.boxed_local()
|
||||
};
|
||||
self.handle_stream(model_telemetry_id, edit_range, chunks, cx);
|
||||
self.handle_stream(telemetry_id, edit_range, chunks, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -2657,7 +2664,6 @@ mod tests {
|
|||
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
cx.update(language_settings::init);
|
||||
|
||||
let text = indoc! {"
|
||||
|
@ -2789,7 +2795,6 @@ mod tests {
|
|||
mut rng: StdRng,
|
||||
) {
|
||||
cx.update(LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
|
||||
|
@ -2853,7 +2858,6 @@ mod tests {
|
|||
#[gpui::test(iterations = 10)]
|
||||
async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
|
||||
cx.update(LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::{assistant_settings::AssistantSettings, LanguageModelCompletionProvider};
|
||||
use crate::assistant_settings::AssistantSettings;
|
||||
use fs::Fs;
|
||||
use gpui::SharedString;
|
||||
use language_model::LanguageModelRegistry;
|
||||
|
@ -81,13 +81,13 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
|
|||
}
|
||||
},
|
||||
{
|
||||
let provider = provider.id();
|
||||
let provider = provider.clone();
|
||||
move |cx| {
|
||||
LanguageModelCompletionProvider::global(cx).update(
|
||||
LanguageModelRegistry::global(cx).update(
|
||||
cx,
|
||||
|completion_provider, cx| {
|
||||
completion_provider
|
||||
.set_active_provider(provider.clone(), cx)
|
||||
.set_active_provider(Some(provider.clone()), cx);
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -95,12 +95,12 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
|
|||
);
|
||||
}
|
||||
|
||||
let selected_model = LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model()
|
||||
.map(|m| m.id());
|
||||
let selected_provider = LanguageModelCompletionProvider::read_global(cx)
|
||||
let selected_provider = LanguageModelRegistry::read_global(cx)
|
||||
.active_provider()
|
||||
.map(|m| m.id());
|
||||
let selected_model = LanguageModelRegistry::read_global(cx)
|
||||
.active_model()
|
||||
.map(|m| m.id());
|
||||
|
||||
for available_model in available_models {
|
||||
menu = menu.custom_entry(
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use crate::{
|
||||
slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant,
|
||||
LanguageModelCompletionProvider,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use assets::Assets;
|
||||
|
@ -19,7 +18,9 @@ use gpui::{
|
|||
};
|
||||
use heed::{types::SerdeBincode, Database, RoTxn};
|
||||
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
|
||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
};
|
||||
use parking_lot::RwLock;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use rope::Rope;
|
||||
|
@ -636,7 +637,10 @@ impl PromptLibrary {
|
|||
};
|
||||
|
||||
let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor;
|
||||
let provider = LanguageModelCompletionProvider::read_global(cx);
|
||||
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let initial_prompt = action.prompt.clone();
|
||||
if provider.is_authenticated(cx) {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
|
@ -725,6 +729,9 @@ impl PromptLibrary {
|
|||
}
|
||||
|
||||
fn count_tokens(&mut self, prompt_id: PromptId, cx: &mut ViewContext<Self>) {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
return;
|
||||
};
|
||||
if let Some(prompt) = self.prompt_editors.get_mut(&prompt_id) {
|
||||
let editor = &prompt.body_editor.read(cx);
|
||||
let buffer = &editor.buffer().read(cx).as_singleton().unwrap().read(cx);
|
||||
|
@ -736,7 +743,7 @@ impl PromptLibrary {
|
|||
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
|
||||
let token_count = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(
|
||||
model.count_tokens(
|
||||
LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
|
@ -804,7 +811,7 @@ impl PromptLibrary {
|
|||
let prompt_metadata = self.store.metadata(prompt_id)?;
|
||||
let prompt_editor = &self.prompt_editors[&prompt_id];
|
||||
let focus_handle = prompt_editor.body_editor.focus_handle(cx);
|
||||
let current_model = LanguageModelCompletionProvider::read_global(cx).active_model();
|
||||
let model = LanguageModelRegistry::read_global(cx).active_model();
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
|
||||
Some(
|
||||
|
@ -914,7 +921,7 @@ impl PromptLibrary {
|
|||
None,
|
||||
format!(
|
||||
"Model: {}",
|
||||
current_model
|
||||
model
|
||||
.as_ref()
|
||||
.map(|model| model
|
||||
.name()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
humanize_token_count, prompts::generate_terminal_assistant_prompt, AssistantPanel,
|
||||
AssistantPanelEvent, LanguageModelCompletionProvider, ModelSelector,
|
||||
AssistantPanelEvent, ModelSelector,
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::telemetry::Telemetry;
|
||||
|
@ -16,7 +16,9 @@ use gpui::{
|
|||
Subscription, Task, TextStyle, UpdateGlobal, View, WeakView,
|
||||
};
|
||||
use language::Buffer;
|
||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
};
|
||||
use settings::Settings;
|
||||
use std::{
|
||||
cmp,
|
||||
|
@ -556,7 +558,7 @@ impl Render for PromptEditor {
|
|||
Tooltip::with_meta(
|
||||
format!(
|
||||
"Using {}",
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.active_model()
|
||||
.map(|model| model.name().0)
|
||||
.unwrap_or_else(|| "No model selected".into()),
|
||||
|
@ -700,6 +702,9 @@ impl PromptEditor {
|
|||
|
||||
fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
|
||||
let assist_id = self.id;
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
return;
|
||||
};
|
||||
self.pending_token_count = cx.spawn(|this, mut cx| async move {
|
||||
cx.background_executor().timer(Duration::from_secs(1)).await;
|
||||
let request =
|
||||
|
@ -707,11 +712,7 @@ impl PromptEditor {
|
|||
inline_assistant.request_for_inline_assist(assist_id, cx)
|
||||
})??;
|
||||
|
||||
let token_count = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify();
|
||||
|
@ -840,7 +841,7 @@ impl PromptEditor {
|
|||
}
|
||||
|
||||
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
||||
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
|
||||
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||
let token_count = self.token_count?;
|
||||
let max_token_count = model.max_token_count();
|
||||
|
||||
|
@ -982,19 +983,16 @@ impl Codegen {
|
|||
}
|
||||
|
||||
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
|
||||
self.status = CodegenStatus::Pending;
|
||||
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let telemetry = self.telemetry.clone();
|
||||
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model()
|
||||
.map(|m| m.telemetry_id())
|
||||
.unwrap_or_default();
|
||||
let response =
|
||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(prompt, cx);
|
||||
|
||||
self.status = CodegenStatus::Pending;
|
||||
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
|
||||
self.generation = cx.spawn(|this, mut cx| async move {
|
||||
let response = response.await;
|
||||
let model_telemetry_id = model.telemetry_id();
|
||||
let response = model.stream_completion(prompt, &cx).await;
|
||||
let generate = async {
|
||||
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue