Systematically optimize agentic editing performance (#28961)
Now that we've established a proper eval in tree, this PR is reboots of our agent loop back to a set of minimal tools and simpler prompts. We should aim to get this branch feeling subjectively competitive with what's on main and then merge it, and build from there. Let's invest in our eval and use it to drive better performance of the agent loop. How you can help: Pick an example, and then make the outcome faster or better. It's fine to even use your own subjective judgment, as our evaluation criteria likely need tuning as well at this point. Focus on making the agent work better in your own subjective experience first. Let's focus on simple/practical improvements to make this thing work better, then determine how we can craft our judgment criteria to lock those improvements in. Release Notes: - N/A --------- Co-authored-by: Max <max@zed.dev> Co-authored-by: Antonio <antonio@zed.dev> Co-authored-by: Agus <agus@zed.dev> Co-authored-by: Richard <richard@zed.dev> Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com> Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Michael Sloan <mgsloan@gmail.com>
This commit is contained in:
parent
8102a16747
commit
bab28560ef
68 changed files with 1575 additions and 478 deletions
|
@ -10,14 +10,16 @@ use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
|
|||
use handlebars::Handlebars;
|
||||
use language::{DiagnosticSeverity, OffsetRangeExt};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
|
||||
StopReason, TokenUsage,
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
MessageContent, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use project::{LspStore, Project, ProjectPath};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cell::RefCell;
|
||||
use std::fmt::Write as _;
|
||||
use std::fs::File;
|
||||
use std::io::Write as _;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
use std::{
|
||||
|
@ -45,6 +47,19 @@ pub struct ExampleBase {
|
|||
pub insert_id: Option<String>,
|
||||
#[serde(default = "default_true")]
|
||||
pub require_lsp: bool,
|
||||
#[serde(default)]
|
||||
pub allow_preexisting_diagnostics: bool,
|
||||
}
|
||||
|
||||
impl ExampleBase {
|
||||
pub fn repo_name(&self) -> String {
|
||||
self.url
|
||||
.split('/')
|
||||
.next_back()
|
||||
.unwrap_or(&"")
|
||||
.trim_end_matches(".git")
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -54,14 +69,12 @@ pub struct Example {
|
|||
pub base: ExampleBase,
|
||||
/// Content of `prompt.md`
|
||||
pub prompt: String,
|
||||
/// Content of `criteria.md`
|
||||
pub criteria: String,
|
||||
/// Markdown output file to append to
|
||||
pub output_file: Option<Arc<Mutex<File>>>,
|
||||
/// Path to the output run directory.
|
||||
pub run_dir: PathBuf,
|
||||
/// Path to markdown output file
|
||||
pub output_file_path: PathBuf,
|
||||
/// Content of `diff_criteria.md`
|
||||
pub diff_criteria: String,
|
||||
/// Content of `thread_criteria.md`, if that file exists (it's optional)
|
||||
pub thread_criteria: Option<String>,
|
||||
/// Path to the directory containing the requests and responses for the agentic loop
|
||||
pub run_directory_path: PathBuf,
|
||||
/// Prefix used for logging that identifies this example
|
||||
pub log_prefix: String,
|
||||
}
|
||||
|
@ -69,41 +82,65 @@ pub struct Example {
|
|||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct RunOutput {
|
||||
pub repository_diff: String,
|
||||
pub diagnostics: String,
|
||||
pub ran_diagnostics_check: bool,
|
||||
pub diagnostics_before: Option<String>,
|
||||
pub diagnostics_after: Option<String>,
|
||||
pub response_count: usize,
|
||||
pub token_usage: TokenUsage,
|
||||
pub tool_use_counts: HashMap<Arc<str>, u32>,
|
||||
pub last_request: LanguageModelRequest,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JudgeInput {
|
||||
pub struct JudgeDiffInput {
|
||||
pub repository_diff: String,
|
||||
pub ran_diagnostics_check: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub diagnostics_before: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub diagnostics_after: Option<String>,
|
||||
pub criteria: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JudgeOutput {
|
||||
pub struct JudgeThreadInput {
|
||||
pub messages: String,
|
||||
pub criteria: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JudgeResponse {
|
||||
pub analysis: String,
|
||||
pub score: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JudgeOutput {
|
||||
pub thread: Option<JudgeResponse>,
|
||||
pub diff: JudgeResponse,
|
||||
}
|
||||
|
||||
impl Example {
|
||||
/// Load an example from a directory containing base.toml, prompt.md, and criteria.md
|
||||
pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
|
||||
let name = Self::name_from_path(dir_path);
|
||||
let base_path = dir_path.join("base.toml");
|
||||
let prompt_path = dir_path.join("prompt.md");
|
||||
let criteria_path = dir_path.join("criteria.md");
|
||||
let output_file_path = run_dir.join(format!("{}.md", name));
|
||||
let diff_criteria_path = dir_path.join("diff_criteria.md");
|
||||
let thread_criteria_path = dir_path.join("thread_criteria.md");
|
||||
let thread_criteria = if thread_criteria_path.exists() {
|
||||
Some(fs::read_to_string(thread_criteria_path.clone())?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Example {
|
||||
name: name.clone(),
|
||||
base: toml::from_str(&fs::read_to_string(&base_path)?)?,
|
||||
prompt: fs::read_to_string(prompt_path.clone())?,
|
||||
criteria: fs::read_to_string(criteria_path.clone())?,
|
||||
run_dir: run_dir.to_path_buf(),
|
||||
output_file: None,
|
||||
output_file_path,
|
||||
thread_criteria,
|
||||
diff_criteria: fs::read_to_string(diff_criteria_path.clone())?,
|
||||
run_directory_path: run_dir.to_path_buf(),
|
||||
log_prefix: name,
|
||||
})
|
||||
}
|
||||
|
@ -111,10 +148,13 @@ impl Example {
|
|||
pub fn set_repetition_number(&mut self, repetition_number: u32) {
|
||||
if repetition_number > 0 {
|
||||
self.name = format!("{}-{}", self.name, repetition_number);
|
||||
self.output_file_path = self.run_dir.join(format!("{}.md", self.name));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn example_output_directory(&self) -> PathBuf {
|
||||
self.run_directory_path.join(&self.name)
|
||||
}
|
||||
|
||||
pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
|
||||
self.log_prefix = format!(
|
||||
"{}{:<width$}\x1b[0m | ",
|
||||
|
@ -134,6 +174,7 @@ impl Example {
|
|||
.context(format!("No such directory {WORKTREES_DIR}"))
|
||||
.unwrap()
|
||||
.join(&self.name)
|
||||
.join(self.base.repo_name())
|
||||
}
|
||||
|
||||
/// Set up the example by checking out the specified Git revision
|
||||
|
@ -187,20 +228,11 @@ impl Example {
|
|||
.await?;
|
||||
}
|
||||
|
||||
// Create the output file
|
||||
let output_file = Arc::new(Mutex::new(File::create(&self.output_file_path)?));
|
||||
self.output_file = Some(output_file);
|
||||
std::fs::create_dir_all(self.example_output_directory())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the output file, panicking if it's not set
|
||||
fn output_file(&self) -> Arc<Mutex<File>> {
|
||||
self.output_file
|
||||
.clone()
|
||||
.expect("Output file not created. Call setup() first.")
|
||||
}
|
||||
|
||||
pub fn run(
|
||||
&self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
|
@ -305,6 +337,11 @@ impl Example {
|
|||
None
|
||||
};
|
||||
|
||||
let diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?;
|
||||
if diagnostics_before.is_some() && !this.base.allow_preexisting_diagnostics {
|
||||
return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`"));
|
||||
}
|
||||
|
||||
if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
|
||||
return Err(anyhow!("Setup only mode"));
|
||||
}
|
||||
|
@ -312,15 +349,32 @@ impl Example {
|
|||
let thread_store = thread_store.await?;
|
||||
let thread =
|
||||
thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
|
||||
let last_request = Rc::new(RefCell::new(None));
|
||||
|
||||
{
|
||||
let output_file_ref = this.output_file();
|
||||
let mut output_file = output_file_ref.lock().unwrap();
|
||||
writeln!(&mut output_file, "👤 USER:").log_err();
|
||||
writeln!(&mut output_file, "{}", this.prompt).log_err();
|
||||
writeln!(&mut output_file, "🤖 ASSISTANT:").log_err();
|
||||
output_file.flush().log_err();
|
||||
}
|
||||
thread.update(cx, |thread, _cx| {
|
||||
let mut request_count = 0;
|
||||
let example_dir_path = this.example_output_directory();
|
||||
|
||||
let last_request = Rc::clone(&last_request);
|
||||
thread.set_request_callback(move |request, response_events| {
|
||||
*last_request.borrow_mut() = Some(request.clone());
|
||||
|
||||
request_count += 1;
|
||||
let messages_file_path = example_dir_path.join(format!("{request_count}.messages.md"));
|
||||
let last_messages_file_path = example_dir_path.join("last.messages.md");
|
||||
let request_markdown = RequestMarkdown::new(request);
|
||||
let response_events_markdown = response_events_to_markdown(response_events);
|
||||
|
||||
let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown);
|
||||
fs::write(messages_file_path, messages.clone()).expect("failed to write messages file");
|
||||
fs::write(last_messages_file_path, messages).expect("failed to write last messages file");
|
||||
|
||||
if request_count == 1 {
|
||||
let tools_file_path = example_dir_path.join("tools.md");
|
||||
fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file");
|
||||
}
|
||||
});
|
||||
})?;
|
||||
|
||||
let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
|
||||
Mutex::new(HashMap::default()).into();
|
||||
|
@ -332,8 +386,6 @@ impl Example {
|
|||
});
|
||||
|
||||
let event_handler_task = cx.spawn({
|
||||
// Need to clone the Arc here because the reference from output_file() won't live long enough
|
||||
let output_file = this.output_file.clone().unwrap();
|
||||
let log_prefix = this.log_prefix.clone();
|
||||
let tool_use_counts = tool_use_counts.clone();
|
||||
let thread = thread.downgrade();
|
||||
|
@ -349,8 +401,6 @@ impl Example {
|
|||
return Err(anyhow!("ThreadEvent channel ended early"));
|
||||
};
|
||||
|
||||
let mut output_file = output_file.lock().unwrap();
|
||||
|
||||
match event {
|
||||
ThreadEvent::Stopped(reason) => match reason {
|
||||
Ok(StopReason::EndTurn) => {
|
||||
|
@ -371,18 +421,7 @@ impl Example {
|
|||
ThreadEvent::ShowError(thread_error) => {
|
||||
break Err(anyhow!(thread_error.clone()));
|
||||
}
|
||||
ThreadEvent::StreamedAssistantText(_, chunk) => {
|
||||
write!(&mut output_file, "{}", chunk).log_err();
|
||||
}
|
||||
ThreadEvent::StreamedAssistantThinking(_, chunk) => {
|
||||
write!(&mut output_file, "{}", chunk).log_err();
|
||||
}
|
||||
ThreadEvent::UsePendingTools { tool_uses } => {
|
||||
writeln!(&mut output_file, "\n\nUSING TOOLS:").log_err();
|
||||
for tool_use in tool_uses {
|
||||
writeln!(&mut output_file, "{}: {}", tool_use.name, tool_use.input)
|
||||
.log_err();
|
||||
}
|
||||
ThreadEvent::StreamedAssistantText(_, _)| ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } => {
|
||||
}
|
||||
ThreadEvent::ToolFinished {
|
||||
tool_use_id,
|
||||
|
@ -398,8 +437,6 @@ impl Example {
|
|||
format!("TOOL FINISHED: {}", tool_use.name)
|
||||
};
|
||||
println!("{log_prefix}{message}");
|
||||
writeln!(&mut output_file, "\n{}", message).log_err();
|
||||
writeln!(&mut output_file, "\n{}\n", tool_result.content).log_err();
|
||||
let mut tool_use_counts = tool_use_counts.lock().unwrap();
|
||||
*tool_use_counts
|
||||
.entry(tool_result.tool_name.clone())
|
||||
|
@ -407,7 +444,6 @@ impl Example {
|
|||
} else {
|
||||
let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
|
||||
println!("{log_prefix}{message}");
|
||||
writeln!(&mut output_file, "\n{}", message).log_err();
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
@ -428,8 +464,6 @@ impl Example {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
output_file.flush().log_err();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
@ -451,21 +485,35 @@ impl Example {
|
|||
println!("{}Getting repository diff", this.log_prefix);
|
||||
let repository_diff = this.repository_diff().await?;
|
||||
|
||||
let repository_diff_path = this.run_dir.join(format!("{}.diff", this.name));
|
||||
let example_output_dir = this.example_output_directory();
|
||||
let repository_diff_path = example_output_dir.join("patch.diff");
|
||||
let mut repository_diff_output_file = File::create(&repository_diff_path)?;
|
||||
writeln!(&mut repository_diff_output_file, "{}", &repository_diff).log_err();
|
||||
|
||||
println!("{}Getting diagnostics", this.log_prefix);
|
||||
let diagnostics = cx
|
||||
let diagnostics_after = cx
|
||||
.update(move |cx| {
|
||||
cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
|
||||
})?
|
||||
.await?;
|
||||
println!("{}Got diagnostics", this.log_prefix);
|
||||
|
||||
let Some(last_request) = last_request.borrow_mut().take() else {
|
||||
return Err(anyhow!("No requests ran."));
|
||||
};
|
||||
|
||||
drop(subscription);
|
||||
drop(lsp_open_handle_and_store);
|
||||
|
||||
if let Some(diagnostics_before) = &diagnostics_before {
|
||||
fs::write(example_output_dir.join("diagnostics_before.txt"), diagnostics_before)?;
|
||||
}
|
||||
|
||||
if let Some(diagnostics_after) = &diagnostics_after {
|
||||
fs::write(example_output_dir.join("diagnostics_after.txt"), diagnostics_after)?;
|
||||
}
|
||||
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
let response_count = thread
|
||||
.messages()
|
||||
|
@ -473,31 +521,38 @@ impl Example {
|
|||
.count();
|
||||
RunOutput {
|
||||
repository_diff,
|
||||
diagnostics,
|
||||
ran_diagnostics_check: this.base.require_lsp,
|
||||
diagnostics_before,
|
||||
diagnostics_after,
|
||||
response_count,
|
||||
token_usage: thread.cumulative_token_usage(),
|
||||
tool_use_counts: tool_use_counts.lock().unwrap().clone(),
|
||||
last_request,
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn judge(
|
||||
async fn judge_diff(
|
||||
&self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
repository_diff: String,
|
||||
judge_repetitions: u32,
|
||||
run_output: &RunOutput,
|
||||
judge_number: u32,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<JudgeOutput> {
|
||||
let judge_prompt = include_str!("judge_prompt.hbs");
|
||||
let judge_prompt_name = "judge_prompt";
|
||||
let mut handlebars = Handlebars::new();
|
||||
handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
|
||||
let prompt = handlebars.render(
|
||||
judge_prompt_name,
|
||||
&JudgeInput {
|
||||
repository_diff,
|
||||
criteria: self.criteria.clone(),
|
||||
) -> Result<(String, JudgeResponse)> {
|
||||
let judge_diff_prompt = include_str!("judge_diff_prompt.hbs");
|
||||
let judge_diff_prompt_name = "judge_diff_prompt";
|
||||
let mut hbs = Handlebars::new();
|
||||
hbs.register_template_string(judge_diff_prompt_name, judge_diff_prompt)?;
|
||||
|
||||
let diff_prompt = hbs.render(
|
||||
judge_diff_prompt_name,
|
||||
&JudgeDiffInput {
|
||||
repository_diff: run_output.repository_diff.clone(),
|
||||
ran_diagnostics_check: run_output.ran_diagnostics_check,
|
||||
diagnostics_before: run_output.diagnostics_before.clone(),
|
||||
diagnostics_after: run_output.diagnostics_after.clone(),
|
||||
criteria: self.diff_criteria.clone(),
|
||||
},
|
||||
)?;
|
||||
|
||||
|
@ -506,7 +561,7 @@ impl Example {
|
|||
prompt_id: None,
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text(prompt)],
|
||||
content: vec![MessageContent::Text(diff_prompt)],
|
||||
cache: false,
|
||||
}],
|
||||
temperature: None,
|
||||
|
@ -514,24 +569,106 @@ impl Example {
|
|||
stop: Vec::new(),
|
||||
};
|
||||
|
||||
let response = send_language_model_request(model, request, cx).await?;
|
||||
let diff_response = send_language_model_request(model, request, cx).await?;
|
||||
let diff_output = JudgeResponse::parse(&diff_response)?;
|
||||
|
||||
let judge_file_path = self.run_dir.join(format!(
|
||||
"{}_judge_{}.md",
|
||||
self.name, // This is the eval_name
|
||||
judge_repetitions
|
||||
));
|
||||
println!(
|
||||
"{}Judge #{judge_number} - Diff score: {}",
|
||||
self.log_prefix, diff_output.score
|
||||
);
|
||||
|
||||
let mut judge_output_file = File::create(&judge_file_path)?;
|
||||
writeln!(&mut judge_output_file, "{}", &response).log_err();
|
||||
|
||||
parse_judge_output(&response)
|
||||
Ok((diff_response, diff_output))
|
||||
}
|
||||
|
||||
pub async fn repository_diff(&self) -> Result<String> {
|
||||
async fn judge_thread(
|
||||
&self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
run_output: &RunOutput,
|
||||
judge_number: u32,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<(String, Option<JudgeResponse>)> {
|
||||
if let Some(criteria) = self.thread_criteria.clone() {
|
||||
let judge_thread_prompt = include_str!("judge_thread_prompt.hbs");
|
||||
let judge_thread_prompt_name = "judge_thread_prompt";
|
||||
let mut hbs = Handlebars::new();
|
||||
hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)?;
|
||||
|
||||
let request_markdown = RequestMarkdown::new(&run_output.last_request);
|
||||
let thread_prompt = hbs.render(
|
||||
judge_thread_prompt_name,
|
||||
&JudgeThreadInput {
|
||||
messages: request_markdown.messages,
|
||||
criteria,
|
||||
},
|
||||
)?;
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text(thread_prompt)],
|
||||
cache: false,
|
||||
}],
|
||||
temperature: None,
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
};
|
||||
|
||||
let thread_response = send_language_model_request(model, request, cx).await?;
|
||||
let thread_output = JudgeResponse::parse(&thread_response)?;
|
||||
|
||||
println!(
|
||||
"{}Judge #{judge_number} - Thread score: {}",
|
||||
self.log_prefix, thread_output.score
|
||||
);
|
||||
|
||||
Ok((thread_response, Some(thread_output)))
|
||||
} else {
|
||||
let msg = "There were no criteria specified for this thread, so this example was not judged on its thread.".to_string();
|
||||
Ok((msg, None))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn judge(
|
||||
&self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
run_output: &RunOutput,
|
||||
judge_number: u32,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<JudgeOutput> {
|
||||
let mut output_file = File::create(
|
||||
self.example_output_directory()
|
||||
.join(format!("judge_{}.md", judge_number)),
|
||||
)
|
||||
.expect("failed to create judge.md");
|
||||
|
||||
println!("{}Running judge #{judge_number}", self.log_prefix);
|
||||
|
||||
let diff_task = self.judge_diff(model.clone(), &run_output, judge_number, cx);
|
||||
let thread_task = self.judge_thread(model.clone(), &run_output, judge_number, cx);
|
||||
|
||||
let (diff_result, thread_result) = futures::join!(diff_task, thread_task);
|
||||
|
||||
let (diff_response, diff_output) = diff_result?;
|
||||
let (thread_response, thread_output) = thread_result?;
|
||||
|
||||
writeln!(
|
||||
&mut output_file,
|
||||
"# Judgment\n\n## Thread\n\n{thread_response}\n\n## Diff\n\n{diff_response}",
|
||||
)
|
||||
.log_err();
|
||||
|
||||
Ok(JudgeOutput {
|
||||
thread: thread_output,
|
||||
diff: diff_output,
|
||||
})
|
||||
}
|
||||
|
||||
async fn repository_diff(&self) -> Result<String> {
|
||||
let worktree_path = self.worktree_path();
|
||||
run_git(&worktree_path, &["add", "-N"]).await?;
|
||||
run_git(&worktree_path, &["diff"]).await
|
||||
run_git(&worktree_path, &["add", "."]).await?;
|
||||
run_git(&worktree_path, &["diff", "--staged"]).await
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -599,7 +736,10 @@ fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool
|
|||
.any(|(_, status)| !status.pending_work.is_empty())
|
||||
}
|
||||
|
||||
async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
|
||||
async fn query_lsp_diagnostics(
|
||||
project: Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Option<String>> {
|
||||
let paths_with_diagnostics = project.update(cx, |project, cx| {
|
||||
project
|
||||
.diagnostic_summaries(true, cx)
|
||||
|
@ -608,6 +748,10 @@ async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> R
|
|||
.collect::<Vec<_>>()
|
||||
})?;
|
||||
|
||||
if paths_with_diagnostics.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut output = String::new();
|
||||
for project_path in paths_with_diagnostics {
|
||||
let buffer = project
|
||||
|
@ -633,16 +777,18 @@ async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> R
|
|||
)?;
|
||||
}
|
||||
}
|
||||
anyhow::Ok(output)
|
||||
anyhow::Ok(Some(output))
|
||||
}
|
||||
|
||||
fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
|
||||
let analysis = get_tag("analysis", response)?.to_string();
|
||||
let score = get_tag("score", response)?
|
||||
.parse()
|
||||
.context("error parsing score")?;
|
||||
impl JudgeResponse {
|
||||
fn parse(response: &str) -> Result<Self> {
|
||||
let analysis = get_tag("analysis", response)?.to_string();
|
||||
let score = get_tag("score", response)?
|
||||
.parse()
|
||||
.context("error parsing score")?;
|
||||
|
||||
Ok(JudgeOutput { analysis, score })
|
||||
Ok(Self { analysis, score })
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tag(name: &'static str, response: &str) -> Result<String> {
|
||||
|
@ -724,9 +870,135 @@ pub async fn send_language_model_request(
|
|||
}
|
||||
}
|
||||
|
||||
struct RequestMarkdown {
|
||||
tools: String,
|
||||
messages: String,
|
||||
}
|
||||
|
||||
impl RequestMarkdown {
|
||||
fn new(request: &LanguageModelRequest) -> Self {
|
||||
let mut tools = String::new();
|
||||
let mut messages = String::new();
|
||||
|
||||
// Print the tools
|
||||
if !request.tools.is_empty() {
|
||||
for tool in &request.tools {
|
||||
write!(&mut tools, "# {}\n\n", tool.name).unwrap();
|
||||
write!(&mut tools, "{}\n\n", tool.description).unwrap();
|
||||
write!(
|
||||
&mut tools,
|
||||
"```json\n{}\n```\n\n",
|
||||
serde_json::to_string_pretty(&tool.input_schema).unwrap_or_default()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// Print the messages
|
||||
for message in &request.messages {
|
||||
let role_str = match message.role {
|
||||
Role::User => "👤 USER",
|
||||
Role::Assistant => "🤖 ASSISTANT",
|
||||
Role::System => "⚙️ SYSTEM",
|
||||
};
|
||||
|
||||
messages.push_str(&format!("# {}\n\n", role_str));
|
||||
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
messages.push_str(text);
|
||||
messages.push_str("\n\n");
|
||||
}
|
||||
MessageContent::Image(_) => {
|
||||
messages.push_str("[IMAGE DATA]\n\n");
|
||||
}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
messages.push_str(&format!(
|
||||
"**Tool Use**: {} (ID: {})\n",
|
||||
tool_use.name, tool_use.id
|
||||
));
|
||||
messages.push_str(&format!("```json\n{}\n```\n\n", tool_use.input));
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
messages.push_str(&format!(
|
||||
"**Tool Result**: {} (ID: {})\n\n",
|
||||
tool_result.tool_name, tool_result.tool_use_id
|
||||
));
|
||||
if tool_result.is_error {
|
||||
messages.push_str("**ERROR:**\n");
|
||||
}
|
||||
messages.push_str(&format!("{}\n", tool_result.content));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self { tools, messages }
|
||||
}
|
||||
}
|
||||
|
||||
fn response_events_to_markdown(
|
||||
response_events: &[std::result::Result<LanguageModelCompletionEvent, String>],
|
||||
) -> String {
|
||||
let mut response = String::new();
|
||||
// Print the response events if any
|
||||
response.push_str("# Response\n\n");
|
||||
let mut text_buffer = String::new();
|
||||
let mut thinking_buffer = String::new();
|
||||
|
||||
let flush_buffers =
|
||||
|output: &mut String, text_buffer: &mut String, thinking_buffer: &mut String| {
|
||||
if !text_buffer.is_empty() {
|
||||
output.push_str(&format!("**Text**:\n{}\n\n", text_buffer));
|
||||
text_buffer.clear();
|
||||
}
|
||||
if !thinking_buffer.is_empty() {
|
||||
output.push_str(&format!("**Thinking**:\n{}\n\n", thinking_buffer));
|
||||
thinking_buffer.clear();
|
||||
}
|
||||
};
|
||||
|
||||
for event in response_events {
|
||||
match event {
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => {
|
||||
text_buffer.push_str(text);
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Thinking(text)) => {
|
||||
thinking_buffer.push_str(text);
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Stop(reason)) => {
|
||||
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
|
||||
response.push_str(&format!("**Stop**: {:?}\n\n", reason));
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
|
||||
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
|
||||
response.push_str(&format!(
|
||||
"**Tool Use**: {} (ID: {})\n",
|
||||
tool_use.name, tool_use.id
|
||||
));
|
||||
response.push_str(&format!("```json\n{}\n```\n\n", tool_use.input));
|
||||
}
|
||||
Ok(
|
||||
LanguageModelCompletionEvent::UsageUpdate(_)
|
||||
| LanguageModelCompletionEvent::StartMessage { .. },
|
||||
) => {}
|
||||
Err(error) => {
|
||||
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
|
||||
response.push_str(&format!("**Error**: {}\n\n", error));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use handlebars::Handlebars;
|
||||
|
||||
#[test]
|
||||
fn test_parse_judge_output() {
|
||||
|
@ -736,7 +1008,7 @@ mod test {
|
|||
"#
|
||||
.unindent();
|
||||
|
||||
let output = parse_judge_output(&response).unwrap();
|
||||
let output = JudgeResponse::parse(&response).unwrap();
|
||||
assert_eq!(
|
||||
output.analysis,
|
||||
"The model did a good job but there were still compilations errors."
|
||||
|
@ -756,8 +1028,158 @@ mod test {
|
|||
"#
|
||||
.unindent();
|
||||
|
||||
let output = parse_judge_output(&response).unwrap();
|
||||
let output = JudgeResponse::parse(&response).unwrap();
|
||||
assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
|
||||
assert_eq!(output.score, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_judge_prompt_with_diagnostics() {
|
||||
// Case 1: Both diagnostics before and after are present
|
||||
let input = JudgeDiffInput {
|
||||
repository_diff: "diff content goes here".to_string(),
|
||||
ran_diagnostics_check: true,
|
||||
diagnostics_before: Some("Error at line 10: variable not found".to_string()),
|
||||
diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
|
||||
criteria: "Fix all bugs".to_string(),
|
||||
};
|
||||
|
||||
let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
|
||||
|
||||
let expected_diagnostics_section = r#"
|
||||
Take into account the diagnostics before and after applying the change:
|
||||
|
||||
<diagnostics_before>
|
||||
Error at line 10: variable not found
|
||||
</diagnostics_before>
|
||||
|
||||
<diagnostics_after>
|
||||
Error at line 15: missing semicolon
|
||||
</diagnostics_after>
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert!(rendered.contains(&expected_diagnostics_section));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_judge_prompt_with_empty_diagnostics() {
|
||||
// Case 2: Diagnostics check run but no diagnostics found
|
||||
let input = JudgeDiffInput {
|
||||
repository_diff: "diff content goes here".to_string(),
|
||||
ran_diagnostics_check: true,
|
||||
diagnostics_before: None,
|
||||
diagnostics_after: None,
|
||||
criteria: "Fix all bugs".to_string(),
|
||||
};
|
||||
|
||||
let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
|
||||
|
||||
let expected_diagnostics_section = r#"
|
||||
Take into account the diagnostics before and after applying the change:
|
||||
|
||||
<diagnostics_before>
|
||||
No diagnostics before applying the edits.
|
||||
</diagnostics_before>
|
||||
|
||||
<diagnostics_after>
|
||||
No diagnostics after applying the edits.
|
||||
</diagnostics_after>
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert!(rendered.contains(&expected_diagnostics_section));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_judge_prompt_with_mixed_diagnostics() {
|
||||
let templates = templates();
|
||||
|
||||
// Case 3: Before diagnostics present, after diagnostics absent
|
||||
let input = JudgeDiffInput {
|
||||
repository_diff: "diff content goes here".to_string(),
|
||||
ran_diagnostics_check: true,
|
||||
diagnostics_before: Some("Error at line 10: variable not found".to_string()),
|
||||
diagnostics_after: None,
|
||||
criteria: "Fix all bugs".to_string(),
|
||||
};
|
||||
|
||||
let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
|
||||
|
||||
let expected_diagnostics_section = r#"
|
||||
Take into account the diagnostics before and after applying the change:
|
||||
|
||||
<diagnostics_before>
|
||||
Error at line 10: variable not found
|
||||
</diagnostics_before>
|
||||
|
||||
<diagnostics_after>
|
||||
No diagnostics after applying the edits.
|
||||
</diagnostics_after>
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert!(rendered.contains(&expected_diagnostics_section));
|
||||
|
||||
// Case 4: Before diagnostics absent, after diagnostics present
|
||||
let input = JudgeDiffInput {
|
||||
repository_diff: "diff content goes here".to_string(),
|
||||
ran_diagnostics_check: true,
|
||||
diagnostics_before: None,
|
||||
diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
|
||||
criteria: "Fix all bugs".to_string(),
|
||||
};
|
||||
|
||||
let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
|
||||
|
||||
let expected_diagnostics_section = r#"
|
||||
Take into account the diagnostics before and after applying the change:
|
||||
|
||||
<diagnostics_before>
|
||||
No diagnostics before applying the edits.
|
||||
</diagnostics_before>
|
||||
|
||||
<diagnostics_after>
|
||||
Error at line 15: missing semicolon
|
||||
</diagnostics_after>
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert!(rendered.contains(&expected_diagnostics_section));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_judge_prompt_without_diagnostics() {
|
||||
let templates = templates();
|
||||
|
||||
// Case 5: No diagnostics check run
|
||||
let input = JudgeDiffInput {
|
||||
repository_diff: "diff content goes here".to_string(),
|
||||
ran_diagnostics_check: false,
|
||||
diagnostics_before: None,
|
||||
diagnostics_after: None,
|
||||
criteria: "Fix all bugs".to_string(),
|
||||
};
|
||||
|
||||
let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
|
||||
|
||||
// Check for the message when no diagnostics were performed
|
||||
let diagnostics_message = "No diagnostic checks were performed.";
|
||||
|
||||
assert!(rendered.contains(diagnostics_message));
|
||||
assert!(!rendered.contains("<diagnostics_before>"));
|
||||
assert!(!rendered.contains("<diagnostics_after>"));
|
||||
}
|
||||
|
||||
const JUDGE_PROMPT_NAME: &str = "judge_prompt";
|
||||
|
||||
fn templates() -> Handlebars<'static> {
|
||||
let mut judge_prompt = include_str!("judge_diff_prompt.hbs").to_string();
|
||||
language::LineEnding::normalize(&mut judge_prompt);
|
||||
let mut handlebars = Handlebars::new();
|
||||
handlebars
|
||||
.register_template_string(JUDGE_PROMPT_NAME, judge_prompt)
|
||||
.unwrap();
|
||||
handlebars
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue