From 1fcc9b36ba3f5bc2982f00c8a25c8d14db4601ce Mon Sep 17 00:00:00 2001 From: Thorsten Ball Date: Fri, 10 Jan 2025 23:40:54 +0100 Subject: [PATCH] zeta: Report Fireworks request data to Snowflake (#22973) Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra Co-authored-by: Conrad --- Cargo.lock | 12 +++ Cargo.toml | 2 + crates/collab/Cargo.toml | 1 + crates/collab/src/llm.rs | 31 +++++- crates/fireworks/Cargo.toml | 19 ++++ crates/fireworks/LICENSE-GPL | 1 + crates/fireworks/src/fireworks.rs | 173 ++++++++++++++++++++++++++++++ 7 files changed, 236 insertions(+), 3 deletions(-) create mode 100644 crates/fireworks/Cargo.toml create mode 120000 crates/fireworks/LICENSE-GPL create mode 100644 crates/fireworks/src/fireworks.rs diff --git a/Cargo.lock b/Cargo.lock index 8029cf7538..7f63e6aceb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2600,6 +2600,7 @@ dependencies = [ "envy", "extension", "file_finder", + "fireworks", "fs", "futures 0.3.31", "git", @@ -4512,6 +4513,17 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "fireworks" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.31", + "http_client", + "serde", + "serde_json", +] + [[package]] name = "fixedbitset" version = "0.4.2" diff --git a/Cargo.toml b/Cargo.toml index d7544984ba..b5787abd42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,7 @@ members = [ "crates/feedback", "crates/file_finder", "crates/file_icons", + "crates/fireworks", "crates/fs", "crates/fsevent", "crates/fuzzy", @@ -222,6 +223,7 @@ feature_flags = { path = "crates/feature_flags" } feedback = { path = "crates/feedback" } file_finder = { path = "crates/file_finder" } file_icons = { path = "crates/file_icons" } +fireworks = { path = "crates/fireworks" } fs = { path = "crates/fs" } fsevent = { path = "crates/fsevent" } fuzzy = { path = "crates/fuzzy" } diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 996ee39843..c6e2a34a3c 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -34,6 +34,7 @@ collections.workspace = true dashmap.workspace = true derive_more.workspace = true envy = "0.4.2" +fireworks.workspace = true futures.workspace = true google_ai.workspace = true hex.workspace = true diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 2ece048064..e0115cc5d0 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -470,23 +470,48 @@ async fn predict_edits( .replace("", &outline_prefix) .replace("", ¶ms.input_events) .replace("", ¶ms.input_excerpt); - let mut response = open_ai::complete_text( + let mut response = fireworks::complete( &state.http_client, api_url, api_key, - open_ai::CompletionRequest { + fireworks::CompletionRequest { model: model.to_string(), prompt: prompt.clone(), max_tokens: 2048, temperature: 0., - prediction: Some(open_ai::Prediction::Content { + prediction: Some(fireworks::Prediction::Content { content: params.input_excerpt, }), rewrite_speculation: Some(true), }, ) .await?; + + state.executor.spawn_detached({ + let kinesis_client = state.kinesis_client.clone(); + let kinesis_stream = state.config.kinesis_stream.clone(); + let headers = response.headers.clone(); + let model = model.clone(); + + async move { + SnowflakeRow::new( + "Fireworks Completion Requested", + claims.metrics_id, + claims.is_staff, + claims.system_id.clone(), + json!({ + "model": model.to_string(), + "headers": headers, + }), + ) + .write(&kinesis_client, &kinesis_stream) + .await + .log_err(); + } + }); + let choice = response + .completion .choices .pop() .context("no output from completion response")?; diff --git a/crates/fireworks/Cargo.toml b/crates/fireworks/Cargo.toml new file mode 100644 index 0000000000..d9aa1b7bb6 --- /dev/null +++ b/crates/fireworks/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "fireworks" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/fireworks.rs" + +[dependencies] +anyhow.workspace = true +futures.workspace = true +http_client.workspace = true +serde.workspace = true +serde_json.workspace = true diff --git a/crates/fireworks/LICENSE-GPL b/crates/fireworks/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/fireworks/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/fireworks/src/fireworks.rs b/crates/fireworks/src/fireworks.rs new file mode 100644 index 0000000000..5772204747 --- /dev/null +++ b/crates/fireworks/src/fireworks.rs @@ -0,0 +1,173 @@ +use anyhow::{anyhow, Result}; +use futures::AsyncReadExt; +use http_client::{http::HeaderMap, AsyncBody, HttpClient, Method, Request as HttpRequest}; +use serde::{Deserialize, Serialize}; + +pub const FIREWORKS_API_URL: &str = "https://api.openai.com/v1"; + +#[derive(Debug, Serialize, Deserialize)] +pub struct CompletionRequest { + pub model: String, + pub prompt: String, + pub max_tokens: u32, + pub temperature: f32, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub prediction: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub rewrite_speculation: Option, +} + +#[derive(Clone, Deserialize, Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Prediction { + Content { content: String }, +} + +#[derive(Debug)] +pub struct Response { + pub completion: CompletionResponse, + pub headers: Headers, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct CompletionResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct CompletionChoice { + pub text: String, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Debug, Clone, Default, Serialize)] +pub struct Headers { + pub server_processing_time: Option, + pub request_id: Option, + pub prompt_tokens: Option, + pub speculation_generated_tokens: Option, + pub cached_prompt_tokens: Option, + pub backend_host: Option, + pub num_concurrent_requests: Option, + pub deployment: Option, + pub tokenizer_queue_duration: Option, + pub tokenizer_duration: Option, + pub prefill_queue_duration: Option, + pub prefill_duration: Option, + pub generation_queue_duration: Option, +} + +impl Headers { + pub fn parse(headers: &HeaderMap) -> Self { + Headers { + request_id: headers + .get("x-request-id") + .and_then(|v| v.to_str().ok()) + .map(String::from), + server_processing_time: headers + .get("fireworks-server-processing-time") + .and_then(|v| v.to_str().ok()?.parse().ok()), + prompt_tokens: headers + .get("fireworks-prompt-tokens") + .and_then(|v| v.to_str().ok()?.parse().ok()), + speculation_generated_tokens: headers + .get("fireworks-speculation-generated-tokens") + .and_then(|v| v.to_str().ok()?.parse().ok()), + cached_prompt_tokens: headers + .get("fireworks-cached-prompt-tokens") + .and_then(|v| v.to_str().ok()?.parse().ok()), + backend_host: headers + .get("fireworks-backend-host") + .and_then(|v| v.to_str().ok()) + .map(String::from), + num_concurrent_requests: headers + .get("fireworks-num-concurrent-requests") + .and_then(|v| v.to_str().ok()?.parse().ok()), + deployment: headers + .get("fireworks-deployment") + .and_then(|v| v.to_str().ok()) + .map(String::from), + tokenizer_queue_duration: headers + .get("fireworks-tokenizer-queue-duration") + .and_then(|v| v.to_str().ok()?.parse().ok()), + tokenizer_duration: headers + .get("fireworks-tokenizer-duration") + .and_then(|v| v.to_str().ok()?.parse().ok()), + prefill_queue_duration: headers + .get("fireworks-prefill-queue-duration") + .and_then(|v| v.to_str().ok()?.parse().ok()), + prefill_duration: headers + .get("fireworks-prefill-duration") + .and_then(|v| v.to_str().ok()?.parse().ok()), + generation_queue_duration: headers + .get("fireworks-generation-queue-duration") + .and_then(|v| v.to_str().ok()?.parse().ok()), + } + } +} + +pub async fn complete( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: CompletionRequest, +) -> Result { + let uri = format!("{api_url}/completions"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)); + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + + if response.status().is_success() { + let headers = Headers::parse(response.headers()); + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + Ok(Response { + completion: serde_json::from_str(&body)?, + headers, + }) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct FireworksResponse { + error: FireworksError, + } + + #[derive(Deserialize)] + struct FireworksError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => Err(anyhow!( + "Failed to connect to Fireworks API: {}", + response.error.message, + )), + + _ => Err(anyhow!( + "Failed to connect to Fireworks API: {} {}", + response.status(), + body, + )), + } + } +}