collab: Remove unused POST /predict_edits endpoint from LLM service (#23997)

This PR removes the `POST /predict_edits` endpoint from the LLM service,
as it has been superseded by the corresponding endpoint running in
Cloudflare Workers.

All traffic is already being routed to the Cloudflare Workers via the
Workers route, so nothing is hitting this endpoint running in the LLM
service anymore.

You can see the drop off in requests to this endpoint on this graph when
the Workers route was added:

<img width="472" alt="Screenshot 2025-01-30 at 9 18 04 PM"
src="https://github.com/user-attachments/assets/fa60f7c8-2737-4329-88a3-17093bdb5a29"
/>

We also don't use the `fireworks` crate anymore in this repo, so it has
been removed.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-01-30 22:21:40 -05:00 committed by GitHub
parent 35fbe1ef3d
commit 8be73bf187
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 2 additions and 380 deletions

12
Cargo.lock generated
View file

@ -2718,7 +2718,6 @@ dependencies = [
"envy",
"extension",
"file_finder",
"fireworks",
"fs",
"futures 0.3.31",
"git",
@ -4657,17 +4656,6 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "fireworks"
version = "0.1.0"
dependencies = [
"anyhow",
"futures 0.3.31",
"http_client",
"serde",
"serde_json",
]
[[package]]
name = "fixedbitset"
version = "0.4.2"

View file

@ -44,7 +44,6 @@ members = [
"crates/feedback",
"crates/file_finder",
"crates/file_icons",
"crates/fireworks",
"crates/fs",
"crates/fsevent",
"crates/fuzzy",
@ -240,7 +239,6 @@ feature_flags = { path = "crates/feature_flags" }
feedback = { path = "crates/feedback" }
file_finder = { path = "crates/file_finder" }
file_icons = { path = "crates/file_icons" }
fireworks = { path = "crates/fireworks" }
fs = { path = "crates/fs" }
fsevent = { path = "crates/fsevent" }
fuzzy = { path = "crates/fuzzy" }

View file

@ -34,7 +34,6 @@ collections.workspace = true
dashmap.workspace = true
derive_more.workspace = true
envy = "0.4.2"
fireworks.workspace = true
futures.workspace = true
google_ai.workspace = true
hex.workspace = true

View file

@ -21,15 +21,12 @@ use chrono::{DateTime, Duration, Utc};
use collections::HashMap;
use db::TokenUsage;
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
use futures::{FutureExt, Stream, StreamExt as _};
use futures::{Stream, StreamExt as _};
use reqwest_client::ReqwestClient;
use rpc::{
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
};
use rpc::{
ListModelsResponse, PredictEditsParams, PredictEditsResponse,
MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
};
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
use serde_json::json;
use std::{
pin::Pin,
@ -44,9 +41,6 @@ pub use token::*;
const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
/// Output token limit. A copy of this constant is also in `crates/zeta/src/zeta.rs`.
const MAX_OUTPUT_TOKENS: u32 = 2048;
pub struct LlmState {
pub config: Config,
pub executor: Executor,
@ -123,7 +117,6 @@ pub fn routes() -> Router<(), Body> {
Router::new()
.route("/models", get(list_models))
.route("/completion", post(perform_completion))
.route("/predict_edits", post(predict_edits))
.layer(middleware::from_fn(validate_api_token))
}
@ -437,156 +430,6 @@ fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
}
}
async fn predict_edits(
Extension(state): Extension<Arc<LlmState>>,
Extension(claims): Extension<LlmTokenClaims>,
_country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
Json(params): Json<PredictEditsParams>,
) -> Result<impl IntoResponse> {
if !claims.is_staff && !claims.has_predict_edits_feature_flag {
return Err(Error::http(
StatusCode::FORBIDDEN,
"no access to Zed's edit prediction feature".to_string(),
));
}
let should_sample = claims.is_staff || params.can_collect_data;
let api_url = state
.config
.prediction_api_url
.as_ref()
.context("no PREDICTION_API_URL configured on the server")?;
let api_key = state
.config
.prediction_api_key
.as_ref()
.context("no PREDICTION_API_KEY configured on the server")?;
let model = state
.config
.prediction_model
.as_ref()
.context("no PREDICTION_MODEL configured on the server")?;
let outline_prefix = params
.outline
.as_ref()
.map(|outline| format!("### Outline for current file:\n{}\n", outline))
.unwrap_or_default();
let prompt = include_str!("./llm/prediction_prompt.md")
.replace("<outline>", &outline_prefix)
.replace("<events>", &params.input_events)
.replace("<excerpt>", &params.input_excerpt);
let request_start = std::time::Instant::now();
let timeout = state
.executor
.sleep(std::time::Duration::from_secs(2))
.fuse();
let response = fireworks::complete(
&state.http_client,
api_url,
api_key,
fireworks::CompletionRequest {
model: model.to_string(),
prompt: prompt.clone(),
max_tokens: MAX_OUTPUT_TOKENS,
temperature: 0.,
prediction: Some(fireworks::Prediction::Content {
content: params.input_excerpt.clone(),
}),
rewrite_speculation: Some(true),
},
)
.fuse();
futures::pin_mut!(timeout);
futures::pin_mut!(response);
futures::select! {
_ = timeout => {
state.executor.spawn_detached({
let kinesis_client = state.kinesis_client.clone();
let kinesis_stream = state.config.kinesis_stream.clone();
let model = model.clone();
async move {
SnowflakeRow::new(
"Fireworks Completion Timeout",
claims.metrics_id,
claims.is_staff,
claims.system_id.clone(),
json!({
"model": model.to_string(),
"prompt": prompt,
}),
)
.write(&kinesis_client, &kinesis_stream)
.await
.log_err();
}
});
Err(anyhow!("request timed out"))?
},
response = response => {
let duration = request_start.elapsed();
let mut response = response?;
let choice = response
.completion
.choices
.pop()
.context("no output from completion response")?;
state.executor.spawn_detached({
let kinesis_client = state.kinesis_client.clone();
let kinesis_stream = state.config.kinesis_stream.clone();
let model = model.clone();
let output = choice.text.clone();
async move {
let properties = if should_sample {
json!({
"model": model.to_string(),
"headers": response.headers,
"usage": response.completion.usage,
"duration": duration.as_secs_f64(),
"prompt": prompt,
"input_excerpt": params.input_excerpt,
"input_events": params.input_events,
"outline": params.outline,
"output": output,
"is_sampled": true,
})
} else {
json!({
"model": model.to_string(),
"headers": response.headers,
"usage": response.completion.usage,
"duration": duration.as_secs_f64(),
"is_sampled": false,
})
};
SnowflakeRow::new(
"Fireworks Completion Requested",
claims.metrics_id,
claims.is_staff,
claims.system_id.clone(),
properties,
)
.write(&kinesis_client, &kinesis_stream)
.await
.log_err();
}
});
Ok(Json(PredictEditsResponse {
output_excerpt: choice.text,
}))
},
}
}
/// The maximum monthly spending an individual user can reach on the free tier
/// before they have to pay.
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);

View file

@ -1,13 +0,0 @@
<outline>## Task
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
### Events:
<events>
### Input:
<excerpt>
### Response:

View file

@ -1,19 +0,0 @@
[package]
name = "fireworks"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/fireworks.rs"
[dependencies]
anyhow.workspace = true
futures.workspace = true
http_client.workspace = true
serde.workspace = true
serde_json.workspace = true

View file

@ -1 +0,0 @@
../../LICENSE-GPL

View file

@ -1,173 +0,0 @@
use anyhow::{anyhow, Result};
use futures::AsyncReadExt;
use http_client::{http::HeaderMap, AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
pub const FIREWORKS_API_URL: &str = "https://api.openai.com/v1";
#[derive(Debug, Serialize, Deserialize)]
pub struct CompletionRequest {
pub model: String,
pub prompt: String,
pub max_tokens: u32,
pub temperature: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prediction: Option<Prediction>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub rewrite_speculation: Option<bool>,
}
#[derive(Clone, Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Prediction {
Content { content: String },
}
#[derive(Debug)]
pub struct Response {
pub completion: CompletionResponse,
pub headers: Headers,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: Usage,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct CompletionChoice {
pub text: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct Headers {
pub server_processing_time: Option<f64>,
pub request_id: Option<String>,
pub prompt_tokens: Option<u32>,
pub speculation_generated_tokens: Option<u32>,
pub cached_prompt_tokens: Option<u32>,
pub backend_host: Option<String>,
pub num_concurrent_requests: Option<u32>,
pub deployment: Option<String>,
pub tokenizer_queue_duration: Option<f64>,
pub tokenizer_duration: Option<f64>,
pub prefill_queue_duration: Option<f64>,
pub prefill_duration: Option<f64>,
pub generation_queue_duration: Option<f64>,
}
impl Headers {
pub fn parse(headers: &HeaderMap) -> Self {
Headers {
request_id: headers
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(String::from),
server_processing_time: headers
.get("fireworks-server-processing-time")
.and_then(|v| v.to_str().ok()?.parse().ok()),
prompt_tokens: headers
.get("fireworks-prompt-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok()),
speculation_generated_tokens: headers
.get("fireworks-speculation-generated-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok()),
cached_prompt_tokens: headers
.get("fireworks-cached-prompt-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok()),
backend_host: headers
.get("fireworks-backend-host")
.and_then(|v| v.to_str().ok())
.map(String::from),
num_concurrent_requests: headers
.get("fireworks-num-concurrent-requests")
.and_then(|v| v.to_str().ok()?.parse().ok()),
deployment: headers
.get("fireworks-deployment")
.and_then(|v| v.to_str().ok())
.map(String::from),
tokenizer_queue_duration: headers
.get("fireworks-tokenizer-queue-duration")
.and_then(|v| v.to_str().ok()?.parse().ok()),
tokenizer_duration: headers
.get("fireworks-tokenizer-duration")
.and_then(|v| v.to_str().ok()?.parse().ok()),
prefill_queue_duration: headers
.get("fireworks-prefill-queue-duration")
.and_then(|v| v.to_str().ok()?.parse().ok()),
prefill_duration: headers
.get("fireworks-prefill-duration")
.and_then(|v| v.to_str().ok()?.parse().ok()),
generation_queue_duration: headers
.get("fireworks-generation-queue-duration")
.and_then(|v| v.to_str().ok()?.parse().ok()),
}
}
}
pub async fn complete(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: CompletionRequest,
) -> Result<Response> {
let uri = format!("{api_url}/completions");
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key));
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let headers = Headers::parse(response.headers());
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
Ok(Response {
completion: serde_json::from_str(&body)?,
headers,
})
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct FireworksResponse {
error: FireworksError,
}
#[derive(Deserialize)]
struct FireworksError {
message: String,
}
match serde_json::from_str::<FireworksResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to Fireworks API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to Fireworks API: {} {}",
response.status(),
body,
)),
}
}
}