replace OpenAIRequest with more generalized Box<dyn CompletionRequest>
This commit is contained in:
parent
05ae978cb7
commit
d813ae8845
6 changed files with 58 additions and 23 deletions
|
@ -1,11 +1,13 @@
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use futures::{future::BoxFuture, stream::BoxStream};
|
use futures::{future::BoxFuture, stream::BoxStream};
|
||||||
|
|
||||||
use crate::providers::open_ai::completion::OpenAIRequest;
|
pub trait CompletionRequest: Send + Sync {
|
||||||
|
fn data(&self) -> serde_json::Result<String>;
|
||||||
|
}
|
||||||
|
|
||||||
pub trait CompletionProvider {
|
pub trait CompletionProvider {
|
||||||
fn complete(
|
fn complete(
|
||||||
&self,
|
&self,
|
||||||
prompt: OpenAIRequest,
|
prompt: Box<dyn CompletionRequest>,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||||
}
|
}
|
||||||
|
|
13
crates/ai/src/providers/dummy.rs
Normal file
13
crates/ai/src/providers/dummy.rs
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
use crate::completion::CompletionRequest;
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct DummyCompletionRequest {
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionRequest for DummyCompletionRequest {
|
||||||
|
fn data(&self) -> serde_json::Result<String> {
|
||||||
|
serde_json::to_string(self)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1 +1,2 @@
|
||||||
|
pub mod dummy;
|
||||||
pub mod open_ai;
|
pub mod open_ai;
|
||||||
|
|
|
@ -12,7 +12,7 @@ use std::{
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::completion::CompletionProvider;
|
use crate::completion::{CompletionProvider, CompletionRequest};
|
||||||
|
|
||||||
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
||||||
|
|
||||||
|
@ -59,6 +59,12 @@ pub struct OpenAIRequest {
|
||||||
pub temperature: f32,
|
pub temperature: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl CompletionRequest for OpenAIRequest {
|
||||||
|
fn data(&self) -> serde_json::Result<String> {
|
||||||
|
serde_json::to_string(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
pub struct ResponseMessage {
|
pub struct ResponseMessage {
|
||||||
pub role: Option<Role>,
|
pub role: Option<Role>,
|
||||||
|
@ -92,13 +98,11 @@ pub struct OpenAIResponseStreamEvent {
|
||||||
pub async fn stream_completion(
|
pub async fn stream_completion(
|
||||||
api_key: String,
|
api_key: String,
|
||||||
executor: Arc<Background>,
|
executor: Arc<Background>,
|
||||||
mut request: OpenAIRequest,
|
request: Box<dyn CompletionRequest>,
|
||||||
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||||
request.stream = true;
|
|
||||||
|
|
||||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
||||||
|
|
||||||
let json_data = serde_json::to_string(&request)?;
|
let json_data = request.data()?;
|
||||||
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
|
@ -189,7 +193,7 @@ impl OpenAICompletionProvider {
|
||||||
impl CompletionProvider for OpenAICompletionProvider {
|
impl CompletionProvider for OpenAICompletionProvider {
|
||||||
fn complete(
|
fn complete(
|
||||||
&self,
|
&self,
|
||||||
prompt: OpenAIRequest,
|
prompt: Box<dyn CompletionRequest>,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
|
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
|
||||||
async move {
|
async move {
|
||||||
|
|
|
@ -6,8 +6,11 @@ use crate::{
|
||||||
SavedMessage,
|
SavedMessage,
|
||||||
};
|
};
|
||||||
|
|
||||||
use ai::providers::open_ai::{
|
use ai::{
|
||||||
|
completion::CompletionRequest,
|
||||||
|
providers::open_ai::{
|
||||||
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
|
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use ai::prompts::repository_context::PromptCodeSnippet;
|
use ai::prompts::repository_context::PromptCodeSnippet;
|
||||||
|
@ -745,13 +748,14 @@ impl AssistantPanel {
|
||||||
content: prompt,
|
content: prompt,
|
||||||
});
|
});
|
||||||
|
|
||||||
let request = OpenAIRequest {
|
let request = Box::new(OpenAIRequest {
|
||||||
model: model.full_name().into(),
|
model: model.full_name().into(),
|
||||||
messages,
|
messages,
|
||||||
stream: true,
|
stream: true,
|
||||||
stop: vec!["|END|>".to_string()],
|
stop: vec!["|END|>".to_string()],
|
||||||
temperature,
|
temperature,
|
||||||
};
|
});
|
||||||
|
|
||||||
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
|
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
|
||||||
anyhow::Ok(())
|
anyhow::Ok(())
|
||||||
})
|
})
|
||||||
|
@ -1735,7 +1739,7 @@ impl Conversation {
|
||||||
return Default::default();
|
return Default::default();
|
||||||
};
|
};
|
||||||
|
|
||||||
let request = OpenAIRequest {
|
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
|
||||||
model: self.model.full_name().to_string(),
|
model: self.model.full_name().to_string(),
|
||||||
messages: self
|
messages: self
|
||||||
.messages(cx)
|
.messages(cx)
|
||||||
|
@ -1745,7 +1749,7 @@ impl Conversation {
|
||||||
stream: true,
|
stream: true,
|
||||||
stop: vec![],
|
stop: vec![],
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
};
|
});
|
||||||
|
|
||||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
let stream = stream_completion(api_key, cx.background().clone(), request);
|
||||||
let assistant_message = self
|
let assistant_message = self
|
||||||
|
@ -2025,13 +2029,13 @@ impl Conversation {
|
||||||
"Summarize the conversation into a short title without punctuation"
|
"Summarize the conversation into a short title without punctuation"
|
||||||
.into(),
|
.into(),
|
||||||
}));
|
}));
|
||||||
let request = OpenAIRequest {
|
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
|
||||||
model: self.model.full_name().to_string(),
|
model: self.model.full_name().to_string(),
|
||||||
messages: messages.collect(),
|
messages: messages.collect(),
|
||||||
stream: true,
|
stream: true,
|
||||||
stop: vec![],
|
stop: vec![],
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
};
|
});
|
||||||
|
|
||||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
let stream = stream_completion(api_key, cx.background().clone(), request);
|
||||||
self.pending_summary = cx.spawn(|this, mut cx| {
|
self.pending_summary = cx.spawn(|this, mut cx| {
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
use crate::streaming_diff::{Hunk, StreamingDiff};
|
use crate::streaming_diff::{Hunk, StreamingDiff};
|
||||||
use ai::completion::CompletionProvider;
|
use ai::completion::{CompletionProvider, CompletionRequest};
|
||||||
use ai::providers::open_ai::OpenAIRequest;
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
|
use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
|
||||||
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
|
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
|
||||||
|
@ -96,7 +95,7 @@ impl Codegen {
|
||||||
self.error.as_ref()
|
self.error.as_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
|
pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
|
||||||
let range = self.range();
|
let range = self.range();
|
||||||
let snapshot = self.snapshot.clone();
|
let snapshot = self.snapshot.clone();
|
||||||
let selected_text = snapshot
|
let selected_text = snapshot
|
||||||
|
@ -336,6 +335,7 @@ fn strip_markdown_codeblock(
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use ai::providers::dummy::DummyCompletionRequest;
|
||||||
use futures::{
|
use futures::{
|
||||||
future::BoxFuture,
|
future::BoxFuture,
|
||||||
stream::{self, BoxStream},
|
stream::{self, BoxStream},
|
||||||
|
@ -381,7 +381,10 @@ mod tests {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
let request = Box::new(DummyCompletionRequest {
|
||||||
|
name: "test".to_string(),
|
||||||
|
});
|
||||||
|
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||||
|
|
||||||
let mut new_text = concat!(
|
let mut new_text = concat!(
|
||||||
" let mut x = 0;\n",
|
" let mut x = 0;\n",
|
||||||
|
@ -443,7 +446,11 @@ mod tests {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
|
||||||
|
let request = Box::new(DummyCompletionRequest {
|
||||||
|
name: "test".to_string(),
|
||||||
|
});
|
||||||
|
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||||
|
|
||||||
let mut new_text = concat!(
|
let mut new_text = concat!(
|
||||||
"t mut x = 0;\n",
|
"t mut x = 0;\n",
|
||||||
|
@ -505,7 +512,11 @@ mod tests {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
|
||||||
|
let request = Box::new(DummyCompletionRequest {
|
||||||
|
name: "test".to_string(),
|
||||||
|
});
|
||||||
|
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||||
|
|
||||||
let mut new_text = concat!(
|
let mut new_text = concat!(
|
||||||
"let mut x = 0;\n",
|
"let mut x = 0;\n",
|
||||||
|
@ -617,7 +628,7 @@ mod tests {
|
||||||
impl CompletionProvider for TestCompletionProvider {
|
impl CompletionProvider for TestCompletionProvider {
|
||||||
fn complete(
|
fn complete(
|
||||||
&self,
|
&self,
|
||||||
_prompt: OpenAIRequest,
|
_prompt: Box<dyn CompletionRequest>,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
let (tx, rx) = mpsc::channel(1);
|
let (tx, rx) = mpsc::channel(1);
|
||||||
*self.last_completion_tx.lock() = Some(tx);
|
*self.last_completion_tx.lock() = Some(tx);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue