zeta: Add CLI tool for querying edit predictions and related context (#35491)
Release Notes: - N/A --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
561ccf86aa
commit
6052115825
9 changed files with 719 additions and 102 deletions
36
Cargo.lock
generated
36
Cargo.lock
generated
|
@ -20606,6 +20606,42 @@ dependencies = [
|
|||
"zlog",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeta_cli"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
"client",
|
||||
"debug_adapter_extension",
|
||||
"extension",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"language",
|
||||
"language_extension",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"languages",
|
||||
"node_runtime",
|
||||
"paths",
|
||||
"project",
|
||||
"prompt_store",
|
||||
"release_channel",
|
||||
"reqwest_client",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"shellexpand 2.1.2",
|
||||
"smol",
|
||||
"terminal_view",
|
||||
"util",
|
||||
"watch",
|
||||
"workspace-hack",
|
||||
"zeta",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "0.6.6"
|
||||
|
|
|
@ -189,6 +189,7 @@ members = [
|
|||
"crates/zed",
|
||||
"crates/zed_actions",
|
||||
"crates/zeta",
|
||||
"crates/zeta_cli",
|
||||
"crates/zlog",
|
||||
"crates/zlog_settings",
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ use collections::{HashMap, HashSet};
|
|||
use extension::ExtensionHostProxy;
|
||||
use futures::future;
|
||||
use gpui::http_client::read_proxy_from_env;
|
||||
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
|
||||
use gpui::{App, AppContext, Application, AsyncApp, Entity, UpdateGlobal};
|
||||
use gpui_tokio::Tokio;
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry, SelectedModel};
|
||||
|
@ -337,7 +337,8 @@ pub struct AgentAppState {
|
|||
}
|
||||
|
||||
pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
||||
release_channel::init(SemanticVersion::default(), cx);
|
||||
let app_version = AppVersion::global(cx);
|
||||
release_channel::init(app_version, cx);
|
||||
gpui_tokio::init(cx);
|
||||
|
||||
let mut settings_store = SettingsStore::new(cx);
|
||||
|
@ -350,7 +351,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
|||
// Set User-Agent so we can download language servers from GitHub
|
||||
let user_agent = format!(
|
||||
"Zed/{} ({}; {})",
|
||||
AppVersion::global(cx),
|
||||
app_version,
|
||||
std::env::consts::OS,
|
||||
std::env::consts::ARCH
|
||||
);
|
||||
|
|
|
@ -146,14 +146,14 @@ pub struct InlineCompletion {
|
|||
input_events: Arc<str>,
|
||||
input_excerpt: Arc<str>,
|
||||
output_excerpt: Arc<str>,
|
||||
request_sent_at: Instant,
|
||||
buffer_snapshotted_at: Instant,
|
||||
response_received_at: Instant,
|
||||
}
|
||||
|
||||
impl InlineCompletion {
|
||||
fn latency(&self) -> Duration {
|
||||
self.response_received_at
|
||||
.duration_since(self.request_sent_at)
|
||||
.duration_since(self.buffer_snapshotted_at)
|
||||
}
|
||||
|
||||
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
|
||||
|
@ -391,104 +391,48 @@ impl Zeta {
|
|||
+ Send
|
||||
+ 'static,
|
||||
{
|
||||
let buffer = buffer.clone();
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
let snapshot = self.report_changes_for_buffer(&buffer, cx);
|
||||
let diagnostic_groups = snapshot.diagnostic_groups(None);
|
||||
let cursor_point = cursor.to_point(&snapshot);
|
||||
let cursor_offset = cursor_point.to_offset(&snapshot);
|
||||
let events = self.events.clone();
|
||||
let path: Arc<Path> = snapshot
|
||||
.file()
|
||||
.map(|f| Arc::from(f.full_path(cx).as_path()))
|
||||
.unwrap_or_else(|| Arc::from(Path::new("untitled")));
|
||||
|
||||
let zeta = cx.entity();
|
||||
let events = self.events.clone();
|
||||
let client = self.client.clone();
|
||||
let llm_token = self.llm_token.clone();
|
||||
let app_version = AppVersion::global(cx);
|
||||
|
||||
let buffer = buffer.clone();
|
||||
|
||||
let local_lsp_store =
|
||||
project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
|
||||
let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store {
|
||||
Some(
|
||||
diagnostic_groups
|
||||
.into_iter()
|
||||
.filter_map(|(language_server_id, diagnostic_group)| {
|
||||
let language_server =
|
||||
local_lsp_store.running_language_server_for_id(language_server_id)?;
|
||||
|
||||
Some((
|
||||
language_server.name(),
|
||||
diagnostic_group.resolve::<usize>(&snapshot),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let full_path: Arc<Path> = snapshot
|
||||
.file()
|
||||
.map(|f| Arc::from(f.full_path(cx).as_path()))
|
||||
.unwrap_or_else(|| Arc::from(Path::new("untitled")));
|
||||
let full_path_str = full_path.to_string_lossy().to_string();
|
||||
let cursor_point = cursor.to_point(&snapshot);
|
||||
let cursor_offset = cursor_point.to_offset(&snapshot);
|
||||
let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
|
||||
let gather_task = gather_context(
|
||||
project,
|
||||
full_path_str,
|
||||
&snapshot,
|
||||
cursor_point,
|
||||
make_events_prompt,
|
||||
can_collect_data,
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let request_sent_at = Instant::now();
|
||||
|
||||
struct BackgroundValues {
|
||||
input_events: String,
|
||||
input_excerpt: String,
|
||||
speculated_output: String,
|
||||
editable_range: Range<usize>,
|
||||
input_outline: String,
|
||||
}
|
||||
|
||||
let values = cx
|
||||
.background_spawn({
|
||||
let snapshot = snapshot.clone();
|
||||
let path = path.clone();
|
||||
async move {
|
||||
let path = path.to_string_lossy();
|
||||
let input_excerpt = excerpt_for_cursor_position(
|
||||
cursor_point,
|
||||
&path,
|
||||
&snapshot,
|
||||
MAX_REWRITE_TOKENS,
|
||||
MAX_CONTEXT_TOKENS,
|
||||
);
|
||||
let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS);
|
||||
let input_outline = prompt_for_outline(&snapshot);
|
||||
|
||||
anyhow::Ok(BackgroundValues {
|
||||
input_events,
|
||||
input_excerpt: input_excerpt.prompt,
|
||||
speculated_output: input_excerpt.speculated_output,
|
||||
editable_range: input_excerpt.editable_range.to_offset(&snapshot),
|
||||
input_outline,
|
||||
})
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
let GatherContextOutput {
|
||||
body,
|
||||
editable_range,
|
||||
} = gather_task.await?;
|
||||
|
||||
log::debug!(
|
||||
"Events:\n{}\nExcerpt:\n{:?}",
|
||||
values.input_events,
|
||||
values.input_excerpt
|
||||
body.input_events,
|
||||
body.input_excerpt
|
||||
);
|
||||
|
||||
let body = PredictEditsBody {
|
||||
input_events: values.input_events.clone(),
|
||||
input_excerpt: values.input_excerpt.clone(),
|
||||
speculated_output: Some(values.speculated_output),
|
||||
outline: Some(values.input_outline.clone()),
|
||||
can_collect_data,
|
||||
diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| {
|
||||
diagnostic_groups
|
||||
.into_iter()
|
||||
.map(|(name, diagnostic_group)| {
|
||||
Ok((name.to_string(), serde_json::to_value(diagnostic_group)?))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()
|
||||
.log_err()
|
||||
}),
|
||||
};
|
||||
let input_outline = body.outline.clone().unwrap_or_default();
|
||||
let input_events = body.input_events.clone();
|
||||
let input_excerpt = body.input_excerpt.clone();
|
||||
|
||||
let response = perform_predict_edits(PerformPredictEditsParams {
|
||||
client,
|
||||
|
@ -546,13 +490,13 @@ impl Zeta {
|
|||
response,
|
||||
buffer,
|
||||
&snapshot,
|
||||
values.editable_range,
|
||||
editable_range,
|
||||
cursor_offset,
|
||||
path,
|
||||
values.input_outline,
|
||||
values.input_events,
|
||||
values.input_excerpt,
|
||||
request_sent_at,
|
||||
full_path,
|
||||
input_outline,
|
||||
input_events,
|
||||
input_excerpt,
|
||||
buffer_snapshotted_at,
|
||||
&cx,
|
||||
)
|
||||
.await
|
||||
|
@ -751,7 +695,7 @@ and then another
|
|||
)
|
||||
}
|
||||
|
||||
fn perform_predict_edits(
|
||||
pub fn perform_predict_edits(
|
||||
params: PerformPredictEditsParams,
|
||||
) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
|
||||
async move {
|
||||
|
@ -906,7 +850,7 @@ and then another
|
|||
input_outline: String,
|
||||
input_events: String,
|
||||
input_excerpt: String,
|
||||
request_sent_at: Instant,
|
||||
buffer_snapshotted_at: Instant,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<Option<InlineCompletion>>> {
|
||||
let snapshot = snapshot.clone();
|
||||
|
@ -952,7 +896,7 @@ and then another
|
|||
input_events: input_events.into(),
|
||||
input_excerpt: input_excerpt.into(),
|
||||
output_excerpt,
|
||||
request_sent_at,
|
||||
buffer_snapshotted_at,
|
||||
response_received_at: Instant::now(),
|
||||
}))
|
||||
})
|
||||
|
@ -1136,7 +1080,7 @@ and then another
|
|||
}
|
||||
}
|
||||
|
||||
struct PerformPredictEditsParams {
|
||||
pub struct PerformPredictEditsParams {
|
||||
pub client: Arc<Client>,
|
||||
pub llm_token: LlmApiToken,
|
||||
pub app_version: SemanticVersion,
|
||||
|
@ -1211,6 +1155,77 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b:
|
|||
.sum()
|
||||
}
|
||||
|
||||
pub struct GatherContextOutput {
|
||||
pub body: PredictEditsBody,
|
||||
pub editable_range: Range<usize>,
|
||||
}
|
||||
|
||||
pub fn gather_context(
|
||||
project: Option<&Entity<Project>>,
|
||||
full_path_str: String,
|
||||
snapshot: &BufferSnapshot,
|
||||
cursor_point: language::Point,
|
||||
make_events_prompt: impl FnOnce() -> String + Send + 'static,
|
||||
can_collect_data: bool,
|
||||
cx: &App,
|
||||
) -> Task<Result<GatherContextOutput>> {
|
||||
let local_lsp_store =
|
||||
project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
|
||||
let diagnostic_groups: Vec<(String, serde_json::Value)> =
|
||||
if let Some(local_lsp_store) = local_lsp_store {
|
||||
snapshot
|
||||
.diagnostic_groups(None)
|
||||
.into_iter()
|
||||
.filter_map(|(language_server_id, diagnostic_group)| {
|
||||
let language_server =
|
||||
local_lsp_store.running_language_server_for_id(language_server_id)?;
|
||||
let diagnostic_group = diagnostic_group.resolve::<usize>(&snapshot);
|
||||
let language_server_name = language_server.name().to_string();
|
||||
let serialized = serde_json::to_value(diagnostic_group).unwrap();
|
||||
Some((language_server_name, serialized))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
cx.background_spawn({
|
||||
let snapshot = snapshot.clone();
|
||||
async move {
|
||||
let diagnostic_groups = if diagnostic_groups.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(diagnostic_groups)
|
||||
};
|
||||
|
||||
let input_excerpt = excerpt_for_cursor_position(
|
||||
cursor_point,
|
||||
&full_path_str,
|
||||
&snapshot,
|
||||
MAX_REWRITE_TOKENS,
|
||||
MAX_CONTEXT_TOKENS,
|
||||
);
|
||||
let input_events = make_events_prompt();
|
||||
let input_outline = prompt_for_outline(&snapshot);
|
||||
let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
|
||||
|
||||
let body = PredictEditsBody {
|
||||
input_events,
|
||||
input_excerpt: input_excerpt.prompt,
|
||||
speculated_output: Some(input_excerpt.speculated_output),
|
||||
outline: Some(input_outline),
|
||||
can_collect_data,
|
||||
diagnostic_groups,
|
||||
};
|
||||
|
||||
Ok(GatherContextOutput {
|
||||
body,
|
||||
editable_range,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
|
||||
let mut input_outline = String::new();
|
||||
|
||||
|
@ -1261,7 +1276,7 @@ struct RegisteredBuffer {
|
|||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum Event {
|
||||
pub enum Event {
|
||||
BufferChange {
|
||||
old_snapshot: BufferSnapshot,
|
||||
new_snapshot: BufferSnapshot,
|
||||
|
@ -1845,7 +1860,7 @@ mod tests {
|
|||
input_events: "".into(),
|
||||
input_excerpt: "".into(),
|
||||
output_excerpt: "".into(),
|
||||
request_sent_at: Instant::now(),
|
||||
buffer_snapshotted_at: Instant::now(),
|
||||
response_received_at: Instant::now(),
|
||||
};
|
||||
|
||||
|
|
45
crates/zeta_cli/Cargo.toml
Normal file
45
crates/zeta_cli/Cargo.toml
Normal file
|
@ -0,0 +1,45 @@
|
|||
[package]
|
||||
name = "zeta_cli"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "zeta"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
clap.workspace = true
|
||||
client.workspace = true
|
||||
debug_adapter_extension.workspace = true
|
||||
extension.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
gpui_tokio.workspace = true
|
||||
language.workspace = true
|
||||
language_extension.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
languages = { workspace = true, features = ["load-grammars"] }
|
||||
node_runtime.workspace = true
|
||||
paths.workspace = true
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
release_channel.workspace = true
|
||||
reqwest_client.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
shellexpand.workspace = true
|
||||
terminal_view.workspace = true
|
||||
util.workspace = true
|
||||
watch.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zeta.workspace = true
|
||||
smol.workspace = true
|
1
crates/zeta_cli/LICENSE-GPL
Symbolic link
1
crates/zeta_cli/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
14
crates/zeta_cli/build.rs
Normal file
14
crates/zeta_cli/build.rs
Normal file
|
@ -0,0 +1,14 @@
|
|||
fn main() {
|
||||
let cargo_toml =
|
||||
std::fs::read_to_string("../zed/Cargo.toml").expect("Failed to read Cargo.toml");
|
||||
let version = cargo_toml
|
||||
.lines()
|
||||
.find(|line| line.starts_with("version = "))
|
||||
.expect("Version not found in crates/zed/Cargo.toml")
|
||||
.split('=')
|
||||
.nth(1)
|
||||
.expect("Invalid version format")
|
||||
.trim()
|
||||
.trim_matches('"');
|
||||
println!("cargo:rustc-env=ZED_PKG_VERSION={}", version);
|
||||
}
|
128
crates/zeta_cli/src/headless.rs
Normal file
128
crates/zeta_cli/src/headless.rs
Normal file
|
@ -0,0 +1,128 @@
|
|||
use client::{Client, ProxySettings, UserStore};
|
||||
use extension::ExtensionHostProxy;
|
||||
use fs::RealFs;
|
||||
use gpui::http_client::read_proxy_from_env;
|
||||
use gpui::{App, AppContext, Entity};
|
||||
use gpui_tokio::Tokio;
|
||||
use language::LanguageRegistry;
|
||||
use language_extension::LspAccess;
|
||||
use node_runtime::{NodeBinaryOptions, NodeRuntime};
|
||||
use project::Project;
|
||||
use project::project_settings::ProjectSettings;
|
||||
use release_channel::AppVersion;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use util::ResultExt as _;
|
||||
|
||||
/// Headless subset of `workspace::AppState`.
|
||||
pub struct ZetaCliAppState {
|
||||
pub languages: Arc<LanguageRegistry>,
|
||||
pub client: Arc<Client>,
|
||||
pub user_store: Entity<UserStore>,
|
||||
pub fs: Arc<dyn fs::Fs>,
|
||||
pub node_runtime: NodeRuntime,
|
||||
}
|
||||
|
||||
// TODO: dedupe with crates/eval/src/eval.rs
|
||||
pub fn init(cx: &mut App) -> ZetaCliAppState {
|
||||
let app_version = AppVersion::load(env!("ZED_PKG_VERSION"));
|
||||
release_channel::init(app_version, cx);
|
||||
gpui_tokio::init(cx);
|
||||
|
||||
let mut settings_store = SettingsStore::new(cx);
|
||||
settings_store
|
||||
.set_default_settings(settings::default_settings().as_ref(), cx)
|
||||
.unwrap();
|
||||
cx.set_global(settings_store);
|
||||
client::init_settings(cx);
|
||||
|
||||
// Set User-Agent so we can download language servers from GitHub
|
||||
let user_agent = format!(
|
||||
"Zed/{} ({}; {})",
|
||||
app_version,
|
||||
std::env::consts::OS,
|
||||
std::env::consts::ARCH
|
||||
);
|
||||
let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
|
||||
let proxy_url = proxy_str
|
||||
.as_ref()
|
||||
.and_then(|input| input.parse().ok())
|
||||
.or_else(read_proxy_from_env);
|
||||
let http = {
|
||||
let _guard = Tokio::handle(cx).enter();
|
||||
|
||||
ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
|
||||
.expect("could not start HTTP client")
|
||||
};
|
||||
cx.set_http_client(Arc::new(http));
|
||||
|
||||
Project::init_settings(cx);
|
||||
|
||||
let client = Client::production(cx);
|
||||
cx.set_http_client(client.http_client());
|
||||
|
||||
let git_binary_path = None;
|
||||
let fs = Arc::new(RealFs::new(
|
||||
git_binary_path,
|
||||
cx.background_executor().clone(),
|
||||
));
|
||||
|
||||
let mut languages = LanguageRegistry::new(cx.background_executor().clone());
|
||||
languages.set_language_server_download_dir(paths::languages_dir().clone());
|
||||
let languages = Arc::new(languages);
|
||||
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
|
||||
extension::init(cx);
|
||||
|
||||
let (mut tx, rx) = watch::channel(None);
|
||||
cx.observe_global::<SettingsStore>(move |cx| {
|
||||
let settings = &ProjectSettings::get_global(cx).node;
|
||||
let options = NodeBinaryOptions {
|
||||
allow_path_lookup: !settings.ignore_system_version,
|
||||
allow_binary_download: true,
|
||||
use_paths: settings.path.as_ref().map(|node_path| {
|
||||
let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
|
||||
let npm_path = settings
|
||||
.npm_path
|
||||
.as_ref()
|
||||
.map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
|
||||
(
|
||||
node_path.clone(),
|
||||
npm_path.unwrap_or_else(|| {
|
||||
let base_path = PathBuf::new();
|
||||
node_path.parent().unwrap_or(&base_path).join("npm")
|
||||
}),
|
||||
)
|
||||
}),
|
||||
};
|
||||
tx.send(Some(options)).log_err();
|
||||
})
|
||||
.detach();
|
||||
let node_runtime = NodeRuntime::new(client.http_client(), None, rx);
|
||||
|
||||
let extension_host_proxy = ExtensionHostProxy::global(cx);
|
||||
|
||||
language::init(cx);
|
||||
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
|
||||
language_extension::init(
|
||||
LspAccess::Noop,
|
||||
extension_host_proxy.clone(),
|
||||
languages.clone(),
|
||||
);
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
languages::init(languages.clone(), node_runtime.clone(), cx);
|
||||
prompt_store::init(cx);
|
||||
terminal_view::init(cx);
|
||||
|
||||
ZetaCliAppState {
|
||||
languages,
|
||||
client,
|
||||
user_store,
|
||||
fs,
|
||||
node_runtime,
|
||||
}
|
||||
}
|
376
crates/zeta_cli/src/main.rs
Normal file
376
crates/zeta_cli/src/main.rs
Normal file
|
@ -0,0 +1,376 @@
|
|||
mod headless;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Parser, Subcommand};
|
||||
use futures::channel::mpsc;
|
||||
use futures::{FutureExt as _, StreamExt as _};
|
||||
use gpui::{AppContext, Application, AsyncApp};
|
||||
use gpui::{Entity, Task};
|
||||
use language::Bias;
|
||||
use language::Buffer;
|
||||
use language::Point;
|
||||
use language_model::LlmApiToken;
|
||||
use project::{Project, ProjectPath};
|
||||
use release_channel::AppVersion;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::exit;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context};
|
||||
|
||||
use crate::headless::ZetaCliAppState;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "zeta")]
|
||||
struct ZetaCliArgs {
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum Commands {
|
||||
Context(ContextArgs),
|
||||
Predict {
|
||||
#[arg(long)]
|
||||
predict_edits_body: Option<FileOrStdin>,
|
||||
#[clap(flatten)]
|
||||
context_args: Option<ContextArgs>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
#[group(requires = "worktree")]
|
||||
struct ContextArgs {
|
||||
#[arg(long)]
|
||||
worktree: PathBuf,
|
||||
#[arg(long)]
|
||||
cursor: CursorPosition,
|
||||
#[arg(long)]
|
||||
use_language_server: bool,
|
||||
#[arg(long)]
|
||||
events: Option<FileOrStdin>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum FileOrStdin {
|
||||
File(PathBuf),
|
||||
Stdin,
|
||||
}
|
||||
|
||||
impl FileOrStdin {
|
||||
async fn read_to_string(&self) -> Result<String, std::io::Error> {
|
||||
match self {
|
||||
FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
|
||||
FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for FileOrStdin {
|
||||
type Err = <PathBuf as FromStr>::Err;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"-" => Ok(Self::Stdin),
|
||||
_ => Ok(Self::File(PathBuf::from_str(s)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CursorPosition {
|
||||
path: PathBuf,
|
||||
point: Point,
|
||||
}
|
||||
|
||||
impl FromStr for CursorPosition {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self> {
|
||||
let parts: Vec<&str> = s.split(':').collect();
|
||||
if parts.len() != 3 {
|
||||
return Err(anyhow!(
|
||||
"Invalid cursor format. Expected 'file.rs:line:column', got '{}'",
|
||||
s
|
||||
));
|
||||
}
|
||||
|
||||
let path = PathBuf::from(parts[0]);
|
||||
let line: u32 = parts[1]
|
||||
.parse()
|
||||
.map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
|
||||
let column: u32 = parts[2]
|
||||
.parse()
|
||||
.map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
|
||||
|
||||
// Convert from 1-based to 0-based indexing
|
||||
let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
|
||||
|
||||
Ok(CursorPosition { path, point })
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_context(
|
||||
args: ContextArgs,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<GatherContextOutput> {
|
||||
let ContextArgs {
|
||||
worktree: worktree_path,
|
||||
cursor,
|
||||
use_language_server,
|
||||
events,
|
||||
} = args;
|
||||
|
||||
let worktree_path = worktree_path.canonicalize()?;
|
||||
if cursor.path.is_absolute() {
|
||||
return Err(anyhow!("Absolute paths are not supported in --cursor"));
|
||||
}
|
||||
|
||||
let (project, _lsp_open_handle, buffer) = if use_language_server {
|
||||
let (project, lsp_open_handle, buffer) =
|
||||
open_buffer_with_language_server(&worktree_path, &cursor.path, &app_state, cx).await?;
|
||||
(Some(project), Some(lsp_open_handle), buffer)
|
||||
} else {
|
||||
let abs_path = worktree_path.join(&cursor.path);
|
||||
let content = smol::fs::read_to_string(&abs_path).await?;
|
||||
let buffer = cx.new(|cx| Buffer::local(content, cx))?;
|
||||
(None, None, buffer)
|
||||
};
|
||||
|
||||
let worktree_name = worktree_path
|
||||
.file_name()
|
||||
.ok_or_else(|| anyhow!("--worktree path must end with a folder name"))?;
|
||||
let full_path_str = PathBuf::from(worktree_name)
|
||||
.join(&cursor.path)
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
|
||||
let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
|
||||
if clipped_cursor != cursor.point {
|
||||
let max_row = snapshot.max_point().row;
|
||||
if cursor.point.row < max_row {
|
||||
return Err(anyhow!(
|
||||
"Cursor position {:?} is out of bounds (line length is {})",
|
||||
cursor.point,
|
||||
snapshot.line_len(cursor.point.row)
|
||||
));
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Cursor position {:?} is out of bounds (max row is {})",
|
||||
cursor.point,
|
||||
max_row
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let events = match events {
|
||||
Some(events) => events.read_to_string().await?,
|
||||
None => String::new(),
|
||||
};
|
||||
let can_collect_data = false;
|
||||
cx.update(|cx| {
|
||||
gather_context(
|
||||
project.as_ref(),
|
||||
full_path_str,
|
||||
&snapshot,
|
||||
clipped_cursor,
|
||||
move || events,
|
||||
can_collect_data,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn open_buffer_with_language_server(
|
||||
worktree_path: &Path,
|
||||
path: &Path,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(Entity<Project>, Entity<Entity<Buffer>>, Entity<Buffer>)> {
|
||||
let project = cx.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
let worktree = project
|
||||
.update(cx, |project, cx| {
|
||||
project.create_worktree(worktree_path, true, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
|
||||
worktree_id: worktree.id(),
|
||||
path: path.to_path_buf().into(),
|
||||
})?;
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||
.await?;
|
||||
|
||||
let lsp_open_handle = project.update(cx, |project, cx| {
|
||||
project.register_buffer_with_language_servers(&buffer, cx)
|
||||
})?;
|
||||
|
||||
let log_prefix = path.to_string_lossy().to_string();
|
||||
wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
|
||||
|
||||
Ok((project, lsp_open_handle, buffer))
|
||||
}
|
||||
|
||||
// TODO: Dedupe with similar function in crates/eval/src/instance.rs
|
||||
pub fn wait_for_lang_server(
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
log_prefix: String,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<()>> {
|
||||
println!("{}⏵ Waiting for language server", log_prefix);
|
||||
|
||||
let (mut tx, mut rx) = mpsc::channel(1);
|
||||
|
||||
let lsp_store = project
|
||||
.read_with(cx, |project, _| project.lsp_store())
|
||||
.unwrap();
|
||||
|
||||
let has_lang_server = buffer
|
||||
.update(cx, |buffer, cx| {
|
||||
lsp_store.update(cx, |lsp_store, cx| {
|
||||
lsp_store
|
||||
.language_servers_for_local_buffer(&buffer, cx)
|
||||
.next()
|
||||
.is_some()
|
||||
})
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
if has_lang_server {
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||
.unwrap()
|
||||
.detach();
|
||||
}
|
||||
|
||||
let subscriptions = [
|
||||
cx.subscribe(&lsp_store, {
|
||||
let log_prefix = log_prefix.clone();
|
||||
move |_, event, _| match event {
|
||||
project::LspStoreEvent::LanguageServerUpdate {
|
||||
message:
|
||||
client::proto::update_language_server::Variant::WorkProgress(
|
||||
client::proto::LspWorkProgress {
|
||||
message: Some(message),
|
||||
..
|
||||
},
|
||||
),
|
||||
..
|
||||
} => println!("{}⟲ {message}", log_prefix),
|
||||
_ => {}
|
||||
}
|
||||
}),
|
||||
cx.subscribe(&project, {
|
||||
let buffer = buffer.clone();
|
||||
move |project, event, cx| match event {
|
||||
project::Event::LanguageServerAdded(_, _, _) => {
|
||||
let buffer = buffer.clone();
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer, cx))
|
||||
.detach();
|
||||
}
|
||||
project::Event::DiskBasedDiagnosticsFinished { .. } => {
|
||||
tx.try_send(()).ok();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}),
|
||||
];
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
|
||||
let result = futures::select! {
|
||||
_ = rx.next() => {
|
||||
println!("{}⚑ Language server idle", log_prefix);
|
||||
anyhow::Ok(())
|
||||
},
|
||||
_ = timeout.fuse() => {
|
||||
anyhow::bail!("LSP wait timed out after 5 minutes");
|
||||
}
|
||||
};
|
||||
drop(subscriptions);
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args = ZetaCliArgs::parse();
|
||||
let http_client = Arc::new(ReqwestClient::new());
|
||||
let app = Application::headless().with_http_client(http_client);
|
||||
|
||||
app.run(move |cx| {
|
||||
let app_state = Arc::new(headless::init(cx));
|
||||
cx.spawn(async move |cx| {
|
||||
let result = match args.command {
|
||||
Commands::Context(context_args) => get_context(context_args, &app_state, cx)
|
||||
.await
|
||||
.map(|output| serde_json::to_string_pretty(&output.body).unwrap()),
|
||||
Commands::Predict {
|
||||
predict_edits_body,
|
||||
context_args,
|
||||
} => {
|
||||
cx.spawn(async move |cx| {
|
||||
let app_version = cx.update(|cx| AppVersion::global(cx))?;
|
||||
app_state.client.sign_in(true, cx).await?;
|
||||
let llm_token = LlmApiToken::default();
|
||||
llm_token.refresh(&app_state.client).await?;
|
||||
|
||||
let predict_edits_body =
|
||||
if let Some(predict_edits_body) = predict_edits_body {
|
||||
serde_json::from_str(&predict_edits_body.read_to_string().await?)?
|
||||
} else if let Some(context_args) = context_args {
|
||||
get_context(context_args, &app_state, cx).await?.body
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Expected either --predict-edits-body-file \
|
||||
or the required args of the `context` command."
|
||||
));
|
||||
};
|
||||
|
||||
let (response, _usage) =
|
||||
Zeta::perform_predict_edits(PerformPredictEditsParams {
|
||||
client: app_state.client.clone(),
|
||||
llm_token,
|
||||
app_version,
|
||||
body: predict_edits_body,
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(response.output_excerpt)
|
||||
})
|
||||
.await
|
||||
}
|
||||
};
|
||||
match result {
|
||||
Ok(output) => {
|
||||
println!("{}", output);
|
||||
let _ = cx.update(|cx| cx.quit());
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed: {:?}", e);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
});
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue