Meter edit predictions by acceptance in free plan (#30984)

TODO:

- [x] Release  a new version of `zed_llm_client`

Release Notes:

- N/A

---------

Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Max Brunsfeld 2025-05-21 10:11:42 -07:00 committed by GitHub
parent afe23cf85a
commit cfd3b0ff7b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 115 additions and 8 deletions

4
Cargo.lock generated
View file

@ -19847,9 +19847,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
version = "0.8.1"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16d993fc42f9ec43ab76fa46c6eb579a66e116bb08cd2bc9a67f3afcaa05d39d"
checksum = "9be71e2f9b271e1eb8eb3e0d986075e770d1a0a299fb036abc3f1fc13a2fa7eb"
dependencies = [
"anyhow",
"serde",

View file

@ -615,7 +615,7 @@ wasmtime-wasi = "29"
which = "6.0.0"
wit-component = "0.221"
workspace-hack = "0.1.0"
zed_llm_client = "0.8.1"
zed_llm_client = "0.8.2"
zstd = "0.11"
[workspace.dependencies.async-stripe]

View file

@ -14,7 +14,7 @@ use license_detection::LICENSE_FILES_TO_CHECK;
pub use license_detection::is_license_eligible_for_data_collection;
pub use rate_completion_modal::*;
use anyhow::{Context as _, Result};
use anyhow::{Context as _, Result, anyhow};
use arrayvec::ArrayVec;
use client::{Client, UserStore};
use collections::{HashMap, HashSet, VecDeque};
@ -23,7 +23,7 @@ use gpui::{
App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
Subscription, Task, WeakEntity, actions,
};
use http_client::{HttpClient, Method};
use http_client::{AsyncBody, HttpClient, Method, Request, Response};
use input_excerpt::excerpt_for_cursor_position;
use language::{
Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
@ -54,8 +54,8 @@ use workspace::Workspace;
use workspace::notifications::{ErrorMessagePrompt, NotificationId};
use worktree::Worktree;
use zed_llm_client::{
EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsBody,
PredictEditsResponse, ZED_VERSION_HEADER_NAME,
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
};
const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
@ -823,6 +823,74 @@ and then another
}
}
fn accept_edit_prediction(
&mut self,
request_id: InlineCompletionId,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let client = self.client.clone();
let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx);
cx.spawn(async move |this, cx| {
let http_client = client.http_client();
let mut response = llm_token_retry(&llm_token, &client, |token| {
let request_builder = http_client::Request::builder().method(Method::POST);
let request_builder =
if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
request_builder.uri(accept_prediction_url)
} else {
request_builder.uri(
http_client
.build_zed_llm_url("/predict_edits/accept", &[])?
.as_ref(),
)
};
Ok(request_builder
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", token))
.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
.body(
serde_json::to_string(&AcceptEditPredictionBody {
request_id: request_id.0,
})?
.into(),
)?)
})
.await?;
if let Some(minimum_required_version) = response
.headers()
.get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
.and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
{
if app_version < minimum_required_version {
return Err(anyhow!(ZedUpdateRequiredError {
minimum_version: minimum_required_version
}));
}
}
if response.status().is_success() {
if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
this.update(cx, |this, cx| {
this.last_usage = Some(usage);
cx.notify();
})?;
}
Ok(())
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
Err(anyhow!(
"error accepting edit prediction.\nStatus: {:?}\nBody: {}",
response.status(),
body
))
}
})
}
fn process_completion_response(
prediction_response: PredictEditsResponse,
buffer: Entity<Buffer>,
@ -1381,6 +1449,34 @@ impl ProviderDataCollection {
}
}
async fn llm_token_retry(
llm_token: &LlmApiToken,
client: &Arc<Client>,
build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
) -> Result<Response<AsyncBody>> {
let mut did_retry = false;
let http_client = client.http_client();
let mut token = llm_token.acquire(client).await?;
loop {
let request = build_request(token.clone())?;
let response = http_client.send(request).await?;
if !did_retry
&& !response.status().is_success()
&& response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
{
did_retry = true;
token = llm_token.refresh(client).await?;
continue;
}
return Ok(response);
}
}
pub struct ZetaInlineCompletionProvider {
zeta: Entity<Zeta>,
pending_completions: ArrayVec<PendingCompletion, 2>,
@ -1597,7 +1693,18 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
// Right now we don't support cycling.
}
fn accept(&mut self, _cx: &mut Context<Self>) {
fn accept(&mut self, cx: &mut Context<Self>) {
let completion_id = self
.current_completion
.as_ref()
.map(|completion| completion.completion.id);
if let Some(completion_id) = completion_id {
self.zeta
.update(cx, |zeta, cx| {
zeta.accept_edit_prediction(completion_id, cx)
})
.detach();
}
self.pending_completions.clear();
}