Introduce staff-only inline completion provider (#21739)

Release Notes:

- N/A

---------

Co-authored-by: Thorsten Ball <mrnugget@gmail.com>
Co-authored-by: Bennet <bennet@zed.dev>
Co-authored-by: Thorsten <thorsten@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-12-09 14:26:36 +01:00 committed by GitHub
parent 39e8944dcc
commit 77b8296fbb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 2890 additions and 356 deletions

View file

@ -149,6 +149,21 @@ spec:
secretKeyRef:
name: google-ai
key: api_key
- name: PREDICTION_API_URL
valueFrom:
secretKeyRef:
name: prediction
key: api_url
- name: PREDICTION_API_KEY
valueFrom:
secretKeyRef:
name: prediction
key: api_key
- name: PREDICTION_MODEL
valueFrom:
secretKeyRef:
name: prediction
key: model
- name: BLOB_STORE_ACCESS_KEY
valueFrom:
secretKeyRef:

View file

@ -483,7 +483,7 @@ pub async fn post_events(
checksum_matched,
))
}
Event::Cpu(_) | Event::Memory(_) => continue,
Event::Cpu(_) | Event::Memory(_) | Event::InlineCompletionRating(_) => continue,
Event::App(event) => to_upload.app_events.push(AppEventRow::from_event(
event.clone(),
wrapper,
@ -1406,6 +1406,10 @@ fn for_snowflake(
),
serde_json::to_value(e).unwrap(),
),
Event::InlineCompletionRating(e) => (
"Inline Completion Feedback".to_string(),
serde_json::to_value(e).unwrap(),
),
Event::Call(e) => {
let event_type = match e.operation.trim() {
"unshare project" => "Project Unshared".to_string(),

View file

@ -180,6 +180,9 @@ pub struct Config {
pub anthropic_api_key: Option<Arc<str>>,
pub anthropic_staff_api_key: Option<Arc<str>>,
pub llm_closed_beta_model_name: Option<Arc<str>>,
pub prediction_api_url: Option<Arc<str>>,
pub prediction_api_key: Option<Arc<str>>,
pub prediction_model: Option<Arc<str>>,
pub zed_client_checksum_seed: Option<String>,
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
@ -230,6 +233,9 @@ impl Config {
anthropic_api_key: None,
anthropic_staff_api_key: None,
llm_closed_beta_model_name: None,
prediction_api_url: None,
prediction_api_key: None,
prediction_model: None,
clickhouse_url: None,
clickhouse_user: None,
clickhouse_password: None,

View file

@ -29,7 +29,10 @@ use reqwest_client::ReqwestClient;
use rpc::{
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
};
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
use rpc::{
ListModelsResponse, PredictEditsParams, PredictEditsResponse,
MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
};
use serde_json::json;
use std::{
pin::Pin,
@ -126,6 +129,7 @@ pub fn routes() -> Router<(), Body> {
Router::new()
.route("/models", get(list_models))
.route("/completion", post(perform_completion))
.route("/predict_edits", post(predict_edits))
.layer(middleware::from_fn(validate_api_token))
}
@ -439,6 +443,59 @@ fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
}
}
async fn predict_edits(
Extension(state): Extension<Arc<LlmState>>,
Extension(claims): Extension<LlmTokenClaims>,
_country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
Json(params): Json<PredictEditsParams>,
) -> Result<impl IntoResponse> {
if !claims.is_staff {
return Err(anyhow!("not found"))?;
}
let api_url = state
.config
.prediction_api_url
.as_ref()
.context("no PREDICTION_API_URL configured on the server")?;
let api_key = state
.config
.prediction_api_key
.as_ref()
.context("no PREDICTION_API_KEY configured on the server")?;
let model = state
.config
.prediction_model
.as_ref()
.context("no PREDICTION_MODEL configured on the server")?;
let prompt = include_str!("./llm/prediction_prompt.md")
.replace("<events>", &params.input_events)
.replace("<excerpt>", &params.input_excerpt);
let mut response = open_ai::complete_text(
&state.http_client,
api_url,
api_key,
open_ai::CompletionRequest {
model: model.to_string(),
prompt: prompt.clone(),
max_tokens: 1024,
temperature: 0.,
prediction: Some(open_ai::Prediction::Content {
content: params.input_excerpt,
}),
rewrite_speculation: Some(true),
},
)
.await?;
let choice = response
.choices
.pop()
.context("no output from completion response")?;
Ok(Json(PredictEditsResponse {
output_excerpt: choice.text,
}))
}
/// The maximum monthly spending an individual user can reach on the free tier
/// before they have to pay.
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);

View file

@ -0,0 +1,12 @@
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
### Events:
<events>
### Input:
<excerpt>
### Response:

View file

@ -546,6 +546,9 @@ impl TestServer {
anthropic_api_key: None,
anthropic_staff_api_key: None,
llm_closed_beta_model_name: None,
prediction_api_url: None,
prediction_api_key: None,
prediction_model: None,
clickhouse_url: None,
clickhouse_user: None,
clickhouse_password: None,