Acquire LLM token from Cloud instead of Collab for Edit Predictions (#35431)

This PR updates the Zed Edit Prediction provider to acquire the LLM
token from Cloud instead of Collab to allow using Edit Predictions even
when disconnected from or unable to connect to the Collab server.

Release Notes:

- N/A

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
Marshall Bowers 2025-07-31 18:12:04 -04:00 committed by GitHub
parent 8e7f1899e1
commit 410348deb0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 211 additions and 125 deletions

View file

@ -40,7 +40,6 @@ log.workspace = true
menu.workspace = true
postage.workspace = true
project.workspace = true
proto.workspace = true
regex.workspace = true
release_channel.workspace = true
serde.workspace = true
@ -59,9 +58,11 @@ worktree.workspace = true
zed_actions.workspace = true
[dev-dependencies]
collections = { workspace = true, features = ["test-support"] }
call = { workspace = true, features = ["test-support"] }
client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
cloud_api_types.workspace = true
collections = { workspace = true, features = ["test-support"] }
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
@ -77,5 +78,4 @@ tree-sitter-rust.workspace = true
unindent.workspace = true
workspace = { workspace = true, features = ["test-support"] }
worktree = { workspace = true, features = ["test-support"] }
call = { workspace = true, features = ["test-support"] }
zlog.workspace = true

View file

@ -16,7 +16,7 @@ pub use rate_completion_modal::*;
use anyhow::{Context as _, Result, anyhow};
use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
use client::{Client, CloudUserStore, EditPredictionUsage, UserStore};
use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
@ -226,12 +226,9 @@ pub struct Zeta {
data_collection_choice: Entity<DataCollectionChoice>,
llm_token: LlmApiToken,
_llm_token_subscription: Subscription,
/// Whether the terms of service have been accepted.
tos_accepted: bool,
/// Whether an update to a newer version of Zed is required to continue using Zeta.
update_required: bool,
user_store: Entity<UserStore>,
_user_store_subscription: Subscription,
cloud_user_store: Entity<CloudUserStore>,
license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
}
@ -244,11 +241,11 @@ impl Zeta {
workspace: Option<WeakEntity<Workspace>>,
worktree: Option<Entity<Worktree>>,
client: Arc<Client>,
user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
cx: &mut App,
) -> Entity<Self> {
let this = Self::global(cx).unwrap_or_else(|| {
let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
let entity = cx.new(|cx| Self::new(workspace, client, cloud_user_store, cx));
cx.set_global(ZetaGlobal(entity.clone()));
entity
});
@ -271,13 +268,13 @@ impl Zeta {
}
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
self.user_store.read(cx).edit_prediction_usage()
self.cloud_user_store.read(cx).edit_prediction_usage()
}
fn new(
workspace: Option<WeakEntity<Workspace>>,
client: Arc<Client>,
user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
cx: &mut Context<Self>,
) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
@ -306,24 +303,9 @@ impl Zeta {
.detach_and_log_err(cx);
},
),
tos_accepted: user_store
.read(cx)
.current_user_has_accepted_terms()
.unwrap_or(false),
update_required: false,
_user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
match event {
client::user::Event::PrivateUserInfoUpdated => {
this.tos_accepted = user_store
.read(cx)
.current_user_has_accepted_terms()
.unwrap_or(false);
}
_ => {}
}
}),
license_detection_watchers: HashMap::default(),
user_store,
cloud_user_store,
}
}
@ -552,8 +534,8 @@ impl Zeta {
if let Some(usage) = usage {
this.update(cx, |this, cx| {
this.user_store.update(cx, |user_store, cx| {
user_store.update_edit_prediction_usage(usage, cx);
this.cloud_user_store.update(cx, |cloud_user_store, cx| {
cloud_user_store.update_edit_prediction_usage(usage, cx);
});
})
.ok();
@ -894,8 +876,8 @@ and then another
if response.status().is_success() {
if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
this.update(cx, |this, cx| {
this.user_store.update(cx, |user_store, cx| {
user_store.update_edit_prediction_usage(usage, cx);
this.cloud_user_store.update(cx, |cloud_user_store, cx| {
cloud_user_store.update_edit_prediction_usage(usage, cx);
});
})?;
}
@ -1573,7 +1555,12 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
}
fn needs_terms_acceptance(&self, cx: &App) -> bool {
!self.zeta.read(cx).tos_accepted
!self
.zeta
.read(cx)
.cloud_user_store
.read(cx)
.has_accepted_tos()
}
fn is_refreshing(&self) -> bool {
@ -1588,7 +1575,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
_debounce: bool,
cx: &mut Context<Self>,
) {
if !self.zeta.read(cx).tos_accepted {
if self.needs_terms_acceptance(cx) {
return;
}
@ -1599,9 +1586,9 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
if self
.zeta
.read(cx)
.user_store
.read_with(cx, |user_store, _| {
user_store.account_too_young() || user_store.has_overdue_invoices()
.cloud_user_store
.read_with(cx, |cloud_user_store, _cx| {
cloud_user_store.account_too_young() || cloud_user_store.has_overdue_invoices()
})
{
return;
@ -1819,15 +1806,51 @@ fn tokens_for_bytes(bytes: usize) -> usize {
mod tests {
use client::test::FakeServer;
use clock::FakeSystemClock;
use cloud_api_types::{
AuthenticatedUser, CreateLlmTokenResponse, GetAuthenticatedUserResponse, LlmToken, PlanInfo,
};
use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
use gpui::TestAppContext;
use http_client::FakeHttpClient;
use indoc::indoc;
use language::Point;
use rpc::proto;
use settings::SettingsStore;
use super::*;
fn make_get_authenticated_user_response() -> GetAuthenticatedUserResponse {
GetAuthenticatedUserResponse {
user: AuthenticatedUser {
id: 1,
metrics_id: "metrics-id-1".to_string(),
avatar_url: "".to_string(),
github_login: "".to_string(),
name: None,
is_staff: false,
accepted_tos_at: None,
},
feature_flags: vec![],
plan: PlanInfo {
plan: Plan::ZedPro,
subscription_period: None,
usage: CurrentUsage {
model_requests: UsageData {
used: 0,
limit: UsageLimit::Limited(500),
},
edit_predictions: UsageData {
used: 250,
limit: UsageLimit::Unlimited,
},
},
trial_started_at: None,
is_usage_based_billing_enabled: false,
is_account_too_young: false,
has_overdue_invoices: false,
},
}
}
#[gpui::test]
async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
@ -2027,28 +2050,55 @@ mod tests {
<|editable_region_end|>
```"};
let http_client = FakeHttpClient::create(move |_| async move {
Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&PredictEditsResponse {
request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
.unwrap(),
output_excerpt: completion_response.to_string(),
})
.unwrap()
.into(),
)
.unwrap())
let http_client = FakeHttpClient::create(move |req| async move {
match (req.method(), req.uri().path()) {
(&Method::GET, "/client/users/me") => Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&make_get_authenticated_user_response())
.unwrap()
.into(),
)
.unwrap()),
(&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&CreateLlmTokenResponse {
token: LlmToken("the-llm-token".to_string()),
})
.unwrap()
.into(),
)
.unwrap()),
(&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&PredictEditsResponse {
request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
.unwrap(),
output_excerpt: completion_response.to_string(),
})
.unwrap()
.into(),
)
.unwrap()),
_ => Ok(http_client::Response::builder()
.status(404)
.body("Not Found".into())
.unwrap()),
}
});
let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
cx.update(|cx| {
RefreshLlmTokenListener::register(client.clone(), cx);
});
let server = FakeServer::for_client(42, &client, cx).await;
// Construct the fake server to authenticate.
let _server = FakeServer::for_client(42, &client, cx).await;
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
let zeta = cx.new(|cx| Zeta::new(None, client, cloud_user_store, cx));
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
@ -2056,13 +2106,6 @@ mod tests {
zeta.request_completion(None, &buffer, cursor, false, cx)
});
server.receive::<proto::GetUsers>().await.unwrap();
let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
server.respond(
token_request.receipt(),
proto::GetLlmTokenResponse { token: "".into() },
);
let completion = completion_task.await.unwrap().unwrap();
buffer.update(cx, |buffer, cx| {
buffer.edit(completion.edits.iter().cloned(), None, cx)
@ -2079,20 +2122,44 @@ mod tests {
cx: &mut TestAppContext,
) -> Vec<(Range<Point>, String)> {
let completion_response = completion_response.to_string();
let http_client = FakeHttpClient::create(move |_| {
let http_client = FakeHttpClient::create(move |req| {
let completion = completion_response.clone();
async move {
Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&PredictEditsResponse {
request_id: Uuid::new_v4(),
output_excerpt: completion,
})
.unwrap()
.into(),
)
.unwrap())
match (req.method(), req.uri().path()) {
(&Method::GET, "/client/users/me") => Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&make_get_authenticated_user_response())
.unwrap()
.into(),
)
.unwrap()),
(&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&CreateLlmTokenResponse {
token: LlmToken("the-llm-token".to_string()),
})
.unwrap()
.into(),
)
.unwrap()),
(&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&PredictEditsResponse {
request_id: Uuid::new_v4(),
output_excerpt: completion,
})
.unwrap()
.into(),
)
.unwrap()),
_ => Ok(http_client::Response::builder()
.status(404)
.body("Not Found".into())
.unwrap()),
}
}
});
@ -2100,9 +2167,12 @@ mod tests {
cx.update(|cx| {
RefreshLlmTokenListener::register(client.clone(), cx);
});
let server = FakeServer::for_client(42, &client, cx).await;
// Construct the fake server to authenticate.
let _server = FakeServer::for_client(42, &client, cx).await;
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
let zeta = cx.new(|cx| Zeta::new(None, client, cloud_user_store, cx));
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
@ -2111,13 +2181,6 @@ mod tests {
zeta.request_completion(None, &buffer, cursor, false, cx)
});
server.receive::<proto::GetUsers>().await.unwrap();
let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
server.respond(
token_request.receipt(),
proto::GetLlmTokenResponse { token: "".into() },
);
let completion = completion_task.await.unwrap().unwrap();
completion
.edits