Remember max mode setting per-thread and add a user setting (#30042)

Supersedes: https://github.com/zed-industries/zed/pull/29936

Thanks for your contribution @imumesh18, but we had a slightly different
take on it :)

Release Notes:

- N/A

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Mikayla Maki 2025-05-06 13:11:21 -07:00 committed by GitHub
parent 6bb6e48171
commit 0055a20512
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 52 additions and 17 deletions

View file

@ -8,6 +8,7 @@ use crate::ui::{
AnimatedLabel, MaxModeTooltip,
preview::{AgentPreview, UsageCallout},
};
use assistant_settings::CompletionMode;
use buffer_diff::BufferDiff;
use client::UserStore;
use collections::{HashMap, HashSet};
@ -42,7 +43,6 @@ use ui::{Disclosure, DocumentationSide, KeyBinding, PopoverMenuHandle, Tooltip,
use util::{ResultExt as _, maybe};
use workspace::dock::DockPosition;
use workspace::{CollaboratorId, Workspace};
use zed_llm_client::CompletionMode;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
use crate::context_store::ContextStore;

View file

@ -5,7 +5,7 @@ use std::sync::Arc;
use std::time::Instant;
use anyhow::{Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_settings::{AssistantSettings, CompletionMode};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
@ -37,7 +37,7 @@ use settings::Settings;
use thiserror::Error;
use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid;
use zed_llm_client::{CompletionMode, CompletionRequestStatus};
use zed_llm_client::CompletionRequestStatus;
use crate::ThreadStore;
use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
@ -312,14 +312,6 @@ pub enum TokenUsageRatio {
Exceeded,
}
fn default_completion_mode(cx: &App) -> CompletionMode {
if cx.is_staff() {
CompletionMode::Max
} else {
CompletionMode::Normal
}
}
#[derive(Debug, Clone, Copy)]
pub enum QueueState {
Sending,
@ -336,7 +328,7 @@ pub struct Thread {
detailed_summary_task: Task<Option<()>>,
detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
completion_mode: CompletionMode,
completion_mode: assistant_settings::CompletionMode,
messages: Vec<Message>,
next_message_id: MessageId,
last_prompt_id: PromptId,
@ -395,7 +387,7 @@ impl Thread {
detailed_summary_task: Task::ready(None),
detailed_summary_tx,
detailed_summary_rx,
completion_mode: default_completion_mode(cx),
completion_mode: AssistantSettings::get_global(cx).preferred_completion_mode,
messages: Vec::new(),
next_message_id: MessageId(0),
last_prompt_id: PromptId::new(),
@ -464,6 +456,10 @@ impl Thread {
.or_else(|| registry.default_model())
});
let completion_mode = serialized
.completion_mode
.unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
Self {
id,
updated_at: serialized.updated_at,
@ -472,7 +468,7 @@ impl Thread {
detailed_summary_task: Task::ready(None),
detailed_summary_tx,
detailed_summary_rx,
completion_mode: default_completion_mode(cx),
completion_mode,
messages: serialized
.messages
.into_iter()
@ -1095,6 +1091,7 @@ impl Thread {
provider: model.provider.id().0.to_string(),
model: model.model.id().0.to_string(),
}),
completion_mode: Some(this.completion_mode),
})
})
}
@ -1246,9 +1243,9 @@ impl Thread {
request.tools = available_tools;
request.mode = if model.supports_max_mode() {
Some(self.completion_mode)
Some(self.completion_mode.into())
} else {
Some(CompletionMode::Normal)
Some(CompletionMode::Normal.into())
};
request

View file

@ -5,7 +5,7 @@ use std::rc::Rc;
use std::sync::Arc;
use anyhow::{Context as _, Result, anyhow};
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, CompletionMode};
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
@ -651,6 +651,8 @@ pub struct SerializedThread {
pub exceeded_window_error: Option<ExceededWindowError>,
#[serde(default)]
pub model: Option<SerializedLanguageModel>,
#[serde(default)]
pub completion_mode: Option<CompletionMode>,
}
#[derive(Serialize, Deserialize, Debug)]
@ -794,6 +796,7 @@ impl LegacySerializedThread {
detailed_summary_state: DetailedSummaryState::default(),
exceeded_window_error: None,
model: None,
completion_mode: None,
}
}
}