assistant: Overhaul provider infrastructure (#14929)
<img width="624" alt="image" src="https://github.com/user-attachments/assets/f492b0bd-14c3-49e2-b2ff-dc78e52b0815"> - [x] Correctly set custom model token count - [x] How to count tokens for Gemini models? - [x] Feature flag zed.dev provider - [x] Figure out how to configure custom models - [ ] Update docs Release Notes: - Added support for quickly switching between multiple language model providers in the assistant panel --------- Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
parent
17ef9a367f
commit
d0f52e90e6
55 changed files with 2757 additions and 2023 deletions
30
Cargo.lock
generated
30
Cargo.lock
generated
|
@ -2509,6 +2509,7 @@ dependencies = [
|
||||||
"http 0.1.0",
|
"http 0.1.0",
|
||||||
"indoc",
|
"indoc",
|
||||||
"language",
|
"language",
|
||||||
|
"language_model",
|
||||||
"live_kit_client",
|
"live_kit_client",
|
||||||
"live_kit_server",
|
"live_kit_server",
|
||||||
"log",
|
"log",
|
||||||
|
@ -2678,36 +2679,22 @@ dependencies = [
|
||||||
name = "completion"
|
name = "completion"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anthropic",
|
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"client",
|
|
||||||
"collections",
|
|
||||||
"ctor",
|
"ctor",
|
||||||
"editor",
|
"editor",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
"futures 0.3.28",
|
"futures 0.3.28",
|
||||||
"gpui",
|
"gpui",
|
||||||
"http 0.1.0",
|
|
||||||
"language",
|
"language",
|
||||||
"language_model",
|
"language_model",
|
||||||
"log",
|
|
||||||
"menu",
|
|
||||||
"ollama",
|
|
||||||
"open_ai",
|
|
||||||
"parking_lot",
|
|
||||||
"project",
|
"project",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
|
||||||
"settings",
|
"settings",
|
||||||
"smol",
|
"smol",
|
||||||
"strum",
|
|
||||||
"text",
|
"text",
|
||||||
"theme",
|
|
||||||
"tiktoken-rs",
|
|
||||||
"ui",
|
"ui",
|
||||||
"unindent",
|
"unindent",
|
||||||
"util",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -6040,11 +6027,19 @@ name = "language_model"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anthropic",
|
"anthropic",
|
||||||
|
"anyhow",
|
||||||
|
"client",
|
||||||
|
"collections",
|
||||||
"ctor",
|
"ctor",
|
||||||
"editor",
|
"editor",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
|
"feature_flags",
|
||||||
|
"futures 0.3.28",
|
||||||
|
"gpui",
|
||||||
|
"http 0.1.0",
|
||||||
"language",
|
"language",
|
||||||
"log",
|
"log",
|
||||||
|
"menu",
|
||||||
"ollama",
|
"ollama",
|
||||||
"open_ai",
|
"open_ai",
|
||||||
"project",
|
"project",
|
||||||
|
@ -6052,9 +6047,15 @@ dependencies = [
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
"schemars",
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"settings",
|
||||||
"strum",
|
"strum",
|
||||||
"text",
|
"text",
|
||||||
|
"theme",
|
||||||
|
"tiktoken-rs",
|
||||||
|
"ui",
|
||||||
"unindent",
|
"unindent",
|
||||||
|
"util",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -13802,6 +13803,7 @@ dependencies = [
|
||||||
"isahc",
|
"isahc",
|
||||||
"journal",
|
"journal",
|
||||||
"language",
|
"language",
|
||||||
|
"language_model",
|
||||||
"language_selector",
|
"language_selector",
|
||||||
"language_tools",
|
"language_tools",
|
||||||
"languages",
|
"languages",
|
||||||
|
|
|
@ -375,7 +375,7 @@
|
||||||
},
|
},
|
||||||
"assistant": {
|
"assistant": {
|
||||||
// Version of this setting.
|
// Version of this setting.
|
||||||
"version": "1",
|
"version": "2",
|
||||||
// Whether the assistant is enabled.
|
// Whether the assistant is enabled.
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
// Whether to show the assistant panel button in the status bar.
|
// Whether to show the assistant panel button in the status bar.
|
||||||
|
@ -386,18 +386,12 @@
|
||||||
"default_width": 640,
|
"default_width": 640,
|
||||||
// Default height when the assistant is docked to the bottom.
|
// Default height when the assistant is docked to the bottom.
|
||||||
"default_height": 320,
|
"default_height": 320,
|
||||||
// AI provider.
|
// The default model to use when creating new contexts.
|
||||||
"provider": {
|
"default_model": {
|
||||||
"name": "openai",
|
// The provider to use.
|
||||||
// The default model to use when creating new contexts. This
|
"provider": "openai",
|
||||||
// setting can take three values:
|
// The model to use.
|
||||||
//
|
"model": "gpt-4o"
|
||||||
// 1. "gpt-3.5-turbo"
|
|
||||||
// 2. "gpt-4"
|
|
||||||
// 3. "gpt-4-turbo-preview"
|
|
||||||
// 4. "gpt-4o"
|
|
||||||
// 5. "gpt-4o-mini"
|
|
||||||
"default_model": "gpt-4o"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// Whether the screen sharing icon is shown in the os status bar.
|
// Whether the screen sharing icon is shown in the os status bar.
|
||||||
|
@ -858,6 +852,8 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
// Different settings for specific language models.
|
||||||
|
"language_models": {},
|
||||||
// Zed's Prettier integration settings.
|
// Zed's Prettier integration settings.
|
||||||
// Allows to enable/disable formatting with Prettier
|
// Allows to enable/disable formatting with Prettier
|
||||||
// and configure default Prettier, used when no project-level Prettier installation is found.
|
// and configure default Prettier, used when no project-level Prettier installation is found.
|
||||||
|
|
|
@ -21,11 +21,7 @@ pub enum Model {
|
||||||
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
|
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
|
||||||
Claude3Haiku,
|
Claude3Haiku,
|
||||||
#[serde(rename = "custom")]
|
#[serde(rename = "custom")]
|
||||||
Custom {
|
Custom { name: String, max_tokens: usize },
|
||||||
name: String,
|
|
||||||
#[serde(default)]
|
|
||||||
max_tokens: Option<usize>,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
|
@ -39,10 +35,7 @@ impl Model {
|
||||||
} else if id.starts_with("claude-3-haiku") {
|
} else if id.starts_with("claude-3-haiku") {
|
||||||
Ok(Self::Claude3Haiku)
|
Ok(Self::Claude3Haiku)
|
||||||
} else {
|
} else {
|
||||||
Ok(Self::Custom {
|
Err(anyhow!("invalid model id"))
|
||||||
name: id.to_string(),
|
|
||||||
max_tokens: None,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,7 +45,7 @@ impl Model {
|
||||||
Model::Claude3Opus => "claude-3-opus-20240229",
|
Model::Claude3Opus => "claude-3-opus-20240229",
|
||||||
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
|
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
|
||||||
Model::Claude3Haiku => "claude-3-opus-20240307",
|
Model::Claude3Haiku => "claude-3-opus-20240307",
|
||||||
Model::Custom { name, .. } => name,
|
Self::Custom { name, .. } => name,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,7 +65,7 @@ impl Model {
|
||||||
| Self::Claude3Opus
|
| Self::Claude3Opus
|
||||||
| Self::Claude3Sonnet
|
| Self::Claude3Sonnet
|
||||||
| Self::Claude3Haiku => 200_000,
|
| Self::Claude3Haiku => 200_000,
|
||||||
Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
|
Self::Custom { max_tokens, .. } => *max_tokens,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,20 +15,20 @@ use assistant_settings::AssistantSettings;
|
||||||
use assistant_slash_command::SlashCommandRegistry;
|
use assistant_slash_command::SlashCommandRegistry;
|
||||||
use client::{proto, Client};
|
use client::{proto, Client};
|
||||||
use command_palette_hooks::CommandPaletteFilter;
|
use command_palette_hooks::CommandPaletteFilter;
|
||||||
use completion::CompletionProvider;
|
use completion::LanguageModelCompletionProvider;
|
||||||
pub use context::*;
|
pub use context::*;
|
||||||
pub use context_store::*;
|
pub use context_store::*;
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
use gpui::{
|
use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
|
||||||
actions, impl_actions, AppContext, BorrowAppContext, Global, SharedString, UpdateGlobal,
|
|
||||||
};
|
|
||||||
use indexed_docs::IndexedDocsRegistry;
|
use indexed_docs::IndexedDocsRegistry;
|
||||||
pub(crate) use inline_assistant::*;
|
pub(crate) use inline_assistant::*;
|
||||||
use language_model::LanguageModelResponseMessage;
|
use language_model::{
|
||||||
|
LanguageModelId, LanguageModelProviderName, LanguageModelRegistry, LanguageModelResponseMessage,
|
||||||
|
};
|
||||||
pub(crate) use model_selector::*;
|
pub(crate) use model_selector::*;
|
||||||
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
|
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{update_settings_file, Settings, SettingsStore};
|
||||||
use slash_command::{
|
use slash_command::{
|
||||||
active_command, default_command, diagnostics_command, docs_command, fetch_command,
|
active_command, default_command, diagnostics_command, docs_command, fetch_command,
|
||||||
file_command, now_command, project_command, prompt_command, search_command, symbols_command,
|
file_command, now_command, project_command, prompt_command, search_command, symbols_command,
|
||||||
|
@ -165,6 +165,16 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
|
||||||
cx.set_global(Assistant::default());
|
cx.set_global(Assistant::default());
|
||||||
AssistantSettings::register(cx);
|
AssistantSettings::register(cx);
|
||||||
|
|
||||||
|
// TODO: remove this when 0.148.0 is released.
|
||||||
|
if AssistantSettings::get_global(cx).using_outdated_settings_version {
|
||||||
|
update_settings_file::<AssistantSettings>(fs.clone(), cx, {
|
||||||
|
let fs = fs.clone();
|
||||||
|
|content, cx| {
|
||||||
|
content.update_file(fs, cx);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
cx.spawn(|mut cx| {
|
cx.spawn(|mut cx| {
|
||||||
let client = client.clone();
|
let client = client.clone();
|
||||||
async move {
|
async move {
|
||||||
|
@ -182,7 +192,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
|
||||||
|
|
||||||
context_store::init(&client);
|
context_store::init(&client);
|
||||||
prompt_library::init(cx);
|
prompt_library::init(cx);
|
||||||
init_completion_provider(Arc::clone(&client), cx);
|
init_completion_provider(cx);
|
||||||
assistant_slash_command::init(cx);
|
assistant_slash_command::init(cx);
|
||||||
register_slash_commands(cx);
|
register_slash_commands(cx);
|
||||||
assistant_panel::init(cx);
|
assistant_panel::init(cx);
|
||||||
|
@ -207,20 +217,38 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn init_completion_provider(client: Arc<Client>, cx: &mut AppContext) {
|
fn init_completion_provider(cx: &mut AppContext) {
|
||||||
let provider = assistant_settings::create_provider_from_settings(client.clone(), 0, cx);
|
completion::init(cx);
|
||||||
cx.set_global(CompletionProvider::new(provider, Some(client)));
|
update_active_language_model_from_settings(cx);
|
||||||
|
|
||||||
let mut settings_version = 0;
|
cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
|
||||||
cx.observe_global::<SettingsStore>(move |cx| {
|
.detach();
|
||||||
settings_version += 1;
|
cx.observe(&LanguageModelRegistry::global(cx), |_, cx| {
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
update_active_language_model_from_settings(cx)
|
||||||
assistant_settings::update_completion_provider_settings(provider, settings_version, cx);
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn update_active_language_model_from_settings(cx: &mut AppContext) {
|
||||||
|
let settings = AssistantSettings::get_global(cx);
|
||||||
|
let provider_name = LanguageModelProviderName::from(settings.default_model.provider.clone());
|
||||||
|
let model_id = LanguageModelId::from(settings.default_model.model.clone());
|
||||||
|
|
||||||
|
let Some(provider) = LanguageModelRegistry::global(cx)
|
||||||
|
.read(cx)
|
||||||
|
.provider(&provider_name)
|
||||||
|
else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let models = provider.provided_models(cx);
|
||||||
|
if let Some(model) = models.iter().find(|model| model.id() == model_id).cloned() {
|
||||||
|
LanguageModelCompletionProvider::global(cx).update(cx, |completion_provider, cx| {
|
||||||
|
completion_provider.set_active_model(model, cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn register_slash_commands(cx: &mut AppContext) {
|
fn register_slash_commands(cx: &mut AppContext) {
|
||||||
let slash_command_registry = SlashCommandRegistry::global(cx);
|
let slash_command_registry = SlashCommandRegistry::global(cx);
|
||||||
slash_command_registry.register_command(file_command::FileSlashCommand, true);
|
slash_command_registry.register_command(file_command::FileSlashCommand, true);
|
||||||
|
|
|
@ -18,7 +18,7 @@ use anyhow::{anyhow, Result};
|
||||||
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
|
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
|
||||||
use client::proto;
|
use client::proto;
|
||||||
use collections::{BTreeSet, HashMap, HashSet};
|
use collections::{BTreeSet, HashMap, HashSet};
|
||||||
use completion::CompletionProvider;
|
use completion::LanguageModelCompletionProvider;
|
||||||
use editor::{
|
use editor::{
|
||||||
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
|
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
|
||||||
display_map::{
|
display_map::{
|
||||||
|
@ -364,13 +364,12 @@ impl AssistantPanel {
|
||||||
cx.subscribe(&pane, Self::handle_pane_event),
|
cx.subscribe(&pane, Self::handle_pane_event),
|
||||||
cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event),
|
cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event),
|
||||||
cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event),
|
cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event),
|
||||||
cx.observe_global::<CompletionProvider>({
|
cx.observe(
|
||||||
let mut prev_settings_version = CompletionProvider::global(cx).settings_version();
|
&LanguageModelCompletionProvider::global(cx),
|
||||||
move |this, cx| {
|
|this, _, cx| {
|
||||||
this.completion_provider_changed(prev_settings_version, cx);
|
this.completion_provider_changed(cx);
|
||||||
prev_settings_version = CompletionProvider::global(cx).settings_version();
|
},
|
||||||
}
|
),
|
||||||
}),
|
|
||||||
];
|
];
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
|
@ -483,37 +482,36 @@ impl AssistantPanel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn completion_provider_changed(
|
fn completion_provider_changed(&mut self, cx: &mut ViewContext<Self>) {
|
||||||
&mut self,
|
if let Some(editor) = self.active_context_editor(cx) {
|
||||||
prev_settings_version: usize,
|
|
||||||
cx: &mut ViewContext<Self>,
|
|
||||||
) {
|
|
||||||
if self.is_authenticated(cx) {
|
|
||||||
self.authentication_prompt = None;
|
|
||||||
|
|
||||||
match self.active_context_editor(cx) {
|
|
||||||
Some(editor) => {
|
|
||||||
editor.update(cx, |active_context, cx| {
|
editor.update(cx, |active_context, cx| {
|
||||||
active_context
|
active_context
|
||||||
.context
|
.context
|
||||||
.update(cx, |context, cx| context.completion_provider_changed(cx))
|
.update(cx, |context, cx| context.completion_provider_changed(cx))
|
||||||
});
|
})
|
||||||
}
|
}
|
||||||
None => {
|
|
||||||
|
if self.active_context_editor(cx).is_none() {
|
||||||
self.new_context(cx);
|
self.new_context(cx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let authentication_prompt = Self::authentication_prompt(cx);
|
||||||
|
for context_editor in self.context_editors(cx) {
|
||||||
|
context_editor.update(cx, |editor, cx| {
|
||||||
|
editor.set_authentication_prompt(authentication_prompt.clone(), cx);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
cx.notify();
|
cx.notify();
|
||||||
} else if self.authentication_prompt.is_none()
|
|
||||||
|| prev_settings_version != CompletionProvider::global(cx).settings_version()
|
|
||||||
{
|
|
||||||
self.authentication_prompt =
|
|
||||||
Some(cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
|
||||||
provider.authentication_prompt(cx)
|
|
||||||
}));
|
|
||||||
cx.notify();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn authentication_prompt(cx: &mut WindowContext) -> Option<AnyView> {
|
||||||
|
if let Some(provider) = LanguageModelCompletionProvider::read_global(cx).active_provider() {
|
||||||
|
if !provider.is_authenticated(cx) {
|
||||||
|
return Some(provider.authentication_prompt(cx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn inline_assist(
|
pub fn inline_assist(
|
||||||
|
@ -774,7 +772,7 @@ impl AssistantPanel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
|
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
|
||||||
CompletionProvider::global(cx)
|
LanguageModelCompletionProvider::read_global(cx)
|
||||||
.reset_credentials(cx)
|
.reset_credentials(cx)
|
||||||
.detach_and_log_err(cx);
|
.detach_and_log_err(cx);
|
||||||
}
|
}
|
||||||
|
@ -783,6 +781,13 @@ impl AssistantPanel {
|
||||||
self.model_selector_menu_handle.toggle(cx);
|
self.model_selector_menu_handle.toggle(cx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn context_editors(&self, cx: &AppContext) -> Vec<View<ContextEditor>> {
|
||||||
|
self.pane
|
||||||
|
.read(cx)
|
||||||
|
.items_of_type::<ContextEditor>()
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
fn active_context_editor(&self, cx: &AppContext) -> Option<View<ContextEditor>> {
|
fn active_context_editor(&self, cx: &AppContext) -> Option<View<ContextEditor>> {
|
||||||
self.pane
|
self.pane
|
||||||
.read(cx)
|
.read(cx)
|
||||||
|
@ -904,11 +909,11 @@ impl AssistantPanel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_authenticated(&mut self, cx: &mut ViewContext<Self>) -> bool {
|
fn is_authenticated(&mut self, cx: &mut ViewContext<Self>) -> bool {
|
||||||
CompletionProvider::global(cx).is_authenticated()
|
LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
|
fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, cx| provider.authenticate(cx))
|
LanguageModelCompletionProvider::read_global(cx).authenticate(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_signed_in(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
fn render_signed_in(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||||
|
@ -968,14 +973,18 @@ impl Panel for AssistantPanel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
|
fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
|
||||||
settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
|
settings::update_settings_file::<AssistantSettings>(
|
||||||
|
self.fs.clone(),
|
||||||
|
cx,
|
||||||
|
move |settings, _| {
|
||||||
let dock = match position {
|
let dock = match position {
|
||||||
DockPosition::Left => AssistantDockPosition::Left,
|
DockPosition::Left => AssistantDockPosition::Left,
|
||||||
DockPosition::Bottom => AssistantDockPosition::Bottom,
|
DockPosition::Bottom => AssistantDockPosition::Bottom,
|
||||||
DockPosition::Right => AssistantDockPosition::Right,
|
DockPosition::Right => AssistantDockPosition::Right,
|
||||||
};
|
};
|
||||||
settings.set_dock(dock);
|
settings.set_dock(dock);
|
||||||
});
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn size(&self, cx: &WindowContext) -> Pixels {
|
fn size(&self, cx: &WindowContext) -> Pixels {
|
||||||
|
@ -1074,6 +1083,7 @@ struct ActiveEditStep {
|
||||||
|
|
||||||
pub struct ContextEditor {
|
pub struct ContextEditor {
|
||||||
context: Model<Context>,
|
context: Model<Context>,
|
||||||
|
authentication_prompt: Option<AnyView>,
|
||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
workspace: WeakView<Workspace>,
|
workspace: WeakView<Workspace>,
|
||||||
project: Model<Project>,
|
project: Model<Project>,
|
||||||
|
@ -1131,6 +1141,7 @@ impl ContextEditor {
|
||||||
let sections = context.read(cx).slash_command_output_sections().to_vec();
|
let sections = context.read(cx).slash_command_output_sections().to_vec();
|
||||||
let mut this = Self {
|
let mut this = Self {
|
||||||
context,
|
context,
|
||||||
|
authentication_prompt: None,
|
||||||
editor,
|
editor,
|
||||||
lsp_adapter_delegate,
|
lsp_adapter_delegate,
|
||||||
blocks: Default::default(),
|
blocks: Default::default(),
|
||||||
|
@ -1150,6 +1161,15 @@ impl ContextEditor {
|
||||||
this
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn set_authentication_prompt(
|
||||||
|
&mut self,
|
||||||
|
authentication_prompt: Option<AnyView>,
|
||||||
|
cx: &mut ViewContext<Self>,
|
||||||
|
) {
|
||||||
|
self.authentication_prompt = authentication_prompt;
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
|
||||||
fn insert_default_prompt(&mut self, cx: &mut ViewContext<Self>) {
|
fn insert_default_prompt(&mut self, cx: &mut ViewContext<Self>) {
|
||||||
let command_name = DefaultSlashCommand.name();
|
let command_name = DefaultSlashCommand.name();
|
||||||
self.editor.update(cx, |editor, cx| {
|
self.editor.update(cx, |editor, cx| {
|
||||||
|
@ -1176,6 +1196,10 @@ impl ContextEditor {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
|
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
|
||||||
|
if self.authentication_prompt.is_some() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if !self.apply_edit_step(cx) {
|
if !self.apply_edit_step(cx) {
|
||||||
self.send_to_model(cx);
|
self.send_to_model(cx);
|
||||||
}
|
}
|
||||||
|
@ -2203,6 +2227,12 @@ impl Render for ContextEditor {
|
||||||
.size_full()
|
.size_full()
|
||||||
.v_flex()
|
.v_flex()
|
||||||
.child(
|
.child(
|
||||||
|
if let Some(authentication_prompt) = self.authentication_prompt.as_ref() {
|
||||||
|
div()
|
||||||
|
.flex_grow()
|
||||||
|
.bg(cx.theme().colors().editor_background)
|
||||||
|
.child(authentication_prompt.clone().into_any())
|
||||||
|
} else {
|
||||||
div()
|
div()
|
||||||
.flex_grow()
|
.flex_grow()
|
||||||
.bg(cx.theme().colors().editor_background)
|
.bg(cx.theme().colors().editor_background)
|
||||||
|
@ -2215,7 +2245,8 @@ impl Render for ContextEditor {
|
||||||
.p_4()
|
.p_4()
|
||||||
.justify_end()
|
.justify_end()
|
||||||
.child(self.render_send_button(cx)),
|
.child(self.render_send_button(cx)),
|
||||||
),
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2543,7 +2574,7 @@ impl ContextEditorToolbarItem {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_remaining_tokens(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
fn render_remaining_tokens(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
||||||
let model = CompletionProvider::global(cx).model();
|
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
|
||||||
let context = &self
|
let context = &self
|
||||||
.active_context_editor
|
.active_context_editor
|
||||||
.as_ref()?
|
.as_ref()?
|
||||||
|
|
|
@ -1,19 +1,14 @@
|
||||||
use std::{sync::Arc, time::Duration};
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anthropic::Model as AnthropicModel;
|
use anthropic::Model as AnthropicModel;
|
||||||
use client::Client;
|
use fs::Fs;
|
||||||
use completion::{
|
|
||||||
AnthropicCompletionProvider, CloudCompletionProvider, CompletionProvider,
|
|
||||||
LanguageModelCompletionProvider, OllamaCompletionProvider, OpenAiCompletionProvider,
|
|
||||||
};
|
|
||||||
use gpui::{AppContext, Pixels};
|
use gpui::{AppContext, Pixels};
|
||||||
use language_model::{CloudModel, LanguageModel};
|
use language_model::{settings::AllLanguageModelSettings, CloudModel, LanguageModel};
|
||||||
use ollama::Model as OllamaModel;
|
use ollama::Model as OllamaModel;
|
||||||
use open_ai::Model as OpenAiModel;
|
use open_ai::Model as OpenAiModel;
|
||||||
use parking_lot::RwLock;
|
|
||||||
use schemars::{schema::Schema, JsonSchema};
|
use schemars::{schema::Schema, JsonSchema};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::{Settings, SettingsSources};
|
use settings::{update_settings_file, Settings, SettingsSources};
|
||||||
|
|
||||||
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
|
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
|
@ -24,43 +19,9 @@ pub enum AssistantDockPosition {
|
||||||
Bottom,
|
Bottom,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq)]
|
|
||||||
pub enum AssistantProvider {
|
|
||||||
ZedDotDev {
|
|
||||||
model: CloudModel,
|
|
||||||
},
|
|
||||||
OpenAi {
|
|
||||||
model: OpenAiModel,
|
|
||||||
api_url: String,
|
|
||||||
low_speed_timeout_in_seconds: Option<u64>,
|
|
||||||
available_models: Vec<OpenAiModel>,
|
|
||||||
},
|
|
||||||
Anthropic {
|
|
||||||
model: AnthropicModel,
|
|
||||||
api_url: String,
|
|
||||||
low_speed_timeout_in_seconds: Option<u64>,
|
|
||||||
},
|
|
||||||
Ollama {
|
|
||||||
model: OllamaModel,
|
|
||||||
api_url: String,
|
|
||||||
low_speed_timeout_in_seconds: Option<u64>,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for AssistantProvider {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::OpenAi {
|
|
||||||
model: OpenAiModel::default(),
|
|
||||||
api_url: open_ai::OPEN_AI_API_URL.into(),
|
|
||||||
low_speed_timeout_in_seconds: None,
|
|
||||||
available_models: Default::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||||
#[serde(tag = "name", rename_all = "snake_case")]
|
#[serde(tag = "name", rename_all = "snake_case")]
|
||||||
pub enum AssistantProviderContent {
|
pub enum AssistantProviderContentV1 {
|
||||||
#[serde(rename = "zed.dev")]
|
#[serde(rename = "zed.dev")]
|
||||||
ZedDotDev { default_model: Option<CloudModel> },
|
ZedDotDev { default_model: Option<CloudModel> },
|
||||||
#[serde(rename = "openai")]
|
#[serde(rename = "openai")]
|
||||||
|
@ -91,7 +52,8 @@ pub struct AssistantSettings {
|
||||||
pub dock: AssistantDockPosition,
|
pub dock: AssistantDockPosition,
|
||||||
pub default_width: Pixels,
|
pub default_width: Pixels,
|
||||||
pub default_height: Pixels,
|
pub default_height: Pixels,
|
||||||
pub provider: AssistantProvider,
|
pub default_model: AssistantDefaultModel,
|
||||||
|
pub using_outdated_settings_version: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Assistant panel settings
|
/// Assistant panel settings
|
||||||
|
@ -123,34 +85,142 @@ impl Default for AssistantSettingsContent {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AssistantSettingsContent {
|
impl AssistantSettingsContent {
|
||||||
fn upgrade(&self) -> AssistantSettingsContentV1 {
|
pub fn is_version_outdated(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
AssistantSettingsContent::Versioned(settings) => match settings {
|
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||||
VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
|
VersionedAssistantSettingsContent::V1(_) => true,
|
||||||
|
VersionedAssistantSettingsContent::V2(_) => false,
|
||||||
},
|
},
|
||||||
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
|
AssistantSettingsContent::Legacy(_) => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update_file(&mut self, fs: Arc<dyn Fs>, cx: &AppContext) {
|
||||||
|
if let AssistantSettingsContent::Versioned(settings) = self {
|
||||||
|
if let VersionedAssistantSettingsContent::V1(settings) = settings {
|
||||||
|
if let Some(provider) = settings.provider.clone() {
|
||||||
|
match provider {
|
||||||
|
AssistantProviderContentV1::Anthropic {
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
..
|
||||||
|
} => update_settings_file::<AllLanguageModelSettings>(
|
||||||
|
fs,
|
||||||
|
cx,
|
||||||
|
move |content, _| {
|
||||||
|
if content.anthropic.is_none() {
|
||||||
|
content.anthropic =
|
||||||
|
Some(language_model::settings::AnthropicSettingsContent {
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AssistantProviderContentV1::Ollama {
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
..
|
||||||
|
} => update_settings_file::<AllLanguageModelSettings>(
|
||||||
|
fs,
|
||||||
|
cx,
|
||||||
|
move |content, _| {
|
||||||
|
if content.ollama.is_none() {
|
||||||
|
content.ollama =
|
||||||
|
Some(language_model::settings::OllamaSettingsContent {
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AssistantProviderContentV1::OpenAi {
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
available_models,
|
||||||
|
..
|
||||||
|
} => update_settings_file::<AllLanguageModelSettings>(
|
||||||
|
fs,
|
||||||
|
cx,
|
||||||
|
move |content, _| {
|
||||||
|
if content.open_ai.is_none() {
|
||||||
|
content.open_ai =
|
||||||
|
Some(language_model::settings::OpenAiSettingsContent {
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
available_models,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*self = AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(
|
||||||
|
self.upgrade(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upgrade(&self) -> AssistantSettingsContentV2 {
|
||||||
|
match self {
|
||||||
|
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||||
|
VersionedAssistantSettingsContent::V1(settings) => AssistantSettingsContentV2 {
|
||||||
|
enabled: settings.enabled,
|
||||||
|
button: settings.button,
|
||||||
|
dock: settings.dock,
|
||||||
|
default_width: settings.default_width,
|
||||||
|
default_height: settings.default_width,
|
||||||
|
default_model: settings
|
||||||
|
.provider
|
||||||
|
.clone()
|
||||||
|
.and_then(|provider| match provider {
|
||||||
|
AssistantProviderContentV1::ZedDotDev { default_model } => {
|
||||||
|
default_model.map(|model| AssistantDefaultModel {
|
||||||
|
provider: "zed.dev".to_string(),
|
||||||
|
model: model.id().to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
AssistantProviderContentV1::OpenAi { default_model, .. } => {
|
||||||
|
default_model.map(|model| AssistantDefaultModel {
|
||||||
|
provider: "openai".to_string(),
|
||||||
|
model: model.id().to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
AssistantProviderContentV1::Anthropic { default_model, .. } => {
|
||||||
|
default_model.map(|model| AssistantDefaultModel {
|
||||||
|
provider: "anthropic".to_string(),
|
||||||
|
model: model.id().to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
AssistantProviderContentV1::Ollama { default_model, .. } => {
|
||||||
|
default_model.map(|model| AssistantDefaultModel {
|
||||||
|
provider: "ollama".to_string(),
|
||||||
|
model: model.id().to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
VersionedAssistantSettingsContent::V2(settings) => settings.clone(),
|
||||||
|
},
|
||||||
|
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV2 {
|
||||||
enabled: None,
|
enabled: None,
|
||||||
button: settings.button,
|
button: settings.button,
|
||||||
dock: settings.dock,
|
dock: settings.dock,
|
||||||
default_width: settings.default_width,
|
default_width: settings.default_width,
|
||||||
default_height: settings.default_height,
|
default_height: settings.default_height,
|
||||||
provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
|
default_model: Some(AssistantDefaultModel {
|
||||||
Some(AssistantProviderContent::OpenAi {
|
provider: "openai".to_string(),
|
||||||
default_model: settings.default_open_ai_model.clone(),
|
model: settings
|
||||||
api_url: Some(open_ai_api_url.clone()),
|
.default_open_ai_model
|
||||||
low_speed_timeout_in_seconds: None,
|
.clone()
|
||||||
available_models: Some(Default::default()),
|
.unwrap_or_default()
|
||||||
})
|
.id()
|
||||||
} else {
|
.to_string(),
|
||||||
settings.default_open_ai_model.clone().map(|open_ai_model| {
|
}),
|
||||||
AssistantProviderContent::OpenAi {
|
|
||||||
default_model: Some(open_ai_model),
|
|
||||||
api_url: None,
|
|
||||||
low_speed_timeout_in_seconds: None,
|
|
||||||
available_models: Some(Default::default()),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -161,6 +231,9 @@ impl AssistantSettingsContent {
|
||||||
VersionedAssistantSettingsContent::V1(settings) => {
|
VersionedAssistantSettingsContent::V1(settings) => {
|
||||||
settings.dock = Some(dock);
|
settings.dock = Some(dock);
|
||||||
}
|
}
|
||||||
|
VersionedAssistantSettingsContent::V2(settings) => {
|
||||||
|
settings.dock = Some(dock);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
AssistantSettingsContent::Legacy(settings) => {
|
AssistantSettingsContent::Legacy(settings) => {
|
||||||
settings.dock = Some(dock);
|
settings.dock = Some(dock);
|
||||||
|
@ -168,74 +241,78 @@ impl AssistantSettingsContent {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_model(&mut self, new_model: LanguageModel) {
|
pub fn set_model(&mut self, language_model: Arc<dyn LanguageModel>) {
|
||||||
|
let model = language_model.id().0.to_string();
|
||||||
|
let provider = language_model.provider_name().0.to_string();
|
||||||
|
|
||||||
match self {
|
match self {
|
||||||
AssistantSettingsContent::Versioned(settings) => match settings {
|
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||||
VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
|
VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() {
|
||||||
Some(AssistantProviderContent::ZedDotDev {
|
"zed.dev" => {
|
||||||
default_model: model,
|
settings.provider = Some(AssistantProviderContentV1::ZedDotDev {
|
||||||
}) => {
|
default_model: CloudModel::from_id(&model).ok(),
|
||||||
if let LanguageModel::Cloud(new_model) = new_model {
|
});
|
||||||
*model = Some(new_model);
|
|
||||||
}
|
}
|
||||||
}
|
"anthropic" => {
|
||||||
Some(AssistantProviderContent::OpenAi {
|
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {
|
||||||
default_model: model,
|
Some(AssistantProviderContentV1::Anthropic {
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
..
|
..
|
||||||
}) => {
|
}) => (api_url.clone(), *low_speed_timeout_in_seconds),
|
||||||
if let LanguageModel::OpenAi(new_model) = new_model {
|
_ => (None, None),
|
||||||
*model = Some(new_model);
|
};
|
||||||
|
settings.provider = Some(AssistantProviderContentV1::Anthropic {
|
||||||
|
default_model: AnthropicModel::from_id(&model).ok(),
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
"ollama" => {
|
||||||
Some(AssistantProviderContent::Anthropic {
|
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {
|
||||||
default_model: model,
|
Some(AssistantProviderContentV1::Ollama {
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
..
|
..
|
||||||
}) => {
|
}) => (api_url.clone(), *low_speed_timeout_in_seconds),
|
||||||
if let LanguageModel::Anthropic(new_model) = new_model {
|
_ => (None, None),
|
||||||
*model = Some(new_model);
|
};
|
||||||
|
settings.provider = Some(AssistantProviderContentV1::Ollama {
|
||||||
|
default_model: Some(ollama::Model::new(&model)),
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
"openai" => {
|
||||||
Some(AssistantProviderContent::Ollama {
|
let (api_url, low_speed_timeout_in_seconds, available_models) =
|
||||||
default_model: model,
|
match &settings.provider {
|
||||||
|
Some(AssistantProviderContentV1::OpenAi {
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
available_models,
|
||||||
..
|
..
|
||||||
}) => {
|
}) => (
|
||||||
if let LanguageModel::Ollama(new_model) = new_model {
|
api_url.clone(),
|
||||||
*model = Some(new_model);
|
*low_speed_timeout_in_seconds,
|
||||||
}
|
available_models.clone(),
|
||||||
}
|
),
|
||||||
provider => match new_model {
|
_ => (None, None, None),
|
||||||
LanguageModel::Cloud(model) => {
|
};
|
||||||
*provider = Some(AssistantProviderContent::ZedDotDev {
|
settings.provider = Some(AssistantProviderContentV1::OpenAi {
|
||||||
default_model: Some(model),
|
default_model: open_ai::Model::from_id(&model).ok(),
|
||||||
})
|
api_url,
|
||||||
}
|
low_speed_timeout_in_seconds,
|
||||||
LanguageModel::OpenAi(model) => {
|
available_models,
|
||||||
*provider = Some(AssistantProviderContent::OpenAi {
|
});
|
||||||
default_model: Some(model),
|
|
||||||
api_url: None,
|
|
||||||
low_speed_timeout_in_seconds: None,
|
|
||||||
available_models: Some(Default::default()),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
LanguageModel::Anthropic(model) => {
|
|
||||||
*provider = Some(AssistantProviderContent::Anthropic {
|
|
||||||
default_model: Some(model),
|
|
||||||
api_url: None,
|
|
||||||
low_speed_timeout_in_seconds: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
LanguageModel::Ollama(model) => {
|
|
||||||
*provider = Some(AssistantProviderContent::Ollama {
|
|
||||||
default_model: Some(model),
|
|
||||||
api_url: None,
|
|
||||||
low_speed_timeout_in_seconds: None,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
_ => {}
|
||||||
},
|
},
|
||||||
},
|
VersionedAssistantSettingsContent::V2(settings) => {
|
||||||
|
settings.default_model = Some(AssistantDefaultModel { provider, model });
|
||||||
|
}
|
||||||
},
|
},
|
||||||
AssistantSettingsContent::Legacy(settings) => {
|
AssistantSettingsContent::Legacy(settings) => {
|
||||||
if let LanguageModel::OpenAi(model) = new_model {
|
if let Ok(model) = open_ai::Model::from_id(&language_model.id().0) {
|
||||||
settings.default_open_ai_model = Some(model);
|
settings.default_open_ai_model = Some(model);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -248,21 +325,78 @@ impl AssistantSettingsContent {
|
||||||
pub enum VersionedAssistantSettingsContent {
|
pub enum VersionedAssistantSettingsContent {
|
||||||
#[serde(rename = "1")]
|
#[serde(rename = "1")]
|
||||||
V1(AssistantSettingsContentV1),
|
V1(AssistantSettingsContentV1),
|
||||||
|
#[serde(rename = "2")]
|
||||||
|
V2(AssistantSettingsContentV2),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for VersionedAssistantSettingsContent {
|
impl Default for VersionedAssistantSettingsContent {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self::V1(AssistantSettingsContentV1 {
|
Self::V2(AssistantSettingsContentV2 {
|
||||||
enabled: None,
|
enabled: None,
|
||||||
button: None,
|
button: None,
|
||||||
dock: None,
|
dock: None,
|
||||||
default_width: None,
|
default_width: None,
|
||||||
default_height: None,
|
default_height: None,
|
||||||
provider: None,
|
default_model: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||||
|
pub struct AssistantSettingsContentV2 {
|
||||||
|
/// Whether the Assistant is enabled.
|
||||||
|
///
|
||||||
|
/// Default: true
|
||||||
|
enabled: Option<bool>,
|
||||||
|
/// Whether to show the assistant panel button in the status bar.
|
||||||
|
///
|
||||||
|
/// Default: true
|
||||||
|
button: Option<bool>,
|
||||||
|
/// Where to dock the assistant.
|
||||||
|
///
|
||||||
|
/// Default: right
|
||||||
|
dock: Option<AssistantDockPosition>,
|
||||||
|
/// Default width in pixels when the assistant is docked to the left or right.
|
||||||
|
///
|
||||||
|
/// Default: 640
|
||||||
|
default_width: Option<f32>,
|
||||||
|
/// Default height in pixels when the assistant is docked to the bottom.
|
||||||
|
///
|
||||||
|
/// Default: 320
|
||||||
|
default_height: Option<f32>,
|
||||||
|
/// The default model to use when creating new contexts.
|
||||||
|
default_model: Option<AssistantDefaultModel>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||||
|
pub struct AssistantDefaultModel {
|
||||||
|
#[schemars(schema_with = "providers_schema")]
|
||||||
|
pub provider: String,
|
||||||
|
pub model: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
|
||||||
|
schemars::schema::SchemaObject {
|
||||||
|
enum_values: Some(vec![
|
||||||
|
"anthropic".into(),
|
||||||
|
"ollama".into(),
|
||||||
|
"openai".into(),
|
||||||
|
"zed.dev".into(),
|
||||||
|
]),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AssistantDefaultModel {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
provider: "openai".to_string(),
|
||||||
|
model: "gpt-4".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||||
pub struct AssistantSettingsContentV1 {
|
pub struct AssistantSettingsContentV1 {
|
||||||
/// Whether the Assistant is enabled.
|
/// Whether the Assistant is enabled.
|
||||||
|
@ -289,7 +423,7 @@ pub struct AssistantSettingsContentV1 {
|
||||||
///
|
///
|
||||||
/// This can either be the internal `zed.dev` service or an external `openai` service,
|
/// This can either be the internal `zed.dev` service or an external `openai` service,
|
||||||
/// each with their respective default models and configurations.
|
/// each with their respective default models and configurations.
|
||||||
provider: Option<AssistantProviderContent>,
|
provider: Option<AssistantProviderContentV1>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||||
|
@ -332,6 +466,10 @@ impl Settings for AssistantSettings {
|
||||||
let mut settings = AssistantSettings::default();
|
let mut settings = AssistantSettings::default();
|
||||||
|
|
||||||
for value in sources.defaults_and_customizations() {
|
for value in sources.defaults_and_customizations() {
|
||||||
|
if value.is_version_outdated() {
|
||||||
|
settings.using_outdated_settings_version = true;
|
||||||
|
}
|
||||||
|
|
||||||
let value = value.upgrade();
|
let value = value.upgrade();
|
||||||
merge(&mut settings.enabled, value.enabled);
|
merge(&mut settings.enabled, value.enabled);
|
||||||
merge(&mut settings.button, value.button);
|
merge(&mut settings.button, value.button);
|
||||||
|
@ -344,123 +482,10 @@ impl Settings for AssistantSettings {
|
||||||
&mut settings.default_height,
|
&mut settings.default_height,
|
||||||
value.default_height.map(Into::into),
|
value.default_height.map(Into::into),
|
||||||
);
|
);
|
||||||
if let Some(provider) = value.provider.clone() {
|
merge(
|
||||||
match (&mut settings.provider, provider) {
|
&mut settings.default_model,
|
||||||
(
|
value.default_model.map(Into::into),
|
||||||
AssistantProvider::ZedDotDev { model },
|
);
|
||||||
AssistantProviderContent::ZedDotDev {
|
|
||||||
default_model: model_override,
|
|
||||||
},
|
|
||||||
) => {
|
|
||||||
merge(model, model_override);
|
|
||||||
}
|
|
||||||
(
|
|
||||||
AssistantProvider::OpenAi {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
available_models,
|
|
||||||
},
|
|
||||||
AssistantProviderContent::OpenAi {
|
|
||||||
default_model: model_override,
|
|
||||||
api_url: api_url_override,
|
|
||||||
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
|
|
||||||
available_models: available_models_override,
|
|
||||||
},
|
|
||||||
) => {
|
|
||||||
merge(model, model_override);
|
|
||||||
merge(api_url, api_url_override);
|
|
||||||
merge(available_models, available_models_override);
|
|
||||||
if let Some(low_speed_timeout_in_seconds_override) =
|
|
||||||
low_speed_timeout_in_seconds_override
|
|
||||||
{
|
|
||||||
*low_speed_timeout_in_seconds =
|
|
||||||
Some(low_speed_timeout_in_seconds_override);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(
|
|
||||||
AssistantProvider::Ollama {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
},
|
|
||||||
AssistantProviderContent::Ollama {
|
|
||||||
default_model: model_override,
|
|
||||||
api_url: api_url_override,
|
|
||||||
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
|
|
||||||
},
|
|
||||||
) => {
|
|
||||||
merge(model, model_override);
|
|
||||||
merge(api_url, api_url_override);
|
|
||||||
if let Some(low_speed_timeout_in_seconds_override) =
|
|
||||||
low_speed_timeout_in_seconds_override
|
|
||||||
{
|
|
||||||
*low_speed_timeout_in_seconds =
|
|
||||||
Some(low_speed_timeout_in_seconds_override);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(
|
|
||||||
AssistantProvider::Anthropic {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
},
|
|
||||||
AssistantProviderContent::Anthropic {
|
|
||||||
default_model: model_override,
|
|
||||||
api_url: api_url_override,
|
|
||||||
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
|
|
||||||
},
|
|
||||||
) => {
|
|
||||||
merge(model, model_override);
|
|
||||||
merge(api_url, api_url_override);
|
|
||||||
if let Some(low_speed_timeout_in_seconds_override) =
|
|
||||||
low_speed_timeout_in_seconds_override
|
|
||||||
{
|
|
||||||
*low_speed_timeout_in_seconds =
|
|
||||||
Some(low_speed_timeout_in_seconds_override);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(provider, provider_override) => {
|
|
||||||
*provider = match provider_override {
|
|
||||||
AssistantProviderContent::ZedDotDev {
|
|
||||||
default_model: model,
|
|
||||||
} => AssistantProvider::ZedDotDev {
|
|
||||||
model: model.unwrap_or_default(),
|
|
||||||
},
|
|
||||||
AssistantProviderContent::OpenAi {
|
|
||||||
default_model: model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
available_models,
|
|
||||||
} => AssistantProvider::OpenAi {
|
|
||||||
model: model.unwrap_or_default(),
|
|
||||||
api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
available_models: available_models.unwrap_or_default(),
|
|
||||||
},
|
|
||||||
AssistantProviderContent::Anthropic {
|
|
||||||
default_model: model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
} => AssistantProvider::Anthropic {
|
|
||||||
model: model.unwrap_or_default(),
|
|
||||||
api_url: api_url
|
|
||||||
.unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
},
|
|
||||||
AssistantProviderContent::Ollama {
|
|
||||||
default_model: model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
} => AssistantProvider::Ollama {
|
|
||||||
model: model.unwrap_or_default(),
|
|
||||||
api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(settings)
|
Ok(settings)
|
||||||
|
@ -473,221 +498,103 @@ fn merge<T>(target: &mut T, value: Option<T>) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_completion_provider_settings(
|
// #[cfg(test)]
|
||||||
provider: &mut CompletionProvider,
|
// mod tests {
|
||||||
version: usize,
|
// use gpui::{AppContext, UpdateGlobal};
|
||||||
cx: &mut AppContext,
|
// use settings::SettingsStore;
|
||||||
) {
|
|
||||||
let updated = match &AssistantSettings::get_global(cx).provider {
|
|
||||||
AssistantProvider::ZedDotDev { model } => provider
|
|
||||||
.update_current_as::<_, CloudCompletionProvider>(|provider| {
|
|
||||||
provider.update(model.clone(), version);
|
|
||||||
}),
|
|
||||||
AssistantProvider::OpenAi {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
available_models,
|
|
||||||
} => provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
|
|
||||||
provider.update(
|
|
||||||
choose_openai_model(&model, &available_models),
|
|
||||||
api_url.clone(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
version,
|
|
||||||
);
|
|
||||||
}),
|
|
||||||
AssistantProvider::Anthropic {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
} => provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
|
||||||
provider.update(
|
|
||||||
model.clone(),
|
|
||||||
api_url.clone(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
version,
|
|
||||||
);
|
|
||||||
}),
|
|
||||||
AssistantProvider::Ollama {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
} => provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
|
||||||
provider.update(
|
|
||||||
model.clone(),
|
|
||||||
api_url.clone(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
version,
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Previously configured provider was changed to another one
|
// use super::*;
|
||||||
if updated.is_none() {
|
|
||||||
provider.update_provider(|client| create_provider_from_settings(client, version, cx));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn create_provider_from_settings(
|
// #[gpui::test]
|
||||||
client: Arc<Client>,
|
// fn test_deserialize_assistant_settings(cx: &mut AppContext) {
|
||||||
settings_version: usize,
|
// let store = settings::SettingsStore::test(cx);
|
||||||
cx: &mut AppContext,
|
// cx.set_global(store);
|
||||||
) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
|
|
||||||
match &AssistantSettings::get_global(cx).provider {
|
|
||||||
AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
|
|
||||||
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
|
|
||||||
)),
|
|
||||||
AssistantProvider::OpenAi {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
available_models,
|
|
||||||
} => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
|
|
||||||
choose_openai_model(&model, &available_models),
|
|
||||||
api_url.clone(),
|
|
||||||
client.http_client(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
settings_version,
|
|
||||||
available_models.clone(),
|
|
||||||
))),
|
|
||||||
AssistantProvider::Anthropic {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
} => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
|
|
||||||
model.clone(),
|
|
||||||
api_url.clone(),
|
|
||||||
client.http_client(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
settings_version,
|
|
||||||
))),
|
|
||||||
AssistantProvider::Ollama {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
} => Arc::new(RwLock::new(OllamaCompletionProvider::new(
|
|
||||||
model.clone(),
|
|
||||||
api_url.clone(),
|
|
||||||
client.http_client(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
settings_version,
|
|
||||||
cx,
|
|
||||||
))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Choose which model to use for openai provider.
|
// // Settings default to gpt-4-turbo.
|
||||||
/// If the model is not available, try to use the first available model, or fallback to the original model.
|
// AssistantSettings::register(cx);
|
||||||
fn choose_openai_model(
|
// assert_eq!(
|
||||||
model: &::open_ai::Model,
|
// AssistantSettings::get_global(cx).provider,
|
||||||
available_models: &[::open_ai::Model],
|
// AssistantProvider::OpenAi {
|
||||||
) -> ::open_ai::Model {
|
// model: OpenAiModel::FourOmni,
|
||||||
available_models
|
// api_url: open_ai::OPEN_AI_API_URL.into(),
|
||||||
.iter()
|
// low_speed_timeout_in_seconds: None,
|
||||||
.find(|&m| m == model)
|
// available_models: Default::default(),
|
||||||
.or_else(|| available_models.first())
|
// }
|
||||||
.unwrap_or_else(|| model)
|
// );
|
||||||
.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
// // Ensure backward-compatibility.
|
||||||
mod tests {
|
// SettingsStore::update_global(cx, |store, cx| {
|
||||||
use gpui::{AppContext, UpdateGlobal};
|
// store
|
||||||
use settings::SettingsStore;
|
// .set_user_settings(
|
||||||
|
// r#"{
|
||||||
|
// "assistant": {
|
||||||
|
// "openai_api_url": "test-url",
|
||||||
|
// }
|
||||||
|
// }"#,
|
||||||
|
// cx,
|
||||||
|
// )
|
||||||
|
// .unwrap();
|
||||||
|
// });
|
||||||
|
// assert_eq!(
|
||||||
|
// AssistantSettings::get_global(cx).provider,
|
||||||
|
// AssistantProvider::OpenAi {
|
||||||
|
// model: OpenAiModel::FourOmni,
|
||||||
|
// api_url: "test-url".into(),
|
||||||
|
// low_speed_timeout_in_seconds: None,
|
||||||
|
// available_models: Default::default(),
|
||||||
|
// }
|
||||||
|
// );
|
||||||
|
// SettingsStore::update_global(cx, |store, cx| {
|
||||||
|
// store
|
||||||
|
// .set_user_settings(
|
||||||
|
// r#"{
|
||||||
|
// "assistant": {
|
||||||
|
// "default_open_ai_model": "gpt-4-0613"
|
||||||
|
// }
|
||||||
|
// }"#,
|
||||||
|
// cx,
|
||||||
|
// )
|
||||||
|
// .unwrap();
|
||||||
|
// });
|
||||||
|
// assert_eq!(
|
||||||
|
// AssistantSettings::get_global(cx).provider,
|
||||||
|
// AssistantProvider::OpenAi {
|
||||||
|
// model: OpenAiModel::Four,
|
||||||
|
// api_url: open_ai::OPEN_AI_API_URL.into(),
|
||||||
|
// low_speed_timeout_in_seconds: None,
|
||||||
|
// available_models: Default::default(),
|
||||||
|
// }
|
||||||
|
// );
|
||||||
|
|
||||||
use super::*;
|
// // The new version supports setting a custom model when using zed.dev.
|
||||||
|
// SettingsStore::update_global(cx, |store, cx| {
|
||||||
#[gpui::test]
|
// store
|
||||||
fn test_deserialize_assistant_settings(cx: &mut AppContext) {
|
// .set_user_settings(
|
||||||
let store = settings::SettingsStore::test(cx);
|
// r#"{
|
||||||
cx.set_global(store);
|
// "assistant": {
|
||||||
|
// "version": "1",
|
||||||
// Settings default to gpt-4-turbo.
|
// "provider": {
|
||||||
AssistantSettings::register(cx);
|
// "name": "zed.dev",
|
||||||
assert_eq!(
|
// "default_model": {
|
||||||
AssistantSettings::get_global(cx).provider,
|
// "custom": {
|
||||||
AssistantProvider::OpenAi {
|
// "name": "custom-provider"
|
||||||
model: OpenAiModel::FourOmni,
|
// }
|
||||||
api_url: open_ai::OPEN_AI_API_URL.into(),
|
// }
|
||||||
low_speed_timeout_in_seconds: None,
|
// }
|
||||||
available_models: Default::default(),
|
// }
|
||||||
}
|
// }"#,
|
||||||
);
|
// cx,
|
||||||
|
// )
|
||||||
// Ensure backward-compatibility.
|
// .unwrap();
|
||||||
SettingsStore::update_global(cx, |store, cx| {
|
// });
|
||||||
store
|
// assert_eq!(
|
||||||
.set_user_settings(
|
// AssistantSettings::get_global(cx).provider,
|
||||||
r#"{
|
// AssistantProvider::ZedDotDev {
|
||||||
"assistant": {
|
// model: CloudModel::Custom {
|
||||||
"openai_api_url": "test-url",
|
// name: "custom-provider".into(),
|
||||||
}
|
// max_tokens: None
|
||||||
}"#,
|
// }
|
||||||
cx,
|
// }
|
||||||
)
|
// );
|
||||||
.unwrap();
|
// }
|
||||||
});
|
// }
|
||||||
assert_eq!(
|
|
||||||
AssistantSettings::get_global(cx).provider,
|
|
||||||
AssistantProvider::OpenAi {
|
|
||||||
model: OpenAiModel::FourOmni,
|
|
||||||
api_url: "test-url".into(),
|
|
||||||
low_speed_timeout_in_seconds: None,
|
|
||||||
available_models: Default::default(),
|
|
||||||
}
|
|
||||||
);
|
|
||||||
SettingsStore::update_global(cx, |store, cx| {
|
|
||||||
store
|
|
||||||
.set_user_settings(
|
|
||||||
r#"{
|
|
||||||
"assistant": {
|
|
||||||
"default_open_ai_model": "gpt-4-0613"
|
|
||||||
}
|
|
||||||
}"#,
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
});
|
|
||||||
assert_eq!(
|
|
||||||
AssistantSettings::get_global(cx).provider,
|
|
||||||
AssistantProvider::OpenAi {
|
|
||||||
model: OpenAiModel::Four,
|
|
||||||
api_url: open_ai::OPEN_AI_API_URL.into(),
|
|
||||||
low_speed_timeout_in_seconds: None,
|
|
||||||
available_models: Default::default(),
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
// The new version supports setting a custom model when using zed.dev.
|
|
||||||
SettingsStore::update_global(cx, |store, cx| {
|
|
||||||
store
|
|
||||||
.set_user_settings(
|
|
||||||
r#"{
|
|
||||||
"assistant": {
|
|
||||||
"version": "1",
|
|
||||||
"provider": {
|
|
||||||
"name": "zed.dev",
|
|
||||||
"default_model": {
|
|
||||||
"custom": {
|
|
||||||
"name": "custom-provider"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}"#,
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
});
|
|
||||||
assert_eq!(
|
|
||||||
AssistantSettings::get_global(cx).provider,
|
|
||||||
AssistantProvider::ZedDotDev {
|
|
||||||
model: CloudModel::Custom {
|
|
||||||
name: "custom-provider".into(),
|
|
||||||
max_tokens: None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, MessageId,
|
prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider,
|
||||||
MessageStatus,
|
MessageId, MessageStatus,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use assistant_slash_command::{
|
use assistant_slash_command::{
|
||||||
|
@ -1124,7 +1124,9 @@ impl Context {
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let token_count = cx
|
let token_count = cx
|
||||||
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
|
.update(|cx| {
|
||||||
|
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||||
|
})?
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
|
@ -1308,7 +1310,9 @@ impl Context {
|
||||||
});
|
});
|
||||||
|
|
||||||
let raw_output = cx
|
let raw_output = cx
|
||||||
.update(|cx| CompletionProvider::global(cx).complete(request, cx))?
|
.update(|cx| {
|
||||||
|
LanguageModelCompletionProvider::read_global(cx).complete(request, cx)
|
||||||
|
})?
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let operations = Self::parse_edit_operations(&raw_output);
|
let operations = Self::parse_edit_operations(&raw_output);
|
||||||
|
@ -1612,13 +1616,14 @@ impl Context {
|
||||||
.then_some(message.id)
|
.then_some(message.id)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
if !CompletionProvider::global(cx).is_authenticated() {
|
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
|
||||||
log::info!("completion provider has no credentials");
|
log::info!("completion provider has no credentials");
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
let request = self.to_completion_request(cx);
|
let request = self.to_completion_request(cx);
|
||||||
let stream = CompletionProvider::global(cx).stream_completion(request, cx);
|
let stream =
|
||||||
|
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
||||||
let assistant_message = self
|
let assistant_message = self
|
||||||
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -1698,11 +1703,14 @@ impl Context {
|
||||||
});
|
});
|
||||||
|
|
||||||
if let Some(telemetry) = this.telemetry.as_ref() {
|
if let Some(telemetry) = this.telemetry.as_ref() {
|
||||||
let model = CompletionProvider::global(cx).model();
|
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||||
|
.active_model()
|
||||||
|
.map(|m| m.telemetry_id())
|
||||||
|
.unwrap_or_default();
|
||||||
telemetry.report_assistant_event(
|
telemetry.report_assistant_event(
|
||||||
Some(this.id.0.clone()),
|
Some(this.id.0.clone()),
|
||||||
AssistantKind::Panel,
|
AssistantKind::Panel,
|
||||||
model.telemetry_id(),
|
model_telemetry_id,
|
||||||
response_latency,
|
response_latency,
|
||||||
error_message,
|
error_message,
|
||||||
);
|
);
|
||||||
|
@ -1727,7 +1735,6 @@ impl Context {
|
||||||
.map(|message| message.to_request_message(self.buffer.read(cx)));
|
.map(|message| message.to_request_message(self.buffer.read(cx)));
|
||||||
|
|
||||||
LanguageModelRequest {
|
LanguageModelRequest {
|
||||||
model: CompletionProvider::global(cx).model(),
|
|
||||||
messages: messages.collect(),
|
messages: messages.collect(),
|
||||||
stop: vec![],
|
stop: vec![],
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
|
@ -1970,7 +1977,7 @@ impl Context {
|
||||||
|
|
||||||
pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
|
pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
|
||||||
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
|
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
|
||||||
if !CompletionProvider::global(cx).is_authenticated() {
|
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1982,13 +1989,13 @@ impl Context {
|
||||||
content: "Summarize the context into a short title without punctuation.".into(),
|
content: "Summarize the context into a short title without punctuation.".into(),
|
||||||
}));
|
}));
|
||||||
let request = LanguageModelRequest {
|
let request = LanguageModelRequest {
|
||||||
model: CompletionProvider::global(cx).model(),
|
|
||||||
messages: messages.collect(),
|
messages: messages.collect(),
|
||||||
stop: vec![],
|
stop: vec![],
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let stream = CompletionProvider::global(cx).stream_completion(request, cx);
|
let stream =
|
||||||
|
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
||||||
self.pending_summary = cx.spawn(|this, mut cx| {
|
self.pending_summary = cx.spawn(|this, mut cx| {
|
||||||
async move {
|
async move {
|
||||||
let mut messages = stream.await?;
|
let mut messages = stream.await?;
|
||||||
|
@ -2504,7 +2511,6 @@ mod tests {
|
||||||
MessageId,
|
MessageId,
|
||||||
};
|
};
|
||||||
use assistant_slash_command::{ArgumentCompletion, SlashCommand};
|
use assistant_slash_command::{ArgumentCompletion, SlashCommand};
|
||||||
use completion::FakeCompletionProvider;
|
|
||||||
use fs::FakeFs;
|
use fs::FakeFs;
|
||||||
use gpui::{AppContext, TestAppContext, WeakView};
|
use gpui::{AppContext, TestAppContext, WeakView};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
|
@ -2524,7 +2530,8 @@ mod tests {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
|
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
|
||||||
let settings_store = SettingsStore::test(cx);
|
let settings_store = SettingsStore::test(cx);
|
||||||
FakeCompletionProvider::setup_test(cx);
|
language_model::LanguageModelRegistry::test(cx);
|
||||||
|
completion::LanguageModelCompletionProvider::test(cx);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
assistant_panel::init(cx);
|
assistant_panel::init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||||
|
@ -2656,7 +2663,8 @@ mod tests {
|
||||||
fn test_message_splitting(cx: &mut AppContext) {
|
fn test_message_splitting(cx: &mut AppContext) {
|
||||||
let settings_store = SettingsStore::test(cx);
|
let settings_store = SettingsStore::test(cx);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
FakeCompletionProvider::setup_test(cx);
|
language_model::LanguageModelRegistry::test(cx);
|
||||||
|
completion::LanguageModelCompletionProvider::test(cx);
|
||||||
assistant_panel::init(cx);
|
assistant_panel::init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||||
|
|
||||||
|
@ -2749,7 +2757,8 @@ mod tests {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
fn test_messages_for_offsets(cx: &mut AppContext) {
|
fn test_messages_for_offsets(cx: &mut AppContext) {
|
||||||
let settings_store = SettingsStore::test(cx);
|
let settings_store = SettingsStore::test(cx);
|
||||||
FakeCompletionProvider::setup_test(cx);
|
language_model::LanguageModelRegistry::test(cx);
|
||||||
|
completion::LanguageModelCompletionProvider::test(cx);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
assistant_panel::init(cx);
|
assistant_panel::init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||||
|
@ -2834,7 +2843,8 @@ mod tests {
|
||||||
async fn test_slash_commands(cx: &mut TestAppContext) {
|
async fn test_slash_commands(cx: &mut TestAppContext) {
|
||||||
let settings_store = cx.update(SettingsStore::test);
|
let settings_store = cx.update(SettingsStore::test);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
cx.update(FakeCompletionProvider::setup_test);
|
cx.update(language_model::LanguageModelRegistry::test);
|
||||||
|
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||||
cx.update(Project::init_settings);
|
cx.update(Project::init_settings);
|
||||||
cx.update(assistant_panel::init);
|
cx.update(assistant_panel::init);
|
||||||
let fs = FakeFs::new(cx.background_executor.clone());
|
let fs = FakeFs::new(cx.background_executor.clone());
|
||||||
|
@ -2959,7 +2969,11 @@ mod tests {
|
||||||
cx.update(prompt_library::init);
|
cx.update(prompt_library::init);
|
||||||
let settings_store = cx.update(SettingsStore::test);
|
let settings_store = cx.update(SettingsStore::test);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
let fake_provider = cx.update(FakeCompletionProvider::setup_test);
|
|
||||||
|
let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
|
||||||
|
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||||
|
|
||||||
|
let fake_model = fake_provider.test_model();
|
||||||
cx.update(assistant_panel::init);
|
cx.update(assistant_panel::init);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||||
|
|
||||||
|
@ -3025,8 +3039,8 @@ mod tests {
|
||||||
});
|
});
|
||||||
|
|
||||||
// Simulate the LLM completion
|
// Simulate the LLM completion
|
||||||
fake_provider.send_last_completion_chunk(llm_response.to_string());
|
fake_model.send_last_completion_chunk(llm_response.to_string());
|
||||||
fake_provider.finish_last_completion();
|
fake_model.finish_last_completion();
|
||||||
|
|
||||||
// Wait for the completion to be processed
|
// Wait for the completion to be processed
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -3107,7 +3121,8 @@ mod tests {
|
||||||
async fn test_serialization(cx: &mut TestAppContext) {
|
async fn test_serialization(cx: &mut TestAppContext) {
|
||||||
let settings_store = cx.update(SettingsStore::test);
|
let settings_store = cx.update(SettingsStore::test);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
cx.update(FakeCompletionProvider::setup_test);
|
cx.update(language_model::LanguageModelRegistry::test);
|
||||||
|
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||||
cx.update(assistant_panel::init);
|
cx.update(assistant_panel::init);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||||
let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
|
let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
|
||||||
|
@ -3183,7 +3198,9 @@ mod tests {
|
||||||
|
|
||||||
let settings_store = cx.update(SettingsStore::test);
|
let settings_store = cx.update(SettingsStore::test);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
cx.update(FakeCompletionProvider::setup_test);
|
cx.update(language_model::LanguageModelRegistry::test);
|
||||||
|
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||||
|
|
||||||
cx.update(assistant_panel::init);
|
cx.update(assistant_panel::init);
|
||||||
let slash_commands = cx.update(SlashCommandRegistry::default_global);
|
let slash_commands = cx.update(SlashCommandRegistry::default_global);
|
||||||
slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
|
slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt,
|
assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt,
|
||||||
AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, StreamingDiff,
|
AssistantPanel, AssistantPanelEvent, Hunk, LanguageModelCompletionProvider, StreamingDiff,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use client::telemetry::Telemetry;
|
use client::telemetry::Telemetry;
|
||||||
|
@ -27,7 +27,9 @@ use gpui::{
|
||||||
WindowContext,
|
WindowContext,
|
||||||
};
|
};
|
||||||
use language::{Buffer, Point, Selection, TransactionId};
|
use language::{Buffer, Point, Selection, TransactionId};
|
||||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
use language_model::{
|
||||||
|
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||||
|
};
|
||||||
use multi_buffer::MultiBufferRow;
|
use multi_buffer::MultiBufferRow;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use rope::Rope;
|
use rope::Rope;
|
||||||
|
@ -844,7 +846,10 @@ impl InlineAssistant {
|
||||||
}
|
}
|
||||||
|
|
||||||
let codegen = assist.codegen.clone();
|
let codegen = assist.codegen.clone();
|
||||||
let telemetry_id = CompletionProvider::global(cx).model().telemetry_id();
|
let telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||||
|
.active_model()
|
||||||
|
.map(|m| m.telemetry_id())
|
||||||
|
.unwrap_or_default();
|
||||||
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
|
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
|
||||||
if user_prompt.trim().to_lowercase() == "delete" {
|
if user_prompt.trim().to_lowercase() == "delete" {
|
||||||
async { Ok(stream::empty().boxed()) }.boxed_local()
|
async { Ok(stream::empty().boxed()) }.boxed_local()
|
||||||
|
@ -854,7 +859,10 @@ impl InlineAssistant {
|
||||||
async move {
|
async move {
|
||||||
let request = request.await?;
|
let request = request.await?;
|
||||||
let chunks = cx
|
let chunks = cx
|
||||||
.update(|cx| CompletionProvider::global(cx).stream_completion(request, cx))?
|
.update(|cx| {
|
||||||
|
LanguageModelCompletionProvider::read_global(cx)
|
||||||
|
.stream_completion(request, cx)
|
||||||
|
})?
|
||||||
.await?;
|
.await?;
|
||||||
Ok(chunks.boxed())
|
Ok(chunks.boxed())
|
||||||
}
|
}
|
||||||
|
@ -871,8 +879,8 @@ impl InlineAssistant {
|
||||||
cx: &mut WindowContext,
|
cx: &mut WindowContext,
|
||||||
) -> Task<Result<LanguageModelRequest>> {
|
) -> Task<Result<LanguageModelRequest>> {
|
||||||
cx.spawn(|mut cx| async move {
|
cx.spawn(|mut cx| async move {
|
||||||
let (user_prompt, context_request, project_name, buffer, range, model) = cx
|
let (user_prompt, context_request, project_name, buffer, range) =
|
||||||
.read_global(|this: &InlineAssistant, cx: &WindowContext| {
|
cx.read_global(|this: &InlineAssistant, cx: &WindowContext| {
|
||||||
let assist = this.assists.get(&assist_id).context("invalid assist")?;
|
let assist = this.assists.get(&assist_id).context("invalid assist")?;
|
||||||
let decorations = assist.decorations.as_ref().context("invalid assist")?;
|
let decorations = assist.decorations.as_ref().context("invalid assist")?;
|
||||||
let editor = assist.editor.upgrade().context("invalid assist")?;
|
let editor = assist.editor.upgrade().context("invalid assist")?;
|
||||||
|
@ -906,15 +914,7 @@ impl InlineAssistant {
|
||||||
});
|
});
|
||||||
let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
|
let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
|
||||||
let range = assist.codegen.read(cx).range.clone();
|
let range = assist.codegen.read(cx).range.clone();
|
||||||
let model = CompletionProvider::global(cx).model();
|
anyhow::Ok((user_prompt, context_request, project_name, buffer, range))
|
||||||
anyhow::Ok((
|
|
||||||
user_prompt,
|
|
||||||
context_request,
|
|
||||||
project_name,
|
|
||||||
buffer,
|
|
||||||
range,
|
|
||||||
model,
|
|
||||||
))
|
|
||||||
})??;
|
})??;
|
||||||
|
|
||||||
let language = buffer.language_at(range.start);
|
let language = buffer.language_at(range.start);
|
||||||
|
@ -973,7 +973,6 @@ impl InlineAssistant {
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(LanguageModelRequest {
|
Ok(LanguageModelRequest {
|
||||||
model,
|
|
||||||
messages,
|
messages,
|
||||||
stop: vec!["|END|>".to_string()],
|
stop: vec!["|END|>".to_string()],
|
||||||
temperature,
|
temperature,
|
||||||
|
@ -1432,24 +1431,39 @@ impl Render for PromptEditor {
|
||||||
PopoverMenu::new("model-switcher")
|
PopoverMenu::new("model-switcher")
|
||||||
.menu(move |cx| {
|
.menu(move |cx| {
|
||||||
ContextMenu::build(cx, |mut menu, cx| {
|
ContextMenu::build(cx, |mut menu, cx| {
|
||||||
for model in CompletionProvider::global(cx).available_models() {
|
for available_model in
|
||||||
|
LanguageModelRegistry::read_global(cx).available_models(cx)
|
||||||
|
{
|
||||||
menu = menu.custom_entry(
|
menu = menu.custom_entry(
|
||||||
{
|
{
|
||||||
let model = model.clone();
|
let model_name = available_model.name().0.clone();
|
||||||
|
let provider =
|
||||||
|
available_model.provider_name().0.clone();
|
||||||
move |_| {
|
move |_| {
|
||||||
Label::new(model.display_name())
|
h_flex()
|
||||||
.into_any_element()
|
.w_full()
|
||||||
|
.justify_between()
|
||||||
|
.child(Label::new(model_name.clone()))
|
||||||
|
.child(
|
||||||
|
div().ml_4().child(
|
||||||
|
Label::new(provider.clone())
|
||||||
|
.color(Color::Muted),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.into_any()
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
let fs = fs.clone();
|
let fs = fs.clone();
|
||||||
let model = model.clone();
|
let model = available_model.clone();
|
||||||
move |cx| {
|
move |cx| {
|
||||||
let model = model.clone();
|
let model = model.clone();
|
||||||
update_settings_file::<AssistantSettings>(
|
update_settings_file::<AssistantSettings>(
|
||||||
fs.clone(),
|
fs.clone(),
|
||||||
cx,
|
cx,
|
||||||
move |settings| settings.set_model(model),
|
move |settings, _| {
|
||||||
|
settings.set_model(model)
|
||||||
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -1468,9 +1482,10 @@ impl Render for PromptEditor {
|
||||||
Tooltip::with_meta(
|
Tooltip::with_meta(
|
||||||
format!(
|
format!(
|
||||||
"Using {}",
|
"Using {}",
|
||||||
CompletionProvider::global(cx)
|
LanguageModelCompletionProvider::read_global(cx)
|
||||||
.model()
|
.active_model()
|
||||||
.display_name()
|
.map(|model| model.name().0)
|
||||||
|
.unwrap_or_else(|| "No model selected".into()),
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
"Change Model",
|
"Change Model",
|
||||||
|
@ -1668,7 +1683,9 @@ impl PromptEditor {
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let token_count = cx
|
let token_count = cx
|
||||||
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
|
.update(|cx| {
|
||||||
|
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||||
|
})?
|
||||||
.await?;
|
.await?;
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
this.token_count = Some(token_count);
|
this.token_count = Some(token_count);
|
||||||
|
@ -1796,7 +1813,7 @@ impl PromptEditor {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
||||||
let model = CompletionProvider::global(cx).model();
|
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
|
||||||
let token_count = self.token_count?;
|
let token_count = self.token_count?;
|
||||||
let max_token_count = model.max_token_count();
|
let max_token_count = model.max_token_count();
|
||||||
|
|
||||||
|
@ -2601,7 +2618,6 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use completion::FakeCompletionProvider;
|
|
||||||
use futures::stream::{self};
|
use futures::stream::{self};
|
||||||
use gpui::{Context, TestAppContext};
|
use gpui::{Context, TestAppContext};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
|
@ -2622,7 +2638,8 @@ mod tests {
|
||||||
#[gpui::test(iterations = 10)]
|
#[gpui::test(iterations = 10)]
|
||||||
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
|
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||||
cx.set_global(cx.update(SettingsStore::test));
|
cx.set_global(cx.update(SettingsStore::test));
|
||||||
cx.update(|cx| FakeCompletionProvider::setup_test(cx));
|
cx.update(language_model::LanguageModelRegistry::test);
|
||||||
|
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||||
cx.update(language_settings::init);
|
cx.update(language_settings::init);
|
||||||
|
|
||||||
let text = indoc! {"
|
let text = indoc! {"
|
||||||
|
@ -2749,7 +2766,8 @@ mod tests {
|
||||||
cx: &mut TestAppContext,
|
cx: &mut TestAppContext,
|
||||||
mut rng: StdRng,
|
mut rng: StdRng,
|
||||||
) {
|
) {
|
||||||
cx.update(|cx| FakeCompletionProvider::setup_test(cx));
|
cx.update(LanguageModelRegistry::test);
|
||||||
|
cx.update(completion::LanguageModelCompletionProvider::test);
|
||||||
cx.set_global(cx.update(SettingsStore::test));
|
cx.set_global(cx.update(SettingsStore::test));
|
||||||
cx.update(language_settings::init);
|
cx.update(language_settings::init);
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::{assistant_settings::AssistantSettings, CompletionProvider, ToggleModelSelector};
|
use crate::{
|
||||||
|
assistant_settings::AssistantSettings, LanguageModelCompletionProvider, ToggleModelSelector,
|
||||||
|
};
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
|
use language_model::LanguageModelRegistry;
|
||||||
use settings::update_settings_file;
|
use settings::update_settings_file;
|
||||||
use ui::{prelude::*, ButtonLike, ContextMenu, PopoverMenu, PopoverMenuHandle, Tooltip};
|
use ui::{prelude::*, ButtonLike, ContextMenu, PopoverMenu, PopoverMenuHandle, Tooltip};
|
||||||
|
|
||||||
|
@ -23,26 +26,65 @@ impl RenderOnce for ModelSelector {
|
||||||
.with_handle(self.handle)
|
.with_handle(self.handle)
|
||||||
.menu(move |cx| {
|
.menu(move |cx| {
|
||||||
ContextMenu::build(cx, |mut menu, cx| {
|
ContextMenu::build(cx, |mut menu, cx| {
|
||||||
for model in CompletionProvider::global(cx).available_models() {
|
for (provider, available_models) in LanguageModelRegistry::global(cx)
|
||||||
|
.read(cx)
|
||||||
|
.available_models_grouped_by_provider(cx)
|
||||||
|
{
|
||||||
|
menu = menu.header(provider.0.clone());
|
||||||
|
|
||||||
|
if available_models.is_empty() {
|
||||||
menu = menu.custom_entry(
|
menu = menu.custom_entry(
|
||||||
{
|
{
|
||||||
let model = model.clone();
|
move |_| {
|
||||||
move |_| Label::new(model.display_name()).into_any_element()
|
h_flex()
|
||||||
|
.w_full()
|
||||||
|
.gap_1()
|
||||||
|
.child(Icon::new(IconName::Settings))
|
||||||
|
.child(Label::new("Configure"))
|
||||||
|
.into_any()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
let provider = provider.clone();
|
||||||
|
move |cx| {
|
||||||
|
LanguageModelCompletionProvider::global(cx).update(
|
||||||
|
cx,
|
||||||
|
|completion_provider, cx| {
|
||||||
|
completion_provider
|
||||||
|
.set_active_provider(provider.clone(), cx)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
for available_model in available_models {
|
||||||
|
menu = menu.custom_entry(
|
||||||
|
{
|
||||||
|
let model_name = available_model.name().0.clone();
|
||||||
|
move |_| {
|
||||||
|
h_flex()
|
||||||
|
.w_full()
|
||||||
|
.child(Label::new(model_name.clone()))
|
||||||
|
.into_any()
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
let fs = self.fs.clone();
|
let fs = self.fs.clone();
|
||||||
let model = model.clone();
|
let model = available_model.clone();
|
||||||
move |cx| {
|
move |cx| {
|
||||||
let model = model.clone();
|
let model = model.clone();
|
||||||
update_settings_file::<AssistantSettings>(
|
update_settings_file::<AssistantSettings>(
|
||||||
fs.clone(),
|
fs.clone(),
|
||||||
cx,
|
cx,
|
||||||
move |settings| settings.set_model(model),
|
move |settings, _| settings.set_model(model),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
menu
|
menu
|
||||||
})
|
})
|
||||||
.into()
|
.into()
|
||||||
|
@ -61,7 +103,10 @@ impl RenderOnce for ModelSelector {
|
||||||
.whitespace_nowrap()
|
.whitespace_nowrap()
|
||||||
.child(
|
.child(
|
||||||
Label::new(
|
Label::new(
|
||||||
CompletionProvider::global(cx).model().display_name(),
|
LanguageModelCompletionProvider::read_global(cx)
|
||||||
|
.active_model()
|
||||||
|
.map(|model| model.name().0)
|
||||||
|
.unwrap_or_else(|| "No model selected".into()),
|
||||||
)
|
)
|
||||||
.size(LabelSize::Small)
|
.size(LabelSize::Small)
|
||||||
.color(Color::Muted),
|
.color(Color::Muted),
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider,
|
slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant,
|
||||||
InlineAssist, InlineAssistant,
|
LanguageModelCompletionProvider,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use assets::Assets;
|
use assets::Assets;
|
||||||
|
@ -636,9 +636,9 @@ impl PromptLibrary {
|
||||||
};
|
};
|
||||||
|
|
||||||
let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor;
|
let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor;
|
||||||
let provider = CompletionProvider::global(cx);
|
let provider = LanguageModelCompletionProvider::read_global(cx);
|
||||||
let initial_prompt = action.prompt.clone();
|
let initial_prompt = action.prompt.clone();
|
||||||
if provider.is_authenticated() {
|
if provider.is_authenticated(cx) {
|
||||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||||
assistant.assist(&prompt_editor, None, None, initial_prompt, cx)
|
assistant.assist(&prompt_editor, None, None, initial_prompt, cx)
|
||||||
})
|
})
|
||||||
|
@ -736,11 +736,8 @@ impl PromptLibrary {
|
||||||
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
|
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
|
||||||
let token_count = cx
|
let token_count = cx
|
||||||
.update(|cx| {
|
.update(|cx| {
|
||||||
let provider = CompletionProvider::global(cx);
|
LanguageModelCompletionProvider::read_global(cx).count_tokens(
|
||||||
let model = provider.model();
|
|
||||||
provider.count_tokens(
|
|
||||||
LanguageModelRequest {
|
LanguageModelRequest {
|
||||||
model,
|
|
||||||
messages: vec![LanguageModelRequestMessage {
|
messages: vec![LanguageModelRequestMessage {
|
||||||
role: Role::System,
|
role: Role::System,
|
||||||
content: body.to_string(),
|
content: body.to_string(),
|
||||||
|
@ -806,7 +803,7 @@ impl PromptLibrary {
|
||||||
let prompt_metadata = self.store.metadata(prompt_id)?;
|
let prompt_metadata = self.store.metadata(prompt_id)?;
|
||||||
let prompt_editor = &self.prompt_editors[&prompt_id];
|
let prompt_editor = &self.prompt_editors[&prompt_id];
|
||||||
let focus_handle = prompt_editor.body_editor.focus_handle(cx);
|
let focus_handle = prompt_editor.body_editor.focus_handle(cx);
|
||||||
let current_model = CompletionProvider::global(cx).model();
|
let current_model = LanguageModelCompletionProvider::read_global(cx).active_model();
|
||||||
let settings = ThemeSettings::get_global(cx);
|
let settings = ThemeSettings::get_global(cx);
|
||||||
|
|
||||||
Some(
|
Some(
|
||||||
|
@ -917,7 +914,11 @@ impl PromptLibrary {
|
||||||
format!(
|
format!(
|
||||||
"Model: {}",
|
"Model: {}",
|
||||||
current_model
|
current_model
|
||||||
.display_name()
|
.as_ref()
|
||||||
|
.map(|model| model
|
||||||
|
.name()
|
||||||
|
.0)
|
||||||
|
.unwrap_or_default()
|
||||||
),
|
),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
assistant_settings::AssistantSettings, humanize_token_count,
|
assistant_settings::AssistantSettings, humanize_token_count,
|
||||||
prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent,
|
prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent,
|
||||||
CompletionProvider,
|
LanguageModelCompletionProvider,
|
||||||
};
|
};
|
||||||
use anyhow::{Context as _, Result};
|
use anyhow::{Context as _, Result};
|
||||||
use client::telemetry::Telemetry;
|
use client::telemetry::Telemetry;
|
||||||
|
@ -17,7 +17,9 @@ use gpui::{
|
||||||
Subscription, Task, TextStyle, UpdateGlobal, View, WeakView,
|
Subscription, Task, TextStyle, UpdateGlobal, View, WeakView,
|
||||||
};
|
};
|
||||||
use language::Buffer;
|
use language::Buffer;
|
||||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
|
use language_model::{
|
||||||
|
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||||
|
};
|
||||||
use settings::{update_settings_file, Settings};
|
use settings::{update_settings_file, Settings};
|
||||||
use std::{
|
use std::{
|
||||||
cmp,
|
cmp,
|
||||||
|
@ -215,8 +217,6 @@ impl TerminalInlineAssistant {
|
||||||
) -> Result<LanguageModelRequest> {
|
) -> Result<LanguageModelRequest> {
|
||||||
let assist = self.assists.get(&assist_id).context("invalid assist")?;
|
let assist = self.assists.get(&assist_id).context("invalid assist")?;
|
||||||
|
|
||||||
let model = CompletionProvider::global(cx).model();
|
|
||||||
|
|
||||||
let shell = std::env::var("SHELL").ok();
|
let shell = std::env::var("SHELL").ok();
|
||||||
let working_directory = assist
|
let working_directory = assist
|
||||||
.terminal
|
.terminal
|
||||||
|
@ -268,7 +268,6 @@ impl TerminalInlineAssistant {
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(LanguageModelRequest {
|
Ok(LanguageModelRequest {
|
||||||
model,
|
|
||||||
messages,
|
messages,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
|
@ -559,24 +558,39 @@ impl Render for PromptEditor {
|
||||||
PopoverMenu::new("model-switcher")
|
PopoverMenu::new("model-switcher")
|
||||||
.menu(move |cx| {
|
.menu(move |cx| {
|
||||||
ContextMenu::build(cx, |mut menu, cx| {
|
ContextMenu::build(cx, |mut menu, cx| {
|
||||||
for model in CompletionProvider::global(cx).available_models() {
|
for available_model in
|
||||||
|
LanguageModelRegistry::read_global(cx).available_models(cx)
|
||||||
|
{
|
||||||
menu = menu.custom_entry(
|
menu = menu.custom_entry(
|
||||||
{
|
{
|
||||||
let model = model.clone();
|
let model_name = available_model.name().0.clone();
|
||||||
|
let provider =
|
||||||
|
available_model.provider_name().0.clone();
|
||||||
move |_| {
|
move |_| {
|
||||||
Label::new(model.display_name())
|
h_flex()
|
||||||
.into_any_element()
|
.w_full()
|
||||||
|
.justify_between()
|
||||||
|
.child(Label::new(model_name.clone()))
|
||||||
|
.child(
|
||||||
|
div().ml_4().child(
|
||||||
|
Label::new(provider.clone())
|
||||||
|
.color(Color::Muted),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.into_any()
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
let fs = fs.clone();
|
let fs = fs.clone();
|
||||||
let model = model.clone();
|
let model = available_model.clone();
|
||||||
move |cx| {
|
move |cx| {
|
||||||
let model = model.clone();
|
let model = model.clone();
|
||||||
update_settings_file::<AssistantSettings>(
|
update_settings_file::<AssistantSettings>(
|
||||||
fs.clone(),
|
fs.clone(),
|
||||||
cx,
|
cx,
|
||||||
move |settings| settings.set_model(model),
|
move |settings, _| {
|
||||||
|
settings.set_model(model)
|
||||||
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -595,9 +609,10 @@ impl Render for PromptEditor {
|
||||||
Tooltip::with_meta(
|
Tooltip::with_meta(
|
||||||
format!(
|
format!(
|
||||||
"Using {}",
|
"Using {}",
|
||||||
CompletionProvider::global(cx)
|
LanguageModelCompletionProvider::read_global(cx)
|
||||||
.model()
|
.active_model()
|
||||||
.display_name()
|
.map(|model| model.name().0)
|
||||||
|
.unwrap_or_else(|| "No model selected".into())
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
"Change Model",
|
"Change Model",
|
||||||
|
@ -748,7 +763,9 @@ impl PromptEditor {
|
||||||
})??;
|
})??;
|
||||||
|
|
||||||
let token_count = cx
|
let token_count = cx
|
||||||
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
|
.update(|cx| {
|
||||||
|
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||||
|
})?
|
||||||
.await?;
|
.await?;
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
this.token_count = Some(token_count);
|
this.token_count = Some(token_count);
|
||||||
|
@ -878,7 +895,7 @@ impl PromptEditor {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
|
||||||
let model = CompletionProvider::global(cx).model();
|
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
|
||||||
let token_count = self.token_count?;
|
let token_count = self.token_count?;
|
||||||
let max_token_count = model.max_token_count();
|
let max_token_count = model.max_token_count();
|
||||||
|
|
||||||
|
@ -1023,8 +1040,12 @@ impl Codegen {
|
||||||
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
|
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
|
||||||
|
|
||||||
let telemetry = self.telemetry.clone();
|
let telemetry = self.telemetry.clone();
|
||||||
let model_telemetry_id = prompt.model.telemetry_id();
|
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||||
let response = CompletionProvider::global(cx).stream_completion(prompt, cx);
|
.active_model()
|
||||||
|
.map(|m| m.telemetry_id())
|
||||||
|
.unwrap_or_default();
|
||||||
|
let response =
|
||||||
|
LanguageModelCompletionProvider::read_global(cx).stream_completion(prompt, cx);
|
||||||
|
|
||||||
self.generation = cx.spawn(|this, mut cx| async move {
|
self.generation = cx.spawn(|this, mut cx| async move {
|
||||||
let response = response.await;
|
let response = response.await;
|
||||||
|
|
|
@ -90,6 +90,7 @@ git_hosting_providers.workspace = true
|
||||||
gpui = { workspace = true, features = ["test-support"] }
|
gpui = { workspace = true, features = ["test-support"] }
|
||||||
indoc.workspace = true
|
indoc.workspace = true
|
||||||
language = { workspace = true, features = ["test-support"] }
|
language = { workspace = true, features = ["test-support"] }
|
||||||
|
language_model = { workspace = true, features = ["test-support"] }
|
||||||
live_kit_client = { workspace = true, features = ["test-support"] }
|
live_kit_client = { workspace = true, features = ["test-support"] }
|
||||||
lsp = { workspace = true, features = ["test-support"] }
|
lsp = { workspace = true, features = ["test-support"] }
|
||||||
menu.workspace = true
|
menu.workspace = true
|
||||||
|
|
|
@ -157,6 +157,8 @@ impl TestServer {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
|
pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
|
||||||
|
let fs = FakeFs::new(cx.executor());
|
||||||
|
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
if cx.has_global::<SettingsStore>() {
|
if cx.has_global::<SettingsStore>() {
|
||||||
panic!("Same cx used to create two test clients")
|
panic!("Same cx used to create two test clients")
|
||||||
|
@ -265,7 +267,6 @@ impl TestServer {
|
||||||
git_hosting_provider_registry
|
git_hosting_provider_registry
|
||||||
.register_hosting_provider(Arc::new(git_hosting_providers::Github));
|
.register_hosting_provider(Arc::new(git_hosting_providers::Github));
|
||||||
|
|
||||||
let fs = FakeFs::new(cx.executor());
|
|
||||||
let user_store = cx.new_model(|cx| UserStore::new(client.clone(), cx));
|
let user_store = cx.new_model(|cx| UserStore::new(client.clone(), cx));
|
||||||
let workspace_store = cx.new_model(|cx| WorkspaceStore::new(client.clone(), cx));
|
let workspace_store = cx.new_model(|cx| WorkspaceStore::new(client.clone(), cx));
|
||||||
let language_registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
let language_registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||||
|
@ -297,7 +298,8 @@ impl TestServer {
|
||||||
menu::init();
|
menu::init();
|
||||||
dev_server_projects::init(client.clone(), cx);
|
dev_server_projects::init(client.clone(), cx);
|
||||||
settings::KeymapFile::load_asset(os_keymap, cx).unwrap();
|
settings::KeymapFile::load_asset(os_keymap, cx).unwrap();
|
||||||
completion::FakeCompletionProvider::setup_test(cx);
|
language_model::LanguageModelRegistry::test(cx);
|
||||||
|
completion::init(cx);
|
||||||
assistant::context_store::init(&client);
|
assistant::context_store::init(&client);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -1107,9 +1107,11 @@ impl Panel for ChatPanel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
|
fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
|
||||||
settings::update_settings_file::<ChatPanelSettings>(self.fs.clone(), cx, move |settings| {
|
settings::update_settings_file::<ChatPanelSettings>(
|
||||||
settings.dock = Some(position)
|
self.fs.clone(),
|
||||||
});
|
cx,
|
||||||
|
move |settings, _| settings.dock = Some(position),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn size(&self, cx: &gpui::WindowContext) -> Pixels {
|
fn size(&self, cx: &gpui::WindowContext) -> Pixels {
|
||||||
|
|
|
@ -2806,7 +2806,7 @@ impl Panel for CollabPanel {
|
||||||
settings::update_settings_file::<CollaborationPanelSettings>(
|
settings::update_settings_file::<CollaborationPanelSettings>(
|
||||||
self.fs.clone(),
|
self.fs.clone(),
|
||||||
cx,
|
cx,
|
||||||
move |settings| settings.dock = Some(position),
|
move |settings, _| settings.dock = Some(position),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -672,7 +672,7 @@ impl Panel for NotificationPanel {
|
||||||
settings::update_settings_file::<NotificationPanelSettings>(
|
settings::update_settings_file::<NotificationPanelSettings>(
|
||||||
self.fs.clone(),
|
self.fs.clone(),
|
||||||
cx,
|
cx,
|
||||||
move |settings| settings.dock = Some(position),
|
move |settings, _| settings.dock = Some(position),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,34 +16,20 @@ doctest = false
|
||||||
test-support = [
|
test-support = [
|
||||||
"editor/test-support",
|
"editor/test-support",
|
||||||
"language/test-support",
|
"language/test-support",
|
||||||
|
"language_model/test-support",
|
||||||
"project/test-support",
|
"project/test-support",
|
||||||
"text/test-support",
|
"text/test-support",
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anthropic = { workspace = true, features = ["schemars"] }
|
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
client.workspace = true
|
|
||||||
collections.workspace = true
|
|
||||||
editor.workspace = true
|
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
http.workspace = true
|
|
||||||
language_model.workspace = true
|
language_model.workspace = true
|
||||||
log.workspace = true
|
|
||||||
menu.workspace = true
|
|
||||||
ollama = { workspace = true, features = ["schemars"] }
|
|
||||||
open_ai = { workspace = true, features = ["schemars"] }
|
|
||||||
parking_lot.workspace = true
|
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
|
||||||
settings.workspace = true
|
settings.workspace = true
|
||||||
smol.workspace = true
|
smol.workspace = true
|
||||||
strum.workspace = true
|
|
||||||
theme.workspace = true
|
|
||||||
tiktoken-rs.workspace = true
|
|
||||||
ui.workspace = true
|
ui.workspace = true
|
||||||
util.workspace = true
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
ctor.workspace = true
|
ctor.workspace = true
|
||||||
|
@ -51,6 +37,7 @@ editor = { workspace = true, features = ["test-support"] }
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
language = { workspace = true, features = ["test-support"] }
|
language = { workspace = true, features = ["test-support"] }
|
||||||
project = { workspace = true, features = ["test-support"] }
|
project = { workspace = true, features = ["test-support"] }
|
||||||
|
language_model = { workspace = true, features = ["test-support"] }
|
||||||
rand.workspace = true
|
rand.workspace = true
|
||||||
text = { workspace = true, features = ["test-support"] }
|
text = { workspace = true, features = ["test-support"] }
|
||||||
unindent.workspace = true
|
unindent.workspace = true
|
||||||
|
|
|
@ -1,318 +0,0 @@
|
||||||
use crate::{count_open_ai_tokens, LanguageModelCompletionProvider};
|
|
||||||
use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
|
|
||||||
use anthropic::{stream_completion, Model as AnthropicModel, Request, RequestMessage};
|
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use editor::{Editor, EditorElement, EditorStyle};
|
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
|
||||||
use gpui::{AnyView, AppContext, Task, TextStyle, View};
|
|
||||||
use http::HttpClient;
|
|
||||||
use language_model::Role;
|
|
||||||
use settings::Settings;
|
|
||||||
use std::time::Duration;
|
|
||||||
use std::{env, sync::Arc};
|
|
||||||
use strum::IntoEnumIterator;
|
|
||||||
use theme::ThemeSettings;
|
|
||||||
use ui::prelude::*;
|
|
||||||
use util::ResultExt;
|
|
||||||
|
|
||||||
pub struct AnthropicCompletionProvider {
|
|
||||||
api_key: Option<String>,
|
|
||||||
api_url: String,
|
|
||||||
model: AnthropicModel,
|
|
||||||
http_client: Arc<dyn HttpClient>,
|
|
||||||
low_speed_timeout: Option<Duration>,
|
|
||||||
settings_version: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
|
|
||||||
fn available_models(&self) -> Vec<LanguageModel> {
|
|
||||||
AnthropicModel::iter()
|
|
||||||
.map(LanguageModel::Anthropic)
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn settings_version(&self) -> usize {
|
|
||||||
self.settings_version
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_authenticated(&self) -> bool {
|
|
||||||
self.api_key.is_some()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
|
||||||
if self.is_authenticated() {
|
|
||||||
Task::ready(Ok(()))
|
|
||||||
} else {
|
|
||||||
let api_url = self.api_url.clone();
|
|
||||||
cx.spawn(|mut cx| async move {
|
|
||||||
let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
|
|
||||||
api_key
|
|
||||||
} else {
|
|
||||||
let (_, api_key) = cx
|
|
||||||
.update(|cx| cx.read_credentials(&api_url))?
|
|
||||||
.await?
|
|
||||||
.ok_or_else(|| anyhow!("credentials not found"))?;
|
|
||||||
String::from_utf8(api_key)?
|
|
||||||
};
|
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
|
||||||
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
|
||||||
provider.api_key = Some(api_key);
|
|
||||||
});
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
|
||||||
let delete_credentials = cx.delete_credentials(&self.api_url);
|
|
||||||
cx.spawn(|mut cx| async move {
|
|
||||||
delete_credentials.await.log_err();
|
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
|
||||||
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
|
||||||
provider.api_key = None;
|
|
||||||
});
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
|
||||||
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
|
||||||
.into()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn model(&self) -> LanguageModel {
|
|
||||||
LanguageModel::Anthropic(self.model.clone())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn count_tokens(
|
|
||||||
&self,
|
|
||||||
request: LanguageModelRequest,
|
|
||||||
cx: &AppContext,
|
|
||||||
) -> BoxFuture<'static, Result<usize>> {
|
|
||||||
count_open_ai_tokens(request, cx.background_executor())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn stream_completion(
|
|
||||||
&self,
|
|
||||||
request: LanguageModelRequest,
|
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
|
||||||
let request = self.to_anthropic_request(request);
|
|
||||||
|
|
||||||
let http_client = self.http_client.clone();
|
|
||||||
let api_key = self.api_key.clone();
|
|
||||||
let api_url = self.api_url.clone();
|
|
||||||
let low_speed_timeout = self.low_speed_timeout;
|
|
||||||
async move {
|
|
||||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
|
||||||
let request = stream_completion(
|
|
||||||
http_client.as_ref(),
|
|
||||||
&api_url,
|
|
||||||
&api_key,
|
|
||||||
request,
|
|
||||||
low_speed_timeout,
|
|
||||||
);
|
|
||||||
let response = request.await?;
|
|
||||||
let stream = response
|
|
||||||
.filter_map(|response| async move {
|
|
||||||
match response {
|
|
||||||
Ok(response) => match response {
|
|
||||||
anthropic::ResponseEvent::ContentBlockStart {
|
|
||||||
content_block, ..
|
|
||||||
} => match content_block {
|
|
||||||
anthropic::ContentBlock::Text { text } => Some(Ok(text)),
|
|
||||||
},
|
|
||||||
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
|
|
||||||
match delta {
|
|
||||||
anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
},
|
|
||||||
Err(error) => Some(Err(error)),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.boxed();
|
|
||||||
Ok(stream)
|
|
||||||
}
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AnthropicCompletionProvider {
|
|
||||||
pub fn new(
|
|
||||||
model: AnthropicModel,
|
|
||||||
api_url: String,
|
|
||||||
http_client: Arc<dyn HttpClient>,
|
|
||||||
low_speed_timeout: Option<Duration>,
|
|
||||||
settings_version: usize,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
api_key: None,
|
|
||||||
api_url,
|
|
||||||
model,
|
|
||||||
http_client,
|
|
||||||
low_speed_timeout,
|
|
||||||
settings_version,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn update(
|
|
||||||
&mut self,
|
|
||||||
model: AnthropicModel,
|
|
||||||
api_url: String,
|
|
||||||
low_speed_timeout: Option<Duration>,
|
|
||||||
settings_version: usize,
|
|
||||||
) {
|
|
||||||
self.model = model;
|
|
||||||
self.api_url = api_url;
|
|
||||||
self.low_speed_timeout = low_speed_timeout;
|
|
||||||
self.settings_version = settings_version;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
|
|
||||||
request.preprocess_anthropic();
|
|
||||||
|
|
||||||
let model = match request.model {
|
|
||||||
LanguageModel::Anthropic(model) => model,
|
|
||||||
_ => self.model.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut system_message = String::new();
|
|
||||||
if request
|
|
||||||
.messages
|
|
||||||
.first()
|
|
||||||
.map_or(false, |message| message.role == Role::System)
|
|
||||||
{
|
|
||||||
system_message = request.messages.remove(0).content;
|
|
||||||
}
|
|
||||||
|
|
||||||
Request {
|
|
||||||
model,
|
|
||||||
messages: request
|
|
||||||
.messages
|
|
||||||
.iter()
|
|
||||||
.map(|msg| RequestMessage {
|
|
||||||
role: match msg.role {
|
|
||||||
Role::User => anthropic::Role::User,
|
|
||||||
Role::Assistant => anthropic::Role::Assistant,
|
|
||||||
Role::System => unreachable!("filtered out by preprocess_request"),
|
|
||||||
},
|
|
||||||
content: msg.content.clone(),
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
stream: true,
|
|
||||||
system: system_message,
|
|
||||||
max_tokens: 4092,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct AuthenticationPrompt {
|
|
||||||
api_key: View<Editor>,
|
|
||||||
api_url: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AuthenticationPrompt {
|
|
||||||
fn new(api_url: String, cx: &mut WindowContext) -> Self {
|
|
||||||
Self {
|
|
||||||
api_key: cx.new_view(|cx| {
|
|
||||||
let mut editor = Editor::single_line(cx);
|
|
||||||
editor.set_placeholder_text(
|
|
||||||
"sk-000000000000000000000000000000000000000000000000",
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
editor
|
|
||||||
}),
|
|
||||||
api_url,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
|
||||||
let api_key = self.api_key.read(cx).text(cx);
|
|
||||||
if api_key.is_empty() {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
|
|
||||||
cx.spawn(|_, mut cx| async move {
|
|
||||||
write_credentials.await?;
|
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
|
||||||
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
|
||||||
provider.api_key = Some(api_key);
|
|
||||||
});
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.detach_and_log_err(cx);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
|
||||||
let settings = ThemeSettings::get_global(cx);
|
|
||||||
let text_style = TextStyle {
|
|
||||||
color: cx.theme().colors().text,
|
|
||||||
font_family: settings.ui_font.family.clone(),
|
|
||||||
font_features: settings.ui_font.features.clone(),
|
|
||||||
font_size: rems(0.875).into(),
|
|
||||||
font_weight: settings.ui_font.weight,
|
|
||||||
line_height: relative(1.3),
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
EditorElement::new(
|
|
||||||
&self.api_key,
|
|
||||||
EditorStyle {
|
|
||||||
background: cx.theme().colors().editor_background,
|
|
||||||
local_player: cx.theme().players().local(),
|
|
||||||
text: text_style,
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Render for AuthenticationPrompt {
|
|
||||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
|
||||||
const INSTRUCTIONS: [&str; 4] = [
|
|
||||||
"To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
|
|
||||||
"You can create an API key at: https://console.anthropic.com/settings/keys",
|
|
||||||
"",
|
|
||||||
"Paste your Anthropic API key below and hit enter to use the assistant:",
|
|
||||||
];
|
|
||||||
|
|
||||||
v_flex()
|
|
||||||
.p_4()
|
|
||||||
.size_full()
|
|
||||||
.on_action(cx.listener(Self::save_api_key))
|
|
||||||
.children(
|
|
||||||
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
|
||||||
)
|
|
||||||
.child(
|
|
||||||
h_flex()
|
|
||||||
.w_full()
|
|
||||||
.my_2()
|
|
||||||
.px_2()
|
|
||||||
.py_1()
|
|
||||||
.bg(cx.theme().colors().editor_background)
|
|
||||||
.rounded_md()
|
|
||||||
.child(self.render_api_key_editor(cx)),
|
|
||||||
)
|
|
||||||
.child(
|
|
||||||
Label::new(
|
|
||||||
"You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
|
|
||||||
)
|
|
||||||
.size(LabelSize::Small),
|
|
||||||
)
|
|
||||||
.child(
|
|
||||||
h_flex()
|
|
||||||
.gap_2()
|
|
||||||
.child(Label::new("Click on").size(LabelSize::Small))
|
|
||||||
.child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
|
|
||||||
.child(
|
|
||||||
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.into_any()
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,214 +0,0 @@
|
||||||
use crate::{
|
|
||||||
count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelCompletionProvider,
|
|
||||||
LanguageModelRequest,
|
|
||||||
};
|
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use client::{proto, Client};
|
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
|
||||||
use gpui::{AnyView, AppContext, Task};
|
|
||||||
use language_model::CloudModel;
|
|
||||||
use std::{future, sync::Arc};
|
|
||||||
use strum::IntoEnumIterator;
|
|
||||||
use ui::prelude::*;
|
|
||||||
|
|
||||||
pub struct CloudCompletionProvider {
|
|
||||||
client: Arc<Client>,
|
|
||||||
model: CloudModel,
|
|
||||||
settings_version: usize,
|
|
||||||
status: client::Status,
|
|
||||||
_maintain_client_status: Task<()>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CloudCompletionProvider {
|
|
||||||
pub fn new(
|
|
||||||
model: CloudModel,
|
|
||||||
client: Arc<Client>,
|
|
||||||
settings_version: usize,
|
|
||||||
cx: &mut AppContext,
|
|
||||||
) -> Self {
|
|
||||||
let mut status_rx = client.status();
|
|
||||||
let status = *status_rx.borrow();
|
|
||||||
let maintain_client_status = cx.spawn(|mut cx| async move {
|
|
||||||
while let Some(status) = status_rx.next().await {
|
|
||||||
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
|
||||||
provider.update_current_as::<_, Self>(|provider| {
|
|
||||||
provider.status = status;
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
});
|
|
||||||
Self {
|
|
||||||
client,
|
|
||||||
model,
|
|
||||||
settings_version,
|
|
||||||
status,
|
|
||||||
_maintain_client_status: maintain_client_status,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn update(&mut self, model: CloudModel, settings_version: usize) {
|
|
||||||
self.model = model;
|
|
||||||
self.settings_version = settings_version;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LanguageModelCompletionProvider for CloudCompletionProvider {
|
|
||||||
fn available_models(&self) -> Vec<LanguageModel> {
|
|
||||||
let mut custom_model = if matches!(self.model, CloudModel::Custom { .. }) {
|
|
||||||
Some(self.model.clone())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
CloudModel::iter()
|
|
||||||
.filter_map(move |model| {
|
|
||||||
if let CloudModel::Custom { .. } = model {
|
|
||||||
custom_model.take()
|
|
||||||
} else {
|
|
||||||
Some(model)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.map(LanguageModel::Cloud)
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn settings_version(&self) -> usize {
|
|
||||||
self.settings_version
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_authenticated(&self) -> bool {
|
|
||||||
self.status.is_connected()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
|
||||||
let client = self.client.clone();
|
|
||||||
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
|
||||||
cx.new_view(|_cx| AuthenticationPrompt).into()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
|
|
||||||
Task::ready(Ok(()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn model(&self) -> LanguageModel {
|
|
||||||
LanguageModel::Cloud(self.model.clone())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn count_tokens(
|
|
||||||
&self,
|
|
||||||
request: LanguageModelRequest,
|
|
||||||
cx: &AppContext,
|
|
||||||
) -> BoxFuture<'static, Result<usize>> {
|
|
||||||
match &request.model {
|
|
||||||
LanguageModel::Cloud(CloudModel::Gpt4)
|
|
||||||
| LanguageModel::Cloud(CloudModel::Gpt4Turbo)
|
|
||||||
| LanguageModel::Cloud(CloudModel::Gpt4Omni)
|
|
||||||
| LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
|
|
||||||
count_open_ai_tokens(request, cx.background_executor())
|
|
||||||
}
|
|
||||||
LanguageModel::Cloud(
|
|
||||||
CloudModel::Claude3_5Sonnet
|
|
||||||
| CloudModel::Claude3Opus
|
|
||||||
| CloudModel::Claude3Sonnet
|
|
||||||
| CloudModel::Claude3Haiku,
|
|
||||||
) => {
|
|
||||||
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
|
|
||||||
count_open_ai_tokens(request, cx.background_executor())
|
|
||||||
}
|
|
||||||
LanguageModel::Cloud(CloudModel::Custom { name, .. }) => {
|
|
||||||
if name.starts_with("anthropic/") {
|
|
||||||
// Can't find a tokenizer for Anthropic models, so for now just use the same as OpenAI's as an approximation.
|
|
||||||
count_open_ai_tokens(request, cx.background_executor())
|
|
||||||
} else {
|
|
||||||
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
|
||||||
model: name.clone(),
|
|
||||||
messages: request
|
|
||||||
.messages
|
|
||||||
.iter()
|
|
||||||
.map(|message| message.to_proto())
|
|
||||||
.collect(),
|
|
||||||
});
|
|
||||||
async move {
|
|
||||||
let response = request.await?;
|
|
||||||
Ok(response.token_count as usize)
|
|
||||||
}
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn stream_completion(
|
|
||||||
&self,
|
|
||||||
mut request: LanguageModelRequest,
|
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
|
||||||
request.preprocess();
|
|
||||||
|
|
||||||
let request = proto::CompleteWithLanguageModel {
|
|
||||||
model: request.model.id().to_string(),
|
|
||||||
messages: request
|
|
||||||
.messages
|
|
||||||
.iter()
|
|
||||||
.map(|message| message.to_proto())
|
|
||||||
.collect(),
|
|
||||||
stop: request.stop,
|
|
||||||
temperature: request.temperature,
|
|
||||||
tools: Vec::new(),
|
|
||||||
tool_choice: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
self.client
|
|
||||||
.request_stream(request)
|
|
||||||
.map_ok(|stream| {
|
|
||||||
stream
|
|
||||||
.filter_map(|response| async move {
|
|
||||||
match response {
|
|
||||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
|
|
||||||
Err(error) => Some(Err(error)),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.boxed()
|
|
||||||
})
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct AuthenticationPrompt;
|
|
||||||
|
|
||||||
impl Render for AuthenticationPrompt {
|
|
||||||
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
|
||||||
const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
|
|
||||||
|
|
||||||
v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
|
|
||||||
v_flex()
|
|
||||||
.gap_2()
|
|
||||||
.child(
|
|
||||||
Button::new("sign_in", "Sign in")
|
|
||||||
.icon_color(Color::Muted)
|
|
||||||
.icon(IconName::Github)
|
|
||||||
.icon_position(IconPosition::Start)
|
|
||||||
.style(ButtonStyle::Filled)
|
|
||||||
.full_width()
|
|
||||||
.on_click(|_, cx| {
|
|
||||||
CompletionProvider::global(cx)
|
|
||||||
.authenticate(cx)
|
|
||||||
.detach_and_log_err(cx);
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.child(
|
|
||||||
div().flex().w_full().items_center().child(
|
|
||||||
Label::new("Sign in to enable collaboration.")
|
|
||||||
.color(Color::Muted)
|
|
||||||
.size(LabelSize::Small),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,31 +1,37 @@
|
||||||
mod anthropic;
|
use anyhow::{anyhow, Result};
|
||||||
mod cloud;
|
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
use gpui::{AppContext, Global, Model, ModelContext, Task};
|
||||||
mod fake;
|
use language_model::{
|
||||||
mod ollama;
|
LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelRegistry,
|
||||||
mod open_ai;
|
LanguageModelRequest,
|
||||||
|
};
|
||||||
pub use anthropic::*;
|
|
||||||
use anyhow::Result;
|
|
||||||
use client::Client;
|
|
||||||
pub use cloud::*;
|
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
|
||||||
pub use fake::*;
|
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
|
|
||||||
use gpui::{AnyView, AppContext, Task, WindowContext};
|
|
||||||
use language_model::{LanguageModel, LanguageModelRequest};
|
|
||||||
pub use ollama::*;
|
|
||||||
pub use open_ai::*;
|
|
||||||
use parking_lot::RwLock;
|
|
||||||
use smol::lock::{Semaphore, SemaphoreGuardArc};
|
use smol::lock::{Semaphore, SemaphoreGuardArc};
|
||||||
use std::{any::Any, pin::Pin, sync::Arc, task::Poll};
|
use std::{pin::Pin, sync::Arc, task::Poll};
|
||||||
|
use ui::Context;
|
||||||
|
|
||||||
pub struct CompletionResponse {
|
pub fn init(cx: &mut AppContext) {
|
||||||
inner: BoxStream<'static, Result<String>>,
|
let completion_provider = cx.new_model(|cx| LanguageModelCompletionProvider::new(cx));
|
||||||
|
cx.set_global(GlobalLanguageModelCompletionProvider(completion_provider));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GlobalLanguageModelCompletionProvider(Model<LanguageModelCompletionProvider>);
|
||||||
|
|
||||||
|
impl Global for GlobalLanguageModelCompletionProvider {}
|
||||||
|
|
||||||
|
pub struct LanguageModelCompletionProvider {
|
||||||
|
active_provider: Option<Arc<dyn LanguageModelProvider>>,
|
||||||
|
active_model: Option<Arc<dyn LanguageModel>>,
|
||||||
|
request_limiter: Arc<Semaphore>,
|
||||||
|
}
|
||||||
|
|
||||||
|
const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
|
||||||
|
|
||||||
|
pub struct LanguageModelCompletionResponse {
|
||||||
|
pub inner: BoxStream<'static, Result<String>>,
|
||||||
_lock: SemaphoreGuardArc,
|
_lock: SemaphoreGuardArc,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl futures::Stream for CompletionResponse {
|
impl futures::Stream for LanguageModelCompletionResponse {
|
||||||
type Item = Result<String>;
|
type Item = Result<String>;
|
||||||
|
|
||||||
fn poll_next(
|
fn poll_next(
|
||||||
|
@ -36,73 +42,96 @@ impl futures::Stream for CompletionResponse {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait LanguageModelCompletionProvider: Send + Sync {
|
impl LanguageModelCompletionProvider {
|
||||||
fn available_models(&self) -> Vec<LanguageModel>;
|
pub fn global(cx: &AppContext) -> Model<Self> {
|
||||||
fn settings_version(&self) -> usize;
|
cx.global::<GlobalLanguageModelCompletionProvider>()
|
||||||
fn is_authenticated(&self) -> bool;
|
.0
|
||||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
|
.clone()
|
||||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
|
}
|
||||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
|
|
||||||
fn model(&self) -> LanguageModel;
|
|
||||||
fn count_tokens(
|
|
||||||
&self,
|
|
||||||
request: LanguageModelRequest,
|
|
||||||
cx: &AppContext,
|
|
||||||
) -> BoxFuture<'static, Result<usize>>;
|
|
||||||
fn stream_completion(
|
|
||||||
&self,
|
|
||||||
request: LanguageModelRequest,
|
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
|
||||||
|
|
||||||
fn as_any_mut(&mut self) -> &mut dyn Any;
|
pub fn read_global(cx: &AppContext) -> &Self {
|
||||||
}
|
cx.global::<GlobalLanguageModelCompletionProvider>()
|
||||||
|
.0
|
||||||
|
.read(cx)
|
||||||
|
}
|
||||||
|
|
||||||
const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
|
pub fn test(cx: &mut AppContext) {
|
||||||
|
let provider = cx.new_model(|cx| {
|
||||||
|
let mut this = Self::new(cx);
|
||||||
|
let available_model = LanguageModelRegistry::read_global(cx)
|
||||||
|
.available_models(cx)
|
||||||
|
.first()
|
||||||
|
.unwrap()
|
||||||
|
.clone();
|
||||||
|
this.set_active_model(available_model, cx);
|
||||||
|
this
|
||||||
|
});
|
||||||
|
cx.set_global(GlobalLanguageModelCompletionProvider(provider));
|
||||||
|
}
|
||||||
|
|
||||||
pub struct CompletionProvider {
|
pub fn new(cx: &mut ModelContext<Self>) -> Self {
|
||||||
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
|
cx.observe(&LanguageModelRegistry::global(cx), |_, _, cx| {
|
||||||
client: Option<Arc<Client>>,
|
cx.notify();
|
||||||
request_limiter: Arc<Semaphore>,
|
})
|
||||||
}
|
.detach();
|
||||||
|
|
||||||
impl CompletionProvider {
|
|
||||||
pub fn new(
|
|
||||||
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
|
|
||||||
client: Option<Arc<Client>>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
provider,
|
active_provider: None,
|
||||||
client,
|
active_model: None,
|
||||||
request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
|
request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn available_models(&self) -> Vec<LanguageModel> {
|
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
|
||||||
self.provider.read().available_models()
|
self.active_provider.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn settings_version(&self) -> usize {
|
pub fn set_active_provider(
|
||||||
self.provider.read().settings_version()
|
&mut self,
|
||||||
|
provider_name: LanguageModelProviderName,
|
||||||
|
cx: &mut ModelContext<Self>,
|
||||||
|
) {
|
||||||
|
self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_name);
|
||||||
|
self.active_model = None;
|
||||||
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_authenticated(&self) -> bool {
|
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
|
||||||
self.provider.read().is_authenticated()
|
self.active_model.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
|
||||||
|
if self.active_model.as_ref().map_or(false, |m| {
|
||||||
|
m.id() == model.id() && m.provider_name() == model.provider_name()
|
||||||
|
}) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.active_provider =
|
||||||
|
LanguageModelRegistry::read_global(cx).provider(&model.provider_name());
|
||||||
|
self.active_model = Some(model);
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_authenticated(&self, cx: &AppContext) -> bool {
|
||||||
|
self.active_provider
|
||||||
|
.as_ref()
|
||||||
|
.map_or(false, |provider| provider.is_authenticated(cx))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
self.provider.read().authenticate(cx)
|
self.active_provider
|
||||||
}
|
.as_ref()
|
||||||
|
.map_or(Task::ready(Ok(())), |provider| provider.authenticate(cx))
|
||||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
|
||||||
self.provider.read().authentication_prompt(cx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
self.provider.read().reset_credentials(cx)
|
self.active_provider
|
||||||
}
|
.as_ref()
|
||||||
|
.map_or(Task::ready(Ok(())), |provider| {
|
||||||
pub fn model(&self) -> LanguageModel {
|
provider.reset_credentials(cx)
|
||||||
self.provider.read().model()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn count_tokens(
|
pub fn count_tokens(
|
||||||
|
@ -110,25 +139,31 @@ impl CompletionProvider {
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AppContext,
|
cx: &AppContext,
|
||||||
) -> BoxFuture<'static, Result<usize>> {
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
self.provider.read().count_tokens(request, cx)
|
if let Some(model) = self.active_model() {
|
||||||
|
model.count_tokens(request, cx)
|
||||||
|
} else {
|
||||||
|
std::future::ready(Err(anyhow!("No active model set"))).boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn stream_completion(
|
pub fn stream_completion(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AppContext,
|
cx: &AppContext,
|
||||||
) -> Task<Result<CompletionResponse>> {
|
) -> Task<Result<LanguageModelCompletionResponse>> {
|
||||||
|
if let Some(language_model) = self.active_model() {
|
||||||
let rate_limiter = self.request_limiter.clone();
|
let rate_limiter = self.request_limiter.clone();
|
||||||
let provider = self.provider.clone();
|
cx.spawn(|cx| async move {
|
||||||
cx.foreground_executor().spawn(async move {
|
|
||||||
let lock = rate_limiter.acquire_arc().await;
|
let lock = rate_limiter.acquire_arc().await;
|
||||||
let response = provider.read().stream_completion(request);
|
let response = language_model.stream_completion(request, &cx).await?;
|
||||||
let response = response.await?;
|
Ok(LanguageModelCompletionResponse {
|
||||||
Ok(CompletionResponse {
|
|
||||||
inner: response,
|
inner: response,
|
||||||
_lock: lock,
|
_lock: lock,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
} else {
|
||||||
|
Task::ready(Err(anyhow!("No active model set")))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
|
pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
|
||||||
|
@ -143,63 +178,43 @@ impl CompletionProvider {
|
||||||
Ok(completion)
|
Ok(completion)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_provider(
|
|
||||||
&mut self,
|
|
||||||
get_provider: impl FnOnce(Arc<Client>) -> Arc<RwLock<dyn LanguageModelCompletionProvider>>,
|
|
||||||
) {
|
|
||||||
if let Some(client) = &self.client {
|
|
||||||
self.provider = get_provider(Arc::clone(client));
|
|
||||||
} else {
|
|
||||||
log::warn!("completion provider cannot be updated because its client was not set");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl gpui::Global for CompletionProvider {}
|
|
||||||
|
|
||||||
impl CompletionProvider {
|
|
||||||
pub fn global(cx: &AppContext) -> &Self {
|
|
||||||
cx.global::<Self>()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
|
|
||||||
&mut self,
|
|
||||||
update: impl FnOnce(&mut T) -> R,
|
|
||||||
) -> Option<R> {
|
|
||||||
let mut provider = self.provider.write();
|
|
||||||
if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
|
|
||||||
Some(update(provider))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::sync::Arc;
|
use futures::StreamExt;
|
||||||
|
|
||||||
use gpui::AppContext;
|
use gpui::AppContext;
|
||||||
use parking_lot::RwLock;
|
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use smol::stream::StreamExt;
|
use ui::Context;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
CompletionProvider, FakeCompletionProvider, LanguageModelRequest,
|
LanguageModelCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS,
|
||||||
MAX_CONCURRENT_COMPLETION_REQUESTS,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use language_model::LanguageModelRegistry;
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
fn test_rate_limiting(cx: &mut AppContext) {
|
fn test_rate_limiting(cx: &mut AppContext) {
|
||||||
SettingsStore::test(cx);
|
SettingsStore::test(cx);
|
||||||
let fake_provider = FakeCompletionProvider::setup_test(cx);
|
let fake_provider = LanguageModelRegistry::test(cx);
|
||||||
|
|
||||||
let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
|
let model = LanguageModelRegistry::read_global(cx)
|
||||||
|
.available_models(cx)
|
||||||
|
.first()
|
||||||
|
.cloned()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let provider = cx.new_model(|cx| {
|
||||||
|
let mut provider = LanguageModelCompletionProvider::new(cx);
|
||||||
|
provider.set_active_model(model.clone(), cx);
|
||||||
|
provider
|
||||||
|
});
|
||||||
|
|
||||||
|
let fake_model = fake_provider.test_model();
|
||||||
|
|
||||||
// Enqueue some requests
|
// Enqueue some requests
|
||||||
for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
|
for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
|
||||||
let response = provider.stream_completion(
|
let response = provider.read(cx).stream_completion(
|
||||||
LanguageModelRequest {
|
LanguageModelRequest {
|
||||||
temperature: i as f32 / 10.0,
|
temperature: i as f32 / 10.0,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
|
@ -216,23 +231,18 @@ mod tests {
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
cx.background_executor().run_until_parked();
|
cx.background_executor().run_until_parked();
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
fake_provider.completion_count(),
|
fake_model.completion_count(),
|
||||||
MAX_CONCURRENT_COMPLETION_REQUESTS
|
MAX_CONCURRENT_COMPLETION_REQUESTS
|
||||||
);
|
);
|
||||||
|
|
||||||
// Get the first completion request that is in flight and mark it as completed.
|
// Get the first completion request that is in flight and mark it as completed.
|
||||||
let completion = fake_provider
|
let completion = fake_model.pending_completions().into_iter().next().unwrap();
|
||||||
.pending_completions()
|
fake_model.finish_completion(&completion);
|
||||||
.into_iter()
|
|
||||||
.next()
|
|
||||||
.unwrap();
|
|
||||||
fake_provider.finish_completion(&completion);
|
|
||||||
|
|
||||||
// Ensure that the number of in-flight completion requests is reduced.
|
// Ensure that the number of in-flight completion requests is reduced.
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
fake_provider.completion_count(),
|
fake_model.completion_count(),
|
||||||
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
|
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -240,32 +250,32 @@ mod tests {
|
||||||
|
|
||||||
// Ensure that another completion request was allowed to acquire the lock.
|
// Ensure that another completion request was allowed to acquire the lock.
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
fake_provider.completion_count(),
|
fake_model.completion_count(),
|
||||||
MAX_CONCURRENT_COMPLETION_REQUESTS
|
MAX_CONCURRENT_COMPLETION_REQUESTS
|
||||||
);
|
);
|
||||||
|
|
||||||
// Mark all completion requests as finished that are in flight.
|
// Mark all completion requests as finished that are in flight.
|
||||||
for request in fake_provider.pending_completions() {
|
for request in fake_model.pending_completions() {
|
||||||
fake_provider.finish_completion(&request);
|
fake_model.finish_completion(&request);
|
||||||
}
|
}
|
||||||
|
|
||||||
assert_eq!(fake_provider.completion_count(), 0);
|
assert_eq!(fake_model.completion_count(), 0);
|
||||||
|
|
||||||
// Wait until the background tasks acquire the lock again.
|
// Wait until the background tasks acquire the lock again.
|
||||||
cx.background_executor().run_until_parked();
|
cx.background_executor().run_until_parked();
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
fake_provider.completion_count(),
|
fake_model.completion_count(),
|
||||||
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
|
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
|
||||||
);
|
);
|
||||||
|
|
||||||
// Finish all remaining completion requests.
|
// Finish all remaining completion requests.
|
||||||
for request in fake_provider.pending_completions() {
|
for request in fake_model.pending_completions() {
|
||||||
fake_provider.finish_completion(&request);
|
fake_model.finish_completion(&request);
|
||||||
}
|
}
|
||||||
|
|
||||||
cx.background_executor().run_until_parked();
|
cx.background_executor().run_until_parked();
|
||||||
|
|
||||||
assert_eq!(fake_provider.completion_count(), 0);
|
assert_eq!(fake_model.completion_count(), 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,115 +0,0 @@
|
||||||
use anyhow::Result;
|
|
||||||
use collections::HashMap;
|
|
||||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
|
||||||
use gpui::{AnyView, AppContext, Task};
|
|
||||||
use std::sync::Arc;
|
|
||||||
use ui::WindowContext;
|
|
||||||
|
|
||||||
use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
|
|
||||||
|
|
||||||
#[derive(Clone, Default)]
|
|
||||||
pub struct FakeCompletionProvider {
|
|
||||||
current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FakeCompletionProvider {
|
|
||||||
pub fn setup_test(cx: &mut AppContext) -> Self {
|
|
||||||
use crate::CompletionProvider;
|
|
||||||
use parking_lot::RwLock;
|
|
||||||
|
|
||||||
let this = Self::default();
|
|
||||||
let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
|
|
||||||
cx.set_global(provider);
|
|
||||||
this
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
|
|
||||||
self.current_completion_txs
|
|
||||||
.lock()
|
|
||||||
.keys()
|
|
||||||
.map(|k| serde_json::from_str(k).unwrap())
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn completion_count(&self) -> usize {
|
|
||||||
self.current_completion_txs.lock().len()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
|
|
||||||
let json = serde_json::to_string(request).unwrap();
|
|
||||||
self.current_completion_txs
|
|
||||||
.lock()
|
|
||||||
.get(&json)
|
|
||||||
.unwrap()
|
|
||||||
.unbounded_send(chunk)
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send_last_completion_chunk(&self, chunk: String) {
|
|
||||||
self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn finish_completion(&self, request: &LanguageModelRequest) {
|
|
||||||
self.current_completion_txs
|
|
||||||
.lock()
|
|
||||||
.remove(&serde_json::to_string(request).unwrap())
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn finish_last_completion(&self) {
|
|
||||||
self.finish_completion(self.pending_completions().last().unwrap());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LanguageModelCompletionProvider for FakeCompletionProvider {
|
|
||||||
fn available_models(&self) -> Vec<LanguageModel> {
|
|
||||||
vec![LanguageModel::default()]
|
|
||||||
}
|
|
||||||
|
|
||||||
fn settings_version(&self) -> usize {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_authenticated(&self) -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
|
|
||||||
Task::ready(Ok(()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
|
|
||||||
Task::ready(Ok(()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn model(&self) -> LanguageModel {
|
|
||||||
LanguageModel::default()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn count_tokens(
|
|
||||||
&self,
|
|
||||||
_request: LanguageModelRequest,
|
|
||||||
_cx: &AppContext,
|
|
||||||
) -> BoxFuture<'static, Result<usize>> {
|
|
||||||
futures::future::ready(Ok(0)).boxed()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn stream_completion(
|
|
||||||
&self,
|
|
||||||
_request: LanguageModelRequest,
|
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
|
||||||
let (tx, rx) = mpsc::unbounded();
|
|
||||||
self.current_completion_txs
|
|
||||||
.lock()
|
|
||||||
.insert(serde_json::to_string(&_request).unwrap(), tx);
|
|
||||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -10384,7 +10384,7 @@ impl Editor {
|
||||||
};
|
};
|
||||||
let fs = workspace.read(cx).app_state().fs.clone();
|
let fs = workspace.read(cx).app_state().fs.clone();
|
||||||
let current_show = TabBarSettings::get_global(cx).show;
|
let current_show = TabBarSettings::get_global(cx).show;
|
||||||
update_settings_file::<TabBarSettings>(fs, cx, move |setting| {
|
update_settings_file::<TabBarSettings>(fs, cx, move |setting, _| {
|
||||||
setting.show = Some(!current_show);
|
setting.show = Some(!current_show);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -178,7 +178,7 @@ impl PickerDelegate for ExtensionVersionSelectorDelegate {
|
||||||
|
|
||||||
update_settings_file::<ExtensionSettings>(self.fs.clone(), cx, {
|
update_settings_file::<ExtensionSettings>(self.fs.clone(), cx, {
|
||||||
let extension_id = extension_id.clone();
|
let extension_id = extension_id.clone();
|
||||||
move |settings| {
|
move |settings, _| {
|
||||||
settings.auto_update_extensions.insert(extension_id, false);
|
settings.auto_update_extensions.insert(extension_id, false);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -910,7 +910,7 @@ impl ExtensionsPage {
|
||||||
if let Some(workspace) = self.workspace.upgrade() {
|
if let Some(workspace) = self.workspace.upgrade() {
|
||||||
let fs = workspace.read(cx).app_state().fs.clone();
|
let fs = workspace.read(cx).app_state().fs.clone();
|
||||||
let selection = *selection;
|
let selection = *selection;
|
||||||
settings::update_settings_file::<T>(fs, cx, move |settings| {
|
settings::update_settings_file::<T>(fs, cx, move |settings, _| {
|
||||||
let value = match selection {
|
let value = match selection {
|
||||||
Selection::Unselected => false,
|
Selection::Unselected => false,
|
||||||
Selection::Selected => true,
|
Selection::Selected => true,
|
||||||
|
|
|
@ -29,6 +29,11 @@ impl FeatureFlag for Remoting {
|
||||||
const NAME: &'static str = "remoting";
|
const NAME: &'static str = "remoting";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct LanguageModels {}
|
||||||
|
impl FeatureFlag for LanguageModels {
|
||||||
|
const NAME: &'static str = "language-models";
|
||||||
|
}
|
||||||
|
|
||||||
pub struct TerminalInlineAssist {}
|
pub struct TerminalInlineAssist {}
|
||||||
impl FeatureFlag for TerminalInlineAssist {
|
impl FeatureFlag for TerminalInlineAssist {
|
||||||
const NAME: &'static str = "terminal-inline-assist";
|
const NAME: &'static str = "terminal-inline-assist";
|
||||||
|
@ -65,6 +70,10 @@ pub trait FeatureFlagAppExt {
|
||||||
fn set_staff(&mut self, staff: bool);
|
fn set_staff(&mut self, staff: bool);
|
||||||
fn has_flag<T: FeatureFlag>(&self) -> bool;
|
fn has_flag<T: FeatureFlag>(&self) -> bool;
|
||||||
fn is_staff(&self) -> bool;
|
fn is_staff(&self) -> bool;
|
||||||
|
|
||||||
|
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
|
||||||
|
where
|
||||||
|
F: Fn(bool, &mut AppContext) + 'static;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FeatureFlagAppExt for AppContext {
|
impl FeatureFlagAppExt for AppContext {
|
||||||
|
@ -90,4 +99,14 @@ impl FeatureFlagAppExt for AppContext {
|
||||||
.map(|flags| flags.staff)
|
.map(|flags| flags.staff)
|
||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
|
||||||
|
where
|
||||||
|
F: Fn(bool, &mut AppContext) + 'static,
|
||||||
|
{
|
||||||
|
self.observe_global::<FeatureFlags>(move |cx| {
|
||||||
|
let feature_flags = cx.global::<FeatureFlags>();
|
||||||
|
callback(feature_flags.has_flag(<T as FeatureFlag>::NAME), cx);
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -420,7 +420,7 @@ async fn configure_disabled_globs(
|
||||||
fn toggle_inline_completions_globally(fs: Arc<dyn Fs>, cx: &mut AppContext) {
|
fn toggle_inline_completions_globally(fs: Arc<dyn Fs>, cx: &mut AppContext) {
|
||||||
let show_inline_completions =
|
let show_inline_completions =
|
||||||
all_language_settings(None, cx).inline_completions_enabled(None, None);
|
all_language_settings(None, cx).inline_completions_enabled(None, None);
|
||||||
update_settings_file::<AllLanguageSettings>(fs, cx, move |file| {
|
update_settings_file::<AllLanguageSettings>(fs, cx, move |file, _| {
|
||||||
file.defaults.show_inline_completions = Some(!show_inline_completions)
|
file.defaults.show_inline_completions = Some(!show_inline_completions)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -432,7 +432,7 @@ fn toggle_inline_completions_for_language(
|
||||||
) {
|
) {
|
||||||
let show_inline_completions =
|
let show_inline_completions =
|
||||||
all_language_settings(None, cx).inline_completions_enabled(Some(&language), None);
|
all_language_settings(None, cx).inline_completions_enabled(Some(&language), None);
|
||||||
update_settings_file::<AllLanguageSettings>(fs, cx, move |file| {
|
update_settings_file::<AllLanguageSettings>(fs, cx, move |file, _| {
|
||||||
file.languages
|
file.languages
|
||||||
.entry(language.name())
|
.entry(language.name())
|
||||||
.or_default()
|
.or_default()
|
||||||
|
@ -441,7 +441,7 @@ fn toggle_inline_completions_for_language(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn hide_copilot(fs: Arc<dyn Fs>, cx: &mut AppContext) {
|
fn hide_copilot(fs: Arc<dyn Fs>, cx: &mut AppContext) {
|
||||||
update_settings_file::<AllLanguageSettings>(fs, cx, move |file| {
|
update_settings_file::<AllLanguageSettings>(fs, cx, move |file, _| {
|
||||||
file.features
|
file.features
|
||||||
.get_or_insert(Default::default())
|
.get_or_insert(Default::default())
|
||||||
.inline_completion_provider = Some(InlineCompletionProvider::None);
|
.inline_completion_provider = Some(InlineCompletionProvider::None);
|
||||||
|
|
|
@ -22,12 +22,27 @@ test-support = [
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anthropic = { workspace = true, features = ["schemars"] }
|
anthropic = { workspace = true, features = ["schemars"] }
|
||||||
|
anyhow.workspace = true
|
||||||
|
client.workspace = true
|
||||||
|
collections.workspace = true
|
||||||
|
editor.workspace = true
|
||||||
|
feature_flags.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
|
gpui.workspace = true
|
||||||
|
http.workspace = true
|
||||||
|
menu.workspace = true
|
||||||
ollama = { workspace = true, features = ["schemars"] }
|
ollama = { workspace = true, features = ["schemars"] }
|
||||||
open_ai = { workspace = true, features = ["schemars"] }
|
open_ai = { workspace = true, features = ["schemars"] }
|
||||||
|
proto = { workspace = true, features = ["test-support"] }
|
||||||
schemars.workspace = true
|
schemars.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
settings.workspace = true
|
||||||
strum.workspace = true
|
strum.workspace = true
|
||||||
proto = { workspace = true, features = ["test-support"] }
|
theme.workspace = true
|
||||||
|
tiktoken-rs.workspace = true
|
||||||
|
ui.workspace = true
|
||||||
|
util.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
ctor.workspace = true
|
ctor.workspace = true
|
||||||
|
|
|
@ -1,7 +1,84 @@
|
||||||
mod model;
|
mod model;
|
||||||
|
pub mod provider;
|
||||||
|
mod registry;
|
||||||
mod request;
|
mod request;
|
||||||
mod role;
|
mod role;
|
||||||
|
pub mod settings;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use client::Client;
|
||||||
|
use futures::{future::BoxFuture, stream::BoxStream};
|
||||||
|
use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
|
||||||
|
|
||||||
pub use model::*;
|
pub use model::*;
|
||||||
|
pub use registry::*;
|
||||||
pub use request::*;
|
pub use request::*;
|
||||||
pub use role::*;
|
pub use role::*;
|
||||||
|
|
||||||
|
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||||
|
settings::init(cx);
|
||||||
|
registry::init(client, cx);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait LanguageModel: Send + Sync {
|
||||||
|
fn id(&self) -> LanguageModelId;
|
||||||
|
fn name(&self) -> LanguageModelName;
|
||||||
|
fn provider_name(&self) -> LanguageModelProviderName;
|
||||||
|
fn telemetry_id(&self) -> String;
|
||||||
|
|
||||||
|
fn max_token_count(&self) -> usize;
|
||||||
|
|
||||||
|
fn count_tokens(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> BoxFuture<'static, Result<usize>>;
|
||||||
|
|
||||||
|
fn stream_completion(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
cx: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait LanguageModelProvider: 'static {
|
||||||
|
fn name(&self) -> LanguageModelProviderName;
|
||||||
|
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
|
||||||
|
fn is_authenticated(&self, cx: &AppContext) -> bool;
|
||||||
|
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||||
|
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
|
||||||
|
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait LanguageModelProviderState: 'static {
|
||||||
|
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
|
||||||
|
pub struct LanguageModelId(pub SharedString);
|
||||||
|
|
||||||
|
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
|
||||||
|
pub struct LanguageModelName(pub SharedString);
|
||||||
|
|
||||||
|
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
|
||||||
|
pub struct LanguageModelProviderName(pub SharedString);
|
||||||
|
|
||||||
|
impl From<String> for LanguageModelId {
|
||||||
|
fn from(value: String) -> Self {
|
||||||
|
Self(SharedString::from(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for LanguageModelName {
|
||||||
|
fn from(value: String) -> Self {
|
||||||
|
Self(SharedString::from(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for LanguageModelProviderName {
|
||||||
|
fn from(value: String) -> Self {
|
||||||
|
Self(SharedString::from(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
pub use anthropic::Model as AnthropicModel;
|
pub use anthropic::Model as AnthropicModel;
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
pub use ollama::Model as OllamaModel;
|
pub use ollama::Model as OllamaModel;
|
||||||
pub use open_ai::Model as OpenAiModel;
|
pub use open_ai::Model as OpenAiModel;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
|
@ -38,6 +39,23 @@ pub enum CloudModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CloudModel {
|
impl CloudModel {
|
||||||
|
pub fn from_id(value: &str) -> Result<Self> {
|
||||||
|
match value {
|
||||||
|
"gpt-3.5-turbo" => Ok(Self::Gpt3Point5Turbo),
|
||||||
|
"gpt-4" => Ok(Self::Gpt4),
|
||||||
|
"gpt-4-turbo-preview" => Ok(Self::Gpt4Turbo),
|
||||||
|
"gpt-4o" => Ok(Self::Gpt4Omni),
|
||||||
|
"gpt-4o-mini" => Ok(Self::Gpt4OmniMini),
|
||||||
|
"claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
|
||||||
|
"claude-3-opus" => Ok(Self::Claude3Opus),
|
||||||
|
"claude-3-sonnet" => Ok(Self::Claude3Sonnet),
|
||||||
|
"claude-3-haiku" => Ok(Self::Claude3Haiku),
|
||||||
|
"gemini-1.5-pro" => Ok(Self::Gemini15Pro),
|
||||||
|
"gemini-1.5-flash" => Ok(Self::Gemini15Flash),
|
||||||
|
_ => Err(anyhow!("invalid model id")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn id(&self) -> &str {
|
pub fn id(&self) -> &str {
|
||||||
match self {
|
match self {
|
||||||
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
|
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
|
||||||
|
|
|
@ -4,57 +4,3 @@ pub use anthropic::Model as AnthropicModel;
|
||||||
pub use cloud_model::*;
|
pub use cloud_model::*;
|
||||||
pub use ollama::Model as OllamaModel;
|
pub use ollama::Model as OllamaModel;
|
||||||
pub use open_ai::Model as OpenAiModel;
|
pub use open_ai::Model as OpenAiModel;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
|
||||||
pub enum LanguageModel {
|
|
||||||
Cloud(CloudModel),
|
|
||||||
OpenAi(OpenAiModel),
|
|
||||||
Anthropic(AnthropicModel),
|
|
||||||
Ollama(OllamaModel),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for LanguageModel {
|
|
||||||
fn default() -> Self {
|
|
||||||
LanguageModel::Cloud(CloudModel::default())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LanguageModel {
|
|
||||||
pub fn telemetry_id(&self) -> String {
|
|
||||||
match self {
|
|
||||||
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
|
|
||||||
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
|
|
||||||
LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
|
|
||||||
LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn display_name(&self) -> String {
|
|
||||||
match self {
|
|
||||||
LanguageModel::OpenAi(model) => model.display_name().into(),
|
|
||||||
LanguageModel::Anthropic(model) => model.display_name().into(),
|
|
||||||
LanguageModel::Cloud(model) => model.display_name().into(),
|
|
||||||
LanguageModel::Ollama(model) => model.display_name().into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn max_token_count(&self) -> usize {
|
|
||||||
match self {
|
|
||||||
LanguageModel::OpenAi(model) => model.max_token_count(),
|
|
||||||
LanguageModel::Anthropic(model) => model.max_token_count(),
|
|
||||||
LanguageModel::Cloud(model) => model.max_token_count(),
|
|
||||||
LanguageModel::Ollama(model) => model.max_token_count(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn id(&self) -> &str {
|
|
||||||
match self {
|
|
||||||
LanguageModel::OpenAi(model) => model.id(),
|
|
||||||
LanguageModel::Anthropic(model) => model.id(),
|
|
||||||
LanguageModel::Cloud(model) => model.id(),
|
|
||||||
LanguageModel::Ollama(model) => model.id(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
6
crates/language_model/src/provider.rs
Normal file
6
crates/language_model/src/provider.rs
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
pub mod anthropic;
|
||||||
|
pub mod cloud;
|
||||||
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
|
pub mod fake;
|
||||||
|
pub mod ollama;
|
||||||
|
pub mod open_ai;
|
454
crates/language_model/src/provider/anthropic.rs
Normal file
454
crates/language_model/src/provider/anthropic.rs
Normal file
|
@ -0,0 +1,454 @@
|
||||||
|
use anthropic::{stream_completion, Request, RequestMessage};
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use collections::HashMap;
|
||||||
|
use editor::{Editor, EditorElement, EditorStyle};
|
||||||
|
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
|
use gpui::{
|
||||||
|
AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View,
|
||||||
|
WhiteSpace,
|
||||||
|
};
|
||||||
|
use http::HttpClient;
|
||||||
|
use settings::{Settings, SettingsStore};
|
||||||
|
use std::{sync::Arc, time::Duration};
|
||||||
|
use strum::IntoEnumIterator;
|
||||||
|
use theme::ThemeSettings;
|
||||||
|
use ui::prelude::*;
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||||
|
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
|
||||||
|
LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||||
|
};
|
||||||
|
|
||||||
|
const PROVIDER_NAME: &str = "anthropic";
|
||||||
|
|
||||||
|
#[derive(Default, Clone, Debug, PartialEq)]
|
||||||
|
pub struct AnthropicSettings {
|
||||||
|
pub api_url: String,
|
||||||
|
pub low_speed_timeout: Option<Duration>,
|
||||||
|
pub available_models: Vec<anthropic::Model>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AnthropicLanguageModelProvider {
|
||||||
|
http_client: Arc<dyn HttpClient>,
|
||||||
|
state: gpui::Model<State>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct State {
|
||||||
|
api_key: Option<String>,
|
||||||
|
settings: AnthropicSettings,
|
||||||
|
_subscription: Subscription,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AnthropicLanguageModelProvider {
|
||||||
|
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
|
||||||
|
let state = cx.new_model(|cx| State {
|
||||||
|
api_key: None,
|
||||||
|
settings: AnthropicSettings::default(),
|
||||||
|
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||||
|
this.settings = AllLanguageModelSettings::get_global(cx).anthropic.clone();
|
||||||
|
cx.notify();
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
Self { http_client, state }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl LanguageModelProviderState for AnthropicLanguageModelProvider {
|
||||||
|
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||||
|
Some(cx.observe(&self.state, |_, _, cx| {
|
||||||
|
cx.notify();
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
||||||
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
|
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||||
|
let mut models = HashMap::default();
|
||||||
|
|
||||||
|
// Add base models from anthropic::Model::iter()
|
||||||
|
for model in anthropic::Model::iter() {
|
||||||
|
if !matches!(model, anthropic::Model::Custom { .. }) {
|
||||||
|
models.insert(model.id().to_string(), model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override with available models from settings
|
||||||
|
for model in &self.state.read(cx).settings.available_models {
|
||||||
|
models.insert(model.id().to_string(), model.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
models
|
||||||
|
.into_values()
|
||||||
|
.map(|model| {
|
||||||
|
Arc::new(AnthropicModel {
|
||||||
|
id: LanguageModelId::from(model.id().to_string()),
|
||||||
|
model,
|
||||||
|
state: self.state.clone(),
|
||||||
|
http_client: self.http_client.clone(),
|
||||||
|
}) as Arc<dyn LanguageModel>
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_authenticated(&self, cx: &AppContext) -> bool {
|
||||||
|
self.state.read(cx).api_key.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
if self.is_authenticated(cx) {
|
||||||
|
Task::ready(Ok(()))
|
||||||
|
} else {
|
||||||
|
let api_url = self.state.read(cx).settings.api_url.clone();
|
||||||
|
let state = self.state.clone();
|
||||||
|
cx.spawn(|mut cx| async move {
|
||||||
|
let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
|
||||||
|
api_key
|
||||||
|
} else {
|
||||||
|
let (_, api_key) = cx
|
||||||
|
.update(|cx| cx.read_credentials(&api_url))?
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||||
|
String::from_utf8(api_key)?
|
||||||
|
};
|
||||||
|
|
||||||
|
state.update(&mut cx, |this, cx| {
|
||||||
|
this.api_key = Some(api_key);
|
||||||
|
cx.notify();
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||||
|
cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx))
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
let state = self.state.clone();
|
||||||
|
let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
|
||||||
|
cx.spawn(|mut cx| async move {
|
||||||
|
delete_credentials.await.log_err();
|
||||||
|
state.update(&mut cx, |this, cx| {
|
||||||
|
this.api_key = None;
|
||||||
|
cx.notify();
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AnthropicModel {
|
||||||
|
id: LanguageModelId,
|
||||||
|
model: anthropic::Model,
|
||||||
|
state: gpui::Model<State>,
|
||||||
|
http_client: Arc<dyn HttpClient>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AnthropicModel {
|
||||||
|
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
|
||||||
|
preprocess_anthropic_request(&mut request);
|
||||||
|
|
||||||
|
let mut system_message = String::new();
|
||||||
|
if request
|
||||||
|
.messages
|
||||||
|
.first()
|
||||||
|
.map_or(false, |message| message.role == Role::System)
|
||||||
|
{
|
||||||
|
system_message = request.messages.remove(0).content;
|
||||||
|
}
|
||||||
|
|
||||||
|
Request {
|
||||||
|
model: self.model.clone(),
|
||||||
|
messages: request
|
||||||
|
.messages
|
||||||
|
.iter()
|
||||||
|
.map(|msg| RequestMessage {
|
||||||
|
role: match msg.role {
|
||||||
|
Role::User => anthropic::Role::User,
|
||||||
|
Role::Assistant => anthropic::Role::Assistant,
|
||||||
|
Role::System => unreachable!("filtered out by preprocess_request"),
|
||||||
|
},
|
||||||
|
content: msg.content.clone(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
stream: true,
|
||||||
|
system: system_message,
|
||||||
|
max_tokens: 4092,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn count_anthropic_tokens(
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
|
cx.background_executor()
|
||||||
|
.spawn(async move {
|
||||||
|
let messages = request
|
||||||
|
.messages
|
||||||
|
.into_iter()
|
||||||
|
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||||
|
role: match message.role {
|
||||||
|
Role::User => "user".into(),
|
||||||
|
Role::Assistant => "assistant".into(),
|
||||||
|
Role::System => "system".into(),
|
||||||
|
},
|
||||||
|
content: Some(message.content),
|
||||||
|
name: None,
|
||||||
|
function_call: None,
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// Tiktoken doesn't yet support these models, so we manually use the
|
||||||
|
// same tokenizer as GPT-4.
|
||||||
|
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModel for AnthropicModel {
|
||||||
|
fn id(&self) -> LanguageModelId {
|
||||||
|
self.id.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> LanguageModelName {
|
||||||
|
LanguageModelName::from(self.model.display_name().to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
|
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn telemetry_id(&self) -> String {
|
||||||
|
format!("anthropic/{}", self.model.id())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_token_count(&self) -> usize {
|
||||||
|
self.model.max_token_count()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn count_tokens(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
|
count_anthropic_tokens(request, cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stream_completion(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
cx: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
|
let request = self.to_anthropic_request(request);
|
||||||
|
|
||||||
|
let http_client = self.http_client.clone();
|
||||||
|
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
|
||||||
|
(
|
||||||
|
state.api_key.clone(),
|
||||||
|
state.settings.api_url.clone(),
|
||||||
|
state.settings.low_speed_timeout,
|
||||||
|
)
|
||||||
|
}) else {
|
||||||
|
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||||
|
};
|
||||||
|
|
||||||
|
async move {
|
||||||
|
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||||
|
let request = stream_completion(
|
||||||
|
http_client.as_ref(),
|
||||||
|
&api_url,
|
||||||
|
&api_key,
|
||||||
|
request,
|
||||||
|
low_speed_timeout,
|
||||||
|
);
|
||||||
|
let response = request.await?;
|
||||||
|
let stream = response
|
||||||
|
.filter_map(|response| async move {
|
||||||
|
match response {
|
||||||
|
Ok(response) => match response {
|
||||||
|
anthropic::ResponseEvent::ContentBlockStart {
|
||||||
|
content_block, ..
|
||||||
|
} => match content_block {
|
||||||
|
anthropic::ContentBlock::Text { text } => Some(Ok(text)),
|
||||||
|
},
|
||||||
|
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
|
||||||
|
match delta {
|
||||||
|
anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
},
|
||||||
|
Err(error) => Some(Err(error)),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed();
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
|
||||||
|
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
|
||||||
|
let mut system_message = String::new();
|
||||||
|
|
||||||
|
for message in request.messages.drain(..) {
|
||||||
|
if message.content.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match message.role {
|
||||||
|
Role::User | Role::Assistant => {
|
||||||
|
if let Some(last_message) = new_messages.last_mut() {
|
||||||
|
if last_message.role == message.role {
|
||||||
|
last_message.content.push_str("\n\n");
|
||||||
|
last_message.content.push_str(&message.content);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
new_messages.push(message);
|
||||||
|
}
|
||||||
|
Role::System => {
|
||||||
|
if !system_message.is_empty() {
|
||||||
|
system_message.push_str("\n\n");
|
||||||
|
}
|
||||||
|
system_message.push_str(&message.content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !system_message.is_empty() {
|
||||||
|
new_messages.insert(
|
||||||
|
0,
|
||||||
|
LanguageModelRequestMessage {
|
||||||
|
role: Role::System,
|
||||||
|
content: system_message,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
request.messages = new_messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AuthenticationPrompt {
|
||||||
|
api_key: View<Editor>,
|
||||||
|
state: gpui::Model<State>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthenticationPrompt {
|
||||||
|
fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
|
||||||
|
Self {
|
||||||
|
api_key: cx.new_view(|cx| {
|
||||||
|
let mut editor = Editor::single_line(cx);
|
||||||
|
editor.set_placeholder_text(
|
||||||
|
"sk-000000000000000000000000000000000000000000000000",
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
editor
|
||||||
|
}),
|
||||||
|
state,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||||
|
let api_key = self.api_key.read(cx).text(cx);
|
||||||
|
if api_key.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let write_credentials = cx.write_credentials(
|
||||||
|
&self.state.read(cx).settings.api_url,
|
||||||
|
"Bearer",
|
||||||
|
api_key.as_bytes(),
|
||||||
|
);
|
||||||
|
let state = self.state.clone();
|
||||||
|
cx.spawn(|_, mut cx| async move {
|
||||||
|
write_credentials.await?;
|
||||||
|
|
||||||
|
state.update(&mut cx, |this, cx| {
|
||||||
|
this.api_key = Some(api_key);
|
||||||
|
cx.notify();
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.detach_and_log_err(cx);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||||
|
let settings = ThemeSettings::get_global(cx);
|
||||||
|
let text_style = TextStyle {
|
||||||
|
color: cx.theme().colors().text,
|
||||||
|
font_family: settings.ui_font.family.clone(),
|
||||||
|
font_features: settings.ui_font.features.clone(),
|
||||||
|
font_size: rems(0.875).into(),
|
||||||
|
font_weight: settings.ui_font.weight,
|
||||||
|
font_style: FontStyle::Normal,
|
||||||
|
line_height: relative(1.3),
|
||||||
|
background_color: None,
|
||||||
|
underline: None,
|
||||||
|
strikethrough: None,
|
||||||
|
white_space: WhiteSpace::Normal,
|
||||||
|
};
|
||||||
|
EditorElement::new(
|
||||||
|
&self.api_key,
|
||||||
|
EditorStyle {
|
||||||
|
background: cx.theme().colors().editor_background,
|
||||||
|
local_player: cx.theme().players().local(),
|
||||||
|
text: text_style,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Render for AuthenticationPrompt {
|
||||||
|
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||||
|
const INSTRUCTIONS: [&str; 4] = [
|
||||||
|
"To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
|
||||||
|
"You can create an API key at: https://console.anthropic.com/settings/keys",
|
||||||
|
"",
|
||||||
|
"Paste your Anthropic API key below and hit enter to use the assistant:",
|
||||||
|
];
|
||||||
|
|
||||||
|
v_flex()
|
||||||
|
.p_4()
|
||||||
|
.size_full()
|
||||||
|
.on_action(cx.listener(Self::save_api_key))
|
||||||
|
.children(
|
||||||
|
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.w_full()
|
||||||
|
.my_2()
|
||||||
|
.px_2()
|
||||||
|
.py_1()
|
||||||
|
.bg(cx.theme().colors().editor_background)
|
||||||
|
.rounded_md()
|
||||||
|
.child(self.render_api_key_editor(cx)),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
Label::new(
|
||||||
|
"You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
|
||||||
|
)
|
||||||
|
.size(LabelSize::Small),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.gap_2()
|
||||||
|
.child(Label::new("Click on").size(LabelSize::Small))
|
||||||
|
.child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
|
||||||
|
.child(
|
||||||
|
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.into_any()
|
||||||
|
}
|
||||||
|
}
|
287
crates/language_model/src/provider/cloud.rs
Normal file
287
crates/language_model/src/provider/cloud.rs
Normal file
|
@ -0,0 +1,287 @@
|
||||||
|
use super::open_ai::count_open_ai_tokens;
|
||||||
|
use crate::{
|
||||||
|
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
|
||||||
|
LanguageModelName, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||||
|
};
|
||||||
|
use anyhow::Result;
|
||||||
|
use client::Client;
|
||||||
|
use collections::HashMap;
|
||||||
|
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
||||||
|
use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
|
||||||
|
use settings::{Settings, SettingsStore};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use strum::IntoEnumIterator;
|
||||||
|
use ui::prelude::*;
|
||||||
|
|
||||||
|
use crate::LanguageModelProvider;
|
||||||
|
|
||||||
|
use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
|
||||||
|
|
||||||
|
pub const PROVIDER_NAME: &str = "zed.dev";
|
||||||
|
|
||||||
|
#[derive(Default, Clone, Debug, PartialEq)]
|
||||||
|
pub struct ZedDotDevSettings {
|
||||||
|
pub available_models: Vec<CloudModel>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CloudLanguageModelProvider {
|
||||||
|
client: Arc<Client>,
|
||||||
|
state: gpui::Model<State>,
|
||||||
|
_maintain_client_status: Task<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct State {
|
||||||
|
client: Arc<Client>,
|
||||||
|
status: client::Status,
|
||||||
|
settings: ZedDotDevSettings,
|
||||||
|
_subscription: Subscription,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl State {
|
||||||
|
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
let client = self.client.clone();
|
||||||
|
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CloudLanguageModelProvider {
|
||||||
|
pub fn new(client: Arc<Client>, cx: &mut AppContext) -> Self {
|
||||||
|
let mut status_rx = client.status();
|
||||||
|
let status = *status_rx.borrow();
|
||||||
|
|
||||||
|
let state = cx.new_model(|cx| State {
|
||||||
|
client: client.clone(),
|
||||||
|
status,
|
||||||
|
settings: ZedDotDevSettings::default(),
|
||||||
|
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||||
|
this.settings = AllLanguageModelSettings::get_global(cx).zed_dot_dev.clone();
|
||||||
|
cx.notify();
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
let state_ref = state.downgrade();
|
||||||
|
let maintain_client_status = cx.spawn(|mut cx| async move {
|
||||||
|
while let Some(status) = status_rx.next().await {
|
||||||
|
if let Some(this) = state_ref.upgrade() {
|
||||||
|
_ = this.update(&mut cx, |this, cx| {
|
||||||
|
this.status = status;
|
||||||
|
cx.notify();
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Self {
|
||||||
|
client,
|
||||||
|
state,
|
||||||
|
_maintain_client_status: maintain_client_status,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelProviderState for CloudLanguageModelProvider {
|
||||||
|
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||||
|
Some(cx.observe(&self.state, |_, _, cx| {
|
||||||
|
cx.notify();
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||||
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
|
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||||
|
let mut models = HashMap::default();
|
||||||
|
|
||||||
|
// Add base models from CloudModel::iter()
|
||||||
|
for model in CloudModel::iter() {
|
||||||
|
if !matches!(model, CloudModel::Custom { .. }) {
|
||||||
|
models.insert(model.id().to_string(), model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override with available models from settings
|
||||||
|
for model in &self.state.read(cx).settings.available_models {
|
||||||
|
models.insert(model.id().to_string(), model.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
models
|
||||||
|
.into_values()
|
||||||
|
.map(|model| {
|
||||||
|
Arc::new(CloudLanguageModel {
|
||||||
|
id: LanguageModelId::from(model.id().to_string()),
|
||||||
|
model,
|
||||||
|
client: self.client.clone(),
|
||||||
|
}) as Arc<dyn LanguageModel>
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_authenticated(&self, cx: &AppContext) -> bool {
|
||||||
|
self.state.read(cx).status.is_connected()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
self.state.read(cx).authenticate(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||||
|
cx.new_view(|_cx| AuthenticationPrompt {
|
||||||
|
state: self.state.clone(),
|
||||||
|
})
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
Task::ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CloudLanguageModel {
|
||||||
|
id: LanguageModelId,
|
||||||
|
model: CloudModel,
|
||||||
|
client: Arc<Client>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModel for CloudLanguageModel {
|
||||||
|
fn id(&self) -> LanguageModelId {
|
||||||
|
self.id.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> LanguageModelName {
|
||||||
|
LanguageModelName::from(self.model.display_name().to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
|
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn telemetry_id(&self) -> String {
|
||||||
|
format!("zed.dev/{}", self.model.id())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_token_count(&self) -> usize {
|
||||||
|
self.model.max_token_count()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn count_tokens(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
|
match &self.model {
|
||||||
|
CloudModel::Gpt3Point5Turbo => {
|
||||||
|
count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
|
||||||
|
}
|
||||||
|
CloudModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx),
|
||||||
|
CloudModel::Gpt4Turbo => count_open_ai_tokens(request, open_ai::Model::FourTurbo, cx),
|
||||||
|
CloudModel::Gpt4Omni => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx),
|
||||||
|
CloudModel::Gpt4OmniMini => {
|
||||||
|
count_open_ai_tokens(request, open_ai::Model::FourOmniMini, cx)
|
||||||
|
}
|
||||||
|
CloudModel::Claude3_5Sonnet
|
||||||
|
| CloudModel::Claude3Opus
|
||||||
|
| CloudModel::Claude3Sonnet
|
||||||
|
| CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
|
||||||
|
_ => {
|
||||||
|
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
||||||
|
model: self.model.id().to_string(),
|
||||||
|
messages: request
|
||||||
|
.messages
|
||||||
|
.iter()
|
||||||
|
.map(|message| message.to_proto())
|
||||||
|
.collect(),
|
||||||
|
});
|
||||||
|
async move {
|
||||||
|
let response = request.await?;
|
||||||
|
Ok(response.token_count as usize)
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stream_completion(
|
||||||
|
&self,
|
||||||
|
mut request: LanguageModelRequest,
|
||||||
|
_: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
|
match &self.model {
|
||||||
|
CloudModel::Claude3Opus
|
||||||
|
| CloudModel::Claude3Sonnet
|
||||||
|
| CloudModel::Claude3Haiku
|
||||||
|
| CloudModel::Claude3_5Sonnet => preprocess_anthropic_request(&mut request),
|
||||||
|
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
|
||||||
|
preprocess_anthropic_request(&mut request)
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
let request = proto::CompleteWithLanguageModel {
|
||||||
|
model: self.id.0.to_string(),
|
||||||
|
messages: request
|
||||||
|
.messages
|
||||||
|
.iter()
|
||||||
|
.map(|message| message.to_proto())
|
||||||
|
.collect(),
|
||||||
|
stop: request.stop,
|
||||||
|
temperature: request.temperature,
|
||||||
|
tools: Vec::new(),
|
||||||
|
tool_choice: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
self.client
|
||||||
|
.request_stream(request)
|
||||||
|
.map_ok(|stream| {
|
||||||
|
stream
|
||||||
|
.filter_map(|response| async move {
|
||||||
|
match response {
|
||||||
|
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
|
||||||
|
Err(error) => Some(Err(error)),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AuthenticationPrompt {
|
||||||
|
state: gpui::Model<State>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Render for AuthenticationPrompt {
|
||||||
|
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||||
|
const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
|
||||||
|
|
||||||
|
v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
|
||||||
|
v_flex()
|
||||||
|
.gap_2()
|
||||||
|
.child(
|
||||||
|
Button::new("sign_in", "Sign in")
|
||||||
|
.icon_color(Color::Muted)
|
||||||
|
.icon(IconName::Github)
|
||||||
|
.icon_position(IconPosition::Start)
|
||||||
|
.style(ButtonStyle::Filled)
|
||||||
|
.full_width()
|
||||||
|
.on_click(cx.listener(move |this, _, cx| {
|
||||||
|
this.state.update(cx, |provider, cx| {
|
||||||
|
provider.authenticate(cx).detach_and_log_err(cx);
|
||||||
|
cx.notify();
|
||||||
|
});
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
div().flex().w_full().items_center().child(
|
||||||
|
Label::new("Sign in to enable collaboration.")
|
||||||
|
.color(Color::Muted)
|
||||||
|
.size(LabelSize::Small),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
160
crates/language_model/src/provider/fake.rs
Normal file
160
crates/language_model/src/provider/fake.rs
Normal file
|
@ -0,0 +1,160 @@
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use collections::HashMap;
|
||||||
|
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||||
|
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||||
|
};
|
||||||
|
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
|
||||||
|
use http::Result;
|
||||||
|
use ui::WindowContext;
|
||||||
|
|
||||||
|
pub fn language_model_id() -> LanguageModelId {
|
||||||
|
LanguageModelId::from("fake".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn language_model_name() -> LanguageModelName {
|
||||||
|
LanguageModelName::from("Fake".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn provider_name() -> LanguageModelProviderName {
|
||||||
|
LanguageModelProviderName::from("fake".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
pub struct FakeLanguageModelProvider {
|
||||||
|
current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelProviderState for FakeLanguageModelProvider {
|
||||||
|
fn subscribe<T: 'static>(&self, _: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelProvider for FakeLanguageModelProvider {
|
||||||
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
|
provider_name()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provided_models(&self, _: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||||
|
vec![Arc::new(FakeLanguageModel {
|
||||||
|
current_completion_txs: self.current_completion_txs.clone(),
|
||||||
|
})]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_authenticated(&self, _: &AppContext) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authenticate(&self, _: &AppContext) -> Task<Result<()>> {
|
||||||
|
Task::ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authentication_prompt(&self, _: &mut WindowContext) -> AnyView {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reset_credentials(&self, _: &AppContext) -> Task<Result<()>> {
|
||||||
|
Task::ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FakeLanguageModelProvider {
|
||||||
|
pub fn test_model(&self) -> FakeLanguageModel {
|
||||||
|
FakeLanguageModel {
|
||||||
|
current_completion_txs: self.current_completion_txs.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FakeLanguageModel {
|
||||||
|
current_completion_txs: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FakeLanguageModel {
|
||||||
|
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
|
||||||
|
self.current_completion_txs
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.keys()
|
||||||
|
.map(|k| serde_json::from_str(k).unwrap())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn completion_count(&self) -> usize {
|
||||||
|
self.current_completion_txs.lock().unwrap().len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
|
||||||
|
let json = serde_json::to_string(request).unwrap();
|
||||||
|
self.current_completion_txs
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.get(&json)
|
||||||
|
.unwrap()
|
||||||
|
.unbounded_send(chunk)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_last_completion_chunk(&self, chunk: String) {
|
||||||
|
self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn finish_completion(&self, request: &LanguageModelRequest) {
|
||||||
|
self.current_completion_txs
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.remove(&serde_json::to_string(request).unwrap())
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn finish_last_completion(&self) {
|
||||||
|
self.finish_completion(self.pending_completions().last().unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModel for FakeLanguageModel {
|
||||||
|
fn id(&self) -> LanguageModelId {
|
||||||
|
language_model_id()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> LanguageModelName {
|
||||||
|
language_model_name()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
|
provider_name()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn telemetry_id(&self) -> String {
|
||||||
|
"fake".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_token_count(&self) -> usize {
|
||||||
|
1000000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn count_tokens(
|
||||||
|
&self,
|
||||||
|
_: LanguageModelRequest,
|
||||||
|
_: &AppContext,
|
||||||
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
|
futures::future::ready(Ok(0)).boxed()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stream_completion(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
_: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
|
let (tx, rx) = mpsc::unbounded();
|
||||||
|
self.current_completion_txs
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.insert(serde_json::to_string(&request).unwrap(), tx);
|
||||||
|
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,49 +1,148 @@
|
||||||
use crate::LanguageModelCompletionProvider;
|
use anyhow::{anyhow, Result};
|
||||||
use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
|
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
use anyhow::Result;
|
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
|
||||||
use futures::StreamExt as _;
|
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
|
|
||||||
use gpui::{AnyView, AppContext, Task};
|
|
||||||
use http::HttpClient;
|
use http::HttpClient;
|
||||||
use language_model::Role;
|
use ollama::{get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest};
|
||||||
use ollama::Model as OllamaModel;
|
use settings::{Settings, SettingsStore};
|
||||||
use ollama::{
|
use std::{sync::Arc, time::Duration};
|
||||||
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
|
||||||
};
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Duration;
|
|
||||||
use ui::{prelude::*, ButtonLike, ElevationIndex};
|
use ui::{prelude::*, ButtonLike, ElevationIndex};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||||
|
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
|
||||||
|
LanguageModelRequest, Role,
|
||||||
|
};
|
||||||
|
|
||||||
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
|
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
|
||||||
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
|
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
|
||||||
|
|
||||||
pub struct OllamaCompletionProvider {
|
const PROVIDER_NAME: &str = "ollama";
|
||||||
api_url: String,
|
|
||||||
model: OllamaModel,
|
#[derive(Default, Debug, Clone, PartialEq)]
|
||||||
http_client: Arc<dyn HttpClient>,
|
pub struct OllamaSettings {
|
||||||
low_speed_timeout: Option<Duration>,
|
pub api_url: String,
|
||||||
settings_version: usize,
|
pub low_speed_timeout: Option<Duration>,
|
||||||
available_models: Vec<OllamaModel>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModelCompletionProvider for OllamaCompletionProvider {
|
pub struct OllamaLanguageModelProvider {
|
||||||
fn available_models(&self) -> Vec<LanguageModel> {
|
http_client: Arc<dyn HttpClient>,
|
||||||
self.available_models
|
state: gpui::Model<State>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct State {
|
||||||
|
http_client: Arc<dyn HttpClient>,
|
||||||
|
available_models: Vec<ollama::Model>,
|
||||||
|
settings: OllamaSettings,
|
||||||
|
_subscription: Subscription,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl State {
|
||||||
|
fn fetch_models(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||||
|
let http_client = self.http_client.clone();
|
||||||
|
let api_url = self.settings.api_url.clone();
|
||||||
|
|
||||||
|
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
|
||||||
|
cx.spawn(|this, mut cx| async move {
|
||||||
|
let models = get_models(http_client.as_ref(), &api_url, None).await?;
|
||||||
|
|
||||||
|
let mut models: Vec<ollama::Model> = models
|
||||||
|
.into_iter()
|
||||||
|
// Since there is no metadata from the Ollama API
|
||||||
|
// indicating which models are embedding models,
|
||||||
|
// simply filter out models with "-embed" in their name
|
||||||
|
.filter(|model| !model.name.contains("-embed"))
|
||||||
|
.map(|model| ollama::Model::new(&model.name))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
|
|
||||||
|
this.update(&mut cx, |this, cx| {
|
||||||
|
this.available_models = models;
|
||||||
|
cx.notify();
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OllamaLanguageModelProvider {
|
||||||
|
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
|
||||||
|
Self {
|
||||||
|
http_client: http_client.clone(),
|
||||||
|
state: cx.new_model(|cx| State {
|
||||||
|
http_client,
|
||||||
|
available_models: Default::default(),
|
||||||
|
settings: OllamaSettings::default(),
|
||||||
|
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||||
|
this.settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
|
||||||
|
cx.notify();
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
let http_client = self.http_client.clone();
|
||||||
|
let api_url = self.state.read(cx).settings.api_url.clone();
|
||||||
|
|
||||||
|
let state = self.state.clone();
|
||||||
|
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
|
||||||
|
cx.spawn(|mut cx| async move {
|
||||||
|
let models = get_models(http_client.as_ref(), &api_url, None).await?;
|
||||||
|
|
||||||
|
let mut models: Vec<ollama::Model> = models
|
||||||
|
.into_iter()
|
||||||
|
// Since there is no metadata from the Ollama API
|
||||||
|
// indicating which models are embedding models,
|
||||||
|
// simply filter out models with "-embed" in their name
|
||||||
|
.filter(|model| !model.name.contains("-embed"))
|
||||||
|
.map(|model| ollama::Model::new(&model.name))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
|
|
||||||
|
state.update(&mut cx, |this, cx| {
|
||||||
|
this.available_models = models;
|
||||||
|
cx.notify();
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelProviderState for OllamaLanguageModelProvider {
|
||||||
|
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||||
|
Some(cx.observe(&self.state, |_, _, cx| {
|
||||||
|
cx.notify();
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelProvider for OllamaLanguageModelProvider {
|
||||||
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
|
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||||
|
self.state
|
||||||
|
.read(cx)
|
||||||
|
.available_models
|
||||||
.iter()
|
.iter()
|
||||||
.map(|m| LanguageModel::Ollama(m.clone()))
|
.map(|model| {
|
||||||
|
Arc::new(OllamaLanguageModel {
|
||||||
|
id: LanguageModelId::from(model.name.clone()),
|
||||||
|
model: model.clone(),
|
||||||
|
http_client: self.http_client.clone(),
|
||||||
|
state: self.state.clone(),
|
||||||
|
}) as Arc<dyn LanguageModel>
|
||||||
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn settings_version(&self) -> usize {
|
fn is_authenticated(&self, cx: &AppContext) -> bool {
|
||||||
self.settings_version
|
!self.state.read(cx).available_models.is_empty()
|
||||||
}
|
|
||||||
|
|
||||||
fn is_authenticated(&self) -> bool {
|
|
||||||
!self.available_models.is_empty()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
if self.is_authenticated() {
|
if self.is_authenticated(cx) {
|
||||||
Task::ready(Ok(()))
|
Task::ready(Ok(()))
|
||||||
} else {
|
} else {
|
||||||
self.fetch_models(cx)
|
self.fetch_models(cx)
|
||||||
|
@ -51,14 +150,9 @@ impl LanguageModelCompletionProvider for OllamaCompletionProvider {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||||
|
let state = self.state.clone();
|
||||||
let fetch_models = Box::new(move |cx: &mut WindowContext| {
|
let fetch_models = Box::new(move |cx: &mut WindowContext| {
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
state.update(cx, |this, cx| this.fetch_models(cx))
|
||||||
provider
|
|
||||||
.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
|
||||||
provider.fetch_models(cx)
|
|
||||||
})
|
|
||||||
.unwrap_or_else(|| Task::ready(Ok(())))
|
|
||||||
})
|
|
||||||
});
|
});
|
||||||
|
|
||||||
cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
|
cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
|
||||||
|
@ -68,9 +162,65 @@ impl LanguageModelCompletionProvider for OllamaCompletionProvider {
|
||||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
self.fetch_models(cx)
|
self.fetch_models(cx)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn model(&self) -> LanguageModel {
|
pub struct OllamaLanguageModel {
|
||||||
LanguageModel::Ollama(self.model.clone())
|
id: LanguageModelId,
|
||||||
|
model: ollama::Model,
|
||||||
|
state: gpui::Model<State>,
|
||||||
|
http_client: Arc<dyn HttpClient>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OllamaLanguageModel {
|
||||||
|
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
|
||||||
|
ChatRequest {
|
||||||
|
model: self.model.name.clone(),
|
||||||
|
messages: request
|
||||||
|
.messages
|
||||||
|
.into_iter()
|
||||||
|
.map(|msg| match msg.role {
|
||||||
|
Role::User => ChatMessage::User {
|
||||||
|
content: msg.content,
|
||||||
|
},
|
||||||
|
Role::Assistant => ChatMessage::Assistant {
|
||||||
|
content: msg.content,
|
||||||
|
},
|
||||||
|
Role::System => ChatMessage::System {
|
||||||
|
content: msg.content,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
|
||||||
|
stream: true,
|
||||||
|
options: Some(ChatOptions {
|
||||||
|
num_ctx: Some(self.model.max_tokens),
|
||||||
|
stop: Some(request.stop),
|
||||||
|
temperature: Some(request.temperature),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModel for OllamaLanguageModel {
|
||||||
|
fn id(&self) -> LanguageModelId {
|
||||||
|
self.id.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> LanguageModelName {
|
||||||
|
LanguageModelName::from(self.model.display_name().to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_token_count(&self) -> usize {
|
||||||
|
self.model.max_token_count()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn telemetry_id(&self) -> String {
|
||||||
|
format!("ollama/{}", self.model.id())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
|
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn count_tokens(
|
fn count_tokens(
|
||||||
|
@ -93,12 +243,20 @@ impl LanguageModelCompletionProvider for OllamaCompletionProvider {
|
||||||
fn stream_completion(
|
fn stream_completion(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
|
cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
let request = self.to_ollama_request(request);
|
let request = self.to_ollama_request(request);
|
||||||
|
|
||||||
let http_client = self.http_client.clone();
|
let http_client = self.http_client.clone();
|
||||||
let api_url = self.api_url.clone();
|
let Ok((api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
|
||||||
let low_speed_timeout = self.low_speed_timeout;
|
(
|
||||||
|
state.settings.api_url.clone(),
|
||||||
|
state.settings.low_speed_timeout,
|
||||||
|
)
|
||||||
|
}) else {
|
||||||
|
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||||
|
};
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
let request =
|
let request =
|
||||||
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
|
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
|
||||||
|
@ -122,143 +280,6 @@ impl LanguageModelCompletionProvider for OllamaCompletionProvider {
|
||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OllamaCompletionProvider {
|
|
||||||
pub fn new(
|
|
||||||
model: OllamaModel,
|
|
||||||
api_url: String,
|
|
||||||
http_client: Arc<dyn HttpClient>,
|
|
||||||
low_speed_timeout: Option<Duration>,
|
|
||||||
settings_version: usize,
|
|
||||||
cx: &AppContext,
|
|
||||||
) -> Self {
|
|
||||||
cx.spawn({
|
|
||||||
let api_url = api_url.clone();
|
|
||||||
let client = http_client.clone();
|
|
||||||
let model = model.name.clone();
|
|
||||||
|
|
||||||
|_| async move {
|
|
||||||
if model.is_empty() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
preload_model(client.as_ref(), &api_url, &model).await
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.detach_and_log_err(cx);
|
|
||||||
|
|
||||||
Self {
|
|
||||||
api_url,
|
|
||||||
model,
|
|
||||||
http_client,
|
|
||||||
low_speed_timeout,
|
|
||||||
settings_version,
|
|
||||||
available_models: Default::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn update(
|
|
||||||
&mut self,
|
|
||||||
model: OllamaModel,
|
|
||||||
api_url: String,
|
|
||||||
low_speed_timeout: Option<Duration>,
|
|
||||||
settings_version: usize,
|
|
||||||
cx: &AppContext,
|
|
||||||
) {
|
|
||||||
cx.spawn({
|
|
||||||
let api_url = api_url.clone();
|
|
||||||
let client = self.http_client.clone();
|
|
||||||
let model = model.name.clone();
|
|
||||||
|
|
||||||
|_| async move { preload_model(client.as_ref(), &api_url, &model).await }
|
|
||||||
})
|
|
||||||
.detach_and_log_err(cx);
|
|
||||||
|
|
||||||
if model.name.is_empty() {
|
|
||||||
self.select_first_available_model()
|
|
||||||
} else {
|
|
||||||
self.model = model;
|
|
||||||
}
|
|
||||||
|
|
||||||
self.api_url = api_url;
|
|
||||||
self.low_speed_timeout = low_speed_timeout;
|
|
||||||
self.settings_version = settings_version;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn select_first_available_model(&mut self) {
|
|
||||||
if let Some(model) = self.available_models.first() {
|
|
||||||
self.model = model.clone();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
|
|
||||||
let http_client = self.http_client.clone();
|
|
||||||
let api_url = self.api_url.clone();
|
|
||||||
|
|
||||||
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
|
|
||||||
cx.spawn(|mut cx| async move {
|
|
||||||
let models = get_models(http_client.as_ref(), &api_url, None).await?;
|
|
||||||
|
|
||||||
let mut models: Vec<OllamaModel> = models
|
|
||||||
.into_iter()
|
|
||||||
// Since there is no metadata from the Ollama API
|
|
||||||
// indicating which models are embedding models,
|
|
||||||
// simply filter out models with "-embed" in their name
|
|
||||||
.filter(|model| !model.name.contains("-embed"))
|
|
||||||
.map(|model| OllamaModel::new(&model.name))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
models.sort_by(|a, b| a.name.cmp(&b.name));
|
|
||||||
|
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
|
||||||
provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
|
||||||
provider.available_models = models;
|
|
||||||
|
|
||||||
if !provider.available_models.is_empty() && provider.model.name.is_empty() {
|
|
||||||
provider.select_first_available_model()
|
|
||||||
}
|
|
||||||
});
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
|
|
||||||
let model = match request.model {
|
|
||||||
LanguageModel::Ollama(model) => model,
|
|
||||||
_ => self.model.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
ChatRequest {
|
|
||||||
model: model.name,
|
|
||||||
messages: request
|
|
||||||
.messages
|
|
||||||
.into_iter()
|
|
||||||
.map(|msg| match msg.role {
|
|
||||||
Role::User => ChatMessage::User {
|
|
||||||
content: msg.content,
|
|
||||||
},
|
|
||||||
Role::Assistant => ChatMessage::Assistant {
|
|
||||||
content: msg.content,
|
|
||||||
},
|
|
||||||
Role::System => ChatMessage::System {
|
|
||||||
content: msg.content,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
keep_alive: model.keep_alive.unwrap_or_default(),
|
|
||||||
stream: true,
|
|
||||||
options: Some(ChatOptions {
|
|
||||||
num_ctx: Some(model.max_tokens),
|
|
||||||
stop: Some(request.stop),
|
|
||||||
temperature: Some(request.temperature),
|
|
||||||
..Default::default()
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct DownloadOllamaMessage {
|
struct DownloadOllamaMessage {
|
|
@ -1,72 +1,159 @@
|
||||||
use crate::CompletionProvider;
|
|
||||||
use crate::LanguageModelCompletionProvider;
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
|
use collections::HashMap;
|
||||||
use editor::{Editor, EditorElement, EditorStyle};
|
use editor::{Editor, EditorElement, EditorStyle};
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
use futures::{future::BoxFuture, FutureExt, StreamExt};
|
||||||
use gpui::{AnyView, AppContext, Task, TextStyle, View};
|
use gpui::{
|
||||||
|
AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View,
|
||||||
|
WhiteSpace,
|
||||||
|
};
|
||||||
use http::HttpClient;
|
use http::HttpClient;
|
||||||
use language_model::{CloudModel, LanguageModel, LanguageModelRequest, Role};
|
|
||||||
use open_ai::Model as OpenAiModel;
|
|
||||||
use open_ai::{stream_completion, Request, RequestMessage};
|
use open_ai::{stream_completion, Request, RequestMessage};
|
||||||
use settings::Settings;
|
use settings::{Settings, SettingsStore};
|
||||||
use std::time::Duration;
|
use std::{sync::Arc, time::Duration};
|
||||||
use std::{env, sync::Arc};
|
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
|
||||||
pub struct OpenAiCompletionProvider {
|
use crate::{
|
||||||
api_key: Option<String>,
|
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||||
api_url: String,
|
LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
|
||||||
model: OpenAiModel,
|
LanguageModelRequest, Role,
|
||||||
http_client: Arc<dyn HttpClient>,
|
};
|
||||||
low_speed_timeout: Option<Duration>,
|
|
||||||
settings_version: usize,
|
const PROVIDER_NAME: &str = "openai";
|
||||||
available_models_from_settings: Vec<OpenAiModel>,
|
|
||||||
|
#[derive(Default, Clone, Debug, PartialEq)]
|
||||||
|
pub struct OpenAiSettings {
|
||||||
|
pub api_url: String,
|
||||||
|
pub low_speed_timeout: Option<Duration>,
|
||||||
|
pub available_models: Vec<open_ai::Model>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAiCompletionProvider {
|
pub struct OpenAiLanguageModelProvider {
|
||||||
pub fn new(
|
|
||||||
model: OpenAiModel,
|
|
||||||
api_url: String,
|
|
||||||
http_client: Arc<dyn HttpClient>,
|
http_client: Arc<dyn HttpClient>,
|
||||||
low_speed_timeout: Option<Duration>,
|
state: gpui::Model<State>,
|
||||||
settings_version: usize,
|
}
|
||||||
available_models_from_settings: Vec<OpenAiModel>,
|
|
||||||
) -> Self {
|
struct State {
|
||||||
Self {
|
api_key: Option<String>,
|
||||||
|
settings: OpenAiSettings,
|
||||||
|
_subscription: Subscription,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAiLanguageModelProvider {
|
||||||
|
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
|
||||||
|
let state = cx.new_model(|cx| State {
|
||||||
api_key: None,
|
api_key: None,
|
||||||
api_url,
|
settings: OpenAiSettings::default(),
|
||||||
|
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||||
|
this.settings = AllLanguageModelSettings::get_global(cx).open_ai.clone();
|
||||||
|
cx.notify();
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
Self { http_client, state }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelProviderState for OpenAiLanguageModelProvider {
|
||||||
|
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
|
||||||
|
Some(cx.observe(&self.state, |_, _, cx| {
|
||||||
|
cx.notify();
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
||||||
|
fn name(&self) -> LanguageModelProviderName {
|
||||||
|
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||||
|
let mut models = HashMap::default();
|
||||||
|
|
||||||
|
// Add base models from open_ai::Model::iter()
|
||||||
|
for model in open_ai::Model::iter() {
|
||||||
|
if !matches!(model, open_ai::Model::Custom { .. }) {
|
||||||
|
models.insert(model.id().to_string(), model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override with available models from settings
|
||||||
|
for model in &self.state.read(cx).settings.available_models {
|
||||||
|
models.insert(model.id().to_string(), model.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
models
|
||||||
|
.into_values()
|
||||||
|
.map(|model| {
|
||||||
|
Arc::new(OpenAiLanguageModel {
|
||||||
|
id: LanguageModelId::from(model.id().to_string()),
|
||||||
model,
|
model,
|
||||||
http_client,
|
state: self.state.clone(),
|
||||||
low_speed_timeout,
|
http_client: self.http_client.clone(),
|
||||||
settings_version,
|
}) as Arc<dyn LanguageModel>
|
||||||
available_models_from_settings,
|
})
|
||||||
}
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update(
|
fn is_authenticated(&self, cx: &AppContext) -> bool {
|
||||||
&mut self,
|
self.state.read(cx).api_key.is_some()
|
||||||
model: OpenAiModel,
|
|
||||||
api_url: String,
|
|
||||||
low_speed_timeout: Option<Duration>,
|
|
||||||
settings_version: usize,
|
|
||||||
) {
|
|
||||||
self.model = model;
|
|
||||||
self.api_url = api_url;
|
|
||||||
self.low_speed_timeout = low_speed_timeout;
|
|
||||||
self.settings_version = settings_version;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
let model = match request.model {
|
if self.is_authenticated(cx) {
|
||||||
LanguageModel::OpenAi(model) => model,
|
Task::ready(Ok(()))
|
||||||
_ => self.model.clone(),
|
} else {
|
||||||
|
let api_url = self.state.read(cx).settings.api_url.clone();
|
||||||
|
let state = self.state.clone();
|
||||||
|
cx.spawn(|mut cx| async move {
|
||||||
|
let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
|
||||||
|
api_key
|
||||||
|
} else {
|
||||||
|
let (_, api_key) = cx
|
||||||
|
.update(|cx| cx.read_credentials(&api_url))?
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||||
|
String::from_utf8(api_key)?
|
||||||
};
|
};
|
||||||
|
state.update(&mut cx, |this, cx| {
|
||||||
|
this.api_key = Some(api_key);
|
||||||
|
cx.notify();
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||||
|
cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx))
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
|
||||||
|
let state = self.state.clone();
|
||||||
|
cx.spawn(|mut cx| async move {
|
||||||
|
delete_credentials.await.log_err();
|
||||||
|
state.update(&mut cx, |this, cx| {
|
||||||
|
this.api_key = None;
|
||||||
|
cx.notify();
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct OpenAiLanguageModel {
|
||||||
|
id: LanguageModelId,
|
||||||
|
model: open_ai::Model,
|
||||||
|
state: gpui::Model<State>,
|
||||||
|
http_client: Arc<dyn HttpClient>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAiLanguageModel {
|
||||||
|
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
||||||
Request {
|
Request {
|
||||||
model,
|
model: self.model.clone(),
|
||||||
messages: request
|
messages: request
|
||||||
.messages
|
.messages
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -92,80 +179,25 @@ impl OpenAiCompletionProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
|
impl LanguageModel for OpenAiLanguageModel {
|
||||||
fn available_models(&self) -> Vec<LanguageModel> {
|
fn id(&self) -> LanguageModelId {
|
||||||
if self.available_models_from_settings.is_empty() {
|
self.id.clone()
|
||||||
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
|
|
||||||
vec![self.model.clone()]
|
|
||||||
} else {
|
|
||||||
OpenAiModel::iter()
|
|
||||||
.filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
|
|
||||||
.collect()
|
|
||||||
};
|
|
||||||
available_models
|
|
||||||
.into_iter()
|
|
||||||
.map(LanguageModel::OpenAi)
|
|
||||||
.collect()
|
|
||||||
} else {
|
|
||||||
self.available_models_from_settings
|
|
||||||
.iter()
|
|
||||||
.cloned()
|
|
||||||
.map(LanguageModel::OpenAi)
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn settings_version(&self) -> usize {
|
fn name(&self) -> LanguageModelName {
|
||||||
self.settings_version
|
LanguageModelName::from(self.model.display_name().to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_authenticated(&self) -> bool {
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
self.api_key.is_some()
|
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn telemetry_id(&self) -> String {
|
||||||
if self.is_authenticated() {
|
format!("openai/{}", self.model.id())
|
||||||
Task::ready(Ok(()))
|
|
||||||
} else {
|
|
||||||
let api_url = self.api_url.clone();
|
|
||||||
cx.spawn(|mut cx| async move {
|
|
||||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
|
||||||
api_key
|
|
||||||
} else {
|
|
||||||
let (_, api_key) = cx
|
|
||||||
.update(|cx| cx.read_credentials(&api_url))?
|
|
||||||
.await?
|
|
||||||
.ok_or_else(|| anyhow!("credentials not found"))?;
|
|
||||||
String::from_utf8(api_key)?
|
|
||||||
};
|
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
|
||||||
provider.update_current_as::<_, Self>(|provider| {
|
|
||||||
provider.api_key = Some(api_key);
|
|
||||||
});
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn max_token_count(&self) -> usize {
|
||||||
let delete_credentials = cx.delete_credentials(&self.api_url);
|
self.model.max_token_count()
|
||||||
cx.spawn(|mut cx| async move {
|
|
||||||
delete_credentials.await.log_err();
|
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
|
||||||
provider.update_current_as::<_, Self>(|provider| {
|
|
||||||
provider.api_key = None;
|
|
||||||
});
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
|
||||||
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
|
||||||
.into()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn model(&self) -> LanguageModel {
|
|
||||||
LanguageModel::OpenAi(self.model.clone())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn count_tokens(
|
fn count_tokens(
|
||||||
|
@ -173,19 +205,27 @@ impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AppContext,
|
cx: &AppContext,
|
||||||
) -> BoxFuture<'static, Result<usize>> {
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
count_open_ai_tokens(request, cx.background_executor())
|
count_open_ai_tokens(request, self.model.clone(), cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn stream_completion(
|
fn stream_completion(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
cx: &AsyncAppContext,
|
||||||
|
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||||
let request = self.to_open_ai_request(request);
|
let request = self.to_open_ai_request(request);
|
||||||
|
|
||||||
let http_client = self.http_client.clone();
|
let http_client = self.http_client.clone();
|
||||||
let api_key = self.api_key.clone();
|
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
|
||||||
let api_url = self.api_url.clone();
|
(
|
||||||
let low_speed_timeout = self.low_speed_timeout;
|
state.api_key.clone(),
|
||||||
|
state.settings.api_url.clone(),
|
||||||
|
state.settings.low_speed_timeout,
|
||||||
|
)
|
||||||
|
}) else {
|
||||||
|
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||||
|
};
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||||
let request = stream_completion(
|
let request = stream_completion(
|
||||||
|
@ -208,17 +248,14 @@ impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
|
||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn count_open_ai_tokens(
|
pub fn count_open_ai_tokens(
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
background_executor: &gpui::BackgroundExecutor,
|
model: open_ai::Model,
|
||||||
|
cx: &AppContext,
|
||||||
) -> BoxFuture<'static, Result<usize>> {
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
background_executor
|
cx.background_executor()
|
||||||
.spawn(async move {
|
.spawn(async move {
|
||||||
let messages = request
|
let messages = request
|
||||||
.messages
|
.messages
|
||||||
|
@ -235,19 +272,10 @@ pub fn count_open_ai_tokens(
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
match request.model {
|
if let open_ai::Model::Custom { .. } = model {
|
||||||
LanguageModel::Anthropic(_)
|
|
||||||
| LanguageModel::Cloud(CloudModel::Claude3_5Sonnet)
|
|
||||||
| LanguageModel::Cloud(CloudModel::Claude3Opus)
|
|
||||||
| LanguageModel::Cloud(CloudModel::Claude3Sonnet)
|
|
||||||
| LanguageModel::Cloud(CloudModel::Claude3Haiku)
|
|
||||||
| LanguageModel::Cloud(CloudModel::Custom { .. })
|
|
||||||
| LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
|
|
||||||
// Tiktoken doesn't yet support these models, so we manually use the
|
|
||||||
// same tokenizer as GPT-4.
|
|
||||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
|
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
|
||||||
}
|
} else {
|
||||||
_ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
|
tiktoken_rs::num_tokens_from_messages(model.id(), &messages)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
|
@ -255,11 +283,11 @@ pub fn count_open_ai_tokens(
|
||||||
|
|
||||||
struct AuthenticationPrompt {
|
struct AuthenticationPrompt {
|
||||||
api_key: View<Editor>,
|
api_key: View<Editor>,
|
||||||
api_url: String,
|
state: gpui::Model<State>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AuthenticationPrompt {
|
impl AuthenticationPrompt {
|
||||||
fn new(api_url: String, cx: &mut WindowContext) -> Self {
|
fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
|
||||||
Self {
|
Self {
|
||||||
api_key: cx.new_view(|cx| {
|
api_key: cx.new_view(|cx| {
|
||||||
let mut editor = Editor::single_line(cx);
|
let mut editor = Editor::single_line(cx);
|
||||||
|
@ -269,7 +297,7 @@ impl AuthenticationPrompt {
|
||||||
);
|
);
|
||||||
editor
|
editor
|
||||||
}),
|
}),
|
||||||
api_url,
|
state,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -279,13 +307,17 @@ impl AuthenticationPrompt {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
|
let write_credentials = cx.write_credentials(
|
||||||
|
&self.state.read(cx).settings.api_url,
|
||||||
|
"Bearer",
|
||||||
|
api_key.as_bytes(),
|
||||||
|
);
|
||||||
|
let state = self.state.clone();
|
||||||
cx.spawn(|_, mut cx| async move {
|
cx.spawn(|_, mut cx| async move {
|
||||||
write_credentials.await?;
|
write_credentials.await?;
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
state.update(&mut cx, |this, cx| {
|
||||||
provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
|
this.api_key = Some(api_key);
|
||||||
provider.api_key = Some(api_key);
|
cx.notify();
|
||||||
});
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.detach_and_log_err(cx);
|
.detach_and_log_err(cx);
|
||||||
|
@ -299,8 +331,12 @@ impl AuthenticationPrompt {
|
||||||
font_features: settings.ui_font.features.clone(),
|
font_features: settings.ui_font.features.clone(),
|
||||||
font_size: rems(0.875).into(),
|
font_size: rems(0.875).into(),
|
||||||
font_weight: settings.ui_font.weight,
|
font_weight: settings.ui_font.weight,
|
||||||
|
font_style: FontStyle::Normal,
|
||||||
line_height: relative(1.3),
|
line_height: relative(1.3),
|
||||||
..Default::default()
|
background_color: None,
|
||||||
|
underline: None,
|
||||||
|
strikethrough: None,
|
||||||
|
white_space: WhiteSpace::Normal,
|
||||||
};
|
};
|
||||||
EditorElement::new(
|
EditorElement::new(
|
||||||
&self.api_key,
|
&self.api_key,
|
172
crates/language_model/src/registry.rs
Normal file
172
crates/language_model/src/registry.rs
Normal file
|
@ -0,0 +1,172 @@
|
||||||
|
use client::Client;
|
||||||
|
use collections::HashMap;
|
||||||
|
use gpui::{AppContext, Global, Model, ModelContext};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use ui::Context;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
provider::{
|
||||||
|
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
|
||||||
|
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
|
||||||
|
},
|
||||||
|
LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||||
|
let registry = cx.new_model(|cx| {
|
||||||
|
let mut registry = LanguageModelRegistry::default();
|
||||||
|
register_language_model_providers(&mut registry, client, cx);
|
||||||
|
registry
|
||||||
|
});
|
||||||
|
cx.set_global(GlobalLanguageModelRegistry(registry));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_language_model_providers(
|
||||||
|
registry: &mut LanguageModelRegistry,
|
||||||
|
client: Arc<Client>,
|
||||||
|
cx: &mut ModelContext<LanguageModelRegistry>,
|
||||||
|
) {
|
||||||
|
use feature_flags::FeatureFlagAppExt;
|
||||||
|
|
||||||
|
registry.register_provider(
|
||||||
|
AnthropicLanguageModelProvider::new(client.http_client(), cx),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
registry.register_provider(
|
||||||
|
OpenAiLanguageModelProvider::new(client.http_client(), cx),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
registry.register_provider(
|
||||||
|
OllamaLanguageModelProvider::new(client.http_client(), cx),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
|
||||||
|
cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
|
||||||
|
let client = client.clone();
|
||||||
|
LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
|
||||||
|
if enabled {
|
||||||
|
registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
|
||||||
|
} else {
|
||||||
|
registry.unregister_provider(
|
||||||
|
&LanguageModelProviderName::from(
|
||||||
|
crate::provider::cloud::PROVIDER_NAME.to_string(),
|
||||||
|
),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GlobalLanguageModelRegistry(Model<LanguageModelRegistry>);
|
||||||
|
|
||||||
|
impl Global for GlobalLanguageModelRegistry {}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct LanguageModelRegistry {
|
||||||
|
providers: HashMap<LanguageModelProviderName, Arc<dyn LanguageModelProvider>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelRegistry {
|
||||||
|
pub fn global(cx: &AppContext) -> Model<Self> {
|
||||||
|
cx.global::<GlobalLanguageModelRegistry>().0.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_global(cx: &AppContext) -> &Self {
|
||||||
|
cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
|
pub fn test(cx: &mut AppContext) -> crate::provider::fake::FakeLanguageModelProvider {
|
||||||
|
let fake_provider = crate::provider::fake::FakeLanguageModelProvider::default();
|
||||||
|
let registry = cx.new_model(|cx| {
|
||||||
|
let mut registry = Self::default();
|
||||||
|
registry.register_provider(fake_provider.clone(), cx);
|
||||||
|
registry
|
||||||
|
});
|
||||||
|
cx.set_global(GlobalLanguageModelRegistry(registry));
|
||||||
|
fake_provider
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
|
||||||
|
&mut self,
|
||||||
|
provider: T,
|
||||||
|
cx: &mut ModelContext<Self>,
|
||||||
|
) {
|
||||||
|
let name = provider.name();
|
||||||
|
|
||||||
|
if let Some(subscription) = provider.subscribe(cx) {
|
||||||
|
subscription.detach();
|
||||||
|
}
|
||||||
|
|
||||||
|
self.providers.insert(name, Arc::new(provider));
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unregister_provider(
|
||||||
|
&mut self,
|
||||||
|
name: &LanguageModelProviderName,
|
||||||
|
cx: &mut ModelContext<Self>,
|
||||||
|
) {
|
||||||
|
if self.providers.remove(name).is_some() {
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn providers(
|
||||||
|
&self,
|
||||||
|
) -> impl Iterator<Item = (&LanguageModelProviderName, &Arc<dyn LanguageModelProvider>)> {
|
||||||
|
self.providers.iter()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||||
|
self.providers
|
||||||
|
.values()
|
||||||
|
.flat_map(|provider| provider.provided_models(cx))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn available_models_grouped_by_provider(
|
||||||
|
&self,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> HashMap<LanguageModelProviderName, Vec<Arc<dyn LanguageModel>>> {
|
||||||
|
self.providers
|
||||||
|
.iter()
|
||||||
|
.map(|(name, provider)| (name.clone(), provider.provided_models(cx)))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn provider(
|
||||||
|
&self,
|
||||||
|
name: &LanguageModelProviderName,
|
||||||
|
) -> Option<Arc<dyn LanguageModelProvider>> {
|
||||||
|
self.providers.get(name).cloned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::provider::fake::FakeLanguageModelProvider;
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
fn test_register_providers(cx: &mut AppContext) {
|
||||||
|
let registry = cx.new_model(|_| LanguageModelRegistry::default());
|
||||||
|
|
||||||
|
registry.update(cx, |registry, cx| {
|
||||||
|
registry.register_provider(FakeLanguageModelProvider::default(), cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
let providers = registry.read(cx).providers().collect::<Vec<_>>();
|
||||||
|
assert_eq!(providers.len(), 1);
|
||||||
|
assert_eq!(providers[0].0, &crate::provider::fake::provider_name());
|
||||||
|
|
||||||
|
registry.update(cx, |registry, cx| {
|
||||||
|
registry.unregister_provider(&crate::provider::fake::provider_name(), cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
let providers = registry.read(cx).providers().collect::<Vec<_>>();
|
||||||
|
assert!(providers.is_empty());
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,7 +1,4 @@
|
||||||
use crate::{
|
use crate::{role::Role, LanguageModelId};
|
||||||
model::{CloudModel, LanguageModel},
|
|
||||||
role::Role,
|
|
||||||
};
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
@ -23,16 +20,15 @@ impl LanguageModelRequestMessage {
|
||||||
|
|
||||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||||
pub struct LanguageModelRequest {
|
pub struct LanguageModelRequest {
|
||||||
pub model: LanguageModel,
|
|
||||||
pub messages: Vec<LanguageModelRequestMessage>,
|
pub messages: Vec<LanguageModelRequestMessage>,
|
||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
pub temperature: f32,
|
pub temperature: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModelRequest {
|
impl LanguageModelRequest {
|
||||||
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
|
pub fn to_proto(&self, model_id: LanguageModelId) -> proto::CompleteWithLanguageModel {
|
||||||
proto::CompleteWithLanguageModel {
|
proto::CompleteWithLanguageModel {
|
||||||
model: self.model.id().to_string(),
|
model: model_id.0.to_string(),
|
||||||
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
|
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
|
||||||
stop: self.stop.clone(),
|
stop: self.stop.clone(),
|
||||||
temperature: self.temperature,
|
temperature: self.temperature,
|
||||||
|
@ -40,70 +36,6 @@ impl LanguageModelRequest {
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Before we send the request to the server, we can perform fixups on it appropriate to the model.
|
|
||||||
pub fn preprocess(&mut self) {
|
|
||||||
match &self.model {
|
|
||||||
LanguageModel::OpenAi(_) => {}
|
|
||||||
LanguageModel::Anthropic(_) => self.preprocess_anthropic(),
|
|
||||||
LanguageModel::Ollama(_) => {}
|
|
||||||
LanguageModel::Cloud(model) => match model {
|
|
||||||
CloudModel::Claude3Opus
|
|
||||||
| CloudModel::Claude3Sonnet
|
|
||||||
| CloudModel::Claude3Haiku
|
|
||||||
| CloudModel::Claude3_5Sonnet => {
|
|
||||||
self.preprocess_anthropic();
|
|
||||||
}
|
|
||||||
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
|
|
||||||
self.preprocess_anthropic();
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn preprocess_anthropic(&mut self) {
|
|
||||||
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
|
|
||||||
let mut system_message = String::new();
|
|
||||||
|
|
||||||
for message in self.messages.drain(..) {
|
|
||||||
if message.content.is_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
match message.role {
|
|
||||||
Role::User | Role::Assistant => {
|
|
||||||
if let Some(last_message) = new_messages.last_mut() {
|
|
||||||
if last_message.role == message.role {
|
|
||||||
last_message.content.push_str("\n\n");
|
|
||||||
last_message.content.push_str(&message.content);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
new_messages.push(message);
|
|
||||||
}
|
|
||||||
Role::System => {
|
|
||||||
if !system_message.is_empty() {
|
|
||||||
system_message.push_str("\n\n");
|
|
||||||
}
|
|
||||||
system_message.push_str(&message.content);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !system_message.is_empty() {
|
|
||||||
new_messages.insert(
|
|
||||||
0,
|
|
||||||
LanguageModelRequestMessage {
|
|
||||||
role: Role::System,
|
|
||||||
content: system_message,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.messages = new_messages;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
|
143
crates/language_model/src/settings.rs
Normal file
143
crates/language_model/src/settings.rs
Normal file
|
@ -0,0 +1,143 @@
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use gpui::AppContext;
|
||||||
|
use schemars::JsonSchema;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use settings::{Settings, SettingsSources};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
provider::{
|
||||||
|
anthropic::AnthropicSettings, cloud::ZedDotDevSettings, ollama::OllamaSettings,
|
||||||
|
open_ai::OpenAiSettings,
|
||||||
|
},
|
||||||
|
CloudModel,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Initializes the language model settings.
|
||||||
|
pub fn init(cx: &mut AppContext) {
|
||||||
|
AllLanguageModelSettings::register(cx);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct AllLanguageModelSettings {
|
||||||
|
pub open_ai: OpenAiSettings,
|
||||||
|
pub anthropic: AnthropicSettings,
|
||||||
|
pub ollama: OllamaSettings,
|
||||||
|
pub zed_dot_dev: ZedDotDevSettings,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||||
|
pub struct AllLanguageModelSettingsContent {
|
||||||
|
pub anthropic: Option<AnthropicSettingsContent>,
|
||||||
|
pub ollama: Option<OllamaSettingsContent>,
|
||||||
|
pub open_ai: Option<OpenAiSettingsContent>,
|
||||||
|
#[serde(rename = "zed.dev")]
|
||||||
|
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||||
|
pub struct AnthropicSettingsContent {
|
||||||
|
pub api_url: Option<String>,
|
||||||
|
pub low_speed_timeout_in_seconds: Option<u64>,
|
||||||
|
pub available_models: Option<Vec<anthropic::Model>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||||
|
pub struct OllamaSettingsContent {
|
||||||
|
pub api_url: Option<String>,
|
||||||
|
pub low_speed_timeout_in_seconds: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||||
|
pub struct OpenAiSettingsContent {
|
||||||
|
pub api_url: Option<String>,
|
||||||
|
pub low_speed_timeout_in_seconds: Option<u64>,
|
||||||
|
pub available_models: Option<Vec<open_ai::Model>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||||
|
pub struct ZedDotDevSettingsContent {
|
||||||
|
available_models: Option<Vec<CloudModel>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl settings::Settings for AllLanguageModelSettings {
|
||||||
|
const KEY: Option<&'static str> = Some("language_models");
|
||||||
|
|
||||||
|
type FileContent = AllLanguageModelSettingsContent;
|
||||||
|
|
||||||
|
fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
|
||||||
|
fn merge<T>(target: &mut T, value: Option<T>) {
|
||||||
|
if let Some(value) = value {
|
||||||
|
*target = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut settings = AllLanguageModelSettings::default();
|
||||||
|
|
||||||
|
for value in sources.defaults_and_customizations() {
|
||||||
|
merge(
|
||||||
|
&mut settings.anthropic.api_url,
|
||||||
|
value.anthropic.as_ref().and_then(|s| s.api_url.clone()),
|
||||||
|
);
|
||||||
|
if let Some(low_speed_timeout_in_seconds) = value
|
||||||
|
.anthropic
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|s| s.low_speed_timeout_in_seconds)
|
||||||
|
{
|
||||||
|
settings.anthropic.low_speed_timeout =
|
||||||
|
Some(Duration::from_secs(low_speed_timeout_in_seconds));
|
||||||
|
}
|
||||||
|
merge(
|
||||||
|
&mut settings.anthropic.available_models,
|
||||||
|
value
|
||||||
|
.anthropic
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|s| s.available_models.clone()),
|
||||||
|
);
|
||||||
|
|
||||||
|
merge(
|
||||||
|
&mut settings.ollama.api_url,
|
||||||
|
value.ollama.as_ref().and_then(|s| s.api_url.clone()),
|
||||||
|
);
|
||||||
|
if let Some(low_speed_timeout_in_seconds) = value
|
||||||
|
.ollama
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|s| s.low_speed_timeout_in_seconds)
|
||||||
|
{
|
||||||
|
settings.ollama.low_speed_timeout =
|
||||||
|
Some(Duration::from_secs(low_speed_timeout_in_seconds));
|
||||||
|
}
|
||||||
|
|
||||||
|
merge(
|
||||||
|
&mut settings.open_ai.api_url,
|
||||||
|
value.open_ai.as_ref().and_then(|s| s.api_url.clone()),
|
||||||
|
);
|
||||||
|
if let Some(low_speed_timeout_in_seconds) = value
|
||||||
|
.open_ai
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|s| s.low_speed_timeout_in_seconds)
|
||||||
|
{
|
||||||
|
settings.open_ai.low_speed_timeout =
|
||||||
|
Some(Duration::from_secs(low_speed_timeout_in_seconds));
|
||||||
|
}
|
||||||
|
merge(
|
||||||
|
&mut settings.open_ai.available_models,
|
||||||
|
value
|
||||||
|
.open_ai
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|s| s.available_models.clone()),
|
||||||
|
);
|
||||||
|
|
||||||
|
merge(
|
||||||
|
&mut settings.zed_dot_dev.available_models,
|
||||||
|
value
|
||||||
|
.zed_dot_dev
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|s| s.available_models.clone()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(settings)
|
||||||
|
}
|
||||||
|
}
|
|
@ -77,14 +77,14 @@ impl Model {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn id(&self) -> &'static str {
|
pub fn id(&self) -> &str {
|
||||||
match self {
|
match self {
|
||||||
Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
|
Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
|
||||||
Self::Four => "gpt-4",
|
Self::Four => "gpt-4",
|
||||||
Self::FourTurbo => "gpt-4-turbo-preview",
|
Self::FourTurbo => "gpt-4-turbo-preview",
|
||||||
Self::FourOmni => "gpt-4o",
|
Self::FourOmni => "gpt-4o",
|
||||||
Self::FourOmniMini => "gpt-4o-mini",
|
Self::FourOmniMini => "gpt-4o-mini",
|
||||||
Self::Custom { .. } => "custom",
|
Self::Custom { name, .. } => name,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2785,7 +2785,7 @@ impl Panel for OutlinePanel {
|
||||||
settings::update_settings_file::<OutlinePanelSettings>(
|
settings::update_settings_file::<OutlinePanelSettings>(
|
||||||
self.fs.clone(),
|
self.fs.clone(),
|
||||||
cx,
|
cx,
|
||||||
move |settings| {
|
move |settings, _| {
|
||||||
let dock = match position {
|
let dock = match position {
|
||||||
DockPosition::Left | DockPosition::Bottom => OutlinePanelDockPosition::Left,
|
DockPosition::Left | DockPosition::Bottom => OutlinePanelDockPosition::Left,
|
||||||
DockPosition::Right => OutlinePanelDockPosition::Right,
|
DockPosition::Right => OutlinePanelDockPosition::Right,
|
||||||
|
|
|
@ -2572,7 +2572,7 @@ impl Panel for ProjectPanel {
|
||||||
settings::update_settings_file::<ProjectPanelSettings>(
|
settings::update_settings_file::<ProjectPanelSettings>(
|
||||||
self.fs.clone(),
|
self.fs.clone(),
|
||||||
cx,
|
cx,
|
||||||
move |settings| {
|
move |settings, _| {
|
||||||
let dock = match position {
|
let dock = match position {
|
||||||
DockPosition::Left | DockPosition::Bottom => ProjectPanelDockPosition::Left,
|
DockPosition::Left | DockPosition::Bottom => ProjectPanelDockPosition::Left,
|
||||||
DockPosition::Right => ProjectPanelDockPosition::Right,
|
DockPosition::Right => ProjectPanelDockPosition::Right,
|
||||||
|
|
|
@ -27,7 +27,7 @@ pub struct HeadlessProject {
|
||||||
|
|
||||||
impl HeadlessProject {
|
impl HeadlessProject {
|
||||||
pub fn init(cx: &mut AppContext) {
|
pub fn init(cx: &mut AppContext) {
|
||||||
cx.set_global(SettingsStore::default());
|
cx.set_global(SettingsStore::new(cx));
|
||||||
WorktreeSettings::register(cx);
|
WorktreeSettings::register(cx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1263,4 +1263,4 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
// See https://github.com/zed-industries/zed/pull/14823#discussion_r1684616398 for why this is here and when it should be removed.
|
// See https://github.com/zed-industries/zed/pull/14823#discussion_r1684616398 for why this is here and when it should be removed.
|
||||||
type _TODO = completion::CompletionProvider;
|
type _TODO = completion::LanguageModelCompletionProvider;
|
||||||
|
|
|
@ -21,7 +21,7 @@ pub use settings_store::{
|
||||||
pub struct SettingsAssets;
|
pub struct SettingsAssets;
|
||||||
|
|
||||||
pub fn init(cx: &mut AppContext) {
|
pub fn init(cx: &mut AppContext) {
|
||||||
let mut settings = SettingsStore::default();
|
let mut settings = SettingsStore::new(cx);
|
||||||
settings
|
settings
|
||||||
.set_default_settings(&default_settings(), cx)
|
.set_default_settings(&default_settings(), cx)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
use crate::{settings_store::SettingsStore, Settings};
|
use crate::{settings_store::SettingsStore, Settings};
|
||||||
use anyhow::{Context, Result};
|
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
use futures::{channel::mpsc, StreamExt};
|
use futures::{channel::mpsc, StreamExt};
|
||||||
use gpui::{AppContext, BackgroundExecutor, UpdateGlobal};
|
use gpui::{AppContext, BackgroundExecutor, ReadGlobal, UpdateGlobal};
|
||||||
use std::{io::ErrorKind, path::PathBuf, sync::Arc, time::Duration};
|
use std::{path::PathBuf, sync::Arc, time::Duration};
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
|
||||||
pub const EMPTY_THEME_NAME: &str = "empty-theme";
|
pub const EMPTY_THEME_NAME: &str = "empty-theme";
|
||||||
|
@ -91,46 +90,10 @@ pub fn handle_settings_file_changes(
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn load_settings(fs: &Arc<dyn Fs>) -> Result<String> {
|
|
||||||
match fs.load(paths::settings_file()).await {
|
|
||||||
result @ Ok(_) => result,
|
|
||||||
Err(err) => {
|
|
||||||
if let Some(e) = err.downcast_ref::<std::io::Error>() {
|
|
||||||
if e.kind() == ErrorKind::NotFound {
|
|
||||||
return Ok(crate::initial_user_settings_content().to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn update_settings_file<T: Settings>(
|
pub fn update_settings_file<T: Settings>(
|
||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
cx: &mut AppContext,
|
cx: &AppContext,
|
||||||
update: impl 'static + Send + FnOnce(&mut T::FileContent),
|
update: impl 'static + Send + FnOnce(&mut T::FileContent, &AppContext),
|
||||||
) {
|
) {
|
||||||
cx.spawn(|cx| async move {
|
SettingsStore::global(cx).update_settings_file::<T>(fs, update);
|
||||||
let old_text = load_settings(&fs).await?;
|
|
||||||
let new_text = cx.read_global(|store: &SettingsStore, _cx| {
|
|
||||||
store.new_text_for_update::<T>(old_text, update)
|
|
||||||
})?;
|
|
||||||
let initial_path = paths::settings_file().as_path();
|
|
||||||
if fs.is_file(initial_path).await {
|
|
||||||
let resolved_path = fs.canonicalize(initial_path).await.with_context(|| {
|
|
||||||
format!("Failed to canonicalize settings path {:?}", initial_path)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
fs.atomic_write(resolved_path.clone(), new_text)
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("Failed to write settings to file {:?}", resolved_path))?;
|
|
||||||
} else {
|
|
||||||
fs.atomic_write(initial_path.to_path_buf(), new_text)
|
|
||||||
.await
|
|
||||||
.with_context(|| format!("Failed to write settings to file {:?}", initial_path))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
anyhow::Ok(())
|
|
||||||
})
|
|
||||||
.detach_and_log_err(cx);
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use collections::{btree_map, hash_map, BTreeMap, HashMap};
|
use collections::{btree_map, hash_map, BTreeMap, HashMap};
|
||||||
use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Global, UpdateGlobal};
|
use fs::Fs;
|
||||||
|
use futures::{channel::mpsc, future::LocalBoxFuture, FutureExt, StreamExt};
|
||||||
|
use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Global, Task, UpdateGlobal};
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use schemars::{gen::SchemaGenerator, schema::RootSchema, JsonSchema};
|
use schemars::{gen::SchemaGenerator, schema::RootSchema, JsonSchema};
|
||||||
use serde::{de::DeserializeOwned, Deserialize as _, Serialize};
|
use serde::{de::DeserializeOwned, Deserialize as _, Serialize};
|
||||||
|
@ -161,23 +163,14 @@ pub struct SettingsStore {
|
||||||
TypeId,
|
TypeId,
|
||||||
Box<dyn Fn(&dyn Any) -> Option<usize> + Send + Sync + 'static>,
|
Box<dyn Fn(&dyn Any) -> Option<usize> + Send + Sync + 'static>,
|
||||||
)>,
|
)>,
|
||||||
|
_setting_file_updates: Task<()>,
|
||||||
|
setting_file_updates_tx: mpsc::UnboundedSender<
|
||||||
|
Box<dyn FnOnce(AsyncAppContext) -> LocalBoxFuture<'static, Result<()>>>,
|
||||||
|
>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Global for SettingsStore {}
|
impl Global for SettingsStore {}
|
||||||
|
|
||||||
impl Default for SettingsStore {
|
|
||||||
fn default() -> Self {
|
|
||||||
SettingsStore {
|
|
||||||
setting_values: Default::default(),
|
|
||||||
raw_default_settings: serde_json::json!({}),
|
|
||||||
raw_user_settings: serde_json::json!({}),
|
|
||||||
raw_extension_settings: serde_json::json!({}),
|
|
||||||
raw_local_settings: Default::default(),
|
|
||||||
tab_size_callback: Default::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct SettingValue<T> {
|
struct SettingValue<T> {
|
||||||
global_value: Option<T>,
|
global_value: Option<T>,
|
||||||
|
@ -207,6 +200,24 @@ trait AnySettingValue: 'static + Send + Sync {
|
||||||
struct DeserializedSetting(Box<dyn Any>);
|
struct DeserializedSetting(Box<dyn Any>);
|
||||||
|
|
||||||
impl SettingsStore {
|
impl SettingsStore {
|
||||||
|
pub fn new(cx: &AppContext) -> Self {
|
||||||
|
let (setting_file_updates_tx, mut setting_file_updates_rx) = mpsc::unbounded();
|
||||||
|
Self {
|
||||||
|
setting_values: Default::default(),
|
||||||
|
raw_default_settings: serde_json::json!({}),
|
||||||
|
raw_user_settings: serde_json::json!({}),
|
||||||
|
raw_extension_settings: serde_json::json!({}),
|
||||||
|
raw_local_settings: Default::default(),
|
||||||
|
tab_size_callback: Default::default(),
|
||||||
|
setting_file_updates_tx,
|
||||||
|
_setting_file_updates: cx.spawn(|cx| async move {
|
||||||
|
while let Some(setting_file_update) = setting_file_updates_rx.next().await {
|
||||||
|
(setting_file_update)(cx.clone()).await.log_err();
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn update<C, R>(cx: &mut C, f: impl FnOnce(&mut Self, &mut C) -> R) -> R
|
pub fn update<C, R>(cx: &mut C, f: impl FnOnce(&mut Self, &mut C) -> R) -> R
|
||||||
where
|
where
|
||||||
C: BorrowAppContext,
|
C: BorrowAppContext,
|
||||||
|
@ -301,7 +312,7 @@ impl SettingsStore {
|
||||||
|
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
pub fn test(cx: &mut AppContext) -> Self {
|
pub fn test(cx: &mut AppContext) -> Self {
|
||||||
let mut this = Self::default();
|
let mut this = Self::new(cx);
|
||||||
this.set_default_settings(&crate::test_settings(), cx)
|
this.set_default_settings(&crate::test_settings(), cx)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
this.set_user_settings("{}", cx).unwrap();
|
this.set_user_settings("{}", cx).unwrap();
|
||||||
|
@ -323,6 +334,59 @@ impl SettingsStore {
|
||||||
self.set_user_settings(&new_text, cx).unwrap();
|
self.set_user_settings(&new_text, cx).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn load_settings(fs: &Arc<dyn Fs>) -> Result<String> {
|
||||||
|
match fs.load(paths::settings_file()).await {
|
||||||
|
result @ Ok(_) => result,
|
||||||
|
Err(err) => {
|
||||||
|
if let Some(e) = err.downcast_ref::<std::io::Error>() {
|
||||||
|
if e.kind() == std::io::ErrorKind::NotFound {
|
||||||
|
return Ok(crate::initial_user_settings_content().to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update_settings_file<T: Settings>(
|
||||||
|
&self,
|
||||||
|
fs: Arc<dyn Fs>,
|
||||||
|
update: impl 'static + Send + FnOnce(&mut T::FileContent, &AppContext),
|
||||||
|
) {
|
||||||
|
self.setting_file_updates_tx
|
||||||
|
.unbounded_send(Box::new(move |cx: AsyncAppContext| {
|
||||||
|
async move {
|
||||||
|
let old_text = Self::load_settings(&fs).await?;
|
||||||
|
let new_text = cx.read_global(|store: &SettingsStore, cx| {
|
||||||
|
store.new_text_for_update::<T>(old_text, |content| update(content, cx))
|
||||||
|
})?;
|
||||||
|
let initial_path = paths::settings_file().as_path();
|
||||||
|
if fs.is_file(initial_path).await {
|
||||||
|
let resolved_path =
|
||||||
|
fs.canonicalize(initial_path).await.with_context(|| {
|
||||||
|
format!("Failed to canonicalize settings path {:?}", initial_path)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
fs.atomic_write(resolved_path.clone(), new_text)
|
||||||
|
.await
|
||||||
|
.with_context(|| {
|
||||||
|
format!("Failed to write settings to file {:?}", resolved_path)
|
||||||
|
})?;
|
||||||
|
} else {
|
||||||
|
fs.atomic_write(initial_path.to_path_buf(), new_text)
|
||||||
|
.await
|
||||||
|
.with_context(|| {
|
||||||
|
format!("Failed to write settings to file {:?}", initial_path)
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
}
|
||||||
|
.boxed_local()
|
||||||
|
}))
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
/// Updates the value of a setting in a JSON file, returning the new text
|
/// Updates the value of a setting in a JSON file, returning the new text
|
||||||
/// for that JSON file.
|
/// for that JSON file.
|
||||||
pub fn new_text_for_update<T: Settings>(
|
pub fn new_text_for_update<T: Settings>(
|
||||||
|
@ -1019,7 +1083,7 @@ mod tests {
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
fn test_settings_store_basic(cx: &mut AppContext) {
|
fn test_settings_store_basic(cx: &mut AppContext) {
|
||||||
let mut store = SettingsStore::default();
|
let mut store = SettingsStore::new(cx);
|
||||||
store.register_setting::<UserSettings>(cx);
|
store.register_setting::<UserSettings>(cx);
|
||||||
store.register_setting::<TurboSetting>(cx);
|
store.register_setting::<TurboSetting>(cx);
|
||||||
store.register_setting::<MultiKeySettings>(cx);
|
store.register_setting::<MultiKeySettings>(cx);
|
||||||
|
@ -1148,7 +1212,7 @@ mod tests {
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
fn test_setting_store_assign_json_before_register(cx: &mut AppContext) {
|
fn test_setting_store_assign_json_before_register(cx: &mut AppContext) {
|
||||||
let mut store = SettingsStore::default();
|
let mut store = SettingsStore::new(cx);
|
||||||
store
|
store
|
||||||
.set_default_settings(
|
.set_default_settings(
|
||||||
r#"{
|
r#"{
|
||||||
|
@ -1191,7 +1255,7 @@ mod tests {
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
fn test_setting_store_update(cx: &mut AppContext) {
|
fn test_setting_store_update(cx: &mut AppContext) {
|
||||||
let mut store = SettingsStore::default();
|
let mut store = SettingsStore::new(cx);
|
||||||
store.register_setting::<MultiKeySettings>(cx);
|
store.register_setting::<MultiKeySettings>(cx);
|
||||||
store.register_setting::<UserSettings>(cx);
|
store.register_setting::<UserSettings>(cx);
|
||||||
store.register_setting::<LanguageSettings>(cx);
|
store.register_setting::<LanguageSettings>(cx);
|
||||||
|
|
|
@ -760,14 +760,18 @@ impl Panel for TerminalPanel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
|
fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
|
||||||
settings::update_settings_file::<TerminalSettings>(self.fs.clone(), cx, move |settings| {
|
settings::update_settings_file::<TerminalSettings>(
|
||||||
|
self.fs.clone(),
|
||||||
|
cx,
|
||||||
|
move |settings, _| {
|
||||||
let dock = match position {
|
let dock = match position {
|
||||||
DockPosition::Left => TerminalDockPosition::Left,
|
DockPosition::Left => TerminalDockPosition::Left,
|
||||||
DockPosition::Bottom => TerminalDockPosition::Bottom,
|
DockPosition::Bottom => TerminalDockPosition::Bottom,
|
||||||
DockPosition::Right => TerminalDockPosition::Right,
|
DockPosition::Right => TerminalDockPosition::Right,
|
||||||
};
|
};
|
||||||
settings.dock = Some(dock);
|
settings.dock = Some(dock);
|
||||||
});
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn size(&self, cx: &WindowContext) -> Pixels {
|
fn size(&self, cx: &WindowContext) -> Pixels {
|
||||||
|
|
|
@ -196,7 +196,7 @@ impl PickerDelegate for ThemeSelectorDelegate {
|
||||||
|
|
||||||
let appearance = Appearance::from(cx.appearance());
|
let appearance = Appearance::from(cx.appearance());
|
||||||
|
|
||||||
update_settings_file::<ThemeSettings>(self.fs.clone(), cx, move |settings| {
|
update_settings_file::<ThemeSettings>(self.fs.clone(), cx, move |settings, _| {
|
||||||
if let Some(selection) = settings.theme.as_mut() {
|
if let Some(selection) = settings.theme.as_mut() {
|
||||||
let theme_to_update = match selection {
|
let theme_to_update = match selection {
|
||||||
ThemeSelection::Static(theme) => theme,
|
ThemeSelection::Static(theme) => theme,
|
||||||
|
|
|
@ -147,7 +147,7 @@ fn register(workspace: &mut Workspace, cx: &mut ViewContext<Workspace>) {
|
||||||
workspace.register_action(|workspace: &mut Workspace, _: &ToggleVimMode, cx| {
|
workspace.register_action(|workspace: &mut Workspace, _: &ToggleVimMode, cx| {
|
||||||
let fs = workspace.app_state().fs.clone();
|
let fs = workspace.app_state().fs.clone();
|
||||||
let currently_enabled = VimModeSetting::get_global(cx).0;
|
let currently_enabled = VimModeSetting::get_global(cx).0;
|
||||||
update_settings_file::<VimModeSetting>(fs, cx, move |setting| {
|
update_settings_file::<VimModeSetting>(fs, cx, move |setting, _| {
|
||||||
*setting = Some(!currently_enabled)
|
*setting = Some(!currently_enabled)
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
|
@ -176,7 +176,7 @@ impl PickerDelegate for BaseKeymapSelectorDelegate {
|
||||||
self.telemetry
|
self.telemetry
|
||||||
.report_setting_event("keymap", base_keymap.to_string());
|
.report_setting_event("keymap", base_keymap.to_string());
|
||||||
|
|
||||||
update_settings_file::<BaseKeymap>(self.fs.clone(), cx, move |setting| {
|
update_settings_file::<BaseKeymap>(self.fs.clone(), cx, move |setting, _| {
|
||||||
*setting = Some(base_keymap)
|
*setting = Some(base_keymap)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -279,7 +279,7 @@ impl WelcomePage {
|
||||||
if let Some(workspace) = self.workspace.upgrade() {
|
if let Some(workspace) = self.workspace.upgrade() {
|
||||||
let fs = workspace.read(cx).app_state().fs.clone();
|
let fs = workspace.read(cx).app_state().fs.clone();
|
||||||
let selection = *selection;
|
let selection = *selection;
|
||||||
settings::update_settings_file::<T>(fs, cx, move |settings| {
|
settings::update_settings_file::<T>(fs, cx, move |settings, _| {
|
||||||
let value = match selection {
|
let value = match selection {
|
||||||
Selection::Unselected => false,
|
Selection::Unselected => false,
|
||||||
Selection::Selected => true,
|
Selection::Selected => true,
|
||||||
|
|
|
@ -56,6 +56,7 @@ install_cli.workspace = true
|
||||||
isahc.workspace = true
|
isahc.workspace = true
|
||||||
journal.workspace = true
|
journal.workspace = true
|
||||||
language.workspace = true
|
language.workspace = true
|
||||||
|
language_model.workspace = true
|
||||||
language_selector.workspace = true
|
language_selector.workspace = true
|
||||||
language_tools.workspace = true
|
language_tools.workspace = true
|
||||||
languages.workspace = true
|
languages.workspace = true
|
||||||
|
|
|
@ -164,6 +164,7 @@ fn init_common(app_state: Arc<AppState>, cx: &mut AppContext) {
|
||||||
SystemAppearance::init(cx);
|
SystemAppearance::init(cx);
|
||||||
theme::init(theme::LoadThemes::All(Box::new(Assets)), cx);
|
theme::init(theme::LoadThemes::All(Box::new(Assets)), cx);
|
||||||
command_palette::init(cx);
|
command_palette::init(cx);
|
||||||
|
language_model::init(app_state.client.clone(), cx);
|
||||||
snippet_provider::init(cx);
|
snippet_provider::init(cx);
|
||||||
supermaven::init(app_state.client.clone(), cx);
|
supermaven::init(app_state.client.clone(), cx);
|
||||||
inline_completion_registry::init(app_state.client.telemetry().clone(), cx);
|
inline_completion_registry::init(app_state.client.telemetry().clone(), cx);
|
||||||
|
|
|
@ -3436,6 +3436,7 @@ mod tests {
|
||||||
project_panel::init((), cx);
|
project_panel::init((), cx);
|
||||||
outline_panel::init((), cx);
|
outline_panel::init((), cx);
|
||||||
terminal_view::init(cx);
|
terminal_view::init(cx);
|
||||||
|
language_model::init(app_state.client.clone(), cx);
|
||||||
assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);
|
assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);
|
||||||
repl::init(app_state.fs.clone(), cx);
|
repl::init(app_state.fs.clone(), cx);
|
||||||
tasks_ui::init(cx);
|
tasks_ui::init(cx);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue