Use futures::future::join_all
instead of futures::stream
in assistant_eval (#26974)
Release Notes: - N/A
This commit is contained in:
parent
a5621662b2
commit
f61d3d28e0
1 changed files with 13 additions and 9 deletions
|
@ -4,7 +4,7 @@ mod judge;
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use eval::{Eval, EvalOutput};
|
use eval::{Eval, EvalOutput};
|
||||||
use futures::{stream, StreamExt};
|
use futures::future;
|
||||||
use gpui::{Application, AsyncApp};
|
use gpui::{Application, AsyncApp};
|
||||||
use headless_assistant::{authenticate_model_provider, find_model, HeadlessAppState};
|
use headless_assistant::{authenticate_model_provider, find_model, HeadlessAppState};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
@ -126,12 +126,13 @@ fn main() {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let loaded_evals = stream::iter(evals_to_run)
|
let eval_load_futures = evals_to_run
|
||||||
|
.into_iter()
|
||||||
.map(|eval_name| {
|
.map(|eval_name| {
|
||||||
let eval_path = evaluation_data_dir.join(&eval_name);
|
let eval_path = evaluation_data_dir.join(&eval_name);
|
||||||
let repos_dir = repos_dir.clone();
|
let load_future = Eval::load(eval_name.clone(), eval_path, &repos_dir);
|
||||||
async move {
|
async move {
|
||||||
match Eval::load(eval_name.clone(), eval_path, &repos_dir).await {
|
match load_future.await {
|
||||||
Ok(eval) => Some(eval),
|
Ok(eval) => Some(eval),
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
// TODO: Persist errors / surface errors at the end.
|
// TODO: Persist errors / surface errors at the end.
|
||||||
|
@ -141,8 +142,9 @@ fn main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.buffer_unordered(args.concurrency)
|
.collect::<Vec<_>>();
|
||||||
.collect::<Vec<_>>()
|
|
||||||
|
let loaded_evals = future::join_all(eval_load_futures)
|
||||||
.await
|
.await
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.flatten()
|
.flatten()
|
||||||
|
@ -160,7 +162,8 @@ fn main() {
|
||||||
// Sort groups in descending order, so that bigger groups start first.
|
// Sort groups in descending order, so that bigger groups start first.
|
||||||
evals_grouped_by_url.sort_by_key(|evals| cmp::Reverse(evals.len()));
|
evals_grouped_by_url.sort_by_key(|evals| cmp::Reverse(evals.len()));
|
||||||
|
|
||||||
let results = stream::iter(evals_grouped_by_url)
|
let result_futures = evals_grouped_by_url
|
||||||
|
.into_iter()
|
||||||
.map(|evals| {
|
.map(|evals| {
|
||||||
let model = model.clone();
|
let model = model.clone();
|
||||||
let judge_model = judge_model.clone();
|
let judge_model = judge_model.clone();
|
||||||
|
@ -185,8 +188,9 @@ fn main() {
|
||||||
results
|
results
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.buffer_unordered(args.concurrency)
|
.collect::<Vec<_>>();
|
||||||
.collect::<Vec<_>>()
|
|
||||||
|
let results = future::join_all(result_futures)
|
||||||
.await
|
.await
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.flatten()
|
.flatten()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue