Add capabilities to OpenAI-compatible model settings (#36370)
### TL;DR * Adds `capabilities` configuration for OpenAI-compatible models * Relates to https://github.com/zed-industries/zed/issues/36215#issuecomment-3193920491 ### Summary This PR introduces support for configuring model capabilities for OpenAI-compatible language models. The implementation addresses the issue that not all OpenAI-compatible APIs support the same features - for example, Cerebras' API explicitly does not support `parallel_tool_calls` as documented in their [OpenAI compatibility guide](https://inference-docs.cerebras.ai/resources/openai#currently-unsupported-openai-features). ### Changes 1. **Model Capabilities Structure**: - Added `ModelCapabilityToggles` struct for UI representation with boolean toggle states - Implemented proper parsing of capability toggles into `ModelCapabilities` 2. **UI Updates**: - Modified the "Add LLM Provider" modal to include checkboxes for each capability - Each OpenAI-compatible model can now be configured with its specific capabilities through the UI 3. **Configuration File Structure**: - Updated the settings schema to support a `capabilities` object for each `openai_compatible` model - Each capability (`tools`, `images`, `parallel_tool_calls`, `prompt_cache_key`) can be individually specified per model ### Example Configuration ```json { "openai_compatible": { "Cerebras": { "api_url": "https://api.cerebras.ai/v1", "available_models": [ { "name": "gpt-oss-120b", "max_tokens": 131000, "capabilities": { "tools": true, "images": false, "parallel_tool_calls": false, "prompt_cache_key": false } } ] } } } ``` ### Tests Added - Added tests to verify default capability values are correctly applied - Added tests to verify that deselected toggles are properly parsed as `false` - Added tests to verify that mixed capability selections work correctly Thanks to @osyvokon for the desired `capabilities` configuration structure! Release Notes: - OpenAI-compatible models now have configurable capabilities (#36370; thanks @calesennett) --------- Co-authored-by: Oleksiy Syvokon <oleksiy@zed.dev>
This commit is contained in:
parent
2bd61668dc
commit
c2f0df9b8e
3 changed files with 208 additions and 12 deletions
|
@ -7,10 +7,12 @@ use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, T
|
||||||
use language_model::LanguageModelRegistry;
|
use language_model::LanguageModelRegistry;
|
||||||
use language_models::{
|
use language_models::{
|
||||||
AllLanguageModelSettings, OpenAiCompatibleSettingsContent,
|
AllLanguageModelSettings, OpenAiCompatibleSettingsContent,
|
||||||
provider::open_ai_compatible::AvailableModel,
|
provider::open_ai_compatible::{AvailableModel, ModelCapabilities},
|
||||||
};
|
};
|
||||||
use settings::update_settings_file;
|
use settings::update_settings_file;
|
||||||
use ui::{Banner, KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
|
use ui::{
|
||||||
|
Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState, prelude::*,
|
||||||
|
};
|
||||||
use ui_input::SingleLineInput;
|
use ui_input::SingleLineInput;
|
||||||
use workspace::{ModalView, Workspace};
|
use workspace::{ModalView, Workspace};
|
||||||
|
|
||||||
|
@ -69,11 +71,19 @@ impl AddLlmProviderInput {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ModelCapabilityToggles {
|
||||||
|
pub supports_tools: ToggleState,
|
||||||
|
pub supports_images: ToggleState,
|
||||||
|
pub supports_parallel_tool_calls: ToggleState,
|
||||||
|
pub supports_prompt_cache_key: ToggleState,
|
||||||
|
}
|
||||||
|
|
||||||
struct ModelInput {
|
struct ModelInput {
|
||||||
name: Entity<SingleLineInput>,
|
name: Entity<SingleLineInput>,
|
||||||
max_completion_tokens: Entity<SingleLineInput>,
|
max_completion_tokens: Entity<SingleLineInput>,
|
||||||
max_output_tokens: Entity<SingleLineInput>,
|
max_output_tokens: Entity<SingleLineInput>,
|
||||||
max_tokens: Entity<SingleLineInput>,
|
max_tokens: Entity<SingleLineInput>,
|
||||||
|
capabilities: ModelCapabilityToggles,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ModelInput {
|
impl ModelInput {
|
||||||
|
@ -100,11 +110,23 @@ impl ModelInput {
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx);
|
let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx);
|
||||||
|
let ModelCapabilities {
|
||||||
|
tools,
|
||||||
|
images,
|
||||||
|
parallel_tool_calls,
|
||||||
|
prompt_cache_key,
|
||||||
|
} = ModelCapabilities::default();
|
||||||
Self {
|
Self {
|
||||||
name: model_name,
|
name: model_name,
|
||||||
max_completion_tokens,
|
max_completion_tokens,
|
||||||
max_output_tokens,
|
max_output_tokens,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
|
capabilities: ModelCapabilityToggles {
|
||||||
|
supports_tools: tools.into(),
|
||||||
|
supports_images: images.into(),
|
||||||
|
supports_parallel_tool_calls: parallel_tool_calls.into(),
|
||||||
|
supports_prompt_cache_key: prompt_cache_key.into(),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -136,6 +158,12 @@ impl ModelInput {
|
||||||
.text(cx)
|
.text(cx)
|
||||||
.parse::<u64>()
|
.parse::<u64>()
|
||||||
.map_err(|_| SharedString::from("Max Tokens must be a number"))?,
|
.map_err(|_| SharedString::from("Max Tokens must be a number"))?,
|
||||||
|
capabilities: ModelCapabilities {
|
||||||
|
tools: self.capabilities.supports_tools.selected(),
|
||||||
|
images: self.capabilities.supports_images.selected(),
|
||||||
|
parallel_tool_calls: self.capabilities.supports_parallel_tool_calls.selected(),
|
||||||
|
prompt_cache_key: self.capabilities.supports_prompt_cache_key.selected(),
|
||||||
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -322,6 +350,55 @@ impl AddLlmProviderModal {
|
||||||
.child(model.max_output_tokens.clone()),
|
.child(model.max_output_tokens.clone()),
|
||||||
)
|
)
|
||||||
.child(model.max_tokens.clone())
|
.child(model.max_tokens.clone())
|
||||||
|
.child(
|
||||||
|
v_flex()
|
||||||
|
.gap_1()
|
||||||
|
.child(
|
||||||
|
Checkbox::new(("supports-tools", ix), model.capabilities.supports_tools)
|
||||||
|
.label("Supports tools")
|
||||||
|
.on_click(cx.listener(move |this, checked, _window, cx| {
|
||||||
|
this.input.models[ix].capabilities.supports_tools = *checked;
|
||||||
|
cx.notify();
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
Checkbox::new(("supports-images", ix), model.capabilities.supports_images)
|
||||||
|
.label("Supports images")
|
||||||
|
.on_click(cx.listener(move |this, checked, _window, cx| {
|
||||||
|
this.input.models[ix].capabilities.supports_images = *checked;
|
||||||
|
cx.notify();
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
Checkbox::new(
|
||||||
|
("supports-parallel-tool-calls", ix),
|
||||||
|
model.capabilities.supports_parallel_tool_calls,
|
||||||
|
)
|
||||||
|
.label("Supports parallel_tool_calls")
|
||||||
|
.on_click(cx.listener(
|
||||||
|
move |this, checked, _window, cx| {
|
||||||
|
this.input.models[ix]
|
||||||
|
.capabilities
|
||||||
|
.supports_parallel_tool_calls = *checked;
|
||||||
|
cx.notify();
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
Checkbox::new(
|
||||||
|
("supports-prompt-cache-key", ix),
|
||||||
|
model.capabilities.supports_prompt_cache_key,
|
||||||
|
)
|
||||||
|
.label("Supports prompt_cache_key")
|
||||||
|
.on_click(cx.listener(
|
||||||
|
move |this, checked, _window, cx| {
|
||||||
|
this.input.models[ix].capabilities.supports_prompt_cache_key =
|
||||||
|
*checked;
|
||||||
|
cx.notify();
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
),
|
||||||
|
)
|
||||||
.when(has_more_than_one_model, |this| {
|
.when(has_more_than_one_model, |this| {
|
||||||
this.child(
|
this.child(
|
||||||
Button::new(("remove-model", ix), "Remove Model")
|
Button::new(("remove-model", ix), "Remove Model")
|
||||||
|
@ -562,6 +639,93 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_model_input_default_capabilities(cx: &mut TestAppContext) {
|
||||||
|
let cx = setup_test(cx).await;
|
||||||
|
|
||||||
|
cx.update(|window, cx| {
|
||||||
|
let model_input = ModelInput::new(window, cx);
|
||||||
|
model_input.name.update(cx, |input, cx| {
|
||||||
|
input.editor().update(cx, |editor, cx| {
|
||||||
|
editor.set_text("somemodel", window, cx);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
assert_eq!(
|
||||||
|
model_input.capabilities.supports_tools,
|
||||||
|
ToggleState::Selected
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
model_input.capabilities.supports_images,
|
||||||
|
ToggleState::Unselected
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
model_input.capabilities.supports_parallel_tool_calls,
|
||||||
|
ToggleState::Unselected
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
model_input.capabilities.supports_prompt_cache_key,
|
||||||
|
ToggleState::Unselected
|
||||||
|
);
|
||||||
|
|
||||||
|
let parsed_model = model_input.parse(cx).unwrap();
|
||||||
|
assert_eq!(parsed_model.capabilities.tools, true);
|
||||||
|
assert_eq!(parsed_model.capabilities.images, false);
|
||||||
|
assert_eq!(parsed_model.capabilities.parallel_tool_calls, false);
|
||||||
|
assert_eq!(parsed_model.capabilities.prompt_cache_key, false);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_model_input_deselected_capabilities(cx: &mut TestAppContext) {
|
||||||
|
let cx = setup_test(cx).await;
|
||||||
|
|
||||||
|
cx.update(|window, cx| {
|
||||||
|
let mut model_input = ModelInput::new(window, cx);
|
||||||
|
model_input.name.update(cx, |input, cx| {
|
||||||
|
input.editor().update(cx, |editor, cx| {
|
||||||
|
editor.set_text("somemodel", window, cx);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
model_input.capabilities.supports_tools = ToggleState::Unselected;
|
||||||
|
model_input.capabilities.supports_images = ToggleState::Unselected;
|
||||||
|
model_input.capabilities.supports_parallel_tool_calls = ToggleState::Unselected;
|
||||||
|
model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
|
||||||
|
|
||||||
|
let parsed_model = model_input.parse(cx).unwrap();
|
||||||
|
assert_eq!(parsed_model.capabilities.tools, false);
|
||||||
|
assert_eq!(parsed_model.capabilities.images, false);
|
||||||
|
assert_eq!(parsed_model.capabilities.parallel_tool_calls, false);
|
||||||
|
assert_eq!(parsed_model.capabilities.prompt_cache_key, false);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_model_input_with_name_and_capabilities(cx: &mut TestAppContext) {
|
||||||
|
let cx = setup_test(cx).await;
|
||||||
|
|
||||||
|
cx.update(|window, cx| {
|
||||||
|
let mut model_input = ModelInput::new(window, cx);
|
||||||
|
model_input.name.update(cx, |input, cx| {
|
||||||
|
input.editor().update(cx, |editor, cx| {
|
||||||
|
editor.set_text("somemodel", window, cx);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
model_input.capabilities.supports_tools = ToggleState::Selected;
|
||||||
|
model_input.capabilities.supports_images = ToggleState::Unselected;
|
||||||
|
model_input.capabilities.supports_parallel_tool_calls = ToggleState::Selected;
|
||||||
|
model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
|
||||||
|
|
||||||
|
let parsed_model = model_input.parse(cx).unwrap();
|
||||||
|
assert_eq!(parsed_model.name, "somemodel");
|
||||||
|
assert_eq!(parsed_model.capabilities.tools, true);
|
||||||
|
assert_eq!(parsed_model.capabilities.images, false);
|
||||||
|
assert_eq!(parsed_model.capabilities.parallel_tool_calls, true);
|
||||||
|
assert_eq!(parsed_model.capabilities.prompt_cache_key, false);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
|
async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
let store = SettingsStore::test(cx);
|
let store = SettingsStore::test(cx);
|
||||||
|
|
|
@ -38,6 +38,27 @@ pub struct AvailableModel {
|
||||||
pub max_tokens: u64,
|
pub max_tokens: u64,
|
||||||
pub max_output_tokens: Option<u64>,
|
pub max_output_tokens: Option<u64>,
|
||||||
pub max_completion_tokens: Option<u64>,
|
pub max_completion_tokens: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub capabilities: ModelCapabilities,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||||
|
pub struct ModelCapabilities {
|
||||||
|
pub tools: bool,
|
||||||
|
pub images: bool,
|
||||||
|
pub parallel_tool_calls: bool,
|
||||||
|
pub prompt_cache_key: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ModelCapabilities {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
tools: true,
|
||||||
|
images: false,
|
||||||
|
parallel_tool_calls: false,
|
||||||
|
prompt_cache_key: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct OpenAiCompatibleLanguageModelProvider {
|
pub struct OpenAiCompatibleLanguageModelProvider {
|
||||||
|
@ -293,17 +314,17 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tools(&self) -> bool {
|
fn supports_tools(&self) -> bool {
|
||||||
true
|
self.model.capabilities.tools
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_images(&self) -> bool {
|
fn supports_images(&self) -> bool {
|
||||||
false
|
self.model.capabilities.images
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||||
match choice {
|
match choice {
|
||||||
LanguageModelToolChoice::Auto => true,
|
LanguageModelToolChoice::Auto => self.model.capabilities.tools,
|
||||||
LanguageModelToolChoice::Any => true,
|
LanguageModelToolChoice::Any => self.model.capabilities.tools,
|
||||||
LanguageModelToolChoice::None => true,
|
LanguageModelToolChoice::None => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -355,13 +376,11 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
|
||||||
LanguageModelCompletionError,
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let supports_parallel_tool_call = true;
|
|
||||||
let supports_prompt_cache_key = false;
|
|
||||||
let request = into_open_ai(
|
let request = into_open_ai(
|
||||||
request,
|
request,
|
||||||
&self.model.name,
|
&self.model.name,
|
||||||
supports_parallel_tool_call,
|
self.model.capabilities.parallel_tool_calls,
|
||||||
supports_prompt_cache_key,
|
self.model.capabilities.prompt_cache_key,
|
||||||
self.max_output_tokens(),
|
self.max_output_tokens(),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
|
@ -427,7 +427,7 @@ Custom models will be listed in the model dropdown in the Agent Panel.
|
||||||
Zed supports using [OpenAI compatible APIs](https://platform.openai.com/docs/api-reference/chat) by specifying a custom `api_url` and `available_models` for the OpenAI provider.
|
Zed supports using [OpenAI compatible APIs](https://platform.openai.com/docs/api-reference/chat) by specifying a custom `api_url` and `available_models` for the OpenAI provider.
|
||||||
This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models.
|
This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models.
|
||||||
|
|
||||||
You can add a custom, OpenAI-compatible model via either via the UI or by editing your `settings.json`.
|
You can add a custom, OpenAI-compatible model either via the UI or by editing your `settings.json`.
|
||||||
|
|
||||||
To do it via the UI, go to the Agent Panel settings (`agent: open settings`) and look for the "Add Provider" button to the right of the "LLM Providers" section title.
|
To do it via the UI, go to the Agent Panel settings (`agent: open settings`) and look for the "Add Provider" button to the right of the "LLM Providers" section title.
|
||||||
Then, fill up the input fields available in the modal.
|
Then, fill up the input fields available in the modal.
|
||||||
|
@ -443,7 +443,13 @@ To do it via your `settings.json`, add the following snippet under `language_mod
|
||||||
{
|
{
|
||||||
"name": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
"name": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
"display_name": "Together Mixtral 8x7B",
|
"display_name": "Together Mixtral 8x7B",
|
||||||
"max_tokens": 32768
|
"max_tokens": 32768,
|
||||||
|
"capabilities": {
|
||||||
|
"tools": true,
|
||||||
|
"images": false,
|
||||||
|
"parallel_tool_calls": false,
|
||||||
|
"prompt_cache_key": false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -451,6 +457,13 @@ To do it via your `settings.json`, add the following snippet under `language_mod
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
By default, OpenAI-compatible models inherit the following capabilities:
|
||||||
|
|
||||||
|
- `tools`: true (supports tool/function calling)
|
||||||
|
- `images`: false (does not support image inputs)
|
||||||
|
- `parallel_tool_calls`: false (does not support `parallel_tool_calls` parameter)
|
||||||
|
- `prompt_cache_key`: false (does not support `prompt_cache_key` parameter)
|
||||||
|
|
||||||
Note that LLM API keys aren't stored in your settings file.
|
Note that LLM API keys aren't stored in your settings file.
|
||||||
So, ensure you have it set in your environment variables (`OPENAI_API_KEY=<your api key>`) so your settings can pick it up.
|
So, ensure you have it set in your environment variables (`OPENAI_API_KEY=<your api key>`) so your settings can pick it up.
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue