diff --git a/crates/languages/src/rust.rs b/crates/languages/src/rust.rs index 39a9835bc7..28a5d426f6 100644 --- a/crates/languages/src/rust.rs +++ b/crates/languages/src/rust.rs @@ -346,10 +346,17 @@ pub(crate) struct RustContextProvider; const RUST_PACKAGE_TASK_VARIABLE: VariableName = VariableName::Custom(Cow::Borrowed("RUST_PACKAGE")); +/// The bin name corresponding to the current file in Cargo.toml +const RUST_BIN_NAME_TASK_VARIABLE: VariableName = + VariableName::Custom(Cow::Borrowed("RUST_BIN_NAME")); + +const RUST_MAIN_FUNCTION_TASK_VARIABLE: VariableName = + VariableName::Custom(Cow::Borrowed("_rust_main_function_end")); + impl ContextProvider for RustContextProvider { fn build_context( &self, - _: &TaskVariables, + task_variables: &TaskVariables, location: &Location, cx: &mut gpui::AppContext, ) -> Result { @@ -358,17 +365,35 @@ impl ContextProvider for RustContextProvider { .read(cx) .file() .and_then(|file| Some(file.as_local()?.abs_path(cx))); - Ok( - if let Some(package_name) = local_abs_path - .as_deref() - .and_then(|local_abs_path| local_abs_path.parent()) - .and_then(human_readable_package_name) + + let local_abs_path = local_abs_path.as_deref(); + + let is_main_function = task_variables + .get(&RUST_MAIN_FUNCTION_TASK_VARIABLE) + .is_some(); + + if is_main_function { + if let Some((package_name, bin_name)) = local_abs_path + .and_then(|local_abs_path| package_name_and_bin_name_from_abs_path(local_abs_path)) { - TaskVariables::from_iter(Some((RUST_PACKAGE_TASK_VARIABLE.clone(), package_name))) - } else { - TaskVariables::default() - }, - ) + return Ok(TaskVariables::from_iter([ + (RUST_PACKAGE_TASK_VARIABLE.clone(), package_name), + (RUST_BIN_NAME_TASK_VARIABLE.clone(), bin_name), + ])); + } + } + + if let Some(package_name) = local_abs_path + .and_then(|local_abs_path| local_abs_path.parent()) + .and_then(human_readable_package_name) + { + return Ok(TaskVariables::from_iter([( + RUST_PACKAGE_TASK_VARIABLE.clone(), + package_name, + )])); + } + + Ok(TaskVariables::default()) } fn associated_tasks(&self) -> Option { @@ -426,6 +451,23 @@ impl ContextProvider for RustContextProvider { tags: vec!["rust-mod-test".to_owned()], ..TaskTemplate::default() }, + TaskTemplate { + label: format!( + "cargo run -p {} --bin {}", + RUST_PACKAGE_TASK_VARIABLE.template_value(), + RUST_BIN_NAME_TASK_VARIABLE.template_value(), + ), + command: "cargo".into(), + args: vec![ + "run".into(), + "-p".into(), + RUST_PACKAGE_TASK_VARIABLE.template_value(), + "--bin".into(), + RUST_BIN_NAME_TASK_VARIABLE.template_value(), + ], + tags: vec!["rust-main".to_owned()], + ..TaskTemplate::default() + }, TaskTemplate { label: format!( "cargo test -p {}", @@ -455,6 +497,65 @@ impl ContextProvider for RustContextProvider { } } +/// Part of the data structure of Cargo metadata +#[derive(serde::Deserialize)] +struct CargoMetadata { + packages: Vec, +} + +#[derive(serde::Deserialize)] +struct CargoPackage { + id: String, + targets: Vec, +} + +#[derive(serde::Deserialize)] +struct CargoTarget { + name: String, + kind: Vec, + src_path: String, +} + +fn package_name_and_bin_name_from_abs_path(abs_path: &Path) -> Option<(String, String)> { + let output = std::process::Command::new("cargo") + .current_dir(abs_path.parent()?) + .arg("metadata") + .arg("--no-deps") + .arg("--format-version") + .arg("1") + .output() + .log_err()? + .stdout; + + let metadata: CargoMetadata = serde_json::from_slice(&output).log_err()?; + + retrieve_package_id_and_bin_name_from_metadata(metadata, abs_path).and_then( + |(package_id, bin_name)| { + let package_name = package_name_from_pkgid(&package_id); + + package_name.map(|package_name| (package_name.to_owned(), bin_name)) + }, + ) +} + +fn retrieve_package_id_and_bin_name_from_metadata( + metadata: CargoMetadata, + abs_path: &Path, +) -> Option<(String, String)> { + let abs_path = abs_path.to_str()?; + + for package in metadata.packages { + for target in package.targets { + let is_bin = target.kind.iter().any(|kind| kind == "bin"); + if target.src_path == abs_path && is_bin { + return Some((package.id, target.name)); + } + } + } + + None +} + fn human_readable_package_name(package_directory: &Path) -> Option { let pkgid = String::from_utf8( std::process::Command::new("cargo") @@ -815,4 +916,37 @@ mod tests { assert_eq!(package_name_from_pkgid(input), Some(expected)); } } + + #[test] + fn test_retrieve_package_id_and_bin_name_from_metadata() { + for (input, absolute_path, expected) in [ + ( + r#"{"packages":[{"id":"path+file:///path/to/zed/crates/zed#0.131.0","targets":[{"name":"zed","kind":["bin"],"src_path":"/path/to/zed/src/main.rs"}]}]}"#, + "/path/to/zed/src/main.rs", + Some(("path+file:///path/to/zed/crates/zed#0.131.0", "zed")), + ), + ( + r#"{"packages":[{"id":"path+file:///path/to/custom-package#my-custom-package@0.1.0","targets":[{"name":"my-custom-bin","kind":["bin"],"src_path":"/path/to/custom-package/src/main.rs"}]}]}"#, + "/path/to/custom-package/src/main.rs", + Some(( + "path+file:///path/to/custom-package#my-custom-package@0.1.0", + "my-custom-bin", + )), + ), + ( + r#"{"packages":[{"id":"path+file:///path/to/custom-package#my-custom-package@0.1.0","targets":[{"name":"my-custom-package","kind":["lib"],"src_path":"/path/to/custom-package/src/main.rs"}]}]}"#, + "/path/to/custom-package/src/main.rs", + None, + ), + ] { + let metadata: CargoMetadata = serde_json::from_str(input).unwrap(); + + let absolute_path = Path::new(absolute_path); + + assert_eq!( + retrieve_package_id_and_bin_name_from_metadata(metadata, absolute_path), + expected.map(|(pkgid, bin)| (pkgid.to_owned(), bin.to_owned())) + ); + } + } } diff --git a/crates/languages/src/rust/runnables.scm b/crates/languages/src/rust/runnables.scm index 90e4900188..963009abc0 100644 --- a/crates/languages/src/rust/runnables.scm +++ b/crates/languages/src/rust/runnables.scm @@ -25,3 +25,15 @@ ) (#set! tag rust-test) ) + +; Rust main function +( + ( + (function_item + name: (_) @run + body: _ + ) @_rust_main_function_end + (#eq? @run "main") + ) + (#set! tag rust-main) +)