diff --git a/crates/assistant_tools/src/edit_files_tool/edit_action.rs b/crates/assistant_tools/src/edit_files_tool/edit_action.rs index 9e9a58737b..7b5481e1b7 100644 --- a/crates/assistant_tools/src/edit_files_tool/edit_action.rs +++ b/crates/assistant_tools/src/edit_files_tool/edit_action.rs @@ -355,6 +355,7 @@ impl std::fmt::Display for ParseError { mod tests { use super::*; use rand::prelude::*; + use util::line_endings; #[test] fn test_simple_edit_action() { @@ -798,19 +799,17 @@ fn new_utils_func() {} EditAction::Replace { file_path: PathBuf::from("mathweb/flask/app.py"), old: "from flask import Flask".to_string(), - new: "import math\nfrom flask import Flask".to_string(), - } - .fix_lf(), + new: line_endings!("import math\nfrom flask import Flask").to_string(), + }, ); assert_eq!( actions[1], EditAction::Replace { file_path: PathBuf::from("mathweb/flask/app.py"), - old: "def factorial(n):\n \"compute factorial\"\n\n if n == 0:\n return 1\n else:\n return n * factorial(n-1)\n".to_string(), + old: line_endings!("def factorial(n):\n \"compute factorial\"\n\n if n == 0:\n return 1\n else:\n return n * factorial(n-1)\n").to_string(), new: "".to_string(), } - .fix_lf() ); assert_eq!( @@ -819,28 +818,30 @@ fn new_utils_func() {} file_path: PathBuf::from("mathweb/flask/app.py"), old: " return str(factorial(n))".to_string(), new: " return str(math.factorial(n))".to_string(), - } - .fix_lf(), + }, ); assert_eq!( actions[3], EditAction::Write { file_path: PathBuf::from("hello.py"), - content: "def hello():\n \"print a greeting\"\n\n print(\"hello\")" - .to_string(), - } - .fix_lf(), + content: line_endings!( + "def hello():\n \"print a greeting\"\n\n print(\"hello\")" + ) + .to_string(), + }, ); assert_eq!( actions[4], EditAction::Replace { file_path: PathBuf::from("main.py"), - old: "def hello():\n \"print a greeting\"\n\n print(\"hello\")".to_string(), + old: line_endings!( + "def hello():\n \"print a greeting\"\n\n print(\"hello\")" + ) + .to_string(), new: "from hello import hello".to_string(), - } - .fix_lf(), + }, ); // The system prompt includes some text that would produce errors @@ -860,29 +861,6 @@ fn new_utils_func() {} ); } - impl EditAction { - fn fix_lf(self: EditAction) -> EditAction { - #[cfg(windows)] - match self { - EditAction::Replace { - file_path, - old, - new, - } => EditAction::Replace { - file_path: file_path.clone(), - old: old.replace("\n", "\r\n"), - new: new.replace("\n", "\r\n"), - }, - EditAction::Write { file_path, content } => EditAction::Write { - file_path: file_path.clone(), - content: content.replace("\n", "\r\n"), - }, - } - #[cfg(not(windows))] - self - } - } - #[test] fn test_print_error() { let input = r#"src/main.rs diff --git a/crates/util/src/util.rs b/crates/util/src/util.rs index ece84bb46b..98ade249fc 100644 --- a/crates/util/src/util.rs +++ b/crates/util/src/util.rs @@ -29,7 +29,7 @@ use anyhow::{anyhow, Context as _}; pub use take_until::*; #[cfg(any(test, feature = "test-support"))] -pub use util_macros::{separator, uri}; +pub use util_macros::{line_endings, separator, uri}; #[macro_export] macro_rules! debug_panic { diff --git a/crates/util_macros/src/util_macros.rs b/crates/util_macros/src/util_macros.rs index 2baba2f473..df48be8c99 100644 --- a/crates/util_macros/src/util_macros.rs +++ b/crates/util_macros/src/util_macros.rs @@ -54,3 +54,29 @@ pub fn uri(input: TokenStream) -> TokenStream { #uri }) } + +/// This macro replaces the line endings `\n` with `\r\n` for Windows. +/// But if the target OS is not Windows, the line endings are returned as is. +/// +/// # Example +/// ```rust +/// use util_macros::line_endings; +/// +/// let text = line_endings!("Hello\nWorld"); +/// #[cfg(target_os = "windows")] +/// assert_eq!(text, "Hello\r\nWorld"); +/// #[cfg(not(target_os = "windows"))] +/// assert_eq!(text, "Hello\nWorld"); +/// ``` +#[proc_macro] +pub fn line_endings(input: TokenStream) -> TokenStream { + let text = parse_macro_input!(input as LitStr); + let text = text.value(); + + #[cfg(target_os = "windows")] + let text = text.replace("\n", "\r\n"); + + TokenStream::from(quote! { + #text + }) +}