assistant: Overhaul provider infrastructure (#14929)

<img width="624" alt="image"
src="https://github.com/user-attachments/assets/f492b0bd-14c3-49e2-b2ff-dc78e52b0815">

- [x] Correctly set custom model token count
- [x] How to count tokens for Gemini models?
- [x] Feature flag zed.dev provider
- [x] Figure out how to configure custom models
- [ ] Update docs

Release Notes:

- Added support for quickly switching between multiple language model
providers in the assistant panel

---------

Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2024-07-23 19:48:41 +02:00 committed by GitHub
parent 17ef9a367f
commit d0f52e90e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
55 changed files with 2757 additions and 2023 deletions

View file

@ -15,20 +15,20 @@ use assistant_settings::AssistantSettings;
use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
use completion::CompletionProvider;
use completion::LanguageModelCompletionProvider;
pub use context::*;
pub use context_store::*;
use fs::Fs;
use gpui::{
actions, impl_actions, AppContext, BorrowAppContext, Global, SharedString, UpdateGlobal,
};
use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
use indexed_docs::IndexedDocsRegistry;
pub(crate) use inline_assistant::*;
use language_model::LanguageModelResponseMessage;
use language_model::{
LanguageModelId, LanguageModelProviderName, LanguageModelRegistry, LanguageModelResponseMessage,
};
pub(crate) use model_selector::*;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use settings::{update_settings_file, Settings, SettingsStore};
use slash_command::{
active_command, default_command, diagnostics_command, docs_command, fetch_command,
file_command, now_command, project_command, prompt_command, search_command, symbols_command,
@ -165,6 +165,16 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
cx.set_global(Assistant::default());
AssistantSettings::register(cx);
// TODO: remove this when 0.148.0 is released.
if AssistantSettings::get_global(cx).using_outdated_settings_version {
update_settings_file::<AssistantSettings>(fs.clone(), cx, {
let fs = fs.clone();
|content, cx| {
content.update_file(fs, cx);
}
});
}
cx.spawn(|mut cx| {
let client = client.clone();
async move {
@ -182,7 +192,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
context_store::init(&client);
prompt_library::init(cx);
init_completion_provider(Arc::clone(&client), cx);
init_completion_provider(cx);
assistant_slash_command::init(cx);
register_slash_commands(cx);
assistant_panel::init(cx);
@ -207,20 +217,38 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
.detach();
}
fn init_completion_provider(client: Arc<Client>, cx: &mut AppContext) {
let provider = assistant_settings::create_provider_from_settings(client.clone(), 0, cx);
cx.set_global(CompletionProvider::new(provider, Some(client)));
fn init_completion_provider(cx: &mut AppContext) {
completion::init(cx);
update_active_language_model_from_settings(cx);
let mut settings_version = 0;
cx.observe_global::<SettingsStore>(move |cx| {
settings_version += 1;
cx.update_global::<CompletionProvider, _>(|provider, cx| {
assistant_settings::update_completion_provider_settings(provider, settings_version, cx);
})
cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
.detach();
cx.observe(&LanguageModelRegistry::global(cx), |_, cx| {
update_active_language_model_from_settings(cx)
})
.detach();
}
fn update_active_language_model_from_settings(cx: &mut AppContext) {
let settings = AssistantSettings::get_global(cx);
let provider_name = LanguageModelProviderName::from(settings.default_model.provider.clone());
let model_id = LanguageModelId::from(settings.default_model.model.clone());
let Some(provider) = LanguageModelRegistry::global(cx)
.read(cx)
.provider(&provider_name)
else {
return;
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| model.id() == model_id).cloned() {
LanguageModelCompletionProvider::global(cx).update(cx, |completion_provider, cx| {
completion_provider.set_active_model(model, cx);
});
}
}
fn register_slash_commands(cx: &mut AppContext) {
let slash_command_registry = SlashCommandRegistry::global(cx);
slash_command_registry.register_command(file_command::FileSlashCommand, true);

View file

@ -18,7 +18,7 @@ use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
use client::proto;
use collections::{BTreeSet, HashMap, HashSet};
use completion::CompletionProvider;
use completion::LanguageModelCompletionProvider;
use editor::{
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
display_map::{
@ -364,13 +364,12 @@ impl AssistantPanel {
cx.subscribe(&pane, Self::handle_pane_event),
cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event),
cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event),
cx.observe_global::<CompletionProvider>({
let mut prev_settings_version = CompletionProvider::global(cx).settings_version();
move |this, cx| {
this.completion_provider_changed(prev_settings_version, cx);
prev_settings_version = CompletionProvider::global(cx).settings_version();
}
}),
cx.observe(
&LanguageModelCompletionProvider::global(cx),
|this, _, cx| {
this.completion_provider_changed(cx);
},
),
];
Self {
@ -483,37 +482,36 @@ impl AssistantPanel {
}
}
fn completion_provider_changed(
&mut self,
prev_settings_version: usize,
cx: &mut ViewContext<Self>,
) {
if self.is_authenticated(cx) {
self.authentication_prompt = None;
match self.active_context_editor(cx) {
Some(editor) => {
editor.update(cx, |active_context, cx| {
active_context
.context
.update(cx, |context, cx| context.completion_provider_changed(cx))
});
}
None => {
self.new_context(cx);
}
}
cx.notify();
} else if self.authentication_prompt.is_none()
|| prev_settings_version != CompletionProvider::global(cx).settings_version()
{
self.authentication_prompt =
Some(cx.update_global::<CompletionProvider, _>(|provider, cx| {
provider.authentication_prompt(cx)
}));
cx.notify();
fn completion_provider_changed(&mut self, cx: &mut ViewContext<Self>) {
if let Some(editor) = self.active_context_editor(cx) {
editor.update(cx, |active_context, cx| {
active_context
.context
.update(cx, |context, cx| context.completion_provider_changed(cx))
})
}
if self.active_context_editor(cx).is_none() {
self.new_context(cx);
}
let authentication_prompt = Self::authentication_prompt(cx);
for context_editor in self.context_editors(cx) {
context_editor.update(cx, |editor, cx| {
editor.set_authentication_prompt(authentication_prompt.clone(), cx);
});
}
cx.notify();
}
fn authentication_prompt(cx: &mut WindowContext) -> Option<AnyView> {
if let Some(provider) = LanguageModelCompletionProvider::read_global(cx).active_provider() {
if !provider.is_authenticated(cx) {
return Some(provider.authentication_prompt(cx));
}
}
None
}
pub fn inline_assist(
@ -774,7 +772,7 @@ impl AssistantPanel {
}
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
CompletionProvider::global(cx)
LanguageModelCompletionProvider::read_global(cx)
.reset_credentials(cx)
.detach_and_log_err(cx);
}
@ -783,6 +781,13 @@ impl AssistantPanel {
self.model_selector_menu_handle.toggle(cx);
}
fn context_editors(&self, cx: &AppContext) -> Vec<View<ContextEditor>> {
self.pane
.read(cx)
.items_of_type::<ContextEditor>()
.collect()
}
fn active_context_editor(&self, cx: &AppContext) -> Option<View<ContextEditor>> {
self.pane
.read(cx)
@ -904,11 +909,11 @@ impl AssistantPanel {
}
fn is_authenticated(&mut self, cx: &mut ViewContext<Self>) -> bool {
CompletionProvider::global(cx).is_authenticated()
LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx)
}
fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
cx.update_global::<CompletionProvider, _>(|provider, cx| provider.authenticate(cx))
LanguageModelCompletionProvider::read_global(cx).authenticate(cx)
}
fn render_signed_in(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
@ -968,14 +973,18 @@ impl Panel for AssistantPanel {
}
fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
let dock = match position {
DockPosition::Left => AssistantDockPosition::Left,
DockPosition::Bottom => AssistantDockPosition::Bottom,
DockPosition::Right => AssistantDockPosition::Right,
};
settings.set_dock(dock);
});
settings::update_settings_file::<AssistantSettings>(
self.fs.clone(),
cx,
move |settings, _| {
let dock = match position {
DockPosition::Left => AssistantDockPosition::Left,
DockPosition::Bottom => AssistantDockPosition::Bottom,
DockPosition::Right => AssistantDockPosition::Right,
};
settings.set_dock(dock);
},
);
}
fn size(&self, cx: &WindowContext) -> Pixels {
@ -1074,6 +1083,7 @@ struct ActiveEditStep {
pub struct ContextEditor {
context: Model<Context>,
authentication_prompt: Option<AnyView>,
fs: Arc<dyn Fs>,
workspace: WeakView<Workspace>,
project: Model<Project>,
@ -1131,6 +1141,7 @@ impl ContextEditor {
let sections = context.read(cx).slash_command_output_sections().to_vec();
let mut this = Self {
context,
authentication_prompt: None,
editor,
lsp_adapter_delegate,
blocks: Default::default(),
@ -1150,6 +1161,15 @@ impl ContextEditor {
this
}
fn set_authentication_prompt(
&mut self,
authentication_prompt: Option<AnyView>,
cx: &mut ViewContext<Self>,
) {
self.authentication_prompt = authentication_prompt;
cx.notify();
}
fn insert_default_prompt(&mut self, cx: &mut ViewContext<Self>) {
let command_name = DefaultSlashCommand.name();
self.editor.update(cx, |editor, cx| {
@ -1176,6 +1196,10 @@ impl ContextEditor {
}
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
if self.authentication_prompt.is_some() {
return;
}
if !self.apply_edit_step(cx) {
self.send_to_model(cx);
}
@ -2203,19 +2227,26 @@ impl Render for ContextEditor {
.size_full()
.v_flex()
.child(
div()
.flex_grow()
.bg(cx.theme().colors().editor_background)
.child(self.editor.clone())
.child(
h_flex()
.w_full()
.absolute()
.bottom_0()
.p_4()
.justify_end()
.child(self.render_send_button(cx)),
),
if let Some(authentication_prompt) = self.authentication_prompt.as_ref() {
div()
.flex_grow()
.bg(cx.theme().colors().editor_background)
.child(authentication_prompt.clone().into_any())
} else {
div()
.flex_grow()
.bg(cx.theme().colors().editor_background)
.child(self.editor.clone())
.child(
h_flex()
.w_full()
.absolute()
.bottom_0()
.p_4()
.justify_end()
.child(self.render_send_button(cx)),
)
},
)
}
}
@ -2543,7 +2574,7 @@ impl ContextEditorToolbarItem {
}
fn render_remaining_tokens(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = CompletionProvider::global(cx).model();
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
let context = &self
.active_context_editor
.as_ref()?

View file

@ -1,19 +1,14 @@
use std::{sync::Arc, time::Duration};
use std::sync::Arc;
use anthropic::Model as AnthropicModel;
use client::Client;
use completion::{
AnthropicCompletionProvider, CloudCompletionProvider, CompletionProvider,
LanguageModelCompletionProvider, OllamaCompletionProvider, OpenAiCompletionProvider,
};
use fs::Fs;
use gpui::{AppContext, Pixels};
use language_model::{CloudModel, LanguageModel};
use language_model::{settings::AllLanguageModelSettings, CloudModel, LanguageModel};
use ollama::Model as OllamaModel;
use open_ai::Model as OpenAiModel;
use parking_lot::RwLock;
use schemars::{schema::Schema, JsonSchema};
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
use settings::{update_settings_file, Settings, SettingsSources};
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
@ -24,43 +19,9 @@ pub enum AssistantDockPosition {
Bottom,
}
#[derive(Debug, PartialEq)]
pub enum AssistantProvider {
ZedDotDev {
model: CloudModel,
},
OpenAi {
model: OpenAiModel,
api_url: String,
low_speed_timeout_in_seconds: Option<u64>,
available_models: Vec<OpenAiModel>,
},
Anthropic {
model: AnthropicModel,
api_url: String,
low_speed_timeout_in_seconds: Option<u64>,
},
Ollama {
model: OllamaModel,
api_url: String,
low_speed_timeout_in_seconds: Option<u64>,
},
}
impl Default for AssistantProvider {
fn default() -> Self {
Self::OpenAi {
model: OpenAiModel::default(),
api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None,
available_models: Default::default(),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(tag = "name", rename_all = "snake_case")]
pub enum AssistantProviderContent {
pub enum AssistantProviderContentV1 {
#[serde(rename = "zed.dev")]
ZedDotDev { default_model: Option<CloudModel> },
#[serde(rename = "openai")]
@ -91,7 +52,8 @@ pub struct AssistantSettings {
pub dock: AssistantDockPosition,
pub default_width: Pixels,
pub default_height: Pixels,
pub provider: AssistantProvider,
pub default_model: AssistantDefaultModel,
pub using_outdated_settings_version: bool,
}
/// Assistant panel settings
@ -123,34 +85,142 @@ impl Default for AssistantSettingsContent {
}
impl AssistantSettingsContent {
fn upgrade(&self) -> AssistantSettingsContentV1 {
pub fn is_version_outdated(&self) -> bool {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
VersionedAssistantSettingsContent::V1(_) => true,
VersionedAssistantSettingsContent::V2(_) => false,
},
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
AssistantSettingsContent::Legacy(_) => true,
}
}
pub fn update_file(&mut self, fs: Arc<dyn Fs>, cx: &AppContext) {
if let AssistantSettingsContent::Versioned(settings) = self {
if let VersionedAssistantSettingsContent::V1(settings) = settings {
if let Some(provider) = settings.provider.clone() {
match provider {
AssistantProviderContentV1::Anthropic {
api_url,
low_speed_timeout_in_seconds,
..
} => update_settings_file::<AllLanguageModelSettings>(
fs,
cx,
move |content, _| {
if content.anthropic.is_none() {
content.anthropic =
Some(language_model::settings::AnthropicSettingsContent {
api_url,
low_speed_timeout_in_seconds,
..Default::default()
});
}
},
),
AssistantProviderContentV1::Ollama {
api_url,
low_speed_timeout_in_seconds,
..
} => update_settings_file::<AllLanguageModelSettings>(
fs,
cx,
move |content, _| {
if content.ollama.is_none() {
content.ollama =
Some(language_model::settings::OllamaSettingsContent {
api_url,
low_speed_timeout_in_seconds,
});
}
},
),
AssistantProviderContentV1::OpenAi {
api_url,
low_speed_timeout_in_seconds,
available_models,
..
} => update_settings_file::<AllLanguageModelSettings>(
fs,
cx,
move |content, _| {
if content.open_ai.is_none() {
content.open_ai =
Some(language_model::settings::OpenAiSettingsContent {
api_url,
low_speed_timeout_in_seconds,
available_models,
});
}
},
),
_ => {}
}
}
}
}
*self = AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(
self.upgrade(),
));
}
fn upgrade(&self) -> AssistantSettingsContentV2 {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => AssistantSettingsContentV2 {
enabled: settings.enabled,
button: settings.button,
dock: settings.dock,
default_width: settings.default_width,
default_height: settings.default_width,
default_model: settings
.provider
.clone()
.and_then(|provider| match provider {
AssistantProviderContentV1::ZedDotDev { default_model } => {
default_model.map(|model| AssistantDefaultModel {
provider: "zed.dev".to_string(),
model: model.id().to_string(),
})
}
AssistantProviderContentV1::OpenAi { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel {
provider: "openai".to_string(),
model: model.id().to_string(),
})
}
AssistantProviderContentV1::Anthropic { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel {
provider: "anthropic".to_string(),
model: model.id().to_string(),
})
}
AssistantProviderContentV1::Ollama { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel {
provider: "ollama".to_string(),
model: model.id().to_string(),
})
}
}),
},
VersionedAssistantSettingsContent::V2(settings) => settings.clone(),
},
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV2 {
enabled: None,
button: settings.button,
dock: settings.dock,
default_width: settings.default_width,
default_height: settings.default_height,
provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
Some(AssistantProviderContent::OpenAi {
default_model: settings.default_open_ai_model.clone(),
api_url: Some(open_ai_api_url.clone()),
low_speed_timeout_in_seconds: None,
available_models: Some(Default::default()),
})
} else {
settings.default_open_ai_model.clone().map(|open_ai_model| {
AssistantProviderContent::OpenAi {
default_model: Some(open_ai_model),
api_url: None,
low_speed_timeout_in_seconds: None,
available_models: Some(Default::default()),
}
})
},
default_model: Some(AssistantDefaultModel {
provider: "openai".to_string(),
model: settings
.default_open_ai_model
.clone()
.unwrap_or_default()
.id()
.to_string(),
}),
},
}
}
@ -161,6 +231,9 @@ impl AssistantSettingsContent {
VersionedAssistantSettingsContent::V1(settings) => {
settings.dock = Some(dock);
}
VersionedAssistantSettingsContent::V2(settings) => {
settings.dock = Some(dock);
}
},
AssistantSettingsContent::Legacy(settings) => {
settings.dock = Some(dock);
@ -168,74 +241,78 @@ impl AssistantSettingsContent {
}
}
pub fn set_model(&mut self, new_model: LanguageModel) {
pub fn set_model(&mut self, language_model: Arc<dyn LanguageModel>) {
let model = language_model.id().0.to_string();
let provider = language_model.provider_name().0.to_string();
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
Some(AssistantProviderContent::ZedDotDev {
default_model: model,
}) => {
if let LanguageModel::Cloud(new_model) = new_model {
*model = Some(new_model);
}
VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() {
"zed.dev" => {
settings.provider = Some(AssistantProviderContentV1::ZedDotDev {
default_model: CloudModel::from_id(&model).ok(),
});
}
Some(AssistantProviderContent::OpenAi {
default_model: model,
..
}) => {
if let LanguageModel::OpenAi(new_model) = new_model {
*model = Some(new_model);
}
"anthropic" => {
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {
Some(AssistantProviderContentV1::Anthropic {
api_url,
low_speed_timeout_in_seconds,
..
}) => (api_url.clone(), *low_speed_timeout_in_seconds),
_ => (None, None),
};
settings.provider = Some(AssistantProviderContentV1::Anthropic {
default_model: AnthropicModel::from_id(&model).ok(),
api_url,
low_speed_timeout_in_seconds,
});
}
Some(AssistantProviderContent::Anthropic {
default_model: model,
..
}) => {
if let LanguageModel::Anthropic(new_model) = new_model {
*model = Some(new_model);
}
"ollama" => {
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {
Some(AssistantProviderContentV1::Ollama {
api_url,
low_speed_timeout_in_seconds,
..
}) => (api_url.clone(), *low_speed_timeout_in_seconds),
_ => (None, None),
};
settings.provider = Some(AssistantProviderContentV1::Ollama {
default_model: Some(ollama::Model::new(&model)),
api_url,
low_speed_timeout_in_seconds,
});
}
Some(AssistantProviderContent::Ollama {
default_model: model,
..
}) => {
if let LanguageModel::Ollama(new_model) = new_model {
*model = Some(new_model);
}
"openai" => {
let (api_url, low_speed_timeout_in_seconds, available_models) =
match &settings.provider {
Some(AssistantProviderContentV1::OpenAi {
api_url,
low_speed_timeout_in_seconds,
available_models,
..
}) => (
api_url.clone(),
*low_speed_timeout_in_seconds,
available_models.clone(),
),
_ => (None, None, None),
};
settings.provider = Some(AssistantProviderContentV1::OpenAi {
default_model: open_ai::Model::from_id(&model).ok(),
api_url,
low_speed_timeout_in_seconds,
available_models,
});
}
provider => match new_model {
LanguageModel::Cloud(model) => {
*provider = Some(AssistantProviderContent::ZedDotDev {
default_model: Some(model),
})
}
LanguageModel::OpenAi(model) => {
*provider = Some(AssistantProviderContent::OpenAi {
default_model: Some(model),
api_url: None,
low_speed_timeout_in_seconds: None,
available_models: Some(Default::default()),
})
}
LanguageModel::Anthropic(model) => {
*provider = Some(AssistantProviderContent::Anthropic {
default_model: Some(model),
api_url: None,
low_speed_timeout_in_seconds: None,
})
}
LanguageModel::Ollama(model) => {
*provider = Some(AssistantProviderContent::Ollama {
default_model: Some(model),
api_url: None,
low_speed_timeout_in_seconds: None,
})
}
},
_ => {}
},
VersionedAssistantSettingsContent::V2(settings) => {
settings.default_model = Some(AssistantDefaultModel { provider, model });
}
},
AssistantSettingsContent::Legacy(settings) => {
if let LanguageModel::OpenAi(model) = new_model {
if let Ok(model) = open_ai::Model::from_id(&language_model.id().0) {
settings.default_open_ai_model = Some(model);
}
}
@ -248,21 +325,78 @@ impl AssistantSettingsContent {
pub enum VersionedAssistantSettingsContent {
#[serde(rename = "1")]
V1(AssistantSettingsContentV1),
#[serde(rename = "2")]
V2(AssistantSettingsContentV2),
}
impl Default for VersionedAssistantSettingsContent {
fn default() -> Self {
Self::V1(AssistantSettingsContentV1 {
Self::V2(AssistantSettingsContentV2 {
enabled: None,
button: None,
dock: None,
default_width: None,
default_height: None,
provider: None,
default_model: None,
})
}
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
pub struct AssistantSettingsContentV2 {
/// Whether the Assistant is enabled.
///
/// Default: true
enabled: Option<bool>,
/// Whether to show the assistant panel button in the status bar.
///
/// Default: true
button: Option<bool>,
/// Where to dock the assistant.
///
/// Default: right
dock: Option<AssistantDockPosition>,
/// Default width in pixels when the assistant is docked to the left or right.
///
/// Default: 640
default_width: Option<f32>,
/// Default height in pixels when the assistant is docked to the bottom.
///
/// Default: 320
default_height: Option<f32>,
/// The default model to use when creating new contexts.
default_model: Option<AssistantDefaultModel>,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct AssistantDefaultModel {
#[schemars(schema_with = "providers_schema")]
pub provider: String,
pub model: String,
}
fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
schemars::schema::SchemaObject {
enum_values: Some(vec![
"anthropic".into(),
"ollama".into(),
"openai".into(),
"zed.dev".into(),
]),
..Default::default()
}
.into()
}
impl Default for AssistantDefaultModel {
fn default() -> Self {
Self {
provider: "openai".to_string(),
model: "gpt-4".to_string(),
}
}
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
pub struct AssistantSettingsContentV1 {
/// Whether the Assistant is enabled.
@ -289,7 +423,7 @@ pub struct AssistantSettingsContentV1 {
///
/// This can either be the internal `zed.dev` service or an external `openai` service,
/// each with their respective default models and configurations.
provider: Option<AssistantProviderContent>,
provider: Option<AssistantProviderContentV1>,
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
@ -332,6 +466,10 @@ impl Settings for AssistantSettings {
let mut settings = AssistantSettings::default();
for value in sources.defaults_and_customizations() {
if value.is_version_outdated() {
settings.using_outdated_settings_version = true;
}
let value = value.upgrade();
merge(&mut settings.enabled, value.enabled);
merge(&mut settings.button, value.button);
@ -344,123 +482,10 @@ impl Settings for AssistantSettings {
&mut settings.default_height,
value.default_height.map(Into::into),
);
if let Some(provider) = value.provider.clone() {
match (&mut settings.provider, provider) {
(
AssistantProvider::ZedDotDev { model },
AssistantProviderContent::ZedDotDev {
default_model: model_override,
},
) => {
merge(model, model_override);
}
(
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
},
AssistantProviderContent::OpenAi {
default_model: model_override,
api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
available_models: available_models_override,
},
) => {
merge(model, model_override);
merge(api_url, api_url_override);
merge(available_models, available_models_override);
if let Some(low_speed_timeout_in_seconds_override) =
low_speed_timeout_in_seconds_override
{
*low_speed_timeout_in_seconds =
Some(low_speed_timeout_in_seconds_override);
}
}
(
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
},
AssistantProviderContent::Ollama {
default_model: model_override,
api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
},
) => {
merge(model, model_override);
merge(api_url, api_url_override);
if let Some(low_speed_timeout_in_seconds_override) =
low_speed_timeout_in_seconds_override
{
*low_speed_timeout_in_seconds =
Some(low_speed_timeout_in_seconds_override);
}
}
(
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
},
AssistantProviderContent::Anthropic {
default_model: model_override,
api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
},
) => {
merge(model, model_override);
merge(api_url, api_url_override);
if let Some(low_speed_timeout_in_seconds_override) =
low_speed_timeout_in_seconds_override
{
*low_speed_timeout_in_seconds =
Some(low_speed_timeout_in_seconds_override);
}
}
(provider, provider_override) => {
*provider = match provider_override {
AssistantProviderContent::ZedDotDev {
default_model: model,
} => AssistantProvider::ZedDotDev {
model: model.unwrap_or_default(),
},
AssistantProviderContent::OpenAi {
default_model: model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => AssistantProvider::OpenAi {
model: model.unwrap_or_default(),
api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
low_speed_timeout_in_seconds,
available_models: available_models.unwrap_or_default(),
},
AssistantProviderContent::Anthropic {
default_model: model,
api_url,
low_speed_timeout_in_seconds,
} => AssistantProvider::Anthropic {
model: model.unwrap_or_default(),
api_url: api_url
.unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
low_speed_timeout_in_seconds,
},
AssistantProviderContent::Ollama {
default_model: model,
api_url,
low_speed_timeout_in_seconds,
} => AssistantProvider::Ollama {
model: model.unwrap_or_default(),
api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
low_speed_timeout_in_seconds,
},
};
}
}
}
merge(
&mut settings.default_model,
value.default_model.map(Into::into),
);
}
Ok(settings)
@ -473,221 +498,103 @@ fn merge<T>(target: &mut T, value: Option<T>) {
}
}
pub fn update_completion_provider_settings(
provider: &mut CompletionProvider,
version: usize,
cx: &mut AppContext,
) {
let updated = match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => provider
.update_current_as::<_, CloudCompletionProvider>(|provider| {
provider.update(model.clone(), version);
}),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
provider.update(
choose_openai_model(&model, &available_models),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
cx,
);
}),
};
// #[cfg(test)]
// mod tests {
// use gpui::{AppContext, UpdateGlobal};
// use settings::SettingsStore;
// Previously configured provider was changed to another one
if updated.is_none() {
provider.update_provider(|client| create_provider_from_settings(client, version, cx));
}
}
// use super::*;
pub(crate) fn create_provider_from_settings(
client: Arc<Client>,
settings_version: usize,
cx: &mut AppContext,
) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
)),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
choose_openai_model(&model, &available_models),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
available_models.clone(),
))),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
))),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => Arc::new(RwLock::new(OllamaCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
cx,
))),
}
}
// #[gpui::test]
// fn test_deserialize_assistant_settings(cx: &mut AppContext) {
// let store = settings::SettingsStore::test(cx);
// cx.set_global(store);
/// Choose which model to use for openai provider.
/// If the model is not available, try to use the first available model, or fallback to the original model.
fn choose_openai_model(
model: &::open_ai::Model,
available_models: &[::open_ai::Model],
) -> ::open_ai::Model {
available_models
.iter()
.find(|&m| m == model)
.or_else(|| available_models.first())
.unwrap_or_else(|| model)
.clone()
}
// // Settings default to gpt-4-turbo.
// AssistantSettings::register(cx);
// assert_eq!(
// AssistantSettings::get_global(cx).provider,
// AssistantProvider::OpenAi {
// model: OpenAiModel::FourOmni,
// api_url: open_ai::OPEN_AI_API_URL.into(),
// low_speed_timeout_in_seconds: None,
// available_models: Default::default(),
// }
// );
#[cfg(test)]
mod tests {
use gpui::{AppContext, UpdateGlobal};
use settings::SettingsStore;
// // Ensure backward-compatibility.
// SettingsStore::update_global(cx, |store, cx| {
// store
// .set_user_settings(
// r#"{
// "assistant": {
// "openai_api_url": "test-url",
// }
// }"#,
// cx,
// )
// .unwrap();
// });
// assert_eq!(
// AssistantSettings::get_global(cx).provider,
// AssistantProvider::OpenAi {
// model: OpenAiModel::FourOmni,
// api_url: "test-url".into(),
// low_speed_timeout_in_seconds: None,
// available_models: Default::default(),
// }
// );
// SettingsStore::update_global(cx, |store, cx| {
// store
// .set_user_settings(
// r#"{
// "assistant": {
// "default_open_ai_model": "gpt-4-0613"
// }
// }"#,
// cx,
// )
// .unwrap();
// });
// assert_eq!(
// AssistantSettings::get_global(cx).provider,
// AssistantProvider::OpenAi {
// model: OpenAiModel::Four,
// api_url: open_ai::OPEN_AI_API_URL.into(),
// low_speed_timeout_in_seconds: None,
// available_models: Default::default(),
// }
// );
use super::*;
#[gpui::test]
fn test_deserialize_assistant_settings(cx: &mut AppContext) {
let store = settings::SettingsStore::test(cx);
cx.set_global(store);
// Settings default to gpt-4-turbo.
AssistantSettings::register(cx);
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
model: OpenAiModel::FourOmni,
api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None,
available_models: Default::default(),
}
);
// Ensure backward-compatibility.
SettingsStore::update_global(cx, |store, cx| {
store
.set_user_settings(
r#"{
"assistant": {
"openai_api_url": "test-url",
}
}"#,
cx,
)
.unwrap();
});
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
model: OpenAiModel::FourOmni,
api_url: "test-url".into(),
low_speed_timeout_in_seconds: None,
available_models: Default::default(),
}
);
SettingsStore::update_global(cx, |store, cx| {
store
.set_user_settings(
r#"{
"assistant": {
"default_open_ai_model": "gpt-4-0613"
}
}"#,
cx,
)
.unwrap();
});
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
model: OpenAiModel::Four,
api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None,
available_models: Default::default(),
}
);
// The new version supports setting a custom model when using zed.dev.
SettingsStore::update_global(cx, |store, cx| {
store
.set_user_settings(
r#"{
"assistant": {
"version": "1",
"provider": {
"name": "zed.dev",
"default_model": {
"custom": {
"name": "custom-provider"
}
}
}
}
}"#,
cx,
)
.unwrap();
});
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::ZedDotDev {
model: CloudModel::Custom {
name: "custom-provider".into(),
max_tokens: None
}
}
);
}
}
// // The new version supports setting a custom model when using zed.dev.
// SettingsStore::update_global(cx, |store, cx| {
// store
// .set_user_settings(
// r#"{
// "assistant": {
// "version": "1",
// "provider": {
// "name": "zed.dev",
// "default_model": {
// "custom": {
// "name": "custom-provider"
// }
// }
// }
// }
// }"#,
// cx,
// )
// .unwrap();
// });
// assert_eq!(
// AssistantSettings::get_global(cx).provider,
// AssistantProvider::ZedDotDev {
// model: CloudModel::Custom {
// name: "custom-provider".into(),
// max_tokens: None
// }
// }
// );
// }
// }

