zeta: Respect x-zed-minimum-required-version header (#24771)

This PR makes it so Zeta respects the `x-zed-minimum-required-version`
header sent back from the server.

If the current Zed version is strictly less than the indicated minimum
required version, we show an error indicating that an update is required
in order to continue using Zeta:

<img width="472" alt="Screenshot 2025-02-12 at 6 15 44 PM"
src="https://github.com/user-attachments/assets/51b85dff-23a0-464c-ae4b-5b8f46b5915c"
/>

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-02-12 18:58:38 -05:00 committed by GitHub
parent 0e42a69490
commit 277fb54632
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 132 additions and 19 deletions

6
Cargo.lock generated
View file

@ -16852,9 +16852,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
version = "0.4.0"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "614669bead4741b2fc352ae1967318be16949cf46f59013e548c6dbfdfc01252"
checksum = "1bf21350eced858d129840589158a8f6895c4fa4327ae56dd8c7d6a98495bed4"
dependencies = [
"serde",
"serde_json",
@ -17076,6 +17076,7 @@ dependencies = [
"postage",
"project",
"regex",
"release_channel",
"reqwest_client",
"rpc",
"serde",
@ -17085,6 +17086,7 @@ dependencies = [
"telemetry",
"telemetry_events",
"theme",
"thiserror 1.0.69",
"tree-sitter-go",
"tree-sitter-rust",
"ui",

View file

@ -264,7 +264,13 @@ fn assign_edit_prediction_provider(
}
}
let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
let zeta = zeta::Zeta::register(
Some(cx.entity()),
worktree,
client.clone(),
user_store,
cx,
);
if let Some(buffer) = &singleton_buffer {
if buffer.read(cx).file().is_some() {

View file

@ -39,6 +39,7 @@ menu.workspace = true
postage.workspace = true
project.workspace = true
regex.workspace = true
release_channel.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
@ -46,6 +47,7 @@ similar.workspace = true
telemetry.workspace = true
telemetry_events.workspace = true
theme.workspace = true
thiserror.workspace = true
ui.workspace = true
util.workspace = true
uuid.workspace = true

View file

@ -9,6 +9,7 @@ mod rate_completion_modal;
pub(crate) use completion_diff_element::*;
use db::kvp::KEY_VALUE_STORE;
use editor::Editor;
pub use init::*;
use inline_completion::DataCollectionState;
pub use license_detection::is_license_eligible_for_data_collection;
@ -20,10 +21,10 @@ use anyhow::{anyhow, Context as _, Result};
use arrayvec::ArrayVec;
use client::{Client, UserStore};
use collections::{HashMap, HashSet, VecDeque};
use feature_flags::FeatureFlagAppExt as _;
use futures::AsyncReadExt;
use gpui::{
actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task,
actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
Subscription, Task,
};
use http_client::{HttpClient, Method};
use input_excerpt::excerpt_for_cursor_position;
@ -34,7 +35,9 @@ use language::{
use language_models::LlmApiToken;
use postage::watch;
use project::Project;
use release_channel::AppVersion;
use settings::WorktreeId;
use std::str::FromStr;
use std::{
borrow::Cow,
cmp,
@ -48,10 +51,16 @@ use std::{
time::{Duration, Instant},
};
use telemetry_events::InlineCompletionRating;
use thiserror::Error;
use util::ResultExt;
use uuid::Uuid;
use workspace::notifications::{ErrorMessagePrompt, NotificationId};
use workspace::Workspace;
use worktree::Worktree;
use zed_llm_client::{PredictEditsBody, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
use zed_llm_client::{
PredictEditsBody, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
MINIMUM_REQUIRED_VERSION_HEADER_NAME,
};
const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
@ -178,6 +187,7 @@ impl std::fmt::Debug for InlineCompletion {
}
pub struct Zeta {
editor: Option<Entity<Editor>>,
client: Arc<Client>,
events: VecDeque<Event>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
@ -188,6 +198,8 @@ pub struct Zeta {
_llm_token_subscription: Subscription,
/// Whether the terms of service have been accepted.
tos_accepted: bool,
/// Whether an update to a newer version of Zed is required to continue using Zeta.
update_required: bool,
_user_store_subscription: Subscription,
license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
}
@ -198,13 +210,14 @@ impl Zeta {
}
pub fn register(
editor: Option<Entity<Editor>>,
worktree: Option<Entity<Worktree>>,
client: Arc<Client>,
user_store: Entity<UserStore>,
cx: &mut App,
) -> Entity<Self> {
let this = Self::global(cx).unwrap_or_else(|| {
let entity = cx.new(|cx| Self::new(client, user_store, cx));
let entity = cx.new(|cx| Self::new(editor, client, user_store, cx));
cx.set_global(ZetaGlobal(entity.clone()));
entity
});
@ -226,13 +239,19 @@ impl Zeta {
self.events.clear();
}
fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
fn new(
editor: Option<Entity<Editor>>,
client: Arc<Client>,
user_store: Entity<UserStore>,
cx: &mut Context<Self>,
) -> Self {
let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
let data_collection_choice = Self::load_data_collection_choices();
let data_collection_choice = cx.new(|_| data_collection_choice);
Self {
editor,
client,
events: VecDeque::new(),
shown_completions: VecDeque::new(),
@ -256,6 +275,7 @@ impl Zeta {
.read(cx)
.current_user_has_accepted_terms()
.unwrap_or(false),
update_required: false,
_user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
match event {
client::user::Event::PrivateUserInfoUpdated => {
@ -335,8 +355,10 @@ impl Zeta {
}
}
pub fn request_completion_impl<F, R>(
#[allow(clippy::too_many_arguments)]
fn request_completion_impl<F, R>(
&mut self,
workspace: Option<Entity<Workspace>>,
project: Option<&Entity<Project>>,
buffer: &Entity<Buffer>,
cursor: language::Anchor,
@ -345,7 +367,7 @@ impl Zeta {
perform_predict_edits: F,
) -> Task<Result<Option<InlineCompletion>>>
where
F: FnOnce(Arc<Client>, LlmApiToken, bool, PredictEditsBody) -> R + 'static,
F: FnOnce(PerformPredictEditsParams) -> R + 'static,
R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
{
let snapshot = self.report_changes_for_buffer(&buffer, cx);
@ -358,9 +380,10 @@ impl Zeta {
.map(|f| Arc::from(f.full_path(cx).as_path()))
.unwrap_or_else(|| Arc::from(Path::new("untitled")));
let zeta = cx.entity();
let client = self.client.clone();
let llm_token = self.llm_token.clone();
let is_staff = cx.is_staff();
let app_version = AppVersion::global(cx);
let buffer = buffer.clone();
@ -447,7 +470,46 @@ impl Zeta {
}),
};
let response = perform_predict_edits(client, llm_token, is_staff, body).await?;
let response = perform_predict_edits(PerformPredictEditsParams {
client,
llm_token,
app_version,
body,
})
.await;
let response = match response {
Ok(response) => response,
Err(err) => {
if err.is::<ZedUpdateRequiredError>() {
cx.update(|cx| {
zeta.update(cx, |zeta, _cx| {
zeta.update_required = true;
});
if let Some(workspace) = workspace {
workspace.update(cx, |workspace, cx| {
workspace.show_notification(
NotificationId::unique::<ZedUpdateRequiredError>(),
cx,
|cx| {
cx.new(|_| {
ErrorMessagePrompt::new(err.to_string())
.with_link_button(
"Update Zed",
"https://zed.dev/releases",
)
})
},
);
});
}
})
.ok();
}
return Err(err);
}
};
log::debug!("completion response: {}", &response.output_excerpt);
@ -632,7 +694,7 @@ and then another
) -> Task<Result<Option<InlineCompletion>>> {
use std::future::ready;
self.request_completion_impl(project, buffer, position, false, cx, |_, _, _, _| {
self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
ready(Ok(response))
})
}
@ -645,7 +707,12 @@ and then another
can_collect_data: bool,
cx: &mut Context<Self>,
) -> Task<Result<Option<InlineCompletion>>> {
let workspace = self
.editor
.as_ref()
.and_then(|editor| editor.read(cx).workspace());
self.request_completion_impl(
workspace,
project,
buffer,
position,
@ -656,12 +723,17 @@ and then another
}
fn perform_predict_edits(
client: Arc<Client>,
llm_token: LlmApiToken,
_is_staff: bool,
body: PredictEditsBody,
params: PerformPredictEditsParams,
) -> impl Future<Output = Result<PredictEditsResponse>> {
async move {
let PerformPredictEditsParams {
client,
llm_token,
app_version,
body,
..
} = params;
let http_client = client.http_client();
let mut token = llm_token.acquire(&client).await?;
let mut did_retry = false;
@ -685,6 +757,18 @@ and then another
let mut response = http_client.send(request).await?;
if let Some(minimum_required_version) = response
.headers()
.get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
.and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
{
if app_version < minimum_required_version {
return Err(anyhow!(ZedUpdateRequiredError {
minimum_version: minimum_required_version
}));
}
}
if response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
@ -1011,6 +1095,21 @@ and then another
}
}
struct PerformPredictEditsParams {
pub client: Arc<Client>,
pub llm_token: LlmApiToken,
pub app_version: SemanticVersion,
pub body: PredictEditsBody,
}
#[derive(Error, Debug)]
#[error(
"You must update to Zed version {minimum_version} or higher to continue using edit predictions."
)]
pub struct ZedUpdateRequiredError {
minimum_version: SemanticVersion,
}
struct LicenseDetectionWatcher {
is_open_source_rx: watch::Receiver<bool>,
_is_open_source_task: Task<()>,
@ -1406,6 +1505,10 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
return;
}
if self.zeta.read(cx).update_required {
return;
}
if let Some(current_completion) = self.current_completion.as_ref() {
let snapshot = buffer.read(cx).snapshot();
if current_completion
@ -1837,7 +1940,7 @@ mod tests {
});
let server = FakeServer::for_client(42, &client, cx).await;
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let zeta = cx.new(|cx| Zeta::new(client, user_store, cx));
let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
@ -1890,7 +1993,7 @@ mod tests {
});
let server = FakeServer::for_client(42, &client, cx).await;
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let zeta = cx.new(|cx| Zeta::new(client, user_store, cx));
let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());