Use futures::future::join_all instead of futures::stream in assistant_eval (#26974)

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-03-18 02:22:18 -06:00 committed by GitHub
parent a5621662b2
commit f61d3d28e0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,7 +4,7 @@ mod judge;
use clap::Parser; use clap::Parser;
use eval::{Eval, EvalOutput}; use eval::{Eval, EvalOutput};
use futures::{stream, StreamExt}; use futures::future;
use gpui::{Application, AsyncApp}; use gpui::{Application, AsyncApp};
use headless_assistant::{authenticate_model_provider, find_model, HeadlessAppState}; use headless_assistant::{authenticate_model_provider, find_model, HeadlessAppState};
use itertools::Itertools; use itertools::Itertools;
@ -126,12 +126,13 @@ fn main() {
.await .await
.unwrap(); .unwrap();
let loaded_evals = stream::iter(evals_to_run) let eval_load_futures = evals_to_run
.into_iter()
.map(|eval_name| { .map(|eval_name| {
let eval_path = evaluation_data_dir.join(&eval_name); let eval_path = evaluation_data_dir.join(&eval_name);
let repos_dir = repos_dir.clone(); let load_future = Eval::load(eval_name.clone(), eval_path, &repos_dir);
async move { async move {
match Eval::load(eval_name.clone(), eval_path, &repos_dir).await { match load_future.await {
Ok(eval) => Some(eval), Ok(eval) => Some(eval),
Err(err) => { Err(err) => {
// TODO: Persist errors / surface errors at the end. // TODO: Persist errors / surface errors at the end.
@ -141,8 +142,9 @@ fn main() {
} }
} }
}) })
.buffer_unordered(args.concurrency) .collect::<Vec<_>>();
.collect::<Vec<_>>()
let loaded_evals = future::join_all(eval_load_futures)
.await .await
.into_iter() .into_iter()
.flatten() .flatten()
@ -160,7 +162,8 @@ fn main() {
// Sort groups in descending order, so that bigger groups start first. // Sort groups in descending order, so that bigger groups start first.
evals_grouped_by_url.sort_by_key(|evals| cmp::Reverse(evals.len())); evals_grouped_by_url.sort_by_key(|evals| cmp::Reverse(evals.len()));
let results = stream::iter(evals_grouped_by_url) let result_futures = evals_grouped_by_url
.into_iter()
.map(|evals| { .map(|evals| {
let model = model.clone(); let model = model.clone();
let judge_model = judge_model.clone(); let judge_model = judge_model.clone();
@ -185,8 +188,9 @@ fn main() {
results results
} }
}) })
.buffer_unordered(args.concurrency) .collect::<Vec<_>>();
.collect::<Vec<_>>()
let results = future::join_all(result_futures)
.await .await
.into_iter() .into_iter()
.flatten() .flatten()