diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index f78293dbf1..7831707af2 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -199,7 +199,7 @@ fn main() { future::join_all(clone_tasks).await; - for example in examples.iter() { + for example in examples.iter_mut() { example.setup().await?; } diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 70140a360a..bbd00bb449 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -57,7 +57,7 @@ pub struct Example { /// Content of `criteria.md` pub criteria: String, /// Markdown output file to append to - pub output_file: Arc>, + pub output_file: Option>>, /// Path to markdown output file pub output_file_path: PathBuf, /// Prefix used for logging that identifies this example @@ -97,14 +97,13 @@ impl Example { "{}.md", dir_path.file_name().unwrap().to_str().unwrap() )); - let output_file = Arc::new(Mutex::new(File::create(&output_file_path).unwrap())); Ok(Example { name: name.clone(), base: toml::from_str(&fs::read_to_string(&base_path)?)?, prompt: fs::read_to_string(prompt_path.clone())?, criteria: fs::read_to_string(criteria_path.clone())?, - output_file, + output_file: None, output_file_path, log_prefix: name, }) @@ -132,7 +131,7 @@ impl Example { } /// Set up the example by checking out the specified Git revision - pub async fn setup(&self) -> Result<()> { + pub async fn setup(&mut self) -> Result<()> { let repo_path = repo_path_for_url(&self.base.url); println!("{}Fetching", self.log_prefix); @@ -171,9 +170,20 @@ impl Example { .await?; } + // Create the output file + let output_file = Arc::new(Mutex::new(File::create(&self.output_file_path)?)); + self.output_file = Some(output_file); + Ok(()) } + /// Returns the output file, panicking if it's not set + fn output_file(&self) -> Arc> { + self.output_file + .clone() + .expect("Output file not created. Call setup() first.") + } + pub fn run( &self, model: Arc, @@ -287,7 +297,8 @@ impl Example { thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?; { - let mut output_file = this.output_file.lock().unwrap(); + let output_file_ref = this.output_file(); + let mut output_file = output_file_ref.lock().unwrap(); writeln!(&mut output_file, "👤 USER:").log_err(); writeln!(&mut output_file, "{}", this.prompt).log_err(); writeln!(&mut output_file, "🤖 ASSISTANT:").log_err(); @@ -304,7 +315,8 @@ impl Example { }); let event_handler_task = cx.spawn({ - let output_file = this.output_file.clone(); + // Need to clone the Arc here because the reference from output_file() won't live long enough + let output_file = this.output_file.clone().unwrap(); let log_prefix = this.log_prefix.clone(); let tool_use_counts = tool_use_counts.clone(); let thread = thread.downgrade(); @@ -471,7 +483,8 @@ impl Example { let response = send_language_model_request(model, request, cx).await?; - let mut output_file = self.output_file.lock().unwrap(); + let output_file_ref = self.output_file(); + let mut output_file = output_file_ref.lock().unwrap(); writeln!(&mut output_file, "\n\n").log_err(); writeln!(&mut output_file, "========================================").log_err();