diff --git a/crates/languages/src/python.rs b/crates/languages/src/python.rs index 825b11db2d..250b80cc32 100644 --- a/crates/languages/src/python.rs +++ b/crates/languages/src/python.rs @@ -348,6 +348,10 @@ const PYTHON_TEST_TARGET_TASK_VARIABLE: VariableName = const PYTHON_ACTIVE_TOOLCHAIN_PATH: VariableName = VariableName::Custom(Cow::Borrowed("PYTHON_ACTIVE_ZED_TOOLCHAIN")); + +const PYTHON_MODULE_NAME_TASK_VARIABLE: VariableName = + VariableName::Custom(Cow::Borrowed("PYTHON_MODULE_NAME")); + impl ContextProvider for PythonContextProvider { fn build_context( &self, @@ -362,7 +366,9 @@ impl ContextProvider for PythonContextProvider { TestRunner::PYTEST => self.build_pytest_target(variables), }; + let module_target = self.build_module_target(variables); let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx)); + cx.spawn(async move |cx| { let active_toolchain = if let Some(worktree_id) = worktree_id { toolchains @@ -376,8 +382,12 @@ impl ContextProvider for PythonContextProvider { String::from("python3") }; let toolchain = (PYTHON_ACTIVE_TOOLCHAIN_PATH, active_toolchain); + Ok(task::TaskVariables::from_iter( - test_target.into_iter().chain([toolchain]), + test_target + .into_iter() + .chain(module_target.into_iter()) + .chain([toolchain]), )) }) } @@ -407,6 +417,17 @@ impl ContextProvider for PythonContextProvider { args: vec![VariableName::File.template_value_with_whitespace()], ..TaskTemplate::default() }, + // Execute a file as module + TaskTemplate { + label: format!("run module '{}'", VariableName::File.template_value()), + command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value(), + args: vec![ + "-m".to_owned(), + PYTHON_MODULE_NAME_TASK_VARIABLE.template_value(), + ], + tags: vec!["python-module-main-method".to_owned()], + ..TaskTemplate::default() + }, ]; tasks.extend(match test_runner { @@ -544,6 +565,19 @@ impl PythonContextProvider { Some((PYTHON_TEST_TARGET_TASK_VARIABLE.clone(), pytest_target_str)) } + + fn build_module_target( + &self, + variables: &task::TaskVariables, + ) -> Result<(VariableName, String)> { + let python_module_name = python_module_name_from_relative_path( + variables.get(&VariableName::RelativeFile).unwrap_or(""), + ); + + let module_target = (PYTHON_MODULE_NAME_TASK_VARIABLE.clone(), python_module_name); + + Ok(module_target) + } } fn python_module_name_from_relative_path(relative_path: &str) -> String { diff --git a/crates/languages/src/python/runnables.scm b/crates/languages/src/python/runnables.scm index 8cdb0d77eb..3b32556707 100644 --- a/crates/languages/src/python/runnables.scm +++ b/crates/languages/src/python/runnables.scm @@ -82,3 +82,19 @@ ) ) ) + +; module main method +( + (module + (if_statement + condition: (comparison_operator + (identifier) @run @_lhs + operators: "==" + (string) @_rhs + ) + (#eq? @_lhs "__name__") + (#match? @_rhs "^[\"']__main__[\"']$") + (#set! tag python-module-main-method) + ) + ) +)