Timeout if completion takes longer than 2s (#23215)

Release Notes:

- N/A
This commit is contained in:
Antonio Scandurra 2025-01-16 11:13:25 +01:00 committed by GitHub
parent a41d72ee81
commit 880f3ff243
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -21,7 +21,7 @@ use chrono::{DateTime, Duration, Utc};
use collections::HashMap; use collections::HashMap;
use db::TokenUsage; use db::TokenUsage;
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase}; use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
use futures::{Stream, StreamExt as _}; use futures::{FutureExt, Stream, StreamExt as _};
use reqwest_client::ReqwestClient; use reqwest_client::ReqwestClient;
use rpc::{ use rpc::{
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME, proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
@ -475,7 +475,11 @@ async fn predict_edits(
.replace("<excerpt>", &params.input_excerpt); .replace("<excerpt>", &params.input_excerpt);
let request_start = std::time::Instant::now(); let request_start = std::time::Instant::now();
let mut response = fireworks::complete( let timeout = state
.executor
.sleep(std::time::Duration::from_secs(2))
.fuse();
let response = fireworks::complete(
&state.http_client, &state.http_client,
api_url, api_url,
api_key, api_key,
@ -490,41 +494,72 @@ async fn predict_edits(
rewrite_speculation: Some(true), rewrite_speculation: Some(true),
}, },
) )
.await?; .fuse();
let duration = request_start.elapsed(); futures::pin_mut!(timeout);
futures::pin_mut!(response);
let choice = response futures::select! {
.completion _ = timeout => {
.choices state.executor.spawn_detached({
.pop() let kinesis_client = state.kinesis_client.clone();
.context("no output from completion response")?; let kinesis_stream = state.config.kinesis_stream.clone();
let model = model.clone();
async move {
SnowflakeRow::new(
"Fireworks Completion Timeout",
claims.metrics_id,
claims.is_staff,
claims.system_id.clone(),
json!({
"model": model.to_string(),
"prompt": prompt,
}),
)
.write(&kinesis_client, &kinesis_stream)
.await
.log_err();
}
});
Err(anyhow!("request timed out"))?
},
response = response => {
let duration = request_start.elapsed();
state.executor.spawn_detached({ let mut response = response?;
let kinesis_client = state.kinesis_client.clone(); let choice = response
let kinesis_stream = state.config.kinesis_stream.clone(); .completion
let model = model.clone(); .choices
async move { .pop()
SnowflakeRow::new( .context("no output from completion response")?;
"Fireworks Completion Requested",
claims.metrics_id,
claims.is_staff,
claims.system_id.clone(),
json!({
"model": model.to_string(),
"headers": response.headers,
"usage": response.completion.usage,
"duration": duration.as_secs_f64(),
}),
)
.write(&kinesis_client, &kinesis_stream)
.await
.log_err();
}
});
Ok(Json(PredictEditsResponse { state.executor.spawn_detached({
output_excerpt: choice.text, let kinesis_client = state.kinesis_client.clone();
})) let kinesis_stream = state.config.kinesis_stream.clone();
let model = model.clone();
async move {
SnowflakeRow::new(
"Fireworks Completion Requested",
claims.metrics_id,
claims.is_staff,
claims.system_id.clone(),
json!({
"model": model.to_string(),
"headers": response.headers,
"usage": response.completion.usage,
"duration": duration.as_secs_f64(),
}),
)
.write(&kinesis_client, &kinesis_stream)
.await
.log_err();
}
});
Ok(Json(PredictEditsResponse {
output_excerpt: choice.text,
}))
},
}
} }
/// The maximum monthly spending an individual user can reach on the free tier /// The maximum monthly spending an individual user can reach on the free tier