Acquire LLM token from Cloud instead of Collab for Edit Predictions (#35431)
This PR updates the Zed Edit Prediction provider to acquire the LLM token from Cloud instead of Collab to allow using Edit Predictions even when disconnected from or unable to connect to the Collab server. Release Notes: - N/A --------- Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
parent
8e7f1899e1
commit
410348deb0
8 changed files with 211 additions and 125 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -20570,6 +20570,7 @@ dependencies = [
|
||||||
"call",
|
"call",
|
||||||
"client",
|
"client",
|
||||||
"clock",
|
"clock",
|
||||||
|
"cloud_api_types",
|
||||||
"cloud_llm_client",
|
"cloud_llm_client",
|
||||||
"collections",
|
"collections",
|
||||||
"command_palette_hooks",
|
"command_palette_hooks",
|
||||||
|
@ -20590,7 +20591,6 @@ dependencies = [
|
||||||
"menu",
|
"menu",
|
||||||
"postage",
|
"postage",
|
||||||
"project",
|
"project",
|
||||||
"proto",
|
|
||||||
"regex",
|
"regex",
|
||||||
"release_channel",
|
"release_channel",
|
||||||
"reqwest_client",
|
"reqwest_client",
|
||||||
|
|
|
@ -8,13 +8,14 @@ use cloud_llm_client::Plan;
|
||||||
use gpui::{Context, Entity, Subscription, Task};
|
use gpui::{Context, Entity, Subscription, Task};
|
||||||
use util::{ResultExt as _, maybe};
|
use util::{ResultExt as _, maybe};
|
||||||
|
|
||||||
use crate::UserStore;
|
|
||||||
use crate::user::Event as RpcUserStoreEvent;
|
use crate::user::Event as RpcUserStoreEvent;
|
||||||
|
use crate::{EditPredictionUsage, RequestUsage, UserStore};
|
||||||
|
|
||||||
pub struct CloudUserStore {
|
pub struct CloudUserStore {
|
||||||
cloud_client: Arc<CloudApiClient>,
|
cloud_client: Arc<CloudApiClient>,
|
||||||
authenticated_user: Option<Arc<AuthenticatedUser>>,
|
authenticated_user: Option<Arc<AuthenticatedUser>>,
|
||||||
plan_info: Option<Arc<PlanInfo>>,
|
plan_info: Option<Arc<PlanInfo>>,
|
||||||
|
edit_prediction_usage: Option<EditPredictionUsage>,
|
||||||
_maintain_authenticated_user_task: Task<()>,
|
_maintain_authenticated_user_task: Task<()>,
|
||||||
_rpc_plan_updated_subscription: Subscription,
|
_rpc_plan_updated_subscription: Subscription,
|
||||||
}
|
}
|
||||||
|
@ -32,6 +33,7 @@ impl CloudUserStore {
|
||||||
cloud_client: cloud_client.clone(),
|
cloud_client: cloud_client.clone(),
|
||||||
authenticated_user: None,
|
authenticated_user: None,
|
||||||
plan_info: None,
|
plan_info: None,
|
||||||
|
edit_prediction_usage: None,
|
||||||
_maintain_authenticated_user_task: cx.spawn(async move |this, cx| {
|
_maintain_authenticated_user_task: cx.spawn(async move |this, cx| {
|
||||||
maybe!(async move {
|
maybe!(async move {
|
||||||
loop {
|
loop {
|
||||||
|
@ -102,8 +104,48 @@ impl CloudUserStore {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn has_accepted_tos(&self) -> bool {
|
||||||
|
self.authenticated_user
|
||||||
|
.as_ref()
|
||||||
|
.map(|user| user.accepted_tos_at.is_some())
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns whether the user's account is too new to use the service.
|
||||||
|
pub fn account_too_young(&self) -> bool {
|
||||||
|
self.plan_info
|
||||||
|
.as_ref()
|
||||||
|
.map(|plan| plan.is_account_too_young)
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns whether the current user has overdue invoices and usage should be blocked.
|
||||||
|
pub fn has_overdue_invoices(&self) -> bool {
|
||||||
|
self.plan_info
|
||||||
|
.as_ref()
|
||||||
|
.map(|plan| plan.has_overdue_invoices)
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
|
||||||
|
self.edit_prediction_usage
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update_edit_prediction_usage(
|
||||||
|
&mut self,
|
||||||
|
usage: EditPredictionUsage,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
self.edit_prediction_usage = Some(usage);
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
|
||||||
fn update_authenticated_user(&mut self, response: GetAuthenticatedUserResponse) {
|
fn update_authenticated_user(&mut self, response: GetAuthenticatedUserResponse) {
|
||||||
self.authenticated_user = Some(Arc::new(response.user));
|
self.authenticated_user = Some(Arc::new(response.user));
|
||||||
|
self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage {
|
||||||
|
limit: response.plan.usage.edit_predictions.limit,
|
||||||
|
amount: response.plan.usage.edit_predictions.used as i32,
|
||||||
|
}));
|
||||||
self.plan_info = Some(Arc::new(response.plan));
|
self.plan_info = Some(Arc::new(response.plan));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -114,7 +114,6 @@ pub struct UserStore {
|
||||||
subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
|
subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
|
||||||
trial_started_at: Option<DateTime<Utc>>,
|
trial_started_at: Option<DateTime<Utc>>,
|
||||||
model_request_usage: Option<ModelRequestUsage>,
|
model_request_usage: Option<ModelRequestUsage>,
|
||||||
edit_prediction_usage: Option<EditPredictionUsage>,
|
|
||||||
is_usage_based_billing_enabled: Option<bool>,
|
is_usage_based_billing_enabled: Option<bool>,
|
||||||
account_too_young: Option<bool>,
|
account_too_young: Option<bool>,
|
||||||
has_overdue_invoices: Option<bool>,
|
has_overdue_invoices: Option<bool>,
|
||||||
|
@ -193,7 +192,6 @@ impl UserStore {
|
||||||
subscription_period: None,
|
subscription_period: None,
|
||||||
trial_started_at: None,
|
trial_started_at: None,
|
||||||
model_request_usage: None,
|
model_request_usage: None,
|
||||||
edit_prediction_usage: None,
|
|
||||||
is_usage_based_billing_enabled: None,
|
is_usage_based_billing_enabled: None,
|
||||||
account_too_young: None,
|
account_too_young: None,
|
||||||
has_overdue_invoices: None,
|
has_overdue_invoices: None,
|
||||||
|
@ -381,12 +379,6 @@ impl UserStore {
|
||||||
RequestUsage::from_proto(usage.model_requests_usage_amount, limit)
|
RequestUsage::from_proto(usage.model_requests_usage_amount, limit)
|
||||||
})
|
})
|
||||||
.map(ModelRequestUsage);
|
.map(ModelRequestUsage);
|
||||||
this.edit_prediction_usage = usage
|
|
||||||
.edit_predictions_usage_limit
|
|
||||||
.and_then(|limit| {
|
|
||||||
RequestUsage::from_proto(usage.model_requests_usage_amount, limit)
|
|
||||||
})
|
|
||||||
.map(EditPredictionUsage);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cx.emit(Event::PlanUpdated);
|
cx.emit(Event::PlanUpdated);
|
||||||
|
@ -400,15 +392,6 @@ impl UserStore {
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_edit_prediction_usage(
|
|
||||||
&mut self,
|
|
||||||
usage: EditPredictionUsage,
|
|
||||||
cx: &mut Context<Self>,
|
|
||||||
) {
|
|
||||||
self.edit_prediction_usage = Some(usage);
|
|
||||||
cx.notify();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
|
fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
|
||||||
match message {
|
match message {
|
||||||
UpdateContacts::Wait(barrier) => {
|
UpdateContacts::Wait(barrier) => {
|
||||||
|
@ -797,10 +780,6 @@ impl UserStore {
|
||||||
self.model_request_usage
|
self.model_request_usage
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
|
|
||||||
self.edit_prediction_usage
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
|
pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
|
||||||
self.current_user.clone()
|
self.current_user.clone()
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,9 +64,14 @@ impl LlmApiToken {
|
||||||
mut lock: RwLockWriteGuard<'_, Option<String>>,
|
mut lock: RwLockWriteGuard<'_, Option<String>>,
|
||||||
client: &Arc<Client>,
|
client: &Arc<Client>,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
let response = client.request(proto::GetLlmToken {}).await?;
|
let system_id = client
|
||||||
*lock = Some(response.token.clone());
|
.telemetry()
|
||||||
Ok(response.token.clone())
|
.system_id()
|
||||||
|
.map(|system_id| system_id.to_string());
|
||||||
|
|
||||||
|
let response = client.cloud_client().create_llm_token(system_id).await?;
|
||||||
|
*lock = Some(response.token.0.clone());
|
||||||
|
Ok(response.token.0.clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -564,7 +564,7 @@ pub fn main() {
|
||||||
snippet_provider::init(cx);
|
snippet_provider::init(cx);
|
||||||
inline_completion_registry::init(
|
inline_completion_registry::init(
|
||||||
app_state.client.clone(),
|
app_state.client.clone(),
|
||||||
app_state.user_store.clone(),
|
app_state.cloud_user_store.clone(),
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx);
|
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx);
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use client::{Client, DisableAiSettings, UserStore};
|
use client::{Client, CloudUserStore, DisableAiSettings};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use copilot::{Copilot, CopilotCompletionProvider};
|
use copilot::{Copilot, CopilotCompletionProvider};
|
||||||
use editor::Editor;
|
use editor::Editor;
|
||||||
|
@ -13,12 +13,12 @@ use util::ResultExt;
|
||||||
use workspace::Workspace;
|
use workspace::Workspace;
|
||||||
use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider};
|
use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider};
|
||||||
|
|
||||||
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
pub fn init(client: Arc<Client>, cloud_user_store: Entity<CloudUserStore>, cx: &mut App) {
|
||||||
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
|
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
|
||||||
cx.observe_new({
|
cx.observe_new({
|
||||||
let editors = editors.clone();
|
let editors = editors.clone();
|
||||||
let client = client.clone();
|
let client = client.clone();
|
||||||
let user_store = user_store.clone();
|
let cloud_user_store = cloud_user_store.clone();
|
||||||
move |editor: &mut Editor, window, cx: &mut Context<Editor>| {
|
move |editor: &mut Editor, window, cx: &mut Context<Editor>| {
|
||||||
if !editor.mode().is_full() {
|
if !editor.mode().is_full() {
|
||||||
return;
|
return;
|
||||||
|
@ -48,7 +48,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||||
editor,
|
editor,
|
||||||
provider,
|
provider,
|
||||||
&client,
|
&client,
|
||||||
user_store.clone(),
|
cloud_user_store.clone(),
|
||||||
window,
|
window,
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
@ -60,7 +60,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||||
|
|
||||||
let mut provider = all_language_settings(None, cx).edit_predictions.provider;
|
let mut provider = all_language_settings(None, cx).edit_predictions.provider;
|
||||||
cx.spawn({
|
cx.spawn({
|
||||||
let user_store = user_store.clone();
|
let cloud_user_store = cloud_user_store.clone();
|
||||||
let editors = editors.clone();
|
let editors = editors.clone();
|
||||||
let client = client.clone();
|
let client = client.clone();
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||||
&editors,
|
&editors,
|
||||||
provider,
|
provider,
|
||||||
&client,
|
&client,
|
||||||
user_store.clone(),
|
cloud_user_store.clone(),
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
})
|
})
|
||||||
|
@ -85,15 +85,12 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||||
cx.observe_global::<SettingsStore>({
|
cx.observe_global::<SettingsStore>({
|
||||||
let editors = editors.clone();
|
let editors = editors.clone();
|
||||||
let client = client.clone();
|
let client = client.clone();
|
||||||
let user_store = user_store.clone();
|
let cloud_user_store = cloud_user_store.clone();
|
||||||
move |cx| {
|
move |cx| {
|
||||||
let new_provider = all_language_settings(None, cx).edit_predictions.provider;
|
let new_provider = all_language_settings(None, cx).edit_predictions.provider;
|
||||||
|
|
||||||
if new_provider != provider {
|
if new_provider != provider {
|
||||||
let tos_accepted = user_store
|
let tos_accepted = cloud_user_store.read(cx).has_accepted_tos();
|
||||||
.read(cx)
|
|
||||||
.current_user_has_accepted_terms()
|
|
||||||
.unwrap_or(false);
|
|
||||||
|
|
||||||
telemetry::event!(
|
telemetry::event!(
|
||||||
"Edit Prediction Provider Changed",
|
"Edit Prediction Provider Changed",
|
||||||
|
@ -107,7 +104,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||||
&editors,
|
&editors,
|
||||||
provider,
|
provider,
|
||||||
&client,
|
&client,
|
||||||
user_store.clone(),
|
cloud_user_store.clone(),
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -148,7 +145,7 @@ fn assign_edit_prediction_providers(
|
||||||
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
|
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
|
||||||
provider: EditPredictionProvider,
|
provider: EditPredictionProvider,
|
||||||
client: &Arc<Client>,
|
client: &Arc<Client>,
|
||||||
user_store: Entity<UserStore>,
|
cloud_user_store: Entity<CloudUserStore>,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) {
|
) {
|
||||||
for (editor, window) in editors.borrow().iter() {
|
for (editor, window) in editors.borrow().iter() {
|
||||||
|
@ -158,7 +155,7 @@ fn assign_edit_prediction_providers(
|
||||||
editor,
|
editor,
|
||||||
provider,
|
provider,
|
||||||
&client,
|
&client,
|
||||||
user_store.clone(),
|
cloud_user_store.clone(),
|
||||||
window,
|
window,
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
@ -213,7 +210,7 @@ fn assign_edit_prediction_provider(
|
||||||
editor: &mut Editor,
|
editor: &mut Editor,
|
||||||
provider: EditPredictionProvider,
|
provider: EditPredictionProvider,
|
||||||
client: &Arc<Client>,
|
client: &Arc<Client>,
|
||||||
user_store: Entity<UserStore>,
|
cloud_user_store: Entity<CloudUserStore>,
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut Context<Editor>,
|
cx: &mut Context<Editor>,
|
||||||
) {
|
) {
|
||||||
|
@ -244,7 +241,7 @@ fn assign_edit_prediction_provider(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EditPredictionProvider::Zed => {
|
EditPredictionProvider::Zed => {
|
||||||
if client.status().borrow().is_connected() {
|
if cloud_user_store.read(cx).is_authenticated() {
|
||||||
let mut worktree = None;
|
let mut worktree = None;
|
||||||
|
|
||||||
if let Some(buffer) = &singleton_buffer {
|
if let Some(buffer) = &singleton_buffer {
|
||||||
|
@ -266,7 +263,7 @@ fn assign_edit_prediction_provider(
|
||||||
.map(|workspace| workspace.downgrade());
|
.map(|workspace| workspace.downgrade());
|
||||||
|
|
||||||
let zeta =
|
let zeta =
|
||||||
zeta::Zeta::register(workspace, worktree, client.clone(), user_store, cx);
|
zeta::Zeta::register(workspace, worktree, client.clone(), cloud_user_store, cx);
|
||||||
|
|
||||||
if let Some(buffer) = &singleton_buffer {
|
if let Some(buffer) = &singleton_buffer {
|
||||||
if buffer.read(cx).file().is_some() {
|
if buffer.read(cx).file().is_some() {
|
||||||
|
|
|
@ -40,7 +40,6 @@ log.workspace = true
|
||||||
menu.workspace = true
|
menu.workspace = true
|
||||||
postage.workspace = true
|
postage.workspace = true
|
||||||
project.workspace = true
|
project.workspace = true
|
||||||
proto.workspace = true
|
|
||||||
regex.workspace = true
|
regex.workspace = true
|
||||||
release_channel.workspace = true
|
release_channel.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
|
@ -59,9 +58,11 @@ worktree.workspace = true
|
||||||
zed_actions.workspace = true
|
zed_actions.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
collections = { workspace = true, features = ["test-support"] }
|
call = { workspace = true, features = ["test-support"] }
|
||||||
client = { workspace = true, features = ["test-support"] }
|
client = { workspace = true, features = ["test-support"] }
|
||||||
clock = { workspace = true, features = ["test-support"] }
|
clock = { workspace = true, features = ["test-support"] }
|
||||||
|
cloud_api_types.workspace = true
|
||||||
|
collections = { workspace = true, features = ["test-support"] }
|
||||||
ctor.workspace = true
|
ctor.workspace = true
|
||||||
editor = { workspace = true, features = ["test-support"] }
|
editor = { workspace = true, features = ["test-support"] }
|
||||||
gpui = { workspace = true, features = ["test-support"] }
|
gpui = { workspace = true, features = ["test-support"] }
|
||||||
|
@ -77,5 +78,4 @@ tree-sitter-rust.workspace = true
|
||||||
unindent.workspace = true
|
unindent.workspace = true
|
||||||
workspace = { workspace = true, features = ["test-support"] }
|
workspace = { workspace = true, features = ["test-support"] }
|
||||||
worktree = { workspace = true, features = ["test-support"] }
|
worktree = { workspace = true, features = ["test-support"] }
|
||||||
call = { workspace = true, features = ["test-support"] }
|
|
||||||
zlog.workspace = true
|
zlog.workspace = true
|
||||||
|
|
|
@ -16,7 +16,7 @@ pub use rate_completion_modal::*;
|
||||||
|
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use arrayvec::ArrayVec;
|
use arrayvec::ArrayVec;
|
||||||
use client::{Client, EditPredictionUsage, UserStore};
|
use client::{Client, CloudUserStore, EditPredictionUsage, UserStore};
|
||||||
use cloud_llm_client::{
|
use cloud_llm_client::{
|
||||||
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
|
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
|
||||||
PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
|
PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
|
||||||
|
@ -226,12 +226,9 @@ pub struct Zeta {
|
||||||
data_collection_choice: Entity<DataCollectionChoice>,
|
data_collection_choice: Entity<DataCollectionChoice>,
|
||||||
llm_token: LlmApiToken,
|
llm_token: LlmApiToken,
|
||||||
_llm_token_subscription: Subscription,
|
_llm_token_subscription: Subscription,
|
||||||
/// Whether the terms of service have been accepted.
|
|
||||||
tos_accepted: bool,
|
|
||||||
/// Whether an update to a newer version of Zed is required to continue using Zeta.
|
/// Whether an update to a newer version of Zed is required to continue using Zeta.
|
||||||
update_required: bool,
|
update_required: bool,
|
||||||
user_store: Entity<UserStore>,
|
cloud_user_store: Entity<CloudUserStore>,
|
||||||
_user_store_subscription: Subscription,
|
|
||||||
license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
|
license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -244,11 +241,11 @@ impl Zeta {
|
||||||
workspace: Option<WeakEntity<Workspace>>,
|
workspace: Option<WeakEntity<Workspace>>,
|
||||||
worktree: Option<Entity<Worktree>>,
|
worktree: Option<Entity<Worktree>>,
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
user_store: Entity<UserStore>,
|
cloud_user_store: Entity<CloudUserStore>,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Entity<Self> {
|
) -> Entity<Self> {
|
||||||
let this = Self::global(cx).unwrap_or_else(|| {
|
let this = Self::global(cx).unwrap_or_else(|| {
|
||||||
let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
|
let entity = cx.new(|cx| Self::new(workspace, client, cloud_user_store, cx));
|
||||||
cx.set_global(ZetaGlobal(entity.clone()));
|
cx.set_global(ZetaGlobal(entity.clone()));
|
||||||
entity
|
entity
|
||||||
});
|
});
|
||||||
|
@ -271,13 +268,13 @@ impl Zeta {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
|
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
|
||||||
self.user_store.read(cx).edit_prediction_usage()
|
self.cloud_user_store.read(cx).edit_prediction_usage()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new(
|
fn new(
|
||||||
workspace: Option<WeakEntity<Workspace>>,
|
workspace: Option<WeakEntity<Workspace>>,
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
user_store: Entity<UserStore>,
|
cloud_user_store: Entity<CloudUserStore>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
|
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
|
||||||
|
@ -306,24 +303,9 @@ impl Zeta {
|
||||||
.detach_and_log_err(cx);
|
.detach_and_log_err(cx);
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
tos_accepted: user_store
|
|
||||||
.read(cx)
|
|
||||||
.current_user_has_accepted_terms()
|
|
||||||
.unwrap_or(false),
|
|
||||||
update_required: false,
|
update_required: false,
|
||||||
_user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
|
|
||||||
match event {
|
|
||||||
client::user::Event::PrivateUserInfoUpdated => {
|
|
||||||
this.tos_accepted = user_store
|
|
||||||
.read(cx)
|
|
||||||
.current_user_has_accepted_terms()
|
|
||||||
.unwrap_or(false);
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
license_detection_watchers: HashMap::default(),
|
license_detection_watchers: HashMap::default(),
|
||||||
user_store,
|
cloud_user_store,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -552,8 +534,8 @@ impl Zeta {
|
||||||
|
|
||||||
if let Some(usage) = usage {
|
if let Some(usage) = usage {
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
this.user_store.update(cx, |user_store, cx| {
|
this.cloud_user_store.update(cx, |cloud_user_store, cx| {
|
||||||
user_store.update_edit_prediction_usage(usage, cx);
|
cloud_user_store.update_edit_prediction_usage(usage, cx);
|
||||||
});
|
});
|
||||||
})
|
})
|
||||||
.ok();
|
.ok();
|
||||||
|
@ -894,8 +876,8 @@ and then another
|
||||||
if response.status().is_success() {
|
if response.status().is_success() {
|
||||||
if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
|
if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
this.user_store.update(cx, |user_store, cx| {
|
this.cloud_user_store.update(cx, |cloud_user_store, cx| {
|
||||||
user_store.update_edit_prediction_usage(usage, cx);
|
cloud_user_store.update_edit_prediction_usage(usage, cx);
|
||||||
});
|
});
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
|
@ -1573,7 +1555,12 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
fn needs_terms_acceptance(&self, cx: &App) -> bool {
|
fn needs_terms_acceptance(&self, cx: &App) -> bool {
|
||||||
!self.zeta.read(cx).tos_accepted
|
!self
|
||||||
|
.zeta
|
||||||
|
.read(cx)
|
||||||
|
.cloud_user_store
|
||||||
|
.read(cx)
|
||||||
|
.has_accepted_tos()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_refreshing(&self) -> bool {
|
fn is_refreshing(&self) -> bool {
|
||||||
|
@ -1588,7 +1575,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
|
||||||
_debounce: bool,
|
_debounce: bool,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
if !self.zeta.read(cx).tos_accepted {
|
if self.needs_terms_acceptance(cx) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1599,9 +1586,9 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
|
||||||
if self
|
if self
|
||||||
.zeta
|
.zeta
|
||||||
.read(cx)
|
.read(cx)
|
||||||
.user_store
|
.cloud_user_store
|
||||||
.read_with(cx, |user_store, _| {
|
.read_with(cx, |cloud_user_store, _cx| {
|
||||||
user_store.account_too_young() || user_store.has_overdue_invoices()
|
cloud_user_store.account_too_young() || cloud_user_store.has_overdue_invoices()
|
||||||
})
|
})
|
||||||
{
|
{
|
||||||
return;
|
return;
|
||||||
|
@ -1819,15 +1806,51 @@ fn tokens_for_bytes(bytes: usize) -> usize {
|
||||||
mod tests {
|
mod tests {
|
||||||
use client::test::FakeServer;
|
use client::test::FakeServer;
|
||||||
use clock::FakeSystemClock;
|
use clock::FakeSystemClock;
|
||||||
|
use cloud_api_types::{
|
||||||
|
AuthenticatedUser, CreateLlmTokenResponse, GetAuthenticatedUserResponse, LlmToken, PlanInfo,
|
||||||
|
};
|
||||||
|
use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
|
||||||
use gpui::TestAppContext;
|
use gpui::TestAppContext;
|
||||||
use http_client::FakeHttpClient;
|
use http_client::FakeHttpClient;
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use language::Point;
|
use language::Point;
|
||||||
use rpc::proto;
|
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
fn make_get_authenticated_user_response() -> GetAuthenticatedUserResponse {
|
||||||
|
GetAuthenticatedUserResponse {
|
||||||
|
user: AuthenticatedUser {
|
||||||
|
id: 1,
|
||||||
|
metrics_id: "metrics-id-1".to_string(),
|
||||||
|
avatar_url: "".to_string(),
|
||||||
|
github_login: "".to_string(),
|
||||||
|
name: None,
|
||||||
|
is_staff: false,
|
||||||
|
accepted_tos_at: None,
|
||||||
|
},
|
||||||
|
feature_flags: vec![],
|
||||||
|
plan: PlanInfo {
|
||||||
|
plan: Plan::ZedPro,
|
||||||
|
subscription_period: None,
|
||||||
|
usage: CurrentUsage {
|
||||||
|
model_requests: UsageData {
|
||||||
|
used: 0,
|
||||||
|
limit: UsageLimit::Limited(500),
|
||||||
|
},
|
||||||
|
edit_predictions: UsageData {
|
||||||
|
used: 250,
|
||||||
|
limit: UsageLimit::Unlimited,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
trial_started_at: None,
|
||||||
|
is_usage_based_billing_enabled: false,
|
||||||
|
is_account_too_young: false,
|
||||||
|
has_overdue_invoices: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
|
async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
|
||||||
let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
|
let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
|
||||||
|
@ -2027,8 +2050,27 @@ mod tests {
|
||||||
<|editable_region_end|>
|
<|editable_region_end|>
|
||||||
```"};
|
```"};
|
||||||
|
|
||||||
let http_client = FakeHttpClient::create(move |_| async move {
|
let http_client = FakeHttpClient::create(move |req| async move {
|
||||||
Ok(http_client::Response::builder()
|
match (req.method(), req.uri().path()) {
|
||||||
|
(&Method::GET, "/client/users/me") => Ok(http_client::Response::builder()
|
||||||
|
.status(200)
|
||||||
|
.body(
|
||||||
|
serde_json::to_string(&make_get_authenticated_user_response())
|
||||||
|
.unwrap()
|
||||||
|
.into(),
|
||||||
|
)
|
||||||
|
.unwrap()),
|
||||||
|
(&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
|
||||||
|
.status(200)
|
||||||
|
.body(
|
||||||
|
serde_json::to_string(&CreateLlmTokenResponse {
|
||||||
|
token: LlmToken("the-llm-token".to_string()),
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.into(),
|
||||||
|
)
|
||||||
|
.unwrap()),
|
||||||
|
(&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
|
||||||
.status(200)
|
.status(200)
|
||||||
.body(
|
.body(
|
||||||
serde_json::to_string(&PredictEditsResponse {
|
serde_json::to_string(&PredictEditsResponse {
|
||||||
|
@ -2039,16 +2081,24 @@ mod tests {
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.into(),
|
.into(),
|
||||||
)
|
)
|
||||||
.unwrap())
|
.unwrap()),
|
||||||
|
_ => Ok(http_client::Response::builder()
|
||||||
|
.status(404)
|
||||||
|
.body("Not Found".into())
|
||||||
|
.unwrap()),
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
|
let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
RefreshLlmTokenListener::register(client.clone(), cx);
|
RefreshLlmTokenListener::register(client.clone(), cx);
|
||||||
});
|
});
|
||||||
let server = FakeServer::for_client(42, &client, cx).await;
|
// Construct the fake server to authenticate.
|
||||||
|
let _server = FakeServer::for_client(42, &client, cx).await;
|
||||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||||
let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
|
let cloud_user_store =
|
||||||
|
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
|
||||||
|
let zeta = cx.new(|cx| Zeta::new(None, client, cloud_user_store, cx));
|
||||||
|
|
||||||
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
|
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
|
||||||
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
|
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
|
||||||
|
@ -2056,13 +2106,6 @@ mod tests {
|
||||||
zeta.request_completion(None, &buffer, cursor, false, cx)
|
zeta.request_completion(None, &buffer, cursor, false, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
server.receive::<proto::GetUsers>().await.unwrap();
|
|
||||||
let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
|
|
||||||
server.respond(
|
|
||||||
token_request.receipt(),
|
|
||||||
proto::GetLlmTokenResponse { token: "".into() },
|
|
||||||
);
|
|
||||||
|
|
||||||
let completion = completion_task.await.unwrap().unwrap();
|
let completion = completion_task.await.unwrap().unwrap();
|
||||||
buffer.update(cx, |buffer, cx| {
|
buffer.update(cx, |buffer, cx| {
|
||||||
buffer.edit(completion.edits.iter().cloned(), None, cx)
|
buffer.edit(completion.edits.iter().cloned(), None, cx)
|
||||||
|
@ -2079,10 +2122,29 @@ mod tests {
|
||||||
cx: &mut TestAppContext,
|
cx: &mut TestAppContext,
|
||||||
) -> Vec<(Range<Point>, String)> {
|
) -> Vec<(Range<Point>, String)> {
|
||||||
let completion_response = completion_response.to_string();
|
let completion_response = completion_response.to_string();
|
||||||
let http_client = FakeHttpClient::create(move |_| {
|
let http_client = FakeHttpClient::create(move |req| {
|
||||||
let completion = completion_response.clone();
|
let completion = completion_response.clone();
|
||||||
async move {
|
async move {
|
||||||
Ok(http_client::Response::builder()
|
match (req.method(), req.uri().path()) {
|
||||||
|
(&Method::GET, "/client/users/me") => Ok(http_client::Response::builder()
|
||||||
|
.status(200)
|
||||||
|
.body(
|
||||||
|
serde_json::to_string(&make_get_authenticated_user_response())
|
||||||
|
.unwrap()
|
||||||
|
.into(),
|
||||||
|
)
|
||||||
|
.unwrap()),
|
||||||
|
(&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
|
||||||
|
.status(200)
|
||||||
|
.body(
|
||||||
|
serde_json::to_string(&CreateLlmTokenResponse {
|
||||||
|
token: LlmToken("the-llm-token".to_string()),
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.into(),
|
||||||
|
)
|
||||||
|
.unwrap()),
|
||||||
|
(&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
|
||||||
.status(200)
|
.status(200)
|
||||||
.body(
|
.body(
|
||||||
serde_json::to_string(&PredictEditsResponse {
|
serde_json::to_string(&PredictEditsResponse {
|
||||||
|
@ -2092,7 +2154,12 @@ mod tests {
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.into(),
|
.into(),
|
||||||
)
|
)
|
||||||
.unwrap())
|
.unwrap()),
|
||||||
|
_ => Ok(http_client::Response::builder()
|
||||||
|
.status(404)
|
||||||
|
.body("Not Found".into())
|
||||||
|
.unwrap()),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -2100,9 +2167,12 @@ mod tests {
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
RefreshLlmTokenListener::register(client.clone(), cx);
|
RefreshLlmTokenListener::register(client.clone(), cx);
|
||||||
});
|
});
|
||||||
let server = FakeServer::for_client(42, &client, cx).await;
|
// Construct the fake server to authenticate.
|
||||||
|
let _server = FakeServer::for_client(42, &client, cx).await;
|
||||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||||
let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
|
let cloud_user_store =
|
||||||
|
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
|
||||||
|
let zeta = cx.new(|cx| Zeta::new(None, client, cloud_user_store, cx));
|
||||||
|
|
||||||
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
|
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
|
||||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
||||||
|
@ -2111,13 +2181,6 @@ mod tests {
|
||||||
zeta.request_completion(None, &buffer, cursor, false, cx)
|
zeta.request_completion(None, &buffer, cursor, false, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
server.receive::<proto::GetUsers>().await.unwrap();
|
|
||||||
let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
|
|
||||||
server.respond(
|
|
||||||
token_request.receipt(),
|
|
||||||
proto::GetLlmTokenResponse { token: "".into() },
|
|
||||||
);
|
|
||||||
|
|
||||||
let completion = completion_task.await.unwrap().unwrap();
|
let completion = completion_task.await.unwrap().unwrap();
|
||||||
completion
|
completion
|
||||||
.edits
|
.edits
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue