agent: Allow customizing temperature by provider/model (#30033)
Adds a new `agent.model_parameters` setting that allows the user to specify a custom temperature for a provider AND/OR model: ```json5 "model_parameters": [ // To set parameters for all requests to OpenAI models: { "provider": "openai", "temperature": 0.5 }, // To set parameters for all requests in general: { "temperature": 0 }, // To set parameters for a specific provider and model: { "provider": "zed.dev", "model": "claude-3-7-sonnet-latest", "temperature": 1.0 } ], ``` Release Notes: - agent: Allow customizing temperature by provider/model --------- Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com> Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
0055a20512
commit
3cdf5ce947
22 changed files with 348 additions and 106 deletions
|
@ -3,6 +3,7 @@ mod context_tests;
|
|||
|
||||
use crate::patch::{AssistantEdit, AssistantPatch, AssistantPatchStatus};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_slash_command::{
|
||||
SlashCommandContent, SlashCommandEvent, SlashCommandLine, SlashCommandOutputSection,
|
||||
SlashCommandResult, SlashCommandWorkingSet,
|
||||
|
@ -1273,10 +1274,10 @@ impl AssistantContext {
|
|||
pub(crate) fn count_remaining_tokens(&mut self, cx: &mut Context<Self>) {
|
||||
// Assume it will be a Chat request, even though that takes fewer tokens (and risks going over the limit),
|
||||
// because otherwise you see in the UI that your empty message has a bunch of tokens already used.
|
||||
let request = self.to_completion_request(RequestType::Chat, cx);
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
|
||||
return;
|
||||
};
|
||||
let request = self.to_completion_request(Some(&model.model), RequestType::Chat, cx);
|
||||
let debounce = self.token_count.is_some();
|
||||
self.pending_token_count = cx.spawn(async move |this, cx| {
|
||||
async move {
|
||||
|
@ -1422,7 +1423,7 @@ impl AssistantContext {
|
|||
}
|
||||
|
||||
let request = {
|
||||
let mut req = self.to_completion_request(RequestType::Chat, cx);
|
||||
let mut req = self.to_completion_request(Some(&model), RequestType::Chat, cx);
|
||||
// Skip the last message because it's likely to change and
|
||||
// therefore would be a waste to cache.
|
||||
req.messages.pop();
|
||||
|
@ -2321,7 +2322,7 @@ impl AssistantContext {
|
|||
// Compute which messages to cache, including the last one.
|
||||
self.mark_cache_anchors(&model.cache_configuration(), false, cx);
|
||||
|
||||
let request = self.to_completion_request(request_type, cx);
|
||||
let request = self.to_completion_request(Some(&model), request_type, cx);
|
||||
|
||||
let assistant_message = self
|
||||
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
||||
|
@ -2561,6 +2562,7 @@ impl AssistantContext {
|
|||
|
||||
pub fn to_completion_request(
|
||||
&self,
|
||||
model: Option<&Arc<dyn LanguageModel>>,
|
||||
request_type: RequestType,
|
||||
cx: &App,
|
||||
) -> LanguageModelRequest {
|
||||
|
@ -2584,7 +2586,8 @@ impl AssistantContext {
|
|||
messages: Vec::new(),
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
temperature: model
|
||||
.and_then(|model| AssistantSettings::temperature_for_model(model, cx)),
|
||||
};
|
||||
for message in self.messages(cx) {
|
||||
if message.status != MessageStatus::Done {
|
||||
|
@ -2981,7 +2984,7 @@ impl AssistantContext {
|
|||
return;
|
||||
}
|
||||
|
||||
let mut request = self.to_completion_request(RequestType::Chat, cx);
|
||||
let mut request = self.to_completion_request(Some(&model.model), RequestType::Chat, cx);
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![
|
||||
|
|
|
@ -43,9 +43,8 @@ use workspace::Workspace;
|
|||
|
||||
#[gpui::test]
|
||||
fn test_inserting_and_removing_messages(cx: &mut App) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
init_test(cx);
|
||||
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let context = cx.new(|cx| {
|
||||
|
@ -182,9 +181,8 @@ fn test_inserting_and_removing_messages(cx: &mut App) {
|
|||
|
||||
#[gpui::test]
|
||||
fn test_message_splitting(cx: &mut App) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
LanguageModelRegistry::test(cx);
|
||||
init_test(cx);
|
||||
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
|
@ -285,9 +283,8 @@ fn test_message_splitting(cx: &mut App) {
|
|||
|
||||
#[gpui::test]
|
||||
fn test_messages_for_offsets(cx: &mut App) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
init_test(cx);
|
||||
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let context = cx.new(|cx| {
|
||||
|
@ -378,10 +375,8 @@ fn test_messages_for_offsets(cx: &mut App) {
|
|||
|
||||
#[gpui::test]
|
||||
async fn test_slash_commands(cx: &mut TestAppContext) {
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(LanguageModelRegistry::test);
|
||||
cx.update(Project::init_settings);
|
||||
cx.update(init_test);
|
||||
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
|
||||
fs.insert_tree(
|
||||
|
@ -671,22 +666,19 @@ async fn test_slash_commands(cx: &mut TestAppContext) {
|
|||
|
||||
#[gpui::test]
|
||||
async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
|
||||
cx.update(prompt_store::init);
|
||||
let mut settings_store = cx.update(SettingsStore::test);
|
||||
cx.update(|cx| {
|
||||
settings_store
|
||||
.set_user_settings(
|
||||
r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
|
||||
cx,
|
||||
)
|
||||
.unwrap()
|
||||
init_test(cx);
|
||||
cx.update_global(|settings_store: &mut SettingsStore, cx| {
|
||||
settings_store
|
||||
.set_user_settings(
|
||||
r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
|
||||
cx,
|
||||
)
|
||||
.unwrap()
|
||||
})
|
||||
});
|
||||
cx.set_global(settings_store);
|
||||
cx.update(language::init);
|
||||
cx.update(Project::init_settings);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [Path::new("/root")], cx).await;
|
||||
cx.update(LanguageModelRegistry::test);
|
||||
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||
|
||||
|
@ -1069,9 +1061,8 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
|
|||
|
||||
#[gpui::test]
|
||||
async fn test_serialization(cx: &mut TestAppContext) {
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(LanguageModelRegistry::test);
|
||||
cx.update(init_test);
|
||||
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let context = cx.new(|cx| {
|
||||
|
@ -1147,6 +1138,8 @@ async fn test_serialization(cx: &mut TestAppContext) {
|
|||
|
||||
#[gpui::test(iterations = 100)]
|
||||
async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||
cx.update(init_test);
|
||||
|
||||
let min_peers = env::var("MIN_PEERS")
|
||||
.map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
|
||||
.unwrap_or(2);
|
||||
|
@ -1157,10 +1150,6 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std
|
|||
.map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
|
||||
.unwrap_or(50);
|
||||
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(LanguageModelRegistry::test);
|
||||
|
||||
let slash_commands = cx.update(SlashCommandRegistry::default_global);
|
||||
slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
|
||||
slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
|
||||
|
@ -1429,9 +1418,8 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std
|
|||
|
||||
#[gpui::test]
|
||||
fn test_mark_cache_anchors(cx: &mut App) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
init_test(cx);
|
||||
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let context = cx.new(|cx| {
|
||||
|
@ -1606,6 +1594,16 @@ fn messages_cache(
|
|||
.collect()
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut App) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
prompt_store::init(cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
assistant_settings::init(cx);
|
||||
Project::init_settings(cx);
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct FakeSlashCommand(String);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue