zeta: Add CLI tool for querying edit predictions and related context (#35491)

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Michael Sloan 2025-08-01 15:08:09 -06:00 committed by GitHub
parent 561ccf86aa
commit 6052115825
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 719 additions and 102 deletions

View file

@ -0,0 +1,128 @@
use client::{Client, ProxySettings, UserStore};
use extension::ExtensionHostProxy;
use fs::RealFs;
use gpui::http_client::read_proxy_from_env;
use gpui::{App, AppContext, Entity};
use gpui_tokio::Tokio;
use language::LanguageRegistry;
use language_extension::LspAccess;
use node_runtime::{NodeBinaryOptions, NodeRuntime};
use project::Project;
use project::project_settings::ProjectSettings;
use release_channel::AppVersion;
use reqwest_client::ReqwestClient;
use settings::{Settings, SettingsStore};
use std::path::PathBuf;
use std::sync::Arc;
use util::ResultExt as _;
/// Headless subset of `workspace::AppState`.
pub struct ZetaCliAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
pub fs: Arc<dyn fs::Fs>,
pub node_runtime: NodeRuntime,
}
// TODO: dedupe with crates/eval/src/eval.rs
pub fn init(cx: &mut App) -> ZetaCliAppState {
let app_version = AppVersion::load(env!("ZED_PKG_VERSION"));
release_channel::init(app_version, cx);
gpui_tokio::init(cx);
let mut settings_store = SettingsStore::new(cx);
settings_store
.set_default_settings(settings::default_settings().as_ref(), cx)
.unwrap();
cx.set_global(settings_store);
client::init_settings(cx);
// Set User-Agent so we can download language servers from GitHub
let user_agent = format!(
"Zed/{} ({}; {})",
app_version,
std::env::consts::OS,
std::env::consts::ARCH
);
let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
let proxy_url = proxy_str
.as_ref()
.and_then(|input| input.parse().ok())
.or_else(read_proxy_from_env);
let http = {
let _guard = Tokio::handle(cx).enter();
ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
.expect("could not start HTTP client")
};
cx.set_http_client(Arc::new(http));
Project::init_settings(cx);
let client = Client::production(cx);
cx.set_http_client(client.http_client());
let git_binary_path = None;
let fs = Arc::new(RealFs::new(
git_binary_path,
cx.background_executor().clone(),
));
let mut languages = LanguageRegistry::new(cx.background_executor().clone());
languages.set_language_server_download_dir(paths::languages_dir().clone());
let languages = Arc::new(languages);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
extension::init(cx);
let (mut tx, rx) = watch::channel(None);
cx.observe_global::<SettingsStore>(move |cx| {
let settings = &ProjectSettings::get_global(cx).node;
let options = NodeBinaryOptions {
allow_path_lookup: !settings.ignore_system_version,
allow_binary_download: true,
use_paths: settings.path.as_ref().map(|node_path| {
let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
let npm_path = settings
.npm_path
.as_ref()
.map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
(
node_path.clone(),
npm_path.unwrap_or_else(|| {
let base_path = PathBuf::new();
node_path.parent().unwrap_or(&base_path).join("npm")
}),
)
}),
};
tx.send(Some(options)).log_err();
})
.detach();
let node_runtime = NodeRuntime::new(client.http_client(), None, rx);
let extension_host_proxy = ExtensionHostProxy::global(cx);
language::init(cx);
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
language_extension::init(
LspAccess::Noop,
extension_host_proxy.clone(),
languages.clone(),
);
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
languages::init(languages.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);
terminal_view::init(cx);
ZetaCliAppState {
languages,
client,
user_store,
fs,
node_runtime,
}
}

376
crates/zeta_cli/src/main.rs Normal file
View file

