diff --git a/Cargo.lock b/Cargo.lock index e824bbe8e1..d9a6ac60eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -595,6 +595,7 @@ version = "0.1.0" dependencies = [ "anthropic", "anyhow", + "collections", "deepseek", "feature_flags", "fs", @@ -3006,6 +3007,7 @@ dependencies = [ "anyhow", "assistant", "assistant_context_editor", + "assistant_settings", "assistant_slash_command", "assistant_tool", "async-stripe", diff --git a/assets/settings/default.json b/assets/settings/default.json index eb961095db..f7bc3a4dc3 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -614,13 +614,11 @@ // // Default: main "fallback_branch_name": "main", - // Whether to sort entries in the panel by path // or by status (the default). // // Default: false "sort_by_path": false, - "scrollbar": { // When to show the scrollbar in the git panel. // @@ -670,6 +668,28 @@ // The model to use. "model": "claude-3-7-sonnet-latest" }, + // Additional parameters for language model requests. When making a request to a model, parameters will be taken + // from the last entry in this list that matches the model's provider and name. In each entry, both provider + // and model are optional, so that you can specify parameters for either one. + "model_parameters": [ + // To set parameters for all requests to OpenAI models: + // { + // "provider": "openai", + // "temperature": 0.5 + // } + // + // To set parameters for all requests in general: + // { + // "temperature": 0 + // } + // + // To set parameters for a specific provider and model: + // { + // "provider": "zed.dev", + // "model": "claude-3-7-sonnet-latest", + // "temperature": 1.0 + // } + ], // When enabled, the agent can run potentially destructive actions without asking for your confirmation. "always_allow_tool_actions": false, // When enabled, the agent will stream edits. diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 140294cbbc..39061f57f3 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1417,7 +1417,10 @@ impl ActiveThread { messages: vec![request_message], tools: vec![], stop: vec![], - temperature: None, + temperature: AssistantSettings::temperature_for_model( + &configured_model.model, + cx, + ), }; Some(configured_model.model.count_tokens(request, cx)) diff --git a/crates/agent/src/buffer_codegen.rs b/crates/agent/src/buffer_codegen.rs index 11fa641c03..ccb402996b 100644 --- a/crates/agent/src/buffer_codegen.rs +++ b/crates/agent/src/buffer_codegen.rs @@ -2,6 +2,7 @@ use crate::context::ContextLoadResult; use crate::inline_prompt_editor::CodegenStatus; use crate::{context::load_context, context_store::ContextStore}; use anyhow::Result; +use assistant_settings::AssistantSettings; use client::telemetry::Telemetry; use collections::HashSet; use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint}; @@ -383,7 +384,7 @@ impl CodegenAlternative { if user_prompt.trim().to_lowercase() == "delete" { async { Ok(LanguageModelTextStream::default()) }.boxed_local() } else { - let request = self.build_request(user_prompt, cx)?; + let request = self.build_request(&model, user_prompt, cx)?; cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await) .boxed_local() }; @@ -393,6 +394,7 @@ impl CodegenAlternative { fn build_request( &self, + model: &Arc, user_prompt: String, cx: &mut App, ) -> Result> { @@ -441,6 +443,8 @@ impl CodegenAlternative { } }); + let temperature = AssistantSettings::temperature_for_model(&model, cx); + Ok(cx.spawn(async move |_cx| { let mut request_message = LanguageModelRequestMessage { role: Role::User, @@ -463,7 +467,7 @@ impl CodegenAlternative { mode: None, tools: Vec::new(), stop: Vec::new(), - temperature: None, + temperature, messages: vec![request_message], } })) diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index 3ba8bdb1ac..912824a88f 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -8,7 +8,7 @@ use crate::ui::{ AnimatedLabel, MaxModeTooltip, preview::{AgentPreview, UsageCallout}, }; -use assistant_settings::CompletionMode; +use assistant_settings::{AssistantSettings, CompletionMode}; use buffer_diff::BufferDiff; use client::UserStore; use collections::{HashMap, HashSet}; @@ -1273,7 +1273,7 @@ impl MessageEditor { messages: vec![request_message], tools: vec![], stop: vec![], - temperature: None, + temperature: AssistantSettings::temperature_for_model(&model.model, cx), }; Some(model.model.count_tokens(request, cx)) diff --git a/crates/agent/src/terminal_inline_assistant.rs b/crates/agent/src/terminal_inline_assistant.rs index e9c24b50da..4d54ce5241 100644 --- a/crates/agent/src/terminal_inline_assistant.rs +++ b/crates/agent/src/terminal_inline_assistant.rs @@ -6,6 +6,7 @@ use crate::inline_prompt_editor::{ use crate::terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen}; use crate::thread_store::{TextThreadStore, ThreadStore}; use anyhow::{Context as _, Result}; +use assistant_settings::AssistantSettings; use client::telemetry::Telemetry; use collections::{HashMap, VecDeque}; use editor::{MultiBuffer, actions::SelectAll}; @@ -266,6 +267,12 @@ impl TerminalInlineAssistant { load_context(contexts, project, &assist.prompt_store, cx) })?; + let ConfiguredModel { model, .. } = LanguageModelRegistry::read_global(cx) + .inline_assistant_model() + .context("No inline assistant model")?; + + let temperature = AssistantSettings::temperature_for_model(&model, cx); + Ok(cx.background_spawn(async move { let mut request_message = LanguageModelRequestMessage { role: Role::User, @@ -287,7 +294,7 @@ impl TerminalInlineAssistant { messages: vec![request_message], tools: Vec::new(), stop: Vec::new(), - temperature: None, + temperature, } })) } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index a2ed103c87..dacde7cda4 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1145,7 +1145,7 @@ impl Thread { messages: vec![], tools: Vec::new(), stop: Vec::new(), - temperature: None, + temperature: AssistantSettings::temperature_for_model(&model, cx), }; let available_tools = self.available_tools(cx, model.clone()); @@ -1251,7 +1251,12 @@ impl Thread { request } - fn to_summarize_request(&self, added_user_message: String) -> LanguageModelRequest { + fn to_summarize_request( + &self, + model: &Arc, + added_user_message: String, + cx: &App, + ) -> LanguageModelRequest { let mut request = LanguageModelRequest { thread_id: None, prompt_id: None, @@ -1259,7 +1264,7 @@ impl Thread { messages: vec![], tools: Vec::new(), stop: Vec::new(), - temperature: None, + temperature: AssistantSettings::temperature_for_model(model, cx), }; for message in &self.messages { @@ -1696,7 +1701,7 @@ impl Thread { If the conversation is about a specific subject, include it in the title. \ Be descriptive. DO NOT speak in the first person."; - let request = self.to_summarize_request(added_user_message.into()); + let request = self.to_summarize_request(&model.model, added_user_message.into(), cx); self.pending_summary = cx.spawn(async move |this, cx| { async move { @@ -1782,7 +1787,7 @@ impl Thread { 4. Any action items or next steps if any\n\ Format it in Markdown with headings and bullet points."; - let request = self.to_summarize_request(added_user_message.into()); + let request = self.to_summarize_request(&model, added_user_message.into(), cx); *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating { message_id: last_message_id, @@ -2655,7 +2660,7 @@ struct PendingCompletion { mod tests { use super::*; use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store}; - use assistant_settings::AssistantSettings; + use assistant_settings::{AssistantSettings, LanguageModelParameters}; use assistant_tool::ToolRegistry; use editor::EditorSettings; use gpui::TestAppContext; @@ -3066,6 +3071,100 @@ fn main() {{ ); } + #[gpui::test] + async fn test_temperature_setting(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project( + cx, + json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), + ) + .await; + + let (_workspace, _thread_store, thread, _context_store, model) = + setup_test_environment(cx, project.clone()).await; + + // Both model and provider + cx.update(|cx| { + AssistantSettings::override_global( + AssistantSettings { + model_parameters: vec![LanguageModelParameters { + provider: Some(model.provider_id().0.to_string().into()), + model: Some(model.id().0.clone()), + temperature: Some(0.66), + }], + ..AssistantSettings::get_global(cx).clone() + }, + cx, + ); + }); + + let request = thread.update(cx, |thread, cx| { + thread.to_completion_request(model.clone(), cx) + }); + assert_eq!(request.temperature, Some(0.66)); + + // Only model + cx.update(|cx| { + AssistantSettings::override_global( + AssistantSettings { + model_parameters: vec![LanguageModelParameters { + provider: None, + model: Some(model.id().0.clone()), + temperature: Some(0.66), + }], + ..AssistantSettings::get_global(cx).clone() + }, + cx, + ); + }); + + let request = thread.update(cx, |thread, cx| { + thread.to_completion_request(model.clone(), cx) + }); + assert_eq!(request.temperature, Some(0.66)); + + // Only provider + cx.update(|cx| { + AssistantSettings::override_global( + AssistantSettings { + model_parameters: vec![LanguageModelParameters { + provider: Some(model.provider_id().0.to_string().into()), + model: None, + temperature: Some(0.66), + }], + ..AssistantSettings::get_global(cx).clone() + }, + cx, + ); + }); + + let request = thread.update(cx, |thread, cx| { + thread.to_completion_request(model.clone(), cx) + }); + assert_eq!(request.temperature, Some(0.66)); + + // Same model name, different provider + cx.update(|cx| { + AssistantSettings::override_global( + AssistantSettings { + model_parameters: vec![LanguageModelParameters { + provider: Some("anthropic".into()), + model: Some(model.id().0.clone()), + temperature: Some(0.66), + }], + ..AssistantSettings::get_global(cx).clone() + }, + cx, + ); + }); + + let request = thread.update(cx, |thread, cx| { + thread.to_completion_request(model.clone(), cx) + }); + assert_eq!(request.temperature, None); + } + fn init_test_settings(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index f54a761037..1d759091e8 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -163,7 +163,7 @@ fn update_active_language_model_from_settings(cx: &mut App) { fn to_selected_model(selection: &LanguageModelSelection) -> language_model::SelectedModel { language_model::SelectedModel { - provider: LanguageModelProviderId::from(selection.provider.clone()), + provider: LanguageModelProviderId::from(selection.provider.0.clone()), model: LanguageModelId::from(selection.model.clone()), } } diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index 9d1ffe9e2d..ac8d09330d 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -2484,7 +2484,7 @@ impl InlineAssist { .read(cx) .active_context(cx)? .read(cx) - .to_completion_request(RequestType::Chat, cx), + .to_completion_request(None, RequestType::Chat, cx), ) } else { None @@ -2870,7 +2870,8 @@ impl CodegenAlternative { if let Some(ConfiguredModel { model, .. }) = LanguageModelRegistry::read_global(cx).inline_assistant_model() { - let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx); + let request = + self.build_request(&model, user_prompt, assistant_panel_context.clone(), cx); match request { Ok(request) => { let total_count = model.count_tokens(request.clone(), cx); @@ -2915,7 +2916,8 @@ impl CodegenAlternative { if user_prompt.trim().to_lowercase() == "delete" { async { Ok(LanguageModelTextStream::default()) }.boxed_local() } else { - let request = self.build_request(user_prompt, assistant_panel_context, cx)?; + let request = + self.build_request(&model, user_prompt, assistant_panel_context, cx)?; self.request = Some(request.clone()); cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await) @@ -2927,6 +2929,7 @@ impl CodegenAlternative { fn build_request( &self, + model: &Arc, user_prompt: String, assistant_panel_context: Option, cx: &App, @@ -2981,7 +2984,7 @@ impl CodegenAlternative { messages, tools: Vec::new(), stop: Vec::new(), - temperature: None, + temperature: AssistantSettings::temperature_for_model(&model, cx), }) } diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index 19ace0877e..0844deb60e 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -261,7 +261,7 @@ impl TerminalInlineAssistant { .read(cx) .active_context(cx)? .read(cx) - .to_completion_request(RequestType::Chat, cx), + .to_completion_request(None, RequestType::Chat, cx), ) }) } else { diff --git a/crates/assistant_context_editor/src/context.rs b/crates/assistant_context_editor/src/context.rs index d0b8c32cc2..31b37ae5fd 100644 --- a/crates/assistant_context_editor/src/context.rs +++ b/crates/assistant_context_editor/src/context.rs @@ -3,6 +3,7 @@ mod context_tests; use crate::patch::{AssistantEdit, AssistantPatch, AssistantPatchStatus}; use anyhow::{Context as _, Result, anyhow}; +use assistant_settings::AssistantSettings; use assistant_slash_command::{ SlashCommandContent, SlashCommandEvent, SlashCommandLine, SlashCommandOutputSection, SlashCommandResult, SlashCommandWorkingSet, @@ -1273,10 +1274,10 @@ impl AssistantContext { pub(crate) fn count_remaining_tokens(&mut self, cx: &mut Context) { // Assume it will be a Chat request, even though that takes fewer tokens (and risks going over the limit), // because otherwise you see in the UI that your empty message has a bunch of tokens already used. - let request = self.to_completion_request(RequestType::Chat, cx); let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else { return; }; + let request = self.to_completion_request(Some(&model.model), RequestType::Chat, cx); let debounce = self.token_count.is_some(); self.pending_token_count = cx.spawn(async move |this, cx| { async move { @@ -1422,7 +1423,7 @@ impl AssistantContext { } let request = { - let mut req = self.to_completion_request(RequestType::Chat, cx); + let mut req = self.to_completion_request(Some(&model), RequestType::Chat, cx); // Skip the last message because it's likely to change and // therefore would be a waste to cache. req.messages.pop(); @@ -2321,7 +2322,7 @@ impl AssistantContext { // Compute which messages to cache, including the last one. self.mark_cache_anchors(&model.cache_configuration(), false, cx); - let request = self.to_completion_request(request_type, cx); + let request = self.to_completion_request(Some(&model), request_type, cx); let assistant_message = self .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) @@ -2561,6 +2562,7 @@ impl AssistantContext { pub fn to_completion_request( &self, + model: Option<&Arc>, request_type: RequestType, cx: &App, ) -> LanguageModelRequest { @@ -2584,7 +2586,8 @@ impl AssistantContext { messages: Vec::new(), tools: Vec::new(), stop: Vec::new(), - temperature: None, + temperature: model + .and_then(|model| AssistantSettings::temperature_for_model(model, cx)), }; for message in self.messages(cx) { if message.status != MessageStatus::Done { @@ -2981,7 +2984,7 @@ impl AssistantContext { return; } - let mut request = self.to_completion_request(RequestType::Chat, cx); + let mut request = self.to_completion_request(Some(&model.model), RequestType::Chat, cx); request.messages.push(LanguageModelRequestMessage { role: Role::User, content: vec![ diff --git a/crates/assistant_context_editor/src/context/context_tests.rs b/crates/assistant_context_editor/src/context/context_tests.rs index 12c8003f6b..66da203886 100644 --- a/crates/assistant_context_editor/src/context/context_tests.rs +++ b/crates/assistant_context_editor/src/context/context_tests.rs @@ -43,9 +43,8 @@ use workspace::Workspace; #[gpui::test] fn test_inserting_and_removing_messages(cx: &mut App) { - let settings_store = SettingsStore::test(cx); - LanguageModelRegistry::test(cx); - cx.set_global(settings_store); + init_test(cx); + let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let context = cx.new(|cx| { @@ -182,9 +181,8 @@ fn test_inserting_and_removing_messages(cx: &mut App) { #[gpui::test] fn test_message_splitting(cx: &mut App) { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - LanguageModelRegistry::test(cx); + init_test(cx); + let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); @@ -285,9 +283,8 @@ fn test_message_splitting(cx: &mut App) { #[gpui::test] fn test_messages_for_offsets(cx: &mut App) { - let settings_store = SettingsStore::test(cx); - LanguageModelRegistry::test(cx); - cx.set_global(settings_store); + init_test(cx); + let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let context = cx.new(|cx| { @@ -378,10 +375,8 @@ fn test_messages_for_offsets(cx: &mut App) { #[gpui::test] async fn test_slash_commands(cx: &mut TestAppContext) { - let settings_store = cx.update(SettingsStore::test); - cx.set_global(settings_store); - cx.update(LanguageModelRegistry::test); - cx.update(Project::init_settings); + cx.update(init_test); + let fs = FakeFs::new(cx.background_executor.clone()); fs.insert_tree( @@ -671,22 +666,19 @@ async fn test_slash_commands(cx: &mut TestAppContext) { #[gpui::test] async fn test_workflow_step_parsing(cx: &mut TestAppContext) { - cx.update(prompt_store::init); - let mut settings_store = cx.update(SettingsStore::test); cx.update(|cx| { - settings_store - .set_user_settings( - r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#, - cx, - ) - .unwrap() + init_test(cx); + cx.update_global(|settings_store: &mut SettingsStore, cx| { + settings_store + .set_user_settings( + r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#, + cx, + ) + .unwrap() + }) }); - cx.set_global(settings_store); - cx.update(language::init); - cx.update(Project::init_settings); let fs = FakeFs::new(cx.executor()); let project = Project::test(fs, [Path::new("/root")], cx).await; - cx.update(LanguageModelRegistry::test); let registry = Arc::new(LanguageRegistry::test(cx.executor())); @@ -1069,9 +1061,8 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) { #[gpui::test] async fn test_serialization(cx: &mut TestAppContext) { - let settings_store = cx.update(SettingsStore::test); - cx.set_global(settings_store); - cx.update(LanguageModelRegistry::test); + cx.update(init_test); + let registry = Arc::new(LanguageRegistry::test(cx.executor())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let context = cx.new(|cx| { @@ -1147,6 +1138,8 @@ async fn test_serialization(cx: &mut TestAppContext) { #[gpui::test(iterations = 100)] async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) { + cx.update(init_test); + let min_peers = env::var("MIN_PEERS") .map(|i| i.parse().expect("invalid `MIN_PEERS` variable")) .unwrap_or(2); @@ -1157,10 +1150,6 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) .unwrap_or(50); - let settings_store = cx.update(SettingsStore::test); - cx.set_global(settings_store); - cx.update(LanguageModelRegistry::test); - let slash_commands = cx.update(SlashCommandRegistry::default_global); slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false); slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false); @@ -1429,9 +1418,8 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std #[gpui::test] fn test_mark_cache_anchors(cx: &mut App) { - let settings_store = SettingsStore::test(cx); - LanguageModelRegistry::test(cx); - cx.set_global(settings_store); + init_test(cx); + let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let context = cx.new(|cx| { @@ -1606,6 +1594,16 @@ fn messages_cache( .collect() } +fn init_test(cx: &mut App) { + let settings_store = SettingsStore::test(cx); + prompt_store::init(cx); + LanguageModelRegistry::test(cx); + cx.set_global(settings_store); + language::init(cx); + assistant_settings::init(cx); + Project::init_settings(cx); +} + #[derive(Clone)] struct FakeSlashCommand(String); diff --git a/crates/assistant_settings/Cargo.toml b/crates/assistant_settings/Cargo.toml index d38845bc77..98eb8954df 100644 --- a/crates/assistant_settings/Cargo.toml +++ b/crates/assistant_settings/Cargo.toml @@ -14,6 +14,7 @@ path = "src/assistant_settings.rs" [dependencies] anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true +collections.workspace = true feature_flags.workspace = true gpui.workspace = true indexmap.workspace = true diff --git a/crates/assistant_settings/src/agent_profile.rs b/crates/assistant_settings/src/agent_profile.rs index 7c9b3bf2a7..df6b4b21c2 100644 --- a/crates/assistant_settings/src/agent_profile.rs +++ b/crates/assistant_settings/src/agent_profile.rs @@ -1,7 +1,7 @@ use std::sync::Arc; +use collections::IndexMap; use gpui::SharedString; -use indexmap::IndexMap; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index a51040861f..00106ed490 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -5,10 +5,10 @@ use std::sync::Arc; use ::open_ai::Model as OpenAiModel; use anthropic::Model as AnthropicModel; use anyhow::{Result, bail}; +use collections::IndexMap; use deepseek::Model as DeepseekModel; use feature_flags::{AgentStreamEditsFeatureFlag, Assistant2FeatureFlag, FeatureFlagAppExt}; -use gpui::{App, Pixels}; -use indexmap::IndexMap; +use gpui::{App, Pixels, SharedString}; use language_model::{CloudModel, LanguageModel}; use lmstudio::Model as LmStudioModel; use ollama::Model as OllamaModel; @@ -18,6 +18,10 @@ use settings::{Settings, SettingsSources}; pub use crate::agent_profile::*; +pub fn init(cx: &mut App) { + AssistantSettings::register(cx); +} + #[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] pub enum AssistantDockPosition { @@ -89,10 +93,20 @@ pub struct AssistantSettings { pub notify_when_agent_waiting: NotifyWhenAgentWaiting, pub stream_edits: bool, pub single_file_review: bool, + pub model_parameters: Vec, pub preferred_completion_mode: CompletionMode, } impl AssistantSettings { + pub fn temperature_for_model(model: &Arc, cx: &App) -> Option { + let settings = Self::get_global(cx); + settings + .model_parameters + .iter() + .rfind(|setting| setting.matches(model)) + .and_then(|m| m.temperature) + } + pub fn stream_edits(&self, cx: &App) -> bool { cx.has_flag::() || self.stream_edits } @@ -106,15 +120,47 @@ impl AssistantSettings { } pub fn set_inline_assistant_model(&mut self, provider: String, model: String) { - self.inline_assistant_model = Some(LanguageModelSelection { provider, model }); + self.inline_assistant_model = Some(LanguageModelSelection { + provider: provider.into(), + model, + }); } pub fn set_commit_message_model(&mut self, provider: String, model: String) { - self.commit_message_model = Some(LanguageModelSelection { provider, model }); + self.commit_message_model = Some(LanguageModelSelection { + provider: provider.into(), + model, + }); } pub fn set_thread_summary_model(&mut self, provider: String, model: String) { - self.thread_summary_model = Some(LanguageModelSelection { provider, model }); + self.thread_summary_model = Some(LanguageModelSelection { + provider: provider.into(), + model, + }); + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] +pub struct LanguageModelParameters { + pub provider: Option, + pub model: Option, + pub temperature: Option, +} + +impl LanguageModelParameters { + pub fn matches(&self, model: &Arc) -> bool { + if let Some(provider) = &self.provider { + if provider.0 != model.provider_id().0 { + return false; + } + } + if let Some(setting_model) = &self.model { + if *setting_model != model.id().0 { + return false; + } + } + true } } @@ -181,37 +227,37 @@ impl AssistantSettingsContent { .and_then(|provider| match provider { AssistantProviderContentV1::ZedDotDev { default_model } => { default_model.map(|model| LanguageModelSelection { - provider: "zed.dev".to_string(), + provider: "zed.dev".into(), model: model.id().to_string(), }) } AssistantProviderContentV1::OpenAi { default_model, .. } => { default_model.map(|model| LanguageModelSelection { - provider: "openai".to_string(), + provider: "openai".into(), model: model.id().to_string(), }) } AssistantProviderContentV1::Anthropic { default_model, .. } => { default_model.map(|model| LanguageModelSelection { - provider: "anthropic".to_string(), + provider: "anthropic".into(), model: model.id().to_string(), }) } AssistantProviderContentV1::Ollama { default_model, .. } => { default_model.map(|model| LanguageModelSelection { - provider: "ollama".to_string(), + provider: "ollama".into(), model: model.id().to_string(), }) } AssistantProviderContentV1::LmStudio { default_model, .. } => { default_model.map(|model| LanguageModelSelection { - provider: "lmstudio".to_string(), + provider: "lmstudio".into(), model: model.id().to_string(), }) } AssistantProviderContentV1::DeepSeek { default_model, .. } => { default_model.map(|model| LanguageModelSelection { - provider: "deepseek".to_string(), + provider: "deepseek".into(), model: model.id().to_string(), }) } @@ -227,6 +273,7 @@ impl AssistantSettingsContent { notify_when_agent_waiting: None, stream_edits: None, single_file_review: None, + model_parameters: Vec::new(), preferred_completion_mode: None, }, VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(), @@ -238,7 +285,7 @@ impl AssistantSettingsContent { default_width: settings.default_width, default_height: settings.default_height, default_model: Some(LanguageModelSelection { - provider: "openai".to_string(), + provider: "openai".into(), model: settings .default_open_ai_model .clone() @@ -257,6 +304,7 @@ impl AssistantSettingsContent { notify_when_agent_waiting: None, stream_edits: None, single_file_review: None, + model_parameters: Vec::new(), preferred_completion_mode: None, }, None => AssistantSettingsContentV2::default(), @@ -370,7 +418,10 @@ impl AssistantSettingsContent { } } VersionedAssistantSettingsContent::V2(ref mut settings) => { - settings.default_model = Some(LanguageModelSelection { provider, model }); + settings.default_model = Some(LanguageModelSelection { + provider: provider.into(), + model, + }); } }, Some(AssistantSettingsContentInner::Legacy(settings)) => { @@ -381,7 +432,10 @@ impl AssistantSettingsContent { None => { self.inner = Some(AssistantSettingsContentInner::for_v2( AssistantSettingsContentV2 { - default_model: Some(LanguageModelSelection { provider, model }), + default_model: Some(LanguageModelSelection { + provider: provider.into(), + model, + }), ..Default::default() }, )); @@ -391,7 +445,10 @@ impl AssistantSettingsContent { pub fn set_inline_assistant_model(&mut self, provider: String, model: String) { self.v2_setting(|setting| { - setting.inline_assistant_model = Some(LanguageModelSelection { provider, model }); + setting.inline_assistant_model = Some(LanguageModelSelection { + provider: provider.into(), + model, + }); Ok(()) }) .ok(); @@ -399,7 +456,10 @@ impl AssistantSettingsContent { pub fn set_commit_message_model(&mut self, provider: String, model: String) { self.v2_setting(|setting| { - setting.commit_message_model = Some(LanguageModelSelection { provider, model }); + setting.commit_message_model = Some(LanguageModelSelection { + provider: provider.into(), + model, + }); Ok(()) }) .ok(); @@ -427,7 +487,10 @@ impl AssistantSettingsContent { pub fn set_thread_summary_model(&mut self, provider: String, model: String) { self.v2_setting(|setting| { - setting.thread_summary_model = Some(LanguageModelSelection { provider, model }); + setting.thread_summary_model = Some(LanguageModelSelection { + provider: provider.into(), + model, + }); Ok(()) }) .ok(); @@ -523,6 +586,7 @@ impl Default for VersionedAssistantSettingsContent { notify_when_agent_waiting: None, stream_edits: None, single_file_review: None, + model_parameters: Vec::new(), preferred_completion_mode: None, }) } @@ -587,6 +651,15 @@ pub struct AssistantSettingsContentV2 { /// /// Default: true single_file_review: Option, + /// Additional parameters for language model requests. When making a request + /// to a model, parameters will be taken from the last entry in this list + /// that matches the model's provider and name. In each entry, both provider + /// and model are optional, so that you can specify parameters for either + /// one. + /// + /// Default: [] + #[serde(default)] + model_parameters: Vec, /// What completion mode to enable for new threads /// @@ -613,33 +686,53 @@ impl From for zed_llm_client::CompletionMode { #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] pub struct LanguageModelSelection { - #[schemars(schema_with = "providers_schema")] - pub provider: String, + pub provider: LanguageModelProviderSetting, pub model: String, } -fn providers_schema(_: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema { - schemars::schema::SchemaObject { - enum_values: Some(vec![ - "anthropic".into(), - "bedrock".into(), - "google".into(), - "lmstudio".into(), - "ollama".into(), - "openai".into(), - "zed.dev".into(), - "copilot_chat".into(), - "deepseek".into(), - ]), - ..Default::default() +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct LanguageModelProviderSetting(pub String); + +impl JsonSchema for LanguageModelProviderSetting { + fn schema_name() -> String { + "LanguageModelProviderSetting".into() + } + + fn json_schema(_: &mut schemars::r#gen::SchemaGenerator) -> Schema { + schemars::schema::SchemaObject { + enum_values: Some(vec![ + "anthropic".into(), + "bedrock".into(), + "google".into(), + "lmstudio".into(), + "ollama".into(), + "openai".into(), + "zed.dev".into(), + "copilot_chat".into(), + "deepseek".into(), + ]), + ..Default::default() + } + .into() + } +} + +impl From for LanguageModelProviderSetting { + fn from(provider: String) -> Self { + Self(provider) + } +} + +impl From<&str> for LanguageModelProviderSetting { + fn from(provider: &str) -> Self { + Self(provider.to_string()) } - .into() } impl Default for LanguageModelSelection { fn default() -> Self { Self { - provider: "openai".to_string(), + provider: LanguageModelProviderSetting("openai".to_string()), model: "gpt-4".to_string(), } } @@ -781,6 +874,10 @@ impl Settings for AssistantSettings { value.preferred_completion_mode, ); + settings + .model_parameters + .extend_from_slice(&value.model_parameters); + if let Some(profiles) = value.profiles { settings .profiles @@ -913,6 +1010,7 @@ mod tests { notify_when_agent_waiting: None, stream_edits: None, single_file_review: None, + model_parameters: Vec::new(), preferred_completion_mode: None, }, )), @@ -976,7 +1074,7 @@ mod tests { AssistantSettingsContentV2 { enabled: Some(false), default_model: Some(LanguageModelSelection { - provider: "xai".to_owned(), + provider: "xai".to_owned().into(), model: "grok".to_owned(), }), ..Default::default() diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 0fc1422129..1a78394f5f 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -78,6 +78,7 @@ zed_llm_client.workspace = true [dev-dependencies] assistant = { workspace = true, features = ["test-support"] } assistant_context_editor.workspace = true +assistant_settings.workspace = true assistant_slash_command.workspace = true assistant_tool.workspace = true async-trait.workspace = true diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 38134dd863..da02e6c32a 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -307,6 +307,7 @@ impl TestServer { ); language_model::LanguageModelRegistry::test(cx); assistant_context_editor::init(client.clone(), cx); + assistant_settings::init(cx); }); client diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index e35a25e025..6ef0fbea5b 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -1735,6 +1735,8 @@ impl GitPanel { } }); + let temperature = AssistantSettings::temperature_for_model(&model, cx); + self.generate_commit_message_task = Some(cx.spawn(async move |this, cx| { async move { let _defer = cx.on_drop(&this, |this, _cx| { @@ -1773,7 +1775,7 @@ impl GitPanel { }], tools: Vec::new(), stop: Vec::new(), - temperature: None, + temperature, }; let stream = model.stream_completion_text(request, &cx); diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index 899d325689..abbb237b4f 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -87,8 +87,8 @@ pub struct AllLanguageModelSettingsContent { #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[serde(untagged)] pub enum AnthropicSettingsContent { - Legacy(LegacyAnthropicSettingsContent), Versioned(VersionedAnthropicSettingsContent), + Legacy(LegacyAnthropicSettingsContent), } impl AnthropicSettingsContent { @@ -197,8 +197,8 @@ pub struct MistralSettingsContent { #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[serde(untagged)] pub enum OpenAiSettingsContent { - Legacy(LegacyOpenAiSettingsContent), Versioned(VersionedOpenAiSettingsContent), + Legacy(LegacyOpenAiSettingsContent), } impl OpenAiSettingsContent { diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index 464cc0a42d..be76898601 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -3530,7 +3530,7 @@ impl LspStore { ) .detach(); } else { - log::info!("No extension events global found. Skipping JSON schema auto-reload setup"); + log::debug!("No extension events global found. Skipping JSON schema auto-reload setup"); } cx.observe_global::(Self::on_settings_changed) .detach(); diff --git a/crates/worktree/src/worktree.rs b/crates/worktree/src/worktree.rs index a114f384ba..bda946498d 100644 --- a/crates/worktree/src/worktree.rs +++ b/crates/worktree/src/worktree.rs @@ -3871,7 +3871,7 @@ impl BackgroundScanner { Some(ancestor_dot_git) }); - log::info!("containing git repository: {containing_git_repository:?}"); + log::trace!("containing git repository: {containing_git_repository:?}"); let (scan_job_tx, scan_job_rx) = channel::unbounded(); { diff --git a/docs/src/configuring-zed.md b/docs/src/configuring-zed.md index b5fd910c43..92ad474dfd 100644 --- a/docs/src/configuring-zed.md +++ b/docs/src/configuring-zed.md @@ -3070,14 +3070,14 @@ Run the `theme selector: toggle` action in the command palette to see a current } ``` -## Assistant Panel +## Agent -- Description: Customize assistant panel -- Setting: `assistant` +- Description: Customize agent behavior +- Setting: `agent` - Default: ```json -"assistant": { +"agent": { "version": "2", "enabled": true, "button": true,