Allow clients to run Zed tasks on remote projects (#12199)

Release Notes:

- Enabled Zed tasks on remote projects with ssh connection string
specified

---------

Co-authored-by: Conrad Irwin <conrad@zed.dev>
This commit is contained in:
Kirill Bulatov 2024-05-24 22:26:57 +03:00 committed by GitHub
parent df35fd0026
commit 055a13a9b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 1250 additions and 600 deletions

View file

@ -36,7 +36,7 @@ use git::{blame::Blame, repository::GitRepository};
use globset::{Glob, GlobSet, GlobSetBuilder};
use gpui::{
AnyModel, AppContext, AsyncAppContext, BackgroundExecutor, BorrowAppContext, Context, Entity,
EventEmitter, Model, ModelContext, PromptLevel, Task, WeakModel,
EventEmitter, Model, ModelContext, PromptLevel, SharedString, Task, WeakModel,
};
use itertools::Itertools;
use language::{
@ -47,10 +47,10 @@ use language::{
serialize_version, split_operations,
},
range_from_lsp, Bias, Buffer, BufferSnapshot, CachedLspAdapter, Capability, CodeLabel,
Diagnostic, DiagnosticEntry, DiagnosticSet, Diff, Documentation, Event as BufferEvent,
File as _, Language, LanguageRegistry, LanguageServerName, LocalFile, LspAdapterDelegate,
Operation, Patch, PendingLanguageServer, PointUtf16, TextBufferSnapshot, ToOffset,
ToPointUtf16, Transaction, Unclipped,
ContextProvider, Diagnostic, DiagnosticEntry, DiagnosticSet, Diff, Documentation,
Event as BufferEvent, File as _, Language, LanguageRegistry, LanguageServerName, LocalFile,
LspAdapterDelegate, Operation, Patch, PendingLanguageServer, PointUtf16, TextBufferSnapshot,
ToOffset, ToPointUtf16, Transaction, Unclipped,
};
use log::error;
use lsp::{
@ -80,6 +80,7 @@ use similar::{ChangeTag, TextDiff};
use smol::channel::{Receiver, Sender};
use smol::lock::Semaphore;
use std::{
borrow::Cow,
cmp::{self, Ordering},
convert::TryInto,
env,
@ -97,7 +98,10 @@ use std::{
},
time::{Duration, Instant},
};
use task::static_source::{StaticSource, TrackedFile};
use task::{
static_source::{StaticSource, TrackedFile},
RevealStrategy, TaskContext, TaskTemplate, TaskVariables, VariableName,
};
use terminals::Terminals;
use text::{Anchor, BufferId, LineEnding};
use util::{
@ -676,6 +680,8 @@ impl Project {
client.add_model_request_handler(Self::handle_lsp_command::<lsp_ext_command::ExpandMacro>);
client.add_model_request_handler(Self::handle_blame_buffer);
client.add_model_request_handler(Self::handle_multi_lsp_query);
client.add_model_request_handler(Self::handle_task_context_for_location);
client.add_model_request_handler(Self::handle_task_templates);
}
pub fn local(
@ -1257,6 +1263,19 @@ impl Project {
self.dev_server_project_id
}
pub fn ssh_connection_string(&self, cx: &ModelContext<Self>) -> Option<SharedString> {
if self.is_local() {
return None;
}
let dev_server_id = self.dev_server_project_id()?;
dev_server_projects::Store::global(cx)
.read(cx)
.dev_server_for_project(dev_server_id)?
.ssh_connection_string
.clone()
}
pub fn replica_id(&self) -> ReplicaId {
match self.client_state {
ProjectClientState::Remote { replica_id, .. } => replica_id,
@ -7892,7 +7911,7 @@ impl Project {
TaskSourceKind::Worktree {
id: remote_worktree_id,
abs_path,
id_base: "local_tasks_for_worktree",
id_base: "local_tasks_for_worktree".into(),
},
|tx, cx| StaticSource::new(TrackedFile::new(tasks_file_rx, tx, cx)),
cx,
@ -7912,7 +7931,7 @@ impl Project {
TaskSourceKind::Worktree {
id: remote_worktree_id,
abs_path,
id_base: "local_vscode_tasks_for_worktree",
id_base: "local_vscode_tasks_for_worktree".into(),
},
|tx, cx| {
StaticSource::new(TrackedFile::new_convertible::<
@ -9424,6 +9443,122 @@ impl Project {
})
}
async fn handle_task_context_for_location(
project: Model<Self>,
envelope: TypedEnvelope<proto::TaskContextForLocation>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<proto::TaskContext> {
let location = envelope
.payload
.location
.context("no location given for task context handling")?;
let location = cx
.update(|cx| deserialize_location(&project, location, cx))?
.await?;
let context_task = project.update(&mut cx, |project, cx| {
let captured_variables = {
let mut variables = TaskVariables::default();
for range in location
.buffer
.read(cx)
.snapshot()
.runnable_ranges(location.range.clone())
{
for (capture_name, value) in range.extra_captures {
variables.insert(VariableName::Custom(capture_name.into()), value);
}
}
variables
};
project.task_context_for_location(captured_variables, location, cx)
})?;
let task_context = context_task.await.unwrap_or_default();
Ok(proto::TaskContext {
cwd: task_context
.cwd
.map(|cwd| cwd.to_string_lossy().to_string()),
task_variables: task_context
.task_variables
.into_iter()
.map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
.collect(),
})
}
async fn handle_task_templates(
project: Model<Self>,
envelope: TypedEnvelope<proto::TaskTemplates>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<proto::TaskTemplatesResponse> {
let worktree = envelope.payload.worktree_id.map(WorktreeId::from_proto);
let location = match envelope.payload.location {
Some(location) => Some(
cx.update(|cx| deserialize_location(&project, location, cx))?
.await
.context("task templates request location deserializing")?,
),
None => None,
};
let templates = project
.update(&mut cx, |project, cx| {
project.task_templates(worktree, location, cx)
})?
.await
.context("receiving task templates")?
.into_iter()
.map(|(kind, template)| {
let kind = Some(match kind {
TaskSourceKind::UserInput => proto::task_source_kind::Kind::UserInput(
proto::task_source_kind::UserInput {},
),
TaskSourceKind::Worktree {
id,
abs_path,
id_base,
} => {
proto::task_source_kind::Kind::Worktree(proto::task_source_kind::Worktree {
id: id.to_proto(),
abs_path: abs_path.to_string_lossy().to_string(),
id_base: id_base.to_string(),
})
}
TaskSourceKind::AbsPath { id_base, abs_path } => {
proto::task_source_kind::Kind::AbsPath(proto::task_source_kind::AbsPath {
abs_path: abs_path.to_string_lossy().to_string(),
id_base: id_base.to_string(),
})
}
TaskSourceKind::Language { name } => {
proto::task_source_kind::Kind::Language(proto::task_source_kind::Language {
name: name.to_string(),
})
}
});
let kind = Some(proto::TaskSourceKind { kind });
let template = Some(proto::TaskTemplate {
label: template.label,
command: template.command,
args: template.args,
env: template.env.into_iter().collect(),
cwd: template.cwd,
use_new_terminal: template.use_new_terminal,
allow_concurrent_runs: template.allow_concurrent_runs,
reveal: match template.reveal {
RevealStrategy::Always => proto::RevealStrategy::Always as i32,
RevealStrategy::Never => proto::RevealStrategy::Never as i32,
},
tags: template.tags,
});
proto::TemplatePair { kind, template }
})
.collect();
Ok(proto::TaskTemplatesResponse { templates })
}
async fn try_resolve_code_action(
lang_server: &LanguageServer,
action: &mut CodeAction,
@ -10410,6 +10545,223 @@ impl Project {
Vec::new()
}
}
pub fn task_context_for_location(
&self,
captured_variables: TaskVariables,
location: Location,
cx: &mut ModelContext<'_, Project>,
) -> Task<Option<TaskContext>> {
if self.is_local() {
let cwd = self.task_cwd(cx).log_err().flatten();
cx.spawn(|project, cx| async move {
let mut task_variables = cx
.update(|cx| {
combine_task_variables(
captured_variables,
location,
BasicContextProvider::new(project.upgrade()?),
cx,
)
.log_err()
})
.ok()
.flatten()?;
// Remove all custom entries starting with _, as they're not intended for use by the end user.
task_variables.sweep();
Some(TaskContext {
cwd,
task_variables,
})
})
} else if let Some(project_id) = self
.remote_id()
.filter(|_| self.ssh_connection_string(cx).is_some())
{
let task_context = self.client().request(proto::TaskContextForLocation {
project_id,
location: Some(proto::Location {
buffer_id: location.buffer.read(cx).remote_id().into(),
start: Some(serialize_anchor(&location.range.start)),
end: Some(serialize_anchor(&location.range.end)),
}),
});
cx.background_executor().spawn(async move {
let task_context = task_context.await.log_err()?;
Some(TaskContext {
cwd: task_context.cwd.map(PathBuf::from),
task_variables: task_context
.task_variables
.into_iter()
.filter_map(
|(variable_name, variable_value)| match variable_name.parse() {
Ok(variable_name) => Some((variable_name, variable_value)),
Err(()) => {
log::error!("Unknown variable name: {variable_name}");
None
}
},
)
.collect(),
})
})
} else {
Task::ready(None)
}
}
pub fn task_templates(
&self,
worktree: Option<WorktreeId>,
location: Option<Location>,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<(TaskSourceKind, TaskTemplate)>>> {
if self.is_local() {
let language = location
.and_then(|location| location.buffer.read(cx).language_at(location.range.start));
Task::ready(Ok(self
.task_inventory()
.read(cx)
.list_tasks(language, worktree)))
} else if let Some(project_id) = self
.remote_id()
.filter(|_| self.ssh_connection_string(cx).is_some())
{
let remote_templates =
self.query_remote_task_templates(project_id, worktree, location.as_ref(), cx);
cx.background_executor().spawn(remote_templates)
} else {
Task::ready(Ok(Vec::new()))
}
}
pub fn query_remote_task_templates(
&self,
project_id: u64,
worktree: Option<WorktreeId>,
location: Option<&Location>,
cx: &AppContext,
) -> Task<Result<Vec<(TaskSourceKind, TaskTemplate)>>> {
let client = self.client();
let location = location.map(|location| serialize_location(location, cx));
cx.spawn(|_| async move {
let response = client
.request(proto::TaskTemplates {
project_id,
worktree_id: worktree.map(|id| id.to_proto()),
location,
})
.await?;
Ok(response
.templates
.into_iter()
.filter_map(|template_pair| {
let task_source_kind = match template_pair.kind?.kind? {
proto::task_source_kind::Kind::UserInput(_) => TaskSourceKind::UserInput,
proto::task_source_kind::Kind::Worktree(worktree) => {
TaskSourceKind::Worktree {
id: WorktreeId::from_proto(worktree.id),
abs_path: PathBuf::from(worktree.abs_path),
id_base: Cow::Owned(worktree.id_base),
}
}
proto::task_source_kind::Kind::AbsPath(abs_path) => {
TaskSourceKind::AbsPath {
id_base: Cow::Owned(abs_path.id_base),
abs_path: PathBuf::from(abs_path.abs_path),
}
}
proto::task_source_kind::Kind::Language(language) => {
TaskSourceKind::Language {
name: language.name.into(),
}
}
};
let proto_template = template_pair.template?;
let reveal = match proto::RevealStrategy::from_i32(proto_template.reveal)
.unwrap_or(proto::RevealStrategy::Always)
{
proto::RevealStrategy::Always => RevealStrategy::Always,
proto::RevealStrategy::Never => RevealStrategy::Never,
};
let task_template = TaskTemplate {
label: proto_template.label,
command: proto_template.command,
args: proto_template.args,
env: proto_template.env.into_iter().collect(),
cwd: proto_template.cwd,
use_new_terminal: proto_template.use_new_terminal,
allow_concurrent_runs: proto_template.allow_concurrent_runs,
reveal,
tags: proto_template.tags,
};
Some((task_source_kind, task_template))
})
.collect())
})
}
fn task_cwd(&self, cx: &AppContext) -> anyhow::Result<Option<PathBuf>> {
let available_worktrees = self
.worktrees()
.filter(|worktree| {
let worktree = worktree.read(cx);
worktree.is_visible()
&& worktree.is_local()
&& worktree.root_entry().map_or(false, |e| e.is_dir())
})
.collect::<Vec<_>>();
let cwd = match available_worktrees.len() {
0 => None,
1 => Some(available_worktrees[0].read(cx).abs_path()),
_ => {
let cwd_for_active_entry = self.active_entry().and_then(|entry_id| {
available_worktrees.into_iter().find_map(|worktree| {
let worktree = worktree.read(cx);
if worktree.contains_entry(entry_id) {
Some(worktree.abs_path())
} else {
None
}
})
});
anyhow::ensure!(
cwd_for_active_entry.is_some(),
"Cannot determine task cwd for multiple worktrees"
);
cwd_for_active_entry
}
};
Ok(cwd.map(|path| path.to_path_buf()))
}
}
fn combine_task_variables(
mut captured_variables: TaskVariables,
location: Location,
baseline: BasicContextProvider,
cx: &mut AppContext,
) -> anyhow::Result<TaskVariables> {
let language_context_provider = location
.buffer
.read(cx)
.language()
.and_then(|language| language.context_provider());
let baseline = baseline
.build_context(&captured_variables, &location, cx)
.context("building basic default context")?;
captured_variables.extend(baseline);
if let Some(provider) = language_context_provider {
captured_variables.extend(
provider
.build_context(&captured_variables, &location, cx)
.context("building provider context")?,
);
}
Ok(captured_variables)
}
async fn populate_labels_for_symbols(
@ -11238,3 +11590,40 @@ impl std::fmt::Display for NoRepositoryError {
}
impl std::error::Error for NoRepositoryError {}
fn serialize_location(location: &Location, cx: &AppContext) -> proto::Location {
proto::Location {
buffer_id: location.buffer.read(cx).remote_id().into(),
start: Some(serialize_anchor(&location.range.start)),
end: Some(serialize_anchor(&location.range.end)),
}
}
fn deserialize_location(
project: &Model<Project>,
location: proto::Location,
cx: &mut AppContext,
) -> Task<Result<Location>> {
let buffer_id = match BufferId::new(location.buffer_id) {
Ok(id) => id,
Err(e) => return Task::ready(Err(e)),
};
let buffer_task = project.update(cx, |project, cx| {
project.wait_for_remote_buffer(buffer_id, cx)
});
cx.spawn(|_| async move {
let buffer = buffer_task.await?;
let start = location
.start
.and_then(deserialize_anchor)
.context("missing task context location start")?;
let end = location
.end
.and_then(deserialize_anchor)
.context("missing task context location end")?;
Ok(Location {
buffer,
range: start..end,
})
})
}