Merge branch 'main' of github.com:zed-industries/zed into zed2
This commit is contained in:
commit
204aba07f6
32 changed files with 1212 additions and 970 deletions
|
@ -8,6 +8,9 @@ publish = false
|
||||||
path = "src/ai.rs"
|
path = "src/ai.rs"
|
||||||
doctest = false
|
doctest = false
|
||||||
|
|
||||||
|
[features]
|
||||||
|
test-support = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
gpui = { path = "../gpui" }
|
gpui = { path = "../gpui" }
|
||||||
util = { path = "../util" }
|
util = { path = "../util" }
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
|
pub mod auth;
|
||||||
pub mod completion;
|
pub mod completion;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod models;
|
pub mod models;
|
||||||
pub mod templates;
|
pub mod prompts;
|
||||||
|
pub mod providers;
|
||||||
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
|
pub mod test;
|
||||||
|
|
15
crates/ai/src/auth.rs
Normal file
15
crates/ai/src/auth.rs
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
use gpui::AppContext;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum ProviderCredential {
|
||||||
|
Credentials { api_key: String },
|
||||||
|
NoCredentials,
|
||||||
|
NotNeeded,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CredentialProvider: Send + Sync {
|
||||||
|
fn has_credentials(&self) -> bool;
|
||||||
|
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
|
||||||
|
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
|
||||||
|
fn delete_credentials(&self, cx: &AppContext);
|
||||||
|
}
|
|
@ -1,214 +1,23 @@
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::Result;
|
||||||
use futures::{
|
use futures::{future::BoxFuture, stream::BoxStream};
|
||||||
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
|
||||||
Stream, StreamExt,
|
|
||||||
};
|
|
||||||
use gpui::executor::Background;
|
|
||||||
use isahc::{http::StatusCode, Request, RequestExt};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::{
|
|
||||||
fmt::{self, Display},
|
|
||||||
io,
|
|
||||||
sync::Arc,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
use crate::{auth::CredentialProvider, models::LanguageModel};
|
||||||
|
|
||||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
pub trait CompletionRequest: Send + Sync {
|
||||||
#[serde(rename_all = "lowercase")]
|
fn data(&self) -> serde_json::Result<String>;
|
||||||
pub enum Role {
|
|
||||||
User,
|
|
||||||
Assistant,
|
|
||||||
System,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Role {
|
pub trait CompletionProvider: CredentialProvider {
|
||||||
pub fn cycle(&mut self) {
|
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||||
*self = match self {
|
|
||||||
Role::User => Role::Assistant,
|
|
||||||
Role::Assistant => Role::System,
|
|
||||||
Role::System => Role::User,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for Role {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Role::User => write!(f, "User"),
|
|
||||||
Role::Assistant => write!(f, "Assistant"),
|
|
||||||
Role::System => write!(f, "System"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
pub struct RequestMessage {
|
|
||||||
pub role: Role,
|
|
||||||
pub content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Default, Serialize)]
|
|
||||||
pub struct OpenAIRequest {
|
|
||||||
pub model: String,
|
|
||||||
pub messages: Vec<RequestMessage>,
|
|
||||||
pub stream: bool,
|
|
||||||
pub stop: Vec<String>,
|
|
||||||
pub temperature: f32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
pub struct ResponseMessage {
|
|
||||||
pub role: Option<Role>,
|
|
||||||
pub content: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
pub struct OpenAIUsage {
|
|
||||||
pub prompt_tokens: u32,
|
|
||||||
pub completion_tokens: u32,
|
|
||||||
pub total_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
pub struct ChatChoiceDelta {
|
|
||||||
pub index: u32,
|
|
||||||
pub delta: ResponseMessage,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
pub struct OpenAIResponseStreamEvent {
|
|
||||||
pub id: Option<String>,
|
|
||||||
pub object: String,
|
|
||||||
pub created: u32,
|
|
||||||
pub model: String,
|
|
||||||
pub choices: Vec<ChatChoiceDelta>,
|
|
||||||
pub usage: Option<OpenAIUsage>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn stream_completion(
|
|
||||||
api_key: String,
|
|
||||||
executor: Arc<Background>,
|
|
||||||
mut request: OpenAIRequest,
|
|
||||||
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
|
||||||
request.stream = true;
|
|
||||||
|
|
||||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
|
||||||
|
|
||||||
let json_data = serde_json::to_string(&request)?;
|
|
||||||
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
|
||||||
.body(json_data)?
|
|
||||||
.send_async()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
if status == StatusCode::OK {
|
|
||||||
executor
|
|
||||||
.spawn(async move {
|
|
||||||
let mut lines = BufReader::new(response.body_mut()).lines();
|
|
||||||
|
|
||||||
fn parse_line(
|
|
||||||
line: Result<String, io::Error>,
|
|
||||||
) -> Result<Option<OpenAIResponseStreamEvent>> {
|
|
||||||
if let Some(data) = line?.strip_prefix("data: ") {
|
|
||||||
let event = serde_json::from_str(&data)?;
|
|
||||||
Ok(Some(event))
|
|
||||||
} else {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
while let Some(line) = lines.next().await {
|
|
||||||
if let Some(event) = parse_line(line).transpose() {
|
|
||||||
let done = event.as_ref().map_or(false, |event| {
|
|
||||||
event
|
|
||||||
.choices
|
|
||||||
.last()
|
|
||||||
.map_or(false, |choice| choice.finish_reason.is_some())
|
|
||||||
});
|
|
||||||
if tx.unbounded_send(event).is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if done {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
anyhow::Ok(())
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
|
|
||||||
Ok(rx)
|
|
||||||
} else {
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIResponse {
|
|
||||||
error: OpenAIError,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIError {
|
|
||||||
message: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
match serde_json::from_str::<OpenAIResponse>(&body) {
|
|
||||||
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
|
||||||
"Failed to connect to OpenAI API: {}",
|
|
||||||
response.error.message,
|
|
||||||
)),
|
|
||||||
|
|
||||||
_ => Err(anyhow!(
|
|
||||||
"Failed to connect to OpenAI API: {} {}",
|
|
||||||
response.status(),
|
|
||||||
body,
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait CompletionProvider {
|
|
||||||
fn complete(
|
fn complete(
|
||||||
&self,
|
&self,
|
||||||
prompt: OpenAIRequest,
|
prompt: Box<dyn CompletionRequest>,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||||
|
fn box_clone(&self) -> Box<dyn CompletionProvider>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct OpenAICompletionProvider {
|
impl Clone for Box<dyn CompletionProvider> {
|
||||||
api_key: String,
|
fn clone(&self) -> Box<dyn CompletionProvider> {
|
||||||
executor: Arc<Background>,
|
self.box_clone()
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAICompletionProvider {
|
|
||||||
pub fn new(api_key: String, executor: Arc<Background>) -> Self {
|
|
||||||
Self { api_key, executor }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CompletionProvider for OpenAICompletionProvider {
|
|
||||||
fn complete(
|
|
||||||
&self,
|
|
||||||
prompt: OpenAIRequest,
|
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
|
||||||
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
|
|
||||||
async move {
|
|
||||||
let response = request.await?;
|
|
||||||
let stream = response
|
|
||||||
.filter_map(|response| async move {
|
|
||||||
match response {
|
|
||||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
|
||||||
Err(error) => Some(Err(error)),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.boxed();
|
|
||||||
Ok(stream)
|
|
||||||
}
|
|
||||||
.boxed()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,32 +1,13 @@
|
||||||
use anyhow::{anyhow, Result};
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::AsyncReadExt;
|
|
||||||
use gpui::executor::Background;
|
|
||||||
use gpui::{serde_json, AppContext};
|
|
||||||
use isahc::http::StatusCode;
|
|
||||||
use isahc::prelude::Configurable;
|
|
||||||
use isahc::{AsyncBody, Response};
|
|
||||||
use lazy_static::lazy_static;
|
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
use parking_lot::Mutex;
|
|
||||||
use parse_duration::parse;
|
|
||||||
use postage::watch;
|
|
||||||
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
||||||
use rusqlite::ToSql;
|
use rusqlite::ToSql;
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::env;
|
|
||||||
use std::ops::Add;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::{Duration, Instant};
|
|
||||||
use tiktoken_rs::{cl100k_base, CoreBPE};
|
|
||||||
use util::http::{HttpClient, Request};
|
|
||||||
use util::ResultExt;
|
|
||||||
|
|
||||||
use crate::completion::OPENAI_API_URL;
|
use crate::auth::CredentialProvider;
|
||||||
|
use crate::models::LanguageModel;
|
||||||
lazy_static! {
|
|
||||||
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone)]
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
pub struct Embedding(pub Vec<f32>);
|
pub struct Embedding(pub Vec<f32>);
|
||||||
|
@ -87,301 +68,14 @@ impl Embedding {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct OpenAIEmbeddings {
|
|
||||||
pub client: Arc<dyn HttpClient>,
|
|
||||||
pub executor: Arc<Background>,
|
|
||||||
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
|
||||||
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct OpenAIEmbeddingRequest<'a> {
|
|
||||||
model: &'static str,
|
|
||||||
input: Vec<&'a str>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIEmbeddingResponse {
|
|
||||||
data: Vec<OpenAIEmbedding>,
|
|
||||||
usage: OpenAIEmbeddingUsage,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct OpenAIEmbedding {
|
|
||||||
embedding: Vec<f32>,
|
|
||||||
index: usize,
|
|
||||||
object: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIEmbeddingUsage {
|
|
||||||
prompt_tokens: usize,
|
|
||||||
total_tokens: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait EmbeddingProvider: Sync + Send {
|
pub trait EmbeddingProvider: CredentialProvider {
|
||||||
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
|
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||||
async fn embed_batch(
|
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
||||||
&self,
|
|
||||||
spans: Vec<String>,
|
|
||||||
api_key: Option<String>,
|
|
||||||
) -> Result<Vec<Embedding>>;
|
|
||||||
fn max_tokens_per_batch(&self) -> usize;
|
fn max_tokens_per_batch(&self) -> usize;
|
||||||
fn truncate(&self, span: &str) -> (String, usize);
|
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant>;
|
fn rate_limit_expiration(&self) -> Option<Instant>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct DummyEmbeddings {}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl EmbeddingProvider for DummyEmbeddings {
|
|
||||||
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
|
|
||||||
Some("Dummy API KEY".to_string())
|
|
||||||
}
|
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
async fn embed_batch(
|
|
||||||
&self,
|
|
||||||
spans: Vec<String>,
|
|
||||||
_api_key: Option<String>,
|
|
||||||
) -> Result<Vec<Embedding>> {
|
|
||||||
// 1024 is the OpenAI Embeddings size for ada models.
|
|
||||||
// the model we will likely be starting with.
|
|
||||||
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
|
|
||||||
return Ok(vec![dummy_vec; spans.len()]);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn max_tokens_per_batch(&self) -> usize {
|
|
||||||
OPENAI_INPUT_LIMIT
|
|
||||||
}
|
|
||||||
|
|
||||||
fn truncate(&self, span: &str) -> (String, usize) {
|
|
||||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
|
||||||
let token_count = tokens.len();
|
|
||||||
let output = if token_count > OPENAI_INPUT_LIMIT {
|
|
||||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
|
||||||
let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
|
|
||||||
new_input.ok().unwrap_or_else(|| span.to_string())
|
|
||||||
} else {
|
|
||||||
span.to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
(output, tokens.len())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const OPENAI_INPUT_LIMIT: usize = 8190;
|
|
||||||
|
|
||||||
impl OpenAIEmbeddings {
|
|
||||||
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
|
|
||||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
|
||||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
|
||||||
|
|
||||||
OpenAIEmbeddings {
|
|
||||||
client,
|
|
||||||
executor,
|
|
||||||
rate_limit_count_rx,
|
|
||||||
rate_limit_count_tx,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn resolve_rate_limit(&self) {
|
|
||||||
let reset_time = *self.rate_limit_count_tx.lock().borrow();
|
|
||||||
|
|
||||||
if let Some(reset_time) = reset_time {
|
|
||||||
if Instant::now() >= reset_time {
|
|
||||||
*self.rate_limit_count_tx.lock().borrow_mut() = None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log::trace!(
|
|
||||||
"resolving reset time: {:?}",
|
|
||||||
*self.rate_limit_count_tx.lock().borrow()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn update_reset_time(&self, reset_time: Instant) {
|
|
||||||
let original_time = *self.rate_limit_count_tx.lock().borrow();
|
|
||||||
|
|
||||||
let updated_time = if let Some(original_time) = original_time {
|
|
||||||
if reset_time < original_time {
|
|
||||||
Some(reset_time)
|
|
||||||
} else {
|
|
||||||
Some(original_time)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Some(reset_time)
|
|
||||||
};
|
|
||||||
|
|
||||||
log::trace!("updating rate limit time: {:?}", updated_time);
|
|
||||||
|
|
||||||
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
|
|
||||||
}
|
|
||||||
async fn send_request(
|
|
||||||
&self,
|
|
||||||
api_key: &str,
|
|
||||||
spans: Vec<&str>,
|
|
||||||
request_timeout: u64,
|
|
||||||
) -> Result<Response<AsyncBody>> {
|
|
||||||
let request = Request::post("https://api.openai.com/v1/embeddings")
|
|
||||||
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
|
||||||
.timeout(Duration::from_secs(request_timeout))
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
|
||||||
.body(
|
|
||||||
serde_json::to_string(&OpenAIEmbeddingRequest {
|
|
||||||
input: spans.clone(),
|
|
||||||
model: "text-embedding-ada-002",
|
|
||||||
})
|
|
||||||
.unwrap()
|
|
||||||
.into(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(self.client.send(request).await?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl EmbeddingProvider for OpenAIEmbeddings {
|
|
||||||
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
|
|
||||||
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
|
||||||
Some(api_key)
|
|
||||||
} else if let Some((_, api_key)) = cx
|
|
||||||
.platform()
|
|
||||||
.read_credentials(OPENAI_API_URL)
|
|
||||||
.log_err()
|
|
||||||
.flatten()
|
|
||||||
{
|
|
||||||
String::from_utf8(api_key).log_err()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn max_tokens_per_batch(&self) -> usize {
|
|
||||||
50000
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
|
||||||
*self.rate_limit_count_rx.borrow()
|
|
||||||
}
|
|
||||||
fn truncate(&self, span: &str) -> (String, usize) {
|
|
||||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
|
||||||
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
|
||||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
|
||||||
OPENAI_BPE_TOKENIZER
|
|
||||||
.decode(tokens.clone())
|
|
||||||
.ok()
|
|
||||||
.unwrap_or_else(|| span.to_string())
|
|
||||||
} else {
|
|
||||||
span.to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
(output, tokens.len())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn embed_batch(
|
|
||||||
&self,
|
|
||||||
spans: Vec<String>,
|
|
||||||
api_key: Option<String>,
|
|
||||||
) -> Result<Vec<Embedding>> {
|
|
||||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
|
||||||
const MAX_RETRIES: usize = 4;
|
|
||||||
|
|
||||||
let Some(api_key) = api_key else {
|
|
||||||
return Err(anyhow!("no open ai key provided"));
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut request_number = 0;
|
|
||||||
let mut rate_limiting = false;
|
|
||||||
let mut request_timeout: u64 = 15;
|
|
||||||
let mut response: Response<AsyncBody>;
|
|
||||||
while request_number < MAX_RETRIES {
|
|
||||||
response = self
|
|
||||||
.send_request(
|
|
||||||
&api_key,
|
|
||||||
spans.iter().map(|x| &**x).collect(),
|
|
||||||
request_timeout,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
request_number += 1;
|
|
||||||
|
|
||||||
match response.status() {
|
|
||||||
StatusCode::REQUEST_TIMEOUT => {
|
|
||||||
request_timeout += 5;
|
|
||||||
}
|
|
||||||
StatusCode::OK => {
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
|
|
||||||
|
|
||||||
log::trace!(
|
|
||||||
"openai embedding completed. tokens: {:?}",
|
|
||||||
response.usage.total_tokens
|
|
||||||
);
|
|
||||||
|
|
||||||
// If we complete a request successfully that was previously rate_limited
|
|
||||||
// resolve the rate limit
|
|
||||||
if rate_limiting {
|
|
||||||
self.resolve_rate_limit()
|
|
||||||
}
|
|
||||||
|
|
||||||
return Ok(response
|
|
||||||
.data
|
|
||||||
.into_iter()
|
|
||||||
.map(|embedding| Embedding::from(embedding.embedding))
|
|
||||||
.collect());
|
|
||||||
}
|
|
||||||
StatusCode::TOO_MANY_REQUESTS => {
|
|
||||||
rate_limiting = true;
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
|
|
||||||
let delay_duration = {
|
|
||||||
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
|
||||||
if let Some(time_to_reset) =
|
|
||||||
response.headers().get("x-ratelimit-reset-tokens")
|
|
||||||
{
|
|
||||||
if let Ok(time_str) = time_to_reset.to_str() {
|
|
||||||
parse(time_str).unwrap_or(delay)
|
|
||||||
} else {
|
|
||||||
delay
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
delay
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// If we've previously rate limited, increment the duration but not the count
|
|
||||||
let reset_time = Instant::now().add(delay_duration);
|
|
||||||
self.update_reset_time(reset_time);
|
|
||||||
|
|
||||||
log::trace!(
|
|
||||||
"openai rate limiting: waiting {:?} until lifted",
|
|
||||||
&delay_duration
|
|
||||||
);
|
|
||||||
|
|
||||||
self.executor.timer(delay_duration).await;
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
return Err(anyhow!(
|
|
||||||
"open ai bad request: {:?} {:?}",
|
|
||||||
&response.status(),
|
|
||||||
body
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(anyhow!("openai max retries"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
@ -1,66 +1,16 @@
|
||||||
use anyhow::anyhow;
|
pub enum TruncationDirection {
|
||||||
use tiktoken_rs::CoreBPE;
|
Start,
|
||||||
use util::ResultExt;
|
End,
|
||||||
|
}
|
||||||
|
|
||||||
pub trait LanguageModel {
|
pub trait LanguageModel {
|
||||||
fn name(&self) -> String;
|
fn name(&self) -> String;
|
||||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||||
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
|
fn truncate(
|
||||||
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: TruncationDirection,
|
||||||
|
) -> anyhow::Result<String>;
|
||||||
fn capacity(&self) -> anyhow::Result<usize>;
|
fn capacity(&self) -> anyhow::Result<usize>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct OpenAILanguageModel {
|
|
||||||
name: String,
|
|
||||||
bpe: Option<CoreBPE>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAILanguageModel {
|
|
||||||
pub fn load(model_name: &str) -> Self {
|
|
||||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
|
|
||||||
OpenAILanguageModel {
|
|
||||||
name: model_name.to_string(),
|
|
||||||
bpe,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LanguageModel for OpenAILanguageModel {
|
|
||||||
fn name(&self) -> String {
|
|
||||||
self.name.clone()
|
|
||||||
}
|
|
||||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
|
||||||
if let Some(bpe) = &self.bpe {
|
|
||||||
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
|
|
||||||
} else {
|
|
||||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
|
|
||||||
if let Some(bpe) = &self.bpe {
|
|
||||||
let tokens = bpe.encode_with_special_tokens(content);
|
|
||||||
if tokens.len() > length {
|
|
||||||
bpe.decode(tokens[..length].to_vec())
|
|
||||||
} else {
|
|
||||||
bpe.decode(tokens)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
|
|
||||||
if let Some(bpe) = &self.bpe {
|
|
||||||
let tokens = bpe.encode_with_special_tokens(content);
|
|
||||||
if tokens.len() > length {
|
|
||||||
bpe.decode(tokens[length..].to_vec())
|
|
||||||
} else {
|
|
||||||
bpe.decode(tokens)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn capacity(&self) -> anyhow::Result<usize> {
|
|
||||||
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ use language::BufferSnapshot;
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
|
||||||
use crate::models::LanguageModel;
|
use crate::models::LanguageModel;
|
||||||
use crate::templates::repository_context::PromptCodeSnippet;
|
use crate::prompts::repository_context::PromptCodeSnippet;
|
||||||
|
|
||||||
pub(crate) enum PromptFileType {
|
pub(crate) enum PromptFileType {
|
||||||
Text,
|
Text,
|
||||||
|
@ -125,6 +125,9 @@ impl PromptChain {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub(crate) mod tests {
|
pub(crate) mod tests {
|
||||||
|
use crate::models::TruncationDirection;
|
||||||
|
use crate::test::FakeLanguageModel;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -141,7 +144,11 @@ pub(crate) mod tests {
|
||||||
let mut token_count = args.model.count_tokens(&content)?;
|
let mut token_count = args.model.count_tokens(&content)?;
|
||||||
if let Some(max_token_length) = max_token_length {
|
if let Some(max_token_length) = max_token_length {
|
||||||
if token_count > max_token_length {
|
if token_count > max_token_length {
|
||||||
content = args.model.truncate(&content, max_token_length)?;
|
content = args.model.truncate(
|
||||||
|
&content,
|
||||||
|
max_token_length,
|
||||||
|
TruncationDirection::End,
|
||||||
|
)?;
|
||||||
token_count = max_token_length;
|
token_count = max_token_length;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -162,7 +169,11 @@ pub(crate) mod tests {
|
||||||
let mut token_count = args.model.count_tokens(&content)?;
|
let mut token_count = args.model.count_tokens(&content)?;
|
||||||
if let Some(max_token_length) = max_token_length {
|
if let Some(max_token_length) = max_token_length {
|
||||||
if token_count > max_token_length {
|
if token_count > max_token_length {
|
||||||
content = args.model.truncate(&content, max_token_length)?;
|
content = args.model.truncate(
|
||||||
|
&content,
|
||||||
|
max_token_length,
|
||||||
|
TruncationDirection::End,
|
||||||
|
)?;
|
||||||
token_count = max_token_length;
|
token_count = max_token_length;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -171,38 +182,7 @@ pub(crate) mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
|
||||||
struct DummyLanguageModel {
|
|
||||||
capacity: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LanguageModel for DummyLanguageModel {
|
|
||||||
fn name(&self) -> String {
|
|
||||||
"dummy".to_string()
|
|
||||||
}
|
|
||||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
|
||||||
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
|
||||||
}
|
|
||||||
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
|
|
||||||
anyhow::Ok(
|
|
||||||
content.chars().collect::<Vec<char>>()[..length]
|
|
||||||
.into_iter()
|
|
||||||
.collect::<String>(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
|
|
||||||
anyhow::Ok(
|
|
||||||
content.chars().collect::<Vec<char>>()[length..]
|
|
||||||
.into_iter()
|
|
||||||
.collect::<String>(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
fn capacity(&self) -> anyhow::Result<usize> {
|
|
||||||
anyhow::Ok(self.capacity)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
|
|
||||||
let args = PromptArguments {
|
let args = PromptArguments {
|
||||||
model: model.clone(),
|
model: model.clone(),
|
||||||
language_name: None,
|
language_name: None,
|
||||||
|
@ -238,7 +218,7 @@ pub(crate) mod tests {
|
||||||
|
|
||||||
// Testing with Truncation Off
|
// Testing with Truncation Off
|
||||||
// Should ignore capacity and return all prompts
|
// Should ignore capacity and return all prompts
|
||||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
|
||||||
let args = PromptArguments {
|
let args = PromptArguments {
|
||||||
model: model.clone(),
|
model: model.clone(),
|
||||||
language_name: None,
|
language_name: None,
|
||||||
|
@ -275,7 +255,7 @@ pub(crate) mod tests {
|
||||||
// Testing with Truncation Off
|
// Testing with Truncation Off
|
||||||
// Should ignore capacity and return all prompts
|
// Should ignore capacity and return all prompts
|
||||||
let capacity = 20;
|
let capacity = 20;
|
||||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||||
let args = PromptArguments {
|
let args = PromptArguments {
|
||||||
model: model.clone(),
|
model: model.clone(),
|
||||||
language_name: None,
|
language_name: None,
|
||||||
|
@ -311,7 +291,7 @@ pub(crate) mod tests {
|
||||||
// Change Ordering of Prompts Based on Priority
|
// Change Ordering of Prompts Based on Priority
|
||||||
let capacity = 120;
|
let capacity = 120;
|
||||||
let reserved_tokens = 10;
|
let reserved_tokens = 10;
|
||||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||||
let args = PromptArguments {
|
let args = PromptArguments {
|
||||||
model: model.clone(),
|
model: model.clone(),
|
||||||
language_name: None,
|
language_name: None,
|
|
@ -3,8 +3,9 @@ use language::BufferSnapshot;
|
||||||
use language::ToOffset;
|
use language::ToOffset;
|
||||||
|
|
||||||
use crate::models::LanguageModel;
|
use crate::models::LanguageModel;
|
||||||
use crate::templates::base::PromptArguments;
|
use crate::models::TruncationDirection;
|
||||||
use crate::templates::base::PromptTemplate;
|
use crate::prompts::base::PromptArguments;
|
||||||
|
use crate::prompts::base::PromptTemplate;
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -70,8 +71,9 @@ fn retrieve_context(
|
||||||
};
|
};
|
||||||
|
|
||||||
let truncated_start_window =
|
let truncated_start_window =
|
||||||
model.truncate_start(&start_window, start_goal_tokens)?;
|
model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
|
||||||
let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
|
let truncated_end_window =
|
||||||
|
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
|
||||||
writeln!(
|
writeln!(
|
||||||
prompt,
|
prompt,
|
||||||
"{truncated_start_window}{selected_window}{truncated_end_window}"
|
"{truncated_start_window}{selected_window}{truncated_end_window}"
|
||||||
|
@ -89,7 +91,7 @@ fn retrieve_context(
|
||||||
if let Some(max_token_count) = max_token_count {
|
if let Some(max_token_count) = max_token_count {
|
||||||
if model.count_tokens(&prompt)? > max_token_count {
|
if model.count_tokens(&prompt)? > max_token_count {
|
||||||
truncated = true;
|
truncated = true;
|
||||||
prompt = model.truncate(&prompt, max_token_count)?;
|
prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -148,7 +150,9 @@ impl PromptTemplate for FileContext {
|
||||||
|
|
||||||
// Really dumb truncation strategy
|
// Really dumb truncation strategy
|
||||||
if let Some(max_tokens) = max_token_length {
|
if let Some(max_tokens) = max_token_length {
|
||||||
prompt = args.model.truncate(&prompt, max_tokens)?;
|
prompt = args
|
||||||
|
.model
|
||||||
|
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let token_count = args.model.count_tokens(&prompt)?;
|
let token_count = args.model.count_tokens(&prompt)?;
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
|
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
|
|
||||||
|
@ -85,7 +85,11 @@ impl PromptTemplate for GenerateInlineContent {
|
||||||
|
|
||||||
// Really dumb truncation strategy
|
// Really dumb truncation strategy
|
||||||
if let Some(max_tokens) = max_token_length {
|
if let Some(max_tokens) = max_token_length {
|
||||||
prompt = args.model.truncate(&prompt, max_tokens)?;
|
prompt = args.model.truncate(
|
||||||
|
&prompt,
|
||||||
|
max_tokens,
|
||||||
|
crate::models::TruncationDirection::End,
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let token_count = args.model.count_tokens(&prompt)?;
|
let token_count = args.model.count_tokens(&prompt)?;
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
|
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
|
|
||||||
pub struct EngineerPreamble {}
|
pub struct EngineerPreamble {}
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::templates::base::{PromptArguments, PromptTemplate};
|
use crate::prompts::base::{PromptArguments, PromptTemplate};
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
use std::{ops::Range, path::PathBuf};
|
use std::{ops::Range, path::PathBuf};
|
||||||
|
|
1
crates/ai/src/providers/mod.rs
Normal file
1
crates/ai/src/providers/mod.rs
Normal file
|
@ -0,0 +1 @@
|
||||||
|
pub mod open_ai;
|
298
crates/ai/src/providers/open_ai/completion.rs
Normal file
298
crates/ai/src/providers/open_ai/completion.rs
Normal file
|
@ -0,0 +1,298 @@
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use futures::{
|
||||||
|
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
||||||
|
Stream, StreamExt,
|
||||||
|
};
|
||||||
|
use gpui::{executor::Background, AppContext};
|
||||||
|
use isahc::{http::StatusCode, Request, RequestExt};
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{
|
||||||
|
env,
|
||||||
|
fmt::{self, Display},
|
||||||
|
io,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
auth::{CredentialProvider, ProviderCredential},
|
||||||
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
|
models::LanguageModel,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum Role {
|
||||||
|
User,
|
||||||
|
Assistant,
|
||||||
|
System,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Role {
|
||||||
|
pub fn cycle(&mut self) {
|
||||||
|
*self = match self {
|
||||||
|
Role::User => Role::Assistant,
|
||||||
|
Role::Assistant => Role::System,
|
||||||
|
Role::System => Role::User,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for Role {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Role::User => write!(f, "User"),
|
||||||
|
Role::Assistant => write!(f, "Assistant"),
|
||||||
|
Role::System => write!(f, "System"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct RequestMessage {
|
||||||
|
pub role: Role,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Serialize)]
|
||||||
|
pub struct OpenAIRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub messages: Vec<RequestMessage>,
|
||||||
|
pub stream: bool,
|
||||||
|
pub stop: Vec<String>,
|
||||||
|
pub temperature: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionRequest for OpenAIRequest {
|
||||||
|
fn data(&self) -> serde_json::Result<String> {
|
||||||
|
serde_json::to_string(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct ResponseMessage {
|
||||||
|
pub role: Option<Role>,
|
||||||
|
pub content: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub struct OpenAIUsage {
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub completion_tokens: u32,
|
||||||
|
pub total_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub struct ChatChoiceDelta {
|
||||||
|
pub index: u32,
|
||||||
|
pub delta: ResponseMessage,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub struct OpenAIResponseStreamEvent {
|
||||||
|
pub id: Option<String>,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u32,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<ChatChoiceDelta>,
|
||||||
|
pub usage: Option<OpenAIUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn stream_completion(
|
||||||
|
credential: ProviderCredential,
|
||||||
|
executor: Arc<Background>,
|
||||||
|
request: Box<dyn CompletionRequest>,
|
||||||
|
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||||
|
let api_key = match credential {
|
||||||
|
ProviderCredential::Credentials { api_key } => api_key,
|
||||||
|
_ => {
|
||||||
|
return Err(anyhow!("no credentials provider for completion"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
||||||
|
|
||||||
|
let json_data = request.data()?;
|
||||||
|
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
|
.body(json_data)?
|
||||||
|
.send_async()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if status == StatusCode::OK {
|
||||||
|
executor
|
||||||
|
.spawn(async move {
|
||||||
|
let mut lines = BufReader::new(response.body_mut()).lines();
|
||||||
|
|
||||||
|
fn parse_line(
|
||||||
|
line: Result<String, io::Error>,
|
||||||
|
) -> Result<Option<OpenAIResponseStreamEvent>> {
|
||||||
|
if let Some(data) = line?.strip_prefix("data: ") {
|
||||||
|
let event = serde_json::from_str(&data)?;
|
||||||
|
Ok(Some(event))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while let Some(line) = lines.next().await {
|
||||||
|
if let Some(event) = parse_line(line).transpose() {
|
||||||
|
let done = event.as_ref().map_or(false, |event| {
|
||||||
|
event
|
||||||
|
.choices
|
||||||
|
.last()
|
||||||
|
.map_or(false, |choice| choice.finish_reason.is_some())
|
||||||
|
});
|
||||||
|
if tx.unbounded_send(event).is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if done {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
|
||||||
|
Ok(rx)
|
||||||
|
} else {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIResponse {
|
||||||
|
error: OpenAIError,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIError {
|
||||||
|
message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
match serde_json::from_str::<OpenAIResponse>(&body) {
|
||||||
|
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||||
|
"Failed to connect to OpenAI API: {}",
|
||||||
|
response.error.message,
|
||||||
|
)),
|
||||||
|
|
||||||
|
_ => Err(anyhow!(
|
||||||
|
"Failed to connect to OpenAI API: {} {}",
|
||||||
|
response.status(),
|
||||||
|
body,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OpenAICompletionProvider {
|
||||||
|
model: OpenAILanguageModel,
|
||||||
|
credential: Arc<RwLock<ProviderCredential>>,
|
||||||
|
executor: Arc<Background>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAICompletionProvider {
|
||||||
|
pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
|
||||||
|
let model = OpenAILanguageModel::load(model_name);
|
||||||
|
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
credential,
|
||||||
|
executor,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CredentialProvider for OpenAICompletionProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
match *self.credential.read() {
|
||||||
|
ProviderCredential::Credentials { .. } => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
||||||
|
let mut credential = self.credential.write();
|
||||||
|
match *credential {
|
||||||
|
ProviderCredential::Credentials { .. } => {
|
||||||
|
return credential.clone();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||||
|
*credential = ProviderCredential::Credentials { api_key };
|
||||||
|
} else if let Some((_, api_key)) = cx
|
||||||
|
.platform()
|
||||||
|
.read_credentials(OPENAI_API_URL)
|
||||||
|
.log_err()
|
||||||
|
.flatten()
|
||||||
|
{
|
||||||
|
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||||
|
*credential = ProviderCredential::Credentials { api_key };
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
credential.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
|
||||||
|
match credential.clone() {
|
||||||
|
ProviderCredential::Credentials { api_key } => {
|
||||||
|
cx.platform()
|
||||||
|
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
*self.credential.write() = credential;
|
||||||
|
}
|
||||||
|
fn delete_credentials(&self, cx: &AppContext) {
|
||||||
|
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
||||||
|
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionProvider for OpenAICompletionProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||||
|
model
|
||||||
|
}
|
||||||
|
fn complete(
|
||||||
|
&self,
|
||||||
|
prompt: Box<dyn CompletionRequest>,
|
||||||
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
|
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
|
||||||
|
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
|
||||||
|
// which is currently model based, due to the langauge model.
|
||||||
|
// At some point in the future we should rectify this.
|
||||||
|
let credential = self.credential.read().clone();
|
||||||
|
let request = stream_completion(credential, self.executor.clone(), prompt);
|
||||||
|
async move {
|
||||||
|
let response = request.await?;
|
||||||
|
let stream = response
|
||||||
|
.filter_map(|response| async move {
|
||||||
|
match response {
|
||||||
|
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||||
|
Err(error) => Some(Err(error)),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed();
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||||
|
Box::new((*self).clone())
|
||||||
|
}
|
||||||
|
}
|
306
crates/ai/src/providers/open_ai/embedding.rs
Normal file
306
crates/ai/src/providers/open_ai/embedding.rs
Normal file
|
@ -0,0 +1,306 @@
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::AsyncReadExt;
|
||||||
|
use gpui::executor::Background;
|
||||||
|
use gpui::{serde_json, AppContext};
|
||||||
|
use isahc::http::StatusCode;
|
||||||
|
use isahc::prelude::Configurable;
|
||||||
|
use isahc::{AsyncBody, Response};
|
||||||
|
use lazy_static::lazy_static;
|
||||||
|
use parking_lot::{Mutex, RwLock};
|
||||||
|
use parse_duration::parse;
|
||||||
|
use postage::watch;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::env;
|
||||||
|
use std::ops::Add;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||||
|
use util::http::{HttpClient, Request};
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
use crate::auth::{CredentialProvider, ProviderCredential};
|
||||||
|
use crate::embedding::{Embedding, EmbeddingProvider};
|
||||||
|
use crate::models::LanguageModel;
|
||||||
|
use crate::providers::open_ai::OpenAILanguageModel;
|
||||||
|
|
||||||
|
use crate::providers::open_ai::OPENAI_API_URL;
|
||||||
|
|
||||||
|
lazy_static! {
|
||||||
|
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OpenAIEmbeddingProvider {
|
||||||
|
model: OpenAILanguageModel,
|
||||||
|
credential: Arc<RwLock<ProviderCredential>>,
|
||||||
|
pub client: Arc<dyn HttpClient>,
|
||||||
|
pub executor: Arc<Background>,
|
||||||
|
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
||||||
|
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct OpenAIEmbeddingRequest<'a> {
|
||||||
|
model: &'static str,
|
||||||
|
input: Vec<&'a str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIEmbeddingResponse {
|
||||||
|
data: Vec<OpenAIEmbedding>,
|
||||||
|
usage: OpenAIEmbeddingUsage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIEmbedding {
|
||||||
|
embedding: Vec<f32>,
|
||||||
|
index: usize,
|
||||||
|
object: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIEmbeddingUsage {
|
||||||
|
prompt_tokens: usize,
|
||||||
|
total_tokens: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAIEmbeddingProvider {
|
||||||
|
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
|
||||||
|
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||||
|
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||||
|
|
||||||
|
let model = OpenAILanguageModel::load("text-embedding-ada-002");
|
||||||
|
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||||
|
|
||||||
|
OpenAIEmbeddingProvider {
|
||||||
|
model,
|
||||||
|
credential,
|
||||||
|
client,
|
||||||
|
executor,
|
||||||
|
rate_limit_count_rx,
|
||||||
|
rate_limit_count_tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_api_key(&self) -> Result<String> {
|
||||||
|
match self.credential.read().clone() {
|
||||||
|
ProviderCredential::Credentials { api_key } => Ok(api_key),
|
||||||
|
_ => Err(anyhow!("api credentials not provided")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_rate_limit(&self) {
|
||||||
|
let reset_time = *self.rate_limit_count_tx.lock().borrow();
|
||||||
|
|
||||||
|
if let Some(reset_time) = reset_time {
|
||||||
|
if Instant::now() >= reset_time {
|
||||||
|
*self.rate_limit_count_tx.lock().borrow_mut() = None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"resolving reset time: {:?}",
|
||||||
|
*self.rate_limit_count_tx.lock().borrow()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_reset_time(&self, reset_time: Instant) {
|
||||||
|
let original_time = *self.rate_limit_count_tx.lock().borrow();
|
||||||
|
|
||||||
|
let updated_time = if let Some(original_time) = original_time {
|
||||||
|
if reset_time < original_time {
|
||||||
|
Some(reset_time)
|
||||||
|
} else {
|
||||||
|
Some(original_time)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Some(reset_time)
|
||||||
|
};
|
||||||
|
|
||||||
|
log::trace!("updating rate limit time: {:?}", updated_time);
|
||||||
|
|
||||||
|
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
|
||||||
|
}
|
||||||
|
async fn send_request(
|
||||||
|
&self,
|
||||||
|
api_key: &str,
|
||||||
|
spans: Vec<&str>,
|
||||||
|
request_timeout: u64,
|
||||||
|
) -> Result<Response<AsyncBody>> {
|
||||||
|
let request = Request::post("https://api.openai.com/v1/embeddings")
|
||||||
|
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
||||||
|
.timeout(Duration::from_secs(request_timeout))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
|
.body(
|
||||||
|
serde_json::to_string(&OpenAIEmbeddingRequest {
|
||||||
|
input: spans.clone(),
|
||||||
|
model: "text-embedding-ada-002",
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.into(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(self.client.send(request).await?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CredentialProvider for OpenAIEmbeddingProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
match *self.credential.read() {
|
||||||
|
ProviderCredential::Credentials { .. } => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
||||||
|
let mut credential = self.credential.write();
|
||||||
|
match *credential {
|
||||||
|
ProviderCredential::Credentials { .. } => {
|
||||||
|
return credential.clone();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||||
|
*credential = ProviderCredential::Credentials { api_key };
|
||||||
|
} else if let Some((_, api_key)) = cx
|
||||||
|
.platform()
|
||||||
|
.read_credentials(OPENAI_API_URL)
|
||||||
|
.log_err()
|
||||||
|
.flatten()
|
||||||
|
{
|
||||||
|
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||||
|
*credential = ProviderCredential::Credentials { api_key };
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
credential.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
|
||||||
|
match credential.clone() {
|
||||||
|
ProviderCredential::Credentials { api_key } => {
|
||||||
|
cx.platform()
|
||||||
|
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
*self.credential.write() = credential;
|
||||||
|
}
|
||||||
|
fn delete_credentials(&self, cx: &AppContext) {
|
||||||
|
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
||||||
|
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||||
|
model
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
|
50000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||||
|
*self.rate_limit_count_rx.borrow()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||||
|
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||||
|
const MAX_RETRIES: usize = 4;
|
||||||
|
|
||||||
|
let api_key = self.get_api_key()?;
|
||||||
|
|
||||||
|
let mut request_number = 0;
|
||||||
|
let mut rate_limiting = false;
|
||||||
|
let mut request_timeout: u64 = 15;
|
||||||
|
let mut response: Response<AsyncBody>;
|
||||||
|
while request_number < MAX_RETRIES {
|
||||||
|
response = self
|
||||||
|
.send_request(
|
||||||
|
&api_key,
|
||||||
|
spans.iter().map(|x| &**x).collect(),
|
||||||
|
request_timeout,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
request_number += 1;
|
||||||
|
|
||||||
|
match response.status() {
|
||||||
|
StatusCode::REQUEST_TIMEOUT => {
|
||||||
|
request_timeout += 5;
|
||||||
|
}
|
||||||
|
StatusCode::OK => {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"openai embedding completed. tokens: {:?}",
|
||||||
|
response.usage.total_tokens
|
||||||
|
);
|
||||||
|
|
||||||
|
// If we complete a request successfully that was previously rate_limited
|
||||||
|
// resolve the rate limit
|
||||||
|
if rate_limiting {
|
||||||
|
self.resolve_rate_limit()
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(response
|
||||||
|
.data
|
||||||
|
.into_iter()
|
||||||
|
.map(|embedding| Embedding::from(embedding.embedding))
|
||||||
|
.collect());
|
||||||
|
}
|
||||||
|
StatusCode::TOO_MANY_REQUESTS => {
|
||||||
|
rate_limiting = true;
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
|
||||||
|
let delay_duration = {
|
||||||
|
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
||||||
|
if let Some(time_to_reset) =
|
||||||
|
response.headers().get("x-ratelimit-reset-tokens")
|
||||||
|
{
|
||||||
|
if let Ok(time_str) = time_to_reset.to_str() {
|
||||||
|
parse(time_str).unwrap_or(delay)
|
||||||
|
} else {
|
||||||
|
delay
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
delay
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// If we've previously rate limited, increment the duration but not the count
|
||||||
|
let reset_time = Instant::now().add(delay_duration);
|
||||||
|
self.update_reset_time(reset_time);
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"openai rate limiting: waiting {:?} until lifted",
|
||||||
|
&delay_duration
|
||||||
|
);
|
||||||
|
|
||||||
|
self.executor.timer(delay_duration).await;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
return Err(anyhow!(
|
||||||
|
"open ai bad request: {:?} {:?}",
|
||||||
|
&response.status(),
|
||||||
|
body
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(anyhow!("openai max retries"))
|
||||||
|
}
|
||||||
|
}
|
9
crates/ai/src/providers/open_ai/mod.rs
Normal file
9
crates/ai/src/providers/open_ai/mod.rs
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
pub mod completion;
|
||||||
|
pub mod embedding;
|
||||||
|
pub mod model;
|
||||||
|
|
||||||
|
pub use completion::*;
|
||||||
|
pub use embedding::*;
|
||||||
|
pub use model::OpenAILanguageModel;
|
||||||
|
|
||||||
|
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
57
crates/ai/src/providers/open_ai/model.rs
Normal file
57
crates/ai/src/providers/open_ai/model.rs
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
use anyhow::anyhow;
|
||||||
|
use tiktoken_rs::CoreBPE;
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
use crate::models::{LanguageModel, TruncationDirection};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OpenAILanguageModel {
|
||||||
|
name: String,
|
||||||
|
bpe: Option<CoreBPE>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAILanguageModel {
|
||||||
|
pub fn load(model_name: &str) -> Self {
|
||||||
|
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
|
||||||
|
OpenAILanguageModel {
|
||||||
|
name: model_name.to_string(),
|
||||||
|
bpe,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModel for OpenAILanguageModel {
|
||||||
|
fn name(&self) -> String {
|
||||||
|
self.name.clone()
|
||||||
|
}
|
||||||
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||||
|
if let Some(bpe) = &self.bpe {
|
||||||
|
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn truncate(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: TruncationDirection,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
if let Some(bpe) = &self.bpe {
|
||||||
|
let tokens = bpe.encode_with_special_tokens(content);
|
||||||
|
if tokens.len() > length {
|
||||||
|
match direction {
|
||||||
|
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
|
||||||
|
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bpe.decode(tokens)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn capacity(&self) -> anyhow::Result<usize> {
|
||||||
|
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
|
||||||
|
}
|
||||||
|
}
|
11
crates/ai/src/providers/open_ai/new.rs
Normal file
11
crates/ai/src/providers/open_ai/new.rs
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
pub trait LanguageModel {
|
||||||
|
fn name(&self) -> String;
|
||||||
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||||
|
fn truncate(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: TruncationDirection,
|
||||||
|
) -> anyhow::Result<String>;
|
||||||
|
fn capacity(&self) -> anyhow::Result<usize>;
|
||||||
|
}
|
191
crates/ai/src/test.rs
Normal file
191
crates/ai/src/test.rs
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
use std::{
|
||||||
|
sync::atomic::{self, AtomicUsize, Ordering},
|
||||||
|
time::Instant,
|
||||||
|
};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
|
use gpui::AppContext;
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
auth::{CredentialProvider, ProviderCredential},
|
||||||
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
|
embedding::{Embedding, EmbeddingProvider},
|
||||||
|
models::{LanguageModel, TruncationDirection},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct FakeLanguageModel {
|
||||||
|
pub capacity: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModel for FakeLanguageModel {
|
||||||
|
fn name(&self) -> String {
|
||||||
|
"dummy".to_string()
|
||||||
|
}
|
||||||
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||||
|
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
||||||
|
}
|
||||||
|
fn truncate(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: TruncationDirection,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
println!("TRYING TO TRUNCATE: {:?}", length.clone());
|
||||||
|
|
||||||
|
if length > self.count_tokens(content)? {
|
||||||
|
println!("NOT TRUNCATING");
|
||||||
|
return anyhow::Ok(content.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::Ok(match direction {
|
||||||
|
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
|
||||||
|
.into_iter()
|
||||||
|
.collect::<String>(),
|
||||||
|
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
|
||||||
|
.into_iter()
|
||||||
|
.collect::<String>(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
fn capacity(&self) -> anyhow::Result<usize> {
|
||||||
|
anyhow::Ok(self.capacity)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FakeEmbeddingProvider {
|
||||||
|
pub embedding_count: AtomicUsize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for FakeEmbeddingProvider {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
FakeEmbeddingProvider {
|
||||||
|
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for FakeEmbeddingProvider {
|
||||||
|
fn default() -> Self {
|
||||||
|
FakeEmbeddingProvider {
|
||||||
|
embedding_count: AtomicUsize::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FakeEmbeddingProvider {
|
||||||
|
pub fn embedding_count(&self) -> usize {
|
||||||
|
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_sync(&self, span: &str) -> Embedding {
|
||||||
|
let mut result = vec![1.0; 26];
|
||||||
|
for letter in span.chars() {
|
||||||
|
let letter = letter.to_ascii_lowercase();
|
||||||
|
if letter as u32 >= 'a' as u32 {
|
||||||
|
let ix = (letter as u32) - ('a' as u32);
|
||||||
|
if ix < 26 {
|
||||||
|
result[ix as usize] += 1.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||||
|
for x in &mut result {
|
||||||
|
*x /= norm;
|
||||||
|
}
|
||||||
|
|
||||||
|
result.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CredentialProvider for FakeEmbeddingProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
|
||||||
|
ProviderCredential::NotNeeded
|
||||||
|
}
|
||||||
|
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
|
||||||
|
fn delete_credentials(&self, _cx: &AppContext) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
Box::new(FakeLanguageModel { capacity: 1000 })
|
||||||
|
}
|
||||||
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
|
1000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
|
||||||
|
self.embedding_count
|
||||||
|
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||||
|
|
||||||
|
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FakeCompletionProvider {
|
||||||
|
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for FakeCompletionProvider {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
last_completion_tx: Mutex::new(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FakeCompletionProvider {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
last_completion_tx: Mutex::new(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_completion(&self, completion: impl Into<String>) {
|
||||||
|
let mut tx = self.last_completion_tx.lock();
|
||||||
|
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn finish_completion(&self) {
|
||||||
|
self.last_completion_tx.lock().take().unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CredentialProvider for FakeCompletionProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
|
||||||
|
ProviderCredential::NotNeeded
|
||||||
|
}
|
||||||
|
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
|
||||||
|
fn delete_credentials(&self, _cx: &AppContext) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionProvider for FakeCompletionProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
|
||||||
|
model
|
||||||
|
}
|
||||||
|
fn complete(
|
||||||
|
&self,
|
||||||
|
_prompt: Box<dyn CompletionRequest>,
|
||||||
|
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
|
||||||
|
let (tx, rx) = mpsc::channel(1);
|
||||||
|
*self.last_completion_tx.lock() = Some(tx);
|
||||||
|
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||||
|
}
|
||||||
|
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||||
|
Box::new((*self).clone())
|
||||||
|
}
|
||||||
|
}
|
|
@ -45,6 +45,7 @@ tiktoken-rs = "0.5"
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
editor = { path = "../editor", features = ["test-support"] }
|
editor = { path = "../editor", features = ["test-support"] }
|
||||||
project = { path = "../project", features = ["test-support"] }
|
project = { path = "../project", features = ["test-support"] }
|
||||||
|
ai = { path = "../ai", features = ["test-support"]}
|
||||||
|
|
||||||
ctor.workspace = true
|
ctor.workspace = true
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
|
|
|
@ -4,7 +4,7 @@ mod codegen;
|
||||||
mod prompts;
|
mod prompts;
|
||||||
mod streaming_diff;
|
mod streaming_diff;
|
||||||
|
|
||||||
use ai::completion::Role;
|
use ai::providers::open_ai::Role;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
pub use assistant_panel::AssistantPanel;
|
pub use assistant_panel::AssistantPanel;
|
||||||
use assistant_settings::OpenAIModel;
|
use assistant_settings::OpenAIModel;
|
||||||
|
|
|
@ -5,12 +5,14 @@ use crate::{
|
||||||
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
|
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
|
||||||
SavedMessage,
|
SavedMessage,
|
||||||
};
|
};
|
||||||
|
|
||||||
use ai::{
|
use ai::{
|
||||||
completion::{
|
auth::ProviderCredential,
|
||||||
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
},
|
providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage},
|
||||||
templates::repository_context::PromptCodeSnippet,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use ai::prompts::repository_context::PromptCodeSnippet;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use chrono::{DateTime, Local};
|
use chrono::{DateTime, Local};
|
||||||
use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
|
use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
|
||||||
|
@ -43,8 +45,8 @@ use search::BufferSearchBar;
|
||||||
use semantic_index::{SemanticIndex, SemanticIndexStatus};
|
use semantic_index::{SemanticIndex, SemanticIndexStatus};
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use std::{
|
use std::{
|
||||||
cell::{Cell, RefCell},
|
cell::Cell,
|
||||||
cmp, env,
|
cmp,
|
||||||
fmt::Write,
|
fmt::Write,
|
||||||
iter,
|
iter,
|
||||||
ops::Range,
|
ops::Range,
|
||||||
|
@ -97,8 +99,8 @@ pub fn init(cx: &mut AppContext) {
|
||||||
cx.capture_action(ConversationEditor::copy);
|
cx.capture_action(ConversationEditor::copy);
|
||||||
cx.add_action(ConversationEditor::split);
|
cx.add_action(ConversationEditor::split);
|
||||||
cx.capture_action(ConversationEditor::cycle_message_role);
|
cx.capture_action(ConversationEditor::cycle_message_role);
|
||||||
cx.add_action(AssistantPanel::save_api_key);
|
cx.add_action(AssistantPanel::save_credentials);
|
||||||
cx.add_action(AssistantPanel::reset_api_key);
|
cx.add_action(AssistantPanel::reset_credentials);
|
||||||
cx.add_action(AssistantPanel::toggle_zoom);
|
cx.add_action(AssistantPanel::toggle_zoom);
|
||||||
cx.add_action(AssistantPanel::deploy);
|
cx.add_action(AssistantPanel::deploy);
|
||||||
cx.add_action(AssistantPanel::select_next_match);
|
cx.add_action(AssistantPanel::select_next_match);
|
||||||
|
@ -140,9 +142,8 @@ pub struct AssistantPanel {
|
||||||
zoomed: bool,
|
zoomed: bool,
|
||||||
has_focus: bool,
|
has_focus: bool,
|
||||||
toolbar: ViewHandle<Toolbar>,
|
toolbar: ViewHandle<Toolbar>,
|
||||||
api_key: Rc<RefCell<Option<String>>>,
|
completion_provider: Box<dyn CompletionProvider>,
|
||||||
api_key_editor: Option<ViewHandle<Editor>>,
|
api_key_editor: Option<ViewHandle<Editor>>,
|
||||||
has_read_credentials: bool,
|
|
||||||
languages: Arc<LanguageRegistry>,
|
languages: Arc<LanguageRegistry>,
|
||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
subscriptions: Vec<Subscription>,
|
subscriptions: Vec<Subscription>,
|
||||||
|
@ -202,6 +203,11 @@ impl AssistantPanel {
|
||||||
});
|
});
|
||||||
|
|
||||||
let semantic_index = SemanticIndex::global(cx);
|
let semantic_index = SemanticIndex::global(cx);
|
||||||
|
// Defaulting currently to GPT4, allow for this to be set via config.
|
||||||
|
let completion_provider = Box::new(OpenAICompletionProvider::new(
|
||||||
|
"gpt-4",
|
||||||
|
cx.background().clone(),
|
||||||
|
));
|
||||||
|
|
||||||
let mut this = Self {
|
let mut this = Self {
|
||||||
workspace: workspace_handle,
|
workspace: workspace_handle,
|
||||||
|
@ -213,9 +219,8 @@ impl AssistantPanel {
|
||||||
zoomed: false,
|
zoomed: false,
|
||||||
has_focus: false,
|
has_focus: false,
|
||||||
toolbar,
|
toolbar,
|
||||||
api_key: Rc::new(RefCell::new(None)),
|
completion_provider,
|
||||||
api_key_editor: None,
|
api_key_editor: None,
|
||||||
has_read_credentials: false,
|
|
||||||
languages: workspace.app_state().languages.clone(),
|
languages: workspace.app_state().languages.clone(),
|
||||||
fs: workspace.app_state().fs.clone(),
|
fs: workspace.app_state().fs.clone(),
|
||||||
width: None,
|
width: None,
|
||||||
|
@ -254,10 +259,7 @@ impl AssistantPanel {
|
||||||
cx: &mut ViewContext<Workspace>,
|
cx: &mut ViewContext<Workspace>,
|
||||||
) {
|
) {
|
||||||
let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
|
let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
|
||||||
if this
|
if this.update(cx, |assistant, _| assistant.has_credentials()) {
|
||||||
.update(cx, |assistant, cx| assistant.load_api_key(cx))
|
|
||||||
.is_some()
|
|
||||||
{
|
|
||||||
this
|
this
|
||||||
} else {
|
} else {
|
||||||
workspace.focus_panel::<AssistantPanel>(cx);
|
workspace.focus_panel::<AssistantPanel>(cx);
|
||||||
|
@ -289,12 +291,6 @@ impl AssistantPanel {
|
||||||
cx: &mut ViewContext<Self>,
|
cx: &mut ViewContext<Self>,
|
||||||
project: &ModelHandle<Project>,
|
project: &ModelHandle<Project>,
|
||||||
) {
|
) {
|
||||||
let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
|
|
||||||
api_key
|
|
||||||
} else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
let selection = editor.read(cx).selections.newest_anchor().clone();
|
let selection = editor.read(cx).selections.newest_anchor().clone();
|
||||||
if selection.start.excerpt_id != selection.end.excerpt_id {
|
if selection.start.excerpt_id != selection.end.excerpt_id {
|
||||||
return;
|
return;
|
||||||
|
@ -325,10 +321,13 @@ impl AssistantPanel {
|
||||||
|
|
||||||
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
|
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
|
||||||
let provider = Arc::new(OpenAICompletionProvider::new(
|
let provider = Arc::new(OpenAICompletionProvider::new(
|
||||||
api_key,
|
"gpt-4",
|
||||||
cx.background().clone(),
|
cx.background().clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
|
// Retrieve Credentials Authenticates the Provider
|
||||||
|
// provider.retrieve_credentials(cx);
|
||||||
|
|
||||||
let codegen = cx.add_model(|cx| {
|
let codegen = cx.add_model(|cx| {
|
||||||
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
|
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
|
||||||
});
|
});
|
||||||
|
@ -745,13 +744,14 @@ impl AssistantPanel {
|
||||||
content: prompt,
|
content: prompt,
|
||||||
});
|
});
|
||||||
|
|
||||||
let request = OpenAIRequest {
|
let request = Box::new(OpenAIRequest {
|
||||||
model: model.full_name().into(),
|
model: model.full_name().into(),
|
||||||
messages,
|
messages,
|
||||||
stream: true,
|
stream: true,
|
||||||
stop: vec!["|END|>".to_string()],
|
stop: vec!["|END|>".to_string()],
|
||||||
temperature,
|
temperature,
|
||||||
};
|
});
|
||||||
|
|
||||||
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
|
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
|
||||||
anyhow::Ok(())
|
anyhow::Ok(())
|
||||||
})
|
})
|
||||||
|
@ -811,7 +811,7 @@ impl AssistantPanel {
|
||||||
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
|
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
|
||||||
let editor = cx.add_view(|cx| {
|
let editor = cx.add_view(|cx| {
|
||||||
ConversationEditor::new(
|
ConversationEditor::new(
|
||||||
self.api_key.clone(),
|
self.completion_provider.clone(),
|
||||||
self.languages.clone(),
|
self.languages.clone(),
|
||||||
self.fs.clone(),
|
self.fs.clone(),
|
||||||
self.workspace.clone(),
|
self.workspace.clone(),
|
||||||
|
@ -870,17 +870,19 @@ impl AssistantPanel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||||
if let Some(api_key) = self
|
if let Some(api_key) = self
|
||||||
.api_key_editor
|
.api_key_editor
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|editor| editor.read(cx).text(cx))
|
.map(|editor| editor.read(cx).text(cx))
|
||||||
{
|
{
|
||||||
if !api_key.is_empty() {
|
if !api_key.is_empty() {
|
||||||
cx.platform()
|
let credential = ProviderCredential::Credentials {
|
||||||
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
api_key: api_key.clone(),
|
||||||
.log_err();
|
};
|
||||||
*self.api_key.borrow_mut() = Some(api_key);
|
|
||||||
|
self.completion_provider.save_credentials(cx, credential);
|
||||||
|
|
||||||
self.api_key_editor.take();
|
self.api_key_editor.take();
|
||||||
cx.focus_self();
|
cx.focus_self();
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
@ -890,9 +892,8 @@ impl AssistantPanel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
|
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
|
||||||
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
self.completion_provider.delete_credentials(cx);
|
||||||
self.api_key.take();
|
|
||||||
self.api_key_editor = Some(build_api_key_editor(cx));
|
self.api_key_editor = Some(build_api_key_editor(cx));
|
||||||
cx.focus_self();
|
cx.focus_self();
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
@ -1151,13 +1152,12 @@ impl AssistantPanel {
|
||||||
|
|
||||||
let fs = self.fs.clone();
|
let fs = self.fs.clone();
|
||||||
let workspace = self.workspace.clone();
|
let workspace = self.workspace.clone();
|
||||||
let api_key = self.api_key.clone();
|
|
||||||
let languages = self.languages.clone();
|
let languages = self.languages.clone();
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
let saved_conversation = fs.load(&path).await?;
|
let saved_conversation = fs.load(&path).await?;
|
||||||
let saved_conversation = serde_json::from_str(&saved_conversation)?;
|
let saved_conversation = serde_json::from_str(&saved_conversation)?;
|
||||||
let conversation = cx.add_model(|cx| {
|
let conversation = cx.add_model(|cx| {
|
||||||
Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
|
Conversation::deserialize(saved_conversation, path.clone(), languages, cx)
|
||||||
});
|
});
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
// If, by the time we've loaded the conversation, the user has already opened
|
// If, by the time we've loaded the conversation, the user has already opened
|
||||||
|
@ -1181,30 +1181,12 @@ impl AssistantPanel {
|
||||||
.position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
|
.position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_api_key(&mut self, cx: &mut ViewContext<Self>) -> Option<String> {
|
fn has_credentials(&mut self) -> bool {
|
||||||
if self.api_key.borrow().is_none() && !self.has_read_credentials {
|
self.completion_provider.has_credentials()
|
||||||
self.has_read_credentials = true;
|
}
|
||||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
|
||||||
Some(api_key)
|
|
||||||
} else if let Some((_, api_key)) = cx
|
|
||||||
.platform()
|
|
||||||
.read_credentials(OPENAI_API_URL)
|
|
||||||
.log_err()
|
|
||||||
.flatten()
|
|
||||||
{
|
|
||||||
String::from_utf8(api_key).log_err()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
if let Some(api_key) = api_key {
|
|
||||||
*self.api_key.borrow_mut() = Some(api_key);
|
|
||||||
} else if self.api_key_editor.is_none() {
|
|
||||||
self.api_key_editor = Some(build_api_key_editor(cx));
|
|
||||||
cx.notify();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.api_key.borrow().clone()
|
fn load_credentials(&mut self, cx: &mut ViewContext<Self>) {
|
||||||
|
self.completion_provider.retrieve_credentials(cx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1389,7 +1371,7 @@ impl Panel for AssistantPanel {
|
||||||
|
|
||||||
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
|
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
|
||||||
if active {
|
if active {
|
||||||
self.load_api_key(cx);
|
self.load_credentials(cx);
|
||||||
|
|
||||||
if self.editors.is_empty() {
|
if self.editors.is_empty() {
|
||||||
self.new_conversation(cx);
|
self.new_conversation(cx);
|
||||||
|
@ -1454,10 +1436,10 @@ struct Conversation {
|
||||||
token_count: Option<usize>,
|
token_count: Option<usize>,
|
||||||
max_token_count: usize,
|
max_token_count: usize,
|
||||||
pending_token_count: Task<Option<()>>,
|
pending_token_count: Task<Option<()>>,
|
||||||
api_key: Rc<RefCell<Option<String>>>,
|
|
||||||
pending_save: Task<Result<()>>,
|
pending_save: Task<Result<()>>,
|
||||||
path: Option<PathBuf>,
|
path: Option<PathBuf>,
|
||||||
_subscriptions: Vec<Subscription>,
|
_subscriptions: Vec<Subscription>,
|
||||||
|
completion_provider: Box<dyn CompletionProvider>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Entity for Conversation {
|
impl Entity for Conversation {
|
||||||
|
@ -1466,9 +1448,9 @@ impl Entity for Conversation {
|
||||||
|
|
||||||
impl Conversation {
|
impl Conversation {
|
||||||
fn new(
|
fn new(
|
||||||
api_key: Rc<RefCell<Option<String>>>,
|
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
|
completion_provider: Box<dyn CompletionProvider>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let markdown = language_registry.language_for_name("Markdown");
|
let markdown = language_registry.language_for_name("Markdown");
|
||||||
let buffer = cx.add_model(|cx| {
|
let buffer = cx.add_model(|cx| {
|
||||||
|
@ -1507,8 +1489,8 @@ impl Conversation {
|
||||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||||
pending_save: Task::ready(Ok(())),
|
pending_save: Task::ready(Ok(())),
|
||||||
path: None,
|
path: None,
|
||||||
api_key,
|
|
||||||
buffer,
|
buffer,
|
||||||
|
completion_provider,
|
||||||
};
|
};
|
||||||
let message = MessageAnchor {
|
let message = MessageAnchor {
|
||||||
id: MessageId(post_inc(&mut this.next_message_id.0)),
|
id: MessageId(post_inc(&mut this.next_message_id.0)),
|
||||||
|
@ -1554,7 +1536,6 @@ impl Conversation {
|
||||||
fn deserialize(
|
fn deserialize(
|
||||||
saved_conversation: SavedConversation,
|
saved_conversation: SavedConversation,
|
||||||
path: PathBuf,
|
path: PathBuf,
|
||||||
api_key: Rc<RefCell<Option<String>>>,
|
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
@ -1563,6 +1544,10 @@ impl Conversation {
|
||||||
None => Some(Uuid::new_v4().to_string()),
|
None => Some(Uuid::new_v4().to_string()),
|
||||||
};
|
};
|
||||||
let model = saved_conversation.model;
|
let model = saved_conversation.model;
|
||||||
|
let completion_provider: Box<dyn CompletionProvider> = Box::new(
|
||||||
|
OpenAICompletionProvider::new(model.full_name(), cx.background().clone()),
|
||||||
|
);
|
||||||
|
completion_provider.retrieve_credentials(cx);
|
||||||
let markdown = language_registry.language_for_name("Markdown");
|
let markdown = language_registry.language_for_name("Markdown");
|
||||||
let mut message_anchors = Vec::new();
|
let mut message_anchors = Vec::new();
|
||||||
let mut next_message_id = MessageId(0);
|
let mut next_message_id = MessageId(0);
|
||||||
|
@ -1609,8 +1594,8 @@ impl Conversation {
|
||||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||||
pending_save: Task::ready(Ok(())),
|
pending_save: Task::ready(Ok(())),
|
||||||
path: Some(path),
|
path: Some(path),
|
||||||
api_key,
|
|
||||||
buffer,
|
buffer,
|
||||||
|
completion_provider,
|
||||||
};
|
};
|
||||||
this.count_remaining_tokens(cx);
|
this.count_remaining_tokens(cx);
|
||||||
this
|
this
|
||||||
|
@ -1731,11 +1716,11 @@ impl Conversation {
|
||||||
}
|
}
|
||||||
|
|
||||||
if should_assist {
|
if should_assist {
|
||||||
let Some(api_key) = self.api_key.borrow().clone() else {
|
if !self.completion_provider.has_credentials() {
|
||||||
return Default::default();
|
return Default::default();
|
||||||
};
|
}
|
||||||
|
|
||||||
let request = OpenAIRequest {
|
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
|
||||||
model: self.model.full_name().to_string(),
|
model: self.model.full_name().to_string(),
|
||||||
messages: self
|
messages: self
|
||||||
.messages(cx)
|
.messages(cx)
|
||||||
|
@ -1745,9 +1730,9 @@ impl Conversation {
|
||||||
stream: true,
|
stream: true,
|
||||||
stop: vec![],
|
stop: vec![],
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
};
|
});
|
||||||
|
|
||||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
let stream = self.completion_provider.complete(request);
|
||||||
let assistant_message = self
|
let assistant_message = self
|
||||||
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -1765,33 +1750,28 @@ impl Conversation {
|
||||||
let mut messages = stream.await?;
|
let mut messages = stream.await?;
|
||||||
|
|
||||||
while let Some(message) = messages.next().await {
|
while let Some(message) = messages.next().await {
|
||||||
let mut message = message?;
|
let text = message?;
|
||||||
if let Some(choice) = message.choices.pop() {
|
|
||||||
this.upgrade(&cx)
|
|
||||||
.ok_or_else(|| anyhow!("conversation was dropped"))?
|
|
||||||
.update(&mut cx, |this, cx| {
|
|
||||||
let text: Arc<str> = choice.delta.content?.into();
|
|
||||||
let message_ix =
|
|
||||||
this.message_anchors.iter().position(|message| {
|
|
||||||
message.id == assistant_message_id
|
|
||||||
})?;
|
|
||||||
this.buffer.update(cx, |buffer, cx| {
|
|
||||||
let offset = this.message_anchors[message_ix + 1..]
|
|
||||||
.iter()
|
|
||||||
.find(|message| message.start.is_valid(buffer))
|
|
||||||
.map_or(buffer.len(), |message| {
|
|
||||||
message
|
|
||||||
.start
|
|
||||||
.to_offset(buffer)
|
|
||||||
.saturating_sub(1)
|
|
||||||
});
|
|
||||||
buffer.edit([(offset..offset, text)], None, cx);
|
|
||||||
});
|
|
||||||
cx.emit(ConversationEvent::StreamedCompletion);
|
|
||||||
|
|
||||||
Some(())
|
this.upgrade(&cx)
|
||||||
|
.ok_or_else(|| anyhow!("conversation was dropped"))?
|
||||||
|
.update(&mut cx, |this, cx| {
|
||||||
|
let message_ix = this
|
||||||
|
.message_anchors
|
||||||
|
.iter()
|
||||||
|
.position(|message| message.id == assistant_message_id)?;
|
||||||
|
this.buffer.update(cx, |buffer, cx| {
|
||||||
|
let offset = this.message_anchors[message_ix + 1..]
|
||||||
|
.iter()
|
||||||
|
.find(|message| message.start.is_valid(buffer))
|
||||||
|
.map_or(buffer.len(), |message| {
|
||||||
|
message.start.to_offset(buffer).saturating_sub(1)
|
||||||
|
});
|
||||||
|
buffer.edit([(offset..offset, text)], None, cx);
|
||||||
});
|
});
|
||||||
}
|
cx.emit(ConversationEvent::StreamedCompletion);
|
||||||
|
|
||||||
|
Some(())
|
||||||
|
});
|
||||||
smol::future::yield_now().await;
|
smol::future::yield_now().await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2013,57 +1993,54 @@ impl Conversation {
|
||||||
|
|
||||||
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
|
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
|
||||||
if self.message_anchors.len() >= 2 && self.summary.is_none() {
|
if self.message_anchors.len() >= 2 && self.summary.is_none() {
|
||||||
let api_key = self.api_key.borrow().clone();
|
if !self.completion_provider.has_credentials() {
|
||||||
if let Some(api_key) = api_key {
|
return;
|
||||||
let messages = self
|
|
||||||
.messages(cx)
|
|
||||||
.take(2)
|
|
||||||
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
|
|
||||||
.chain(Some(RequestMessage {
|
|
||||||
role: Role::User,
|
|
||||||
content:
|
|
||||||
"Summarize the conversation into a short title without punctuation"
|
|
||||||
.into(),
|
|
||||||
}));
|
|
||||||
let request = OpenAIRequest {
|
|
||||||
model: self.model.full_name().to_string(),
|
|
||||||
messages: messages.collect(),
|
|
||||||
stream: true,
|
|
||||||
stop: vec![],
|
|
||||||
temperature: 1.0,
|
|
||||||
};
|
|
||||||
|
|
||||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
|
||||||
self.pending_summary = cx.spawn(|this, mut cx| {
|
|
||||||
async move {
|
|
||||||
let mut messages = stream.await?;
|
|
||||||
|
|
||||||
while let Some(message) = messages.next().await {
|
|
||||||
let mut message = message?;
|
|
||||||
if let Some(choice) = message.choices.pop() {
|
|
||||||
let text = choice.delta.content.unwrap_or_default();
|
|
||||||
this.update(&mut cx, |this, cx| {
|
|
||||||
this.summary
|
|
||||||
.get_or_insert(Default::default())
|
|
||||||
.text
|
|
||||||
.push_str(&text);
|
|
||||||
cx.emit(ConversationEvent::SummaryChanged);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
this.update(&mut cx, |this, cx| {
|
|
||||||
if let Some(summary) = this.summary.as_mut() {
|
|
||||||
summary.done = true;
|
|
||||||
cx.emit(ConversationEvent::SummaryChanged);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
anyhow::Ok(())
|
|
||||||
}
|
|
||||||
.log_err()
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let messages = self
|
||||||
|
.messages(cx)
|
||||||
|
.take(2)
|
||||||
|
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
|
||||||
|
.chain(Some(RequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: "Summarize the conversation into a short title without punctuation"
|
||||||
|
.into(),
|
||||||
|
}));
|
||||||
|
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
|
||||||
|
model: self.model.full_name().to_string(),
|
||||||
|
messages: messages.collect(),
|
||||||
|
stream: true,
|
||||||
|
stop: vec![],
|
||||||
|
temperature: 1.0,
|
||||||
|
});
|
||||||
|
|
||||||
|
let stream = self.completion_provider.complete(request);
|
||||||
|
self.pending_summary = cx.spawn(|this, mut cx| {
|
||||||
|
async move {
|
||||||
|
let mut messages = stream.await?;
|
||||||
|
|
||||||
|
while let Some(message) = messages.next().await {
|
||||||
|
let text = message?;
|
||||||
|
this.update(&mut cx, |this, cx| {
|
||||||
|
this.summary
|
||||||
|
.get_or_insert(Default::default())
|
||||||
|
.text
|
||||||
|
.push_str(&text);
|
||||||
|
cx.emit(ConversationEvent::SummaryChanged);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
this.update(&mut cx, |this, cx| {
|
||||||
|
if let Some(summary) = this.summary.as_mut() {
|
||||||
|
summary.done = true;
|
||||||
|
cx.emit(ConversationEvent::SummaryChanged);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
}
|
||||||
|
.log_err()
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2224,13 +2201,14 @@ struct ConversationEditor {
|
||||||
|
|
||||||
impl ConversationEditor {
|
impl ConversationEditor {
|
||||||
fn new(
|
fn new(
|
||||||
api_key: Rc<RefCell<Option<String>>>,
|
completion_provider: Box<dyn CompletionProvider>,
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
workspace: WeakViewHandle<Workspace>,
|
workspace: WeakViewHandle<Workspace>,
|
||||||
cx: &mut ViewContext<Self>,
|
cx: &mut ViewContext<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
|
let conversation =
|
||||||
|
cx.add_model(|cx| Conversation::new(language_registry, cx, completion_provider));
|
||||||
Self::for_conversation(conversation, fs, workspace, cx)
|
Self::for_conversation(conversation, fs, workspace, cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3419,6 +3397,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::MessageId;
|
use crate::MessageId;
|
||||||
|
use ai::test::FakeCompletionProvider;
|
||||||
use gpui::AppContext;
|
use gpui::AppContext;
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
|
@ -3426,7 +3405,9 @@ mod tests {
|
||||||
cx.set_global(SettingsStore::test(cx));
|
cx.set_global(SettingsStore::test(cx));
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test());
|
let registry = Arc::new(LanguageRegistry::test());
|
||||||
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
|
|
||||||
|
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||||
|
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read(cx).buffer.clone();
|
||||||
|
|
||||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||||
|
@ -3554,7 +3535,9 @@ mod tests {
|
||||||
cx.set_global(SettingsStore::test(cx));
|
cx.set_global(SettingsStore::test(cx));
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test());
|
let registry = Arc::new(LanguageRegistry::test());
|
||||||
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
|
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||||
|
|
||||||
|
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read(cx).buffer.clone();
|
||||||
|
|
||||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||||
|
@ -3650,7 +3633,8 @@ mod tests {
|
||||||
cx.set_global(SettingsStore::test(cx));
|
cx.set_global(SettingsStore::test(cx));
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test());
|
let registry = Arc::new(LanguageRegistry::test());
|
||||||
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
|
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||||
|
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read(cx).buffer.clone();
|
||||||
|
|
||||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||||
|
@ -3732,8 +3716,9 @@ mod tests {
|
||||||
cx.set_global(SettingsStore::test(cx));
|
cx.set_global(SettingsStore::test(cx));
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test());
|
let registry = Arc::new(LanguageRegistry::test());
|
||||||
|
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||||
let conversation =
|
let conversation =
|
||||||
cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
|
cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read(cx).buffer.clone();
|
||||||
let message_0 = conversation.read(cx).message_anchors[0].id;
|
let message_0 = conversation.read(cx).message_anchors[0].id;
|
||||||
let message_1 = conversation.update(cx, |conversation, cx| {
|
let message_1 = conversation.update(cx, |conversation, cx| {
|
||||||
|
@ -3770,7 +3755,6 @@ mod tests {
|
||||||
Conversation::deserialize(
|
Conversation::deserialize(
|
||||||
conversation.read(cx).serialize(cx),
|
conversation.read(cx).serialize(cx),
|
||||||
Default::default(),
|
Default::default(),
|
||||||
Default::default(),
|
|
||||||
registry.clone(),
|
registry.clone(),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::streaming_diff::{Hunk, StreamingDiff};
|
use crate::streaming_diff::{Hunk, StreamingDiff};
|
||||||
use ai::completion::{CompletionProvider, OpenAIRequest};
|
use ai::completion::{CompletionProvider, CompletionRequest};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
|
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
|
||||||
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
|
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
|
||||||
|
@ -96,7 +96,7 @@ impl Codegen {
|
||||||
self.error.as_ref()
|
self.error.as_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
|
pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
|
||||||
let range = self.range();
|
let range = self.range();
|
||||||
let snapshot = self.snapshot.clone();
|
let snapshot = self.snapshot.clone();
|
||||||
let selected_text = snapshot
|
let selected_text = snapshot
|
||||||
|
@ -336,17 +336,25 @@ fn strip_markdown_codeblock(
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use futures::{
|
use ai::test::FakeCompletionProvider;
|
||||||
future::BoxFuture,
|
use futures::stream::{self};
|
||||||
stream::{self, BoxStream},
|
|
||||||
};
|
|
||||||
use gpui::{executor::Deterministic, TestAppContext};
|
use gpui::{executor::Deterministic, TestAppContext};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
|
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
|
||||||
use parking_lot::Mutex;
|
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
use serde::Serialize;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use smol::future::FutureExt;
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct DummyCompletionRequest {
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionRequest for DummyCompletionRequest {
|
||||||
|
fn data(&self) -> serde_json::Result<String> {
|
||||||
|
serde_json::to_string(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test(iterations = 10)]
|
#[gpui::test(iterations = 10)]
|
||||||
async fn test_transform_autoindent(
|
async fn test_transform_autoindent(
|
||||||
|
@ -372,7 +380,7 @@ mod tests {
|
||||||
let snapshot = buffer.snapshot(cx);
|
let snapshot = buffer.snapshot(cx);
|
||||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||||
});
|
});
|
||||||
let provider = Arc::new(TestCompletionProvider::new());
|
let provider = Arc::new(FakeCompletionProvider::new());
|
||||||
let codegen = cx.add_model(|cx| {
|
let codegen = cx.add_model(|cx| {
|
||||||
Codegen::new(
|
Codegen::new(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
|
@ -381,7 +389,11 @@ mod tests {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
|
||||||
|
let request = Box::new(DummyCompletionRequest {
|
||||||
|
name: "test".to_string(),
|
||||||
|
});
|
||||||
|
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||||
|
|
||||||
let mut new_text = concat!(
|
let mut new_text = concat!(
|
||||||
" let mut x = 0;\n",
|
" let mut x = 0;\n",
|
||||||
|
@ -434,7 +446,7 @@ mod tests {
|
||||||
let snapshot = buffer.snapshot(cx);
|
let snapshot = buffer.snapshot(cx);
|
||||||
snapshot.anchor_before(Point::new(1, 6))
|
snapshot.anchor_before(Point::new(1, 6))
|
||||||
});
|
});
|
||||||
let provider = Arc::new(TestCompletionProvider::new());
|
let provider = Arc::new(FakeCompletionProvider::new());
|
||||||
let codegen = cx.add_model(|cx| {
|
let codegen = cx.add_model(|cx| {
|
||||||
Codegen::new(
|
Codegen::new(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
|
@ -443,7 +455,11 @@ mod tests {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
|
||||||
|
let request = Box::new(DummyCompletionRequest {
|
||||||
|
name: "test".to_string(),
|
||||||
|
});
|
||||||
|
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||||
|
|
||||||
let mut new_text = concat!(
|
let mut new_text = concat!(
|
||||||
"t mut x = 0;\n",
|
"t mut x = 0;\n",
|
||||||
|
@ -496,7 +512,7 @@ mod tests {
|
||||||
let snapshot = buffer.snapshot(cx);
|
let snapshot = buffer.snapshot(cx);
|
||||||
snapshot.anchor_before(Point::new(1, 2))
|
snapshot.anchor_before(Point::new(1, 2))
|
||||||
});
|
});
|
||||||
let provider = Arc::new(TestCompletionProvider::new());
|
let provider = Arc::new(FakeCompletionProvider::new());
|
||||||
let codegen = cx.add_model(|cx| {
|
let codegen = cx.add_model(|cx| {
|
||||||
Codegen::new(
|
Codegen::new(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
|
@ -505,7 +521,11 @@ mod tests {
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
|
||||||
|
let request = Box::new(DummyCompletionRequest {
|
||||||
|
name: "test".to_string(),
|
||||||
|
});
|
||||||
|
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||||
|
|
||||||
let mut new_text = concat!(
|
let mut new_text = concat!(
|
||||||
"let mut x = 0;\n",
|
"let mut x = 0;\n",
|
||||||
|
@ -593,38 +613,6 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TestCompletionProvider {
|
|
||||||
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TestCompletionProvider {
|
|
||||||
fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
last_completion_tx: Mutex::new(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn send_completion(&self, completion: impl Into<String>) {
|
|
||||||
let mut tx = self.last_completion_tx.lock();
|
|
||||||
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn finish_completion(&self) {
|
|
||||||
self.last_completion_tx.lock().take().unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CompletionProvider for TestCompletionProvider {
|
|
||||||
fn complete(
|
|
||||||
&self,
|
|
||||||
_prompt: OpenAIRequest,
|
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
|
||||||
let (tx, rx) = mpsc::channel(1);
|
|
||||||
*self.last_completion_tx.lock() = Some(tx);
|
|
||||||
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rust_lang() -> Language {
|
fn rust_lang() -> Language {
|
||||||
Language::new(
|
Language::new(
|
||||||
LanguageConfig {
|
LanguageConfig {
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
use ai::models::{LanguageModel, OpenAILanguageModel};
|
use ai::models::LanguageModel;
|
||||||
use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
|
use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
|
||||||
use ai::templates::file_context::FileContext;
|
use ai::prompts::file_context::FileContext;
|
||||||
use ai::templates::generate::GenerateInlineContent;
|
use ai::prompts::generate::GenerateInlineContent;
|
||||||
use ai::templates::preamble::EngineerPreamble;
|
use ai::prompts::preamble::EngineerPreamble;
|
||||||
use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
|
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
|
||||||
|
use ai::providers::open_ai::OpenAILanguageModel;
|
||||||
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
||||||
use std::cmp::{self, Reverse};
|
use std::cmp::{self, Reverse};
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
|
|
|
@ -967,7 +967,6 @@ impl CompletionsMenu {
|
||||||
self.selected_item -= 1;
|
self.selected_item -= 1;
|
||||||
} else {
|
} else {
|
||||||
self.selected_item = self.matches.len() - 1;
|
self.selected_item = self.matches.len() - 1;
|
||||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
|
||||||
}
|
}
|
||||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
||||||
self.attempt_resolve_selected_completion_documentation(project, cx);
|
self.attempt_resolve_selected_completion_documentation(project, cx);
|
||||||
|
@ -1538,7 +1537,6 @@ impl CodeActionsMenu {
|
||||||
self.selected_item -= 1;
|
self.selected_item -= 1;
|
||||||
} else {
|
} else {
|
||||||
self.selected_item = self.actions.len() - 1;
|
self.selected_item = self.actions.len() - 1;
|
||||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
|
||||||
}
|
}
|
||||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
@ -1547,11 +1545,10 @@ impl CodeActionsMenu {
|
||||||
fn select_next(&mut self, cx: &mut ViewContext<Editor>) {
|
fn select_next(&mut self, cx: &mut ViewContext<Editor>) {
|
||||||
if self.selected_item + 1 < self.actions.len() {
|
if self.selected_item + 1 < self.actions.len() {
|
||||||
self.selected_item += 1;
|
self.selected_item += 1;
|
||||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
|
||||||
} else {
|
} else {
|
||||||
self.selected_item = 0;
|
self.selected_item = 0;
|
||||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
|
||||||
}
|
}
|
||||||
|
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,7 @@ sha1 = "0.10.5"
|
||||||
ndarray = { version = "0.15.0" }
|
ndarray = { version = "0.15.0" }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
ai = { path = "../ai", features = ["test-support"] }
|
||||||
collections = { path = "../collections", features = ["test-support"] }
|
collections = { path = "../collections", features = ["test-support"] }
|
||||||
gpui = { path = "../gpui", features = ["test-support"] }
|
gpui = { path = "../gpui", features = ["test-support"] }
|
||||||
language = { path = "../language", features = ["test-support"] }
|
language = { path = "../language", features = ["test-support"] }
|
||||||
|
|
|
@ -41,7 +41,6 @@ pub struct EmbeddingQueue {
|
||||||
pending_batch_token_count: usize,
|
pending_batch_token_count: usize,
|
||||||
finished_files_tx: channel::Sender<FileToEmbed>,
|
finished_files_tx: channel::Sender<FileToEmbed>,
|
||||||
finished_files_rx: channel::Receiver<FileToEmbed>,
|
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||||
api_key: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbeddingQueue {
|
impl EmbeddingQueue {
|
||||||
pub fn new(
|
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
|
||||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
|
||||||
executor: Arc<Background>,
|
|
||||||
api_key: Option<String>,
|
|
||||||
) -> Self {
|
|
||||||
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
||||||
Self {
|
Self {
|
||||||
embedding_provider,
|
embedding_provider,
|
||||||
|
@ -64,14 +59,9 @@ impl EmbeddingQueue {
|
||||||
pending_batch_token_count: 0,
|
pending_batch_token_count: 0,
|
||||||
finished_files_tx,
|
finished_files_tx,
|
||||||
finished_files_rx,
|
finished_files_rx,
|
||||||
api_key,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_api_key(&mut self, api_key: Option<String>) {
|
|
||||||
self.api_key = api_key
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn push(&mut self, file: FileToEmbed) {
|
pub fn push(&mut self, file: FileToEmbed) {
|
||||||
if file.spans.is_empty() {
|
if file.spans.is_empty() {
|
||||||
self.finished_files_tx.try_send(file).unwrap();
|
self.finished_files_tx.try_send(file).unwrap();
|
||||||
|
@ -118,7 +108,6 @@ impl EmbeddingQueue {
|
||||||
|
|
||||||
let finished_files_tx = self.finished_files_tx.clone();
|
let finished_files_tx = self.finished_files_tx.clone();
|
||||||
let embedding_provider = self.embedding_provider.clone();
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
let api_key = self.api_key.clone();
|
|
||||||
|
|
||||||
self.executor
|
self.executor
|
||||||
.spawn(async move {
|
.spawn(async move {
|
||||||
|
@ -143,7 +132,7 @@ impl EmbeddingQueue {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
match embedding_provider.embed_batch(spans, api_key).await {
|
match embedding_provider.embed_batch(spans).await {
|
||||||
Ok(embeddings) => {
|
Ok(embeddings) => {
|
||||||
let mut embeddings = embeddings.into_iter();
|
let mut embeddings = embeddings.into_iter();
|
||||||
for fragment in batch {
|
for fragment in batch {
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
use ai::embedding::{Embedding, EmbeddingProvider};
|
use ai::{
|
||||||
|
embedding::{Embedding, EmbeddingProvider},
|
||||||
|
models::TruncationDirection,
|
||||||
|
};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use language::{Grammar, Language};
|
use language::{Grammar, Language};
|
||||||
use rusqlite::{
|
use rusqlite::{
|
||||||
|
@ -108,7 +111,14 @@ impl CodeContextRetriever {
|
||||||
.replace("<language>", language_name.as_ref())
|
.replace("<language>", language_name.as_ref())
|
||||||
.replace("<item>", &content);
|
.replace("<item>", &content);
|
||||||
let digest = SpanDigest::from(document_span.as_str());
|
let digest = SpanDigest::from(document_span.as_str());
|
||||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
let model = self.embedding_provider.base_model();
|
||||||
|
let document_span = model.truncate(
|
||||||
|
&document_span,
|
||||||
|
model.capacity()?,
|
||||||
|
ai::models::TruncationDirection::End,
|
||||||
|
)?;
|
||||||
|
let token_count = model.count_tokens(&document_span)?;
|
||||||
|
|
||||||
Ok(vec![Span {
|
Ok(vec![Span {
|
||||||
range: 0..content.len(),
|
range: 0..content.len(),
|
||||||
content: document_span,
|
content: document_span,
|
||||||
|
@ -131,7 +141,15 @@ impl CodeContextRetriever {
|
||||||
)
|
)
|
||||||
.replace("<item>", &content);
|
.replace("<item>", &content);
|
||||||
let digest = SpanDigest::from(document_span.as_str());
|
let digest = SpanDigest::from(document_span.as_str());
|
||||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
|
||||||
|
let model = self.embedding_provider.base_model();
|
||||||
|
let document_span = model.truncate(
|
||||||
|
&document_span,
|
||||||
|
model.capacity()?,
|
||||||
|
ai::models::TruncationDirection::End,
|
||||||
|
)?;
|
||||||
|
let token_count = model.count_tokens(&document_span)?;
|
||||||
|
|
||||||
Ok(vec![Span {
|
Ok(vec![Span {
|
||||||
range: 0..content.len(),
|
range: 0..content.len(),
|
||||||
content: document_span,
|
content: document_span,
|
||||||
|
@ -222,8 +240,13 @@ impl CodeContextRetriever {
|
||||||
.replace("<language>", language_name.as_ref())
|
.replace("<language>", language_name.as_ref())
|
||||||
.replace("item", &span.content);
|
.replace("item", &span.content);
|
||||||
|
|
||||||
let (document_content, token_count) =
|
let model = self.embedding_provider.base_model();
|
||||||
self.embedding_provider.truncate(&document_content);
|
let document_content = model.truncate(
|
||||||
|
&document_content,
|
||||||
|
model.capacity()?,
|
||||||
|
TruncationDirection::End,
|
||||||
|
)?;
|
||||||
|
let token_count = model.count_tokens(&document_content)?;
|
||||||
|
|
||||||
span.content = document_content;
|
span.content = document_content;
|
||||||
span.token_count = token_count;
|
span.token_count = token_count;
|
||||||
|
|
|
@ -7,7 +7,8 @@ pub mod semantic_index_settings;
|
||||||
mod semantic_index_tests;
|
mod semantic_index_tests;
|
||||||
|
|
||||||
use crate::semantic_index_settings::SemanticIndexSettings;
|
use crate::semantic_index_settings::SemanticIndexSettings;
|
||||||
use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
|
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||||
|
use ai::providers::open_ai::OpenAIEmbeddingProvider;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use collections::{BTreeMap, HashMap, HashSet};
|
use collections::{BTreeMap, HashMap, HashSet};
|
||||||
use db::VectorDatabase;
|
use db::VectorDatabase;
|
||||||
|
@ -88,7 +89,7 @@ pub fn init(
|
||||||
let semantic_index = SemanticIndex::new(
|
let semantic_index = SemanticIndex::new(
|
||||||
fs,
|
fs,
|
||||||
db_file_path,
|
db_file_path,
|
||||||
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
|
Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
|
||||||
language_registry,
|
language_registry,
|
||||||
cx.clone(),
|
cx.clone(),
|
||||||
)
|
)
|
||||||
|
@ -123,8 +124,6 @@ pub struct SemanticIndex {
|
||||||
_embedding_task: Task<()>,
|
_embedding_task: Task<()>,
|
||||||
_parsing_files_tasks: Vec<Task<()>>,
|
_parsing_files_tasks: Vec<Task<()>>,
|
||||||
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
||||||
api_key: Option<String>,
|
|
||||||
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ProjectState {
|
struct ProjectState {
|
||||||
|
@ -278,18 +277,18 @@ impl SemanticIndex {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn authenticate(&mut self, cx: &AppContext) {
|
pub fn authenticate(&mut self, cx: &AppContext) -> bool {
|
||||||
if self.api_key.is_none() {
|
if !self.embedding_provider.has_credentials() {
|
||||||
self.api_key = self.embedding_provider.retrieve_credentials(cx);
|
self.embedding_provider.retrieve_credentials(cx);
|
||||||
|
} else {
|
||||||
self.embedding_queue
|
return true;
|
||||||
.lock()
|
|
||||||
.set_api_key(self.api_key.clone());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.embedding_provider.has_credentials()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_authenticated(&self) -> bool {
|
pub fn is_authenticated(&self) -> bool {
|
||||||
self.api_key.is_some()
|
self.embedding_provider.has_credentials()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn enabled(cx: &AppContext) -> bool {
|
pub fn enabled(cx: &AppContext) -> bool {
|
||||||
|
@ -339,7 +338,7 @@ impl SemanticIndex {
|
||||||
Ok(cx.add_model(|cx| {
|
Ok(cx.add_model(|cx| {
|
||||||
let t0 = Instant::now();
|
let t0 = Instant::now();
|
||||||
let embedding_queue =
|
let embedding_queue =
|
||||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
|
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
|
||||||
let _embedding_task = cx.background().spawn({
|
let _embedding_task = cx.background().spawn({
|
||||||
let embedded_files = embedding_queue.finished_files();
|
let embedded_files = embedding_queue.finished_files();
|
||||||
let db = db.clone();
|
let db = db.clone();
|
||||||
|
@ -404,8 +403,6 @@ impl SemanticIndex {
|
||||||
_embedding_task,
|
_embedding_task,
|
||||||
_parsing_files_tasks,
|
_parsing_files_tasks,
|
||||||
projects: Default::default(),
|
projects: Default::default(),
|
||||||
api_key: None,
|
|
||||||
embedding_queue
|
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
@ -720,13 +717,13 @@ impl SemanticIndex {
|
||||||
|
|
||||||
let index = self.index_project(project.clone(), cx);
|
let index = self.index_project(project.clone(), cx);
|
||||||
let embedding_provider = self.embedding_provider.clone();
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
let api_key = self.api_key.clone();
|
|
||||||
|
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
index.await?;
|
index.await?;
|
||||||
let t0 = Instant::now();
|
let t0 = Instant::now();
|
||||||
|
|
||||||
let query = embedding_provider
|
let query = embedding_provider
|
||||||
.embed_batch(vec![query], api_key)
|
.embed_batch(vec![query])
|
||||||
.await?
|
.await?
|
||||||
.pop()
|
.pop()
|
||||||
.ok_or_else(|| anyhow!("could not embed query"))?;
|
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||||
|
@ -944,7 +941,6 @@ impl SemanticIndex {
|
||||||
let fs = self.fs.clone();
|
let fs = self.fs.clone();
|
||||||
let db_path = self.db.path().clone();
|
let db_path = self.db.path().clone();
|
||||||
let background = cx.background().clone();
|
let background = cx.background().clone();
|
||||||
let api_key = self.api_key.clone();
|
|
||||||
cx.background().spawn(async move {
|
cx.background().spawn(async move {
|
||||||
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
||||||
let mut results = Vec::<SearchResult>::new();
|
let mut results = Vec::<SearchResult>::new();
|
||||||
|
@ -959,15 +955,10 @@ impl SemanticIndex {
|
||||||
.parse_file_with_template(None, &snapshot.text(), language)
|
.parse_file_with_template(None, &snapshot.text(), language)
|
||||||
.log_err()
|
.log_err()
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
if Self::embed_spans(
|
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
|
||||||
&mut spans,
|
.await
|
||||||
embedding_provider.as_ref(),
|
.log_err()
|
||||||
&db,
|
.is_some()
|
||||||
api_key.clone(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.log_err()
|
|
||||||
.is_some()
|
|
||||||
{
|
{
|
||||||
for span in spans {
|
for span in spans {
|
||||||
let similarity = span.embedding.unwrap().similarity(&query);
|
let similarity = span.embedding.unwrap().similarity(&query);
|
||||||
|
@ -1007,9 +998,8 @@ impl SemanticIndex {
|
||||||
project: ModelHandle<Project>,
|
project: ModelHandle<Project>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<()>> {
|
) -> Task<Result<()>> {
|
||||||
if self.api_key.is_none() {
|
if !self.is_authenticated() {
|
||||||
self.authenticate(cx);
|
if !self.authenticate(cx) {
|
||||||
if self.api_key.is_none() {
|
|
||||||
return Task::ready(Err(anyhow!("user is not authenticated")));
|
return Task::ready(Err(anyhow!("user is not authenticated")));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1192,7 +1182,6 @@ impl SemanticIndex {
|
||||||
spans: &mut [Span],
|
spans: &mut [Span],
|
||||||
embedding_provider: &dyn EmbeddingProvider,
|
embedding_provider: &dyn EmbeddingProvider,
|
||||||
db: &VectorDatabase,
|
db: &VectorDatabase,
|
||||||
api_key: Option<String>,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut batch = Vec::new();
|
let mut batch = Vec::new();
|
||||||
let mut batch_tokens = 0;
|
let mut batch_tokens = 0;
|
||||||
|
@ -1215,7 +1204,7 @@ impl SemanticIndex {
|
||||||
|
|
||||||
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
||||||
let batch_embeddings = embedding_provider
|
let batch_embeddings = embedding_provider
|
||||||
.embed_batch(mem::take(&mut batch), api_key.clone())
|
.embed_batch(mem::take(&mut batch))
|
||||||
.await?;
|
.await?;
|
||||||
embeddings.extend(batch_embeddings);
|
embeddings.extend(batch_embeddings);
|
||||||
batch_tokens = 0;
|
batch_tokens = 0;
|
||||||
|
@ -1227,7 +1216,7 @@ impl SemanticIndex {
|
||||||
|
|
||||||
if !batch.is_empty() {
|
if !batch.is_empty() {
|
||||||
let batch_embeddings = embedding_provider
|
let batch_embeddings = embedding_provider
|
||||||
.embed_batch(mem::take(&mut batch), api_key)
|
.embed_batch(mem::take(&mut batch))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
embeddings.extend(batch_embeddings);
|
embeddings.extend(batch_embeddings);
|
||||||
|
|
|
@ -4,10 +4,9 @@ use crate::{
|
||||||
semantic_index_settings::SemanticIndexSettings,
|
semantic_index_settings::SemanticIndexSettings,
|
||||||
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
|
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
|
||||||
};
|
};
|
||||||
use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
|
use ai::test::FakeEmbeddingProvider;
|
||||||
use anyhow::Result;
|
|
||||||
use async_trait::async_trait;
|
use gpui::{executor::Deterministic, Task, TestAppContext};
|
||||||
use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
|
|
||||||
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
|
@ -15,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs
|
||||||
use rand::{rngs::StdRng, Rng};
|
use rand::{rngs::StdRng, Rng};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use std::{
|
use std::{path::Path, sync::Arc, time::SystemTime};
|
||||||
path::Path,
|
|
||||||
sync::{
|
|
||||||
atomic::{self, AtomicUsize},
|
|
||||||
Arc,
|
|
||||||
},
|
|
||||||
time::{Instant, SystemTime},
|
|
||||||
};
|
|
||||||
use unindent::Unindent;
|
use unindent::Unindent;
|
||||||
use util::RandomCharIter;
|
use util::RandomCharIter;
|
||||||
|
|
||||||
|
@ -228,7 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||||
|
|
||||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
|
|
||||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
|
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
|
||||||
for file in &files {
|
for file in &files {
|
||||||
queue.push(file.clone());
|
queue.push(file.clone());
|
||||||
}
|
}
|
||||||
|
@ -280,7 +272,7 @@ fn assert_search_results(
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_rust() {
|
async fn test_code_context_retrieval_rust() {
|
||||||
let language = rust_lang();
|
let language = rust_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = "
|
let text = "
|
||||||
|
@ -382,7 +374,7 @@ async fn test_code_context_retrieval_rust() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_json() {
|
async fn test_code_context_retrieval_json() {
|
||||||
let language = json_lang();
|
let language = json_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
|
@ -466,7 +458,7 @@ fn assert_documents_eq(
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_javascript() {
|
async fn test_code_context_retrieval_javascript() {
|
||||||
let language = js_lang();
|
let language = js_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = "
|
let text = "
|
||||||
|
@ -565,7 +557,7 @@ async fn test_code_context_retrieval_javascript() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_lua() {
|
async fn test_code_context_retrieval_lua() {
|
||||||
let language = lua_lang();
|
let language = lua_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
|
@ -639,7 +631,7 @@ async fn test_code_context_retrieval_lua() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_elixir() {
|
async fn test_code_context_retrieval_elixir() {
|
||||||
let language = elixir_lang();
|
let language = elixir_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
|
@ -756,7 +748,7 @@ async fn test_code_context_retrieval_elixir() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_cpp() {
|
async fn test_code_context_retrieval_cpp() {
|
||||||
let language = cpp_lang();
|
let language = cpp_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = "
|
let text = "
|
||||||
|
@ -909,7 +901,7 @@ async fn test_code_context_retrieval_cpp() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_ruby() {
|
async fn test_code_context_retrieval_ruby() {
|
||||||
let language = ruby_lang();
|
let language = ruby_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
|
@ -1100,7 +1092,7 @@ async fn test_code_context_retrieval_ruby() {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_php() {
|
async fn test_code_context_retrieval_php() {
|
||||||
let language = php_lang();
|
let language = php_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
|
@ -1248,65 +1240,6 @@ async fn test_code_context_retrieval_php() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
struct FakeEmbeddingProvider {
|
|
||||||
embedding_count: AtomicUsize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FakeEmbeddingProvider {
|
|
||||||
fn embedding_count(&self) -> usize {
|
|
||||||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn embed_sync(&self, span: &str) -> Embedding {
|
|
||||||
let mut result = vec![1.0; 26];
|
|
||||||
for letter in span.chars() {
|
|
||||||
let letter = letter.to_ascii_lowercase();
|
|
||||||
if letter as u32 >= 'a' as u32 {
|
|
||||||
let ix = (letter as u32) - ('a' as u32);
|
|
||||||
if ix < 26 {
|
|
||||||
result[ix as usize] += 1.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
||||||
for x in &mut result {
|
|
||||||
*x /= norm;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.into()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
|
||||||
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
|
|
||||||
Some("Fake Credentials".to_string())
|
|
||||||
}
|
|
||||||
fn truncate(&self, span: &str) -> (String, usize) {
|
|
||||||
(span.to_string(), 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn max_tokens_per_batch(&self) -> usize {
|
|
||||||
200
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn embed_batch(
|
|
||||||
&self,
|
|
||||||
spans: Vec<String>,
|
|
||||||
_api_key: Option<String>,
|
|
||||||
) -> Result<Vec<Embedding>> {
|
|
||||||
self.embedding_count
|
|
||||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
|
||||||
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn js_lang() -> Arc<Language> {
|
fn js_lang() -> Arc<Language> {
|
||||||
Arc::new(
|
Arc::new(
|
||||||
Language::new(
|
Language::new(
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use ai::embedding::OpenAIEmbeddings;
|
use ai::providers::open_ai::OpenAIEmbeddingProvider;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use client::{self, UserStore};
|
use client::{self, UserStore};
|
||||||
use gpui::{AsyncAppContext, ModelHandle, Task};
|
use gpui::{AsyncAppContext, ModelHandle, Task};
|
||||||
|
@ -475,7 +475,7 @@ fn main() {
|
||||||
let semantic_index = SemanticIndex::new(
|
let semantic_index = SemanticIndex::new(
|
||||||
fs.clone(),
|
fs.clone(),
|
||||||
db_file_path,
|
db_file_path,
|
||||||
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
|
Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
|
||||||
languages.clone(),
|
languages.clone(),
|
||||||
cx.clone(),
|
cx.clone(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -321,8 +321,8 @@ impl LspAdapter for NextLspAdapter {
|
||||||
latest_github_release("elixir-tools/next-ls", false, delegate.http_client()).await?;
|
latest_github_release("elixir-tools/next-ls", false, delegate.http_client()).await?;
|
||||||
let version = release.name.clone();
|
let version = release.name.clone();
|
||||||
let platform = match consts::ARCH {
|
let platform = match consts::ARCH {
|
||||||
"x86_64" => "darwin_arm64",
|
"x86_64" => "darwin_amd64",
|
||||||
"aarch64" => "darwin_amd64",
|
"aarch64" => "darwin_arm64",
|
||||||
other => bail!("Running on unsupported platform: {other}"),
|
other => bail!("Running on unsupported platform: {other}"),
|
||||||
};
|
};
|
||||||
let asset_name = format!("next_ls_{}", platform);
|
let asset_name = format!("next_ls_{}", platform);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue