ai: Separate model settings for each feature (#28088)
Closes: https://github.com/zed-industries/zed/issues/20582 Allows users to select a specific model for each AI-powered feature: - Agent panel - Inline assistant - Thread summarization - Commit message generation If unspecified for a given feature, it will use the `default_model` setting. Release Notes: - Added support for configuring a specific model for each AI-powered feature --------- Co-authored-by: Danilo Leal <daniloleal09@gmail.com> Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
This commit is contained in:
parent
cf0d1e4229
commit
43cb925a59
27 changed files with 670 additions and 381 deletions
|
@ -21,7 +21,7 @@ use gpui::{
|
|||
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
|
||||
};
|
||||
use language::{Buffer, LanguageRegistry};
|
||||
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
|
||||
use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelToolUseId, Role};
|
||||
use markdown::{Markdown, MarkdownStyle};
|
||||
use project::ProjectItem as _;
|
||||
use settings::{Settings as _, update_settings_file};
|
||||
|
@ -606,7 +606,7 @@ impl ActiveThread {
|
|||
|
||||
if self.thread.read(cx).all_tools_finished() {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
if let Some(model) = model_registry.active_model() {
|
||||
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.attach_tool_results(cx);
|
||||
if !canceled {
|
||||
|
@ -814,21 +814,17 @@ impl ActiveThread {
|
|||
}
|
||||
});
|
||||
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider();
|
||||
if provider
|
||||
.as_ref()
|
||||
.map_or(false, |provider| provider.must_accept_terms(cx))
|
||||
{
|
||||
cx.notify();
|
||||
return;
|
||||
}
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let Some(model) = model_registry.active_model() else {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
|
||||
return;
|
||||
};
|
||||
|
||||
if model.provider.must_accept_terms(cx) {
|
||||
cx.notify();
|
||||
return;
|
||||
}
|
||||
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.send_to_model(model, RequestKind::Chat, cx)
|
||||
thread.send_to_model(model.model, RequestKind::Chat, cx)
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
|
|
|
@ -202,43 +202,43 @@ impl PickerDelegate for ToolPickerDelegate {
|
|||
let default_profile = self.profile.clone();
|
||||
let tool = tool.clone();
|
||||
move |settings, _cx| match settings {
|
||||
AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(
|
||||
settings,
|
||||
)) => {
|
||||
let profiles = settings.profiles.get_or_insert_default();
|
||||
let profile =
|
||||
profiles
|
||||
.entry(profile_id)
|
||||
.or_insert_with(|| AgentProfileContent {
|
||||
name: default_profile.name.into(),
|
||||
tools: default_profile.tools,
|
||||
enable_all_context_servers: Some(
|
||||
default_profile.enable_all_context_servers,
|
||||
),
|
||||
context_servers: default_profile
|
||||
.context_servers
|
||||
.into_iter()
|
||||
.map(|(server_id, preset)| {
|
||||
(
|
||||
server_id,
|
||||
ContextServerPresetContent {
|
||||
tools: preset.tools,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
});
|
||||
AssistantSettingsContent::Versioned(boxed) => {
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
let profiles = settings.profiles.get_or_insert_default();
|
||||
let profile =
|
||||
profiles
|
||||
.entry(profile_id)
|
||||
.or_insert_with(|| AgentProfileContent {
|
||||
name: default_profile.name.into(),
|
||||
tools: default_profile.tools,
|
||||
enable_all_context_servers: Some(
|
||||
default_profile.enable_all_context_servers,
|
||||
),
|
||||
context_servers: default_profile
|
||||
.context_servers
|
||||
.into_iter()
|
||||
.map(|(server_id, preset)| {
|
||||
(
|
||||
server_id,
|
||||
ContextServerPresetContent {
|
||||
tools: preset.tools,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
});
|
||||
|
||||
match tool.source {
|
||||
ToolSource::Native => {
|
||||
*profile.tools.entry(tool.name).or_default() = is_enabled;
|
||||
}
|
||||
ToolSource::ContextServer { id } => {
|
||||
let preset = profile
|
||||
.context_servers
|
||||
.entry(id.clone().into())
|
||||
.or_default();
|
||||
*preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
|
||||
match tool.source {
|
||||
ToolSource::Native => {
|
||||
*profile.tools.entry(tool.name).or_default() = is_enabled;
|
||||
}
|
||||
ToolSource::ContextServer { id } => {
|
||||
let preset = profile
|
||||
.context_servers
|
||||
.entry(id.clone().into())
|
||||
.or_default();
|
||||
*preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,10 +9,17 @@ use settings::update_settings_file;
|
|||
use std::sync::Arc;
|
||||
use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum ModelType {
|
||||
Default,
|
||||
InlineAssistant,
|
||||
}
|
||||
|
||||
pub struct AssistantModelSelector {
|
||||
selector: Entity<LanguageModelSelector>,
|
||||
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
|
||||
focus_handle: FocusHandle,
|
||||
model_type: ModelType,
|
||||
}
|
||||
|
||||
impl AssistantModelSelector {
|
||||
|
@ -20,6 +27,7 @@ impl AssistantModelSelector {
|
|||
fs: Arc<dyn Fs>,
|
||||
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
|
||||
focus_handle: FocusHandle,
|
||||
model_type: ModelType,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
|
@ -28,11 +36,32 @@ impl AssistantModelSelector {
|
|||
let fs = fs.clone();
|
||||
LanguageModelSelector::new(
|
||||
move |model, cx| {
|
||||
update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
cx,
|
||||
move |settings, _cx| settings.set_model(model.clone()),
|
||||
);
|
||||
let provider = model.provider_id().0.to_string();
|
||||
let model_id = model.id().0.to_string();
|
||||
|
||||
match model_type {
|
||||
ModelType::Default => {
|
||||
update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
cx,
|
||||
move |settings, _cx| {
|
||||
settings.set_model(model.clone());
|
||||
},
|
||||
);
|
||||
}
|
||||
ModelType::InlineAssistant => {
|
||||
update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
cx,
|
||||
move |settings, _cx| {
|
||||
settings.set_inline_assistant_model(
|
||||
provider.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
|
@ -40,6 +69,7 @@ impl AssistantModelSelector {
|
|||
}),
|
||||
menu_handle,
|
||||
focus_handle,
|
||||
model_type,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,10 +80,16 @@ impl AssistantModelSelector {
|
|||
|
||||
impl Render for AssistantModelSelector {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let active_model = LanguageModelRegistry::read_global(cx).active_model();
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
|
||||
let model = match self.model_type {
|
||||
ModelType::Default => model_registry.default_model(),
|
||||
ModelType::InlineAssistant => model_registry.inline_assistant_model(),
|
||||
};
|
||||
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
let model_name = match active_model {
|
||||
Some(model) => model.name().0,
|
||||
let model_name = match model {
|
||||
Some(model) => model.model.name().0,
|
||||
_ => SharedString::from("No model selected"),
|
||||
};
|
||||
|
||||
|
|
|
@ -571,10 +571,8 @@ impl AssistantPanel {
|
|||
match event {
|
||||
AssistantConfigurationEvent::NewThread(provider) => {
|
||||
if LanguageModelRegistry::read_global(cx)
|
||||
.active_provider()
|
||||
.map_or(true, |active_provider| {
|
||||
active_provider.id() != provider.id()
|
||||
})
|
||||
.default_model()
|
||||
.map_or(true, |model| model.provider.id() != provider.id())
|
||||
{
|
||||
if let Some(model) = provider.default_model(cx) {
|
||||
update_settings_file::<AssistantSettings>(
|
||||
|
@ -922,16 +920,18 @@ impl AssistantPanel {
|
|||
}
|
||||
|
||||
fn configuration_error(&self, cx: &App) -> Option<ConfigurationError> {
|
||||
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
|
||||
return Some(ConfigurationError::NoProvider);
|
||||
};
|
||||
|
||||
if !provider.is_authenticated(cx) {
|
||||
if !model.provider.is_authenticated(cx) {
|
||||
return Some(ConfigurationError::ProviderNotAuthenticated);
|
||||
}
|
||||
|
||||
if provider.must_accept_terms(cx) {
|
||||
return Some(ConfigurationError::ProviderPendingTermsAcceptance(provider));
|
||||
if model.provider.must_accept_terms(cx) {
|
||||
return Some(ConfigurationError::ProviderPendingTermsAcceptance(
|
||||
model.provider,
|
||||
));
|
||||
}
|
||||
|
||||
None
|
||||
|
|
|
@ -156,8 +156,9 @@ impl BufferCodegen {
|
|||
}
|
||||
|
||||
let primary_model = LanguageModelRegistry::read_global(cx)
|
||||
.active_model()
|
||||
.context("no active model")?;
|
||||
.default_model()
|
||||
.context("no active model")?
|
||||
.model;
|
||||
|
||||
for (model, alternative) in iter::once(primary_model)
|
||||
.chain(alternative_models)
|
||||
|
|
|
@ -239,8 +239,8 @@ impl InlineAssistant {
|
|||
|
||||
let is_authenticated = || {
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.active_provider()
|
||||
.map_or(false, |provider| provider.is_authenticated(cx))
|
||||
.inline_assistant_model()
|
||||
.map_or(false, |model| model.provider.is_authenticated(cx))
|
||||
};
|
||||
|
||||
let thread_store = workspace
|
||||
|
@ -279,8 +279,8 @@ impl InlineAssistant {
|
|||
cx.spawn_in(window, async move |_workspace, cx| {
|
||||
let Some(task) = cx.update(|_, cx| {
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.active_provider()
|
||||
.map_or(None, |provider| Some(provider.authenticate(cx)))
|
||||
.inline_assistant_model()
|
||||
.map_or(None, |model| Some(model.provider.authenticate(cx)))
|
||||
})?
|
||||
else {
|
||||
let answer = cx
|
||||
|
@ -401,14 +401,14 @@ impl InlineAssistant {
|
|||
|
||||
codegen_ranges.push(anchor_range);
|
||||
|
||||
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
|
||||
if let Some(model) = LanguageModelRegistry::read_global(cx).inline_assistant_model() {
|
||||
self.telemetry.report_assistant_event(AssistantEvent {
|
||||
conversation_id: None,
|
||||
kind: AssistantKind::Inline,
|
||||
phase: AssistantPhase::Invoked,
|
||||
message_id: None,
|
||||
model: model.telemetry_id(),
|
||||
model_provider: model.provider_id().to_string(),
|
||||
model: model.model.telemetry_id(),
|
||||
model_provider: model.provider.id().to_string(),
|
||||
response_latency: None,
|
||||
error_message: None,
|
||||
language_name: buffer.language().map(|language| language.name().to_proto()),
|
||||
|
@ -976,7 +976,7 @@ impl InlineAssistant {
|
|||
let active_alternative = assist.codegen.read(cx).active_alternative().clone();
|
||||
let message_id = active_alternative.read(cx).message_id.clone();
|
||||
|
||||
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
|
||||
if let Some(model) = LanguageModelRegistry::read_global(cx).inline_assistant_model() {
|
||||
let language_name = assist.editor.upgrade().and_then(|editor| {
|
||||
let multibuffer = editor.read(cx).buffer().read(cx);
|
||||
let snapshot = multibuffer.snapshot(cx);
|
||||
|
@ -996,15 +996,15 @@ impl InlineAssistant {
|
|||
} else {
|
||||
AssistantPhase::Accepted
|
||||
},
|
||||
model: model.telemetry_id(),
|
||||
model_provider: model.provider_id().to_string(),
|
||||
model: model.model.telemetry_id(),
|
||||
model_provider: model.model.provider_id().to_string(),
|
||||
response_latency: None,
|
||||
error_message: None,
|
||||
language_name: language_name.map(|name| name.to_proto()),
|
||||
},
|
||||
Some(self.telemetry.clone()),
|
||||
cx.http_client(),
|
||||
model.api_key(cx),
|
||||
model.model.api_key(cx),
|
||||
cx.background_executor(),
|
||||
);
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::assistant_model_selector::AssistantModelSelector;
|
||||
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
|
||||
use crate::buffer_codegen::BufferCodegen;
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::ContextStore;
|
||||
|
@ -582,7 +582,7 @@ impl<T: 'static> PromptEditor<T> {
|
|||
let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
|
||||
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let default_model = model_registry.active_model();
|
||||
let default_model = model_registry.default_model().map(|default| default.model);
|
||||
let alternative_models = model_registry.inline_alternative_models();
|
||||
|
||||
let get_model_name = |index: usize| -> String {
|
||||
|
@ -890,6 +890,7 @@ impl PromptEditor<BufferCodegen> {
|
|||
fs,
|
||||
model_selector_menu_handle,
|
||||
prompt_editor.focus_handle(cx),
|
||||
ModelType::InlineAssistant,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
|
@ -1042,6 +1043,7 @@ impl PromptEditor<TerminalCodegen> {
|
|||
fs,
|
||||
model_selector_menu_handle.clone(),
|
||||
prompt_editor.focus_handle(cx),
|
||||
ModelType::InlineAssistant,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::assistant_model_selector::ModelType;
|
||||
use collections::HashSet;
|
||||
use editor::actions::MoveUp;
|
||||
use editor::{ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorStyle};
|
||||
|
@ -10,7 +11,7 @@ use gpui::{
|
|||
WeakEntity, linear_color_stop, linear_gradient, point,
|
||||
};
|
||||
use language::Buffer;
|
||||
use language_model::LanguageModelRegistry;
|
||||
use language_model::{ConfiguredModel, LanguageModelRegistry};
|
||||
use language_model_selector::ToggleModelSelector;
|
||||
use multi_buffer;
|
||||
use project::Project;
|
||||
|
@ -139,6 +140,7 @@ impl MessageEditor {
|
|||
fs.clone(),
|
||||
model_selector_menu_handle,
|
||||
editor.focus_handle(cx),
|
||||
ModelType::Default,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
|
@ -191,7 +193,7 @@ impl MessageEditor {
|
|||
|
||||
fn is_model_selected(&self, cx: &App) -> bool {
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.active_model()
|
||||
.default_model()
|
||||
.is_some()
|
||||
}
|
||||
|
||||
|
@ -201,20 +203,16 @@ impl MessageEditor {
|
|||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider();
|
||||
if provider
|
||||
.as_ref()
|
||||
.map_or(false, |provider| provider.must_accept_terms(cx))
|
||||
{
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else {
|
||||
return;
|
||||
};
|
||||
|
||||
if provider.must_accept_terms(cx) {
|
||||
cx.notify();
|
||||
return;
|
||||
}
|
||||
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let Some(model) = model_registry.active_model() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let user_message = self.editor.update(cx, |editor, cx| {
|
||||
let text = editor.text(cx);
|
||||
editor.clear(window, cx);
|
||||
|
|
|
@ -130,8 +130,8 @@ impl Render for ProfileSelector {
|
|||
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let supports_tools = model_registry
|
||||
.active_model()
|
||||
.map_or(false, |model| model.supports_tools());
|
||||
.default_model()
|
||||
.map_or(false, |default| default.model.supports_tools());
|
||||
|
||||
let icon = match profile_id.as_str() {
|
||||
"write" => IconName::Pencil,
|
||||
|
|
|
@ -2,7 +2,9 @@ use crate::inline_prompt_editor::CodegenStatus;
|
|||
use client::telemetry::Telemetry;
|
||||
use futures::{SinkExt, StreamExt, channel::mpsc};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Task};
|
||||
use language_model::{LanguageModelRegistry, LanguageModelRequest, report_assistant_event};
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, report_assistant_event,
|
||||
};
|
||||
use std::{sync::Arc, time::Instant};
|
||||
use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
|
||||
use terminal::Terminal;
|
||||
|
@ -31,7 +33,9 @@ impl TerminalCodegen {
|
|||
}
|
||||
|
||||
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
let Some(ConfiguredModel { model, .. }) =
|
||||
LanguageModelRegistry::read_global(cx).inline_assistant_model()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
|
|
|
@ -13,8 +13,8 @@ use fs::Fs;
|
|||
use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
|
||||
use language::Buffer;
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
report_assistant_event,
|
||||
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
Role, report_assistant_event,
|
||||
};
|
||||
use prompt_store::PromptBuilder;
|
||||
use std::sync::Arc;
|
||||
|
@ -286,7 +286,9 @@ impl TerminalInlineAssistant {
|
|||
})
|
||||
.log_err();
|
||||
|
||||
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
|
||||
if let Some(ConfiguredModel { model, .. }) =
|
||||
LanguageModelRegistry::read_global(cx).inline_assistant_model()
|
||||
{
|
||||
let codegen = assist.codegen.read(cx);
|
||||
let executor = cx.background_executor().clone();
|
||||
report_assistant_event(
|
||||
|
|
|
@ -14,10 +14,10 @@ use futures::{FutureExt, StreamExt as _};
|
|||
use git::repository::DiffType;
|
||||
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
|
||||
Role, StopReason, TokenUsage,
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
PaymentRequiredError, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
|
||||
use project::{Project, Worktree};
|
||||
|
@ -1250,14 +1250,11 @@ impl Thread {
|
|||
}
|
||||
|
||||
pub fn summarize(&mut self, cx: &mut Context<Self>) {
|
||||
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
|
||||
return;
|
||||
};
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
|
||||
return;
|
||||
};
|
||||
|
||||
if !provider.is_authenticated(cx) {
|
||||
if !model.provider.is_authenticated(cx) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1276,7 +1273,7 @@ impl Thread {
|
|||
|
||||
self.pending_summary = cx.spawn(async move |this, cx| {
|
||||
async move {
|
||||
let stream = model.stream_completion_text(request, &cx);
|
||||
let stream = model.model.stream_completion_text(request, &cx);
|
||||
let mut messages = stream.await?;
|
||||
|
||||
let mut new_summary = String::new();
|
||||
|
@ -1320,8 +1317,8 @@ impl Thread {
|
|||
_ => {}
|
||||
}
|
||||
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
|
||||
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||
let ConfiguredModel { model, provider } =
|
||||
LanguageModelRegistry::read_global(cx).thread_summary_model()?;
|
||||
|
||||
if !provider.is_authenticated(cx) {
|
||||
return None;
|
||||
|
@ -1782,11 +1779,11 @@ impl Thread {
|
|||
|
||||
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let Some(model) = model_registry.active_model() else {
|
||||
let Some(model) = model_registry.default_model() else {
|
||||
return TotalTokenUsage::default();
|
||||
};
|
||||
|
||||
let max = model.max_token_count();
|
||||
let max = model.model.max_token_count();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
|
||||
|
|
|
@ -161,12 +161,38 @@ fn init_language_model_settings(cx: &mut App) {
|
|||
|
||||
fn update_active_language_model_from_settings(cx: &mut App) {
|
||||
let settings = AssistantSettings::get_global(cx);
|
||||
// Default model - used as fallback
|
||||
let active_model_provider_name =
|
||||
LanguageModelProviderId::from(settings.default_model.provider.clone());
|
||||
let active_model_id = LanguageModelId::from(settings.default_model.model.clone());
|
||||
let editor_provider_name =
|
||||
LanguageModelProviderId::from(settings.editor_model.provider.clone());
|
||||
let editor_model_id = LanguageModelId::from(settings.editor_model.model.clone());
|
||||
|
||||
// Inline assistant model
|
||||
let inline_assistant_model = settings
|
||||
.inline_assistant_model
|
||||
.as_ref()
|
||||
.unwrap_or(&settings.default_model);
|
||||
let inline_assistant_provider_name =
|
||||
LanguageModelProviderId::from(inline_assistant_model.provider.clone());
|
||||
let inline_assistant_model_id = LanguageModelId::from(inline_assistant_model.model.clone());
|
||||
|
||||
// Commit message model
|
||||
let commit_message_model = settings
|
||||
.commit_message_model
|
||||
.as_ref()
|
||||
.unwrap_or(&settings.default_model);
|
||||
let commit_message_provider_name =
|
||||
LanguageModelProviderId::from(commit_message_model.provider.clone());
|
||||
let commit_message_model_id = LanguageModelId::from(commit_message_model.model.clone());
|
||||
|
||||
// Thread summary model
|
||||
let thread_summary_model = settings
|
||||
.thread_summary_model
|
||||
.as_ref()
|
||||
.unwrap_or(&settings.default_model);
|
||||
let thread_summary_provider_name =
|
||||
LanguageModelProviderId::from(thread_summary_model.provider.clone());
|
||||
let thread_summary_model_id = LanguageModelId::from(thread_summary_model.model.clone());
|
||||
|
||||
let inline_alternatives = settings
|
||||
.inline_alternatives
|
||||
.iter()
|
||||
|
@ -177,9 +203,29 @@ fn update_active_language_model_from_settings(cx: &mut App) {
|
|||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.select_active_model(&active_model_provider_name, &active_model_id, cx);
|
||||
registry.select_editor_model(&editor_provider_name, &editor_model_id, cx);
|
||||
// Set the default model
|
||||
registry.select_default_model(&active_model_provider_name, &active_model_id, cx);
|
||||
|
||||
// Set the specific models
|
||||
registry.select_inline_assistant_model(
|
||||
&inline_assistant_provider_name,
|
||||
&inline_assistant_model_id,
|
||||
cx,
|
||||
);
|
||||
registry.select_commit_message_model(
|
||||
&commit_message_provider_name,
|
||||
&commit_message_model_id,
|
||||
cx,
|
||||
);
|
||||
registry.select_thread_summary_model(
|
||||
&thread_summary_provider_name,
|
||||
&thread_summary_model_id,
|
||||
cx,
|
||||
);
|
||||
|
||||
// Set the alternatives
|
||||
registry.select_inline_alternative_models(inline_alternatives, cx);
|
||||
});
|
||||
}
|
||||
|
|
|
@ -22,7 +22,8 @@ use gpui::{
|
|||
};
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
|
||||
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||
ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_library::{PromptLibrary, open_prompt_library};
|
||||
|
@ -298,8 +299,10 @@ impl AssistantPanel {
|
|||
&LanguageModelRegistry::global(cx),
|
||||
window,
|
||||
|this, _, event: &language_model::Event, window, cx| match event {
|
||||
language_model::Event::ActiveModelChanged
|
||||
| language_model::Event::EditorModelChanged => {
|
||||
language_model::Event::DefaultModelChanged
|
||||
| language_model::Event::InlineAssistantModelChanged
|
||||
| language_model::Event::CommitMessageModelChanged
|
||||
| language_model::Event::ThreadSummaryModelChanged => {
|
||||
this.completion_provider_changed(window, cx);
|
||||
}
|
||||
language_model::Event::ProviderStateChanged => {
|
||||
|
@ -468,12 +471,12 @@ impl AssistantPanel {
|
|||
}
|
||||
|
||||
fn update_zed_ai_notice_visibility(&mut self, client_status: Status, cx: &mut Context<Self>) {
|
||||
let active_provider = LanguageModelRegistry::read_global(cx).active_provider();
|
||||
let model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
|
||||
// If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is
|
||||
// the provider, we want to show a nudge to sign in.
|
||||
let show_zed_ai_notice = client_status.is_signed_out()
|
||||
&& active_provider.map_or(true, |provider| provider.id().0 == ZED_CLOUD_PROVIDER_ID);
|
||||
&& model.map_or(true, |model| model.provider.id().0 == ZED_CLOUD_PROVIDER_ID);
|
||||
|
||||
self.show_zed_ai_notice = show_zed_ai_notice;
|
||||
cx.notify();
|
||||
|
@ -541,8 +544,8 @@ impl AssistantPanel {
|
|||
}
|
||||
|
||||
let Some(new_provider_id) = LanguageModelRegistry::read_global(cx)
|
||||
.active_provider()
|
||||
.map(|p| p.id())
|
||||
.default_model()
|
||||
.map(|default| default.provider.id())
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
@ -568,7 +571,9 @@ impl AssistantPanel {
|
|||
return;
|
||||
}
|
||||
|
||||
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
|
||||
let Some(ConfiguredModel { provider, .. }) =
|
||||
LanguageModelRegistry::read_global(cx).default_model()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
|
@ -976,8 +981,8 @@ impl AssistantPanel {
|
|||
|this, _, event: &ConfigurationViewEvent, window, cx| match event {
|
||||
ConfigurationViewEvent::NewProviderContextEditor(provider) => {
|
||||
if LanguageModelRegistry::read_global(cx)
|
||||
.active_provider()
|
||||
.map_or(true, |p| p.id() != provider.id())
|
||||
.default_model()
|
||||
.map_or(true, |default| default.provider.id() != provider.id())
|
||||
{
|
||||
if let Some(model) = provider.default_model(cx) {
|
||||
update_settings_file::<AssistantSettings>(
|
||||
|
@ -1155,8 +1160,8 @@ impl AssistantPanel {
|
|||
|
||||
fn is_authenticated(&mut self, cx: &mut Context<Self>) -> bool {
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.active_provider()
|
||||
.map_or(false, |provider| provider.is_authenticated(cx))
|
||||
.default_model()
|
||||
.map_or(false, |default| default.provider.is_authenticated(cx))
|
||||
}
|
||||
|
||||
fn authenticate(
|
||||
|
@ -1164,8 +1169,8 @@ impl AssistantPanel {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<Result<(), AuthenticateError>>> {
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.active_provider()
|
||||
.map_or(None, |provider| Some(provider.authenticate(cx)))
|
||||
.default_model()
|
||||
.map_or(None, |default| Some(default.provider.authenticate(cx)))
|
||||
}
|
||||
|
||||
fn restart_context_servers(
|
||||
|
|
|
@ -34,8 +34,8 @@ use gpui::{
|
|||
};
|
||||
use language::{Buffer, IndentKind, Point, Selection, TransactionId, line_diff};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelTextStream, Role, report_assistant_event,
|
||||
ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
|
||||
};
|
||||
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
|
@ -312,7 +312,9 @@ impl InlineAssistant {
|
|||
start..end,
|
||||
));
|
||||
|
||||
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
|
||||
if let Some(ConfiguredModel { model, .. }) =
|
||||
LanguageModelRegistry::read_global(cx).default_model()
|
||||
{
|
||||
self.telemetry.report_assistant_event(AssistantEvent {
|
||||
conversation_id: None,
|
||||
kind: AssistantKind::Inline,
|
||||
|
@ -877,7 +879,9 @@ impl InlineAssistant {
|
|||
let active_alternative = assist.codegen.read(cx).active_alternative().clone();
|
||||
let message_id = active_alternative.read(cx).message_id.clone();
|
||||
|
||||
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
|
||||
if let Some(ConfiguredModel { model, .. }) =
|
||||
LanguageModelRegistry::read_global(cx).default_model()
|
||||
{
|
||||
let language_name = assist.editor.upgrade().and_then(|editor| {
|
||||
let multibuffer = editor.read(cx).buffer().read(cx);
|
||||
let multibuffer_snapshot = multibuffer.snapshot(cx);
|
||||
|
@ -1629,8 +1633,8 @@ impl Render for PromptEditor {
|
|||
format!(
|
||||
"Using {}",
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.active_model()
|
||||
.map(|model| model.name().0)
|
||||
.default_model()
|
||||
.map(|default| default.model.name().0)
|
||||
.unwrap_or_else(|| "No model selected".into()),
|
||||
),
|
||||
None,
|
||||
|
@ -2077,7 +2081,7 @@ impl PromptEditor {
|
|||
let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
|
||||
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let default_model = model_registry.active_model();
|
||||
let default_model = model_registry.default_model().map(|default| default.model);
|
||||
let alternative_models = model_registry.inline_alternative_models();
|
||||
|
||||
let get_model_name = |index: usize| -> String {
|
||||
|
@ -2183,7 +2187,9 @@ impl PromptEditor {
|
|||
}
|
||||
|
||||
fn render_token_count(&self, cx: &mut Context<Self>) -> Option<impl IntoElement> {
|
||||
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()?
|
||||
.model;
|
||||
let token_counts = self.token_counts?;
|
||||
let max_token_count = model.max_token_count();
|
||||
|
||||
|
@ -2638,8 +2644,9 @@ impl Codegen {
|
|||
}
|
||||
|
||||
let primary_model = LanguageModelRegistry::read_global(cx)
|
||||
.active_model()
|
||||
.context("no active model")?;
|
||||
.default_model()
|
||||
.context("no active model")?
|
||||
.model;
|
||||
|
||||
for (model, alternative) in iter::once(primary_model)
|
||||
.chain(alternative_models)
|
||||
|
@ -2863,7 +2870,9 @@ impl CodegenAlternative {
|
|||
assistant_panel_context: Option<LanguageModelRequest>,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<TokenCounts>> {
|
||||
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
|
||||
if let Some(ConfiguredModel { model, .. }) =
|
||||
LanguageModelRegistry::read_global(cx).inline_assistant_model()
|
||||
{
|
||||
let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
|
||||
match request {
|
||||
Ok(request) => {
|
||||
|
|
|
@ -16,8 +16,8 @@ use gpui::{
|
|||
};
|
||||
use language::Buffer;
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
report_assistant_event,
|
||||
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
Role, report_assistant_event,
|
||||
};
|
||||
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
|
||||
use prompt_store::PromptBuilder;
|
||||
|
@ -318,7 +318,9 @@ impl TerminalInlineAssistant {
|
|||
})
|
||||
.log_err();
|
||||
|
||||
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
|
||||
if let Some(ConfiguredModel { model, .. }) =
|
||||
LanguageModelRegistry::read_global(cx).inline_assistant_model()
|
||||
{
|
||||
let codegen = assist.codegen.read(cx);
|
||||
let executor = cx.background_executor().clone();
|
||||
report_assistant_event(
|
||||
|
@ -652,8 +654,8 @@ impl Render for PromptEditor {
|
|||
format!(
|
||||
"Using {}",
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.active_model()
|
||||
.map(|model| model.name().0)
|
||||
.inline_assistant_model()
|
||||
.map(|inline_assistant| inline_assistant.model.name().0)
|
||||
.unwrap_or_else(|| "No model selected".into()),
|
||||
),
|
||||
None,
|
||||
|
@ -822,7 +824,9 @@ impl PromptEditor {
|
|||
|
||||
fn count_tokens(&mut self, cx: &mut Context<Self>) {
|
||||
let assist_id = self.id;
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
let Some(ConfiguredModel { model, .. }) =
|
||||
LanguageModelRegistry::read_global(cx).inline_assistant_model()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
self.pending_token_count = cx.spawn(async move |this, cx| {
|
||||
|
@ -980,7 +984,9 @@ impl PromptEditor {
|
|||
}
|
||||
|
||||
fn render_token_count(&self, cx: &mut Context<Self>) -> Option<impl IntoElement> {
|
||||
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.inline_assistant_model()?
|
||||
.model;
|
||||
let token_count = self.token_count?;
|
||||
let max_token_count = model.max_token_count();
|
||||
|
||||
|
@ -1131,7 +1137,9 @@ impl Codegen {
|
|||
}
|
||||
|
||||
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
let Some(ConfiguredModel { model, .. }) =
|
||||
LanguageModelRegistry::read_global(cx).inline_assistant_model()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
|
|
|
@ -1272,7 +1272,7 @@ impl AssistantContext {
|
|||
// Assume it will be a Chat request, even though that takes fewer tokens (and risks going over the limit),
|
||||
// because otherwise you see in the UI that your empty message has a bunch of tokens already used.
|
||||
let request = self.to_completion_request(RequestType::Chat, cx);
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
|
||||
return;
|
||||
};
|
||||
let debounce = self.token_count.is_some();
|
||||
|
@ -1284,10 +1284,12 @@ impl AssistantContext {
|
|||
.await;
|
||||
}
|
||||
|
||||
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
|
||||
let token_count = cx
|
||||
.update(|cx| model.model.count_tokens(request, cx))?
|
||||
.await?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
this.start_cache_warming(&model, cx);
|
||||
this.start_cache_warming(&model.model, cx);
|
||||
cx.notify()
|
||||
})
|
||||
}
|
||||
|
@ -2304,14 +2306,16 @@ impl AssistantContext {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Option<MessageAnchor> {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let provider = model_registry.active_provider()?;
|
||||
let model = model_registry.active_model()?;
|
||||
let model = model_registry.default_model()?;
|
||||
let last_message_id = self.get_last_valid_message_id(cx)?;
|
||||
|
||||
if !provider.is_authenticated(cx) {
|
||||
if !model.provider.is_authenticated(cx) {
|
||||
log::info!("completion provider has no credentials");
|
||||
return None;
|
||||
}
|
||||
|
||||
let model = model.model;
|
||||
|
||||
// Compute which messages to cache, including the last one.
|
||||
self.mark_cache_anchors(&model.cache_configuration(), false, cx);
|
||||
|
||||
|
@ -2940,15 +2944,12 @@ impl AssistantContext {
|
|||
}
|
||||
|
||||
pub fn summarize(&mut self, replace_old: bool, cx: &mut Context<Self>) {
|
||||
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
|
||||
return;
|
||||
};
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
|
||||
return;
|
||||
};
|
||||
|
||||
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
|
||||
if !provider.is_authenticated(cx) {
|
||||
if !model.provider.is_authenticated(cx) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2964,7 +2965,7 @@ impl AssistantContext {
|
|||
|
||||
self.pending_summary = cx.spawn(async move |this, cx| {
|
||||
async move {
|
||||
let stream = model.stream_completion_text(request, &cx);
|
||||
let stream = model.model.stream_completion_text(request, &cx);
|
||||
let mut messages = stream.await?;
|
||||
|
||||
let mut replaced = !replace_old;
|
||||
|
|
|
@ -384,7 +384,9 @@ impl ContextEditor {
|
|||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider();
|
||||
let provider = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.map(|default| default.provider);
|
||||
if provider
|
||||
.as_ref()
|
||||
.map_or(false, |provider| provider.must_accept_terms(cx))
|
||||
|
@ -2395,13 +2397,13 @@ impl ContextEditor {
|
|||
None => (ButtonStyle::Filled, None),
|
||||
};
|
||||
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider();
|
||||
let model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
|
||||
let has_configuration_error = configuration_error(cx).is_some();
|
||||
let needs_to_accept_terms = self.show_accept_terms
|
||||
&& provider
|
||||
&& model
|
||||
.as_ref()
|
||||
.map_or(false, |provider| provider.must_accept_terms(cx));
|
||||
.map_or(false, |model| model.provider.must_accept_terms(cx));
|
||||
let disabled = has_configuration_error || needs_to_accept_terms;
|
||||
|
||||
ButtonLike::new("send_button")
|
||||
|
@ -2454,7 +2456,9 @@ impl ContextEditor {
|
|||
None => (ButtonStyle::Filled, None),
|
||||
};
|
||||
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider();
|
||||
let provider = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.map(|default| default.provider);
|
||||
|
||||
let has_configuration_error = configuration_error(cx).is_some();
|
||||
let needs_to_accept_terms = self.show_accept_terms
|
||||
|
@ -2500,7 +2504,9 @@ impl ContextEditor {
|
|||
}
|
||||
|
||||
fn render_language_model_selector(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let active_model = LanguageModelRegistry::read_global(cx).active_model();
|
||||
let active_model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.map(|default| default.model);
|
||||
let focus_handle = self.editor().focus_handle(cx).clone();
|
||||
let model_name = match active_model {
|
||||
Some(model) => model.name().0,
|
||||
|
@ -3020,7 +3026,9 @@ impl EventEmitter<SearchEvent> for ContextEditor {}
|
|||
|
||||
impl Render for ContextEditor {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider();
|
||||
let provider = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.map(|default| default.provider);
|
||||
let accept_terms = if self.show_accept_terms {
|
||||
provider.as_ref().and_then(|provider| {
|
||||
provider.render_accept_terms(LanguageModelProviderTosView::PromptEditorPopup, cx)
|
||||
|
@ -3616,7 +3624,9 @@ enum TokenState {
|
|||
fn token_state(context: &Entity<AssistantContext>, cx: &App) -> Option<TokenState> {
|
||||
const WARNING_TOKEN_THRESHOLD: f32 = 0.8;
|
||||
|
||||
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()?
|
||||
.model;
|
||||
let token_count = context.read(cx).token_count()?;
|
||||
let max_token_count = model.max_token_count();
|
||||
|
||||
|
@ -3669,16 +3679,16 @@ pub enum ConfigurationError {
|
|||
}
|
||||
|
||||
fn configuration_error(cx: &App) -> Option<ConfigurationError> {
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider();
|
||||
let is_authenticated = provider
|
||||
let model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
let is_authenticated = model
|
||||
.as_ref()
|
||||
.map_or(false, |provider| provider.is_authenticated(cx));
|
||||
.map_or(false, |model| model.provider.is_authenticated(cx));
|
||||
|
||||
if provider.is_some() && is_authenticated {
|
||||
if model.is_some() && is_authenticated {
|
||||
return None;
|
||||
}
|
||||
|
||||
if provider.is_none() {
|
||||
if model.is_none() {
|
||||
return Some(ConfigurationError::NoProvider);
|
||||
}
|
||||
|
||||
|
|
|
@ -156,10 +156,10 @@ impl HeadlessAssistant {
|
|||
}
|
||||
if thread.read(cx).all_tools_finished() {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
if let Some(model) = model_registry.active_model() {
|
||||
if let Some(model) = model_registry.default_model() {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.attach_tool_results(cx);
|
||||
thread.send_to_model(model, RequestKind::Chat, cx);
|
||||
thread.send_to_model(model.model, RequestKind::Chat, cx);
|
||||
});
|
||||
} else {
|
||||
println!(
|
||||
|
|
|
@ -37,9 +37,6 @@ struct Args {
|
|||
/// Name of the model (default: "claude-3-7-sonnet-latest")
|
||||
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
|
||||
model_name: String,
|
||||
/// Name of the editor model (default: value of `--model_name`).
|
||||
#[arg(long)]
|
||||
editor_model_name: Option<String>,
|
||||
/// Name of the judge model (default: value of `--model_name`).
|
||||
#[arg(long)]
|
||||
judge_model_name: Option<String>,
|
||||
|
@ -79,11 +76,6 @@ fn main() {
|
|||
let app_state = headless_assistant::init(cx);
|
||||
|
||||
let model = find_model(&args.model_name, cx).unwrap();
|
||||
let editor_model = if let Some(model_name) = &args.editor_model_name {
|
||||
find_model(model_name, cx).unwrap()
|
||||
} else {
|
||||
model.clone()
|
||||
};
|
||||
let judge_model = if let Some(model_name) = &args.judge_model_name {
|
||||
find_model(model_name, cx).unwrap()
|
||||
} else {
|
||||
|
@ -91,12 +83,10 @@ fn main() {
|
|||
};
|
||||
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.set_active_model(Some(model.clone()), cx);
|
||||
registry.set_editor_model(Some(editor_model.clone()), cx);
|
||||
registry.set_default_model(Some(model.clone()), cx);
|
||||
});
|
||||
|
||||
let model_provider_id = model.provider_id();
|
||||
let editor_model_provider_id = editor_model.provider_id();
|
||||
let judge_model_provider_id = judge_model.provider_id();
|
||||
|
||||
let framework_path_clone = framework_path.clone();
|
||||
|
@ -110,10 +100,6 @@ fn main() {
|
|||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
|
||||
.unwrap()
|
||||
.await
|
||||
|
|
|
@ -77,7 +77,9 @@ pub struct AssistantSettings {
|
|||
pub default_width: Pixels,
|
||||
pub default_height: Pixels,
|
||||
pub default_model: LanguageModelSelection,
|
||||
pub editor_model: LanguageModelSelection,
|
||||
pub inline_assistant_model: Option<LanguageModelSelection>,
|
||||
pub commit_message_model: Option<LanguageModelSelection>,
|
||||
pub thread_summary_model: Option<LanguageModelSelection>,
|
||||
pub inline_alternatives: Vec<LanguageModelSelection>,
|
||||
pub using_outdated_settings_version: bool,
|
||||
pub enable_experimental_live_diffs: bool,
|
||||
|
@ -95,13 +97,25 @@ impl AssistantSettings {
|
|||
|
||||
cx.is_staff() || self.enable_experimental_live_diffs
|
||||
}
|
||||
|
||||
pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
|
||||
self.inline_assistant_model = Some(LanguageModelSelection { provider, model });
|
||||
}
|
||||
|
||||
pub fn set_commit_message_model(&mut self, provider: String, model: String) {
|
||||
self.commit_message_model = Some(LanguageModelSelection { provider, model });
|
||||
}
|
||||
|
||||
pub fn set_thread_summary_model(&mut self, provider: String, model: String) {
|
||||
self.thread_summary_model = Some(LanguageModelSelection { provider, model });
|
||||
}
|
||||
}
|
||||
|
||||
/// Assistant panel settings
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
pub enum AssistantSettingsContent {
|
||||
Versioned(VersionedAssistantSettingsContent),
|
||||
Versioned(Box<VersionedAssistantSettingsContent>),
|
||||
Legacy(LegacyAssistantSettingsContent),
|
||||
}
|
||||
|
||||
|
@ -121,14 +135,14 @@ impl JsonSchema for AssistantSettingsContent {
|
|||
|
||||
impl Default for AssistantSettingsContent {
|
||||
fn default() -> Self {
|
||||
Self::Versioned(VersionedAssistantSettingsContent::default())
|
||||
Self::Versioned(Box::new(VersionedAssistantSettingsContent::default()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AssistantSettingsContent {
|
||||
pub fn is_version_outdated(&self) -> bool {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||
AssistantSettingsContent::Versioned(settings) => match **settings {
|
||||
VersionedAssistantSettingsContent::V1(_) => true,
|
||||
VersionedAssistantSettingsContent::V2(_) => false,
|
||||
},
|
||||
|
@ -138,8 +152,8 @@ impl AssistantSettingsContent {
|
|||
|
||||
fn upgrade(&self) -> AssistantSettingsContentV2 {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||
VersionedAssistantSettingsContent::V1(settings) => AssistantSettingsContentV2 {
|
||||
AssistantSettingsContent::Versioned(settings) => match **settings {
|
||||
VersionedAssistantSettingsContent::V1(ref settings) => AssistantSettingsContentV2 {
|
||||
enabled: settings.enabled,
|
||||
button: settings.button,
|
||||
dock: settings.dock,
|
||||
|
@ -186,7 +200,9 @@ impl AssistantSettingsContent {
|
|||
})
|
||||
}
|
||||
}),
|
||||
editor_model: None,
|
||||
inline_assistant_model: None,
|
||||
commit_message_model: None,
|
||||
thread_summary_model: None,
|
||||
inline_alternatives: None,
|
||||
enable_experimental_live_diffs: None,
|
||||
default_profile: None,
|
||||
|
@ -194,7 +210,7 @@ impl AssistantSettingsContent {
|
|||
always_allow_tool_actions: None,
|
||||
notify_when_agent_waiting: None,
|
||||
},
|
||||
VersionedAssistantSettingsContent::V2(settings) => settings.clone(),
|
||||
VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
|
||||
},
|
||||
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV2 {
|
||||
enabled: None,
|
||||
|
@ -211,7 +227,9 @@ impl AssistantSettingsContent {
|
|||
.id()
|
||||
.to_string(),
|
||||
}),
|
||||
editor_model: None,
|
||||
inline_assistant_model: None,
|
||||
commit_message_model: None,
|
||||
thread_summary_model: None,
|
||||
inline_alternatives: None,
|
||||
enable_experimental_live_diffs: None,
|
||||
default_profile: None,
|
||||
|
@ -224,11 +242,11 @@ impl AssistantSettingsContent {
|
|||
|
||||
pub fn set_dock(&mut self, dock: AssistantDockPosition) {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||
VersionedAssistantSettingsContent::V1(settings) => {
|
||||
AssistantSettingsContent::Versioned(settings) => match **settings {
|
||||
VersionedAssistantSettingsContent::V1(ref mut settings) => {
|
||||
settings.dock = Some(dock);
|
||||
}
|
||||
VersionedAssistantSettingsContent::V2(settings) => {
|
||||
VersionedAssistantSettingsContent::V2(ref mut settings) => {
|
||||
settings.dock = Some(dock);
|
||||
}
|
||||
},
|
||||
|
@ -243,77 +261,79 @@ impl AssistantSettingsContent {
|
|||
let provider = language_model.provider_id().0.to_string();
|
||||
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||
VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() {
|
||||
"zed.dev" => {
|
||||
log::warn!("attempted to set zed.dev model on outdated settings");
|
||||
}
|
||||
"anthropic" => {
|
||||
let api_url = match &settings.provider {
|
||||
Some(AssistantProviderContentV1::Anthropic { api_url, .. }) => {
|
||||
api_url.clone()
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
settings.provider = Some(AssistantProviderContentV1::Anthropic {
|
||||
default_model: AnthropicModel::from_id(&model).ok(),
|
||||
api_url,
|
||||
});
|
||||
}
|
||||
"ollama" => {
|
||||
let api_url = match &settings.provider {
|
||||
Some(AssistantProviderContentV1::Ollama { api_url, .. }) => {
|
||||
api_url.clone()
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
settings.provider = Some(AssistantProviderContentV1::Ollama {
|
||||
default_model: Some(ollama::Model::new(&model, None, None)),
|
||||
api_url,
|
||||
});
|
||||
}
|
||||
"lmstudio" => {
|
||||
let api_url = match &settings.provider {
|
||||
Some(AssistantProviderContentV1::LmStudio { api_url, .. }) => {
|
||||
api_url.clone()
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
settings.provider = Some(AssistantProviderContentV1::LmStudio {
|
||||
default_model: Some(lmstudio::Model::new(&model, None, None)),
|
||||
api_url,
|
||||
});
|
||||
}
|
||||
"openai" => {
|
||||
let (api_url, available_models) = match &settings.provider {
|
||||
Some(AssistantProviderContentV1::OpenAi {
|
||||
AssistantSettingsContent::Versioned(settings) => match **settings {
|
||||
VersionedAssistantSettingsContent::V1(ref mut settings) => {
|
||||
match provider.as_ref() {
|
||||
"zed.dev" => {
|
||||
log::warn!("attempted to set zed.dev model on outdated settings");
|
||||
}
|
||||
"anthropic" => {
|
||||
let api_url = match &settings.provider {
|
||||
Some(AssistantProviderContentV1::Anthropic { api_url, .. }) => {
|
||||
api_url.clone()
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
settings.provider = Some(AssistantProviderContentV1::Anthropic {
|
||||
default_model: AnthropicModel::from_id(&model).ok(),
|
||||
api_url,
|
||||
});
|
||||
}
|
||||
"ollama" => {
|
||||
let api_url = match &settings.provider {
|
||||
Some(AssistantProviderContentV1::Ollama { api_url, .. }) => {
|
||||
api_url.clone()
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
settings.provider = Some(AssistantProviderContentV1::Ollama {
|
||||
default_model: Some(ollama::Model::new(&model, None, None)),
|
||||
api_url,
|
||||
});
|
||||
}
|
||||
"lmstudio" => {
|
||||
let api_url = match &settings.provider {
|
||||
Some(AssistantProviderContentV1::LmStudio { api_url, .. }) => {
|
||||
api_url.clone()
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
settings.provider = Some(AssistantProviderContentV1::LmStudio {
|
||||
default_model: Some(lmstudio::Model::new(&model, None, None)),
|
||||
api_url,
|
||||
});
|
||||
}
|
||||
"openai" => {
|
||||
let (api_url, available_models) = match &settings.provider {
|
||||
Some(AssistantProviderContentV1::OpenAi {
|
||||
api_url,
|
||||
available_models,
|
||||
..
|
||||
}) => (api_url.clone(), available_models.clone()),
|
||||
_ => (None, None),
|
||||
};
|
||||
settings.provider = Some(AssistantProviderContentV1::OpenAi {
|
||||
default_model: OpenAiModel::from_id(&model).ok(),
|
||||
api_url,
|
||||
available_models,
|
||||
..
|
||||
}) => (api_url.clone(), available_models.clone()),
|
||||
_ => (None, None),
|
||||
};
|
||||
settings.provider = Some(AssistantProviderContentV1::OpenAi {
|
||||
default_model: OpenAiModel::from_id(&model).ok(),
|
||||
api_url,
|
||||
available_models,
|
||||
});
|
||||
});
|
||||
}
|
||||
"deepseek" => {
|
||||
let api_url = match &settings.provider {
|
||||
Some(AssistantProviderContentV1::DeepSeek { api_url, .. }) => {
|
||||
api_url.clone()
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
settings.provider = Some(AssistantProviderContentV1::DeepSeek {
|
||||
default_model: DeepseekModel::from_id(&model).ok(),
|
||||
api_url,
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
"deepseek" => {
|
||||
let api_url = match &settings.provider {
|
||||
Some(AssistantProviderContentV1::DeepSeek { api_url, .. }) => {
|
||||
api_url.clone()
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
settings.provider = Some(AssistantProviderContentV1::DeepSeek {
|
||||
default_model: DeepseekModel::from_id(&model).ok(),
|
||||
api_url,
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
VersionedAssistantSettingsContent::V2(settings) => {
|
||||
}
|
||||
VersionedAssistantSettingsContent::V2(ref mut settings) => {
|
||||
settings.default_model = Some(LanguageModelSelection { provider, model });
|
||||
}
|
||||
},
|
||||
|
@ -325,23 +345,48 @@ impl AssistantSettingsContent {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
|
||||
if let AssistantSettingsContent::Versioned(boxed) = self {
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
settings.inline_assistant_model = Some(LanguageModelSelection { provider, model });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_commit_message_model(&mut self, provider: String, model: String) {
|
||||
if let AssistantSettingsContent::Versioned(boxed) = self {
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
settings.commit_message_model = Some(LanguageModelSelection { provider, model });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_thread_summary_model(&mut self, provider: String, model: String) {
|
||||
if let AssistantSettingsContent::Versioned(boxed) = self {
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
settings.thread_summary_model = Some(LanguageModelSelection { provider, model });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_always_allow_tool_actions(&mut self, allow: bool) {
|
||||
let AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(settings)) =
|
||||
self
|
||||
else {
|
||||
let AssistantSettingsContent::Versioned(boxed) = self else {
|
||||
return;
|
||||
};
|
||||
settings.always_allow_tool_actions = Some(allow);
|
||||
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
settings.always_allow_tool_actions = Some(allow);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_profile(&mut self, profile_id: AgentProfileId) {
|
||||
let AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(settings)) =
|
||||
self
|
||||
else {
|
||||
let AssistantSettingsContent::Versioned(boxed) = self else {
|
||||
return;
|
||||
};
|
||||
|
||||
settings.default_profile = Some(profile_id);
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
settings.default_profile = Some(profile_id);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_profile(
|
||||
|
@ -349,37 +394,37 @@ impl AssistantSettingsContent {
|
|||
profile_id: AgentProfileId,
|
||||
profile: AgentProfile,
|
||||
) -> Result<()> {
|
||||
let AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(settings)) =
|
||||
self
|
||||
else {
|
||||
let AssistantSettingsContent::Versioned(boxed) = self else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let profiles = settings.profiles.get_or_insert_default();
|
||||
if profiles.contains_key(&profile_id) {
|
||||
bail!("profile with ID '{profile_id}' already exists");
|
||||
}
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
let profiles = settings.profiles.get_or_insert_default();
|
||||
if profiles.contains_key(&profile_id) {
|
||||
bail!("profile with ID '{profile_id}' already exists");
|
||||
}
|
||||
|
||||
profiles.insert(
|
||||
profile_id,
|
||||
AgentProfileContent {
|
||||
name: profile.name.into(),
|
||||
tools: profile.tools,
|
||||
enable_all_context_servers: Some(profile.enable_all_context_servers),
|
||||
context_servers: profile
|
||||
.context_servers
|
||||
.into_iter()
|
||||
.map(|(server_id, preset)| {
|
||||
(
|
||||
server_id,
|
||||
ContextServerPresetContent {
|
||||
tools: preset.tools,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
);
|
||||
profiles.insert(
|
||||
profile_id,
|
||||
AgentProfileContent {
|
||||
name: profile.name.into(),
|
||||
tools: profile.tools,
|
||||
enable_all_context_servers: Some(profile.enable_all_context_servers),
|
||||
context_servers: profile
|
||||
.context_servers
|
||||
.into_iter()
|
||||
.map(|(server_id, preset)| {
|
||||
(
|
||||
server_id,
|
||||
ContextServerPresetContent {
|
||||
tools: preset.tools,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -403,7 +448,9 @@ impl Default for VersionedAssistantSettingsContent {
|
|||
default_width: None,
|
||||
default_height: None,
|
||||
default_model: None,
|
||||
editor_model: None,
|
||||
inline_assistant_model: None,
|
||||
commit_message_model: None,
|
||||
thread_summary_model: None,
|
||||
inline_alternatives: None,
|
||||
enable_experimental_live_diffs: None,
|
||||
default_profile: None,
|
||||
|
@ -436,10 +483,14 @@ pub struct AssistantSettingsContentV2 {
|
|||
///
|
||||
/// Default: 320
|
||||
default_height: Option<f32>,
|
||||
/// The default model to use when creating new chats.
|
||||
/// The default model to use when creating new chats and for other features when a specific model is not specified.
|
||||
default_model: Option<LanguageModelSelection>,
|
||||
/// The model to use when applying edits from the assistant.
|
||||
editor_model: Option<LanguageModelSelection>,
|
||||
/// Model to use for the inline assistant. Defaults to default_model when not specified.
|
||||
inline_assistant_model: Option<LanguageModelSelection>,
|
||||
/// Model to use for generating git commit messages. Defaults to default_model when not specified.
|
||||
commit_message_model: Option<LanguageModelSelection>,
|
||||
/// Model to use for generating thread summaries. Defaults to default_model when not specified.
|
||||
thread_summary_model: Option<LanguageModelSelection>,
|
||||
/// Additional models with which to generate alternatives when performing inline assists.
|
||||
inline_alternatives: Option<Vec<LanguageModelSelection>>,
|
||||
/// Enable experimental live diffs in the assistant panel.
|
||||
|
@ -601,7 +652,15 @@ impl Settings for AssistantSettings {
|
|||
value.default_height.map(Into::into),
|
||||
);
|
||||
merge(&mut settings.default_model, value.default_model);
|
||||
merge(&mut settings.editor_model, value.editor_model);
|
||||
settings.inline_assistant_model = value
|
||||
.inline_assistant_model
|
||||
.or(settings.inline_assistant_model.take());
|
||||
settings.commit_message_model = value
|
||||
.commit_message_model
|
||||
.or(settings.commit_message_model.take());
|
||||
settings.thread_summary_model = value
|
||||
.thread_summary_model
|
||||
.or(settings.thread_summary_model.take());
|
||||
merge(&mut settings.inline_alternatives, value.inline_alternatives);
|
||||
merge(
|
||||
&mut settings.enable_experimental_live_diffs,
|
||||
|
@ -692,16 +751,15 @@ mod tests {
|
|||
settings::SettingsStore::global(cx).update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
|settings, _| {
|
||||
*settings = AssistantSettingsContent::Versioned(
|
||||
*settings = AssistantSettingsContent::Versioned(Box::new(
|
||||
VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 {
|
||||
default_model: Some(LanguageModelSelection {
|
||||
provider: "test-provider".into(),
|
||||
model: "gpt-99".into(),
|
||||
}),
|
||||
editor_model: Some(LanguageModelSelection {
|
||||
provider: "test-provider".into(),
|
||||
model: "gpt-99".into(),
|
||||
}),
|
||||
inline_assistant_model: None,
|
||||
commit_message_model: None,
|
||||
thread_summary_model: None,
|
||||
inline_alternatives: None,
|
||||
enabled: None,
|
||||
button: None,
|
||||
|
@ -714,7 +772,7 @@ mod tests {
|
|||
always_allow_tool_actions: None,
|
||||
notify_when_agent_waiting: None,
|
||||
}),
|
||||
)
|
||||
))
|
||||
},
|
||||
);
|
||||
});
|
||||
|
|
|
@ -9,7 +9,7 @@ use collections::HashSet;
|
|||
use edit_action::{EditAction, EditActionParser, edit_model_prompt};
|
||||
use futures::{SinkExt, StreamExt, channel::mpsc};
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
use language_model::{ConfiguredModel, LanguageModelToolSchemaFormat};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
|
||||
};
|
||||
|
@ -205,8 +205,8 @@ impl EditToolRequest {
|
|||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let Some(model) = model_registry.editor_model() else {
|
||||
return Task::ready(Err(anyhow!("No editor model configured")));
|
||||
let Some(ConfiguredModel { model, .. }) = model_registry.default_model() else {
|
||||
return Task::ready(Err(anyhow!("No model configured")));
|
||||
};
|
||||
|
||||
let mut messages = messages.to_vec();
|
||||
|
|
|
@ -37,7 +37,8 @@ use gpui::{
|
|||
use itertools::Itertools;
|
||||
use language::{Buffer, File};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, Role,
|
||||
};
|
||||
use menu::{Confirm, SecondaryConfirm, SelectFirst, SelectLast, SelectNext, SelectPrevious};
|
||||
use multi_buffer::ExcerptInfo;
|
||||
|
@ -3764,8 +3765,9 @@ fn current_language_model(cx: &Context<'_, GitPanel>) -> Option<Arc<dyn Language
|
|||
assistant_settings::AssistantSettings::get_global(cx)
|
||||
.enabled
|
||||
.then(|| {
|
||||
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
|
||||
let model = LanguageModelRegistry::read_global(cx).active_model()?;
|
||||
let ConfiguredModel { provider, model } =
|
||||
LanguageModelRegistry::read_global(cx).commit_message_model()?;
|
||||
|
||||
provider.is_authenticated(cx).then(|| model)
|
||||
})
|
||||
.flatten()
|
||||
|
|
|
@ -17,20 +17,25 @@ impl Global for GlobalLanguageModelRegistry {}
|
|||
|
||||
#[derive(Default)]
|
||||
pub struct LanguageModelRegistry {
|
||||
active_model: Option<ActiveModel>,
|
||||
editor_model: Option<ActiveModel>,
|
||||
default_model: Option<ConfiguredModel>,
|
||||
inline_assistant_model: Option<ConfiguredModel>,
|
||||
commit_message_model: Option<ConfiguredModel>,
|
||||
thread_summary_model: Option<ConfiguredModel>,
|
||||
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
|
||||
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
|
||||
}
|
||||
|
||||
pub struct ActiveModel {
|
||||
provider: Arc<dyn LanguageModelProvider>,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
#[derive(Clone)]
|
||||
pub struct ConfiguredModel {
|
||||
pub provider: Arc<dyn LanguageModelProvider>,
|
||||
pub model: Arc<dyn LanguageModel>,
|
||||
}
|
||||
|
||||
pub enum Event {
|
||||
ActiveModelChanged,
|
||||
EditorModelChanged,
|
||||
DefaultModelChanged,
|
||||
InlineAssistantModelChanged,
|
||||
CommitMessageModelChanged,
|
||||
ThreadSummaryModelChanged,
|
||||
ProviderStateChanged,
|
||||
AddedProvider(LanguageModelProviderId),
|
||||
RemovedProvider(LanguageModelProviderId),
|
||||
|
@ -54,7 +59,7 @@ impl LanguageModelRegistry {
|
|||
let mut registry = Self::default();
|
||||
registry.register_provider(fake_provider.clone(), cx);
|
||||
let model = fake_provider.provided_models(cx)[0].clone();
|
||||
registry.set_active_model(Some(model), cx);
|
||||
registry.set_default_model(Some(model), cx);
|
||||
registry
|
||||
});
|
||||
cx.set_global(GlobalLanguageModelRegistry(registry));
|
||||
|
@ -114,7 +119,7 @@ impl LanguageModelRegistry {
|
|||
self.providers.get(id).cloned()
|
||||
}
|
||||
|
||||
pub fn select_active_model(
|
||||
pub fn select_default_model(
|
||||
&mut self,
|
||||
provider: &LanguageModelProviderId,
|
||||
model_id: &LanguageModelId,
|
||||
|
@ -126,11 +131,11 @@ impl LanguageModelRegistry {
|
|||
|
||||
let models = provider.provided_models(cx);
|
||||
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
|
||||
self.set_active_model(Some(model), cx);
|
||||
self.set_default_model(Some(model), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select_editor_model(
|
||||
pub fn select_inline_assistant_model(
|
||||
&mut self,
|
||||
provider: &LanguageModelProviderId,
|
||||
model_id: &LanguageModelId,
|
||||
|
@ -142,23 +147,43 @@ impl LanguageModelRegistry {
|
|||
|
||||
let models = provider.provided_models(cx);
|
||||
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
|
||||
self.set_editor_model(Some(model), cx);
|
||||
self.set_inline_assistant_model(Some(model), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_active_provider(
|
||||
pub fn select_commit_message_model(
|
||||
&mut self,
|
||||
provider: Option<Arc<dyn LanguageModelProvider>>,
|
||||
provider: &LanguageModelProviderId,
|
||||
model_id: &LanguageModelId,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.active_model = provider.map(|provider| ActiveModel {
|
||||
provider,
|
||||
model: None,
|
||||
});
|
||||
cx.emit(Event::ActiveModelChanged);
|
||||
let Some(provider) = self.provider(provider) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let models = provider.provided_models(cx);
|
||||
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
|
||||
self.set_commit_message_model(Some(model), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_active_model(
|
||||
pub fn select_thread_summary_model(
|
||||
&mut self,
|
||||
provider: &LanguageModelProviderId,
|
||||
model_id: &LanguageModelId,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(provider) = self.provider(provider) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let models = provider.provided_models(cx);
|
||||
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
|
||||
self.set_thread_summary_model(Some(model), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_default_model(
|
||||
&mut self,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
cx: &mut Context<Self>,
|
||||
|
@ -166,21 +191,18 @@ impl LanguageModelRegistry {
|
|||
if let Some(model) = model {
|
||||
let provider_id = model.provider_id();
|
||||
if let Some(provider) = self.providers.get(&provider_id).cloned() {
|
||||
self.active_model = Some(ActiveModel {
|
||||
provider,
|
||||
model: Some(model),
|
||||
});
|
||||
cx.emit(Event::ActiveModelChanged);
|
||||
self.default_model = Some(ConfiguredModel { provider, model });
|
||||
cx.emit(Event::DefaultModelChanged);
|
||||
} else {
|
||||
log::warn!("Active model's provider not found in registry");
|
||||
}
|
||||
} else {
|
||||
self.active_model = None;
|
||||
cx.emit(Event::ActiveModelChanged);
|
||||
self.default_model = None;
|
||||
cx.emit(Event::DefaultModelChanged);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_editor_model(
|
||||
pub fn set_inline_assistant_model(
|
||||
&mut self,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
cx: &mut Context<Self>,
|
||||
|
@ -188,35 +210,80 @@ impl LanguageModelRegistry {
|
|||
if let Some(model) = model {
|
||||
let provider_id = model.provider_id();
|
||||
if let Some(provider) = self.providers.get(&provider_id).cloned() {
|
||||
self.editor_model = Some(ActiveModel {
|
||||
provider,
|
||||
model: Some(model),
|
||||
});
|
||||
cx.emit(Event::EditorModelChanged);
|
||||
self.inline_assistant_model = Some(ConfiguredModel { provider, model });
|
||||
cx.emit(Event::InlineAssistantModelChanged);
|
||||
} else {
|
||||
log::warn!("Active model's provider not found in registry");
|
||||
log::warn!("Inline assistant model's provider not found in registry");
|
||||
}
|
||||
} else {
|
||||
self.editor_model = None;
|
||||
cx.emit(Event::EditorModelChanged);
|
||||
self.inline_assistant_model = None;
|
||||
cx.emit(Event::InlineAssistantModelChanged);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
|
||||
pub fn set_commit_message_model(
|
||||
&mut self,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(model) = model {
|
||||
let provider_id = model.provider_id();
|
||||
if let Some(provider) = self.providers.get(&provider_id).cloned() {
|
||||
self.commit_message_model = Some(ConfiguredModel { provider, model });
|
||||
cx.emit(Event::CommitMessageModelChanged);
|
||||
} else {
|
||||
log::warn!("Commit message model's provider not found in registry");
|
||||
}
|
||||
} else {
|
||||
self.commit_message_model = None;
|
||||
cx.emit(Event::CommitMessageModelChanged);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_thread_summary_model(
|
||||
&mut self,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(model) = model {
|
||||
let provider_id = model.provider_id();
|
||||
if let Some(provider) = self.providers.get(&provider_id).cloned() {
|
||||
self.thread_summary_model = Some(ConfiguredModel { provider, model });
|
||||
cx.emit(Event::ThreadSummaryModelChanged);
|
||||
} else {
|
||||
log::warn!("Thread summary model's provider not found in registry");
|
||||
}
|
||||
} else {
|
||||
self.thread_summary_model = None;
|
||||
cx.emit(Event::ThreadSummaryModelChanged);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_model(&self) -> Option<ConfiguredModel> {
|
||||
#[cfg(debug_assertions)]
|
||||
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(self.active_model.as_ref()?.provider.clone())
|
||||
self.default_model.clone()
|
||||
}
|
||||
|
||||
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.active_model.as_ref()?.model.clone()
|
||||
pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
|
||||
self.inline_assistant_model
|
||||
.clone()
|
||||
.or_else(|| self.default_model())
|
||||
}
|
||||
|
||||
pub fn editor_model(&self) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.editor_model.as_ref()?.model.clone()
|
||||
pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
|
||||
self.commit_message_model
|
||||
.clone()
|
||||
.or_else(|| self.default_model())
|
||||
}
|
||||
|
||||
pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
|
||||
self.thread_summary_model
|
||||
.clone()
|
||||
.or_else(|| self.default_model())
|
||||
}
|
||||
|
||||
/// Selects and sets the inline alternatives for language models based on
|
||||
|
|
|
@ -168,11 +168,11 @@ impl LanguageModelSelector {
|
|||
}
|
||||
|
||||
fn get_active_model_index(cx: &App) -> usize {
|
||||
let active_model = LanguageModelRegistry::read_global(cx).active_model();
|
||||
let active_model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
Self::all_models(cx)
|
||||
.iter()
|
||||
.position(|model_info| {
|
||||
Some(model_info.model.id()) == active_model.as_ref().map(|model| model.id())
|
||||
Some(model_info.model.id()) == active_model.as_ref().map(|model| model.model.id())
|
||||
})
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
@ -406,13 +406,10 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
|||
let model_info = self.filtered_models.get(ix)?;
|
||||
let provider_name: String = model_info.model.provider_name().0.clone().into();
|
||||
|
||||
let active_provider_id = LanguageModelRegistry::read_global(cx)
|
||||
.active_provider()
|
||||
.map(|m| m.id());
|
||||
let active_model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
|
||||
let active_model_id = LanguageModelRegistry::read_global(cx)
|
||||
.active_model()
|
||||
.map(|m| m.id());
|
||||
let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
|
||||
let active_model_id = active_model.map(|m| m.model.id());
|
||||
|
||||
let is_selected = Some(model_info.model.provider_id()) == active_provider_id
|
||||
&& Some(model_info.model.id()) == active_model_id;
|
||||
|
|
|
@ -9,7 +9,7 @@ use gpui::{
|
|||
};
|
||||
use language::{Buffer, LanguageRegistry, language_settings::SoftWrap};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||
};
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use release_channel::ReleaseChannel;
|
||||
|
@ -777,7 +777,9 @@ impl PromptLibrary {
|
|||
};
|
||||
|
||||
let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor;
|
||||
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
|
||||
let Some(ConfiguredModel { provider, .. }) =
|
||||
LanguageModelRegistry::read_global(cx).inline_assistant_model()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
|
@ -880,7 +882,9 @@ impl PromptLibrary {
|
|||
}
|
||||
|
||||
fn count_tokens(&mut self, prompt_id: PromptId, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
|
||||
let Some(ConfiguredModel { model, .. }) =
|
||||
LanguageModelRegistry::read_global(cx).default_model()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
if let Some(prompt) = self.prompt_editors.get_mut(&prompt_id) {
|
||||
|
@ -967,7 +971,9 @@ impl PromptLibrary {
|
|||
let prompt_metadata = self.store.metadata(prompt_id)?;
|
||||
let prompt_editor = &self.prompt_editors[&prompt_id];
|
||||
let focus_handle = prompt_editor.body_editor.focus_handle(cx);
|
||||
let model = LanguageModelRegistry::read_global(cx).active_model();
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.map(|default| default.model);
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
|
||||
Some(
|
||||
|
|
|
@ -19,7 +19,8 @@ To further customize providers, you can use `settings.json` to do that as follow
|
|||
|
||||
- [Configuring endpoints](#custom-endpoint)
|
||||
- [Configuring timeouts](#provider-timeout)
|
||||
- [Configuring default model](#default-model)
|
||||
- [Configuring models](#default-model)
|
||||
- [Configuring feature-specific models](#feature-specific-models)
|
||||
- [Configuring alternative models for inline assists](#alternative-assists)
|
||||
|
||||
### Zed AI {#zed-ai}
|
||||
|
@ -281,8 +282,24 @@ Example configuration for using X.ai Grok with Zed:
|
|||
"enabled": true,
|
||||
"default_model": {
|
||||
"provider": "zed.dev",
|
||||
"model": "claude-3-7-sonnet"
|
||||
},
|
||||
"editor_model": {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4o"
|
||||
},
|
||||
"inline_assistant_model": {
|
||||
"provider": "anthropic",
|
||||
"model": "claude-3-5-sonnet"
|
||||
},
|
||||
"commit_message_model": {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4o-mini"
|
||||
},
|
||||
"thread_summary_model": {
|
||||
"provider": "google",
|
||||
"model": "gemini-1.5-flash"
|
||||
},
|
||||
"version": "2",
|
||||
"button": true,
|
||||
"default_width": 480,
|
||||
|
@ -328,7 +345,7 @@ To do so, add the following to your Zed `settings.json`:
|
|||
|
||||
Where `some-provider` can be any of the following values: `anthropic`, `google`, `ollama`, `openai`.
|
||||
|
||||
#### Configuring the default model {#default-model}
|
||||
#### Configuring models {#default-model}
|
||||
|
||||
The default model can be set via the model dropdown in the assistant panel's top-right corner. Selecting a model saves it as the default.
|
||||
You can also manually edit the `default_model` object in your settings:
|
||||
|
@ -345,6 +362,47 @@ You can also manually edit the `default_model` object in your settings:
|
|||
}
|
||||
```
|
||||
|
||||
#### Feature-specific models {#feature-specific-models}
|
||||
|
||||
> Currently only available in [Preview](https://zed.dev/releases/preview).
|
||||
|
||||
Zed allows you to configure different models for specific features.
|
||||
This provides flexibility to use more powerful models for certain tasks while using faster or more efficient models for others.
|
||||
|
||||
If a feature-specific model is not set, it will fall back to using the default model, which is the one you set on the Agent Panel.
|
||||
|
||||
You can configure the following feature-specific models:
|
||||
|
||||
- Thread summary model: Used for generating thread summaries
|
||||
- Inline assistant model: Used for the inline assistant feature
|
||||
- Commit message model: Used for generating Git commit messages
|
||||
|
||||
Example configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"assistant": {
|
||||
"version": "2",
|
||||
"default_model": {
|
||||
"provider": "zed.dev",
|
||||
"model": "claude-3-7-sonnet"
|
||||
},
|
||||
"inline_assistant_model": {
|
||||
"provider": "anthropic",
|
||||
"model": "claude-3-5-sonnet"
|
||||
},
|
||||
"commit_message_model": {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4o-mini"
|
||||
},
|
||||
"thread_summary_model": {
|
||||
"provider": "google",
|
||||
"model": "gemini-2.0-flash"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Configuring alternative models for inline assists {#alternative-assists}
|
||||
|
||||
You can configure additional models that will be used to perform inline assists in parallel. When you do this,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue