From 880f3ff2431379a906514a19f0694c65b37af5b3 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 16 Jan 2025 11:13:25 +0100 Subject: [PATCH] Timeout if completion takes longer than 2s (#23215) Release Notes: - N/A --- crates/collab/src/llm.rs | 103 ++++++++++++++++++++++++++------------- 1 file changed, 69 insertions(+), 34 deletions(-) diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 8bf3fa6443..2cb9c6182a 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -21,7 +21,7 @@ use chrono::{DateTime, Duration, Utc}; use collections::HashMap; use db::TokenUsage; use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase}; -use futures::{Stream, StreamExt as _}; +use futures::{FutureExt, Stream, StreamExt as _}; use reqwest_client::ReqwestClient; use rpc::{ proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME, @@ -475,7 +475,11 @@ async fn predict_edits( .replace("", ¶ms.input_excerpt); let request_start = std::time::Instant::now(); - let mut response = fireworks::complete( + let timeout = state + .executor + .sleep(std::time::Duration::from_secs(2)) + .fuse(); + let response = fireworks::complete( &state.http_client, api_url, api_key, @@ -490,41 +494,72 @@ async fn predict_edits( rewrite_speculation: Some(true), }, ) - .await?; - let duration = request_start.elapsed(); + .fuse(); + futures::pin_mut!(timeout); + futures::pin_mut!(response); - let choice = response - .completion - .choices - .pop() - .context("no output from completion response")?; + futures::select! { + _ = timeout => { + state.executor.spawn_detached({ + let kinesis_client = state.kinesis_client.clone(); + let kinesis_stream = state.config.kinesis_stream.clone(); + let model = model.clone(); + async move { + SnowflakeRow::new( + "Fireworks Completion Timeout", + claims.metrics_id, + claims.is_staff, + claims.system_id.clone(), + json!({ + "model": model.to_string(), + "prompt": prompt, + }), + ) + .write(&kinesis_client, &kinesis_stream) + .await + .log_err(); + } + }); + Err(anyhow!("request timed out"))? + }, + response = response => { + let duration = request_start.elapsed(); - state.executor.spawn_detached({ - let kinesis_client = state.kinesis_client.clone(); - let kinesis_stream = state.config.kinesis_stream.clone(); - let model = model.clone(); - async move { - SnowflakeRow::new( - "Fireworks Completion Requested", - claims.metrics_id, - claims.is_staff, - claims.system_id.clone(), - json!({ - "model": model.to_string(), - "headers": response.headers, - "usage": response.completion.usage, - "duration": duration.as_secs_f64(), - }), - ) - .write(&kinesis_client, &kinesis_stream) - .await - .log_err(); - } - }); + let mut response = response?; + let choice = response + .completion + .choices + .pop() + .context("no output from completion response")?; - Ok(Json(PredictEditsResponse { - output_excerpt: choice.text, - })) + state.executor.spawn_detached({ + let kinesis_client = state.kinesis_client.clone(); + let kinesis_stream = state.config.kinesis_stream.clone(); + let model = model.clone(); + async move { + SnowflakeRow::new( + "Fireworks Completion Requested", + claims.metrics_id, + claims.is_staff, + claims.system_id.clone(), + json!({ + "model": model.to_string(), + "headers": response.headers, + "usage": response.completion.usage, + "duration": duration.as_secs_f64(), + }), + ) + .write(&kinesis_client, &kinesis_stream) + .await + .log_err(); + } + }); + + Ok(Json(PredictEditsResponse { + output_excerpt: choice.text, + })) + }, + } } /// The maximum monthly spending an individual user can reach on the free tier