Support multiple OpenAI compatible providers (#34212)
TODO - [x] OpenAI Compatible API Icon - [x] Docs - [x] Link to docs in OpenAI provider section about configuring OpenAI API compatible providers Closes #33992 Related to #30010 Release Notes: - agent: Add support for adding multiple OpenAI API compatible providers --------- Co-authored-by: MrSubidubi <dev@bahn.sh> Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
This commit is contained in:
parent
1a76a6b0bf
commit
230061a6cb
23 changed files with 1450 additions and 191 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
@ -231,6 +231,7 @@ dependencies = [
|
|||
"jsonschema",
|
||||
"language",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"languages",
|
||||
"log",
|
||||
"lsp",
|
||||
|
@ -269,6 +270,7 @@ dependencies = [
|
|||
"time_format",
|
||||
"tree-sitter-md",
|
||||
"ui",
|
||||
"ui_input",
|
||||
"unindent",
|
||||
"urlencoding",
|
||||
"util",
|
||||
|
@ -9097,11 +9099,11 @@ dependencies = [
|
|||
"client",
|
||||
"collections",
|
||||
"component",
|
||||
"convert_case 0.8.0",
|
||||
"copilot",
|
||||
"credentials_provider",
|
||||
"deepseek",
|
||||
"editor",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"google_ai",
|
||||
"gpui",
|
||||
|
|
4
assets/icons/ai_open_ai_compat.svg
Normal file
4
assets/icons/ai_open_ai_compat.svg
Normal file
|
@ -0,0 +1,4 @@
|
|||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M7.25669 0.999943C8.27509 0.993825 9.24655 1.42125 9.9227 2.17279C11.4427 1.85079 12.9991 2.53518 13.7733 3.86518C14.159 4.5149 14.3171 5.26409 14.2372 5.99994H13.2967C13.3789 5.42185 13.265 4.8321 12.9686 4.32514C12.2353 3.06961 10.6088 2.63919 9.33676 3.36322L6.48032 4.98822C6.46926 4.99697 6.46284 5.01135 6.46372 5.02533V6.38568L9.91294 4.42084C10.0565 4.33818 10.2336 4.33823 10.3768 4.42084L13.1502 5.99994H11.2948L9.88364 5.19623C9.87034 5.19054 9.85459 5.19128 9.84262 5.19916L8.64926 5.87983L8.8602 5.99994H7.99985C6.89539 6.00004 5.99988 6.89547 5.99985 7.99994V9.34955L3.90219 8.15522C3.75815 8.07431 3.66897 7.92228 3.66977 7.75873V4.53803C3.66977 4.50828 3.67172 4.4654 3.67172 4.44135C3.08836 4.65262 2.59832 5.0599 2.28794 5.59174C1.55635 6.84647 1.99122 8.44936 3.26059 9.17475L5.99985 10.7363V11.6162C5.87564 11.6568 5.73827 11.6456 5.6229 11.579L2.7977 9.96869C2.77156 9.95382 2.73449 9.9311 2.71372 9.91889C2.60687 10.5231 2.7194 11.1466 3.0311 11.6777C3.6435 12.7209 4.87159 13.1902 5.99985 12.9023V13.8398C4.50443 14.1233 2.98758 13.4424 2.22641 12.1347C1.71174 11.2677 1.60096 10.2237 1.9227 9.27045C0.880739 8.13295 0.703328 6.46023 1.48325 5.13373C1.98739 4.26024 2.84863 3.64401 3.84653 3.44233C4.3245 1.9837 5.70306 0.996447 7.25669 0.999943ZM7.25766 1.91498C5.78932 1.9143 4.59839 3.08914 4.59751 4.53803V7.79193C4.59926 7.80578 4.60735 7.81796 4.61997 7.82416L5.8143 8.50483L5.81626 4.57611C5.81537 4.41216 5.90431 4.2606 6.04868 4.17963L8.87387 2.56928C8.89868 2.55441 8.93612 2.53379 8.95786 2.5224C8.48035 2.13046 7.8788 1.91498 7.25766 1.91498Z" fill="black"/>
|
||||
<path d="M13.5 6C14.6046 6 15.5 6.89543 15.5 8V13.5C15.5 14.6046 14.6046 15.5 13.5 15.5H8C6.89543 15.5 6 14.6046 6 13.5V8C6 6.89543 6.89543 6 8 6H13.5ZM10.8916 8.02539C10.0563 8.02539 9.33453 8.27982 8.81934 8.76562C8.30213 9.25335 8.02547 9.94371 8.02539 10.748C8.02539 11.557 8.29852 12.2492 8.81543 12.7373C9.33013 13.2232 10.0521 13.4746 10.8916 13.4746C11.9865 13.4745 12.8545 13.1022 13.3076 12.3525C13.3894 12.2176 13.4521 12.0693 13.4521 11.8857C13.4521 11.4795 13.0933 11.2773 12.7842 11.2773C12.6604 11.2774 12.5292 11.3025 12.4072 11.3779C12.2862 11.4529 12.2058 11.5586 12.1494 11.666L12.1475 11.6689C11.9677 12.0213 11.5535 12.246 10.8955 12.2461C10.4219 12.2461 10.0667 12.0932 9.83008 11.8506C9.59255 11.607 9.44141 11.2389 9.44141 10.748C9.44148 10.264 9.59319 9.89628 9.83203 9.65137C10.0702 9.40725 10.4255 9.25391 10.8916 9.25391C11.4912 9.25399 11.9415 9.50614 12.1289 9.8916V9.89062C12.1888 10.0157 12.276 10.1311 12.4023 10.2129C12.5303 10.2956 12.6724 10.3271 12.8115 10.3271C12.9661 10.3271 13.1303 10.2857 13.2627 10.1758C13.4018 10.0603 13.4746 9.89383 13.4746 9.71582C13.4746 9.61857 13.4542 9.52036 13.4199 9.42773L13.3818 9.33691C12.9749 8.49175 11.9927 8.02548 10.8916 8.02539ZM10.3203 8.97852L10.1494 9.03516C10.2095 9.01178 10.2716 8.99089 10.3359 8.97363C10.3307 8.97505 10.3256 8.97706 10.3203 8.97852ZM10.4814 8.94141C10.4969 8.9385 10.5126 8.93616 10.5283 8.93359C10.5126 8.93617 10.4969 8.9385 10.4814 8.94141ZM10.6709 8.91504C10.6819 8.91399 10.693 8.913 10.7041 8.91211C10.693 8.913 10.6819 8.91399 10.6709 8.91504Z" fill="black" fill-opacity="0.95"/>
|
||||
</svg>
|
After Width: | Height: | Size: 3.2 KiB |
|
@ -1712,6 +1712,7 @@
|
|||
"openai": {
|
||||
"api_url": "https://api.openai.com/v1"
|
||||
},
|
||||
"openai_compatible": {},
|
||||
"open_router": {
|
||||
"api_url": "https://openrouter.ai/api/v1"
|
||||
},
|
||||
|
|
|
@ -5490,7 +5490,7 @@ fn main() {{
|
|||
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
||||
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
|
||||
|
||||
let provider = Arc::new(FakeLanguageModelProvider);
|
||||
let provider = Arc::new(FakeLanguageModelProvider::default());
|
||||
let model = provider.test_model();
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(model);
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ itertools.workspace = true
|
|||
jsonschema.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
log.workspace = true
|
||||
lsp.workspace = true
|
||||
markdown.workspace = true
|
||||
|
@ -87,6 +88,7 @@ theme.workspace = true
|
|||
time.workspace = true
|
||||
time_format.workspace = true
|
||||
ui.workspace = true
|
||||
ui_input.workspace = true
|
||||
urlencoding.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
|
|
|
@ -3895,7 +3895,7 @@ mod tests {
|
|||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.set_default_model(
|
||||
Some(ConfiguredModel {
|
||||
provider: Arc::new(FakeLanguageModelProvider),
|
||||
provider: Arc::new(FakeLanguageModelProvider::default()),
|
||||
model,
|
||||
}),
|
||||
cx,
|
||||
|
@ -3979,7 +3979,7 @@ mod tests {
|
|||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.set_default_model(
|
||||
Some(ConfiguredModel {
|
||||
provider: Arc::new(FakeLanguageModelProvider),
|
||||
provider: Arc::new(FakeLanguageModelProvider::default()),
|
||||
model: model.clone(),
|
||||
}),
|
||||
cx,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
mod add_llm_provider_modal;
|
||||
mod configure_context_server_modal;
|
||||
mod manage_profiles_modal;
|
||||
mod tool_picker;
|
||||
|
@ -37,7 +38,10 @@ use zed_actions::ExtensionCategoryFilter;
|
|||
pub(crate) use configure_context_server_modal::ConfigureContextServerModal;
|
||||
pub(crate) use manage_profiles_modal::ManageProfilesModal;
|
||||
|
||||
use crate::AddContextServer;
|
||||
use crate::{
|
||||
AddContextServer,
|
||||
agent_configuration::add_llm_provider_modal::{AddLlmProviderModal, LlmCompatibleProvider},
|
||||
};
|
||||
|
||||
pub struct AgentConfiguration {
|
||||
fs: Arc<dyn Fs>,
|
||||
|
@ -304,16 +308,55 @@ impl AgentConfiguration {
|
|||
|
||||
v_flex()
|
||||
.child(
|
||||
v_flex()
|
||||
h_flex()
|
||||
.p(DynamicSpacing::Base16.rems(cx))
|
||||
.pr(DynamicSpacing::Base20.rems(cx))
|
||||
.pb_0()
|
||||
.mb_2p5()
|
||||
.gap_0p5()
|
||||
.child(Headline::new("LLM Providers"))
|
||||
.items_start()
|
||||
.justify_between()
|
||||
.child(
|
||||
Label::new("Add at least one provider to use AI-powered features.")
|
||||
.color(Color::Muted),
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(Headline::new("LLM Providers"))
|
||||
.child(
|
||||
Label::new("Add at least one provider to use AI-powered features.")
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
PopoverMenu::new("add-provider-popover")
|
||||
.trigger(
|
||||
Button::new("add-provider", "Add Provider")
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon(IconName::Plus)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.label_size(LabelSize::Small),
|
||||
)
|
||||
.anchor(gpui::Corner::TopRight)
|
||||
.menu({
|
||||
let workspace = self.workspace.clone();
|
||||
move |window, cx| {
|
||||
Some(ContextMenu::build(window, cx, |menu, _window, _cx| {
|
||||
menu.header("Compatible APIs").entry("OpenAI", None, {
|
||||
let workspace = workspace.clone();
|
||||
move |window, cx| {
|
||||
workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
AddLlmProviderModal::toggle(
|
||||
LlmCompatibleProvider::OpenAi,
|
||||
workspace,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
})
|
||||
}))
|
||||
}
|
||||
}),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
|
|
|
@ -0,0 +1,639 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use collections::HashSet;
|
||||
use fs::Fs;
|
||||
use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, Task};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use language_models::{
|
||||
AllLanguageModelSettings, OpenAiCompatibleSettingsContent,
|
||||
provider::open_ai_compatible::AvailableModel,
|
||||
};
|
||||
use settings::update_settings_file;
|
||||
use ui::{Banner, KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
|
||||
use ui_input::SingleLineInput;
|
||||
use workspace::{ModalView, Workspace};
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum LlmCompatibleProvider {
|
||||
OpenAi,
|
||||
}
|
||||
|
||||
impl LlmCompatibleProvider {
|
||||
fn name(&self) -> &'static str {
|
||||
match self {
|
||||
LlmCompatibleProvider::OpenAi => "OpenAI",
|
||||
}
|
||||
}
|
||||
|
||||
fn api_url(&self) -> &'static str {
|
||||
match self {
|
||||
LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct AddLlmProviderInput {
|
||||
provider_name: Entity<SingleLineInput>,
|
||||
api_url: Entity<SingleLineInput>,
|
||||
api_key: Entity<SingleLineInput>,
|
||||
models: Vec<ModelInput>,
|
||||
}
|
||||
|
||||
impl AddLlmProviderInput {
|
||||
fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
|
||||
let provider_name = single_line_input("Provider Name", provider.name(), None, window, cx);
|
||||
let api_url = single_line_input("API URL", provider.api_url(), None, window, cx);
|
||||
let api_key = single_line_input(
|
||||
"API Key",
|
||||
"000000000000000000000000000000000000000000000000",
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
Self {
|
||||
provider_name,
|
||||
api_url,
|
||||
api_key,
|
||||
models: vec![ModelInput::new(window, cx)],
|
||||
}
|
||||
}
|
||||
|
||||
fn add_model(&mut self, window: &mut Window, cx: &mut App) {
|
||||
self.models.push(ModelInput::new(window, cx));
|
||||
}
|
||||
|
||||
fn remove_model(&mut self, index: usize) {
|
||||
self.models.remove(index);
|
||||
}
|
||||
}
|
||||
|
||||
struct ModelInput {
|
||||
name: Entity<SingleLineInput>,
|
||||
max_completion_tokens: Entity<SingleLineInput>,
|
||||
max_output_tokens: Entity<SingleLineInput>,
|
||||
max_tokens: Entity<SingleLineInput>,
|
||||
}
|
||||
|
||||
impl ModelInput {
|
||||
fn new(window: &mut Window, cx: &mut App) -> Self {
|
||||
let model_name = single_line_input(
|
||||
"Model Name",
|
||||
"e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
let max_completion_tokens = single_line_input(
|
||||
"Max Completion Tokens",
|
||||
"200000",
|
||||
Some("200000"),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
let max_output_tokens = single_line_input(
|
||||
"Max Output Tokens",
|
||||
"Max Output Tokens",
|
||||
Some("32000"),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx);
|
||||
Self {
|
||||
name: model_name,
|
||||
max_completion_tokens,
|
||||
max_output_tokens,
|
||||
max_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse(&self, cx: &App) -> Result<AvailableModel, SharedString> {
|
||||
let name = self.name.read(cx).text(cx);
|
||||
if name.is_empty() {
|
||||
return Err(SharedString::from("Model Name cannot be empty"));
|
||||
}
|
||||
Ok(AvailableModel {
|
||||
name,
|
||||
display_name: None,
|
||||
max_completion_tokens: Some(
|
||||
self.max_completion_tokens
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.parse::<u64>()
|
||||
.map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
|
||||
),
|
||||
max_output_tokens: Some(
|
||||
self.max_output_tokens
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.parse::<u64>()
|
||||
.map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
|
||||
),
|
||||
max_tokens: self
|
||||
.max_tokens
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.parse::<u64>()
|
||||
.map_err(|_| SharedString::from("Max Tokens must be a number"))?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn single_line_input(
|
||||
label: impl Into<SharedString>,
|
||||
placeholder: impl Into<SharedString>,
|
||||
text: Option<&str>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<SingleLineInput> {
|
||||
cx.new(|cx| {
|
||||
let input = SingleLineInput::new(window, cx, placeholder).label(label);
|
||||
if let Some(text) = text {
|
||||
input
|
||||
.editor()
|
||||
.update(cx, |editor, cx| editor.set_text(text, window, cx));
|
||||
}
|
||||
input
|
||||
})
|
||||
}
|
||||
|
||||
fn save_provider_to_settings(
|
||||
input: &AddLlmProviderInput,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<(), SharedString>> {
|
||||
let provider_name: Arc<str> = input.provider_name.read(cx).text(cx).into();
|
||||
if provider_name.is_empty() {
|
||||
return Task::ready(Err("Provider Name cannot be empty".into()));
|
||||
}
|
||||
|
||||
if LanguageModelRegistry::read_global(cx)
|
||||
.providers()
|
||||
.iter()
|
||||
.any(|provider| {
|
||||
provider.id().0.as_ref() == provider_name.as_ref()
|
||||
|| provider.name().0.as_ref() == provider_name.as_ref()
|
||||
})
|
||||
{
|
||||
return Task::ready(Err(
|
||||
"Provider Name is already taken by another provider".into()
|
||||
));
|
||||
}
|
||||
|
||||
let api_url = input.api_url.read(cx).text(cx);
|
||||
if api_url.is_empty() {
|
||||
return Task::ready(Err("API URL cannot be empty".into()));
|
||||
}
|
||||
|
||||
let api_key = input.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return Task::ready(Err("API Key cannot be empty".into()));
|
||||
}
|
||||
|
||||
let mut models = Vec::new();
|
||||
let mut model_names: HashSet<String> = HashSet::default();
|
||||
for model in &input.models {
|
||||
match model.parse(cx) {
|
||||
Ok(model) => {
|
||||
if !model_names.insert(model.name.clone()) {
|
||||
return Task::ready(Err("Model Names must be unique".into()));
|
||||
}
|
||||
models.push(model)
|
||||
}
|
||||
Err(err) => return Task::ready(Err(err)),
|
||||
}
|
||||
}
|
||||
|
||||
let fs = <dyn Fs>::global(cx);
|
||||
let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
|
||||
cx.spawn(async move |cx| {
|
||||
task.await
|
||||
.map_err(|_| "Failed to write API key to keychain")?;
|
||||
cx.update(|cx| {
|
||||
update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| {
|
||||
settings.openai_compatible.get_or_insert_default().insert(
|
||||
provider_name,
|
||||
OpenAiCompatibleSettingsContent {
|
||||
api_url,
|
||||
available_models: models,
|
||||
},
|
||||
);
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub struct AddLlmProviderModal {
|
||||
provider: LlmCompatibleProvider,
|
||||
input: AddLlmProviderInput,
|
||||
focus_handle: FocusHandle,
|
||||
last_error: Option<SharedString>,
|
||||
}
|
||||
|
||||
impl AddLlmProviderModal {
|
||||
pub fn toggle(
|
||||
provider: LlmCompatibleProvider,
|
||||
workspace: &mut Workspace,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) {
|
||||
workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
|
||||
}
|
||||
|
||||
fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
Self {
|
||||
input: AddLlmProviderInput::new(provider, window, cx),
|
||||
provider,
|
||||
last_error: None,
|
||||
focus_handle: cx.focus_handle(),
|
||||
}
|
||||
}
|
||||
|
||||
fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context<Self>) {
|
||||
let task = save_provider_to_settings(&self.input, cx);
|
||||
cx.spawn(async move |this, cx| {
|
||||
let result = task.await;
|
||||
this.update(cx, |this, cx| match result {
|
||||
Ok(_) => {
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
Err(error) => {
|
||||
this.last_error = Some(error);
|
||||
cx.notify();
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
|
||||
fn render_section(&self) -> Section {
|
||||
Section::new()
|
||||
.child(self.input.provider_name.clone())
|
||||
.child(self.input.api_url.clone())
|
||||
.child(self.input.api_key.clone())
|
||||
}
|
||||
|
||||
fn render_model_section(&self, cx: &mut Context<Self>) -> Section {
|
||||
Section::new().child(
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_between()
|
||||
.child(Label::new("Models").size(LabelSize::Small))
|
||||
.child(
|
||||
Button::new("add-model", "Add Model")
|
||||
.icon(IconName::Plus)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.input.add_model(window, cx);
|
||||
cx.notify();
|
||||
})),
|
||||
),
|
||||
)
|
||||
.children(
|
||||
self.input
|
||||
.models
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(ix, _)| self.render_model(ix, cx)),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_model(&self, ix: usize, cx: &mut Context<Self>) -> impl IntoElement + use<> {
|
||||
let has_more_than_one_model = self.input.models.len() > 1;
|
||||
let model = &self.input.models[ix];
|
||||
|
||||
v_flex()
|
||||
.p_2()
|
||||
.gap_2()
|
||||
.rounded_sm()
|
||||
.border_1()
|
||||
.border_dashed()
|
||||
.border_color(cx.theme().colors().border.opacity(0.6))
|
||||
.bg(cx.theme().colors().element_active.opacity(0.15))
|
||||
.child(model.name.clone())
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(model.max_completion_tokens.clone())
|
||||
.child(model.max_output_tokens.clone()),
|
||||
)
|
||||
.child(model.max_tokens.clone())
|
||||
.when(has_more_than_one_model, |this| {
|
||||
this.child(
|
||||
Button::new(("remove-model", ix), "Remove Model")
|
||||
.icon(IconName::Trash)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.label_size(LabelSize::Small)
|
||||
.style(ButtonStyle::Outlined)
|
||||
.full_width()
|
||||
.on_click(cx.listener(move |this, _, _window, cx| {
|
||||
this.input.remove_model(ix);
|
||||
cx.notify();
|
||||
})),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
|
||||
|
||||
impl Focusable for AddLlmProviderModal {
|
||||
fn focus_handle(&self, _cx: &App) -> FocusHandle {
|
||||
self.focus_handle.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl ModalView for AddLlmProviderModal {}
|
||||
|
||||
impl Render for AddLlmProviderModal {
|
||||
fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
|
||||
let focus_handle = self.focus_handle(cx);
|
||||
|
||||
div()
|
||||
.id("add-llm-provider-modal")
|
||||
.key_context("AddLlmProviderModal")
|
||||
.w(rems(34.))
|
||||
.elevation_3(cx)
|
||||
.on_action(cx.listener(Self::cancel))
|
||||
.capture_any_mouse_down(cx.listener(|this, _, window, cx| {
|
||||
this.focus_handle(cx).focus(window);
|
||||
}))
|
||||
.child(
|
||||
Modal::new("configure-context-server", None)
|
||||
.header(ModalHeader::new().headline("Add LLM Provider").description(
|
||||
match self.provider {
|
||||
LlmCompatibleProvider::OpenAi => {
|
||||
"This provider will use an OpenAI compatible API."
|
||||
}
|
||||
},
|
||||
))
|
||||
.when_some(self.last_error.clone(), |this, error| {
|
||||
this.section(
|
||||
Section::new().child(
|
||||
Banner::new()
|
||||
.severity(ui::Severity::Warning)
|
||||
.child(div().text_xs().child(error)),
|
||||
),
|
||||
)
|
||||
})
|
||||
.child(
|
||||
v_flex()
|
||||
.id("modal_content")
|
||||
.max_h_128()
|
||||
.overflow_y_scroll()
|
||||
.gap_2()
|
||||
.child(self.render_section())
|
||||
.child(self.render_model_section(cx)),
|
||||
)
|
||||
.footer(
|
||||
ModalFooter::new().end_slot(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
Button::new("cancel", "Cancel")
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(
|
||||
&menu::Cancel,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(cx.listener(|this, _event, window, cx| {
|
||||
this.cancel(&menu::Cancel, window, cx)
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new("save-server", "Save Provider")
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(
|
||||
&menu::Confirm,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(cx.listener(|this, _event, window, cx| {
|
||||
this.confirm(&menu::Confirm, window, cx)
|
||||
})),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use editor::EditorSettings;
|
||||
use fs::FakeFs;
|
||||
use gpui::{TestAppContext, VisualTestContext};
|
||||
use language::language_settings;
|
||||
use language_model::{
|
||||
LanguageModelProviderId, LanguageModelProviderName,
|
||||
fake_provider::FakeLanguageModelProvider,
|
||||
};
|
||||
use project::Project;
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
|
||||
let cx = setup_test(cx).await;
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
|
||||
Some("Provider Name cannot be empty".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
|
||||
Some("API URL cannot be empty".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
|
||||
Some("API Key cannot be empty".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"somekey",
|
||||
vec![("", "200000", "200000", "32000")],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Model Name cannot be empty".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"somekey",
|
||||
vec![("somemodel", "abc", "200000", "32000")],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Max Tokens must be a number".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"somekey",
|
||||
vec![("somemodel", "200000", "abc", "32000")],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Max Completion Tokens must be a number".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"somekey",
|
||||
vec![("somemodel", "200000", "200000", "abc")],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Max Output Tokens must be a number".into())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"somekey",
|
||||
vec![
|
||||
("somemodel", "200000", "200000", "32000"),
|
||||
("somemodel", "200000", "200000", "32000"),
|
||||
],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Model Names must be unique".into())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
|
||||
let cx = setup_test(cx).await;
|
||||
|
||||
cx.update(|_window, cx| {
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.register_provider(
|
||||
FakeLanguageModelProvider::new(
|
||||
LanguageModelProviderId::new("someprovider"),
|
||||
LanguageModelProviderName::new("Some Provider"),
|
||||
),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
save_provider_validation_errors(
|
||||
"someprovider",
|
||||
"someurl",
|
||||
"someapikey",
|
||||
vec![("somemodel", "200000", "200000", "32000")],
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
Some("Provider Name is already taken by another provider".into())
|
||||
);
|
||||
}
|
||||
|
||||
async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
|
||||
cx.update(|cx| {
|
||||
let store = SettingsStore::test(cx);
|
||||
cx.set_global(store);
|
||||
workspace::init_settings(cx);
|
||||
Project::init_settings(cx);
|
||||
theme::init(theme::LoadThemes::JustBase, cx);
|
||||
language_settings::init(cx);
|
||||
EditorSettings::register(cx);
|
||||
language_model::init_settings(cx);
|
||||
language_models::init_settings(cx);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
cx.update(|cx| <dyn Fs>::set_global(fs.clone(), cx));
|
||||
let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
|
||||
let (_, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
|
||||
cx
|
||||
}
|
||||
|
||||
async fn save_provider_validation_errors(
|
||||
provider_name: &str,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
models: Vec<(&str, &str, &str, &str)>,
|
||||
cx: &mut VisualTestContext,
|
||||
) -> Option<SharedString> {
|
||||
fn set_text(
|
||||
input: &Entity<SingleLineInput>,
|
||||
text: &str,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
input.update(cx, |input, cx| {
|
||||
input.editor().update(cx, |editor, cx| {
|
||||
editor.set_text(text, window, cx);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
let task = cx.update(|window, cx| {
|
||||
let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
|
||||
set_text(&input.provider_name, provider_name, window, cx);
|
||||
set_text(&input.api_url, api_url, window, cx);
|
||||
set_text(&input.api_key, api_key, window, cx);
|
||||
|
||||
for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
|
||||
models.iter().enumerate()
|
||||
{
|
||||
if i >= input.models.len() {
|
||||
input.models.push(ModelInput::new(window, cx));
|
||||
}
|
||||
let model = &mut input.models[i];
|
||||
set_text(&model.name, name, window, cx);
|
||||
set_text(&model.max_tokens, max_tokens, window, cx);
|
||||
set_text(
|
||||
&model.max_completion_tokens,
|
||||
max_completion_tokens,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
set_text(&model.max_output_tokens, max_output_tokens, window, cx);
|
||||
}
|
||||
save_provider_to_settings(&input, cx)
|
||||
});
|
||||
|
||||
task.await.err()
|
||||
}
|
||||
}
|
|
@ -1323,7 +1323,7 @@ fn setup_context_editor_with_fake_model(
|
|||
) -> (Entity<AssistantContext>, Arc<FakeLanguageModel>) {
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.executor().clone()));
|
||||
|
||||
let fake_provider = Arc::new(FakeLanguageModelProvider);
|
||||
let fake_provider = Arc::new(FakeLanguageModelProvider::default());
|
||||
let fake_model = Arc::new(fake_provider.test_model());
|
||||
|
||||
cx.update(|cx| {
|
||||
|
|
|
@ -200,7 +200,7 @@ mod tests {
|
|||
|
||||
// Run the tool before any changes
|
||||
let tool = Arc::new(ProjectNotificationsTool);
|
||||
let provider = Arc::new(FakeLanguageModelProvider);
|
||||
let provider = Arc::new(FakeLanguageModelProvider::default());
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(provider.test_model());
|
||||
let request = Arc::new(LanguageModelRequest::default());
|
||||
let tool_input = json!({});
|
||||
|
|
|
@ -20,6 +20,7 @@ pub enum IconName {
|
|||
AiMistral,
|
||||
AiOllama,
|
||||
AiOpenAi,
|
||||
AiOpenAiCompat,
|
||||
AiOpenRouter,
|
||||
AiVZero,
|
||||
AiXAi,
|
||||
|
|
|
@ -10,25 +10,21 @@ use http_client::Result;
|
|||
use parking_lot::Mutex;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub fn language_model_id() -> LanguageModelId {
|
||||
LanguageModelId::from("fake".to_string())
|
||||
#[derive(Clone)]
|
||||
pub struct FakeLanguageModelProvider {
|
||||
id: LanguageModelProviderId,
|
||||
name: LanguageModelProviderName,
|
||||
}
|
||||
|
||||
pub fn language_model_name() -> LanguageModelName {
|
||||
LanguageModelName::from("Fake".to_string())
|
||||
impl Default for FakeLanguageModelProvider {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
id: LanguageModelProviderId::from("fake".to_string()),
|
||||
name: LanguageModelProviderName::from("Fake".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn provider_id() -> LanguageModelProviderId {
|
||||
LanguageModelProviderId::from("fake".to_string())
|
||||
}
|
||||
|
||||
pub fn provider_name() -> LanguageModelProviderName {
|
||||
LanguageModelProviderName::from("Fake".to_string())
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct FakeLanguageModelProvider;
|
||||
|
||||
impl LanguageModelProviderState for FakeLanguageModelProvider {
|
||||
type ObservableEntity = ();
|
||||
|
||||
|
@ -39,11 +35,11 @@ impl LanguageModelProviderState for FakeLanguageModelProvider {
|
|||
|
||||
impl LanguageModelProvider for FakeLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
provider_id()
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
provider_name()
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
|
@ -76,6 +72,10 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
|
|||
}
|
||||
|
||||
impl FakeLanguageModelProvider {
|
||||
pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
|
||||
Self { id, name }
|
||||
}
|
||||
|
||||
pub fn test_model(&self) -> FakeLanguageModel {
|
||||
FakeLanguageModel::default()
|
||||
}
|
||||
|
@ -89,11 +89,22 @@ pub struct ToolUseRequest {
|
|||
pub schema: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct FakeLanguageModel {
|
||||
provider_id: LanguageModelProviderId,
|
||||
provider_name: LanguageModelProviderName,
|
||||
current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
|
||||
}
|
||||
|
||||
impl Default for FakeLanguageModel {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
provider_id: LanguageModelProviderId::from("fake".to_string()),
|
||||
provider_name: LanguageModelProviderName::from("Fake".to_string()),
|
||||
current_completion_txs: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeLanguageModel {
|
||||
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
|
||||
self.current_completion_txs
|
||||
|
@ -138,19 +149,19 @@ impl FakeLanguageModel {
|
|||
|
||||
impl LanguageModel for FakeLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
language_model_id()
|
||||
LanguageModelId::from("fake".to_string())
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
language_model_name()
|
||||
LanguageModelName::from("Fake".to_string())
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
provider_id()
|
||||
self.provider_id.clone()
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
provider_name()
|
||||
self.provider_name.clone()
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
|
|
|
@ -735,6 +735,18 @@ impl From<String> for LanguageModelProviderName {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<Arc<str>> for LanguageModelProviderId {
|
||||
fn from(value: Arc<str>) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Arc<str>> for LanguageModelProviderName {
|
||||
fn from(value: Arc<str>) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -125,7 +125,7 @@ impl LanguageModelRegistry {
|
|||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
|
||||
let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
|
||||
let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default();
|
||||
let registry = cx.new(|cx| {
|
||||
let mut registry = Self::default();
|
||||
registry.register_provider(fake_provider.clone(), cx);
|
||||
|
@ -403,16 +403,17 @@ mod tests {
|
|||
fn test_register_providers(cx: &mut App) {
|
||||
let registry = cx.new(|_| LanguageModelRegistry::default());
|
||||
|
||||
let provider = FakeLanguageModelProvider::default();
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.register_provider(FakeLanguageModelProvider, cx);
|
||||
registry.register_provider(provider.clone(), cx);
|
||||
});
|
||||
|
||||
let providers = registry.read(cx).providers();
|
||||
assert_eq!(providers.len(), 1);
|
||||
assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
|
||||
assert_eq!(providers[0].id(), provider.id());
|
||||
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.unregister_provider(crate::fake_provider::provider_id(), cx);
|
||||
registry.unregister_provider(provider.id(), cx);
|
||||
});
|
||||
|
||||
let providers = registry.read(cx).providers();
|
||||
|
|
|
@ -26,10 +26,10 @@ client.workspace = true
|
|||
collections.workspace = true
|
||||
component.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
convert_case.workspace = true
|
||||
copilot.workspace = true
|
||||
deepseek = { workspace = true, features = ["schemars"] }
|
||||
editor.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
google_ai = { workspace = true, features = ["schemars"] }
|
||||
gpui.workspace = true
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use ::settings::{Settings, SettingsStore};
|
||||
use client::{Client, UserStore};
|
||||
use collections::HashSet;
|
||||
use gpui::{App, Context, Entity};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
|
||||
use provider::deepseek::DeepSeekLanguageModelProvider;
|
||||
|
||||
pub mod provider;
|
||||
|
@ -18,17 +20,81 @@ use crate::provider::lmstudio::LmStudioLanguageModelProvider;
|
|||
use crate::provider::mistral::MistralLanguageModelProvider;
|
||||
use crate::provider::ollama::OllamaLanguageModelProvider;
|
||||
use crate::provider::open_ai::OpenAiLanguageModelProvider;
|
||||
use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
|
||||
use crate::provider::open_router::OpenRouterLanguageModelProvider;
|
||||
use crate::provider::vercel::VercelLanguageModelProvider;
|
||||
use crate::provider::x_ai::XAiLanguageModelProvider;
|
||||
pub use crate::settings::*;
|
||||
|
||||
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
|
||||
crate::settings::init(cx);
|
||||
crate::settings::init_settings(cx);
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
registry.update(cx, |registry, cx| {
|
||||
register_language_model_providers(registry, user_store, client, cx);
|
||||
register_language_model_providers(registry, user_store, client.clone(), cx);
|
||||
});
|
||||
|
||||
let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
|
||||
.openai_compatible
|
||||
.keys()
|
||||
.cloned()
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
registry.update(cx, |registry, cx| {
|
||||
register_openai_compatible_providers(
|
||||
registry,
|
||||
&HashSet::default(),
|
||||
&openai_compatible_providers,
|
||||
client.clone(),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
cx.observe_global::<SettingsStore>(move |cx| {
|
||||
let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
|
||||
.openai_compatible
|
||||
.keys()
|
||||
.cloned()
|
||||
.collect::<HashSet<_>>();
|
||||
if openai_compatible_providers_new != openai_compatible_providers {
|
||||
registry.update(cx, |registry, cx| {
|
||||
register_openai_compatible_providers(
|
||||
registry,
|
||||
&openai_compatible_providers,
|
||||
&openai_compatible_providers_new,
|
||||
client.clone(),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
openai_compatible_providers = openai_compatible_providers_new;
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn register_openai_compatible_providers(
|
||||
registry: &mut LanguageModelRegistry,
|
||||
old: &HashSet<Arc<str>>,
|
||||
new: &HashSet<Arc<str>>,
|
||||
client: Arc<Client>,
|
||||
cx: &mut Context<LanguageModelRegistry>,
|
||||
) {
|
||||
for provider_id in old {
|
||||
if !new.contains(provider_id) {
|
||||
registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
|
||||
}
|
||||
}
|
||||
|
||||
for provider_id in new {
|
||||
if !old.contains(provider_id) {
|
||||
registry.register_provider(
|
||||
OpenAiCompatibleLanguageModelProvider::new(
|
||||
provider_id.clone(),
|
||||
client.http_client(),
|
||||
cx,
|
||||
),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_language_model_providers(
|
||||
|
|
|
@ -8,6 +8,7 @@ pub mod lmstudio;
|
|||
pub mod mistral;
|
||||
pub mod ollama;
|
||||
pub mod open_ai;
|
||||
pub mod open_ai_compatible;
|
||||
pub mod open_router;
|
||||
pub mod vercel;
|
||||
pub mod x_ai;
|
||||
|
|
|
@ -2,7 +2,6 @@ use anyhow::{Context as _, Result, anyhow};
|
|||
use collections::{BTreeMap, HashMap};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
|
||||
use fs::Fs;
|
||||
use futures::Stream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
|
||||
|
@ -18,7 +17,7 @@ use menu;
|
|||
use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore, update_settings_file};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr as _;
|
||||
use std::sync::Arc;
|
||||
|
@ -28,7 +27,6 @@ use ui::{ElevationIndex, List, Tooltip, prelude::*};
|
|||
use ui_input::SingleLineInput;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::OpenAiSettingsContent;
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
|
||||
|
@ -621,26 +619,32 @@ struct RawToolCall {
|
|||
arguments: String,
|
||||
}
|
||||
|
||||
pub(crate) fn collect_tiktoken_messages(
|
||||
request: LanguageModelRequest,
|
||||
) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
|
||||
request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.string_contents()),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub fn count_open_ai_tokens(
|
||||
request: LanguageModelRequest,
|
||||
model: Model,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
cx.background_spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.string_contents()),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let messages = collect_tiktoken_messages(request);
|
||||
|
||||
match model {
|
||||
Model::Custom { max_tokens, .. } => {
|
||||
|
@ -678,7 +682,6 @@ pub fn count_open_ai_tokens(
|
|||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<SingleLineInput>,
|
||||
api_url_editor: Entity<SingleLineInput>,
|
||||
state: gpui::Entity<State>,
|
||||
load_credentials_task: Option<Task<()>>,
|
||||
}
|
||||
|
@ -691,23 +694,6 @@ impl ConfigurationView {
|
|||
cx,
|
||||
"sk-000000000000000000000000000000000000000000000000",
|
||||
)
|
||||
.label("API key")
|
||||
});
|
||||
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.api_url
|
||||
.clone();
|
||||
|
||||
let api_url_editor = cx.new(|cx| {
|
||||
let input = SingleLineInput::new(window, cx, open_ai::OPEN_AI_API_URL).label("API URL");
|
||||
|
||||
if !api_url.is_empty() {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text(&*api_url, window, cx);
|
||||
});
|
||||
}
|
||||
input
|
||||
});
|
||||
|
||||
cx.observe(&state, |_, _, cx| {
|
||||
|
@ -735,7 +721,6 @@ impl ConfigurationView {
|
|||
|
||||
Self {
|
||||
api_key_editor,
|
||||
api_url_editor,
|
||||
state,
|
||||
load_credentials_task,
|
||||
}
|
||||
|
@ -783,57 +768,6 @@ impl ConfigurationView {
|
|||
cx.notify();
|
||||
}
|
||||
|
||||
fn save_api_url(&mut self, cx: &mut Context<Self>) {
|
||||
let api_url = self
|
||||
.api_url_editor
|
||||
.read(cx)
|
||||
.editor()
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
let current_url = AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.api_url
|
||||
.clone();
|
||||
|
||||
let effective_current_url = if current_url.is_empty() {
|
||||
open_ai::OPEN_AI_API_URL
|
||||
} else {
|
||||
¤t_url
|
||||
};
|
||||
|
||||
if !api_url.is_empty() && api_url != effective_current_url {
|
||||
let fs = <dyn Fs>::global(cx);
|
||||
update_settings_file::<AllLanguageModelSettings>(fs, cx, move |settings, _| {
|
||||
if let Some(settings) = settings.openai.as_mut() {
|
||||
settings.api_url = Some(api_url.clone());
|
||||
} else {
|
||||
settings.openai = Some(OpenAiSettingsContent {
|
||||
api_url: Some(api_url.clone()),
|
||||
available_models: None,
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn reset_api_url(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_url_editor.update(cx, |input, cx| {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text("", window, cx);
|
||||
});
|
||||
});
|
||||
let fs = <dyn Fs>::global(cx);
|
||||
update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| {
|
||||
if let Some(settings) = settings.openai.as_mut() {
|
||||
settings.api_url = None;
|
||||
}
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
|
||||
!self.state.read(cx).is_authenticated()
|
||||
}
|
||||
|
@ -846,7 +780,6 @@ impl Render for ConfigurationView {
|
|||
let api_key_section = if self.should_render_editor(cx) {
|
||||
v_flex()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
|
||||
.child(Label::new("To use Zed's assistant with OpenAI, you need to add an API key. Follow these steps:"))
|
||||
.child(
|
||||
List::new()
|
||||
|
@ -910,59 +843,34 @@ impl Render for ConfigurationView {
|
|||
.into_any()
|
||||
};
|
||||
|
||||
let custom_api_url_set =
|
||||
AllLanguageModelSettings::get_global(cx).openai.api_url != open_ai::OPEN_AI_API_URL;
|
||||
|
||||
let api_url_section = if custom_api_url_set {
|
||||
h_flex()
|
||||
.mt_1()
|
||||
.p_1()
|
||||
.justify_between()
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().background)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new("Custom API URL configured.")),
|
||||
)
|
||||
.child(
|
||||
Button::new("reset-api-url", "Reset API URL")
|
||||
.label_size(LabelSize::Small)
|
||||
.icon(IconName::Undo)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_position(IconPosition::Start)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.on_click(
|
||||
cx.listener(|this, _, window, cx| this.reset_api_url(window, cx)),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
} else {
|
||||
v_flex()
|
||||
.on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| {
|
||||
this.save_api_url(cx);
|
||||
cx.notify();
|
||||
}))
|
||||
.mt_2()
|
||||
.pt_2()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.gap_1()
|
||||
.child(
|
||||
List::new()
|
||||
.child(InstructionListItem::text_only(
|
||||
"Optionally, you can change the base URL for the OpenAI API request.",
|
||||
))
|
||||
.child(InstructionListItem::text_only(
|
||||
"Paste the new API endpoint below and hit enter",
|
||||
)),
|
||||
)
|
||||
.child(self.api_url_editor.clone())
|
||||
.into_any()
|
||||
};
|
||||
let compatible_api_section = h_flex()
|
||||
.mt_1p5()
|
||||
.gap_0p5()
|
||||
.flex_wrap()
|
||||
.when(self.should_render_editor(cx), |this| {
|
||||
this.pt_1p5()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
})
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Icon::new(IconName::Info)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(Label::new("Zed also supports OpenAI-compatible models.")),
|
||||
)
|
||||
.child(
|
||||
Button::new("docs", "Learn More")
|
||||
.icon(IconName::ArrowUpRight)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.on_click(move |_, _window, cx| {
|
||||
cx.open_url("https://zed.dev/docs/ai/configuration#openai-api-compatible")
|
||||
}),
|
||||
);
|
||||
|
||||
if self.load_credentials_task.is_some() {
|
||||
div().child(Label::new("Loading credentials…")).into_any()
|
||||
|
@ -970,7 +878,7 @@ impl Render for ConfigurationView {
|
|||
v_flex()
|
||||
.size_full()
|
||||
.child(api_key_section)
|
||||
.child(api_url_section)
|
||||
.child(compatible_api_section)
|
||||
.into_any()
|
||||
}
|
||||
}
|
||||
|
|
522
crates/language_models/src/provider/open_ai_compatible.rs
Normal file
522
crates/language_models/src/provider/open_ai_compatible.rs
Normal file
|
@ -0,0 +1,522 @@
|
|||
use anyhow::{Context as _, Result, anyhow};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
|
||||
use convert_case::{Case, Casing};
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, RateLimiter,
|
||||
};
|
||||
use menu;
|
||||
use open_ai::{ResponseStreamEvent, stream_completion};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
|
||||
use ui::{ElevationIndex, Tooltip, prelude::*};
|
||||
use ui_input::SingleLineInput;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai};
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct OpenAiCompatibleSettings {
|
||||
pub api_url: String,
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: u64,
|
||||
pub max_output_tokens: Option<u64>,
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
pub struct OpenAiCompatibleLanguageModelProvider {
|
||||
id: LanguageModelProviderId,
|
||||
name: LanguageModelProviderName,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
state: gpui::Entity<State>,
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
id: Arc<str>,
|
||||
env_var_name: Arc<str>,
|
||||
api_key: Option<String>,
|
||||
api_key_from_env: bool,
|
||||
settings: OpenAiCompatibleSettings,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = self.settings.api_url.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(&api_url, &cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = self.settings.api_url.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let env_var_name = self.env_var_name.clone();
|
||||
let api_url = self.settings.api_url.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(env_var_name.as_ref()) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, &cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAiCompatibleLanguageModelProvider {
|
||||
pub fn new(id: Arc<str>, http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
|
||||
AllLanguageModelSettings::get_global(cx)
|
||||
.openai_compatible
|
||||
.get(id)
|
||||
}
|
||||
|
||||
let state = cx.new(|cx| State {
|
||||
id: id.clone(),
|
||||
env_var_name: format!("{}_API_KEY", id).to_case(Case::Constant).into(),
|
||||
settings: resolve_settings(&id, cx).cloned().unwrap_or_default(),
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let Some(settings) = resolve_settings(&this.id, cx) else {
|
||||
return;
|
||||
};
|
||||
if &this.settings != settings {
|
||||
this.settings = settings.clone();
|
||||
cx.notify();
|
||||
}
|
||||
}),
|
||||
});
|
||||
|
||||
Self {
|
||||
id: id.clone().into(),
|
||||
name: id.into(),
|
||||
http_client,
|
||||
state,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> {
|
||||
Arc::new(OpenAiCompatibleLanguageModel {
|
||||
id: LanguageModelId::from(model.name.clone()),
|
||||
provider_id: self.id.clone(),
|
||||
provider_name: self.name.clone(),
|
||||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for OpenAiCompatibleLanguageModelProvider {
|
||||
type ObservableEntity = State;
|
||||
|
||||
fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
|
||||
Some(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::AiOpenAiCompat
|
||||
}
|
||||
|
||||
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.state
|
||||
.read(cx)
|
||||
.settings
|
||||
.available_models
|
||||
.first()
|
||||
.map(|model| self.create_language_model(model.clone()))
|
||||
}
|
||||
|
||||
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
None
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||
self.state
|
||||
.read(cx)
|
||||
.settings
|
||||
.available_models
|
||||
.iter()
|
||||
.map(|model| self.create_language_model(model.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, cx: &App) -> bool {
|
||||
self.state.read(cx).is_authenticated()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
|
||||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpenAiCompatibleLanguageModel {
|
||||
id: LanguageModelId,
|
||||
provider_id: LanguageModelProviderId,
|
||||
provider_name: LanguageModelProviderName,
|
||||
model: AvailableModel,
|
||||
state: gpui::Entity<State>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl OpenAiCompatibleLanguageModel {
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: open_ai::Request,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
||||
{
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, _| {
|
||||
(state.api_key.clone(), state.settings.api_url.clone())
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let provider = self.provider_name.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey { provider });
|
||||
};
|
||||
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let response = request.await?;
|
||||
Ok(response)
|
||||
});
|
||||
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for OpenAiCompatibleLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName::from(
|
||||
self.model
|
||||
.display_name
|
||||
.clone()
|
||||
.unwrap_or_else(|| self.model.name.clone()),
|
||||
)
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
self.provider_id.clone()
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
self.provider_name.clone()
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||
match choice {
|
||||
LanguageModelToolChoice::Auto => true,
|
||||
LanguageModelToolChoice::Any => true,
|
||||
LanguageModelToolChoice::None => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("openai/{}", self.model.name)
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_tokens
|
||||
}
|
||||
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
self.model.max_output_tokens
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
let max_token_count = self.max_token_count();
|
||||
cx.background_spawn(async move {
|
||||
let messages = super::open_ai::collect_tiktoken_messages(request);
|
||||
let model = if max_token_count >= 100_000 {
|
||||
// If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
|
||||
"gpt-4o"
|
||||
} else {
|
||||
// Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
|
||||
// supported with this tiktoken method
|
||||
"gpt-4"
|
||||
};
|
||||
tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
futures::stream::BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let request = into_open_ai(request, &self.model.name, true, self.max_output_tokens());
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
let mapper = OpenAiEventMapper::new();
|
||||
Ok(mapper.map_stream(completions.await?).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<SingleLineInput>,
|
||||
state: gpui::Entity<State>,
|
||||
load_credentials_task: Option<Task<()>>,
|
||||
}
|
||||
|
||||
impl ConfigurationView {
|
||||
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
let api_key_editor = cx.new(|cx| {
|
||||
SingleLineInput::new(
|
||||
window,
|
||||
cx,
|
||||
"000000000000000000000000000000000000000000000000000",
|
||||
)
|
||||
});
|
||||
|
||||
cx.observe(&state, |_, _, cx| {
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
|
||||
let load_credentials_task = Some(cx.spawn_in(window, {
|
||||
let state = state.clone();
|
||||
async move |this, cx| {
|
||||
if let Some(task) = state
|
||||
.update(cx, |state, cx| state.authenticate(cx))
|
||||
.log_err()
|
||||
{
|
||||
// We don't log an error, because "not signed in" is also an error.
|
||||
let _ = task.await;
|
||||
}
|
||||
this.update(cx, |this, cx| {
|
||||
this.load_credentials_task = None;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}));
|
||||
|
||||
Self {
|
||||
api_key_editor,
|
||||
state,
|
||||
load_credentials_task,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self
|
||||
.api_key_editor
|
||||
.read(cx)
|
||||
.editor()
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
// Don't proceed if no API key is provided and we're not authenticated
|
||||
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
|
||||
return;
|
||||
}
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_key_editor.update(cx, |input, cx| {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text("", window, cx);
|
||||
});
|
||||
});
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
|
||||
!self.state.read(cx).is_authenticated()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
let env_var_name = self.state.read(cx).env_var_name.clone();
|
||||
|
||||
let api_key_section = if self.should_render_editor(cx) {
|
||||
v_flex()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.child(Label::new("To use Zed's assistant with an OpenAI compatible provider, you need to add an API key."))
|
||||
.child(
|
||||
div()
|
||||
.pt(DynamicSpacing::Base04.rems(cx))
|
||||
.child(self.api_key_editor.clone())
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
format!("You can also assign the {env_var_name} environment variable and restart Zed."),
|
||||
)
|
||||
.size(LabelSize::Small).color(Color::Muted),
|
||||
)
|
||||
.into_any()
|
||||
} else {
|
||||
h_flex()
|
||||
.mt_1()
|
||||
.p_1()
|
||||
.justify_between()
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().background)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {env_var_name} environment variable.")
|
||||
} else {
|
||||
"API key configured.".to_string()
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new("reset-api-key", "Reset API Key")
|
||||
.label_size(LabelSize::Small)
|
||||
.icon(IconName::Undo)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_position(IconPosition::Start)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {env_var_name} environment variable.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
.into_any()
|
||||
};
|
||||
|
||||
if self.load_credentials_task.is_some() {
|
||||
div().child(Label::new("Loading credentials…")).into_any()
|
||||
} else {
|
||||
v_flex().size_full().child(api_key_section).into_any()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,4 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use gpui::App;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -15,13 +18,14 @@ use crate::provider::{
|
|||
mistral::MistralSettings,
|
||||
ollama::OllamaSettings,
|
||||
open_ai::OpenAiSettings,
|
||||
open_ai_compatible::OpenAiCompatibleSettings,
|
||||
open_router::OpenRouterSettings,
|
||||
vercel::VercelSettings,
|
||||
x_ai::XAiSettings,
|
||||
};
|
||||
|
||||
/// Initializes the language model settings.
|
||||
pub fn init(cx: &mut App) {
|
||||
pub fn init_settings(cx: &mut App) {
|
||||
AllLanguageModelSettings::register(cx);
|
||||
}
|
||||
|
||||
|
@ -36,6 +40,7 @@ pub struct AllLanguageModelSettings {
|
|||
pub ollama: OllamaSettings,
|
||||
pub open_router: OpenRouterSettings,
|
||||
pub openai: OpenAiSettings,
|
||||
pub openai_compatible: HashMap<Arc<str>, OpenAiCompatibleSettings>,
|
||||
pub vercel: VercelSettings,
|
||||
pub x_ai: XAiSettings,
|
||||
pub zed_dot_dev: ZedDotDevSettings,
|
||||
|
@ -52,6 +57,7 @@ pub struct AllLanguageModelSettingsContent {
|
|||
pub ollama: Option<OllamaSettingsContent>,
|
||||
pub open_router: Option<OpenRouterSettingsContent>,
|
||||
pub openai: Option<OpenAiSettingsContent>,
|
||||
pub openai_compatible: Option<HashMap<Arc<str>, OpenAiCompatibleSettingsContent>>,
|
||||
pub vercel: Option<VercelSettingsContent>,
|
||||
pub x_ai: Option<XAiSettingsContent>,
|
||||
#[serde(rename = "zed.dev")]
|
||||
|
@ -103,6 +109,12 @@ pub struct OpenAiSettingsContent {
|
|||
pub available_models: Option<Vec<provider::open_ai::AvailableModel>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct OpenAiCompatibleSettingsContent {
|
||||
pub api_url: String,
|
||||
pub available_models: Vec<provider::open_ai_compatible::AvailableModel>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct VercelSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
|
@ -226,6 +238,19 @@ impl settings::Settings for AllLanguageModelSettings {
|
|||
openai.as_ref().and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
// OpenAI Compatible
|
||||
if let Some(openai_compatible) = value.openai_compatible.clone() {
|
||||
for (id, openai_compatible_settings) in openai_compatible {
|
||||
settings.openai_compatible.insert(
|
||||
id,
|
||||
OpenAiCompatibleSettings {
|
||||
api_url: openai_compatible_settings.api_url,
|
||||
available_models: openai_compatible_settings.available_models,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Vercel
|
||||
let vercel = value.vercel.clone();
|
||||
merge(
|
||||
|
|
|
@ -93,6 +93,7 @@ impl RenderOnce for Modal {
|
|||
#[derive(IntoElement)]
|
||||
pub struct ModalHeader {
|
||||
headline: Option<SharedString>,
|
||||
description: Option<SharedString>,
|
||||
children: SmallVec<[AnyElement; 2]>,
|
||||
show_dismiss_button: bool,
|
||||
show_back_button: bool,
|
||||
|
@ -108,6 +109,7 @@ impl ModalHeader {
|
|||
pub fn new() -> Self {
|
||||
Self {
|
||||
headline: None,
|
||||
description: None,
|
||||
children: SmallVec::new(),
|
||||
show_dismiss_button: false,
|
||||
show_back_button: false,
|
||||
|
@ -123,6 +125,11 @@ impl ModalHeader {
|
|||
self
|
||||
}
|
||||
|
||||
pub fn description(mut self, description: impl Into<SharedString>) -> Self {
|
||||
self.description = Some(description.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn show_dismiss_button(mut self, show: bool) -> Self {
|
||||
self.show_dismiss_button = show;
|
||||
self
|
||||
|
@ -171,7 +178,14 @@ impl RenderOnce for ModalHeader {
|
|||
}),
|
||||
)
|
||||
})
|
||||
.child(div().flex_1().children(children))
|
||||
.child(
|
||||
v_flex().flex_1().children(children).when_some(
|
||||
self.description,
|
||||
|this, description| {
|
||||
this.child(Label::new(description).color(Color::Muted).mb_2())
|
||||
},
|
||||
),
|
||||
)
|
||||
.when(self.show_dismiss_button, |this| {
|
||||
this.child(
|
||||
IconButton::new("dismiss", IconName::Close)
|
||||
|
|
|
@ -97,6 +97,10 @@ impl SingleLineInput {
|
|||
pub fn editor(&self) -> &Entity<Editor> {
|
||||
&self.editor
|
||||
}
|
||||
|
||||
pub fn text(&self, cx: &App) -> String {
|
||||
self.editor().read(cx).text(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for SingleLineInput {
|
||||
|
|
|
@ -444,14 +444,17 @@ Custom models will be listed in the model dropdown in the Agent Panel.
|
|||
|
||||
### OpenAI API Compatible {#openai-api-compatible}
|
||||
|
||||
Zed supports using OpenAI compatible APIs by specifying a custom `endpoint` and `available_models` for the OpenAI provider.
|
||||
Zed supports using [OpenAI compatible APIs](https://platform.openai.com/docs/api-reference/chat) by specifying a custom `api_url` and `available_models` for the OpenAI provider. This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models.
|
||||
|
||||
Zed supports using OpenAI compatible APIs by specifying a custom `api_url` and `available_models` for the OpenAI provider. This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models.
|
||||
To configure a compatible API, you can add a custom API URL for OpenAI either via the UI (currently available only in Preview) or by editing your `settings.json`.
|
||||
|
||||
To configure a compatible API, you can add a custom API URL for OpenAI either via the UI or by editing your `settings.json`. For example, to connect to [Together AI](https://www.together.ai/):
|
||||
For example, to connect to [Together AI](https://www.together.ai/) via the UI:
|
||||
|
||||
1. Get an API key from your [Together AI account](https://api.together.ai/settings/api-keys).
|
||||
2. Add the following to your `settings.json`:
|
||||
1. Get an API key from your [Together AI account](https://api.together.ai/settings/api-keys).
|
||||
2. Go to the Agent Panel's settings view, click on the "Add Provider" button, and then on the "OpenAI" menu item
|
||||
3. Add the requested fields, such as `api_url`, `api_key`, available models, and others
|
||||
|
||||
Alternatively, you can also add it via the `settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue