diff --git a/Cargo.lock b/Cargo.lock index 121e9a28dd..bf0ed9b163 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,6 +114,7 @@ dependencies = [ "log", "menu", "ordered-float", + "parking_lot 0.11.2", "project", "rand 0.8.5", "regex", diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 4438f88108..d96e470d5c 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -27,6 +27,7 @@ futures.workspace = true indoc.workspace = true isahc.workspace = true ordered-float.workspace = true +parking_lot.workspace = true regex.workspace = true schemars.workspace = true serde.workspace = true diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 2e8eca80e3..7d9b93b0a7 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -27,7 +27,7 @@ use util::paths::CONVERSATIONS_DIR; const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; // Data types for chat completion requests -#[derive(Debug, Serialize)] +#[derive(Debug, Default, Serialize)] pub struct OpenAIRequest { model: String, messages: Vec, diff --git a/crates/ai/src/codegen.rs b/crates/ai/src/codegen.rs index b24c0f9435..9657d9a492 100644 --- a/crates/ai/src/codegen.rs +++ b/crates/ai/src/codegen.rs @@ -406,9 +406,68 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { - use futures::stream; - use super::*; + use futures::stream; + use gpui::{executor::Deterministic, TestAppContext}; + use indoc::indoc; + use language::{tree_sitter_rust, Buffer, Language, LanguageConfig}; + use parking_lot::Mutex; + use rand::prelude::*; + + #[gpui::test(iterations = 10)] + async fn test_autoindent( + cx: &mut TestAppContext, + mut rng: StdRng, + deterministic: Arc, + ) { + let text = indoc! {" + fn main() { + let x = 0; + for _ in 0..10 { + x += 1; + } + } + "}; + let buffer = + cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx)); + let range = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4)) + }); + let provider = Arc::new(TestCompletionProvider::new()); + let codegen = cx.add_model(|cx| Codegen::new(buffer.clone(), range, provider.clone(), cx)); + codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let mut new_text = indoc! {" + let mut x = 0; + while x < 10 { + x += 1; + } + "}; + while !new_text.is_empty() { + let max_len = cmp::min(new_text.len(), 10); + let len = rng.gen_range(1..=max_len); + let (chunk, suffix) = new_text.split_at(len); + provider.send_completion(chunk); + new_text = suffix; + deterministic.run_until_parked(); + } + provider.finish_completion(); + deterministic.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + while x < 10 { + x += 1; + } + } + "} + ); + } #[gpui::test] async fn test_strip_markdown_codeblock() { @@ -465,4 +524,56 @@ mod tests { ) } } + + struct TestCompletionProvider { + last_completion_tx: Mutex>>, + } + + impl TestCompletionProvider { + fn new() -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } + + fn send_completion(&self, completion: impl Into) { + let mut tx = self.last_completion_tx.lock(); + tx.as_mut().unwrap().try_send(completion.into()).unwrap(); + } + + fn finish_completion(&self) { + self.last_completion_tx.lock().take().unwrap(); + } + } + + impl CompletionProvider for TestCompletionProvider { + fn complete( + &self, + _prompt: OpenAIRequest, + ) -> BoxFuture<'static, Result>>> { + let (tx, rx) = mpsc::channel(1); + *self.last_completion_tx.lock() = Some(tx); + async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() + } + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ) + .with_indents_query( + r#" + (call_expression) @indent + (field_expression) @indent + (_ "(" ")" @end) @indent + (_ "{" "}" @end) @indent + "#, + ) + .unwrap() + } }