assistant: Overhaul provider infrastructure (#14929)
<img width="624" alt="image" src="https://github.com/user-attachments/assets/f492b0bd-14c3-49e2-b2ff-dc78e52b0815"> - [x] Correctly set custom model token count - [x] How to count tokens for Gemini models? - [x] Feature flag zed.dev provider - [x] Figure out how to configure custom models - [ ] Update docs Release Notes: - Added support for quickly switching between multiple language model providers in the assistant panel --------- Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
parent
17ef9a367f
commit
d0f52e90e6
55 changed files with 2757 additions and 2023 deletions
|
@ -15,20 +15,20 @@ use assistant_settings::AssistantSettings;
|
|||
use assistant_slash_command::SlashCommandRegistry;
|
||||
use client::{proto, Client};
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
use completion::CompletionProvider;
|
||||
use completion::LanguageModelCompletionProvider;
|
||||
pub use context::*;
|
||||
pub use context_store::*;
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
actions, impl_actions, AppContext, BorrowAppContext, Global, SharedString, UpdateGlobal,
|
||||
};
|
||||
use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
|
||||
use indexed_docs::IndexedDocsRegistry;
|
||||
pub(crate) use inline_assistant::*;
|
||||
use language_model::LanguageModelResponseMessage;
|
||||
use language_model::{
|
||||
LanguageModelId, LanguageModelProviderName, LanguageModelRegistry, LanguageModelResponseMessage,
|
||||
};
|
||||
pub(crate) use model_selector::*;
|
||||
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use settings::{update_settings_file, Settings, SettingsStore};
|
||||
use slash_command::{
|
||||
active_command, default_command, diagnostics_command, docs_command, fetch_command,
|
||||
file_command, now_command, project_command, prompt_command, search_command, symbols_command,
|
||||
|
@ -165,6 +165,16 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
|
|||
cx.set_global(Assistant::default());
|
||||
AssistantSettings::register(cx);
|
||||
|
||||
// TODO: remove this when 0.148.0 is released.
|
||||
if AssistantSettings::get_global(cx).using_outdated_settings_version {
|
||||
update_settings_file::<AssistantSettings>(fs.clone(), cx, {
|
||||
let fs = fs.clone();
|
||||
|content, cx| {
|
||||
content.update_file(fs, cx);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
cx.spawn(|mut cx| {
|
||||
let client = client.clone();
|
||||
async move {
|
||||
|
@ -182,7 +192,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
|
|||
|
||||
context_store::init(&client);
|
||||
prompt_library::init(cx);
|
||||
init_completion_provider(Arc::clone(&client), cx);
|
||||
init_completion_provider(cx);
|
||||
assistant_slash_command::init(cx);
|
||||
register_slash_commands(cx);
|
||||
assistant_panel::init(cx);
|
||||
|
@ -207,20 +217,38 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
|
|||
.detach();
|
||||
}
|
||||
|
||||
fn init_completion_provider(client: Arc<Client>, cx: &mut AppContext) {
|
||||
let provider = assistant_settings::create_provider_from_settings(client.clone(), 0, cx);
|
||||
cx.set_global(CompletionProvider::new(provider, Some(client)));
|
||||
fn init_completion_provider(cx: &mut AppContext) {
|
||||
completion::init(cx);
|
||||
update_active_language_model_from_settings(cx);
|
||||
|
||||
let mut settings_version = 0;
|
||||
cx.observe_global::<SettingsStore>(move |cx| {
|
||||
settings_version += 1;
|
||||
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
||||
assistant_settings::update_completion_provider_settings(provider, settings_version, cx);
|
||||
})
|
||||
cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
|
||||
.detach();
|
||||
cx.observe(&LanguageModelRegistry::global(cx), |_, cx| {
|
||||
update_active_language_model_from_settings(cx)
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn update_active_language_model_from_settings(cx: &mut AppContext) {
|
||||
let settings = AssistantSettings::get_global(cx);
|
||||
let provider_name = LanguageModelProviderName::from(settings.default_model.provider.clone());
|
||||
let model_id = LanguageModelId::from(settings.default_model.model.clone());
|
||||
|
||||
let Some(provider) = LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.provider(&provider_name)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let models = provider.provided_models(cx);
|
||||
if let Some(model) = models.iter().find(|model| model.id() == model_id).cloned() {
|
||||
LanguageModelCompletionProvider::global(cx).update(cx, |completion_provider, cx| {
|
||||
completion_provider.set_active_model(model, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn register_slash_commands(cx: &mut AppContext) {
|
||||
let slash_command_registry = SlashCommandRegistry::global(cx);
|
||||
slash_command_registry.register_command(file_command::FileSlashCommand, true);
|
||||
|
|
|
@ -18,7 +18,7 @@ use anyhow::{anyhow, Result};
|
|||
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
|
||||
use client::proto;
|
||||
use collections::{BTreeSet, HashMap, HashSet};
|
||||
use completion::CompletionProvider;
|
||||
use completion::LanguageModelCompletionProvider;
|
||||
use editor::{
|
||||
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
|
||||
display_map::{
|
||||
|
@ -364,13 +364,12 @@ impl AssistantPanel {
|
|||
cx.subscribe(&pane, Self::handle_pane_event),
|
||||
cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event),
|
||||
cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event),
|
||||
cx.observe_global::<CompletionProvider>({
|
||||
let mut prev_settings_version = CompletionProvider::global(cx).settings_version();
|
||||
move |this, cx| {
|
||||
this.completion_provider_changed(prev_settings_version, cx);
|
||||
prev_settings_version = CompletionProvider::global(cx).settings_version();
|
||||
}
|
||||
}),
|
||||
cx.observe(
|
||||
&LanguageModelCompletionProvider::global(cx),
|
||||
|this, _, cx| {
|
||||
this.completion_provider_changed(cx);
|
||||
},
|
||||
),
|
||||
];
|
||||
|
||||
Self {
|
||||
|
@ -483,37 +482,36 @@ impl AssistantPanel {
|
|||
}
|
||||
}
|
||||
|
||||
fn completion_provider_changed(
|
||||
&mut self,
|
||||
prev_settings_version: usize,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) {
|
||||
if self.is_authenticated(cx) {
|
||||
self.authentication_prompt = None;
|
||||
|
||||
match self.active_context_editor(cx) {
|
||||
Some(editor) => {
|
||||
editor.update(cx, |active_context, cx| {
|
||||
active_context
|
||||
.context
|
||||
.update(cx, |context, cx| context.completion_provider_changed(cx))
|
||||
});
|
||||
}
|
||||
None => {
|
||||
self.new_context(cx);
|
||||
}
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
} else if self.authentication_prompt.is_none()
|
||||
|| prev_settings_version != CompletionProvider::global(cx).settings_version()
|
||||
{
|
||||
self.authentication_prompt =
|
||||
Some(cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
||||
provider.authentication_prompt(cx)
|
||||
}));
|
||||
cx.notify();
|
||||
fn completion_provider_changed(&mut self, cx: &mut ViewContext<Self>) {
|
||||
if let Some(editor) = self.active_context_editor(cx) {
|
||||
editor.update(cx, |active_context, cx| {
|
||||
active_context
|
||||
.context
|
||||
.update(cx, |context, cx| context.completion_provider_changed(cx))
|
||||
})
|
||||
}
|
||||
|
||||
if self.active_context_editor(cx).is_none() {
|
||||
self.new_context(cx);
|
||||
}
|
||||
|
||||
let authentication_prompt = Self::authentication_prompt(cx);
|
||||
for context_editor in self.context_editors(cx) {
|
||||
context_editor.update(cx, |editor, cx| {
|
||||
editor.set_authentication_prompt(authentication_prompt.clone(), cx);
|
||||
});
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn authentication_prompt(cx: &mut WindowContext) -> Option<AnyView> {
|
||||
if let Some(provider) = LanguageModelCompletionProvider::read_global(cx).active_provider() {
|
||||
if !provider.is_authenticated(cx) {
|
||||
return Some(provider.authentication_prompt(cx));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn inline_assist(
|
||||
|
@ -774,7 +772,7 @@ impl AssistantPanel {
|
|||
}
|
||||
|
||||
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
|
||||
CompletionProvider::global(cx)
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
.reset_credentials(cx)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
@ -783,6 +781,13 @@ impl AssistantPanel {
|
|||
self.model_selector_menu_handle.toggle(cx);
|
||||
}
|
||||
|
||||
fn context_editors(&self, cx: &AppContext) -> Vec<View<ContextEditor>> {
|
||||
self.pane
|
||||
.read(cx)
|
||||
.items_of_type::<ContextEditor>()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn active_context_editor(&self, cx: &AppContext) -> Option<View<ContextEditor>> {
|
||||
self.pane
|
||||
.read(cx)
|
||||
|
@ -904,11 +909,11 @@ impl AssistantPanel {
|
|||
}
|
||||
|
||||
fn is_authenticated(&mut self, cx: &mut ViewContext<Self>) -> bool {
|
||||
CompletionProvider::global(cx).is_authenticated()
|
||||
LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx)
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
|
||||
cx.update_global::<CompletionProvider, _>(|provider, cx| provider.authenticate(cx))
|
||||
LanguageModelCompletionProvider::read_global(cx).authenticate(cx)
|
||||
}
|
||||
|
||||
fn render_signed_in(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
|
@ -968,14 +973,18 @@ impl Panel for AssistantPanel {
|
|||
}
|
||||
|
||||
fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
|
||||
settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
|
||||
let dock = match position {
|
||||
DockPosition::Left => AssistantDockPosition::Left,
|
||||
DockPosition::Bottom => AssistantDockPosition::Bottom,
|
||||
DockPosition::Right => AssistantDockPosition::Right,
|
||||
};
|
||||
settings.set_dock(dock);
|
||||
});
|
||||
settings::update_settings_file::<AssistantSettings>(
|
||||
self.fs.clone(),
|
||||
cx,
|
||||
move |settings, _| {
|
||||
let dock = match position {
|
||||
DockPosition::Left => AssistantDockPosition::Left,
|
||||
DockPosition::Bottom => AssistantDockPosition::Bottom,
|
||||
DockPosition::Right => AssistantDockPosition::Right,
|
||||
};
|
||||
settings.set_dock(dock);
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn size(&self, cx: &WindowContext) -> Pixels {
|
||||
|
@ -1074,6 +1083,7 @@ struct ActiveEditStep {
|
|||
|
||||
pub struct ContextEditor {
|
||||
context: Model<Context>,
|
||||
authentication_prompt: Option<AnyView>,
|
||||
fs: Arc<dyn Fs>,
|
||||
workspace: WeakView<Workspace>,
|
||||
project: Model<Project>,
|
||||
|
@ -1131,6 +1141,7 @@ impl ContextEditor {
|
|||
let sections = context.read(cx).slash_command_output_sections().to_vec();
|
||||
let mut this = Self {
|
||||
context,
|
||||
authentication_prompt: None,
|
||||
editor,
|
||||
lsp_adapter_delegate,
|
||||
blocks: Default::default(),
|
||||
|
@ -1150,6 +1161,15 @@ impl ContextEditor {
|
|||
this
|
||||
}
|
||||
|
||||
fn set_authentication_prompt(
|
||||
&mut self,
|
||||
authentication_prompt: Option<AnyView>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) {
|
||||
self.authentication_prompt = authentication_prompt;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn insert_default_prompt(&mut self, cx: &mut ViewContext<Self>) {
|
||||
let command_name = DefaultSlashCommand.name();
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
|
@ -1176,6 +1196,10 @@ impl ContextEditor {
|
|||
}
|
||||
|
||||
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
|
||||
if self.authentication_prompt.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
if !self.apply_edit_step(cx) {
|
||||
self.send_to_model(cx);
|
||||
}
|
||||
|
@ -2203,19 +2227,26 @@ impl Render for ContextEditor {
|
|||
.size_full()
|
||||
.v_flex()
|
||||
.child(
|
||||
div()
|
||||
.flex_grow()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.child(self.editor.clone())
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.absolute()
|
||||
.bottom_0()
|
||||
.p_4()
|
||||
.justify_end()
|
||||
.child(self.render_send_button(cx)),
|
||||
),
|
||||
if let Some(authentication_prompt) = self.authentication_prompt.as_ref() {
|
||||
div()
|
||||
.flex_grow()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.child(authentication_prompt.clone().into_any())
|
||||
} else {
|
||||
div()
|
||||
.flex_grow()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.child(self.editor.clone())
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.absolute()
|
||||
.bottom_0()
|
||||
.p_4()
|
||||
.justify_end()
|
||||
.child(self.render_send_button(cx)),
|
||||
)
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -2543,7 +2574,7 @@ impl ContextEditorToolbarItem {
|
|||
}
|
||||
|
||||
fn render_remaining_tokens(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
||||
let model = CompletionProvider::global(cx).model();
|
||||
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
|
||||
let context = &self
|
||||
.active_context_editor
|
||||
.as_ref()?
|
||||
|
|
|
@ -1,19 +1,14 @@
|
|||
use std::{sync::Arc, time::Duration};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anthropic::Model as AnthropicModel;
|
||||
use client::Client;
|
||||
use completion::{
|
||||
AnthropicCompletionProvider, CloudCompletionProvider, CompletionProvider,
|
||||
LanguageModelCompletionProvider, OllamaCompletionProvider, OpenAiCompletionProvider,
|
||||
};
|
||||
use fs::Fs;
|
||||
use gpui::{AppContext, Pixels};
|
||||
use language_model::{CloudModel, LanguageModel};
|
||||
use language_model::{settings::AllLanguageModelSettings, CloudModel, LanguageModel};
|
||||
use ollama::Model as OllamaModel;
|
||||
use open_ai::Model as OpenAiModel;
|
||||
use parking_lot::RwLock;
|
||||
use schemars::{schema::Schema, JsonSchema};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
use settings::{update_settings_file, Settings, SettingsSources};
|
||||
|
||||
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
|
@ -24,43 +19,9 @@ pub enum AssistantDockPosition {
|
|||
Bottom,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum AssistantProvider {
|
||||
ZedDotDev {
|
||||
model: CloudModel,
|
||||
},
|
||||
OpenAi {
|
||||
model: OpenAiModel,
|
||||
api_url: String,
|
||||
low_speed_timeout_in_seconds: Option<u64>,
|
||||
available_models: Vec<OpenAiModel>,
|
||||
},
|
||||
Anthropic {
|
||||
model: AnthropicModel,
|
||||
api_url: String,
|
||||
low_speed_timeout_in_seconds: Option<u64>,
|
||||
},
|
||||
Ollama {
|
||||
model: OllamaModel,
|
||||
api_url: String,
|
||||
low_speed_timeout_in_seconds: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for AssistantProvider {
|
||||
fn default() -> Self {
|
||||
Self::OpenAi {
|
||||
model: OpenAiModel::default(),
|
||||
api_url: open_ai::OPEN_AI_API_URL.into(),
|
||||
low_speed_timeout_in_seconds: None,
|
||||
available_models: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
#[serde(tag = "name", rename_all = "snake_case")]
|
||||
pub enum AssistantProviderContent {
|
||||
pub enum AssistantProviderContentV1 {
|
||||
#[serde(rename = "zed.dev")]
|
||||
ZedDotDev { default_model: Option<CloudModel> },
|
||||
#[serde(rename = "openai")]
|
||||
|
@ -91,7 +52,8 @@ pub struct AssistantSettings {
|
|||
pub dock: AssistantDockPosition,
|
||||
pub default_width: Pixels,
|
||||
pub default_height: Pixels,
|
||||
pub provider: AssistantProvider,
|
||||
pub default_model: AssistantDefaultModel,
|
||||
pub using_outdated_settings_version: bool,
|
||||
}
|
||||
|
||||
/// Assistant panel settings
|
||||
|
@ -123,34 +85,142 @@ impl Default for AssistantSettingsContent {
|
|||
}
|
||||
|
||||
impl AssistantSettingsContent {
|
||||
fn upgrade(&self) -> AssistantSettingsContentV1 {
|
||||
pub fn is_version_outdated(&self) -> bool {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||
VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
|
||||
VersionedAssistantSettingsContent::V1(_) => true,
|
||||
VersionedAssistantSettingsContent::V2(_) => false,
|
||||
},
|
||||
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
|
||||
AssistantSettingsContent::Legacy(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update_file(&mut self, fs: Arc<dyn Fs>, cx: &AppContext) {
|
||||
if let AssistantSettingsContent::Versioned(settings) = self {
|
||||
if let VersionedAssistantSettingsContent::V1(settings) = settings {
|
||||
if let Some(provider) = settings.provider.clone() {
|
||||
match provider {
|
||||
AssistantProviderContentV1::Anthropic {
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
..
|
||||
} => update_settings_file::<AllLanguageModelSettings>(
|
||||
fs,
|
||||
cx,
|
||||
move |content, _| {
|
||||
if content.anthropic.is_none() {
|
||||
content.anthropic =
|
||||
Some(language_model::settings::AnthropicSettingsContent {
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
},
|
||||
),
|
||||
AssistantProviderContentV1::Ollama {
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
..
|
||||
} => update_settings_file::<AllLanguageModelSettings>(
|
||||
fs,
|
||||
cx,
|
||||
move |content, _| {
|
||||
if content.ollama.is_none() {
|
||||
content.ollama =
|
||||
Some(language_model::settings::OllamaSettingsContent {
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
});
|
||||
}
|
||||
},
|
||||
),
|
||||
AssistantProviderContentV1::OpenAi {
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
available_models,
|
||||
..
|
||||
} => update_settings_file::<AllLanguageModelSettings>(
|
||||
fs,
|
||||
cx,
|
||||
move |content, _| {
|
||||
if content.open_ai.is_none() {
|
||||
content.open_ai =
|
||||
Some(language_model::settings::OpenAiSettingsContent {
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
available_models,
|
||||
});
|
||||
}
|
||||
},
|
||||
),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*self = AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(
|
||||
self.upgrade(),
|
||||
));
|
||||
}
|
||||
|
||||
fn upgrade(&self) -> AssistantSettingsContentV2 {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||
VersionedAssistantSettingsContent::V1(settings) => AssistantSettingsContentV2 {
|
||||
enabled: settings.enabled,
|
||||
button: settings.button,
|
||||
dock: settings.dock,
|
||||
default_width: settings.default_width,
|
||||
default_height: settings.default_width,
|
||||
default_model: settings
|
||||
.provider
|
||||
.clone()
|
||||
.and_then(|provider| match provider {
|
||||
AssistantProviderContentV1::ZedDotDev { default_model } => {
|
||||
default_model.map(|model| AssistantDefaultModel {
|
||||
provider: "zed.dev".to_string(),
|
||||
model: model.id().to_string(),
|
||||
})
|
||||
}
|
||||
AssistantProviderContentV1::OpenAi { default_model, .. } => {
|
||||
default_model.map(|model| AssistantDefaultModel {
|
||||
provider: "openai".to_string(),
|
||||
model: model.id().to_string(),
|
||||
})
|
||||
}
|
||||
AssistantProviderContentV1::Anthropic { default_model, .. } => {
|
||||
default_model.map(|model| AssistantDefaultModel {
|
||||
provider: "anthropic".to_string(),
|
||||
model: model.id().to_string(),
|
||||
})
|
||||
}
|
||||
AssistantProviderContentV1::Ollama { default_model, .. } => {
|
||||
default_model.map(|model| AssistantDefaultModel {
|
||||
provider: "ollama".to_string(),
|
||||
model: model.id().to_string(),
|
||||
})
|
||||
}
|
||||
}),
|
||||
},
|
||||
VersionedAssistantSettingsContent::V2(settings) => settings.clone(),
|
||||
},
|
||||
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV2 {
|
||||
enabled: None,
|
||||
button: settings.button,
|
||||
dock: settings.dock,
|
||||
default_width: settings.default_width,
|
||||
default_height: settings.default_height,
|
||||
provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
|
||||
Some(AssistantProviderContent::OpenAi {
|
||||
default_model: settings.default_open_ai_model.clone(),
|
||||
api_url: Some(open_ai_api_url.clone()),
|
||||
low_speed_timeout_in_seconds: None,
|
||||
available_models: Some(Default::default()),
|
||||
})
|
||||
} else {
|
||||
settings.default_open_ai_model.clone().map(|open_ai_model| {
|
||||
AssistantProviderContent::OpenAi {
|
||||
default_model: Some(open_ai_model),
|
||||
api_url: None,
|
||||
low_speed_timeout_in_seconds: None,
|
||||
available_models: Some(Default::default()),
|
||||
}
|
||||
})
|
||||
},
|
||||
default_model: Some(AssistantDefaultModel {
|
||||
provider: "openai".to_string(),
|
||||
model: settings
|
||||
.default_open_ai_model
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.id()
|
||||
.to_string(),
|
||||
}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -161,6 +231,9 @@ impl AssistantSettingsContent {
|
|||
VersionedAssistantSettingsContent::V1(settings) => {
|
||||
settings.dock = Some(dock);
|
||||
}
|
||||
VersionedAssistantSettingsContent::V2(settings) => {
|
||||
settings.dock = Some(dock);
|
||||
}
|
||||
},
|
||||
AssistantSettingsContent::Legacy(settings) => {
|
||||
settings.dock = Some(dock);
|
||||
|
@ -168,74 +241,78 @@ impl AssistantSettingsContent {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn set_model(&mut self, new_model: LanguageModel) {
|
||||
pub fn set_model(&mut self, language_model: Arc<dyn LanguageModel>) {
|
||||
let model = language_model.id().0.to_string();
|
||||
let provider = language_model.provider_name().0.to_string();
|
||||
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||
VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
|
||||
Some(AssistantProviderContent::ZedDotDev {
|
||||
default_model: model,
|
||||
}) => {
|
||||
if let LanguageModel::Cloud(new_model) = new_model {
|
||||
*model = Some(new_model);
|
||||
}
|
||||
VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() {
|
||||
"zed.dev" => {
|
||||
settings.provider = Some(AssistantProviderContentV1::ZedDotDev {
|
||||
default_model: CloudModel::from_id(&model).ok(),
|
||||
});
|
||||
}
|
||||
Some(AssistantProviderContent::OpenAi {
|
||||
default_model: model,
|
||||
..
|
||||
}) => {
|
||||
if let LanguageModel::OpenAi(new_model) = new_model {
|
||||
*model = Some(new_model);
|
||||
}
|
||||
"anthropic" => {
|
||||
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {
|
||||
Some(AssistantProviderContentV1::Anthropic {
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
..
|
||||
}) => (api_url.clone(), *low_speed_timeout_in_seconds),
|
||||
_ => (None, None),
|
||||
};
|
||||
settings.provider = Some(AssistantProviderContentV1::Anthropic {
|
||||
default_model: AnthropicModel::from_id(&model).ok(),
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
});
|
||||
}
|
||||
Some(AssistantProviderContent::Anthropic {
|
||||
default_model: model,
|
||||
..
|
||||
}) => {
|
||||
if let LanguageModel::Anthropic(new_model) = new_model {
|
||||
*model = Some(new_model);
|
||||
}
|
||||
"ollama" => {
|
||||
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {
|
||||
Some(AssistantProviderContentV1::Ollama {
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
..
|
||||
}) => (api_url.clone(), *low_speed_timeout_in_seconds),
|
||||
_ => (None, None),
|
||||
};
|
||||
settings.provider = Some(AssistantProviderContentV1::Ollama {
|
||||
default_model: Some(ollama::Model::new(&model)),
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
});
|
||||
}
|
||||
Some(AssistantProviderContent::Ollama {
|
||||
default_model: model,
|
||||
..
|
||||
}) => {
|
||||
if let LanguageModel::Ollama(new_model) = new_model {
|
||||
*model = Some(new_model);
|
||||
}
|
||||
"openai" => {
|
||||
let (api_url, low_speed_timeout_in_seconds, available_models) =
|
||||
match &settings.provider {
|
||||
Some(AssistantProviderContentV1::OpenAi {
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
available_models,
|
||||
..
|
||||
}) => (
|
||||
api_url.clone(),
|
||||
*low_speed_timeout_in_seconds,
|
||||
available_models.clone(),
|
||||
),
|
||||
_ => (None, None, None),
|
||||
};
|
||||
settings.provider = Some(AssistantProviderContentV1::OpenAi {
|
||||
default_model: open_ai::Model::from_id(&model).ok(),
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
available_models,
|
||||
});
|
||||
}
|
||||
provider => match new_model {
|
||||
LanguageModel::Cloud(model) => {
|
||||
*provider = Some(AssistantProviderContent::ZedDotDev {
|
||||
default_model: Some(model),
|
||||
})
|
||||
}
|
||||
LanguageModel::OpenAi(model) => {
|
||||
*provider = Some(AssistantProviderContent::OpenAi {
|
||||
default_model: Some(model),
|
||||
api_url: None,
|
||||
low_speed_timeout_in_seconds: None,
|
||||
available_models: Some(Default::default()),
|
||||
})
|
||||
}
|
||||
LanguageModel::Anthropic(model) => {
|
||||
*provider = Some(AssistantProviderContent::Anthropic {
|
||||
default_model: Some(model),
|
||||
api_url: None,
|
||||
low_speed_timeout_in_seconds: None,
|
||||
})
|
||||
}
|
||||
LanguageModel::Ollama(model) => {
|
||||
*provider = Some(AssistantProviderContent::Ollama {
|
||||
default_model: Some(model),
|
||||
api_url: None,
|
||||
low_speed_timeout_in_seconds: None,
|
||||
})
|
||||
}
|
||||
},
|
||||
_ => {}
|
||||
},
|
||||
VersionedAssistantSettingsContent::V2(settings) => {
|
||||
settings.default_model = Some(AssistantDefaultModel { provider, model });
|
||||
}
|
||||
},
|
||||
AssistantSettingsContent::Legacy(settings) => {
|
||||
if let LanguageModel::OpenAi(model) = new_model {
|
||||
if let Ok(model) = open_ai::Model::from_id(&language_model.id().0) {
|
||||
settings.default_open_ai_model = Some(model);
|
||||
}
|
||||
}
|
||||
|
@ -248,21 +325,78 @@ impl AssistantSettingsContent {
|
|||
pub enum VersionedAssistantSettingsContent {
|
||||
#[serde(rename = "1")]
|
||||
V1(AssistantSettingsContentV1),
|
||||
#[serde(rename = "2")]
|
||||
V2(AssistantSettingsContentV2),
|
||||
}
|
||||
|
||||
impl Default for VersionedAssistantSettingsContent {
|
||||
fn default() -> Self {
|
||||
Self::V1(AssistantSettingsContentV1 {
|
||||
Self::V2(AssistantSettingsContentV2 {
|
||||
enabled: None,
|
||||
button: None,
|
||||
dock: None,
|
||||
default_width: None,
|
||||
default_height: None,
|
||||
provider: None,
|
||||
default_model: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
pub struct AssistantSettingsContentV2 {
|
||||
/// Whether the Assistant is enabled.
|
||||
///
|
||||
/// Default: true
|
||||
enabled: Option<bool>,
|
||||
/// Whether to show the assistant panel button in the status bar.
|
||||
///
|
||||
/// Default: true
|
||||
button: Option<bool>,
|
||||
/// Where to dock the assistant.
|
||||
///
|
||||
/// Default: right
|
||||
dock: Option<AssistantDockPosition>,
|
||||
/// Default width in pixels when the assistant is docked to the left or right.
|
||||
///
|
||||
/// Default: 640
|
||||
default_width: Option<f32>,
|
||||
/// Default height in pixels when the assistant is docked to the bottom.
|
||||
///
|
||||
/// Default: 320
|
||||
default_height: Option<f32>,
|
||||
/// The default model to use when creating new contexts.
|
||||
default_model: Option<AssistantDefaultModel>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
pub struct AssistantDefaultModel {
|
||||
#[schemars(schema_with = "providers_schema")]
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
|
||||
schemars::schema::SchemaObject {
|
||||
enum_values: Some(vec![
|
||||
"anthropic".into(),
|
||||
"ollama".into(),
|
||||
"openai".into(),
|
||||
"zed.dev".into(),
|
||||
]),
|
||||
..Default::default()
|
||||
}
|
||||
.into()
|
||||
}
|
||||
|
||||
impl Default for AssistantDefaultModel {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
provider: "openai".to_string(),
|
||||
model: "gpt-4".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
pub struct AssistantSettingsContentV1 {
|
||||
/// Whether the Assistant is enabled.
|
||||
|
@ -289,7 +423,7 @@ pub struct AssistantSettingsContentV1 {
|
|||
///
|
||||
/// This can either be the internal `zed.dev` service or an external `openai` service,
|
||||
/// each with their respective default models and configurations.
|
||||
provider: Option<AssistantProviderContent>,
|
||||
provider: Option<AssistantProviderContentV1>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
|
@ -332,6 +466,10 @@ impl Settings for AssistantSettings {
|
|||
let mut settings = AssistantSettings::default();
|
||||
|
||||
for value in sources.defaults_and_customizations() {
|
||||
if value.is_version_outdated() {
|
||||
settings.using_outdated_settings_version = true;
|
||||
}
|
||||
|
||||
let value = value.upgrade();
|
||||
merge(&mut settings.enabled, value.enabled);
|
||||
merge(&mut settings.button, value.button);
|
||||
|
@ -344,123 +482,10 @@ impl Settings for AssistantSettings {
|
|||
&mut settings.default_height,
|
||||
value.default_height.map(Into::into),
|
||||
);
|
||||
if let Some(provider) = value.provider.clone() {
|
||||
match (&mut settings.provider, provider) {
|
||||
(
|
||||
AssistantProvider::ZedDotDev { model },
|
||||
AssistantProviderContent::ZedDotDev {
|
||||
default_model: model_override,
|
||||
},
|
||||
) => {
|
||||
merge(model, model_override);
|
||||
}
|
||||
(
|
||||
AssistantProvider::OpenAi {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
available_models,
|
||||
},
|
||||
AssistantProviderContent::OpenAi {
|
||||
default_model: model_override,
|
||||
api_url: api_url_override,
|
||||
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
|
||||
available_models: available_models_override,
|
||||
},
|
||||
) => {
|
||||
merge(model, model_override);
|
||||
merge(api_url, api_url_override);
|
||||
merge(available_models, available_models_override);
|
||||
if let Some(low_speed_timeout_in_seconds_override) =
|
||||
low_speed_timeout_in_seconds_override
|
||||
{
|
||||
*low_speed_timeout_in_seconds =
|
||||
Some(low_speed_timeout_in_seconds_override);
|
||||
}
|
||||
}
|
||||
(
|
||||
AssistantProvider::Ollama {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
},
|
||||
AssistantProviderContent::Ollama {
|
||||
default_model: model_override,
|
||||
api_url: api_url_override,
|
||||
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
|
||||
},
|
||||
) => {
|
||||
merge(model, model_override);
|
||||
merge(api_url, api_url_override);
|
||||
if let Some(low_speed_timeout_in_seconds_override) =
|
||||
low_speed_timeout_in_seconds_override
|
||||
{
|
||||
*low_speed_timeout_in_seconds =
|
||||
Some(low_speed_timeout_in_seconds_override);
|
||||
}
|
||||
}
|
||||
(
|
||||
AssistantProvider::Anthropic {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
},
|
||||
AssistantProviderContent::Anthropic {
|
||||
default_model: model_override,
|
||||
api_url: api_url_override,
|
||||
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
|
||||
},
|
||||
) => {
|
||||
merge(model, model_override);
|
||||
merge(api_url, api_url_override);
|
||||
if let Some(low_speed_timeout_in_seconds_override) =
|
||||
low_speed_timeout_in_seconds_override
|
||||
{
|
||||
*low_speed_timeout_in_seconds =
|
||||
Some(low_speed_timeout_in_seconds_override);
|
||||
}
|
||||
}
|
||||
(provider, provider_override) => {
|
||||
*provider = match provider_override {
|
||||
AssistantProviderContent::ZedDotDev {
|
||||
default_model: model,
|
||||
} => AssistantProvider::ZedDotDev {
|
||||
model: model.unwrap_or_default(),
|
||||
},
|
||||
AssistantProviderContent::OpenAi {
|
||||
default_model: model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
available_models,
|
||||
} => AssistantProvider::OpenAi {
|
||||
model: model.unwrap_or_default(),
|
||||
api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
|
||||
low_speed_timeout_in_seconds,
|
||||
available_models: available_models.unwrap_or_default(),
|
||||
},
|
||||
AssistantProviderContent::Anthropic {
|
||||
default_model: model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => AssistantProvider::Anthropic {
|
||||
model: model.unwrap_or_default(),
|
||||
api_url: api_url
|
||||
.unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
|
||||
low_speed_timeout_in_seconds,
|
||||
},
|
||||
AssistantProviderContent::Ollama {
|
||||
default_model: model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => AssistantProvider::Ollama {
|
||||
model: model.unwrap_or_default(),
|
||||
api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
|
||||
low_speed_timeout_in_seconds,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
merge(
|
||||
&mut settings.default_model,
|
||||
value.default_model.map(Into::into),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
|
@ -473,221 +498,103 @@ fn merge<T>(target: &mut T, value: Option<T>) {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn update_completion_provider_settings(
|
||||
provider: &mut CompletionProvider,
|
||||
version: usize,
|
||||
cx: &mut AppContext,
|
||||
) {
|
||||
let updated = match &AssistantSettings::get_global(cx).provider {
|
||||
AssistantProvider::ZedDotDev { model } => provider
|
||||
.update_current_as::<_, CloudCompletionProvider>(|provider| {
|
||||
provider.update(model.clone(), version);
|
||||
}),
|
||||
AssistantProvider::OpenAi {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
available_models,
|
||||
} => provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
|
||||
provider.update(
|
||||
choose_openai_model(&model, &available_models),
|
||||
api_url.clone(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
version,
|
||||
);
|
||||
}),
|
||||
AssistantProvider::Anthropic {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||
provider.update(
|
||||
model.clone(),
|
||||
api_url.clone(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
version,
|
||||
);
|
||||
}),
|
||||
AssistantProvider::Ollama {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
||||
provider.update(
|
||||
model.clone(),
|
||||
api_url.clone(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
version,
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
};
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use gpui::{AppContext, UpdateGlobal};
|
||||
// use settings::SettingsStore;
|
||||
|
||||
// Previously configured provider was changed to another one
|
||||
if updated.is_none() {
|
||||
provider.update_provider(|client| create_provider_from_settings(client, version, cx));
|
||||
}
|
||||
}
|
||||
// use super::*;
|
||||
|
||||
pub(crate) fn create_provider_from_settings(
|
||||
client: Arc<Client>,
|
||||
settings_version: usize,
|
||||
cx: &mut AppContext,
|
||||
) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
|
||||
match &AssistantSettings::get_global(cx).provider {
|
||||
AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
|
||||
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
|
||||
)),
|
||||
AssistantProvider::OpenAi {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
available_models,
|
||||
} => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
|
||||
choose_openai_model(&model, &available_models),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
available_models.clone(),
|
||||
))),
|
||||
AssistantProvider::Anthropic {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
|
||||
model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
))),
|
||||
AssistantProvider::Ollama {
|
||||
model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => Arc::new(RwLock::new(OllamaCompletionProvider::new(
|
||||
model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
cx,
|
||||
))),
|
||||
}
|
||||
}
|
||||
// #[gpui::test]
|
||||
// fn test_deserialize_assistant_settings(cx: &mut AppContext) {
|
||||
// let store = settings::SettingsStore::test(cx);
|
||||
// cx.set_global(store);
|
||||
|
||||
/// Choose which model to use for openai provider.
|
||||
/// If the model is not available, try to use the first available model, or fallback to the original model.
|
||||
fn choose_openai_model(
|
||||
model: &::open_ai::Model,
|
||||
available_models: &[::open_ai::Model],
|
||||
) -> ::open_ai::Model {
|
||||
available_models
|
||||
.iter()
|
||||
.find(|&m| m == model)
|
||||
.or_else(|| available_models.first())
|
||||
.unwrap_or_else(|| model)
|
||||
.clone()
|
||||
}
|
||||
// // Settings default to gpt-4-turbo.
|
||||
// AssistantSettings::register(cx);
|
||||
// assert_eq!(
|
||||
// AssistantSettings::get_global(cx).provider,
|
||||
// AssistantProvider::OpenAi {
|
||||
// model: OpenAiModel::FourOmni,
|
||||
// api_url: open_ai::OPEN_AI_API_URL.into(),
|
||||
// low_speed_timeout_in_seconds: None,
|
||||
// available_models: Default::default(),
|
||||
// }
|
||||
// );
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use gpui::{AppContext, UpdateGlobal};
|
||||
use settings::SettingsStore;
|
||||
// // Ensure backward-compatibility.
|
||||
// SettingsStore::update_global(cx, |store, cx| {
|
||||
// store
|
||||
// .set_user_settings(
|
||||
// r#"{
|
||||
// "assistant": {
|
||||
// "openai_api_url": "test-url",
|
||||
// }
|
||||
// }"#,
|
||||
// cx,
|
||||
// )
|
||||
// .unwrap();
|
||||
// });
|
||||
// assert_eq!(
|
||||
// AssistantSettings::get_global(cx).provider,
|
||||
// AssistantProvider::OpenAi {
|
||||
// model: OpenAiModel::FourOmni,
|
||||
// api_url: "test-url".into(),
|
||||
// low_speed_timeout_in_seconds: None,
|
||||
// available_models: Default::default(),
|
||||
// }
|
||||
// );
|
||||
// SettingsStore::update_global(cx, |store, cx| {
|
||||
// store
|
||||
// .set_user_settings(
|
||||
// r#"{
|
||||
// "assistant": {
|
||||
// "default_open_ai_model": "gpt-4-0613"
|
||||
// }
|
||||
// }"#,
|
||||
// cx,
|
||||
// )
|
||||
// .unwrap();
|
||||
// });
|
||||
// assert_eq!(
|
||||
// AssistantSettings::get_global(cx).provider,
|
||||
// AssistantProvider::OpenAi {
|
||||
// model: OpenAiModel::Four,
|
||||
// api_url: open_ai::OPEN_AI_API_URL.into(),
|
||||
// low_speed_timeout_in_seconds: None,
|
||||
// available_models: Default::default(),
|
||||
// }
|
||||
// );
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_deserialize_assistant_settings(cx: &mut AppContext) {
|
||||
let store = settings::SettingsStore::test(cx);
|
||||
cx.set_global(store);
|
||||
|
||||
// Settings default to gpt-4-turbo.
|
||||
AssistantSettings::register(cx);
|
||||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::OpenAi {
|
||||
model: OpenAiModel::FourOmni,
|
||||
api_url: open_ai::OPEN_AI_API_URL.into(),
|
||||
low_speed_timeout_in_seconds: None,
|
||||
available_models: Default::default(),
|
||||
}
|
||||
);
|
||||
|
||||
// Ensure backward-compatibility.
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store
|
||||
.set_user_settings(
|
||||
r#"{
|
||||
"assistant": {
|
||||
"openai_api_url": "test-url",
|
||||
}
|
||||
}"#,
|
||||
cx,
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::OpenAi {
|
||||
model: OpenAiModel::FourOmni,
|
||||
api_url: "test-url".into(),
|
||||
low_speed_timeout_in_seconds: None,
|
||||
available_models: Default::default(),
|
||||
}
|
||||
);
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store
|
||||
.set_user_settings(
|
||||
r#"{
|
||||
"assistant": {
|
||||
"default_open_ai_model": "gpt-4-0613"
|
||||
}
|
||||
}"#,
|
||||
cx,
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::OpenAi {
|
||||
model: OpenAiModel::Four,
|
||||
api_url: open_ai::OPEN_AI_API_URL.into(),
|
||||
low_speed_timeout_in_seconds: None,
|
||||
available_models: Default::default(),
|
||||
}
|
||||
);
|
||||
|
||||
// The new version supports setting a custom model when using zed.dev.
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store
|
||||
.set_user_settings(
|
||||
r#"{
|
||||
"assistant": {
|
||||
"version": "1",
|
||||
"provider": {
|
||||
"name": "zed.dev",
|
||||
"default_model": {
|
||||
"custom": {
|
||||
"name": "custom-provider"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}"#,
|
||||
cx,
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::ZedDotDev {
|
||||
model: CloudModel::Custom {
|
||||
name: "custom-provider".into(),
|
||||
max_tokens: None
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
// // The new version supports setting a custom model when using zed.dev.
|
||||
// SettingsStore::update_global(cx, |store, cx| {
|
||||
// store
|
||||
// .set_user_settings(
|
||||
// r#"{
|
||||
// "assistant": {
|
||||
// "version": "1",
|
||||
// "provider": {
|
||||
// "name": "zed.dev",
|
||||
// "default_model": {
|
||||
// "custom": {
|
||||
// "name": "custom-provider"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }"#,
|
||||
// cx,
|
||||
// )
|
||||
// .unwrap();
|
||||
// });
|
||||
// assert_eq!(
|
||||
// AssistantSettings::get_global(cx).provider,
|
||||
// AssistantProvider::ZedDotDev {
|
||||
// model: CloudModel::Custom {
|
||||
// name: "custom-provider".into(),
|
||||
// max_tokens: None
|
||||
// }
|
||||
// }
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, MessageId,
|
||||
MessageStatus,
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider,
|
||||
MessageId, MessageStatus,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_slash_command::{
|
||||
|
@ -1124,7 +1124,9 @@ impl Context {
|
|||
.await;
|
||||
|
||||
let token_count = cx
|
||||
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
|
@ -1308,7 +1310,9 @@ impl Context {
|
|||
});
|
||||
|
||||
let raw_output = cx
|
||||
.update(|cx| CompletionProvider::global(cx).complete(request, cx))?
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).complete(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let operations = Self::parse_edit_operations(&raw_output);
|
||||
|
@ -1612,13 +1616,14 @@ impl Context {
|
|||
.then_some(message.id)
|
||||
})?;
|
||||
|
||||
if !CompletionProvider::global(cx).is_authenticated() {
|
||||
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
|
||||
log::info!("completion provider has no credentials");
|
||||
return None;
|
||||
}
|
||||
|
||||
let request = self.to_completion_request(cx);
|
||||
let stream = CompletionProvider::global(cx).stream_completion(request, cx);
|
||||
let stream =
|
||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
||||
let assistant_message = self
|
||||
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
||||
.unwrap();
|
||||
|
@ -1698,11 +1703,14 @@ impl Context {
|
|||
});
|
||||
|
||||
if let Some(telemetry) = this.telemetry.as_ref() {
|
||||
let model = CompletionProvider::global(cx).model();
|
||||
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model()
|
||||
.map(|m| m.telemetry_id())
|
||||
.unwrap_or_default();
|
||||
telemetry.report_assistant_event(
|
||||
Some(this.id.0.clone()),
|
||||
AssistantKind::Panel,
|
||||
model.telemetry_id(),
|
||||
model_telemetry_id,
|
||||
response_latency,
|
||||
error_message,
|
||||
);
|
||||
|
@ -1727,7 +1735,6 @@ impl Context {
|
|||
.map(|message| message.to_request_message(self.buffer.read(cx)));
|
||||
|
||||
LanguageModelRequest {
|
||||
model: CompletionProvider::global(cx).model(),
|
||||
messages: messages.collect(),
|
||||
stop: vec![],
|
||||
temperature: 1.0,
|
||||
|
@ -1970,7 +1977,7 @@ impl Context {
|
|||
|
||||
pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
|
||||
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
|
||||
if !CompletionProvider::global(cx).is_authenticated() {
|
||||
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1982,13 +1989,13 @@ impl Context {
|
|||
content: "Summarize the context into a short title without punctuation.".into(),
|
||||
}));
|
||||
let request = LanguageModelRequest {
|
||||
model: CompletionProvider::global(cx).model(),
|
||||
messages: messages.collect(),
|
||||
stop: vec![],
|
||||
temperature: 1.0,
|
||||
};
|
||||
|
||||
let stream = CompletionProvider::global(cx).stream_completion(request, cx);
|
||||
let stream =
|
||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
||||
self.pending_summary = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let mut messages = stream.await?;
|
||||
|
@ -2504,7 +2511,6 @@ mod tests {
|
|||
MessageId,
|
||||
};
|
||||
use assistant_slash_command::{ArgumentCompletion, SlashCommand};
|
||||
use completion::FakeCompletionProvider;
|
||||
use fs::FakeFs;
|
||||
use gpui::{AppContext, TestAppContext, WeakView};
|
||||
use indoc::indoc;
|
||||
|
@ -2524,7 +2530,8 @@ mod tests {
|
|||
#[gpui::test]
|
||||
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
FakeCompletionProvider::setup_test(cx);
|
||||
language_model::LanguageModelRegistry::test(cx);
|
||||
completion::LanguageModelCompletionProvider::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
assistant_panel::init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
|
@ -2656,7 +2663,8 @@ mod tests {
|
|||
fn test_message_splitting(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
FakeCompletionProvider::setup_test(cx);
|
||||
language_model::LanguageModelRegistry::test(cx);
|
||||
completion::LanguageModelCompletionProvider::test(cx);
|
||||
assistant_panel::init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
|
||||
|
@ -2749,7 +2757,8 @@ mod tests {
|
|||
#[gpui::test]
|
||||
fn test_messages_for_offsets(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
FakeCompletionProvider::setup_test(cx);
|
||||
language_model::LanguageModelRegistry::test(cx);
|
||||
completion::LanguageModelCompletionProvider::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
assistant_panel::init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
|
@ -2834,7 +2843,8 @@ mod tests {
|
|||
async fn test_slash_commands(cx: &mut TestAppContext) {
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(FakeCompletionProvider::setup_test);
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
cx.update(Project::init_settings);
|
||||
cx.update(assistant_panel::init);
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
|
@ -2959,7 +2969,11 @@ mod tests {
|
|||
cx.update(prompt_library::init);
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
let fake_provider = cx.update(FakeCompletionProvider::setup_test);
|
||||
|
||||
let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
|
||||
let fake_model = fake_provider.test_model();
|
||||
cx.update(assistant_panel::init);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||
|
||||
|
@ -3025,8 +3039,8 @@ mod tests {
|
|||
});
|
||||
|
||||
// Simulate the LLM completion
|
||||
fake_provider.send_last_completion_chunk(llm_response.to_string());
|
||||
fake_provider.finish_last_completion();
|
||||
fake_model.send_last_completion_chunk(llm_response.to_string());
|
||||
fake_model.finish_last_completion();
|
||||
|
||||
// Wait for the completion to be processed
|
||||
cx.run_until_parked();
|
||||
|
@ -3107,7 +3121,8 @@ mod tests {
|
|||
async fn test_serialization(cx: &mut TestAppContext) {
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(FakeCompletionProvider::setup_test);
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
cx.update(assistant_panel::init);
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||
let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
|
||||
|
@ -3183,7 +3198,9 @@ mod tests {
|
|||
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(FakeCompletionProvider::setup_test);
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
|
||||
cx.update(assistant_panel::init);
|
||||
let slash_commands = cx.update(SlashCommandRegistry::default_global);
|
||||
slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt,
|
||||
AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, StreamingDiff,
|
||||
AssistantPanel, AssistantPanelEvent, Hunk, LanguageModelCompletionProvider, StreamingDiff,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use client::telemetry::Telemetry;
|
||||
|
@ -27,7 +27,9 @@ use gpui::{
|
|||
WindowContext,
|
||||
};
|
||||
use language::{Buffer, Point, Selection, TransactionId};
|
||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use parking_lot::Mutex;
|
||||
use rope::Rope;
|
||||
|
@ -844,7 +846,10 @@ impl InlineAssistant {
|
|||
}
|
||||
|
||||
let codegen = assist.codegen.clone();
|
||||
let telemetry_id = CompletionProvider::global(cx).model().telemetry_id();
|
||||
let telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model()
|
||||
.map(|m| m.telemetry_id())
|
||||
.unwrap_or_default();
|
||||
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
|
||||
if user_prompt.trim().to_lowercase() == "delete" {
|
||||
async { Ok(stream::empty().boxed()) }.boxed_local()
|
||||
|
@ -854,7 +859,10 @@ impl InlineAssistant {
|
|||
async move {
|
||||
let request = request.await?;
|
||||
let chunks = cx
|
||||
.update(|cx| CompletionProvider::global(cx).stream_completion(request, cx))?
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
.stream_completion(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
Ok(chunks.boxed())
|
||||
}
|
||||
|
@ -871,8 +879,8 @@ impl InlineAssistant {
|
|||
cx: &mut WindowContext,
|
||||
) -> Task<Result<LanguageModelRequest>> {
|
||||
cx.spawn(|mut cx| async move {
|
||||
let (user_prompt, context_request, project_name, buffer, range, model) = cx
|
||||
.read_global(|this: &InlineAssistant, cx: &WindowContext| {
|
||||
let (user_prompt, context_request, project_name, buffer, range) =
|
||||
cx.read_global(|this: &InlineAssistant, cx: &WindowContext| {
|
||||
let assist = this.assists.get(&assist_id).context("invalid assist")?;
|
||||
let decorations = assist.decorations.as_ref().context("invalid assist")?;
|
||||
let editor = assist.editor.upgrade().context("invalid assist")?;
|
||||
|
@ -906,15 +914,7 @@ impl InlineAssistant {
|
|||
});
|
||||
let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
|
||||
let range = assist.codegen.read(cx).range.clone();
|
||||
let model = CompletionProvider::global(cx).model();
|
||||
anyhow::Ok((
|
||||
user_prompt,
|
||||
context_request,
|
||||
project_name,
|
||||
buffer,
|
||||
range,
|
||||
model,
|
||||
))
|
||||
anyhow::Ok((user_prompt, context_request, project_name, buffer, range))
|
||||
})??;
|
||||
|
||||
let language = buffer.language_at(range.start);
|
||||
|
@ -973,7 +973,6 @@ impl InlineAssistant {
|
|||
});
|
||||
|
||||
Ok(LanguageModelRequest {
|
||||
model,
|
||||
messages,
|
||||
stop: vec!["|END|>".to_string()],
|
||||
temperature,
|
||||
|
@ -1432,24 +1431,39 @@ impl Render for PromptEditor {
|
|||
PopoverMenu::new("model-switcher")
|
||||
.menu(move |cx| {
|
||||
ContextMenu::build(cx, |mut menu, cx| {
|
||||
for model in CompletionProvider::global(cx).available_models() {
|
||||
for available_model in
|
||||
LanguageModelRegistry::read_global(cx).available_models(cx)
|
||||
{
|
||||
menu = menu.custom_entry(
|
||||
{
|
||||
let model = model.clone();
|
||||
let model_name = available_model.name().0.clone();
|
||||
let provider =
|
||||
available_model.provider_name().0.clone();
|
||||
move |_| {
|
||||
Label::new(model.display_name())
|
||||
.into_any_element()
|
||||
h_flex()
|
||||
.w_full()
|
||||
.justify_between()
|
||||
.child(Label::new(model_name.clone()))
|
||||
.child(
|
||||
div().ml_4().child(
|
||||
Label::new(provider.clone())
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
},
|
||||
{
|
||||
let fs = fs.clone();
|
||||
let model = model.clone();
|
||||
let model = available_model.clone();
|
||||
move |cx| {
|
||||
let model = model.clone();
|
||||
update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
cx,
|
||||
move |settings| settings.set_model(model),
|
||||
move |settings, _| {
|
||||
settings.set_model(model)
|
||||
},
|
||||
);
|
||||
}
|
||||
},
|
||||
|
@ -1468,9 +1482,10 @@ impl Render for PromptEditor {
|
|||
Tooltip::with_meta(
|
||||
format!(
|
||||
"Using {}",
|
||||
CompletionProvider::global(cx)
|
||||
.model()
|
||||
.display_name()
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model()
|
||||
.map(|model| model.name().0)
|
||||
.unwrap_or_else(|| "No model selected".into()),
|
||||
),
|
||||
None,
|
||||
"Change Model",
|
||||
|
@ -1668,7 +1683,9 @@ impl PromptEditor {
|
|||
.await?;
|
||||
|
||||
let token_count = cx
|
||||
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
|
@ -1796,7 +1813,7 @@ impl PromptEditor {
|
|||
}
|
||||
|
||||
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
||||
let model = CompletionProvider::global(cx).model();
|
||||
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
|
||||
let token_count = self.token_count?;
|
||||
let max_token_count = model.max_token_count();
|
||||
|
||||
|
@ -2601,7 +2618,6 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use completion::FakeCompletionProvider;
|
||||
use futures::stream::{self};
|
||||
use gpui::{Context, TestAppContext};
|
||||
use indoc::indoc;
|
||||
|
@ -2622,7 +2638,8 @@ mod tests {
|
|||
#[gpui::test(iterations = 10)]
|
||||
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(|cx| FakeCompletionProvider::setup_test(cx));
|
||||
cx.update(language_model::LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
cx.update(language_settings::init);
|
||||
|
||||
let text = indoc! {"
|
||||
|
@ -2749,7 +2766,8 @@ mod tests {
|
|||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
) {
|
||||
cx.update(|cx| FakeCompletionProvider::setup_test(cx));
|
||||
cx.update(LanguageModelRegistry::test);
|
||||
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::{assistant_settings::AssistantSettings, CompletionProvider, ToggleModelSelector};
|
||||
use crate::{
|
||||
assistant_settings::AssistantSettings, LanguageModelCompletionProvider, ToggleModelSelector,
|
||||
};
|
||||
use fs::Fs;
|
||||
use language_model::LanguageModelRegistry;
|
||||
use settings::update_settings_file;
|
||||
use ui::{prelude::*, ButtonLike, ContextMenu, PopoverMenu, PopoverMenuHandle, Tooltip};
|
||||
|
||||
|
@ -23,25 +26,64 @@ impl RenderOnce for ModelSelector {
|
|||
.with_handle(self.handle)
|
||||
.menu(move |cx| {
|
||||
ContextMenu::build(cx, |mut menu, cx| {
|
||||
for model in CompletionProvider::global(cx).available_models() {
|
||||
menu = menu.custom_entry(
|
||||
{
|
||||
let model = model.clone();
|
||||
move |_| Label::new(model.display_name()).into_any_element()
|
||||
},
|
||||
{
|
||||
let fs = self.fs.clone();
|
||||
let model = model.clone();
|
||||
move |cx| {
|
||||
let model = model.clone();
|
||||
update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
cx,
|
||||
move |settings| settings.set_model(model),
|
||||
);
|
||||
}
|
||||
},
|
||||
);
|
||||
for (provider, available_models) in LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.available_models_grouped_by_provider(cx)
|
||||
{
|
||||
menu = menu.header(provider.0.clone());
|
||||
|
||||
if available_models.is_empty() {
|
||||
menu = menu.custom_entry(
|
||||
{
|
||||
move |_| {
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Settings))
|
||||
.child(Label::new("Configure"))
|
||||
.into_any()
|
||||
}
|
||||
},
|
||||
{
|
||||
let provider = provider.clone();
|
||||
move |cx| {
|
||||
LanguageModelCompletionProvider::global(cx).update(
|
||||
cx,
|
||||
|completion_provider, cx| {
|
||||
completion_provider
|
||||
.set_active_provider(provider.clone(), cx)
|
||||
},
|
||||
);
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
for available_model in available_models {
|
||||
menu = menu.custom_entry(
|
||||
{
|
||||
let model_name = available_model.name().0.clone();
|
||||
move |_| {
|
||||
h_flex()
|
||||
.w_full()
|
||||
.child(Label::new(model_name.clone()))
|
||||
.into_any()
|
||||
}
|
||||
},
|
||||
{
|
||||
let fs = self.fs.clone();
|
||||
let model = available_model.clone();
|
||||
move |cx| {
|
||||
let model = model.clone();
|
||||
update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
cx,
|
||||
move |settings, _| settings.set_model(model),
|
||||
);
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
menu
|
||||
})
|
||||
|
@ -61,7 +103,10 @@ impl RenderOnce for ModelSelector {
|
|||
.whitespace_nowrap()
|
||||
.child(
|
||||
Label::new(
|
||||
CompletionProvider::global(cx).model().display_name(),
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model()
|
||||
.map(|model| model.name().0)
|
||||
.unwrap_or_else(|| "No model selected".into()),
|
||||
)
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider,
|
||||
InlineAssist, InlineAssistant,
|
||||
slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant,
|
||||
LanguageModelCompletionProvider,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use assets::Assets;
|
||||
|
@ -636,9 +636,9 @@ impl PromptLibrary {
|
|||
};
|
||||
|
||||
let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor;
|
||||
let provider = CompletionProvider::global(cx);
|
||||
let provider = LanguageModelCompletionProvider::read_global(cx);
|
||||
let initial_prompt = action.prompt.clone();
|
||||
if provider.is_authenticated() {
|
||||
if provider.is_authenticated(cx) {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
assistant.assist(&prompt_editor, None, None, initial_prompt, cx)
|
||||
})
|
||||
|
@ -736,11 +736,8 @@ impl PromptLibrary {
|
|||
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
|
||||
let token_count = cx
|
||||
.update(|cx| {
|
||||
let provider = CompletionProvider::global(cx);
|
||||
let model = provider.model();
|
||||
provider.count_tokens(
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(
|
||||
LanguageModelRequest {
|
||||
model,
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: body.to_string(),
|
||||
|
@ -806,7 +803,7 @@ impl PromptLibrary {
|
|||
let prompt_metadata = self.store.metadata(prompt_id)?;
|
||||
let prompt_editor = &self.prompt_editors[&prompt_id];
|
||||
let focus_handle = prompt_editor.body_editor.focus_handle(cx);
|
||||
let current_model = CompletionProvider::global(cx).model();
|
||||
let current_model = LanguageModelCompletionProvider::read_global(cx).active_model();
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
|
||||
Some(
|
||||
|
@ -917,7 +914,11 @@ impl PromptLibrary {
|
|||
format!(
|
||||
"Model: {}",
|
||||
current_model
|
||||
.display_name()
|
||||
.as_ref()
|
||||
.map(|model| model
|
||||
.name()
|
||||
.0)
|
||||
.unwrap_or_default()
|
||||
),
|
||||
cx,
|
||||
)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
assistant_settings::AssistantSettings, humanize_token_count,
|
||||
prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent,
|
||||
CompletionProvider,
|
||||
LanguageModelCompletionProvider,
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::telemetry::Telemetry;
|
||||
|
@ -17,7 +17,9 @@ use gpui::{
|
|||
Subscription, Task, TextStyle, UpdateGlobal, View, WeakView,
|
||||
};
|
||||
use language::Buffer;
|
||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
};
|
||||
use settings::{update_settings_file, Settings};
|
||||
use std::{
|
||||
cmp,
|
||||
|
@ -215,8 +217,6 @@ impl TerminalInlineAssistant {
|
|||
) -> Result<LanguageModelRequest> {
|
||||
let assist = self.assists.get(&assist_id).context("invalid assist")?;
|
||||
|
||||
let model = CompletionProvider::global(cx).model();
|
||||
|
||||
let shell = std::env::var("SHELL").ok();
|
||||
let working_directory = assist
|
||||
.terminal
|
||||
|
@ -268,7 +268,6 @@ impl TerminalInlineAssistant {
|
|||
});
|
||||
|
||||
Ok(LanguageModelRequest {
|
||||
model,
|
||||
messages,
|
||||
stop: Vec::new(),
|
||||
temperature: 1.0,
|
||||
|
@ -559,24 +558,39 @@ impl Render for PromptEditor {
|
|||
PopoverMenu::new("model-switcher")
|
||||
.menu(move |cx| {
|
||||
ContextMenu::build(cx, |mut menu, cx| {
|
||||
for model in CompletionProvider::global(cx).available_models() {
|
||||
for available_model in
|
||||
LanguageModelRegistry::read_global(cx).available_models(cx)
|
||||
{
|
||||
menu = menu.custom_entry(
|
||||
{
|
||||
let model = model.clone();
|
||||
let model_name = available_model.name().0.clone();
|
||||
let provider =
|
||||
available_model.provider_name().0.clone();
|
||||
move |_| {
|
||||
Label::new(model.display_name())
|
||||
.into_any_element()
|
||||
h_flex()
|
||||
.w_full()
|
||||
.justify_between()
|
||||
.child(Label::new(model_name.clone()))
|
||||
.child(
|
||||
div().ml_4().child(
|
||||
Label::new(provider.clone())
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
},
|
||||
{
|
||||
let fs = fs.clone();
|
||||
let model = model.clone();
|
||||
let model = available_model.clone();
|
||||
move |cx| {
|
||||
let model = model.clone();
|
||||
update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
cx,
|
||||
move |settings| settings.set_model(model),
|
||||
move |settings, _| {
|
||||
settings.set_model(model)
|
||||
},
|
||||
);
|
||||
}
|
||||
},
|
||||
|
@ -595,9 +609,10 @@ impl Render for PromptEditor {
|
|||
Tooltip::with_meta(
|
||||
format!(
|
||||
"Using {}",
|
||||
CompletionProvider::global(cx)
|
||||
.model()
|
||||
.display_name()
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model()
|
||||
.map(|model| model.name().0)
|
||||
.unwrap_or_else(|| "No model selected".into())
|
||||
),
|
||||
None,
|
||||
"Change Model",
|
||||
|
@ -748,7 +763,9 @@ impl PromptEditor {
|
|||
})??;
|
||||
|
||||
let token_count = cx
|
||||
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
|
@ -878,7 +895,7 @@ impl PromptEditor {
|
|||
}
|
||||
|
||||
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
||||
let model = CompletionProvider::global(cx).model();
|
||||
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
|
||||
let token_count = self.token_count?;
|
||||
let max_token_count = model.max_token_count();
|
||||
|
||||
|
@ -1023,8 +1040,12 @@ impl Codegen {
|
|||
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
|
||||
|
||||
let telemetry = self.telemetry.clone();
|
||||
let model_telemetry_id = prompt.model.telemetry_id();
|
||||
let response = CompletionProvider::global(cx).stream_completion(prompt, cx);
|
||||
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model()
|
||||
.map(|m| m.telemetry_id())
|
||||
.unwrap_or_default();
|
||||
let response =
|
||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(prompt, cx);
|
||||
|
||||
self.generation = cx.spawn(|this, mut cx| async move {
|
||||
let response = response.await;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue