Support multiple OpenAI compatible providers (#34212)
TODO - [x] OpenAI Compatible API Icon - [x] Docs - [x] Link to docs in OpenAI provider section about configuring OpenAI API compatible providers Closes #33992 Related to #30010 Release Notes: - agent: Add support for adding multiple OpenAI API compatible providers --------- Co-authored-by: MrSubidubi <dev@bahn.sh> Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
This commit is contained in:
parent
1a76a6b0bf
commit
230061a6cb
23 changed files with 1450 additions and 191 deletions
|
@ -53,6 +53,7 @@ itertools.workspace = true
|
|||
jsonschema.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
log.workspace = true
|
||||
lsp.workspace = true
|
||||
markdown.workspace = true
|
||||
|
@ -87,6 +88,7 @@ theme.workspace = true
|
|||
time.workspace = true
|
||||
time_format.workspace = true
|
||||
ui.workspace = true
|
||||
ui_input.workspace = true
|
||||
urlencoding.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
|
|
|
@ -3895,7 +3895,7 @@ mod tests {
|
|||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.set_default_model(
|
||||
Some(ConfiguredModel {
|
||||
provider: Arc::new(FakeLanguageModelProvider),
|
||||
provider: Arc::new(FakeLanguageModelProvider::default()),
|
||||
model,
|
||||
}),
|
||||
cx,
|
||||
|
@ -3979,7 +3979,7 @@ mod tests {
|
|||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.set_default_model(
|
||||
Some(ConfiguredModel {
|
||||
provider: Arc::new(FakeLanguageModelProvider),
|
||||
provider: Arc::new(FakeLanguageModelProvider::default()),
|
||||
model: model.clone(),
|
||||
}),
|
||||
cx,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
mod add_llm_provider_modal;
|
||||
mod configure_context_server_modal;
|
||||
mod manage_profiles_modal;
|
||||
mod tool_picker;
|
||||
|
@ -37,7 +38,10 @@ use zed_actions::ExtensionCategoryFilter;
|
|||
pub(crate) use configure_context_server_modal::ConfigureContextServerModal;
|
||||
pub(crate) use manage_profiles_modal::ManageProfilesModal;
|
||||
|
||||
use crate::AddContextServer;
|
||||
use crate::{
|
||||
AddContextServer,
|
||||
agent_configuration::add_llm_provider_modal::{AddLlmProviderModal, LlmCompatibleProvider},
|
||||
};
|
||||
|
||||
pub struct AgentConfiguration {
|
||||
fs: Arc<dyn Fs>,
|
||||
|
@ -304,16 +308,55 @@ impl AgentConfiguration {
|
|||
|
||||
v_flex()
|
||||
.child(
|
||||
v_flex()
|
||||
h_flex()
|
||||
.p(DynamicSpacing::Base16.rems(cx))
|
||||
.pr(DynamicSpacing::Base20.rems(cx))
|
||||
.pb_0()
|
||||
.mb_2p5()
|
||||
.gap_0p5()
|
||||
.child(Headline::new("LLM Providers"))
|
||||
.items_start()
|
||||
.justify_between()
|
||||
.child(
|
||||
Label::new("Add at least one provider to use AI-powered features.")
|
||||
.color(Color::Muted),
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(Headline::new("LLM Providers"))
|
||||
.child(
|
||||
Label::new("Add at least one provider to use AI-powered features.")
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
PopoverMenu::new("add-provider-popover")
|
||||
.trigger(
|
||||
Button::new("add-provider", "Add Provider")
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon(IconName::Plus)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.label_size(LabelSize::Small),
|
||||
)
|
||||
.anchor(gpui::Corner::TopRight)
|
||||
.menu({
|
||||
let workspace = self.workspace.clone();
|
||||
move |window, cx| {
|
||||
Some(ContextMenu::build(window, cx, |menu, _window, _cx| {
|
||||
menu.header("Compatible APIs").entry("OpenAI", None, {
|
||||
let workspace = workspace.clone();
|
||||
move |window, cx| {
|
||||
workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
AddLlmProviderModal::toggle(
|
||||
LlmCompatibleProvider::OpenAi,
|
||||
workspace,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
})
|
||||
}))
|
||||
}
|
||||
}),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
|
|
|
@ -0,0 +1,639 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use collections::HashSet;
|
||||
use fs::Fs;
|
||||
use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, Task};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use language_models::{
|
||||
AllLanguageModelSettings, OpenAiCompatibleSettingsContent,
|
||||
provider::open_ai_compatible::AvailableModel,
|
||||
};
|
||||
use settings::update_settings_file;
|
||||
use ui::{Banner, KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
|
||||
use ui_input::SingleLineInput;
|
||||
use workspace::{ModalView, Workspace};
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum LlmCompatibleProvider {
|
||||
OpenAi,
|
||||
}
|
||||
|
||||
impl LlmCompatibleProvider {
|
||||
fn name(&self) -> &'static str {
|
||||
match self {
|
||||
LlmCompatibleProvider::OpenAi => "OpenAI",
|
||||
}
|
||||
}
|
||||
|
||||
fn api_url(&self) -> &'static str {
|
||||
match self {
|
||||
LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct AddLlmProviderInput {
|
||||
provider_name: Entity<SingleLineInput>,
|
||||
api_url: Entity<SingleLineInput>,
|
||||
api_key: Entity<SingleLineInput>,
|
||||
models: Vec<ModelInput>,
|
||||
}
|
||||
|
||||
impl AddLlmProviderInput {
|
||||
fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
|
||||
let provider_name = single_line_input("Provider Name", provider.name(), None, window, cx);
|
||||
let api_url = single_line_input("API URL", provider.api_url(), None, window, cx);
|
||||
let api_key = single_line_input(
|
||||
"API Key",
|
||||
"000000000000000000000000000000000000000000000000",
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
Self {
|
||||
provider_name,
|
||||
api_url,
|
||||
api_key,
|
||||
models: vec![ModelInput::new(window, cx)],
|
||||
}
|
||||
}
|
||||
|
||||
fn add_model(&mut self, window: &mut Window, cx: &mut App) {
|
||||
self.models.push(ModelInput::new(window, cx));
|
||||
}
|
||||
|
||||
fn remove_model(&mut self, index: usize) {
|
||||
self.models.remove(index);
|
||||
}
|
||||
}
|
||||
|
||||
struct ModelInput {
|
||||
name: Entity<SingleLineInput>,
|
||||
max_completion_tokens: Entity<SingleLineInput>,
|
||||
max_output_tokens: Entity<SingleLineInput>,
|
||||
max_tokens: Entity<SingleLineInput>,
|
||||
}
|
||||
|
||||
impl ModelInput {
|
||||
fn new(window: &mut Window, cx: &mut App) -> Self {
|
||||
let model_name = single_line_input(
|
||||
"Model Name",
|
||||
"e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
let max_completion_tokens = single_line_input(
|
||||
"Max Completion Tokens",
|
||||
"200000",
|
||||
Some("200000"),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
let max_output_tokens = single_line_input(
|
||||
"Max Output Tokens",
|
||||
"Max Output Tokens",
|
||||
Some("32000"),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx);
|
||||
Self {
|
||||
name: model_name,
|
||||
max_completion_tokens,
|
||||
max_output_tokens,
|
||||
max_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
|
||||
let name = self.name.read(cx).text(cx);
|
||||
if name.is_empty() {
|
||||
return Err(SharedString::from("Model Name cannot be empty"));
|
||||
}
|
||||
Ok(AvailableModel {
|
||||
name,
|
||||
display_name: None,
|
||||
max_completion_tokens: Some(
|
||||
self.max_completion_tokens
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.parse::<u64>()
|
||||
.map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
|
||||
),
|
||||
max_output_tokens: Some(
|
||||
self.max_output_tokens
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.parse::<u64>()
|
||||
.map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
|
||||
),
|
||||
max_tokens: self
|
||||
.max_tokens
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.parse::<u64>()
|
||||
.map_err(|_| SharedString::from("Max Tokens must be a number"))?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn single_line_input(
|
||||
label: impl Into<SharedString>,
|
||||
placeholder: impl Into<SharedString>,
|
||||
text: Option<&str>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<SingleLineInput> {
|
||||
cx.new(|cx| {
|
||||
let input = SingleLineInput::new(window, cx, placeholder).label(label);
|
||||
if let Some(text) = text {
|
||||
input
|
||||
.editor()
|
||||
.update(cx, |editor, cx| editor.set_text(text, window, cx));
|
||||
}
|
||||
input
|
||||
})
|
||||
}
|
||||
|
||||
fn save_provider_to_settings(
|
||||
input: &AddLlmProviderInput,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<(), SharedString>> {
|
||||
let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
|
||||
if provider_name.is_empty() {
|
||||
return Task::ready(Err("Provider Name cannot be empty".into()));
|
||||
}
|
||||
|
||||
if LanguageModelRegistry::read_global(cx)
|
||||
.providers()
|
||||
.iter()
|
||||
.any(|provider| {
|
||||
provider.id().0.as_ref() == provider_name.as_ref()
|
||||
|| provider.name().0.as_ref() == provider_name.as_ref()
|
||||
})
|
||||
{
|
||||
return Task::ready(Err(
|
||||
"Provider Name is already taken by another provider".into()
|
||||
));
|
||||
}
|
||||
|
||||
let api_url = input.api_url.read(cx).text(cx);
|
||||
if api_url.is_empty() {
|
||||
return Task::ready(Err("API URL cannot be empty".into()));
|
||||
}
|
||||
|
||||
let api_key = input.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return Task::ready(Err("API Key cannot be empty".into()));
|
||||
}
|
||||
|
||||
let mut models = Vec::new();
|
||||
let mut model_names: HashSet<String> = HashSet::default();
|
||||
for model in &input.models {
|
||||
match model.parse(cx) {
|
||||
Ok(model) => {
|
||||
if !model_names.insert(model.name.clone()) {
|
||||
return Task::ready(Err("Model Names must be unique".into()));
|
||||
}
|
||||
models.push(model)
|
||||
}
|
||||
Err(err) => return Task::ready(Err(err)),
|
||||
}
|
||||
}
|
||||
|
||||
let fs = <dyn Fs>::global(cx);
|
||||
let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
|
||||
cx.spawn(async move |cx| {
|
||||
task.await
|
||||
.map_err(|_| "Failed to write API key to keychain")?;
|
||||
cx.update(|cx| {
|
||||
update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| {
|
||||
settings.openai_compatible.get_or_insert_default().insert(
|
||||
provider_name,
|
||||
OpenAiCompatibleSettingsContent {
|
||||
api_url,
|
||||
available_models: models,
|
||||
},
|
||||
);
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub struct AddLlmProviderModal {
|
||||
provider: LlmCompatibleProvider,
|
||||
input: AddLlmProviderInput,
|
||||
focus_handle: FocusHandle,
|
||||
last_error: Option<SharedString>,
|
||||
}
|
||||
|
||||
impl AddLlmProviderModal {
|
||||
pub fn toggle(
|
||||
provider: LlmCompatibleProvider,
|
||||
workspace: &mut Workspace,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) {
|
||||
workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
|
||||
}
|
||||
|
||||
fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
Self {
|
||||
input: AddLlmProviderInput::new(provider, window, cx),
|
||||
provider,
|
||||
last_error: None,
|
||||
focus_handle: cx.focus_handle(),
|
||||
}
|
||||
}
|
||||
|
||||
fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
|
||||
let task = save_provider_to_settings(&self.input, cx);
|
||||
cx.spawn(async move |this, cx| {
|
||||
let result = task.await;
|
||||
this.update(cx, |this, cx| match result {
|
||||
Ok(_) => {
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
Err(error) => {
|
||||
this.last_error = Some(error);
|
||||
cx.notify();
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
|
||||
fn render_section(&self) -> Section {
|
||||
Section::new()
|
||||
.child(self.input.provider_name.clone())
|
||||
.child(self.input.api_url.clone())
|
||||
.child(self.input.api_key.clone())
|
||||
}
|
||||
|
||||
fn render_model_section(&self, cx: &mut Context<Self>) -> Section {
|
||||
Section::new().child(
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_between()
|
||||
.child(Label::new("Models").size(LabelSize::Small))
|
||||
.child(
|
||||
Button::new("add-model", "Add Model")
|
||||
.icon(IconName::Plus)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.input.add_model(window, cx);
|
||||
cx.notify();
|
||||
})),
|
||||
),
|
||||
)
|
||||
.children(
|
||||
self.input
|
||||
.models
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(ix, _)| self.render_model(ix, cx)),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
|
||||
let has_more_than_one_model = self.input.models.len() > 1;
|
||||
let model = &self.input.models[ix];
|
||||
|
||||
v_flex()
|
||||
.p_2()
|
||||
.gap_2()
|
||||
.rounded_sm()
|
||||
.border_1()
|
||||
.border_dashed()
|
||||
.border_color(cx.theme().colors().border.opacity(0.6))
|
||||
.bg(cx.theme().colors().element_active.opacity(0.15))
|
||||
.child(model.name.clone())
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(model.max_completion_tokens.clone())
|
||||
.child(model.max_output_tokens.clone()),
|
||||
)
|
||||
.child(model.max_tokens.clone())
|
||||
.when(has_more_than_one_model, |this| {
|
||||
this.child(
|
||||
Button::new(("remove-model", ix), "Remove Model")
|
||||
.icon(IconName::Trash)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.label_size(LabelSize::Small)
|
||||
.style(ButtonStyle::Outlined)
|
||||
.full_width()
|
||||
.on_click(cx.listener(move |this, _, _window, cx| {
|
||||
this.input.remove_model(ix);
|
||||
cx.notify();
|
||||
})),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
|
||||
|
||||
impl Focusable for AddLlmProviderModal {
|
||||
fn focus_handle(&self, _cx: &App) -> FocusHandle {
|
||||
self.focus_handle.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl ModalView for AddLlmProviderModal {}
|
||||
|
||||
impl Render for AddLlmProviderModal {
|
||||
fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
|
||||
let focus_handle = self.focus_handle(cx);
|
||||
|
||||
div()
|
||||
.id("add-llm-provider-modal")
|
||||
.key_context("AddLlmProviderModal")
|
||||
.w(rems(34.))
|
||||
.elevation_3(cx)
|
||||
.on_action(cx.listener(Self::cancel))
|
||||
.capture_any_mouse_down(cx.listener(|this, _, window, cx| {
|
||||
this.focus_handle(cx).focus(window);
|
||||
}))
|
||||
.child(
|
||||
Modal::new("configure-context-server", None)
|
||||
.header(ModalHeader::new().headline("Add LLM Provider").description(
|
||||
match self.provider {
|
||||
LlmCompatibleProvider::OpenAi => {
|
||||
"This provider will use an OpenAI compatible API."
|
||||
}
|
||||
},
|
||||
))
|
||||
.when_some(self.last_error.clone(), |this, error| {
|
||||
this.section(
|
||||
Section::new().child(
|
||||
Banner::new()
|
||||
.severity(ui::Severity::Warning)
|
||||
.child(div().text_xs().child(error)),
|
||||
),
|
||||
)
|
||||
})
|
||||
.child(
|
||||
v_flex()
|
||||
.id("modal_content")
|
||||
.max_h_128()
|
||||
.overflow_y_scroll()
|
||||
.gap_2()
|
||||
.child(self.render_section())
|
||||
.child(self.render_model_section(cx)),
|
||||
)
|
||||
.footer(
|
||||
ModalFooter::new().end_slot(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
Button::new("cancel", "Cancel")
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(
|
||||
&menu::Cancel,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(cx.listener(|this, _event, window, cx| {
|
||||
this.cancel(&menu::Cancel, window, cx)
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new("save-server", "Save Provider")
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(
|
||||
&menu::Confirm,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(cx.listener(|this, _event, window, cx| {
|
||||
this.confirm(&menu::Confirm, window, cx)
|
||||
})),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use editor::EditorSettings;
|
||||
use fs::FakeFs;
|
||||
use gpui::{TestAppContext, VisualTestContext};
|
||||
use language::language_settings;
|
||||
use language_model::{
|
||||
LanguageModelProviderId, LanguageModelProviderName,
|
||||
fake_provider::FakeLanguageModelProvider,
|
||||
};
|
||||
use project::Project;
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
|
||||
let cx = setup_test(cx).await;
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
|
||||
Some("Provider Name cannot be empty".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
|
||||
Some("API URL cannot be empty".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
|
||||
Some("API Key cannot be empty".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"somekey",
|
||||
vec![("", "200000", "200000", "32000")],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Model Name cannot be empty".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"somekey",
|
||||
vec![("somemodel", "abc", "200000", "32000")],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Max Tokens must be a number".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"somekey",
|
||||
vec![("somemodel", "200000", "abc", "32000")],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Max Completion Tokens must be a number".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"somekey",
|
||||
vec![("somemodel", "200000", "200000", "abc")],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Max Output Tokens must be a number".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"somekey",
|
||||
vec![
|
||||
("somemodel", "200000", "200000", "32000"),
|
||||
("somemodel", "200000", "200000", "32000"),
|
||||
],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Model Names must be unique".into())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
|
||||
let cx = setup_test(cx).await;
|
||||
|
||||
cx.update(|_window, cx| {
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.register_provider(
|
||||
FakeLanguageModelProvider::new(
|
||||
LanguageModelProviderId::new("someprovider"),
|
||||
LanguageModelProviderName::new("Some Provider"),
|
||||
),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"someapikey",
|
||||
vec![("somemodel", "200000", "200000", "32000")],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Provider Name is already taken by another provider".into())
|
||||
);
|
||||
}
|
||||
|
||||
async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
|
||||
cx.update(|cx| {
|
||||
let store = SettingsStore::test(cx);
|
||||
cx.set_global(store);
|
||||
workspace::init_settings(cx);
|
||||
Project::init_settings(cx);
|
||||
theme::init(theme::LoadThemes::JustBase, cx);
|
||||
language_settings::init(cx);
|
||||
EditorSettings::register(cx);
|
||||
language_model::init_settings(cx);
|
||||
language_models::init_settings(cx);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
|
||||
let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
|
||||
let (_, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
|
||||
cx
|
||||
}
|
||||
|
||||
async fn save_provider_validation_errors(
|
||||
provider_name: &str,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
models: Vec<(&str, &str, &str, &str)>,
|
||||
cx: &mut VisualTestContext,
|
||||
) -> Option<SharedString> {
|
||||
fn set_text(
|
||||
input: &Entity<SingleLineInput>,
|
||||
text: &str,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
input.update(cx, |input, cx| {
|
||||
input.editor().update(cx, |editor, cx| {
|
||||
editor.set_text(text, window, cx);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
let task = cx.update(|window, cx| {
|
||||
let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
|
||||
set_text(&input.provider_name, provider_name, window, cx);
|
||||
set_text(&input.api_url, api_url, window, cx);
|
||||
set_text(&input.api_key, api_key, window, cx);
|
||||
|
||||
for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
|
||||
models.iter().enumerate()
|
||||
{
|
||||
if i >= input.models.len() {
|
||||
input.models.push(ModelInput::new(window, cx));
|
||||
}
|
||||
let model = &mut input.models[i];
|
||||
set_text(&model.name, name, window, cx);
|
||||
set_text(&model.max_tokens, max_tokens, window, cx);
|
||||
set_text(
|
||||
&model.max_completion_tokens,
|
||||
max_completion_tokens,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
set_text(&model.max_output_tokens, max_output_tokens, window, cx);
|
||||
}
|
||||
save_provider_to_settings(&input, cx)
|
||||
});
|
||||
|
||||
task.await.err()
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue