Introduce staff-only inline completion provider (#21739)
Release Notes: - N/A --------- Co-authored-by: Thorsten Ball <mrnugget@gmail.com> Co-authored-by: Bennet <bennet@zed.dev> Co-authored-by: Thorsten <thorsten@zed.dev>
This commit is contained in:
parent
39e8944dcc
commit
77b8296fbb
39 changed files with 2890 additions and 356 deletions
|
@ -29,7 +29,10 @@ use reqwest_client::ReqwestClient;
|
|||
use rpc::{
|
||||
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||
};
|
||||
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
|
||||
use rpc::{
|
||||
ListModelsResponse, PredictEditsParams, PredictEditsResponse,
|
||||
MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
|
||||
};
|
||||
use serde_json::json;
|
||||
use std::{
|
||||
pin::Pin,
|
||||
|
@ -126,6 +129,7 @@ pub fn routes() -> Router<(), Body> {
|
|||
Router::new()
|
||||
.route("/models", get(list_models))
|
||||
.route("/completion", post(perform_completion))
|
||||
.route("/predict_edits", post(predict_edits))
|
||||
.layer(middleware::from_fn(validate_api_token))
|
||||
}
|
||||
|
||||
|
@ -439,6 +443,59 @@ fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
|
|||
}
|
||||
}
|
||||
|
||||
async fn predict_edits(
|
||||
Extension(state): Extension<Arc<LlmState>>,
|
||||
Extension(claims): Extension<LlmTokenClaims>,
|
||||
_country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||
Json(params): Json<PredictEditsParams>,
|
||||
) -> Result<impl IntoResponse> {
|
||||
if !claims.is_staff {
|
||||
return Err(anyhow!("not found"))?;
|
||||
}
|
||||
|
||||
let api_url = state
|
||||
.config
|
||||
.prediction_api_url
|
||||
.as_ref()
|
||||
.context("no PREDICTION_API_URL configured on the server")?;
|
||||
let api_key = state
|
||||
.config
|
||||
.prediction_api_key
|
||||
.as_ref()
|
||||
.context("no PREDICTION_API_KEY configured on the server")?;
|
||||
let model = state
|
||||
.config
|
||||
.prediction_model
|
||||
.as_ref()
|
||||
.context("no PREDICTION_MODEL configured on the server")?;
|
||||
let prompt = include_str!("./llm/prediction_prompt.md")
|
||||
.replace("<events>", ¶ms.input_events)
|
||||
.replace("<excerpt>", ¶ms.input_excerpt);
|
||||
let mut response = open_ai::complete_text(
|
||||
&state.http_client,
|
||||
api_url,
|
||||
api_key,
|
||||
open_ai::CompletionRequest {
|
||||
model: model.to_string(),
|
||||
prompt: prompt.clone(),
|
||||
max_tokens: 1024,
|
||||
temperature: 0.,
|
||||
prediction: Some(open_ai::Prediction::Content {
|
||||
content: params.input_excerpt,
|
||||
}),
|
||||
rewrite_speculation: Some(true),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let choice = response
|
||||
.choices
|
||||
.pop()
|
||||
.context("no output from completion response")?;
|
||||
Ok(Json(PredictEditsResponse {
|
||||
output_excerpt: choice.text,
|
||||
}))
|
||||
}
|
||||
|
||||
/// The maximum monthly spending an individual user can reach on the free tier
|
||||
/// before they have to pay.
|
||||
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue