replace OpenAIRequest with more generalized Box<dyn CompletionRequest>

This commit is contained in:
KCaverly 2023-10-22 14:33:19 +02:00
parent 05ae978cb7
commit d813ae8845
6 changed files with 58 additions and 23 deletions

View file

@ -1,6 +1,5 @@
use crate::streaming_diff::{Hunk, StreamingDiff};
use ai::completion::CompletionProvider;
use ai::providers::open_ai::OpenAIRequest;
use ai::completion::{CompletionProvider, CompletionRequest};
use anyhow::Result;
use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
@ -96,7 +95,7 @@ impl Codegen {
self.error.as_ref()
}
pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
let range = self.range();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
@ -336,6 +335,7 @@ fn strip_markdown_codeblock(
#[cfg(test)]
mod tests {
use super::*;
use ai::providers::dummy::DummyCompletionRequest;
use futures::{
future::BoxFuture,
stream::{self, BoxStream},
@ -381,7 +381,10 @@ mod tests {
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
" let mut x = 0;\n",
@ -443,7 +446,11 @@ mod tests {
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
"t mut x = 0;\n",
@ -505,7 +512,11 @@ mod tests {
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
"let mut x = 0;\n",
@ -617,7 +628,7 @@ mod tests {
impl CompletionProvider for TestCompletionProvider {
fn complete(
&self,
_prompt: OpenAIRequest,
_prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);