View file

@ -1,6 +1,6 @@
use crate::{
prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, MessageId,
MessageStatus,
prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider,
MessageId, MessageStatus,
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
@ -1124,7 +1124,9 @@ impl Context {
.await;
let token_count = cx
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.await?;
this.update(&mut cx, |this, cx| {
@ -1308,7 +1310,9 @@ impl Context {
});
let raw_output = cx
.update(|cx| CompletionProvider::global(cx).complete(request, cx))?
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).complete(request, cx)
})?
.await?;
let operations = Self::parse_edit_operations(&raw_output);
@ -1612,13 +1616,14 @@ impl Context {
.then_some(message.id)
})?;
if !CompletionProvider::global(cx).is_authenticated() {
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
log::info!("completion provider has no credentials");
return None;
}
let request = self.to_completion_request(cx);
let stream = CompletionProvider::global(cx).stream_completion(request, cx);
let stream =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap();
@ -1698,11 +1703,14 @@ impl Context {
});
if let Some(telemetry) = this.telemetry.as_ref() {
let model = CompletionProvider::global(cx).model();
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
telemetry.report_assistant_event(
Some(this.id.0.clone()),
AssistantKind::Panel,
model.telemetry_id(),
model_telemetry_id,
response_latency,
error_message,
);
@ -1727,7 +1735,6 @@ impl Context {
.map(|message| message.to_request_message(self.buffer.read(cx)));
LanguageModelRequest {
model: CompletionProvider::global(cx).model(),
messages: messages.collect(),
stop: vec![],
temperature: 1.0,
@ -1970,7 +1977,7 @@ impl Context {
pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
if !CompletionProvider::global(cx).is_authenticated() {
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
return;
}
@ -1982,13 +1989,13 @@ impl Context {
content: "Summarize the context into a short title without punctuation.".into(),
}));
let request = LanguageModelRequest {
model: CompletionProvider::global(cx).model(),
messages: messages.collect(),
stop: vec![],
temperature: 1.0,
};
let stream = CompletionProvider::global(cx).stream_completion(request, cx);
let stream =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
self.pending_summary = cx.spawn(|this, mut cx| {
async move {
let mut messages = stream.await?;
@ -2504,7 +2511,6 @@ mod tests {
MessageId,
};
use assistant_slash_command::{ArgumentCompletion, SlashCommand};
use completion::FakeCompletionProvider;
use fs::FakeFs;
use gpui::{AppContext, TestAppContext, WeakView};
use indoc::indoc;
@ -2524,7 +2530,8 @@ mod tests {
#[gpui::test]
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
FakeCompletionProvider::setup_test(cx);
language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
cx.set_global(settings_store);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -2656,7 +2663,8 @@ mod tests {
fn test_message_splitting(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
FakeCompletionProvider::setup_test(cx);
language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -2749,7 +2757,8 @@ mod tests {
#[gpui::test]
fn test_messages_for_offsets(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
FakeCompletionProvider::setup_test(cx);
language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
cx.set_global(settings_store);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -2834,7 +2843,8 @@ mod tests {
async fn test_slash_commands(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(FakeCompletionProvider::setup_test);
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(Project::init_settings);
cx.update(assistant_panel::init);
let fs = FakeFs::new(cx.background_executor.clone());
@ -2959,7 +2969,11 @@ mod tests {
cx.update(prompt_library::init);
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
let fake_provider = cx.update(FakeCompletionProvider::setup_test);
let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
let fake_model = fake_provider.test_model();
cx.update(assistant_panel::init);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
@ -3025,8 +3039,8 @@ mod tests {
});
// Simulate the LLM completion
fake_provider.send_last_completion_chunk(llm_response.to_string());
fake_provider.finish_last_completion();
fake_model.send_last_completion_chunk(llm_response.to_string());
fake_model.finish_last_completion();
// Wait for the completion to be processed
cx.run_until_parked();
@ -3107,7 +3121,8 @@ mod tests {
async fn test_serialization(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(FakeCompletionProvider::setup_test);
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(assistant_panel::init);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
@ -3183,7 +3198,9 @@ mod tests {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(FakeCompletionProvider::setup_test);
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(assistant_panel::init);
let slash_commands = cx.update(SlashCommandRegistry::default_global);
slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);

View file

@ -1,6 +1,6 @@
use crate::{
assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt,
AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, StreamingDiff,
AssistantPanel, AssistantPanelEvent, Hunk, LanguageModelCompletionProvider, StreamingDiff,
};
use anyhow::{anyhow, Context as _, Result};
use client::telemetry::Telemetry;
@ -27,7 +27,9 @@ use gpui::{
WindowContext,
};
use language::{Buffer, Point, Selection, TransactionId};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use rope::Rope;
@ -844,7 +846,10 @@ impl InlineAssistant {
}
let codegen = assist.codegen.clone();
let telemetry_id = CompletionProvider::global(cx).model().telemetry_id();
let telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
if user_prompt.trim().to_lowercase() == "delete" {
async { Ok(stream::empty().boxed()) }.boxed_local()
@ -854,7 +859,10 @@ impl InlineAssistant {
async move {
let request = request.await?;
let chunks = cx
.update(|cx| CompletionProvider::global(cx).stream_completion(request, cx))?
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx)
.stream_completion(request, cx)
})?
.await?;
Ok(chunks.boxed())
}
@ -871,8 +879,8 @@ impl InlineAssistant {
cx: &mut WindowContext,
) -> Task<Result<LanguageModelRequest>> {
cx.spawn(|mut cx| async move {
let (user_prompt, context_request, project_name, buffer, range, model) = cx
.read_global(|this: &InlineAssistant, cx: &WindowContext| {
let (user_prompt, context_request, project_name, buffer, range) =
cx.read_global(|this: &InlineAssistant, cx: &WindowContext| {
let assist = this.assists.get(&assist_id).context("invalid assist")?;
let decorations = assist.decorations.as_ref().context("invalid assist")?;
let editor = assist.editor.upgrade().context("invalid assist")?;
@ -906,15 +914,7 @@ impl InlineAssistant {
});
let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
let range = assist.codegen.read(cx).range.clone();
let model = CompletionProvider::global(cx).model();
anyhow::Ok((
user_prompt,
context_request,
project_name,
buffer,
range,
model,
))
anyhow::Ok((user_prompt, context_request, project_name, buffer, range))
})??;
let language = buffer.language_at(range.start);
@ -973,7 +973,6 @@ impl InlineAssistant {
});
Ok(LanguageModelRequest {
model,
messages,
stop: vec!["|END|>".to_string()],
temperature,
@ -1432,24 +1431,39 @@ impl Render for PromptEditor {
PopoverMenu::new("model-switcher")
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models() {
for available_model in
LanguageModelRegistry::read_global(cx).available_models(cx)
{
menu = menu.custom_entry(
{
let model = model.clone();
let model_name = available_model.name().0.clone();
let provider =
available_model.provider_name().0.clone();
move |_| {
Label::new(model.display_name())
.into_any_element()
h_flex()
.w_full()
.justify_between()
.child(Label::new(model_name.clone()))
.child(
div().ml_4().child(
Label::new(provider.clone())
.color(Color::Muted),
),
)
.into_any()
}
},
{
let fs = fs.clone();
let model = model.clone();
let model = available_model.clone();
move |cx| {
let model = model.clone();
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings| settings.set_model(model),
move |settings, _| {
settings.set_model(model)
},
);
}
},
@ -1468,9 +1482,10 @@ impl Render for PromptEditor {
Tooltip::with_meta(
format!(
"Using {}",
CompletionProvider::global(cx)
.model()
.display_name()
LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()),
),
None,
"Change Model",
@ -1668,7 +1683,9 @@ impl PromptEditor {
.await?;
let token_count = cx
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
@ -1796,7 +1813,7 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = CompletionProvider::global(cx).model();
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
let token_count = self.token_count?;
let max_token_count = model.max_token_count();
@ -2601,7 +2618,6 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
#[cfg(test)]
mod tests {
use super::*;
use completion::FakeCompletionProvider;
use futures::stream::{self};
use gpui::{Context, TestAppContext};
use indoc::indoc;
@ -2622,7 +2638,8 @@ mod tests {
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
cx.set_global(cx.update(SettingsStore::test));
cx.update(|cx| FakeCompletionProvider::setup_test(cx));
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(language_settings::init);
let text = indoc! {"
@ -2749,7 +2766,8 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
cx.update(|cx| FakeCompletionProvider::setup_test(cx));
cx.update(LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);

View file

@ -1,7 +1,10 @@
use std::sync::Arc;
use crate::{assistant_settings::AssistantSettings, CompletionProvider, ToggleModelSelector};
use crate::{
assistant_settings::AssistantSettings, LanguageModelCompletionProvider, ToggleModelSelector,
};
use fs::Fs;
use language_model::LanguageModelRegistry;
use settings::update_settings_file;
use ui::{prelude::*, ButtonLike, ContextMenu, PopoverMenu, PopoverMenuHandle, Tooltip};
@ -23,25 +26,64 @@ impl RenderOnce for ModelSelector {
.with_handle(self.handle)
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models() {
menu = menu.custom_entry(
{
let model = model.clone();
move |_| Label::new(model.display_name()).into_any_element()
},
{
let fs = self.fs.clone();
let model = model.clone();
move |cx| {
let model = model.clone();
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings| settings.set_model(model),
);
}
},
);
for (provider, available_models) in LanguageModelRegistry::global(cx)
.read(cx)
.available_models_grouped_by_provider(cx)
{
menu = menu.header(provider.0.clone());
if available_models.is_empty() {
menu = menu.custom_entry(
{
move |_| {
h_flex()
.w_full()
.gap_1()
.child(Icon::new(IconName::Settings))
.child(Label::new("Configure"))
.into_any()
}
},
{
let provider = provider.clone();
move |cx| {
LanguageModelCompletionProvider::global(cx).update(
cx,
|completion_provider, cx| {
completion_provider
.set_active_provider(provider.clone(), cx)
},
);
}
},
);
}
for available_model in available_models {
menu = menu.custom_entry(
{
let model_name = available_model.name().0.clone();
move |_| {
h_flex()
.w_full()
.child(Label::new(model_name.clone()))
.into_any()
}
},
{
let fs = self.fs.clone();
let model = available_model.clone();
move |cx| {
let model = model.clone();
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings, _| settings.set_model(model),
);
}
},
);
}
}
menu
})
@ -61,7 +103,10 @@ impl RenderOnce for ModelSelector {
.whitespace_nowrap()
.child(
Label::new(
CompletionProvider::global(cx).model().display_name(),
LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()),
)
.size(LabelSize::Small)
.color(Color::Muted),

View file

@ -1,6 +1,6 @@
use crate::{
slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider,
InlineAssist, InlineAssistant,
slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant,
LanguageModelCompletionProvider,
};
use anyhow::{anyhow, Result};
use assets::Assets;
@ -636,9 +636,9 @@ impl PromptLibrary {
};
let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor;
let provider = CompletionProvider::global(cx);
let provider = LanguageModelCompletionProvider::read_global(cx);
let initial_prompt = action.prompt.clone();
if provider.is_authenticated() {
if provider.is_authenticated(cx) {
InlineAssistant::update_global(cx, |assistant, cx| {
assistant.assist(&prompt_editor, None, None, initial_prompt, cx)
})
@ -736,11 +736,8 @@ impl PromptLibrary {
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
let token_count = cx
.update(|cx| {
let provider = CompletionProvider::global(cx);
let model = provider.model();
provider.count_tokens(
LanguageModelCompletionProvider::read_global(cx).count_tokens(
LanguageModelRequest {
model,
messages: vec![LanguageModelRequestMessage {
role: Role::System,
content: body.to_string(),
@ -806,7 +803,7 @@ impl PromptLibrary {
let prompt_metadata = self.store.metadata(prompt_id)?;
let prompt_editor = &self.prompt_editors[&prompt_id];
let focus_handle = prompt_editor.body_editor.focus_handle(cx);
let current_model = CompletionProvider::global(cx).model();
let current_model = LanguageModelCompletionProvider::read_global(cx).active_model();
let settings = ThemeSettings::get_global(cx);
Some(
@ -917,7 +914,11 @@ impl PromptLibrary {
format!(
"Model: {}",
current_model
.display_name()
.as_ref()
.map(|model| model
.name()
.0)
.unwrap_or_default()
),
cx,
)

View file

@ -1,7 +1,7 @@
use crate::{
assistant_settings::AssistantSettings, humanize_token_count,
prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent,
CompletionProvider,
LanguageModelCompletionProvider,
};
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
@ -17,7 +17,9 @@ use gpui::{
Subscription, Task, TextStyle, UpdateGlobal, View, WeakView,
};
use language::Buffer;
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use settings::{update_settings_file, Settings};
use std::{
cmp,
@ -215,8 +217,6 @@ impl TerminalInlineAssistant {
) -> Result<LanguageModelRequest> {
let assist = self.assists.get(&assist_id).context("invalid assist")?;
let model = CompletionProvider::global(cx).model();
let shell = std::env::var("SHELL").ok();
let working_directory = assist
.terminal
@ -268,7 +268,6 @@ impl TerminalInlineAssistant {
});
Ok(LanguageModelRequest {
model,
messages,
stop: Vec::new(),
temperature: 1.0,
@ -559,24 +558,39 @@ impl Render for PromptEditor {
PopoverMenu::new("model-switcher")
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models() {
for available_model in
LanguageModelRegistry::read_global(cx).available_models(cx)
{
menu = menu.custom_entry(
{
let model = model.clone();
let model_name = available_model.name().0.clone();
let provider =
available_model.provider_name().0.clone();
move |_| {
Label::new(model.display_name())
.into_any_element()
h_flex()
.w_full()
.justify_between()
.child(Label::new(model_name.clone()))
.child(
div().ml_4().child(
Label::new(provider.clone())
.color(Color::Muted),
),
)
.into_any()
}
},
{
let fs = fs.clone();
let model = model.clone();
let model = available_model.clone();
move |cx| {
let model = model.clone();
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings| settings.set_model(model),
move |settings, _| {
settings.set_model(model)
},
);
}
},
@ -595,9 +609,10 @@ impl Render for PromptEditor {
Tooltip::with_meta(
format!(
"Using {}",
CompletionProvider::global(cx)
.model()
.display_name()
LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into())
),
None,
"Change Model",
@ -748,7 +763,9 @@ impl PromptEditor {
})??;
let token_count = cx
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
@ -878,7 +895,7 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = CompletionProvider::global(cx).model();
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
let token_count = self.token_count?;
let max_token_count = model.max_token_count();
@ -1023,8 +1040,12 @@ impl Codegen {
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
let telemetry = self.telemetry.clone();
let model_telemetry_id = prompt.model.telemetry_id();
let response = CompletionProvider::global(cx).stream_completion(prompt, cx);
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
let response =
LanguageModelCompletionProvider::read_global(cx).stream_completion(prompt, cx);
self.generation = cx.spawn(|this, mut cx| async move {
let response = response.await;