Associate each thread with a model (#29573)
This PR makes it possible to use different LLM models in the agent panels of two different projects, simultaneously. It also properly restores a thread's original model when restoring it from the history, rather than having it use the default model. As before, newly-created threads will use the current default model. Release Notes: - Enabled different project windows to use different models in the agent panel - Enhanced the agent panel so that when revisiting old threads, their original model will be used. --------- Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
parent
5102c4c002
commit
17903a0999
15 changed files with 168 additions and 114 deletions
|
@ -25,8 +25,8 @@ use gpui::{
|
|||
};
|
||||
use language::{Buffer, LanguageRegistry};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent,
|
||||
RequestUsage, Role, StopReason,
|
||||
LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, RequestUsage, Role,
|
||||
StopReason,
|
||||
};
|
||||
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
|
||||
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
|
||||
|
@ -1252,7 +1252,7 @@ impl ActiveThread {
|
|||
cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged);
|
||||
state._update_token_count_task.take();
|
||||
|
||||
let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else {
|
||||
let Some(configured_model) = self.thread.read(cx).configured_model() else {
|
||||
state.last_estimated_token_count.take();
|
||||
return;
|
||||
};
|
||||
|
@ -1305,7 +1305,7 @@ impl ActiveThread {
|
|||
temperature: None,
|
||||
};
|
||||
|
||||
Some(default_model.model.count_tokens(request, cx))
|
||||
Some(configured_model.model.count_tokens(request, cx))
|
||||
})? {
|
||||
task.await?
|
||||
} else {
|
||||
|
@ -1338,7 +1338,7 @@ impl ActiveThread {
|
|||
return;
|
||||
};
|
||||
let edited_text = state.editor.read(cx).text(cx);
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
let thread_model = self.thread.update(cx, |thread, cx| {
|
||||
thread.edit_message(
|
||||
message_id,
|
||||
Role::User,
|
||||
|
@ -1348,9 +1348,10 @@ impl ActiveThread {
|
|||
for message_id in self.messages_after(message_id) {
|
||||
thread.delete_message(*message_id, cx);
|
||||
}
|
||||
thread.get_or_init_configured_model(cx)
|
||||
});
|
||||
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
|
||||
let Some(model) = thread_model else {
|
||||
return;
|
||||
};
|
||||
|
||||
|
|
|
@ -951,6 +951,7 @@ mod tests {
|
|||
ThemeSettings::register(cx);
|
||||
ContextServerSettings::register(cx);
|
||||
EditorSettings::register(cx);
|
||||
language_model::init_settings(cx);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
|
|
|
@ -2,6 +2,8 @@ use assistant_settings::AssistantSettings;
|
|||
use fs::Fs;
|
||||
use gpui::{Entity, FocusHandle, SharedString};
|
||||
|
||||
use crate::Thread;
|
||||
use language_model::{ConfiguredModel, LanguageModelRegistry};
|
||||
use language_model_selector::{
|
||||
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
|
||||
};
|
||||
|
@ -9,7 +11,11 @@ use settings::update_settings_file;
|
|||
use std::sync::Arc;
|
||||
use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
|
||||
|
||||
pub use language_model_selector::ModelType;
|
||||
#[derive(Clone)]
|
||||
pub enum ModelType {
|
||||
Default(Entity<Thread>),
|
||||
InlineAssistant,
|
||||
}
|
||||
|
||||
pub struct AssistantModelSelector {
|
||||
selector: Entity<LanguageModelSelector>,
|
||||
|
@ -24,18 +30,39 @@ impl AssistantModelSelector {
|
|||
focus_handle: FocusHandle,
|
||||
model_type: ModelType,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
Self {
|
||||
selector: cx.new(|cx| {
|
||||
selector: cx.new(move |cx| {
|
||||
let fs = fs.clone();
|
||||
LanguageModelSelector::new(
|
||||
{
|
||||
let model_type = model_type.clone();
|
||||
move |cx| match &model_type {
|
||||
ModelType::Default(thread) => thread.read(cx).configured_model(),
|
||||
ModelType::InlineAssistant => {
|
||||
LanguageModelRegistry::read_global(cx).inline_assistant_model()
|
||||
}
|
||||
}
|
||||
},
|
||||
move |model, cx| {
|
||||
let provider = model.provider_id().0.to_string();
|
||||
let model_id = model.id().0.to_string();
|
||||
|
||||
match model_type {
|
||||
ModelType::Default => {
|
||||
match &model_type {
|
||||
ModelType::Default(thread) => {
|
||||
thread.update(cx, |thread, cx| {
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
if let Some(provider) = registry.provider(&model.provider_id())
|
||||
{
|
||||
thread.set_configured_model(
|
||||
Some(ConfiguredModel {
|
||||
provider,
|
||||
model: model.clone(),
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
});
|
||||
update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
cx,
|
||||
|
@ -58,7 +85,6 @@ impl AssistantModelSelector {
|
|||
}
|
||||
}
|
||||
},
|
||||
model_type,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
|
|
|
@ -1274,12 +1274,12 @@ impl AssistantPanel {
|
|||
let is_generating = thread.is_generating();
|
||||
let message_editor = self.message_editor.read(cx);
|
||||
|
||||
let conversation_token_usage = thread.total_token_usage(cx);
|
||||
let conversation_token_usage = thread.total_token_usage();
|
||||
let (total_token_usage, is_estimating) = if let Some((editing_message_id, unsent_tokens)) =
|
||||
self.thread.read(cx).editing_message_id()
|
||||
{
|
||||
let combined = thread
|
||||
.token_usage_up_to_message(editing_message_id, cx)
|
||||
.token_usage_up_to_message(editing_message_id)
|
||||
.add(unsent_tokens);
|
||||
|
||||
(combined, unsent_tokens > 0)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::assistant_model_selector::AssistantModelSelector;
|
||||
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
|
||||
use crate::buffer_codegen::BufferCodegen;
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::ContextStore;
|
||||
|
@ -20,7 +20,7 @@ use gpui::{
|
|||
Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point,
|
||||
};
|
||||
use language_model::{LanguageModel, LanguageModelRegistry};
|
||||
use language_model_selector::{ModelType, ToggleModelSelector};
|
||||
use language_model_selector::ToggleModelSelector;
|
||||
use parking_lot::Mutex;
|
||||
use settings::Settings;
|
||||
use std::cmp;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::assistant_model_selector::ModelType;
|
||||
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
|
||||
use crate::context::{ContextLoadResult, load_context};
|
||||
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
|
||||
use buffer_diff::BufferDiff;
|
||||
|
@ -21,9 +21,7 @@ use gpui::{
|
|||
Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
|
||||
};
|
||||
use language::{Buffer, Language};
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage, MessageContent,
|
||||
};
|
||||
use language_model::{ConfiguredModel, LanguageModelRequestMessage, MessageContent};
|
||||
use language_model_selector::ToggleModelSelector;
|
||||
use multi_buffer;
|
||||
use project::Project;
|
||||
|
@ -36,7 +34,6 @@ use util::ResultExt as _;
|
|||
use workspace::Workspace;
|
||||
use zed_llm_client::CompletionMode;
|
||||
|
||||
use crate::assistant_model_selector::AssistantModelSelector;
|
||||
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
|
||||
|
@ -153,6 +150,17 @@ impl MessageEditor {
|
|||
}),
|
||||
];
|
||||
|
||||
let model_selector = cx.new(|cx| {
|
||||
AssistantModelSelector::new(
|
||||
fs.clone(),
|
||||
model_selector_menu_handle,
|
||||
editor.focus_handle(cx),
|
||||
ModelType::Default(thread.clone()),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
Self {
|
||||
editor: editor.clone(),
|
||||
project: thread.read(cx).project().clone(),
|
||||
|
@ -165,16 +173,7 @@ impl MessageEditor {
|
|||
context_picker_menu_handle,
|
||||
load_context_task: None,
|
||||
last_loaded_context: None,
|
||||
model_selector: cx.new(|cx| {
|
||||
AssistantModelSelector::new(
|
||||
fs.clone(),
|
||||
model_selector_menu_handle,
|
||||
editor.focus_handle(cx),
|
||||
ModelType::Default,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
model_selector,
|
||||
edits_expanded: false,
|
||||
editor_is_expanded: false,
|
||||
profile_selector: cx
|
||||
|
@ -263,15 +262,11 @@ impl MessageEditor {
|
|||
self.editor.read(cx).text(cx).trim().is_empty()
|
||||
}
|
||||
|
||||
fn is_model_selected(&self, cx: &App) -> bool {
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.is_some()
|
||||
}
|
||||
|
||||
fn send_to_model(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else {
|
||||
let Some(ConfiguredModel { model, provider }) = self
|
||||
.thread
|
||||
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx))
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
|
@ -408,14 +403,13 @@ impl MessageEditor {
|
|||
return None;
|
||||
}
|
||||
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.map(|default| default.model.clone())?;
|
||||
if !model.supports_max_mode() {
|
||||
let thread = self.thread.read(cx);
|
||||
let model = thread.configured_model();
|
||||
if !model?.model.supports_max_mode() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let active_completion_mode = self.thread.read(cx).completion_mode();
|
||||
let active_completion_mode = thread.completion_mode();
|
||||
|
||||
Some(
|
||||
IconButton::new("max-mode", IconName::SquarePlus)
|
||||
|
@ -442,24 +436,21 @@ impl MessageEditor {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Div {
|
||||
let thread = self.thread.read(cx);
|
||||
let model = thread.configured_model();
|
||||
|
||||
let editor_bg_color = cx.theme().colors().editor_background;
|
||||
let is_generating = thread.is_generating();
|
||||
let focus_handle = self.editor.focus_handle(cx);
|
||||
|
||||
let is_model_selected = self.is_model_selected(cx);
|
||||
let is_model_selected = model.is_some();
|
||||
let is_editor_empty = self.is_editor_empty(cx);
|
||||
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.map(|default| default.model.clone());
|
||||
|
||||
let incompatible_tools = model
|
||||
.as_ref()
|
||||
.map(|model| {
|
||||
self.incompatible_tools_state.update(cx, |state, cx| {
|
||||
state
|
||||
.incompatible_tools(model, cx)
|
||||
.incompatible_tools(&model.model, cx)
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
|
@ -1058,7 +1049,7 @@ impl MessageEditor {
|
|||
cx.emit(MessageEditorEvent::Changed);
|
||||
self.update_token_count_task.take();
|
||||
|
||||
let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else {
|
||||
let Some(model) = self.thread.read(cx).configured_model() else {
|
||||
self.last_estimated_token_count.take();
|
||||
return;
|
||||
};
|
||||
|
@ -1111,7 +1102,7 @@ impl MessageEditor {
|
|||
temperature: None,
|
||||
};
|
||||
|
||||
Some(default_model.model.count_tokens(request, cx))
|
||||
Some(model.model.count_tokens(request, cx))
|
||||
})? {
|
||||
task.await?
|
||||
} else {
|
||||
|
@ -1143,7 +1134,7 @@ impl Focusable for MessageEditor {
|
|||
impl Render for MessageEditor {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let thread = self.thread.read(cx);
|
||||
let total_token_usage = thread.total_token_usage(cx);
|
||||
let total_token_usage = thread.total_token_usage();
|
||||
let token_usage_ratio = total_token_usage.ratio();
|
||||
|
||||
let action_log = self.thread.read(cx).action_log();
|
||||
|
|
|
@ -22,8 +22,8 @@ use language_model::{
|
|||
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
|
||||
TokenUsage,
|
||||
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
|
||||
StopReason, TokenUsage,
|
||||
};
|
||||
use postage::stream::Stream as _;
|
||||
use project::Project;
|
||||
|
@ -41,8 +41,8 @@ use zed_llm_client::CompletionMode;
|
|||
use crate::ThreadStore;
|
||||
use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
|
||||
use crate::thread_store::{
|
||||
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
|
||||
SerializedToolUse, SharedProjectContext,
|
||||
SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, SerializedThread,
|
||||
SerializedToolResult, SerializedToolUse, SharedProjectContext,
|
||||
};
|
||||
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
|
||||
|
||||
|
@ -332,6 +332,7 @@ pub struct Thread {
|
|||
Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
|
||||
>,
|
||||
remaining_turns: u32,
|
||||
configured_model: Option<ConfiguredModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
@ -351,6 +352,8 @@ impl Thread {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
|
||||
let configured_model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
|
||||
Self {
|
||||
id: ThreadId::new(),
|
||||
updated_at: Utc::now(),
|
||||
|
@ -388,6 +391,7 @@ impl Thread {
|
|||
last_auto_capture_at: None,
|
||||
request_callback: None,
|
||||
remaining_turns: u32::MAX,
|
||||
configured_model,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -411,6 +415,19 @@ impl Thread {
|
|||
let (detailed_summary_tx, detailed_summary_rx) =
|
||||
postage::watch::channel_with(serialized.detailed_summary_state);
|
||||
|
||||
let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
serialized
|
||||
.model
|
||||
.and_then(|model| {
|
||||
let model = SelectedModel {
|
||||
provider: model.provider.clone().into(),
|
||||
model: model.model.clone().into(),
|
||||
};
|
||||
registry.select_model(&model, cx)
|
||||
})
|
||||
.or_else(|| registry.default_model())
|
||||
});
|
||||
|
||||
Self {
|
||||
id,
|
||||
updated_at: serialized.updated_at,
|
||||
|
@ -468,6 +485,7 @@ impl Thread {
|
|||
last_auto_capture_at: None,
|
||||
request_callback: None,
|
||||
remaining_turns: u32::MAX,
|
||||
configured_model,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -507,6 +525,22 @@ impl Thread {
|
|||
self.project_context.clone()
|
||||
}
|
||||
|
||||
pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
|
||||
if self.configured_model.is_none() {
|
||||
self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
}
|
||||
self.configured_model.clone()
|
||||
}
|
||||
|
||||
pub fn configured_model(&self) -> Option<ConfiguredModel> {
|
||||
self.configured_model.clone()
|
||||
}
|
||||
|
||||
pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
|
||||
self.configured_model = model;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
|
||||
|
||||
pub fn summary_or_default(&self) -> SharedString {
|
||||
|
@ -952,6 +986,13 @@ impl Thread {
|
|||
request_token_usage: this.request_token_usage.clone(),
|
||||
detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
|
||||
exceeded_window_error: this.exceeded_window_error.clone(),
|
||||
model: this
|
||||
.configured_model
|
||||
.as_ref()
|
||||
.map(|model| SerializedLanguageModel {
|
||||
provider: model.provider.id().0.to_string(),
|
||||
model: model.model.id().0.to_string(),
|
||||
}),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -1733,7 +1774,7 @@ impl Thread {
|
|||
tool_use_id.clone(),
|
||||
tool_name,
|
||||
Err(anyhow!("Error parsing input JSON: {error}")),
|
||||
cx,
|
||||
self.configured_model.as_ref(),
|
||||
);
|
||||
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
|
||||
pending_tool_use.ui_text.clone()
|
||||
|
@ -1808,7 +1849,7 @@ impl Thread {
|
|||
tool_use_id.clone(),
|
||||
tool_name,
|
||||
output,
|
||||
cx,
|
||||
thread.configured_model.as_ref(),
|
||||
);
|
||||
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
|
||||
})
|
||||
|
@ -1826,10 +1867,9 @@ impl Thread {
|
|||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if self.all_tools_finished() {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
|
||||
if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
|
||||
if !canceled {
|
||||
self.send_to_model(model, window, cx);
|
||||
self.send_to_model(model.clone(), window, cx);
|
||||
}
|
||||
self.auto_capture_telemetry(cx);
|
||||
}
|
||||
|
@ -2254,8 +2294,8 @@ impl Thread {
|
|||
self.cumulative_token_usage
|
||||
}
|
||||
|
||||
pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
|
||||
pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
|
||||
let Some(model) = self.configured_model.as_ref() else {
|
||||
return TotalTokenUsage::default();
|
||||
};
|
||||
|
||||
|
@ -2283,9 +2323,8 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let Some(model) = model_registry.default_model() else {
|
||||
pub fn total_token_usage(&self) -> TotalTokenUsage {
|
||||
let Some(model) = self.configured_model.as_ref() else {
|
||||
return TotalTokenUsage::default();
|
||||
};
|
||||
|
||||
|
@ -2336,8 +2375,12 @@ impl Thread {
|
|||
"Permission to run tool action denied by user"
|
||||
));
|
||||
|
||||
self.tool_use
|
||||
.insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
|
||||
self.tool_use.insert_tool_output(
|
||||
tool_use_id.clone(),
|
||||
tool_name,
|
||||
err,
|
||||
self.configured_model.as_ref(),
|
||||
);
|
||||
self.tool_finished(tool_use_id.clone(), None, true, window, cx);
|
||||
}
|
||||
}
|
||||
|
@ -2769,6 +2812,7 @@ fn main() {{
|
|||
prompt_store::init(cx);
|
||||
thread_store::init(cx);
|
||||
workspace::init_settings(cx);
|
||||
language_model::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
ContextServerSettings::register(cx);
|
||||
EditorSettings::register(cx);
|
||||
|
|
|
@ -640,6 +640,14 @@ pub struct SerializedThread {
|
|||
pub detailed_summary_state: DetailedSummaryState,
|
||||
#[serde(default)]
|
||||
pub exceeded_window_error: Option<ExceededWindowError>,
|
||||
#[serde(default)]
|
||||
pub model: Option<SerializedLanguageModel>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct SerializedLanguageModel {
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl SerializedThread {
|
||||
|
@ -774,6 +782,7 @@ impl LegacySerializedThread {
|
|||
request_token_usage: Vec::new(),
|
||||
detailed_summary_state: DetailedSummaryState::default(),
|
||||
exceeded_window_error: None,
|
||||
model: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ use futures::FutureExt as _;
|
|||
use futures::future::Shared;
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
|
||||
ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
|
||||
};
|
||||
use ui::IconName;
|
||||
|
@ -353,7 +353,7 @@ impl ToolUseState {
|
|||
tool_use_id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
output: Result<String>,
|
||||
cx: &App,
|
||||
configured_model: Option<&ConfiguredModel>,
|
||||
) -> Option<PendingToolUse> {
|
||||
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
|
||||
|
||||
|
@ -373,13 +373,10 @@ impl ToolUseState {
|
|||
|
||||
match output {
|
||||
Ok(tool_result) => {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
|
||||
const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
|
||||
|
||||
// Protect from clearly large output
|
||||
let tool_output_limit = model_registry
|
||||
.default_model()
|
||||
let tool_output_limit = configured_model
|
||||
.map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
|
||||
.unwrap_or(usize::MAX);
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ use language_model::{
|
|||
ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
|
||||
};
|
||||
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
|
||||
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use parking_lot::Mutex;
|
||||
use project::{CodeAction, LspAction, ProjectTransaction};
|
||||
|
@ -1759,6 +1759,7 @@ impl PromptEditor {
|
|||
language_model_selector: cx.new(|cx| {
|
||||
let fs = fs.clone();
|
||||
LanguageModelSelector::new(
|
||||
|cx| LanguageModelRegistry::read_global(cx).default_model(),
|
||||
move |model, cx| {
|
||||
update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
|
@ -1766,7 +1767,6 @@ impl PromptEditor {
|
|||
move |settings, _| settings.set_model(model.clone()),
|
||||
);
|
||||
},
|
||||
ModelType::Default,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
|
|
|
@ -19,7 +19,7 @@ use language_model::{
|
|||
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
Role, report_assistant_event,
|
||||
};
|
||||
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
|
||||
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
|
||||
use prompt_store::PromptBuilder;
|
||||
use settings::{Settings, update_settings_file};
|
||||
use std::{
|
||||
|
@ -749,6 +749,7 @@ impl PromptEditor {
|
|||
language_model_selector: cx.new(|cx| {
|
||||
let fs = fs.clone();
|
||||
LanguageModelSelector::new(
|
||||
|cx| LanguageModelRegistry::read_global(cx).default_model(),
|
||||
move |model, cx| {
|
||||
update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
|
@ -756,7 +757,6 @@ impl PromptEditor {
|
|||
move |settings, _| settings.set_model(model.clone()),
|
||||
);
|
||||
},
|
||||
ModelType::Default,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
|
|
|
@ -39,7 +39,7 @@ use language_model::{
|
|||
Role,
|
||||
};
|
||||
use language_model_selector::{
|
||||
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType, ToggleModelSelector,
|
||||
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
|
||||
};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use picker::Picker;
|
||||
|
@ -291,6 +291,7 @@ impl ContextEditor {
|
|||
dragged_file_worktrees: Vec::new(),
|
||||
language_model_selector: cx.new(|cx| {
|
||||
LanguageModelSelector::new(
|
||||
|cx| LanguageModelRegistry::read_global(cx).default_model(),
|
||||
move |model, cx| {
|
||||
update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
|
@ -298,7 +299,6 @@ impl ContextEditor {
|
|||
move |settings, _| settings.set_model(model.clone()),
|
||||
);
|
||||
},
|
||||
ModelType::Default,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
|
|
|
@ -39,10 +39,14 @@ pub use crate::telemetry::*;
|
|||
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut App) {
|
||||
registry::init(cx);
|
||||
init_settings(cx);
|
||||
RefreshLlmTokenListener::register(client.clone(), cx);
|
||||
}
|
||||
|
||||
pub fn init_settings(cx: &mut App) {
|
||||
registry::init(cx);
|
||||
}
|
||||
|
||||
/// The availability of a [`LanguageModel`].
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum LanguageModelAvailability {
|
||||
|
|
|
@ -188,7 +188,7 @@ impl LanguageModelRegistry {
|
|||
.collect::<Vec<_>>();
|
||||
}
|
||||
|
||||
fn select_model(
|
||||
pub fn select_model(
|
||||
&mut self,
|
||||
selected_model: &SelectedModel,
|
||||
cx: &mut Context<Self>,
|
||||
|
|
|
@ -22,7 +22,8 @@ action_with_deprecated_aliases!(
|
|||
|
||||
const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
|
||||
|
||||
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
|
||||
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
|
||||
type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
|
||||
|
||||
pub struct LanguageModelSelector {
|
||||
picker: Entity<Picker<LanguageModelPickerDelegate>>,
|
||||
|
@ -30,16 +31,10 @@ pub struct LanguageModelSelector {
|
|||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum ModelType {
|
||||
Default,
|
||||
InlineAssistant,
|
||||
}
|
||||
|
||||
impl LanguageModelSelector {
|
||||
pub fn new(
|
||||
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
|
||||
model_type: ModelType,
|
||||
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
|
||||
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
|
@ -52,9 +47,9 @@ impl LanguageModelSelector {
|
|||
language_model_selector: cx.entity().downgrade(),
|
||||
on_model_changed: on_model_changed.clone(),
|
||||
all_models: Arc::new(all_models),
|
||||
selected_index: Self::get_active_model_index(&entries, model_type, cx),
|
||||
selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
|
||||
filtered_entries: entries,
|
||||
model_type,
|
||||
get_active_model: Arc::new(get_active_model),
|
||||
};
|
||||
|
||||
let picker = cx.new(|cx| {
|
||||
|
@ -204,26 +199,13 @@ impl LanguageModelSelector {
|
|||
}
|
||||
|
||||
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
|
||||
let model_type = self.picker.read(cx).delegate.model_type;
|
||||
Self::active_model_by_type(model_type, cx)
|
||||
}
|
||||
|
||||
fn active_model_by_type(model_type: ModelType, cx: &App) -> Option<ConfiguredModel> {
|
||||
match model_type {
|
||||
ModelType::Default => LanguageModelRegistry::read_global(cx).default_model(),
|
||||
ModelType::InlineAssistant => {
|
||||
LanguageModelRegistry::read_global(cx).inline_assistant_model()
|
||||
}
|
||||
}
|
||||
(self.picker.read(cx).delegate.get_active_model)(cx)
|
||||
}
|
||||
|
||||
fn get_active_model_index(
|
||||
entries: &[LanguageModelPickerEntry],
|
||||
model_type: ModelType,
|
||||
cx: &App,
|
||||
active_model: Option<ConfiguredModel>,
|
||||
) -> usize {
|
||||
let active_model = Self::active_model_by_type(model_type, cx);
|
||||
|
||||
entries
|
||||
.iter()
|
||||
.position(|entry| {
|
||||
|
@ -232,7 +214,7 @@ impl LanguageModelSelector {
|
|||
.as_ref()
|
||||
.map(|active_model| {
|
||||
active_model.model.id() == model.model.id()
|
||||
&& active_model.model.provider_id() == model.model.provider_id()
|
||||
&& active_model.provider.id() == model.model.provider_id()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
|
@ -325,10 +307,10 @@ struct ModelInfo {
|
|||
pub struct LanguageModelPickerDelegate {
|
||||
language_model_selector: WeakEntity<LanguageModelSelector>,
|
||||
on_model_changed: OnModelChanged,
|
||||
get_active_model: GetActiveModel,
|
||||
all_models: Arc<GroupedModels>,
|
||||
filtered_entries: Vec<LanguageModelPickerEntry>,
|
||||
selected_index: usize,
|
||||
model_type: ModelType,
|
||||
}
|
||||
|
||||
struct GroupedModels {
|
||||
|
@ -522,8 +504,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
|||
.into_any_element(),
|
||||
),
|
||||
LanguageModelPickerEntry::Model(model_info) => {
|
||||
let active_model = LanguageModelSelector::active_model_by_type(self.model_type, cx);
|
||||
|
||||
let active_model = (self.get_active_model)(cx);
|
||||
let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
|
||||
let active_model_id = active_model.map(|m| m.model.id());
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue