From 7af77b1cf95da45314092aa35f7bcc04fa4fd3bc Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 27 Oct 2023 12:26:01 +0200 Subject: [PATCH] moved TestCompletionProvider into ai --- crates/ai/src/test.rs | 39 +++++++++++++++++++++++++++++++++ crates/assistant/Cargo.toml | 1 + crates/assistant/src/codegen.rs | 38 +------------------------------- 3 files changed, 41 insertions(+), 37 deletions(-) diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index bc143e3c21..2c78027b62 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -4,9 +4,12 @@ use std::{ }; use async_trait::async_trait; +use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use parking_lot::Mutex; use crate::{ auth::{CredentialProvider, NullCredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, embedding::{Embedding, EmbeddingProvider}, models::{LanguageModel, TruncationDirection}, }; @@ -125,3 +128,39 @@ impl EmbeddingProvider for FakeEmbeddingProvider { anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) } } + +pub struct TestCompletionProvider { + last_completion_tx: Mutex>>, +} + +impl TestCompletionProvider { + pub fn new() -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } + + pub fn send_completion(&self, completion: impl Into) { + let mut tx = self.last_completion_tx.lock(); + tx.as_mut().unwrap().try_send(completion.into()).unwrap(); + } + + pub fn finish_completion(&self) { + self.last_completion_tx.lock().take().unwrap(); + } +} + +impl CompletionProvider for TestCompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); + model + } + fn complete( + &self, + _prompt: Box, + ) -> BoxFuture<'static, anyhow::Result>>> { + let (tx, rx) = mpsc::channel(1); + *self.last_completion_tx.lock() = Some(tx); + async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() + } +} diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index 9cfdd3301a..6b0ce659e3 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -44,6 +44,7 @@ tiktoken-rs = "0.5" [dev-dependencies] editor = { path = "../editor", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } +ai = { path = "../ai", features = ["test-support"]} ctor.workspace = true env_logger.workspace = true diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 33adb2e570..3516fc3708 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -335,7 +335,7 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; - use ai::{models::LanguageModel, test::FakeLanguageModel}; + use ai::test::TestCompletionProvider; use futures::{ future::BoxFuture, stream::{self, BoxStream}, @@ -617,42 +617,6 @@ mod tests { } } - struct TestCompletionProvider { - last_completion_tx: Mutex>>, - } - - impl TestCompletionProvider { - fn new() -> Self { - Self { - last_completion_tx: Mutex::new(None), - } - } - - fn send_completion(&self, completion: impl Into) { - let mut tx = self.last_completion_tx.lock(); - tx.as_mut().unwrap().try_send(completion.into()).unwrap(); - } - - fn finish_completion(&self) { - self.last_completion_tx.lock().take().unwrap(); - } - } - - impl CompletionProvider for TestCompletionProvider { - fn base_model(&self) -> Box { - let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); - model - } - fn complete( - &self, - _prompt: Box, - ) -> BoxFuture<'static, Result>>> { - let (tx, rx) = mpsc::channel(1); - *self.last_completion_tx.lock() = Some(tx); - async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() - } - } - fn rust_lang() -> Language { Language::new( LanguageConfig {