moved TestCompletionProvider into ai
This commit is contained in:
parent
ec9d79b6fe
commit
7af77b1cf9
3 changed files with 41 additions and 37 deletions
|
@ -4,9 +4,12 @@ use std::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
|
auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
|
||||||
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
embedding::{Embedding, EmbeddingProvider},
|
embedding::{Embedding, EmbeddingProvider},
|
||||||
models::{LanguageModel, TruncationDirection},
|
models::{LanguageModel, TruncationDirection},
|
||||||
};
|
};
|
||||||
|
@ -125,3 +128,39 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||||
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct TestCompletionProvider {
|
||||||
|
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestCompletionProvider {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
last_completion_tx: Mutex::new(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_completion(&self, completion: impl Into<String>) {
|
||||||
|
let mut tx = self.last_completion_tx.lock();
|
||||||
|
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn finish_completion(&self) {
|
||||||
|
self.last_completion_tx.lock().take().unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionProvider for TestCompletionProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
|
||||||
|
model
|
||||||
|
}
|
||||||
|
fn complete(
|
||||||
|
&self,
|
||||||
|
_prompt: Box<dyn CompletionRequest>,
|
||||||
|
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
|
||||||
|
let (tx, rx) = mpsc::channel(1);
|
||||||
|
*self.last_completion_tx.lock() = Some(tx);
|
||||||
|
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -44,6 +44,7 @@ tiktoken-rs = "0.5"
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
editor = { path = "../editor", features = ["test-support"] }
|
editor = { path = "../editor", features = ["test-support"] }
|
||||||
project = { path = "../project", features = ["test-support"] }
|
project = { path = "../project", features = ["test-support"] }
|
||||||
|
ai = { path = "../ai", features = ["test-support"]}
|
||||||
|
|
||||||
ctor.workspace = true
|
ctor.workspace = true
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
|
|
|
@ -335,7 +335,7 @@ fn strip_markdown_codeblock(
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use ai::{models::LanguageModel, test::FakeLanguageModel};
|
use ai::test::TestCompletionProvider;
|
||||||
use futures::{
|
use futures::{
|
||||||
future::BoxFuture,
|
future::BoxFuture,
|
||||||
stream::{self, BoxStream},
|
stream::{self, BoxStream},
|
||||||
|
@ -617,42 +617,6 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TestCompletionProvider {
|
|
||||||
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TestCompletionProvider {
|
|
||||||
fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
last_completion_tx: Mutex::new(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn send_completion(&self, completion: impl Into<String>) {
|
|
||||||
let mut tx = self.last_completion_tx.lock();
|
|
||||||
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn finish_completion(&self) {
|
|
||||||
self.last_completion_tx.lock().take().unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CompletionProvider for TestCompletionProvider {
|
|
||||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
|
||||||
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
|
|
||||||
model
|
|
||||||
}
|
|
||||||
fn complete(
|
|
||||||
&self,
|
|
||||||
_prompt: Box<dyn CompletionRequest>,
|
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
|
||||||
let (tx, rx) = mpsc::channel(1);
|
|
||||||
*self.last_completion_tx.lock() = Some(tx);
|
|
||||||
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rust_lang() -> Language {
|
fn rust_lang() -> Language {
|
||||||
Language::new(
|
Language::new(
|
||||||
LanguageConfig {
|
LanguageConfig {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue