diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index 170b2268f9..de6ce9da71 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -53,6 +53,8 @@ pub struct OpenAIRequest { pub model: String, pub messages: Vec, pub stream: bool, + pub stop: Vec, + pub temperature: f32, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/templates/generate.rs index 19334340c8..1eeb197f93 100644 --- a/crates/ai/src/templates/generate.rs +++ b/crates/ai/src/templates/generate.rs @@ -78,7 +78,7 @@ impl PromptTemplate for GenerateInlineContent { match file_type { PromptFileType::Code => { - writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap(); + // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap(); } _ => {} } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 4dd4e2a983..ca8c54a285 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -661,6 +661,19 @@ impl AssistantPanel { None }; + // Higher Temperature increases the randomness of model outputs. + // If Markdown or No Language is Known, increase the randomness for more creative output + // If Code, decrease temperature to get more deterministic outputs + let temperature = if let Some(language) = language_name.clone() { + if language.to_string() != "Markdown".to_string() { + 0.5 + } else { + 1.0 + } + } else { + 1.0 + }; + let user_prompt = user_prompt.to_string(); let snippets = if retrieve_context { @@ -731,10 +744,13 @@ impl AssistantPanel { role: Role::User, content: prompt, }); + let request = OpenAIRequest { model: model.full_name().into(), messages, stream: true, + stop: vec!["|END|>".to_string()], + temperature, }; codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx)); anyhow::Ok(()) @@ -1727,6 +1743,8 @@ impl Conversation { .map(|message| message.to_open_ai_message(self.buffer.read(cx))) .collect(), stream: true, + stop: vec![], + temperature: 1.0, }; let stream = stream_completion(api_key, cx.background().clone(), request); @@ -2011,6 +2029,8 @@ impl Conversation { model: self.model.full_name().to_string(), messages: messages.collect(), stream: true, + stop: vec![], + temperature: 1.0, }; let stream = stream_completion(api_key, cx.background().clone(), request);