@ -0,0 +1,376 @@
mod headless;
use anyhow::{Result, anyhow};
use clap::{Args, Parser, Subcommand};
use futures::channel::mpsc;
use futures::{FutureExt as _, StreamExt as _};
use gpui::{AppContext, Application, AsyncApp};
use gpui::{Entity, Task};
use language::Bias;
use language::Buffer;
use language::Point;
use language_model::LlmApiToken;
use project::{Project, ProjectPath};
use release_channel::AppVersion;
use reqwest_client::ReqwestClient;
use std::path::{Path, PathBuf};
use std::process::exit;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context};
use crate::headless::ZetaCliAppState;
#[derive(Parser, Debug)]
#[command(name = "zeta")]
struct ZetaCliArgs {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand, Debug)]
enum Commands {
Context(ContextArgs),
Predict {
#[arg(long)]
predict_edits_body: Option<FileOrStdin>,
#[clap(flatten)]
context_args: Option<ContextArgs>,
},
}
#[derive(Debug, Args)]
#[group(requires = "worktree")]
struct ContextArgs {
#[arg(long)]
worktree: PathBuf,
#[arg(long)]
cursor: CursorPosition,
#[arg(long)]
use_language_server: bool,
#[arg(long)]
events: Option<FileOrStdin>,
}
#[derive(Debug, Clone)]
enum FileOrStdin {
File(PathBuf),
Stdin,
}
impl FileOrStdin {
async fn read_to_string(&self) -> Result<String, std::io::Error> {
match self {
FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
}
}
}
impl FromStr for FileOrStdin {
type Err = <PathBuf as FromStr>::Err;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"-" => Ok(Self::Stdin),
_ => Ok(Self::File(PathBuf::from_str(s)?)),
}
}
}
#[derive(Debug, Clone)]
struct CursorPosition {
path: PathBuf,
point: Point,
}
impl FromStr for CursorPosition {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
let parts: Vec<&str> = s.split(':').collect();
if parts.len() != 3 {
return Err(anyhow!(
"Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
s
));
}
let path = PathBuf::from(parts[0]);
let line: u32 = parts[1]
.parse()
.map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
let column: u32 = parts[2]
.parse()
.map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
// Convert from 1-based to 0-based indexing
let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
Ok(CursorPosition { path, point })
}
}
async fn get_context(
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<GatherContextOutput> {
let ContextArgs {
worktree: worktree_path,
cursor,
use_language_server,
events,
} = args;
let worktree_path = worktree_path.canonicalize()?;
if cursor.path.is_absolute() {
return Err(anyhow!("Absolute paths are not supported in --cursor"));
}
let (project, _lsp_open_handle, buffer) = if use_language_server {
let (project, lsp_open_handle, buffer) =
open_buffer_with_language_server(&worktree_path, &cursor.path, &app_state, cx).await?;
(Some(project), Some(lsp_open_handle), buffer)
} else {
let abs_path = worktree_path.join(&cursor.path);
let content = smol::fs::read_to_string(&abs_path).await?;
let buffer = cx.new(|cx| Buffer::local(content, cx))?;
(None, None, buffer)
};
let worktree_name = worktree_path
.file_name()
.ok_or_else(|| anyhow!("--worktree path must end with a folder name"))?;
let full_path_str = PathBuf::from(worktree_name)
.join(&cursor.path)
.to_string_lossy()
.to_string();
let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
if clipped_cursor != cursor.point {
let max_row = snapshot.max_point().row;
if cursor.point.row < max_row {
return Err(anyhow!(
"Cursor position {:?} is out of bounds (line length is {})",
cursor.point,
snapshot.line_len(cursor.point.row)
));
} else {
return Err(anyhow!(
"Cursor position {:?} is out of bounds (max row is {})",
cursor.point,
max_row
));
}
}
let events = match events {
Some(events) => events.read_to_string().await?,
None => String::new(),
};
let can_collect_data = false;
cx.update(|cx| {
gather_context(
project.as_ref(),
full_path_str,
&snapshot,
clipped_cursor,
move || events,
can_collect_data,
cx,
)
})?
.await
}
pub async fn open_buffer_with_language_server(
worktree_path: &Path,
path: &Path,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<(Entity<Project>, Entity<Entity<Buffer>>, Entity<Buffer>)> {
let project = cx.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})?;
let worktree = project
.update(cx, |project, cx| {
project.create_worktree(worktree_path, true, cx)
})?
.await?;
let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
worktree_id: worktree.id(),
path: path.to_path_buf().into(),
})?;
let buffer = project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
let lsp_open_handle = project.update(cx, |project, cx| {
project.register_buffer_with_language_servers(&buffer, cx)
})?;
let log_prefix = path.to_string_lossy().to_string();
wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
Ok((project, lsp_open_handle, buffer))
}
// TODO: Dedupe with similar function in crates/eval/src/instance.rs
pub fn wait_for_lang_server(
project: &Entity<Project>,
buffer: &Entity<Buffer>,
log_prefix: String,
cx: &mut AsyncApp,
) -> Task<Result<()>> {
println!("{}⏵ Waiting for language server", log_prefix);
let (mut tx, mut rx) = mpsc::channel(1);
let lsp_store = project
.read_with(cx, |project, _| project.lsp_store())
.unwrap();
let has_lang_server = buffer
.update(cx, |buffer, cx| {
lsp_store.update(cx, |lsp_store, cx| {
lsp_store
.language_servers_for_local_buffer(&buffer, cx)
.next()
.is_some()
})
})
.unwrap_or(false);
if has_lang_server {
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.unwrap()
.detach();
}
let subscriptions = [
cx.subscribe(&lsp_store, {
let log_prefix = log_prefix.clone();
move |_, event, _| match event {
project::LspStoreEvent::LanguageServerUpdate {
message:
client::proto::update_language_server::Variant::WorkProgress(
client::proto::LspWorkProgress {
message: Some(message),
..
},
),
..
} => println!("{}{message}", log_prefix),
_ => {}
}
}),
cx.subscribe(&project, {
let buffer = buffer.clone();
move |project, event, cx| match event {
project::Event::LanguageServerAdded(_, _, _) => {
let buffer = buffer.clone();
project
.update(cx, |project, cx| project.save_buffer(buffer, cx))
.detach();
}
project::Event::DiskBasedDiagnosticsFinished { .. } => {
tx.try_send(()).ok();
}
_ => {}
}
}),
];
cx.spawn(async move |cx| {
let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
let result = futures::select! {
_ = rx.next() => {
println!("{}⚑ Language server idle", log_prefix);
anyhow::Ok(())
},
_ = timeout.fuse() => {
anyhow::bail!("LSP wait timed out after 5 minutes");
}
};
drop(subscriptions);
result
})
}
fn main() {
let args = ZetaCliArgs::parse();
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client);
app.run(move |cx| {
let app_state = Arc::new(headless::init(cx));
cx.spawn(async move |cx| {
let result = match args.command {
Commands::Context(context_args) => get_context(context_args, &app_state, cx)
.await
.map(|output| serde_json::to_string_pretty(&output.body).unwrap()),
Commands::Predict {
predict_edits_body,
context_args,
} => {
cx.spawn(async move |cx| {
let app_version = cx.update(|cx| AppVersion::global(cx))?;
app_state.client.sign_in(true, cx).await?;
let llm_token = LlmApiToken::default();
llm_token.refresh(&app_state.client).await?;
let predict_edits_body =
if let Some(predict_edits_body) = predict_edits_body {
serde_json::from_str(&predict_edits_body.read_to_string().await?)?
} else if let Some(context_args) = context_args {
get_context(context_args, &app_state, cx).await?.body
} else {
return Err(anyhow!(
"Expected either --predict-edits-body-file \
or the required args of the `context` command."
));
};
let (response, _usage) =
Zeta::perform_predict_edits(PerformPredictEditsParams {
client: app_state.client.clone(),
llm_token,
app_version,
body: predict_edits_body,
})
.await?;
Ok(response.output_excerpt)
})
.await
}
};
match result {
Ok(output) => {
println!("{}", output);
let _ = cx.update(|cx| cx.quit());
}
Err(e) => {
eprintln!("Failed: {:?}", e);
exit(1);
}
}
})
.detach();
});
}