Allow AI interactions to be proxied through Zed's server so you don't need an API key (#7367)

Co-authored-by: Antonio <antonio@zed.dev>

Resurrected this from some assistant work I did in Spring of 2023.
- [x] Resurrect streaming responses
- [x] Use streaming responses to enable AI via Zed's servers by default
(but preserve API key option for now)
- [x] Simplify protobuf
- [x] Proxy to OpenAI on zed.dev
- [x] Proxy to Gemini on zed.dev
- [x] Improve UX for switching between openAI and google models
- We current disallow cycling when setting a custom model, but we need a
better solution to keep OpenAI models available while testing the google
ones
- [x] Show remaining tokens correctly for Google models
- [x] Remove semantic index
- [x] Delete `ai` crate
- [x] Cloud front so we can ban abuse
- [x] Rate-limiting
- [x] Fix panic when using inline assistant
- [x] Double check the upgraded `AssistantSettings` are
backwards-compatible
- [x] Add hosted LLM interaction behind a `language-models` feature
flag.

Release Notes:

- We are temporarily removing the semantic index in order to redesign it
from scratch.

---------

Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Thorsten <thorsten@zed.dev>
Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Nathan Sobo 2024-03-19 12:22:26 -06:00 committed by GitHub
parent 905a24079a
commit 8ae5a3b61a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
87 changed files with 3647 additions and 8937 deletions

234
Cargo.lock generated
View file

@ -85,32 +85,6 @@ dependencies = [
"memchr",
]
[[package]]
name = "ai"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"bincode",
"futures 0.3.28",
"gpui",
"isahc",
"language",
"log",
"matrixmultiply",
"ordered-float 2.10.0",
"parking_lot",
"parse_duration",
"postage",
"rand 0.8.5",
"rusqlite",
"schemars",
"serde",
"serde_json",
"tiktoken-rs",
"util",
]
[[package]]
name = "alacritty_terminal"
version = "0.22.1-dev"
@ -339,9 +313,9 @@ dependencies = [
name = "assistant"
version = "0.1.0"
dependencies = [
"ai",
"anyhow",
"chrono",
"client",
"collections",
"ctor",
"editor",
@ -354,13 +328,14 @@ dependencies = [
"log",
"menu",
"multi_buffer",
"open_ai",
"ordered-float 2.10.0",
"parking_lot",
"project",
"rand 0.8.5",
"regex",
"schemars",
"search",
"semantic_index",
"serde",
"serde_json",
"settings",
@ -1339,7 +1314,7 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6773ddc0eafc0e509fb60e48dff7f450f8e674a0686ae8605e8d9901bd5eefa"
dependencies = [
"num-bigint 0.4.4",
"num-bigint",
"num-integer",
"num-traits",
]
@ -2209,11 +2184,11 @@ dependencies = [
"fs",
"futures 0.3.28",
"git",
"google_ai",
"gpui",
"hex",
"indoc",
"language",
"lazy_static",
"live_kit_client",
"live_kit_server",
"log",
@ -2222,6 +2197,7 @@ dependencies = [
"nanoid",
"node_runtime",
"notifications",
"open_ai",
"parking_lot",
"pretty_assertions",
"project",
@ -3554,24 +3530,12 @@ dependencies = [
"workspace",
]
[[package]]
name = "fallible-iterator"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
[[package]]
name = "fallible-iterator"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649"
[[package]]
name = "fallible-streaming-iterator"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
[[package]]
name = "fancy-regex"
version = "0.11.0"
@ -4183,7 +4147,7 @@ version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0"
dependencies = [
"fallible-iterator 0.3.0",
"fallible-iterator",
"indexmap 2.0.0",
"stable_deref_trait",
]
@ -4279,6 +4243,17 @@ dependencies = [
"workspace",
]
[[package]]
name = "google_ai"
version = "0.1.0"
dependencies = [
"anyhow",
"futures 0.3.28",
"serde",
"serde_json",
"util",
]
[[package]]
name = "gpu-alloc"
version = "0.6.0"
@ -5667,16 +5642,6 @@ version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
[[package]]
name = "matrixmultiply"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2"
dependencies = [
"autocfg",
"rawpointer",
]
[[package]]
name = "maybe-owned"
version = "0.3.4"
@ -5946,19 +5911,6 @@ dependencies = [
"tempfile",
]
[[package]]
name = "ndarray"
version = "0.15.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
dependencies = [
"matrixmultiply",
"num-complex 0.4.4",
"num-integer",
"num-traits",
"rawpointer",
]
[[package]]
name = "ndk"
version = "0.7.0"
@ -6111,45 +6063,20 @@ dependencies = [
"winapi",
]
[[package]]
name = "num"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
dependencies = [
"num-bigint 0.2.6",
"num-complex 0.2.4",
"num-integer",
"num-iter",
"num-rational 0.2.4",
"num-traits",
]
[[package]]
name = "num"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
dependencies = [
"num-bigint 0.4.4",
"num-complex 0.4.4",
"num-bigint",
"num-complex",
"num-integer",
"num-iter",
"num-rational 0.4.1",
"num-traits",
]
[[package]]
name = "num-bigint"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-bigint"
version = "0.4.4"
@ -6196,16 +6123,6 @@ dependencies = [
"zeroize",
]
[[package]]
name = "num-complex"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
dependencies = [
"autocfg",
"num-traits",
]
[[package]]
name = "num-complex"
version = "0.4.4"
@ -6247,18 +6164,6 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef"
dependencies = [
"autocfg",
"num-bigint 0.2.6",
"num-integer",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.3.2"
@ -6277,7 +6182,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
dependencies = [
"autocfg",
"num-bigint 0.4.4",
"num-bigint",
"num-integer",
"num-traits",
]
@ -6436,7 +6341,7 @@ dependencies = [
"futures-util",
"hkdf",
"hmac 0.12.1",
"num 0.4.1",
"num",
"num-bigint-dig 0.8.4",
"pbkdf2 0.12.2",
"rand 0.8.5",
@ -6464,6 +6369,18 @@ dependencies = [
"pathdiff",
]
[[package]]
name = "open_ai"
version = "0.1.0"
dependencies = [
"anyhow",
"futures 0.3.28",
"schemars",
"serde",
"serde_json",
"util",
]
[[package]]
name = "openssl"
version = "0.10.57"
@ -6679,17 +6596,6 @@ dependencies = [
"windows-targets 0.48.5",
]
[[package]]
name = "parse_duration"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7037e5e93e0172a5a96874380bf73bc6ecef022e26fa25f2be26864d6b3ba95d"
dependencies = [
"lazy_static",
"num 0.2.1",
"regex",
]
[[package]]
name = "password-hash"
version = "0.2.1"
@ -7471,12 +7377,6 @@ dependencies = [
"raw-window-handle 0.5.2",
]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "rayon"
version = "1.8.0"
@ -7935,20 +7835,6 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rusqlite"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2"
dependencies = [
"bitflags 2.4.2",
"fallible-iterator 0.2.0",
"fallible-streaming-iterator",
"hashlink",
"libsqlite3-sys",
"smallvec",
]
[[package]]
name = "rust-embed"
version = "8.2.0"
@ -8378,7 +8264,6 @@ dependencies = [
"language",
"menu",
"project",
"semantic_index",
"serde",
"serde_json",
"settings",
@ -8434,52 +8319,6 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58bf37232d3bb9a2c4e641ca2a11d83b5062066f88df7fed36c28772046d65ba"
[[package]]
name = "semantic_index"
version = "0.1.0"
dependencies = [
"ai",
"anyhow",
"collections",
"ctor",
"env_logger",
"futures 0.3.28",
"gpui",
"language",
"lazy_static",
"log",
"ndarray",
"ordered-float 2.10.0",
"parking_lot",
"postage",
"pretty_assertions",
"project",
"rand 0.8.5",
"release_channel",
"rpc",
"rusqlite",
"schemars",
"serde",
"serde_json",
"settings",
"sha1",
"smol",
"tempfile",
"tree-sitter",
"tree-sitter-cpp",
"tree-sitter-elixir",
"tree-sitter-json 0.20.0",
"tree-sitter-lua",
"tree-sitter-php",
"tree-sitter-ruby",
"tree-sitter-rust",
"tree-sitter-toml",
"tree-sitter-typescript",
"unindent",
"util",
"workspace",
]
[[package]]
name = "semver"
version = "1.0.18"
@ -8766,7 +8605,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eb4ea60fb301dc81dfc113df680571045d375ab7345d171c5dc7d7e13107a80"
dependencies = [
"chrono",
"num-bigint 0.4.4",
"num-bigint",
"num-traits",
"thiserror",
]
@ -9197,7 +9036,7 @@ dependencies = [
"log",
"md-5",
"memchr",
"num-bigint 0.4.4",
"num-bigint",
"once_cell",
"rand 0.8.5",
"rust_decimal",
@ -12729,7 +12568,6 @@ dependencies = [
"release_channel",
"rope",
"search",
"semantic_index",
"serde",
"serde_json",
"settings",

View file

@ -1,7 +1,6 @@
[workspace]
members = [
"crates/activity_indicator",
"crates/ai",
"crates/assets",
"crates/assistant",
"crates/audio",
@ -34,6 +33,7 @@ members = [
"crates/fuzzy",
"crates/git",
"crates/go_to_line",
"crates/google_ai",
"crates/gpui",
"crates/gpui_macros",
"crates/image_viewer",
@ -52,6 +52,7 @@ members = [
"crates/multi_buffer",
"crates/node_runtime",
"crates/notifications",
"crates/open_ai",
"crates/outline",
"crates/picker",
"crates/prettier",
@ -69,7 +70,6 @@ members = [
"crates/task",
"crates/tasks_ui",
"crates/search",
"crates/semantic_index",
"crates/settings",
"crates/snippet",
"crates/sqlez",
@ -138,6 +138,7 @@ fsevent = { path = "crates/fsevent" }
fuzzy = { path = "crates/fuzzy" }
git = { path = "crates/git" }
go_to_line = { path = "crates/go_to_line" }
google_ai = { path = "crates/google_ai" }
gpui = { path = "crates/gpui" }
gpui_macros = { path = "crates/gpui_macros" }
install_cli = { path = "crates/install_cli" }
@ -156,6 +157,7 @@ menu = { path = "crates/menu" }
multi_buffer = { path = "crates/multi_buffer" }
node_runtime = { path = "crates/node_runtime" }
notifications = { path = "crates/notifications" }
open_ai = { path = "crates/open_ai" }
outline = { path = "crates/outline" }
picker = { path = "crates/picker" }
plugin = { path = "crates/plugin" }
@ -174,7 +176,6 @@ rpc = { path = "crates/rpc" }
task = { path = "crates/task" }
tasks_ui = { path = "crates/tasks_ui" }
search = { path = "crates/search" }
semantic_index = { path = "crates/semantic_index" }
settings = { path = "crates/settings" }
snippet = { path = "crates/snippet" }
sqlez = { path = "crates/sqlez" }

View file

@ -251,7 +251,6 @@
"alt-tab": "search::CycleMode",
"cmd-shift-h": "search::ToggleReplace",
"alt-cmd-g": "search::ActivateRegexMode",
"alt-cmd-s": "search::ActivateSemanticMode",
"alt-cmd-x": "search::ActivateTextMode"
}
},
@ -276,7 +275,6 @@
"alt-tab": "search::CycleMode",
"cmd-shift-h": "search::ToggleReplace",
"alt-cmd-g": "search::ActivateRegexMode",
"alt-cmd-s": "search::ActivateSemanticMode",
"alt-cmd-x": "search::ActivateTextMode"
}
},
@ -302,7 +300,6 @@
"alt-tab": "search::CycleMode",
"alt-cmd-f": "project_search::ToggleFilters",
"alt-cmd-g": "search::ActivateRegexMode",
"alt-cmd-s": "search::ActivateSemanticMode",
"alt-cmd-x": "search::ActivateTextMode"
}
},

View file

@ -237,6 +237,8 @@
"default_width": 380
},
"assistant": {
// Version of this setting.
"version": "1",
// Whether to show the assistant panel button in the status bar.
"button": true,
// Where to dock the assistant panel. Can be 'left', 'right' or 'bottom'.
@ -245,28 +247,16 @@
"default_width": 640,
// Default height when the assistant is docked to the bottom.
"default_height": 320,
// Deprecated: Please use `provider.api_url` instead.
// The default OpenAI API endpoint to use when starting new conversations.
"openai_api_url": "https://api.openai.com/v1",
// Deprecated: Please use `provider.default_model` instead.
// The default OpenAI model to use when starting new conversations. This
// setting can take three values:
//
// 1. "gpt-3.5-turbo-0613""
// 2. "gpt-4-0613""
// 3. "gpt-4-1106-preview"
"default_open_ai_model": "gpt-4-1106-preview",
// AI provider.
"provider": {
"type": "openai",
// The default OpenAI API endpoint to use when starting new conversations.
"api_url": "https://api.openai.com/v1",
// The default OpenAI model to use when starting new conversations. This
"name": "openai",
// The default model to use when starting new conversations. This
// setting can take three values:
//
// 1. "gpt-3.5-turbo-0613""
// 2. "gpt-4-0613""
// 3. "gpt-4-1106-preview"
"default_model": "gpt-4-1106-preview"
// 1. "gpt-3.5-turbo"
// 2. "gpt-4"
// 3. "gpt-4-turbo-preview"
"default_model": "gpt-4-turbo-preview"
}
},
// Whether the screen sharing icon is shown in the os status bar.
@ -505,10 +495,6 @@
// Existing terminals will not pick up this change until they are recreated.
// "max_scroll_history_lines": 10000,
},
// Difference settings for semantic_index
"semantic_index": {
"enabled": true
},
// Settings specific to our elixir integration
"elixir": {
// Change the LSP zed uses for elixir.

View file

@ -1,41 +0,0 @@
[package]
name = "ai"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/ai.rs"
doctest = false
[features]
test-support = []
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
bincode = "1.3.3"
futures.workspace = true
gpui.workspace = true
isahc.workspace = true
language.workspace = true
log.workspace = true
matrixmultiply = "0.3.7"
ordered-float.workspace = true
parking_lot.workspace = true
parse_duration = "2.1.1"
postage.workspace = true
rand.workspace = true
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
tiktoken-rs.workspace = true
util.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }

View file

@ -1 +0,0 @@
../../LICENSE-GPL

View file

@ -1,8 +0,0 @@
pub mod auth;
pub mod completion;
pub mod embedding;
pub mod models;
pub mod prompts;
pub mod providers;
#[cfg(any(test, feature = "test-support"))]
pub mod test;

View file

@ -1,23 +0,0 @@
use futures::future::BoxFuture;
use gpui::AppContext;
#[derive(Clone, Debug)]
pub enum ProviderCredential {
Credentials { api_key: String },
NoCredentials,
NotNeeded,
}
pub trait CredentialProvider: Send + Sync {
fn has_credentials(&self) -> bool;
#[must_use]
fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential>;
#[must_use]
fn save_credentials(
&self,
cx: &mut AppContext,
credential: ProviderCredential,
) -> BoxFuture<()>;
#[must_use]
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()>;
}

View file

@ -1,23 +0,0 @@
use anyhow::Result;
use futures::{future::BoxFuture, stream::BoxStream};
use crate::{auth::CredentialProvider, models::LanguageModel};
pub trait CompletionRequest: Send + Sync {
fn data(&self) -> serde_json::Result<String>;
}
pub trait CompletionProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn box_clone(&self) -> Box<dyn CompletionProvider>;
}
impl Clone for Box<dyn CompletionProvider> {
fn clone(&self) -> Box<dyn CompletionProvider> {
self.box_clone()
}
}

View file

@ -1,121 +0,0 @@
use std::time::Instant;
use anyhow::Result;
use async_trait::async_trait;
use ordered_float::OrderedFloat;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
use crate::auth::CredentialProvider;
use crate::models::LanguageModel;
#[derive(Debug, PartialEq, Clone)]
pub struct Embedding(pub Vec<f32>);
// This is needed for semantic index functionality
// Unfortunately it has to live wherever the "Embedding" struct is created.
// Keeping this in here though, introduces a 'rusqlite' dependency into AI
// which is less than ideal
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding =
bincode::deserialize(bytes).map_err(|err| rusqlite::types::FromSqlError::Other(err))?;
Ok(Embedding(embedding))
}
}
impl ToSql for Embedding {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
let bytes = bincode::serialize(&self.0)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
}
}
impl From<Vec<f32>> for Embedding {
fn from(value: Vec<f32>) -> Self {
Embedding(value)
}
}
impl Embedding {
pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
let len = self.0.len();
assert_eq!(len, other.0.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
self.0.as_ptr(),
len as isize,
1,
other.0.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
OrderedFloat(result)
}
}
#[async_trait]
pub trait EmbeddingProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn rate_limit_expiration(&self) -> Option<Instant>;
}
#[cfg(test)]
mod tests {
use super::*;
use rand::prelude::*;
#[gpui::test]
fn test_similarity(mut rng: StdRng) {
assert_eq!(
Embedding::from(vec![1., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
0.
);
assert_eq!(
Embedding::from(vec![2., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
6.
);
for _ in 0..100 {
let size = 1536;
let mut a = vec![0.; size];
let mut b = vec![0.; size];
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
*a = rng.gen();
*b = rng.gen();
}
let a = Embedding::from(a);
let b = Embedding::from(b);
assert_eq!(
round_to_decimals(a.similarity(&b), 1),
round_to_decimals(reference_dot(&a.0, &b.0), 1)
);
}
fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
let factor = 10.0_f32.powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
}
}
}

View file

@ -1,16 +0,0 @@
pub enum TruncationDirection {
Start,
End,
}
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}

View file

@ -1,337 +0,0 @@
use std::cmp::Reverse;
use std::ops::Range;
use std::sync::Arc;
use language::BufferSnapshot;
use util::ResultExt;
use crate::models::LanguageModel;
use crate::prompts::repository_context::PromptCodeSnippet;
pub(crate) enum PromptFileType {
Text,
Code,
}
// TODO: Set this up to manage for defaults well
pub struct PromptArguments {
pub model: Arc<dyn LanguageModel>,
pub user_prompt: Option<String>,
pub language_name: Option<String>,
pub project_name: Option<String>,
pub snippets: Vec<PromptCodeSnippet>,
pub reserved_tokens: usize,
pub buffer: Option<BufferSnapshot>,
pub selected_range: Option<Range<usize>>,
}
impl PromptArguments {
pub(crate) fn get_file_type(&self) -> PromptFileType {
if self
.language_name
.as_ref()
.map(|name| !["Markdown", "Plain Text"].contains(&name.as_str()))
.unwrap_or(true)
{
PromptFileType::Code
} else {
PromptFileType::Text
}
}
}
pub trait PromptTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)>;
}
#[repr(i8)]
#[derive(PartialEq, Eq)]
pub enum PromptPriority {
/// Ignores truncation.
Mandatory,
/// Truncates based on priority.
Ordered { order: usize },
}
impl PartialOrd for PromptPriority {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PromptPriority {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match (self, other) {
(Self::Mandatory, Self::Mandatory) => std::cmp::Ordering::Equal,
(Self::Mandatory, Self::Ordered { .. }) => std::cmp::Ordering::Greater,
(Self::Ordered { .. }, Self::Mandatory) => std::cmp::Ordering::Less,
(Self::Ordered { order: a }, Self::Ordered { order: b }) => b.cmp(a),
}
}
}
pub struct PromptChain {
args: PromptArguments,
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
}
impl PromptChain {
pub fn new(
args: PromptArguments,
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
) -> Self {
PromptChain { args, templates }
}
pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
// Argsort based on Prompt Priority
let separator = "\n";
let separator_tokens = self.args.model.count_tokens(separator)?;
let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
let mut tokens_outstanding = if truncate {
Some(self.args.model.capacity()? - self.args.reserved_tokens)
} else {
None
};
let mut prompts = vec!["".to_string(); sorted_indices.len()];
for idx in sorted_indices {
let (_, template) = &self.templates[idx];
if let Some((template_prompt, prompt_token_count)) =
template.generate(&self.args, tokens_outstanding).log_err()
{
if template_prompt != "" {
prompts[idx] = template_prompt;
if let Some(remaining_tokens) = tokens_outstanding {
let new_tokens = prompt_token_count + separator_tokens;
tokens_outstanding = if remaining_tokens > new_tokens {
Some(remaining_tokens - new_tokens)
} else {
Some(0)
};
}
}
}
}
prompts.retain(|x| x != "");
let full_prompt = prompts.join(separator);
let total_token_count = self.args.model.count_tokens(&full_prompt)?;
anyhow::Ok((prompts.join(separator), total_token_count))
}
}
#[cfg(test)]
pub(crate) mod tests {
use crate::models::TruncationDirection;
use crate::test::FakeLanguageModel;
use super::*;
#[test]
pub fn test_prompt_chain() {
struct TestPromptTemplate {}
impl PromptTemplate for TestPromptTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut content = "This is a test prompt template".to_string();
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length;
}
}
anyhow::Ok((content, token_count))
}
}
struct TestLowPriorityTemplate {}
impl PromptTemplate for TestLowPriorityTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut content = "This is a low priority test prompt template".to_string();
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length;
}
}
anyhow::Ok((content, token_count))
}
}
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(false).unwrap();
assert_eq!(
prompt,
"This is a test prompt template\nThis is a low priority test prompt template"
.to_string()
);
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(false).unwrap();
assert_eq!(
prompt,
"This is a test prompt template\nThis is a low priority test prompt template"
.to_string()
);
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let capacity = 20;
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
(
PromptPriority::Ordered { order: 2 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(true).unwrap();
assert_eq!(prompt, "This is a test promp".to_string());
assert_eq!(token_count, capacity);
// Change Ordering of Prompts Based on Priority
let capacity = 120;
let reserved_tokens = 10;
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Mandatory,
Box::new(TestLowPriorityTemplate {}),
),
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(true).unwrap();
assert_eq!(
prompt,
"This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
.to_string()
);
assert_eq!(token_count, capacity - reserved_tokens);
}
}

View file

@ -1,164 +0,0 @@
use anyhow::anyhow;
use language::BufferSnapshot;
use language::ToOffset;
use crate::models::LanguageModel;
use crate::models::TruncationDirection;
use crate::prompts::base::PromptArguments;
use crate::prompts::base::PromptTemplate;
use std::fmt::Write;
use std::ops::Range;
use std::sync::Arc;
fn retrieve_context(
buffer: &BufferSnapshot,
selected_range: &Option<Range<usize>>,
model: Arc<dyn LanguageModel>,
max_token_count: Option<usize>,
) -> anyhow::Result<(String, usize, bool)> {
let mut prompt = String::new();
let mut truncated = false;
if let Some(selected_range) = selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
let start_window = buffer.text_for_range(0..start).collect::<String>();
let mut selected_window = String::new();
if start == end {
write!(selected_window, "<|START|>").unwrap();
} else {
write!(selected_window, "<|START|").unwrap();
}
write!(
selected_window,
"{}",
buffer.text_for_range(start..end).collect::<String>()
)
.unwrap();
if start != end {
write!(selected_window, "|END|>").unwrap();
}
let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
if let Some(max_token_count) = max_token_count {
let selected_tokens = model.count_tokens(&selected_window)?;
if selected_tokens > max_token_count {
return Err(anyhow!(
"selected range is greater than model context window, truncation not possible"
));
};
let mut remaining_tokens = max_token_count - selected_tokens;
let start_window_tokens = model.count_tokens(&start_window)?;
let end_window_tokens = model.count_tokens(&end_window)?;
let outside_tokens = start_window_tokens + end_window_tokens;
if outside_tokens > remaining_tokens {
let (start_goal_tokens, end_goal_tokens) =
if start_window_tokens < end_window_tokens {
let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
remaining_tokens -= start_goal_tokens;
let end_goal_tokens = remaining_tokens.min(end_window_tokens);
(start_goal_tokens, end_goal_tokens)
} else {
let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
remaining_tokens -= end_goal_tokens;
let start_goal_tokens = remaining_tokens.min(start_window_tokens);
(start_goal_tokens, end_goal_tokens)
};
let truncated_start_window =
model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
let truncated_end_window =
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
writeln!(
prompt,
"{truncated_start_window}{selected_window}{truncated_end_window}"
)
.unwrap();
truncated = true;
} else {
writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
}
} else {
// If we dont have a selected range, include entire file.
writeln!(prompt, "{}", &buffer.text()).unwrap();
// Dumb truncation strategy
if let Some(max_token_count) = max_token_count {
if model.count_tokens(&prompt)? > max_token_count {
truncated = true;
prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
}
}
}
}
let token_count = model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count, truncated))
}
pub struct FileContext {}
impl PromptTemplate for FileContext {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
if let Some(buffer) = &args.buffer {
let mut prompt = String::new();
// Add Initial Preamble
// TODO: Do we want to add the path in here?
writeln!(
prompt,
"The file you are currently working on has the following content:"
)
.unwrap();
let language_name = args
.language_name
.clone()
.unwrap_or("".to_string())
.to_lowercase();
let (context, _, truncated) = retrieve_context(
buffer,
&args.selected_range,
args.model.clone(),
max_token_length,
)?;
writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
if truncated {
writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
}
if let Some(selected_range) = &args.selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
if start == end {
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
} else {
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
}
}
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args
.model
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
}
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
} else {
Err(anyhow!("no buffer provided to retrieve file context from"))
}
}
}

View file

@ -1,99 +0,0 @@
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use anyhow::anyhow;
use std::fmt::Write;
pub fn capitalize(s: &str) -> String {
let mut c = s.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
}
}
pub struct GenerateInlineContent {}
impl PromptTemplate for GenerateInlineContent {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let Some(user_prompt) = &args.user_prompt else {
return Err(anyhow!("user prompt not provided"));
};
let file_type = args.get_file_type();
let content_type = match &file_type {
PromptFileType::Code => "code",
PromptFileType::Text => "text",
};
let mut prompt = String::new();
if let Some(selected_range) = &args.selected_range {
if selected_range.start == selected_range.end {
writeln!(
prompt,
"Assume the cursor is located where the `<|START|>` span is."
)
.unwrap();
writeln!(
prompt,
"{} can't be replaced, so assume your answer will be inserted at the cursor.",
capitalize(content_type)
)
.unwrap();
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}",
)
.unwrap();
} else {
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
}
} else {
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}"
)
.unwrap();
}
if let Some(language_name) = &args.language_name {
writeln!(
prompt,
"Your answer MUST always and only be valid {}.",
language_name
)
.unwrap();
}
writeln!(prompt, "Never make remarks about the output.").unwrap();
writeln!(
prompt,
"Do not return anything else, except the generated {content_type}."
)
.unwrap();
match file_type {
PromptFileType::Code => {
// writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
}
_ => {}
}
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args.model.truncate(
&prompt,
max_tokens,
crate::models::TruncationDirection::End,
)?;
}
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
}
}

View file

@ -1,5 +0,0 @@
pub mod base;
pub mod file_context;
pub mod generate;
pub mod preamble;
pub mod repository_context;

View file

@ -1,52 +0,0 @@
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use std::fmt::Write;
pub struct EngineerPreamble {}
impl PromptTemplate for EngineerPreamble {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut prompts = Vec::new();
match args.get_file_type() {
PromptFileType::Code => {
prompts.push(format!(
"You are an expert {}engineer.",
args.language_name.clone().unwrap_or("".to_string()) + " "
));
}
PromptFileType::Text => {
prompts.push("You are an expert engineer.".to_string());
}
}
if let Some(project_name) = args.project_name.clone() {
prompts.push(format!(
"You are currently working inside the '{project_name}' project in code editor Zed."
));
}
if let Some(mut remaining_tokens) = max_token_length {
let mut prompt = String::new();
let mut total_count = 0;
for prompt_piece in prompts {
let prompt_token_count =
args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
if remaining_tokens > prompt_token_count {
writeln!(prompt, "{prompt_piece}").unwrap();
remaining_tokens -= prompt_token_count;
total_count += prompt_token_count;
}
}
anyhow::Ok((prompt, total_count))
} else {
let prompt = prompts.join("\n");
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
}
}
}

View file

@ -1,96 +0,0 @@
use crate::prompts::base::{PromptArguments, PromptTemplate};
use std::fmt::Write;
use std::{ops::Range, path::PathBuf};
use gpui::{AsyncAppContext, Model};
use language::{Anchor, Buffer};
#[derive(Clone)]
pub struct PromptCodeSnippet {
path: Option<PathBuf>,
language_name: Option<String>,
content: String,
}
impl PromptCodeSnippet {
pub fn new(
buffer: Model<Buffer>,
range: Range<Anchor>,
cx: &mut AsyncAppContext,
) -> anyhow::Result<Self> {
let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
let snapshot = buffer.snapshot();
let content = snapshot.text_for_range(range.clone()).collect::<String>();
let language_name = buffer
.language()
.map(|language| language.name().to_string().to_lowercase());
let file_path = buffer.file().map(|file| file.path().to_path_buf());
(content, language_name, file_path)
})?;
anyhow::Ok(PromptCodeSnippet {
path: file_path,
language_name,
content,
})
}
}
impl ToString for PromptCodeSnippet {
fn to_string(&self) -> String {
let path = self
.path
.as_ref()
.map(|path| path.to_string_lossy().to_string())
.unwrap_or("".to_string());
let language_name = self.language_name.clone().unwrap_or("".to_string());
let content = self.content.clone();
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
}
}
pub struct RepositoryContext {}
impl PromptTemplate for RepositoryContext {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
let mut prompt = String::new();
let mut remaining_tokens = max_token_length;
let separator_token_length = args.model.count_tokens("\n")?;
for snippet in &args.snippets {
let mut snippet_prompt = template.to_string();
let content = snippet.to_string();
writeln!(snippet_prompt, "{content}").unwrap();
let token_count = args.model.count_tokens(&snippet_prompt)?;
if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
if let Some(tokens_left) = remaining_tokens {
if tokens_left >= token_count {
writeln!(prompt, "{snippet_prompt}").unwrap();
remaining_tokens = if tokens_left >= (token_count + separator_token_length)
{
Some(tokens_left - token_count - separator_token_length)
} else {
Some(0)
};
}
} else {
writeln!(prompt, "{snippet_prompt}").unwrap();
}
}
}
let total_token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, total_token_count))
}
}

View file

@ -1 +0,0 @@
pub mod open_ai;

View file

@ -1,9 +0,0 @@
pub mod completion;
pub mod embedding;
pub mod model;
pub use completion::*;
pub use embedding::*;
pub use model::OpenAiLanguageModel;
pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";

View file

@ -1,421 +0,0 @@
use std::{
env,
fmt::{self, Display},
io,
sync::Arc,
};
use anyhow::{anyhow, Result};
use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
use gpui::{AppContext, BackgroundExecutor};
use isahc::{http::StatusCode, Request, RequestExt};
use parking_lot::RwLock;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use util::ResultExt;
use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
models::LanguageModel,
};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Debug, Default, Serialize)]
pub struct OpenAiRequest {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
}
impl CompletionRequest for OpenAiRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAiUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAiResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<OpenAiUsage>,
}
async fn stream_completion(
api_url: String,
kind: OpenAiCompletionProviderKind,
credential: ProviderCredential,
executor: BackgroundExecutor,
request: Box<dyn CompletionRequest>,
) -> Result<impl Stream<Item = Result<OpenAiResponseStreamEvent>>> {
let api_key = match credential {
ProviderCredential::Credentials { api_key } => api_key,
_ => {
return Err(anyhow!("no credentials provider for completion"));
}
};
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
let (auth_header_name, auth_header_value) = kind.auth_header(api_key);
let json_data = request.data()?;
let mut response = Request::post(kind.completions_endpoint_url(&api_url))
.header("Content-Type", "application/json")
.header(auth_header_name, auth_header_value)
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAiResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
let done = event.as_ref().map_or(false, |event| {
event
.choices
.last()
.map_or(false, |choice| choice.finish_reason.is_some())
});
if tx.unbounded_send(event).is_err() {
break;
}
if done {
break;
}
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAiResponse {
error: OpenAiError,
}
#[derive(Deserialize)]
struct OpenAiError {
message: String,
}
match serde_json::from_str::<OpenAiResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
pub enum AzureOpenAiApiVersion {
/// Retiring April 2, 2024.
#[serde(rename = "2023-03-15-preview")]
V2023_03_15Preview,
#[serde(rename = "2023-05-15")]
V2023_05_15,
/// Retiring April 2, 2024.
#[serde(rename = "2023-06-01-preview")]
V2023_06_01Preview,
/// Retiring April 2, 2024.
#[serde(rename = "2023-07-01-preview")]
V2023_07_01Preview,
/// Retiring April 2, 2024.
#[serde(rename = "2023-08-01-preview")]
V2023_08_01Preview,
/// Retiring April 2, 2024.
#[serde(rename = "2023-09-01-preview")]
V2023_09_01Preview,
#[serde(rename = "2023-12-01-preview")]
V2023_12_01Preview,
#[serde(rename = "2024-02-15-preview")]
V2024_02_15Preview,
}
impl fmt::Display for AzureOpenAiApiVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
match self {
Self::V2023_03_15Preview => "2023-03-15-preview",
Self::V2023_05_15 => "2023-05-15",
Self::V2023_06_01Preview => "2023-06-01-preview",
Self::V2023_07_01Preview => "2023-07-01-preview",
Self::V2023_08_01Preview => "2023-08-01-preview",
Self::V2023_09_01Preview => "2023-09-01-preview",
Self::V2023_12_01Preview => "2023-12-01-preview",
Self::V2024_02_15Preview => "2024-02-15-preview",
}
)
}
}
#[derive(Clone)]
pub enum OpenAiCompletionProviderKind {
OpenAi,
AzureOpenAi {
deployment_id: String,
api_version: AzureOpenAiApiVersion,
},
}
impl OpenAiCompletionProviderKind {
/// Returns the chat completion endpoint URL for this [`OpenAiCompletionProviderKind`].
fn completions_endpoint_url(&self, api_url: &str) -> String {
match self {
Self::OpenAi => {
// https://platform.openai.com/docs/api-reference/chat/create
format!("{api_url}/chat/completions")
}
Self::AzureOpenAi {
deployment_id,
api_version,
} => {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
format!("{api_url}/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}")
}
}
}
/// Returns the authentication header for this [`OpenAiCompletionProviderKind`].
fn auth_header(&self, api_key: String) -> (&'static str, String) {
match self {
Self::OpenAi => ("Authorization", format!("Bearer {api_key}")),
Self::AzureOpenAi { .. } => ("Api-Key", api_key),
}
}
}
#[derive(Clone)]
pub struct OpenAiCompletionProvider {
api_url: String,
kind: OpenAiCompletionProviderKind,
model: OpenAiLanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
executor: BackgroundExecutor,
}
impl OpenAiCompletionProvider {
pub async fn new(
api_url: String,
kind: OpenAiCompletionProviderKind,
model_name: String,
executor: BackgroundExecutor,
) -> Self {
let model = executor
.spawn(async move { OpenAiLanguageModel::load(&model_name) })
.await;
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self {
api_url,
kind,
model,
credential,
executor,
}
}
}
impl CredentialProvider for OpenAiCompletionProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
let existing_credential = self.credential.read().clone();
let retrieved_credential = match existing_credential {
ProviderCredential::Credentials { .. } => {
return async move { existing_credential }.boxed()
}
_ => {
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
async move { ProviderCredential::Credentials { api_key } }.boxed()
} else {
let credentials = cx.read_credentials(OPEN_AI_API_URL);
async move {
if let Some(Some((_, api_key))) = credentials.await.log_err() {
if let Some(api_key) = String::from_utf8(api_key).log_err() {
ProviderCredential::Credentials { api_key }
} else {
ProviderCredential::NoCredentials
}
} else {
ProviderCredential::NoCredentials
}
}
.boxed()
}
}
};
async move {
let retrieved_credential = retrieved_credential.await;
*self.credential.write() = retrieved_credential.clone();
retrieved_credential
}
.boxed()
}
fn save_credentials(
&self,
cx: &mut AppContext,
credential: ProviderCredential,
) -> BoxFuture<()> {
*self.credential.write() = credential.clone();
let credential = credential.clone();
let write_credentials = match credential {
ProviderCredential::Credentials { api_key } => {
Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
}
_ => None,
};
async move {
if let Some(write_credentials) = write_credentials {
write_credentials.await.log_err();
}
}
.boxed()
}
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
*self.credential.write() = ProviderCredential::NoCredentials;
let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
async move {
delete_credentials.await.log_err();
}
.boxed()
}
}
impl CompletionProvider for OpenAiCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
// which is currently model based, due to the language model.
// At some point in the future we should rectify this.
let credential = self.credential.read().clone();
let api_url = self.api_url.clone();
let kind = self.kind.clone();
let request = stream_completion(api_url, kind, credential, self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

View file

@ -1,345 +0,0 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::future::BoxFuture;
use futures::AsyncReadExt;
use futures::FutureExt;
use gpui::AppContext;
use gpui::BackgroundExecutor;
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use parking_lot::{Mutex, RwLock};
use parse_duration::parse;
use postage::watch;
use serde::{Deserialize, Serialize};
use serde_json;
use std::env;
use std::ops::Add;
use std::sync::{Arc, OnceLock};
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
use util::ResultExt;
use crate::auth::{CredentialProvider, ProviderCredential};
use crate::embedding::{Embedding, EmbeddingProvider};
use crate::models::LanguageModel;
use crate::providers::open_ai::OpenAiLanguageModel;
use crate::providers::open_ai::OPEN_AI_API_URL;
pub(crate) fn open_ai_bpe_tokenizer() -> &'static CoreBPE {
static OPEN_AI_BPE_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
OPEN_AI_BPE_TOKENIZER.get_or_init(|| cl100k_base().unwrap())
}
#[derive(Clone)]
pub struct OpenAiEmbeddingProvider {
api_url: String,
model: OpenAiLanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
pub client: Arc<dyn HttpClient>,
pub executor: BackgroundExecutor,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
}
#[derive(Serialize)]
struct OpenAiEmbeddingRequest<'a> {
model: &'static str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct OpenAiEmbeddingResponse {
data: Vec<OpenAiEmbedding>,
usage: OpenAiEmbeddingUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAiEmbedding {
embedding: Vec<f32>,
index: usize,
object: String,
}
#[derive(Deserialize)]
struct OpenAiEmbeddingUsage {
prompt_tokens: usize,
total_tokens: usize,
}
impl OpenAiEmbeddingProvider {
pub async fn new(
api_url: String,
client: Arc<dyn HttpClient>,
executor: BackgroundExecutor,
) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
// Loading the model is expensive, so ensure this runs off the main thread.
let model = executor
.spawn(async move { OpenAiLanguageModel::load("text-embedding-ada-002") })
.await;
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
OpenAiEmbeddingProvider {
api_url,
model,
credential,
client,
executor,
rate_limit_count_rx,
rate_limit_count_tx,
}
}
fn get_api_key(&self) -> Result<String> {
match self.credential.read().clone() {
ProviderCredential::Credentials { api_key } => Ok(api_key),
_ => Err(anyhow!("api credentials not provided")),
}
}
fn resolve_rate_limit(&self) {
let reset_time = *self.rate_limit_count_tx.lock().borrow();
if let Some(reset_time) = reset_time {
if Instant::now() >= reset_time {
*self.rate_limit_count_tx.lock().borrow_mut() = None
}
}
log::trace!(
"resolving reset time: {:?}",
*self.rate_limit_count_tx.lock().borrow()
);
}
fn update_reset_time(&self, reset_time: Instant) {
let original_time = *self.rate_limit_count_tx.lock().borrow();
let updated_time = if let Some(original_time) = original_time {
if reset_time < original_time {
Some(reset_time)
} else {
Some(original_time)
}
} else {
Some(reset_time)
};
log::trace!("updating rate limit time: {:?}", updated_time);
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
}
async fn send_request(
&self,
api_url: &str,
api_key: &str,
spans: Vec<&str>,
request_timeout: u64,
) -> Result<Response<AsyncBody>> {
let request = Request::post(format!("{api_url}/embeddings"))
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.timeout(Duration::from_secs(request_timeout))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAiEmbeddingRequest {
input: spans.clone(),
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
Ok(self.client.send(request).await?)
}
}
impl CredentialProvider for OpenAiEmbeddingProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
let existing_credential = self.credential.read().clone();
let retrieved_credential = match existing_credential {
ProviderCredential::Credentials { .. } => {
return async move { existing_credential }.boxed()
}
_ => {
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
async move { ProviderCredential::Credentials { api_key } }.boxed()
} else {
let credentials = cx.read_credentials(OPEN_AI_API_URL);
async move {
if let Some(Some((_, api_key))) = credentials.await.log_err() {
if let Some(api_key) = String::from_utf8(api_key).log_err() {
ProviderCredential::Credentials { api_key }
} else {
ProviderCredential::NoCredentials
}
} else {
ProviderCredential::NoCredentials
}
}
.boxed()
}
}
};
async move {
let retrieved_credential = retrieved_credential.await;
*self.credential.write() = retrieved_credential.clone();
retrieved_credential
}
.boxed()
}
fn save_credentials(
&self,
cx: &mut AppContext,
credential: ProviderCredential,
) -> BoxFuture<()> {
*self.credential.write() = credential.clone();
let credential = credential.clone();
let write_credentials = match credential {
ProviderCredential::Credentials { api_key } => {
Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
}
_ => None,
};
async move {
if let Some(write_credentials) = write_credentials {
write_credentials.await.log_err();
}
}
.boxed()
}
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
*self.credential.write() = ProviderCredential::NoCredentials;
let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
async move {
delete_credentials.await.log_err();
}
.boxed()
}
}
#[async_trait]
impl EmbeddingProvider for OpenAiEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn max_tokens_per_batch(&self) -> usize {
50000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
*self.rate_limit_count_rx.borrow()
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
let api_url = self.api_url.as_str();
let api_key = self.get_api_key()?;
let mut request_number = 0;
let mut rate_limiting = false;
let mut request_timeout: u64 = 15;
let mut response: Response<AsyncBody>;
while request_number < MAX_RETRIES {
response = self
.send_request(
&api_url,
&api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)
.await?;
request_number += 1;
match response.status() {
StatusCode::REQUEST_TIMEOUT => {
request_timeout += 5;
}
StatusCode::OK => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: OpenAiEmbeddingResponse = serde_json::from_str(&body)?;
log::trace!(
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
// If we complete a request successfully that was previously rate_limited
// resolve the rate limit
if rate_limiting {
self.resolve_rate_limit()
}
return Ok(response
.data
.into_iter()
.map(|embedding| Embedding::from(embedding.embedding))
.collect());
}
StatusCode::TOO_MANY_REQUESTS => {
rate_limiting = true;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let delay_duration = {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
if let Some(time_to_reset) =
response.headers().get("x-ratelimit-reset-tokens")
{
if let Ok(time_str) = time_to_reset.to_str() {
parse(time_str).unwrap_or(delay)
} else {
delay
}
} else {
delay
}
};
// If we've previously rate limited, increment the duration but not the count
let reset_time = Instant::now().add(delay_duration);
self.update_reset_time(reset_time);
log::trace!(
"openai rate limiting: waiting {:?} until lifted",
&delay_duration
);
self.executor.timer(delay_duration).await;
}
_ => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"open ai bad request: {:?} {:?}",
&response.status(),
body
));
}
}
}
Err(anyhow!("openai max retries"))
}
}

View file

@ -1,59 +0,0 @@
use anyhow::anyhow;
use tiktoken_rs::CoreBPE;
use crate::models::{LanguageModel, TruncationDirection};
use super::open_ai_bpe_tokenizer;
#[derive(Clone)]
pub struct OpenAiLanguageModel {
name: String,
bpe: Option<CoreBPE>,
}
impl OpenAiLanguageModel {
pub fn load(model_name: &str) -> Self {
let bpe = tiktoken_rs::get_bpe_from_model(model_name)
.unwrap_or(open_ai_bpe_tokenizer().to_owned());
OpenAiLanguageModel {
name: model_name.to_string(),
bpe: Some(bpe),
}
}
}
impl LanguageModel for OpenAiLanguageModel {
fn name(&self) -> String {
self.name.clone()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
if let Some(bpe) = &self.bpe {
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
match direction {
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
}
} else {
bpe.decode(tokens)
}
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
}
}

View file

@ -1,206 +0,0 @@
use std::{
sync::atomic::{self, AtomicUsize, Ordering},
time::Instant,
};
use async_trait::async_trait;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::AppContext;
use parking_lot::Mutex;
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
embedding::{Embedding, EmbeddingProvider},
models::{LanguageModel, TruncationDirection},
};
#[derive(Clone)]
pub struct FakeLanguageModel {
pub capacity: usize,
}
impl LanguageModel for FakeLanguageModel {
fn name(&self) -> String {
"dummy".to_string()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
println!("TRYING TO TRUNCATE: {:?}", length.clone());
if length > self.count_tokens(content)? {
println!("NOT TRUNCATING");
return anyhow::Ok(content.to_string());
}
anyhow::Ok(match direction {
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
.into_iter()
.collect::<String>(),
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
.into_iter()
.collect::<String>(),
})
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(self.capacity)
}
}
#[derive(Default)]
pub struct FakeEmbeddingProvider {
pub embedding_count: AtomicUsize,
}
impl Clone for FakeEmbeddingProvider {
fn clone(&self) -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
}
}
}
impl FakeEmbeddingProvider {
pub fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
pub fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result.into()
}
}
impl CredentialProvider for FakeEmbeddingProvider {
fn has_credentials(&self) -> bool {
true
}
fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
async { ProviderCredential::NotNeeded }.boxed()
}
fn save_credentials(
&self,
_cx: &mut AppContext,
_credential: ProviderCredential,
) -> BoxFuture<()> {
async {}.boxed()
}
fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
async {}.boxed()
}
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
Box::new(FakeLanguageModel { capacity: 1000 })
}
fn max_tokens_per_batch(&self) -> usize {
1000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
pub struct FakeCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
impl Clone for FakeCompletionProvider {
fn clone(&self) -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
}
impl FakeCompletionProvider {
pub fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
pub fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
pub fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
impl CredentialProvider for FakeCompletionProvider {
fn has_credentials(&self) -> bool {
true
}
fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
async { ProviderCredential::NotNeeded }.boxed()
}
fn save_credentials(
&self,
_cx: &mut AppContext,
_credential: ProviderCredential,
) -> BoxFuture<()> {
async {}.boxed()
}
fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
async {}.boxed()
}
}
impl CompletionProvider for FakeCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
model
}
fn complete(
&self,
_prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

View file

@ -5,17 +5,14 @@ edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/assistant.rs"
doctest = false
[dependencies]
ai.workspace = true
anyhow.workspace = true
chrono.workspace = true
client.workspace = true
collections.workspace = true
editor.workspace = true
fs.workspace = true
@ -26,12 +23,13 @@ language.workspace = true
log.workspace = true
menu.workspace = true
multi_buffer.workspace = true
open_ai = { workspace = true, features = ["schemars"] }
ordered-float.workspace = true
parking_lot.workspace = true
project.workspace = true
regex.workspace = true
schemars.workspace = true
search.workspace = true
semantic_index.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
@ -45,7 +43,6 @@ uuid.workspace = true
workspace.workspace = true
[dev-dependencies]
ai = { workspace = true, features = ["test-support"] }
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true

View file

@ -1,22 +1,24 @@
pub mod assistant_panel;
pub mod assistant_settings;
mod codegen;
mod completion_provider;
mod prompts;
mod saved_conversation;
mod streaming_diff;
use ai::providers::open_ai::Role;
use anyhow::Result;
pub use assistant_panel::AssistantPanel;
use assistant_settings::OpenAiModel;
use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
use chrono::{DateTime, Local};
use collections::HashMap;
use fs::Fs;
use futures::StreamExt;
use client::{proto, Client};
pub(crate) use completion_provider::*;
use gpui::{actions, AppContext, SharedString};
use regex::Regex;
pub(crate) use saved_conversation::*;
use serde::{Deserialize, Serialize};
use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
use util::paths::CONVERSATIONS_DIR;
use settings::Settings;
use std::{
fmt::{self, Display},
sync::Arc,
};
actions!(
assistant,
@ -30,7 +32,6 @@ actions!(
ResetKey,
InlineAssist,
ToggleIncludeConversation,
ToggleRetrieveContext,
]
);
@ -39,6 +40,139 @@ actions!(
)]
struct MessageId(usize);
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
Role::System => write!(f, "system"),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum LanguageModel {
ZedDotDev(ZedDotDevModel),
OpenAi(OpenAiModel),
}
impl Default for LanguageModel {
fn default() -> Self {
LanguageModel::ZedDotDev(ZedDotDevModel::default())
}
}
impl LanguageModel {
pub fn telemetry_id(&self) -> String {
match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
}
}
pub fn display_name(&self) -> String {
match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
}
}
pub fn max_token_count(&self) -> usize {
match self {
LanguageModel::OpenAi(model) => tiktoken_rs::model::get_context_size(model.id()),
LanguageModel::ZedDotDev(model) => match model {
ZedDotDevModel::GptThreePointFiveTurbo
| ZedDotDevModel::GptFour
| ZedDotDevModel::GptFourTurbo => tiktoken_rs::model::get_context_size(model.id()),
ZedDotDevModel::Custom(_) => 30720, // TODO: Base this on the selected model.
},
}
}
pub fn id(&self) -> &str {
match self {
LanguageModel::OpenAi(model) => model.id(),
LanguageModel::ZedDotDev(model) => model.id(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelRequestMessage {
pub role: Role,
pub content: String,
}
impl LanguageModelRequestMessage {
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
proto::LanguageModelRequestMessage {
role: match self.role {
Role::User => proto::LanguageModelRole::LanguageModelUser,
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
Role::System => proto::LanguageModelRole::LanguageModelSystem,
} as i32,
content: self.content.clone(),
}
}
}
#[derive(Debug, Default, Serialize)]
pub struct LanguageModelRequest {
pub model: LanguageModel,
pub messages: Vec<LanguageModelRequestMessage>,
pub stop: Vec<String>,
pub temperature: f32,
}
impl LanguageModelRequest {
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
proto::CompleteWithLanguageModel {
model: self.model.id().to_string(),
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
stop: self.stop.clone(),
temperature: self.temperature,
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct LanguageModelUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct LanguageModelChoiceDelta {
pub index: u32,
pub delta: LanguageModelResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
struct MessageMetadata {
role: Role,
@ -53,71 +187,9 @@ enum MessageStatus {
Error(SharedString),
}
#[derive(Serialize, Deserialize)]
struct SavedMessage {
id: MessageId,
start: usize,
}
#[derive(Serialize, Deserialize)]
struct SavedConversation {
id: Option<String>,
zed: String,
version: String,
text: String,
messages: Vec<SavedMessage>,
message_metadata: HashMap<MessageId, MessageMetadata>,
summary: String,
api_url: Option<String>,
model: OpenAiModel,
}
impl SavedConversation {
const VERSION: &'static str = "0.1.0";
}
struct SavedConversationMetadata {
title: String,
path: PathBuf,
mtime: chrono::DateTime<chrono::Local>,
}
impl SavedConversationMetadata {
pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
fs.create_dir(&CONVERSATIONS_DIR).await?;
let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
let mut conversations = Vec::<SavedConversationMetadata>::new();
while let Some(path) = paths.next().await {
let path = path?;
if path.extension() != Some(OsStr::new("json")) {
continue;
}
let pattern = r" - \d+.zed.json$";
let re = Regex::new(pattern).unwrap();
let metadata = fs.metadata(&path).await?;
if let Some((file_name, metadata)) = path
.file_name()
.and_then(|name| name.to_str())
.zip(metadata)
{
let title = re.replace(file_name, "");
conversations.push(Self {
title: title.into_owned(),
path,
mtime: metadata.mtime.into(),
});
}
}
conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
Ok(conversations)
}
}
pub fn init(cx: &mut AppContext) {
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
AssistantSettings::register(cx);
completion_provider::init(client, cx);
assistant_panel::init(cx);
}

File diff suppressed because it is too large Load diff

View file

@ -1,169 +1,296 @@
use ai::providers::open_ai::{
AzureOpenAiApiVersion, OpenAiCompletionProviderKind, OPEN_AI_API_URL,
};
use anyhow::anyhow;
use std::fmt;
use gpui::Pixels;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
pub use open_ai::Model as OpenAiModel;
use schemars::{
schema::{InstanceType, Metadata, Schema, SchemaObject},
JsonSchema,
};
use serde::{
de::{self, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
use settings::Settings;
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum OpenAiModel {
#[serde(rename = "gpt-3.5-turbo-0613")]
ThreePointFiveTurbo,
#[serde(rename = "gpt-4-0613")]
Four,
#[serde(rename = "gpt-4-1106-preview")]
FourTurbo,
#[derive(Clone, Debug, Default, PartialEq)]
pub enum ZedDotDevModel {
GptThreePointFiveTurbo,
GptFour,
#[default]
GptFourTurbo,
Custom(String),
}
impl OpenAiModel {
pub fn full_name(&self) -> &'static str {
impl Serialize for ZedDotDevModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.id())
}
}
impl<'de> Deserialize<'de> for ZedDotDevModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ZedDotDevModelVisitor;
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
type Value = ZedDotDevModel;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
match value {
"gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo),
"gpt-4" => Ok(ZedDotDevModel::GptFour),
"gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo),
_ => Ok(ZedDotDevModel::Custom(value.to_owned())),
}
}
}
deserializer.deserialize_str(ZedDotDevModelVisitor)
}
}
impl JsonSchema for ZedDotDevModel {
fn schema_name() -> String {
"ZedDotDevModel".to_owned()
}
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = vec![
"gpt-3.5-turbo".to_owned(),
"gpt-4".to_owned(),
"gpt-4-turbo-preview".to_owned(),
];
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()),
default: Some(serde_json::json!("gpt-4-turbo-preview")),
examples: vec![
serde_json::json!("gpt-3.5-turbo"),
serde_json::json!("gpt-4"),
serde_json::json!("gpt-4-turbo-preview"),
serde_json::json!("custom-model-name"),
],
..Default::default()
})),
..Default::default()
})
}
}
impl ZedDotDevModel {
pub fn id(&self) -> &str {
match self {
Self::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
Self::Four => "gpt-4-0613",
Self::FourTurbo => "gpt-4-1106-preview",
Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
Self::GptFour => "gpt-4",
Self::GptFourTurbo => "gpt-4-turbo-preview",
Self::Custom(id) => id,
}
}
pub fn short_name(&self) -> &'static str {
pub fn display_name(&self) -> &str {
match self {
Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
Self::Four => "gpt-4",
Self::FourTurbo => "gpt-4-turbo",
}
}
pub fn cycle(&self) -> Self {
match self {
Self::ThreePointFiveTurbo => Self::Four,
Self::Four => Self::FourTurbo,
Self::FourTurbo => Self::ThreePointFiveTurbo,
Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
Self::GptFour => "gpt-4",
Self::GptFourTurbo => "gpt-4-turbo",
Self::Custom(id) => id.as_str(),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum AssistantDockPosition {
Left,
#[default]
Right,
Bottom,
}
#[derive(Debug, Deserialize)]
pub struct AssistantSettings {
/// Whether to show the assistant panel button in the status bar.
pub button: bool,
/// Where to dock the assistant.
pub dock: AssistantDockPosition,
/// Default width in pixels when the assistant is docked to the left or right.
pub default_width: Pixels,
/// Default height in pixels when the assistant is docked to the bottom.
pub default_height: Pixels,
/// The default OpenAI model to use when starting new conversations.
#[deprecated = "Please use `provider.default_model` instead."]
pub default_open_ai_model: OpenAiModel,
/// OpenAI API base URL to use when starting new conversations.
#[deprecated = "Please use `provider.api_url` instead."]
pub openai_api_url: String,
/// The settings for the AI provider.
pub provider: AiProviderSettings,
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(tag = "name", rename_all = "snake_case")]
pub enum AssistantProvider {
#[serde(rename = "zed.dev")]
ZedDotDev {
#[serde(default)]
default_model: ZedDotDevModel,
},
#[serde(rename = "openai")]
OpenAi {
#[serde(default)]
default_model: OpenAiModel,
#[serde(default = "open_ai_url")]
api_url: String,
},
}
impl AssistantSettings {
pub fn provider_kind(&self) -> anyhow::Result<OpenAiCompletionProviderKind> {
match &self.provider {
AiProviderSettings::OpenAi(_) => Ok(OpenAiCompletionProviderKind::OpenAi),
AiProviderSettings::AzureOpenAi(settings) => {
let deployment_id = settings
.deployment_id
.clone()
.ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?;
let api_version = settings
.api_version
.ok_or_else(|| anyhow!("no Azure OpenAI API version"))?;
Ok(OpenAiCompletionProviderKind::AzureOpenAi {
deployment_id,
api_version,
})
}
impl Default for AssistantProvider {
fn default() -> Self {
Self::ZedDotDev {
default_model: ZedDotDevModel::default(),
}
}
}
pub fn provider_api_url(&self) -> anyhow::Result<String> {
match &self.provider {
AiProviderSettings::OpenAi(settings) => Ok(settings
.api_url
.clone()
.unwrap_or_else(|| OPEN_AI_API_URL.to_string())),
AiProviderSettings::AzureOpenAi(settings) => settings
.api_url
.clone()
.ok_or_else(|| anyhow!("no Azure OpenAI API URL")),
}
fn open_ai_url() -> String {
"https://api.openai.com/v1".into()
}
#[derive(Default, Debug, Deserialize, Serialize)]
pub struct AssistantSettings {
pub button: bool,
pub dock: AssistantDockPosition,
pub default_width: Pixels,
pub default_height: Pixels,
pub provider: AssistantProvider,
}
/// Assistant panel settings
#[derive(Clone, Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum AssistantSettingsContent {
Versioned(VersionedAssistantSettingsContent),
Legacy(LegacyAssistantSettingsContent),
}
impl JsonSchema for AssistantSettingsContent {
fn schema_name() -> String {
VersionedAssistantSettingsContent::schema_name()
}
pub fn provider_model(&self) -> anyhow::Result<OpenAiModel> {
match &self.provider {
AiProviderSettings::OpenAi(settings) => {
Ok(settings.default_model.unwrap_or(OpenAiModel::FourTurbo))
}
AiProviderSettings::AzureOpenAi(settings) => {
let deployment_id = settings
.deployment_id
.as_deref()
.ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?;
fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
VersionedAssistantSettingsContent::json_schema(gen)
}
match deployment_id {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-preview
"gpt-4" | "gpt-4-32k" => Ok(OpenAiModel::Four),
// https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-35
"gpt-35-turbo" | "gpt-35-turbo-16k" | "gpt-35-turbo-instruct" => {
Ok(OpenAiModel::ThreePointFiveTurbo)
fn is_referenceable() -> bool {
VersionedAssistantSettingsContent::is_referenceable()
}
}
impl Default for AssistantSettingsContent {
fn default() -> Self {
Self::Versioned(VersionedAssistantSettingsContent::default())
}
}
impl AssistantSettingsContent {
fn upgrade(&self) -> AssistantSettingsContentV1 {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
},
AssistantSettingsContent::Legacy(settings) => {
if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
AssistantSettingsContentV1 {
button: settings.button,
dock: settings.dock,
default_width: settings.default_width,
default_height: settings.default_height,
provider: Some(AssistantProvider::OpenAi {
default_model: settings
.default_open_ai_model
.clone()
.unwrap_or_default(),
api_url: open_ai_api_url.clone(),
}),
}
} else if let Some(open_ai_model) = settings.default_open_ai_model.clone() {
AssistantSettingsContentV1 {
button: settings.button,
dock: settings.dock,
default_width: settings.default_width,
default_height: settings.default_height,
provider: Some(AssistantProvider::OpenAi {
default_model: open_ai_model,
api_url: open_ai_url(),
}),
}
} else {
AssistantSettingsContentV1 {
button: settings.button,
dock: settings.dock,
default_width: settings.default_width,
default_height: settings.default_height,
provider: None,
}
_ => Err(anyhow!(
"no matching OpenAI model found for deployment ID: '{deployment_id}'"
)),
}
}
}
}
pub fn provider_model_name(&self) -> anyhow::Result<String> {
match &self.provider {
AiProviderSettings::OpenAi(settings) => Ok(settings
.default_model
.unwrap_or(OpenAiModel::FourTurbo)
.full_name()
.to_string()),
AiProviderSettings::AzureOpenAi(settings) => settings
.deployment_id
.clone()
.ok_or_else(|| anyhow!("no Azure OpenAI deployment ID")),
pub fn set_dock(&mut self, dock: AssistantDockPosition) {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => {
settings.dock = Some(dock);
}
},
AssistantSettingsContent::Legacy(settings) => {
settings.dock = Some(dock);
}
}
}
}
impl Settings for AssistantSettings {
const KEY: Option<&'static str> = Some("assistant");
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
#[serde(tag = "version")]
pub enum VersionedAssistantSettingsContent {
#[serde(rename = "1")]
V1(AssistantSettingsContentV1),
}
type FileContent = AssistantSettingsContent;
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_: &mut gpui::AppContext,
) -> anyhow::Result<Self> {
Self::load_via_json_merge(default_value, user_values)
impl Default for VersionedAssistantSettingsContent {
fn default() -> Self {
Self::V1(AssistantSettingsContentV1 {
button: None,
dock: None,
default_width: None,
default_height: None,
provider: None,
})
}
}
/// Assistant panel settings
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
pub struct AssistantSettingsContent {
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
pub struct AssistantSettingsContentV1 {
/// Whether to show the assistant panel button in the status bar.
///
/// Default: true
button: Option<bool>,
/// Where to dock the assistant.
///
/// Default: right
dock: Option<AssistantDockPosition>,
/// Default width in pixels when the assistant is docked to the left or right.
///
/// Default: 640
default_width: Option<f32>,
/// Default height in pixels when the assistant is docked to the bottom.
///
/// Default: 320
default_height: Option<f32>,
/// The provider of the assistant service.
///
/// This can either be the internal `zed.dev` service or an external `openai` service,
/// each with their respective default models and configurations.
provider: Option<AssistantProvider>,
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
pub struct LegacyAssistantSettingsContent {
/// Whether to show the assistant panel button in the status bar.
///
/// Default: true
@ -180,88 +307,164 @@ pub struct AssistantSettingsContent {
///
/// Default: 320
pub default_height: Option<f32>,
/// Deprecated: Please use `provider.default_model` instead.
/// The default OpenAI model to use when starting new conversations.
///
/// Default: gpt-4-1106-preview
#[deprecated = "Please use `provider.default_model` instead."]
pub default_open_ai_model: Option<OpenAiModel>,
/// Deprecated: Please use `provider.api_url` instead.
/// OpenAI API base URL to use when starting new conversations.
///
/// Default: https://api.openai.com/v1
#[deprecated = "Please use `provider.api_url` instead."]
pub openai_api_url: Option<String>,
/// The settings for the AI provider.
#[serde(default)]
pub provider: AiProviderSettingsContent,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AiProviderSettings {
/// The settings for the OpenAI provider.
#[serde(rename = "openai")]
OpenAi(OpenAiProviderSettings),
/// The settings for the Azure OpenAI provider.
#[serde(rename = "azure_openai")]
AzureOpenAi(AzureOpenAiProviderSettings),
}
impl Settings for AssistantSettings {
const KEY: Option<&'static str> = Some("assistant");
/// The settings for the AI provider used by the Zed Assistant.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AiProviderSettingsContent {
/// The settings for the OpenAI provider.
#[serde(rename = "openai")]
OpenAi(OpenAiProviderSettingsContent),
/// The settings for the Azure OpenAI provider.
#[serde(rename = "azure_openai")]
AzureOpenAi(AzureOpenAiProviderSettingsContent),
}
type FileContent = AssistantSettingsContent;
impl Default for AiProviderSettingsContent {
fn default() -> Self {
Self::OpenAi(OpenAiProviderSettingsContent::default())
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_: &mut gpui::AppContext,
) -> anyhow::Result<Self> {
let mut settings = AssistantSettings::default();
for value in [default_value].iter().chain(user_values) {
let value = value.upgrade();
merge(&mut settings.button, value.button);
merge(&mut settings.dock, value.dock);
merge(
&mut settings.default_width,
value.default_width.map(Into::into),
);
merge(
&mut settings.default_height,
value.default_height.map(Into::into),
);
if let Some(provider) = value.provider.clone() {
match (&mut settings.provider, provider) {
(
AssistantProvider::ZedDotDev { default_model },
AssistantProvider::ZedDotDev {
default_model: default_model_override,
},
) => {
*default_model = default_model_override;
}
(
AssistantProvider::OpenAi {
default_model,
api_url,
},
AssistantProvider::OpenAi {
default_model: default_model_override,
api_url: api_url_override,
},
) => {
*default_model = default_model_override;
*api_url = api_url_override;
}
(merged, provider_override) => {
*merged = provider_override;
}
}
}
}
Ok(settings)
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct OpenAiProviderSettings {
/// The OpenAI API base URL to use when starting new conversations.
pub api_url: Option<String>,
/// The default OpenAI model to use when starting new conversations.
pub default_model: Option<OpenAiModel>,
fn merge<T: Copy>(target: &mut T, value: Option<T>) {
if let Some(value) = value {
*target = value;
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
pub struct OpenAiProviderSettingsContent {
/// The OpenAI API base URL to use when starting new conversations.
///
/// Default: https://api.openai.com/v1
pub api_url: Option<String>,
/// The default OpenAI model to use when starting new conversations.
///
/// Default: gpt-4-1106-preview
pub default_model: Option<OpenAiModel>,
}
#[cfg(test)]
mod tests {
use gpui::AppContext;
use settings::SettingsStore;
#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAiProviderSettings {
/// The Azure OpenAI API base URL to use when starting new conversations.
pub api_url: Option<String>,
/// The Azure OpenAI API version.
pub api_version: Option<AzureOpenAiApiVersion>,
/// The Azure OpenAI API deployment ID.
pub deployment_id: Option<String>,
}
use super::*;
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AzureOpenAiProviderSettingsContent {
/// The Azure OpenAI API base URL to use when starting new conversations.
pub api_url: Option<String>,
/// The Azure OpenAI API version.
pub api_version: Option<AzureOpenAiApiVersion>,
/// The Azure OpenAI deployment ID.
pub deployment_id: Option<String>,
#[gpui::test]
fn test_deserialize_assistant_settings(cx: &mut AppContext) {
let store = settings::SettingsStore::test(cx);
cx.set_global(store);
// Settings default to gpt-4-turbo.
AssistantSettings::register(cx);
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
default_model: OpenAiModel::FourTurbo,
api_url: open_ai_url()
}
);
// Ensure backward-compatibility.
cx.update_global::<SettingsStore, _>(|store, cx| {
store
.set_user_settings(
r#"{
"assistant": {
"openai_api_url": "test-url",
}
}"#,
cx,
)
.unwrap();
});
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
default_model: OpenAiModel::FourTurbo,
api_url: "test-url".into()
}
);
cx.update_global::<SettingsStore, _>(|store, cx| {
store
.set_user_settings(
r#"{
"assistant": {
"default_open_ai_model": "gpt-4-0613"
}
}"#,
cx,
)
.unwrap();
});
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
default_model: OpenAiModel::Four,
api_url: open_ai_url()
}
);
// The new version supports setting a custom model when using zed.dev.
cx.update_global::<SettingsStore, _>(|store, cx| {
store
.set_user_settings(
r#"{
"assistant": {
"version": "1",
"provider": {
"name": "zed.dev",
"default_model": "custom"
}
}
}"#,
cx,
)
.unwrap();
});
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::ZedDotDev {
default_model: ZedDotDevModel::Custom("custom".into())
}
);
}
}

View file

@ -1,12 +1,13 @@
use crate::streaming_diff::{Hunk, StreamingDiff};
use ai::completion::{CompletionProvider, CompletionRequest};
use crate::{
streaming_diff::{Hunk, StreamingDiff},
CompletionProvider, LanguageModelRequest,
};
use anyhow::Result;
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use gpui::{EventEmitter, Model, ModelContext, Task};
use language::{Rope, TransactionId};
use multi_buffer;
use std::{cmp, future, ops::Range, sync::Arc};
use std::{cmp, future, ops::Range};
pub enum Event {
Finished,
@ -20,7 +21,6 @@ pub enum CodegenKind {
}
pub struct Codegen {
provider: Arc<dyn CompletionProvider>,
buffer: Model<MultiBuffer>,
snapshot: MultiBufferSnapshot,
kind: CodegenKind,
@ -35,15 +35,9 @@ pub struct Codegen {
impl EventEmitter<Event> for Codegen {}
impl Codegen {
pub fn new(
buffer: Model<MultiBuffer>,
kind: CodegenKind,
provider: Arc<dyn CompletionProvider>,
cx: &mut ModelContext<Self>,
) -> Self {
pub fn new(buffer: Model<MultiBuffer>, kind: CodegenKind, cx: &mut ModelContext<Self>) -> Self {
let snapshot = buffer.read(cx).snapshot(cx);
Self {
provider,
buffer: buffer.clone(),
snapshot,
kind,
@ -94,7 +88,7 @@ impl Codegen {
self.error.as_ref()
}
pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
let range = self.range();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
@ -108,7 +102,7 @@ impl Codegen {
.next()
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
let response = self.provider.complete(prompt);
let response = CompletionProvider::global(cx).complete(prompt);
self.generation = cx.spawn(|this, mut cx| {
async move {
let generate = async {
@ -305,7 +299,7 @@ fn strip_invalid_spans_from_codeblock(
}
if first_line {
if buffer == "" || buffer == "`" || buffer == "``" {
if buffer.is_empty() || buffer == "`" || buffer == "``" {
return future::ready(None);
} else if buffer.starts_with("```") {
starts_with_markdown_codeblock = true;
@ -360,8 +354,9 @@ fn strip_invalid_spans_from_codeblock(
mod tests {
use std::sync::Arc;
use crate::FakeCompletionProvider;
use super::*;
use ai::test::FakeCompletionProvider;
use futures::stream::{self};
use gpui::{Context, TestAppContext};
use indoc::indoc;
@ -378,15 +373,11 @@ mod tests {
pub name: String,
}
impl CompletionRequest for DummyCompletionRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
let provider = FakeCompletionProvider::default();
cx.set_global(cx.update(SettingsStore::test));
cx.set_global(CompletionProvider::Fake(provider.clone()));
cx.update(language_settings::init);
let text = indoc! {"
@ -405,19 +396,10 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Transform { range },
provider.clone(),
cx,
)
});
let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Transform { range }, cx));
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
let request = LanguageModelRequest::default();
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
@ -430,8 +412,7 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
println!("CHUNK: {:?}", &chunk);
provider.send_completion(chunk);
provider.send_completion(chunk.into());
new_text = suffix;
cx.background_executor.run_until_parked();
}
@ -456,6 +437,8 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
let provider = FakeCompletionProvider::default();
cx.set_global(CompletionProvider::Fake(provider.clone()));
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
@ -472,19 +455,10 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))
});
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
let request = LanguageModelRequest::default();
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
@ -497,7 +471,7 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
provider.send_completion(chunk.into());
new_text = suffix;
cx.background_executor.run_until_parked();
}
@ -522,6 +496,8 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
let provider = FakeCompletionProvider::default();
cx.set_global(CompletionProvider::Fake(provider.clone()));
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
@ -538,19 +514,10 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))
});
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
let request = LanguageModelRequest::default();
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
@ -563,8 +530,7 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
println!("{:?}", &chunk);
provider.send_completion(chunk);
provider.send_completion(chunk.into());
new_text = suffix;
cx.background_executor.run_until_parked();
}

View file

@ -0,0 +1,188 @@
#[cfg(test)]
mod fake;
mod open_ai;
mod zed;
#[cfg(test)]
pub use fake::*;
pub use open_ai::*;
pub use zed::*;
use crate::{
assistant_settings::{AssistantProvider, AssistantSettings},
LanguageModel, LanguageModelRequest,
};
use anyhow::Result;
use client::Client;
use futures::{future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, AppContext, Task, WindowContext};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let mut settings_version = 0;
let provider = match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { default_model } => {
CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
default_model.clone(),
client.clone(),
settings_version,
cx,
))
}
AssistantProvider::OpenAi {
default_model,
api_url,
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
default_model.clone(),
api_url.clone(),
client.http_client(),
settings_version,
)),
};
cx.set_global(provider);
cx.observe_global::<SettingsStore>(move |cx| {
settings_version += 1;
cx.update_global::<CompletionProvider, _>(|provider, cx| {
match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
(
CompletionProvider::OpenAi(provider),
AssistantProvider::OpenAi {
default_model,
api_url,
},
) => {
provider.update(default_model.clone(), api_url.clone(), settings_version);
}
(
CompletionProvider::ZedDotDev(provider),
AssistantProvider::ZedDotDev { default_model },
) => {
provider.update(default_model.clone(), settings_version);
}
(CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => {
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
default_model.clone(),
client.clone(),
settings_version,
cx,
));
}
(
CompletionProvider::ZedDotDev(_),
AssistantProvider::OpenAi {
default_model,
api_url,
},
) => {
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
default_model.clone(),
api_url.clone(),
client.http_client(),
settings_version,
));
}
#[cfg(test)]
(CompletionProvider::Fake(_), _) => unimplemented!(),
}
})
})
.detach();
}
pub enum CompletionProvider {
OpenAi(OpenAiCompletionProvider),
ZedDotDev(ZedDotDevCompletionProvider),
#[cfg(test)]
Fake(FakeCompletionProvider),
}
impl gpui::Global for CompletionProvider {}
impl CompletionProvider {
pub fn global(cx: &AppContext) -> &Self {
cx.global::<Self>()
}
pub fn settings_version(&self) -> usize {
match self {
CompletionProvider::OpenAi(provider) => provider.settings_version(),
CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
}
pub fn is_authenticated(&self) -> bool {
match self {
CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
#[cfg(test)]
CompletionProvider::Fake(_) => true,
}
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
match self {
CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
#[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())),
}
}
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
match self {
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
}
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
match self {
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
#[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())),
}
}
pub fn default_model(&self) -> LanguageModel {
match self {
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
CompletionProvider::ZedDotDev(provider) => {
LanguageModel::ZedDotDev(provider.default_model())
}
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
}
pub fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
match self {
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
}
pub fn complete(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
match self {
CompletionProvider::OpenAi(provider) => provider.complete(request),
CompletionProvider::ZedDotDev(provider) => provider.complete(request),
#[cfg(test)]
CompletionProvider::Fake(provider) => provider.complete(),
}
}
}

View file

@ -0,0 +1,29 @@
use anyhow::Result;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use std::sync::Arc;
#[derive(Clone, Default)]
pub struct FakeCompletionProvider {
current_completion_tx: Arc<parking_lot::Mutex<Option<mpsc::UnboundedSender<String>>>>,
}
impl FakeCompletionProvider {
pub fn complete(&self) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::unbounded();
*self.current_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(Ok).boxed()) }.boxed()
}
pub fn send_completion(&self, chunk: String) {
self.current_completion_tx
.lock()
.as_ref()
.unwrap()
.unbounded_send(chunk)
.unwrap();
}
pub fn finish_completion(&self) {
self.current_completion_tx.lock().take();
}
}

View file

@ -0,0 +1,301 @@
use crate::{
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
};
use anyhow::{anyhow, Result};
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
use settings::Settings;
use std::{env, sync::Arc};
use theme::ThemeSettings;
use ui::prelude::*;
use util::{http::HttpClient, ResultExt};
pub struct OpenAiCompletionProvider {
api_key: Option<String>,
api_url: String,
default_model: OpenAiModel,
http_client: Arc<dyn HttpClient>,
settings_version: usize,
}
impl OpenAiCompletionProvider {
pub fn new(
default_model: OpenAiModel,
api_url: String,
http_client: Arc<dyn HttpClient>,
settings_version: usize,
) -> Self {
Self {
api_key: None,
api_url,
default_model,
http_client,
settings_version,
}
}
pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) {
self.default_model = default_model;
self.api_url = api_url;
self.settings_version = settings_version;
}
pub fn settings_version(&self) -> usize {
self.settings_version
}
pub fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
let api_url = self.api_url.clone();
cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
api_key
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or_else(|| anyhow!("credentials not found"))?;
String::from_utf8(api_key)?
};
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::OpenAi(provider) = provider {
provider.api_key = Some(api_key);
}
})
})
}
}
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let delete_credentials = cx.delete_credentials(&self.api_url);
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::OpenAi(provider) = provider {
provider.api_key = None;
}
})
})
}
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
.into()
}
pub fn default_model(&self) -> OpenAiModel {
self.default_model.clone()
}
pub fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
count_open_ai_tokens(request, cx.background_executor())
}
pub fn complete(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = self.to_open_ai_request(request);
let http_client = self.http_client.clone();
let api_key = self.api_key.clone();
let api_url = self.api_url.clone();
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
let model = match request.model {
LanguageModel::ZedDotDev(_) => self.default_model(),
LanguageModel::OpenAi(model) => model,
};
Request {
model,
messages: request
.messages
.into_iter()
.map(|msg| RequestMessage {
role: msg.role.into(),
content: msg.content,
})
.collect(),
stream: true,
stop: request.stop,
temperature: request.temperature,
}
}
}
pub fn count_open_ai_tokens(
request: LanguageModelRequest,
background_executor: &gpui::BackgroundExecutor,
) -> BoxFuture<'static, Result<usize>> {
background_executor
.spawn(async move {
let messages = request
.messages
.into_iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(message.content),
name: None,
function_call: None,
})
.collect::<Vec<_>>();
tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
})
.boxed()
}
impl From<Role> for open_ai::Role {
fn from(val: Role) -> Self {
match val {
Role::User => OpenAiRole::User,
Role::Assistant => OpenAiRole::Assistant,
Role::System => OpenAiRole::System,
}
}
}
struct AuthenticationPrompt {
api_key: View<Editor>,
api_url: String,
}
impl AuthenticationPrompt {
fn new(api_url: String, cx: &mut WindowContext) -> Self {
Self {
api_key: cx.new_view(|cx| {
let mut editor = Editor::single_line(cx);
editor.set_placeholder_text(
"sk-000000000000000000000000000000000000000000000000",
cx,
);
editor
}),
api_url,
}
}
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
let api_key = self.api_key.read(cx).text(cx);
if api_key.is_empty() {
return;
}
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
cx.spawn(|_, mut cx| async move {
write_credentials.await?;
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::OpenAi(provider) = provider {
provider.api_key = Some(api_key);
}
})
})
.detach_and_log_err(cx);
}
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.ui_font.family.clone(),
font_features: settings.ui_font.features,
font_size: rems(0.875).into(),
font_weight: FontWeight::NORMAL,
font_style: FontStyle::Normal,
line_height: relative(1.3),
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
};
EditorElement::new(
&self.api_key,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
..Default::default()
},
)
}
}
impl Render for AuthenticationPrompt {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
const INSTRUCTIONS: [&str; 6] = [
"To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
" - You can create an API key at: platform.openai.com/api-keys",
" - Make sure your OpenAI account has credits",
" - Having a subscription for another service like GitHub Copilot won't work.",
"",
"Paste your OpenAI API key below and hit enter to use the assistant:",
];
v_flex()
.p_4()
.size_full()
.on_action(cx.listener(Self::save_api_key))
.children(
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
)
.child(
h_flex()
.w_full()
.my_2()
.px_2()
.py_1()
.bg(cx.theme().colors().editor_background)
.rounded_md()
.child(self.render_api_key_editor(cx)),
)
.child(
Label::new(
"You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
)
.size(LabelSize::Small),
)
.child(
h_flex()
.gap_2()
.child(Label::new("Click on").size(LabelSize::Small))
.child(Icon::new(IconName::Ai).size(IconSize::XSmall))
.child(
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
),
)
.into_any()
}
}

View file

@ -0,0 +1,167 @@
use crate::{
assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider,
LanguageModelRequest,
};
use anyhow::{anyhow, Result};
use client::{proto, Client};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
use gpui::{AnyView, AppContext, Task};
use std::{future, sync::Arc};
use ui::prelude::*;
pub struct ZedDotDevCompletionProvider {
client: Arc<Client>,
default_model: ZedDotDevModel,
settings_version: usize,
status: client::Status,
_maintain_client_status: Task<()>,
}
impl ZedDotDevCompletionProvider {
pub fn new(
default_model: ZedDotDevModel,
client: Arc<Client>,
settings_version: usize,
cx: &mut AppContext,
) -> Self {
let mut status_rx = client.status();
let status = *status_rx.borrow();
let maintain_client_status = cx.spawn(|mut cx| async move {
while let Some(status) = status_rx.next().await {
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::ZedDotDev(provider) = provider {
provider.status = status;
} else {
unreachable!()
}
});
}
});
Self {
client,
default_model,
settings_version,
status,
_maintain_client_status: maintain_client_status,
}
}
pub fn update(&mut self, default_model: ZedDotDevModel, settings_version: usize) {
self.default_model = default_model;
self.settings_version = settings_version;
}
pub fn settings_version(&self) -> usize {
self.settings_version
}
pub fn default_model(&self) -> ZedDotDevModel {
self.default_model.clone()
}
pub fn is_authenticated(&self) -> bool {
self.status.is_connected()
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
let client = self.client.clone();
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
}
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|_cx| AuthenticationPrompt).into()
}
pub fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
match request.model {
crate::LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFour)
| crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFourTurbo)
| crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptThreePointFiveTurbo) => {
count_open_ai_tokens(request, cx.background_executor())
}
crate::LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
let request = self.client.request(proto::CountTokensWithLanguageModel {
model,
messages: request
.messages
.iter()
.map(|message| message.to_proto())
.collect(),
});
async move {
let response = request.await?;
Ok(response.token_count as usize)
}
.boxed()
}
}
}
pub fn complete(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = proto::CompleteWithLanguageModel {
model: request.model.id().to_string(),
messages: request
.messages
.iter()
.map(|message| message.to_proto())
.collect(),
stop: request.stop,
temperature: request.temperature,
};
self.client
.request_stream(request)
.map_ok(|stream| {
stream
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed()
})
.boxed()
}
}
struct AuthenticationPrompt;
impl Render for AuthenticationPrompt {
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
v_flex()
.gap_2()
.child(
Button::new("sign_in", "Sign in")
.icon_color(Color::Muted)
.icon(IconName::Github)
.icon_position(IconPosition::Start)
.style(ButtonStyle::Filled)
.full_width()
.on_click(|_, cx| {
CompletionProvider::global(cx)
.authenticate(cx)
.detach_and_log_err(cx);
}),
)
.child(
div().flex().w_full().items_center().child(
Label::new("Sign in to enable collaboration.")
.color(Color::Muted)
.size(LabelSize::Small),
),
),
)
}
}

View file

@ -1,394 +1,95 @@
use ai::models::LanguageModel;
use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
use ai::prompts::file_context::FileContext;
use ai::prompts::generate::GenerateInlineContent;
use ai::prompts::preamble::EngineerPreamble;
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
use ai::providers::open_ai::OpenAiLanguageModel;
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp::{self, Reverse};
use std::ops::Range;
use std::sync::Arc;
#[allow(dead_code)]
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
#[derive(Debug)]
struct Match {
collapse: Range<usize>,
keep: Vec<Range<usize>>,
}
let selected_range = selected_range.to_offset(buffer);
let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
Some(&grammar.embedding_config.as_ref()?.query)
});
let configs = ts_matches
.grammars()
.iter()
.map(|g| g.embedding_config.as_ref().unwrap())
.collect::<Vec<_>>();
let mut matches = Vec::new();
while let Some(mat) = ts_matches.peek() {
let config = &configs[mat.grammar_index];
if let Some(collapse) = mat.captures.iter().find_map(|cap| {
if Some(cap.index) == config.collapse_capture_ix {
Some(cap.node.byte_range())
} else {
None
}
}) {
let mut keep = Vec::new();
for capture in mat.captures.iter() {
if Some(capture.index) == config.keep_capture_ix {
keep.push(capture.node.byte_range());
} else {
continue;
}
}
ts_matches.advance();
matches.push(Match { collapse, keep });
} else {
ts_matches.advance();
}
}
matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
let mut matches = matches.into_iter().peekable();
let mut summary = String::new();
let mut offset = 0;
let mut flushed_selection = false;
while let Some(mat) = matches.next() {
// Keep extending the collapsed range if the next match surrounds
// the current one.
while let Some(next_mat) = matches.peek() {
if mat.collapse.start <= next_mat.collapse.start
&& mat.collapse.end >= next_mat.collapse.end
{
matches.next().unwrap();
} else {
break;
}
}
if offset > mat.collapse.start {
// Skip collapsed nodes that have already been summarized.
offset = cmp::max(offset, mat.collapse.end);
continue;
}
if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
if !flushed_selection {
// The collapsed node ends after the selection starts, so we'll flush the selection first.
summary.extend(buffer.text_for_range(offset..selected_range.start));
summary.push_str("<|S|");
if selected_range.end == selected_range.start {
summary.push_str(">");
} else {
summary.extend(buffer.text_for_range(selected_range.clone()));
summary.push_str("|E|>");
}
offset = selected_range.end;
flushed_selection = true;
}
// If the selection intersects the collapsed node, we won't collapse it.
if selected_range.end >= mat.collapse.start {
continue;
}
}
summary.extend(buffer.text_for_range(offset..mat.collapse.start));
for keep in mat.keep {
summary.extend(buffer.text_for_range(keep));
}
offset = mat.collapse.end;
}
// Flush selection if we haven't already done so.
if !flushed_selection && offset <= selected_range.start {
summary.extend(buffer.text_for_range(offset..selected_range.start));
summary.push_str("<|S|");
if selected_range.end == selected_range.start {
summary.push_str(">");
} else {
summary.extend(buffer.text_for_range(selected_range.clone()));
summary.push_str("|E|>");
}
offset = selected_range.end;
}
summary.extend(buffer.text_for_range(offset..buffer.len()));
summary
}
use language::BufferSnapshot;
use std::{fmt::Write, ops::Range};
pub fn generate_content_prompt(
user_prompt: String,
language_name: Option<&str>,
buffer: BufferSnapshot,
range: Range<usize>,
search_results: Vec<PromptCodeSnippet>,
model: &str,
project_name: Option<String>,
) -> anyhow::Result<String> {
// Using new Prompt Templates
let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAiLanguageModel::load(model));
let lang_name = if let Some(language_name) = language_name {
Some(language_name.to_string())
let mut prompt = String::new();
let content_type = match language_name {
None | Some("Markdown" | "Plain Text") => {
writeln!(prompt, "You are an expert engineer.")?;
"Text"
}
Some(language_name) => {
writeln!(prompt, "You are an expert {language_name} engineer.")?;
writeln!(
prompt,
"Your answer MUST always and only be valid {}.",
language_name
)?;
"Code"
}
};
if let Some(project_name) = project_name {
writeln!(
prompt,
"You are currently working inside the '{project_name}' project in code editor Zed."
)?;
}
// Include file content.
for chunk in buffer.text_for_range(0..range.start) {
prompt.push_str(chunk);
}
if range.is_empty() {
prompt.push_str("<|START|>");
} else {
None
};
let args = PromptArguments {
model: openai_model,
language_name: lang_name.clone(),
project_name,
snippets: search_results.clone(),
reserved_tokens: 1000,
buffer: Some(buffer),
selected_range: Some(range),
user_prompt: Some(user_prompt.clone()),
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
(
PromptPriority::Ordered { order: 1 },
Box::new(RepositoryContext {}),
),
(
PromptPriority::Ordered { order: 0 },
Box::new(FileContext {}),
),
(
PromptPriority::Mandatory,
Box::new(GenerateInlineContent {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, _) = chain.generate(true)?;
anyhow::Ok(prompt)
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use gpui::{AppContext, Context};
use indoc::indoc;
use language::{
language_settings, tree_sitter_rust, Buffer, BufferId, Language, LanguageConfig,
LanguageMatcher, Point,
};
use settings::SettingsStore;
use std::sync::Arc;
pub(crate) fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_embedding_query(
r#"
(
[(line_comment) (attribute_item)]* @context
.
[
(struct_item
name: (_) @name)
(enum_item
name: (_) @name)
(impl_item
trait: (_)? @name
"for"? @name
type: (_) @name)
(trait_item
name: (_) @name)
(function_item
name: (_) @name
body: (block
"{" @keep
"}" @keep) @collapse)
(macro_definition
name: (_) @name)
] @item
)
"#,
)
.unwrap()
prompt.push_str("<|START|");
}
#[gpui::test]
fn test_outline_for_prompt(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language_settings::init(cx);
let text = indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {
let a = 1;
let b = 2;
Self { a, b }
}
pub fn a(&self, param: bool) -> usize {
self.a
}
pub fn b(&self) -> usize {
self.b
}
}
"};
let buffer = cx.new_model(|cx| {
Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
});
let snapshot = buffer.read(cx).snapshot();
assert_eq!(
summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
indoc! {"
struct X {
<|S|>a: usize,
b: usize,
}
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
assert_eq!(
summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {
let <|S|a |E|>= 1;
let b = 2;
Self { a, b }
}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
assert_eq!(
summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
<|S|>
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
assert_eq!(
summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
<|S|>"}
);
// Ensure nested functions get collapsed properly.
let text = indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {
let a = 1;
let b = 2;
Self { a, b }
}
pub fn a(&self, param: bool) -> usize {
let a = 30;
fn nested() -> usize {
3
}
self.a + nested()
}
pub fn b(&self) -> usize {
self.b
}
}
"};
buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
let snapshot = buffer.read(cx).snapshot();
assert_eq!(
summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
indoc! {"
<|S|>struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
for chunk in buffer.text_for_range(range.clone()) {
prompt.push_str(chunk);
}
if !range.is_empty() {
prompt.push_str("|END|>");
}
for chunk in buffer.text_for_range(range.end..buffer.len()) {
prompt.push_str(chunk);
}
prompt.push('\n');
if range.is_empty() {
writeln!(
prompt,
"Assume the cursor is located where the `<|START|>` span is."
)
.unwrap();
writeln!(
prompt,
"{content_type} can't be replaced, so assume your answer will be inserted at the cursor.",
)
.unwrap();
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}",
)
.unwrap();
} else {
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
writeln!(
prompt,
"Double check that you only return code and not the '<|START|' and '|END|'> spans"
)
.unwrap();
}
writeln!(prompt, "Never make remarks about the output.").unwrap();
writeln!(
prompt,
"Do not return anything else, except the generated {content_type}."
)
.unwrap();
Ok(prompt)
}

View file

@ -0,0 +1,121 @@
use crate::{assistant_settings::OpenAiModel, MessageId, MessageMetadata};
use anyhow::{anyhow, Result};
use collections::HashMap;
use fs::Fs;
use futures::StreamExt;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::{
cmp::Reverse,
ffi::OsStr,
path::{Path, PathBuf},
sync::Arc,
};
use util::paths::CONVERSATIONS_DIR;
#[derive(Serialize, Deserialize)]
pub struct SavedMessage {
pub id: MessageId,
pub start: usize,
}
#[derive(Serialize, Deserialize)]
pub struct SavedConversation {
pub id: Option<String>,
pub zed: String,
pub version: String,
pub text: String,
pub messages: Vec<SavedMessage>,
pub message_metadata: HashMap<MessageId, MessageMetadata>,
pub summary: String,
}
impl SavedConversation {
pub const VERSION: &'static str = "0.2.0";
pub async fn load(path: &Path, fs: &dyn Fs) -> Result<Self> {
let saved_conversation = fs.load(path).await?;
let saved_conversation_json =
serde_json::from_str::<serde_json::Value>(&saved_conversation)?;
match saved_conversation_json
.get("version")
.ok_or_else(|| anyhow!("version not found"))?
{
serde_json::Value::String(version) => match version.as_str() {
Self::VERSION => Ok(serde_json::from_value::<Self>(saved_conversation_json)?),
"0.1.0" => {
let saved_conversation =
serde_json::from_value::<SavedConversationV0_1_0>(saved_conversation_json)?;
Ok(Self {
id: saved_conversation.id,
zed: saved_conversation.zed,
version: saved_conversation.version,
text: saved_conversation.text,
messages: saved_conversation.messages,
message_metadata: saved_conversation.message_metadata,
summary: saved_conversation.summary,
})
}
_ => Err(anyhow!(
"unrecognized saved conversation version: {}",
version
)),
},
_ => Err(anyhow!("version not found on saved conversation")),
}
}
}
#[derive(Serialize, Deserialize)]
struct SavedConversationV0_1_0 {
id: Option<String>,
zed: String,
version: String,
text: String,
messages: Vec<SavedMessage>,
message_metadata: HashMap<MessageId, MessageMetadata>,
summary: String,
api_url: Option<String>,
model: OpenAiModel,
}
pub struct SavedConversationMetadata {
pub title: String,
pub path: PathBuf,
pub mtime: chrono::DateTime<chrono::Local>,
}
impl SavedConversationMetadata {
pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
fs.create_dir(&CONVERSATIONS_DIR).await?;
let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
let mut conversations = Vec::<SavedConversationMetadata>::new();
while let Some(path) = paths.next().await {
let path = path?;
if path.extension() != Some(OsStr::new("json")) {
continue;
}
let pattern = r" - \d+.zed.json$";
let re = Regex::new(pattern).unwrap();
let metadata = fs.metadata(&path).await?;
if let Some((file_name, metadata)) = path
.file_name()
.and_then(|name| name.to_str())
.zip(metadata)
{
let title = re.replace(file_name, "");
conversations.push(Self {
title: title.into_owned(),
path,
mtime: metadata.mtime.into(),
});
}
}
conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
Ok(conversations)
}
}

View file

@ -197,12 +197,10 @@ impl StreamingDiff {
} else {
hunks.push(Hunk::Remove { len: char_len })
}
} else if let Some(Hunk::Keep { len }) = hunks.last_mut() {
*len += char_len;
} else {
if let Some(Hunk::Keep { len }) = hunks.last_mut() {
*len += char_len;
} else {
hunks.push(Hunk::Keep { len: char_len })
}
hunks.push(Hunk::Keep { len: char_len })
}
}

View file

@ -13,7 +13,7 @@ use async_tungstenite::tungstenite::{
use clock::SystemClock;
use collections::HashMap;
use futures::{
channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt,
channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
TryFutureExt as _, TryStreamExt,
};
use gpui::{
@ -36,7 +36,10 @@ use std::{
future::Future,
marker::PhantomData,
path::PathBuf,
sync::{atomic::AtomicU64, Arc, Weak},
sync::{
atomic::{AtomicU64, Ordering},
Arc, Weak,
},
time::{Duration, Instant},
};
use telemetry::Telemetry;
@ -442,7 +445,7 @@ impl Client {
}
pub fn id(&self) -> u64 {
self.id.load(std::sync::atomic::Ordering::SeqCst)
self.id.load(Ordering::SeqCst)
}
pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
@ -450,7 +453,7 @@ impl Client {
}
pub fn set_id(&self, id: u64) -> &Self {
self.id.store(id, std::sync::atomic::Ordering::SeqCst);
self.id.store(id, Ordering::SeqCst);
self
}
@ -1260,6 +1263,30 @@ impl Client {
.map_ok(|envelope| envelope.payload)
}
pub fn request_stream<T: RequestMessage>(
&self,
request: T,
) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
let client_id = self.id.load(Ordering::SeqCst);
log::debug!(
"rpc request start. client_id:{}. name:{}",
client_id,
T::NAME
);
let response = self
.connection_id()
.map(|conn_id| self.peer.request_stream(conn_id, request));
async move {
let response = response?.await;
log::debug!(
"rpc request finish. client_id:{}. name:{}",
client_id,
T::NAME
);
response
}
}
pub fn request_envelope<T: RequestMessage>(
&self,
request: T,

View file

@ -261,7 +261,7 @@ impl Telemetry {
self: &Arc<Self>,
conversation_id: Option<String>,
kind: AssistantKind,
model: &str,
model: String,
) {
let event = Event::Assistant(AssistantEvent {
conversation_id,

View file

@ -31,10 +31,12 @@ collections.workspace = true
dashmap = "5.4"
envy = "0.4.2"
futures.workspace = true
google_ai.workspace = true
hex.workspace = true
live_kit_server.workspace = true
log.workspace = true
nanoid = "0.4"
open_ai.workspace = true
parking_lot.workspace = true
prometheus = "0.13"
prost.workspace = true
@ -80,7 +82,6 @@ git = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
lazy_static.workspace = true
live_kit_client = { workspace = true, features = ["test-support"] }
lsp = { workspace = true, features = ["test-support"] }
menu.workspace = true

View file

@ -379,6 +379,16 @@ CREATE TABLE extension_versions (
CREATE UNIQUE INDEX "index_extensions_external_id" ON "extensions" ("external_id");
CREATE INDEX "index_extensions_total_download_count" ON "extensions" ("total_download_count");
CREATE TABLE rate_buckets (
user_id INT NOT NULL,
rate_limit_name VARCHAR(255) NOT NULL,
token_count INT NOT NULL,
last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL,
PRIMARY KEY (user_id, rate_limit_name),
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name);
CREATE TABLE hosted_projects (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id),

View file

@ -0,0 +1,11 @@
CREATE TABLE IF NOT EXISTS rate_buckets (
user_id INT NOT NULL,
rate_limit_name VARCHAR(255) NOT NULL,
token_count INT NOT NULL,
last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL,
PRIMARY KEY (user_id, rate_limit_name),
CONSTRAINT fk_user
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name);

75
crates/collab/src/ai.rs Normal file
View file

@ -0,0 +1,75 @@
use anyhow::{anyhow, Result};
use rpc::proto;
pub fn language_model_request_to_open_ai(
request: proto::CompleteWithLanguageModel,
) -> Result<open_ai::Request> {
Ok(open_ai::Request {
model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo),
messages: request
.messages
.into_iter()
.map(|message| {
let role = proto::LanguageModelRole::from_i32(message.role)
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
Ok(open_ai::RequestMessage {
role: match role {
proto::LanguageModelRole::LanguageModelUser => open_ai::Role::User,
proto::LanguageModelRole::LanguageModelAssistant => {
open_ai::Role::Assistant
}
proto::LanguageModelRole::LanguageModelSystem => open_ai::Role::System,
},
content: message.content,
})
})
.collect::<Result<Vec<open_ai::RequestMessage>>>()?,
stream: true,
stop: request.stop,
temperature: request.temperature,
})
}
pub fn language_model_request_to_google_ai(
request: proto::CompleteWithLanguageModel,
) -> Result<google_ai::GenerateContentRequest> {
Ok(google_ai::GenerateContentRequest {
contents: request
.messages
.into_iter()
.map(language_model_request_message_to_google_ai)
.collect::<Result<Vec<_>>>()?,
generation_config: None,
safety_settings: None,
})
}
pub fn language_model_request_message_to_google_ai(
message: proto::LanguageModelRequestMessage,
) -> Result<google_ai::Content> {
let role = proto::LanguageModelRole::from_i32(message.role)
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
Ok(google_ai::Content {
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
text: message.content,
})],
role: match role {
proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
},
})
}
pub fn count_tokens_request_to_google_ai(
request: proto::CountTokensWithLanguageModel,
) -> Result<google_ai::CountTokensRequest> {
Ok(google_ai::CountTokensRequest {
contents: request
.messages
.into_iter()
.map(language_model_request_message_to_google_ai)
.collect::<Result<Vec<_>>>()?,
})
}

View file

@ -1,6 +1,5 @@
use crate::{
db::{ExtensionMetadata, NewExtensionVersion},
executor::Executor,
AppState, Error, Result,
};
use anyhow::{anyhow, Context as _};
@ -136,7 +135,7 @@ async fn download_extension(
const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>, executor: Executor) {
pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
let Some(blob_store_client) = app_state.blob_store_client.clone() else {
log::info!("no blob store client");
return;
@ -146,6 +145,7 @@ pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>, e
return;
};
let executor = app_state.executor.clone();
executor.spawn_detached({
let executor = executor.clone();
async move {

View file

@ -10,6 +10,7 @@ pub mod hosted_projects;
pub mod messages;
pub mod notifications;
pub mod projects;
pub mod rate_buckets;
pub mod rooms;
pub mod servers;
pub mod users;

View file

@ -0,0 +1,58 @@
use super::*;
use crate::db::tables::rate_buckets;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
impl Database {
/// Saves the rate limit for the given user and rate limit name if the last_refill is later
/// than the currently saved timestamp.
pub async fn save_rate_buckets(&self, buckets: &[rate_buckets::Model]) -> Result<()> {
if buckets.is_empty() {
return Ok(());
}
self.transaction(|tx| async move {
rate_buckets::Entity::insert_many(buckets.iter().map(|bucket| {
rate_buckets::ActiveModel {
user_id: ActiveValue::Set(bucket.user_id),
rate_limit_name: ActiveValue::Set(bucket.rate_limit_name.clone()),
token_count: ActiveValue::Set(bucket.token_count),
last_refill: ActiveValue::Set(bucket.last_refill),
}
}))
.on_conflict(
OnConflict::columns([
rate_buckets::Column::UserId,
rate_buckets::Column::RateLimitName,
])
.update_columns([
rate_buckets::Column::TokenCount,
rate_buckets::Column::LastRefill,
])
.to_owned(),
)
.exec(&*tx)
.await?;
Ok(())
})
.await
}
/// Retrieves the rate limit for the given user and rate limit name.
pub async fn get_rate_bucket(
&self,
user_id: UserId,
rate_limit_name: &str,
) -> Result<Option<rate_buckets::Model>> {
self.transaction(|tx| async move {
let rate_limit = rate_buckets::Entity::find()
.filter(rate_buckets::Column::UserId.eq(user_id))
.filter(rate_buckets::Column::RateLimitName.eq(rate_limit_name))
.one(&*tx)
.await?;
Ok(rate_limit)
})
.await
}
}

View file

@ -22,6 +22,7 @@ pub mod observed_buffer_edits;
pub mod observed_channel_messages;
pub mod project;
pub mod project_collaborator;
pub mod rate_buckets;
pub mod room;
pub mod room_participant;
pub mod server;

View file

@ -0,0 +1,31 @@
use crate::db::UserId;
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "rate_buckets")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub user_id: UserId,
#[sea_orm(primary_key, auto_increment = false)]
pub rate_limit_name: String,
pub token_count: i32,
pub last_refill: DateTime,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::UserId",
to = "super::user::Column::Id"
)]
User,
}
impl Related<super::user::Entity> for Entity {
fn to() -> RelationDef {
Relation::User.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -1,8 +1,10 @@
pub mod ai;
pub mod api;
pub mod auth;
pub mod db;
pub mod env;
pub mod executor;
mod rate_limiter;
pub mod rpc;
#[cfg(test)]
@ -13,6 +15,7 @@ use aws_config::{BehaviorVersion, Region};
use axum::{http::StatusCode, response::IntoResponse};
use db::{ChannelId, Database};
use executor::Executor;
pub use rate_limiter::*;
use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::ResultExt;
@ -126,6 +129,8 @@ pub struct Config {
pub blob_store_secret_key: Option<String>,
pub blob_store_bucket: Option<String>,
pub zed_environment: Arc<str>,
pub openai_api_key: Option<Arc<str>>,
pub google_ai_api_key: Option<Arc<str>>,
pub zed_client_checksum_seed: Option<String>,
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
@ -147,12 +152,14 @@ pub struct AppState {
pub db: Arc<Database>,
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>,
pub rate_limiter: Arc<RateLimiter>,
pub executor: Executor,
pub clickhouse_client: Option<clickhouse::Client>,
pub config: Config,
}
impl AppState {
pub async fn new(config: Config) -> Result<Arc<Self>> {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
let mut db_options = db::ConnectOptions::new(config.database_url.clone());
db_options.max_connections(config.database_max_connections);
let mut db = Database::new(db_options, Executor::Production).await?;
@ -173,10 +180,13 @@ impl AppState {
None
};
let db = Arc::new(db);
let this = Self {
db: Arc::new(db),
db: db.clone(),
live_kit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
rate_limiter: Arc::new(RateLimiter::new(db)),
executor,
clickhouse_client: config
.clickhouse_url
.as_ref()

View file

@ -7,7 +7,7 @@ use axum::{
};
use collab::{
api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
Config, MigrateConfig, Result,
Config, MigrateConfig, RateLimiter, Result,
};
use db::Database;
use std::{
@ -62,18 +62,27 @@ async fn main() -> Result<()> {
run_migrations().await?;
let state = AppState::new(config).await?;
let state = AppState::new(config, Executor::Production).await?;
let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
.expect("failed to bind TCP listener");
let epoch = state
.db
.create_server(&state.config.zed_environment)
.await?;
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
rpc_server.start().await?;
fetch_extensions_from_blob_store_periodically(state.clone());
RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone());
let rpc_server = if is_collab {
let epoch = state
.db
.create_server(&state.config.zed_environment)
.await?;
let rpc_server =
collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
rpc_server.start().await?;
Some(rpc_server)
@ -82,7 +91,7 @@ async fn main() -> Result<()> {
};
if is_api {
fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production);
fetch_extensions_from_blob_store_periodically(state.clone());
}
let mut app = collab::api::routes(rpc_server.clone(), state.clone());

View file

@ -0,0 +1,274 @@
use crate::{db::UserId, executor::Executor, Database, Error, Result};
use anyhow::anyhow;
use chrono::{DateTime, Duration, Utc};
use dashmap::{DashMap, DashSet};
use sea_orm::prelude::DateTimeUtc;
use std::sync::Arc;
use util::ResultExt;
pub trait RateLimit: 'static {
fn capacity() -> usize;
fn refill_duration() -> Duration;
fn db_name() -> &'static str;
}
/// Used to enforce per-user rate limits
pub struct RateLimiter {
buckets: DashMap<(UserId, String), RateBucket>,
dirty_buckets: DashSet<(UserId, String)>,
db: Arc<Database>,
}
impl RateLimiter {
pub fn new(db: Arc<Database>) -> Self {
RateLimiter {
buckets: DashMap::new(),
dirty_buckets: DashSet::new(),
db,
}
}
/// Spawns a new task that periodically saves rate limit data to the database.
pub fn save_periodically(rate_limiter: Arc<Self>, executor: Executor) {
const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
executor.clone().spawn_detached(async move {
loop {
executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await;
rate_limiter.save().await.log_err();
}
});
}
/// Returns an error if the user has exceeded the specified `RateLimit`.
/// Attempts to read the from the database if no cached RateBucket currently exists.
pub async fn check<T: RateLimit>(&self, user_id: UserId) -> Result<()> {
self.check_internal::<T>(user_id, Utc::now()).await
}
async fn check_internal<T: RateLimit>(&self, user_id: UserId, now: DateTimeUtc) -> Result<()> {
let bucket_key = (user_id, T::db_name().to_string());
// Attempt to fetch the bucket from the database if it hasn't been cached.
// For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
// but this enforces limits across restarts so long as the database is reachable.
if !self.buckets.contains_key(&bucket_key) {
if let Some(bucket) = self.load_bucket::<T>(user_id).await.log_err().flatten() {
self.buckets.insert(bucket_key.clone(), bucket);
self.dirty_buckets.insert(bucket_key.clone());
}
}
let mut bucket = self
.buckets
.entry(bucket_key.clone())
.or_insert_with(|| RateBucket::new(T::capacity(), T::refill_duration(), now));
if bucket.value_mut().allow(now) {
self.dirty_buckets.insert(bucket_key);
Ok(())
} else {
Err(anyhow!("rate limit exceeded"))?
}
}
async fn load_bucket<K: RateLimit>(
&self,
user_id: UserId,
) -> Result<Option<RateBucket>, Error> {
Ok(self
.db
.get_rate_bucket(user_id, K::db_name())
.await?
.map(|saved_bucket| RateBucket {
capacity: K::capacity(),
refill_time_per_token: K::refill_duration(),
token_count: saved_bucket.token_count as usize,
last_refill: DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
}))
}
pub async fn save(&self) -> Result<()> {
let mut buckets = Vec::new();
self.dirty_buckets.retain(|key| {
if let Some(bucket) = self.buckets.get(&key) {
buckets.push(crate::db::rate_buckets::Model {
user_id: key.0,
rate_limit_name: key.1.clone(),
token_count: bucket.token_count as i32,
last_refill: bucket.last_refill.naive_utc(),
});
}
false
});
match self.db.save_rate_buckets(&buckets).await {
Ok(()) => Ok(()),
Err(err) => {
for bucket in buckets {
self.dirty_buckets
.insert((bucket.user_id, bucket.rate_limit_name));
}
Err(err)
}
}
}
}
#[derive(Clone)]
struct RateBucket {
capacity: usize,
token_count: usize,
refill_time_per_token: Duration,
last_refill: DateTimeUtc,
}
impl RateBucket {
fn new(capacity: usize, refill_duration: Duration, now: DateTimeUtc) -> Self {
RateBucket {
capacity,
token_count: capacity,
refill_time_per_token: refill_duration / capacity as i32,
last_refill: now,
}
}
fn allow(&mut self, now: DateTimeUtc) -> bool {
self.refill(now);
if self.token_count > 0 {
self.token_count -= 1;
true
} else {
false
}
}
fn refill(&mut self, now: DateTimeUtc) {
let elapsed = now - self.last_refill;
if elapsed >= self.refill_time_per_token {
let new_tokens =
elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
self.last_refill = now;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::{NewUserParams, TestDb};
use gpui::TestAppContext;
#[gpui::test]
async fn test_rate_limiter(cx: &mut TestAppContext) {
let test_db = TestDb::sqlite(cx.executor().clone());
let db = test_db.db().clone();
let user_1 = db
.create_user(
"user-1@zed.dev",
false,
NewUserParams {
github_login: "user-1".into(),
github_user_id: 1,
},
)
.await
.unwrap()
.user_id;
let user_2 = db
.create_user(
"user-2@zed.dev",
false,
NewUserParams {
github_login: "user-2".into(),
github_user_id: 2,
},
)
.await
.unwrap()
.user_id;
let mut now = Utc::now();
let rate_limiter = RateLimiter::new(db.clone());
// User 1 can access resource A two times before being rate-limited.
rate_limiter
.check_internal::<RateLimitA>(user_1, now)
.await
.unwrap();
rate_limiter
.check_internal::<RateLimitA>(user_1, now)
.await
.unwrap();
rate_limiter
.check_internal::<RateLimitA>(user_1, now)
.await
.unwrap_err();
// User 2 can access resource A and user 1 can access resource B.
rate_limiter
.check_internal::<RateLimitB>(user_2, now)
.await
.unwrap();
rate_limiter
.check_internal::<RateLimitB>(user_1, now)
.await
.unwrap();
// After one second, user 1 can make another request before being rate-limited again.
now += Duration::seconds(1);
rate_limiter
.check_internal::<RateLimitA>(user_1, now)
.await
.unwrap();
rate_limiter
.check_internal::<RateLimitA>(user_1, now)
.await
.unwrap_err();
rate_limiter.save().await.unwrap();
// Rate limits are reloaded from the database, so user A is still rate-limited
// for resource A.
let rate_limiter = RateLimiter::new(db.clone());
rate_limiter
.check_internal::<RateLimitA>(user_1, now)
.await
.unwrap_err();
}
struct RateLimitA;
impl RateLimit for RateLimitA {
fn capacity() -> usize {
2
}
fn refill_duration() -> Duration {
Duration::seconds(2)
}
fn db_name() -> &'static str {
"rate-limit-a"
}
}
struct RateLimitB;
impl RateLimit for RateLimitB {
fn capacity() -> usize {
10
}
fn refill_duration() -> Duration {
Duration::seconds(3)
}
fn db_name() -> &'static str {
"rate-limit-b"
}
}
}

View file

@ -9,9 +9,9 @@ use crate::{
User, UserId,
},
executor::Executor,
AppState, Error, Result,
AppState, Error, RateLimit, RateLimiter, Result,
};
use anyhow::anyhow;
use anyhow::{anyhow, Context as _};
use async_tungstenite::tungstenite::{
protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
};
@ -30,6 +30,8 @@ use axum::{
};
use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
use futures::{
channel::oneshot,
future::{self, BoxFuture},
@ -39,15 +41,14 @@ use futures::{
use prometheus::{register_int_gauge, IntGauge};
use rpc::{
proto::{
self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
},
Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
};
use serde::{Serialize, Serializer};
use std::{
any::TypeId,
fmt,
future::Future,
marker::PhantomData,
mem,
@ -64,7 +65,7 @@ use time::OffsetDateTime;
use tokio::sync::{watch, Semaphore};
use tower::ServiceBuilder;
use tracing::{field, info_span, instrument, Instrument};
use util::SemanticVersion;
use util::{http::IsahcHttpClient, SemanticVersion};
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
@ -92,6 +93,18 @@ impl<R: RequestMessage> Response<R> {
}
}
struct StreamingResponse<R: RequestMessage> {
peer: Arc<Peer>,
receipt: Receipt<R>,
}
impl<R: RequestMessage> StreamingResponse<R> {
fn send(&self, payload: R::Response) -> Result<()> {
self.peer.respond(self.receipt, payload)?;
Ok(())
}
}
#[derive(Clone)]
struct Session {
user_id: UserId,
@ -100,6 +113,8 @@ struct Session {
peer: Arc<Peer>,
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
http_client: IsahcHttpClient,
rate_limiter: Arc<RateLimiter>,
_executor: Executor,
}
@ -124,8 +139,8 @@ impl Session {
}
}
impl fmt::Debug for Session {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
impl Debug for Session {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Session")
.field("user_id", &self.user_id)
.field("connection_id", &self.connection_id)
@ -148,7 +163,6 @@ pub struct Server {
peer: Arc<Peer>,
pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
app_state: Arc<AppState>,
executor: Executor,
handlers: HashMap<TypeId, MessageHandler>,
teardown: watch::Sender<bool>,
}
@ -175,12 +189,11 @@ where
}
impl Server {
pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
let mut server = Self {
id: parking_lot::Mutex::new(id),
peer: Peer::new(id.0 as u32),
app_state,
executor,
app_state: app_state.clone(),
connection_pool: Default::default(),
handlers: Default::default(),
teardown: watch::channel(false).0,
@ -280,7 +293,30 @@ impl Server {
.add_message_handler(update_followers)
.add_request_handler(get_private_user_info)
.add_message_handler(acknowledge_channel_message)
.add_message_handler(acknowledge_buffer_version);
.add_message_handler(acknowledge_buffer_version)
.add_streaming_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
complete_with_language_model(
request,
response,
session,
app_state.config.openai_api_key.clone(),
app_state.config.google_ai_api_key.clone(),
)
}
})
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
count_tokens_with_language_model(
request,
response,
session,
app_state.config.google_ai_api_key.clone(),
)
}
});
Arc::new(server)
}
@ -289,12 +325,12 @@ impl Server {
let server_id = *self.id.lock();
let app_state = self.app_state.clone();
let peer = self.peer.clone();
let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
let pool = self.connection_pool.clone();
let live_kit_client = self.app_state.live_kit_client.clone();
let span = info_span!("start server");
self.executor.spawn_detached(
self.app_state.executor.spawn_detached(
async move {
tracing::info!("waiting for cleanup timeout");
timeout.await;
@ -536,6 +572,40 @@ impl Server {
})
}
fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
Fut: Send + Future<Output = Result<()>>,
M: RequestMessage,
{
let handler = Arc::new(handler);
self.add_handler(move |envelope, session| {
let receipt = envelope.receipt();
let handler = handler.clone();
async move {
let peer = session.peer.clone();
let response = StreamingResponse {
peer: peer.clone(),
receipt,
};
match (handler)(envelope.payload, response, session).await {
Ok(()) => {
peer.end_stream(receipt)?;
Ok(())
}
Err(error) => {
let proto_err = match &error {
Error::Internal(err) => err.to_proto(),
_ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
};
peer.respond_with_error(receipt, proto_err)?;
Err(error)
}
}
}
})
}
#[allow(clippy::too_many_arguments)]
pub fn handle_connection(
self: &Arc<Self>,
@ -569,6 +639,14 @@ impl Server {
tracing::Span::current().record("connection_id", format!("{}", connection_id));
tracing::info!("connection opened");
let http_client = match IsahcHttpClient::new() {
Ok(http_client) => http_client,
Err(error) => {
tracing::error!(?error, "failed to create HTTP client");
return;
}
};
let session = Session {
user_id,
connection_id,
@ -576,7 +654,9 @@ impl Server {
peer: this.peer.clone(),
connection_pool: this.connection_pool.clone(),
live_kit_client: this.app_state.live_kit_client.clone(),
_executor: executor.clone()
http_client,
rate_limiter: this.app_state.rate_limiter.clone(),
_executor: executor.clone(),
};
if let Err(error) = this.send_initial_client_update(connection_id, user, zed_version, send_connection_id, &session).await {
@ -3220,6 +3300,207 @@ async fn acknowledge_buffer_version(
Ok(())
}
struct CompleteWithLanguageModelRateLimit;
impl RateLimit for CompleteWithLanguageModelRateLimit {
fn capacity() -> usize {
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(120) // Picked arbitrarily
}
fn refill_duration() -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name() -> &'static str {
"complete-with-language-model"
}
}
async fn complete_with_language_model(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: Session,
open_ai_api_key: Option<Arc<str>>,
google_ai_api_key: Option<Arc<str>>,
) -> Result<()> {
authorize_access_to_language_models(&session).await?;
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id)
.await?;
if request.model.starts_with("gpt") {
let api_key =
open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
complete_with_open_ai(request, response, session, api_key).await?;
} else if request.model.starts_with("gemini") {
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
complete_with_google_ai(request, response, session, api_key).await?;
}
Ok(())
}
async fn complete_with_open_ai(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: Session,
api_key: Arc<str>,
) -> Result<()> {
const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
let mut completion_stream = open_ai::stream_completion(
&session.http_client,
OPEN_AI_API_URL,
&api_key,
crate::ai::language_model_request_to_open_ai(request)?,
)
.await
.context("open_ai::stream_completion request failed")?;
while let Some(event) = completion_stream.next().await {
let event = event?;
response.send(proto::LanguageModelResponse {
choices: event
.choices
.into_iter()
.map(|choice| proto::LanguageModelChoiceDelta {
index: choice.index,
delta: Some(proto::LanguageModelResponseMessage {
role: choice.delta.role.map(|role| match role {
open_ai::Role::User => LanguageModelRole::LanguageModelUser,
open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
} as i32),
content: choice.delta.content,
}),
finish_reason: choice.finish_reason,
})
.collect(),
})?;
}
Ok(())
}
async fn complete_with_google_ai(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: Session,
api_key: Arc<str>,
) -> Result<()> {
let mut stream = google_ai::stream_generate_content(
&session.http_client,
google_ai::API_URL,
api_key.as_ref(),
crate::ai::language_model_request_to_google_ai(request)?,
)
.await
.context("google_ai::stream_generate_content request failed")?;
while let Some(event) = stream.next().await {
let event = event?;
response.send(proto::LanguageModelResponse {
choices: event
.candidates
.unwrap_or_default()
.into_iter()
.map(|candidate| proto::LanguageModelChoiceDelta {
index: candidate.index as u32,
delta: Some(proto::LanguageModelResponseMessage {
role: Some(match candidate.content.role {
google_ai::Role::User => LanguageModelRole::LanguageModelUser,
google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
} as i32),
content: Some(
candidate
.content
.parts
.into_iter()
.filter_map(|part| match part {
google_ai::Part::TextPart(part) => Some(part.text),
google_ai::Part::InlineDataPart(_) => None,
})
.collect(),
),
}),
finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
})
.collect(),
})?;
}
Ok(())
}
struct CountTokensWithLanguageModelRateLimit;
impl RateLimit for CountTokensWithLanguageModelRateLimit {
fn capacity() -> usize {
std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(600) // Picked arbitrarily
}
fn refill_duration() -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name() -> &'static str {
"count-tokens-with-language-model"
}
}
async fn count_tokens_with_language_model(
request: proto::CountTokensWithLanguageModel,
response: Response<proto::CountTokensWithLanguageModel>,
session: Session,
google_ai_api_key: Option<Arc<str>>,
) -> Result<()> {
authorize_access_to_language_models(&session).await?;
if !request.model.starts_with("gemini") {
return Err(anyhow!(
"counting tokens for model: {:?} is not supported",
request.model
))?;
}
session
.rate_limiter
.check::<CountTokensWithLanguageModelRateLimit>(session.user_id)
.await?;
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
let tokens_response = google_ai::count_tokens(
&session.http_client,
google_ai::API_URL,
&api_key,
crate::ai::count_tokens_request_to_google_ai(request)?,
)
.await?;
response.send(proto::CountTokensResponse {
token_count: tokens_response.total_tokens as u32,
})?;
Ok(())
}
async fn authorize_access_to_language_models(session: &Session) -> Result<(), Error> {
let db = session.db().await;
let flags = db.get_user_flags(session.user_id).await?;
if flags.iter().any(|flag| flag == "language-models") {
Ok(())
} else {
Err(anyhow!("permission denied"))?
}
}
/// Start receiving chat updates for a channel
async fn join_channel_chat(
request: proto::JoinChannelChat,

View file

@ -2,7 +2,7 @@ use crate::{
db::{tests::TestDb, NewUserParams, UserId},
executor::Executor,
rpc::{Server, ZedVersion, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
AppState, Config,
AppState, Config, RateLimiter,
};
use anyhow::anyhow;
use call::ActiveCall;
@ -93,17 +93,14 @@ impl TestServer {
deterministic.clone(),
)
.unwrap();
let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
let executor = Executor::Deterministic(deterministic.clone());
let app_state = Self::build_app_state(&test_db, &live_kit_server, executor.clone()).await;
let epoch = app_state
.db
.create_server(&app_state.config.zed_environment)
.await
.unwrap();
let server = Server::new(
epoch,
app_state.clone(),
Executor::Deterministic(deterministic.clone()),
);
let server = Server::new(epoch, app_state.clone());
server.start().await.unwrap();
// Advance clock to ensure the server's cleanup task is finished.
deterministic.advance_clock(CLEANUP_TIMEOUT);
@ -482,12 +479,15 @@ impl TestServer {
pub async fn build_app_state(
test_db: &TestDb,
fake_server: &live_kit_client::TestServer,
live_kit_test_server: &live_kit_client::TestServer,
executor: Executor,
) -> Arc<AppState> {
Arc::new(AppState {
db: test_db.db().clone(),
live_kit_client: Some(Arc::new(fake_server.create_api_client())),
live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
blob_store_client: None,
rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
executor,
clickhouse_client: None,
config: Config {
http_port: 0,
@ -506,6 +506,8 @@ impl TestServer {
blob_store_access_key: None,
blob_store_secret_key: None,
blob_store_bucket: None,
openai_api_key: None,
google_ai_api_key: None,
clickhouse_url: None,
clickhouse_user: None,
clickhouse_password: None,

View file

@ -0,0 +1,14 @@
[package]
name = "google_ai"
version = "0.1.0"
edition = "2021"
[lib]
path = "src/google_ai.rs"
[dependencies]
anyhow.workspace = true
futures.workspace = true
serde.workspace = true
serde_json.workspace = true
util.workspace = true

View file

@ -0,0 +1,266 @@
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use serde::{Deserialize, Serialize};
use util::http::HttpClient;
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
pub async fn stream_generate_content<T: HttpClient>(
client: &T,
api_url: &str,
api_key: &str,
request: GenerateContentRequest,
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
let uri = format!(
"{}/v1beta/models/gemini-pro:streamGenerateContent?alt=sse&key={}",
api_url, api_key
);
let request = serde_json::to_string(&request)?;
let mut response = client.post_json(&uri, request.into()).await?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
Ok(reader
.lines()
.filter_map(|line| async move {
match line {
Ok(line) => {
if let Some(line) = line.strip_prefix("data: ") {
match serde_json::from_str(line) {
Ok(response) => Some(Ok(response)),
Err(error) => Some(Err(anyhow!(error))),
}
} else {
None
}
}
Err(error) => Some(Err(anyhow!(error))),
}
})
.boxed())
} else {
let mut text = String::new();
response.body_mut().read_to_string(&mut text).await?;
Err(anyhow!(
"error during streamGenerateContent, status code: {:?}, body: {}",
response.status(),
text
))
}
}
pub async fn count_tokens<T: HttpClient>(
client: &T,
api_url: &str,
api_key: &str,
request: CountTokensRequest,
) -> Result<CountTokensResponse> {
let uri = format!(
"{}/v1beta/models/gemini-pro:countTokens?key={}",
api_url, api_key
);
let request = serde_json::to_string(&request)?;
let mut response = client.post_json(&uri, request.into()).await?;
let mut text = String::new();
response.body_mut().read_to_string(&mut text).await?;
if response.status().is_success() {
Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
} else {
Err(anyhow!(
"error during countTokens, status code: {:?}, body: {}",
response.status(),
text
))
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum Task {
#[serde(rename = "generateContent")]
GenerateContent,
#[serde(rename = "streamGenerateContent")]
StreamGenerateContent,
#[serde(rename = "countTokens")]
CountTokens,
#[serde(rename = "embedContent")]
EmbedContent,
#[serde(rename = "batchEmbedContents")]
BatchEmbedContents,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest {
pub contents: Vec<Content>,
pub generation_config: Option<GenerationConfig>,
pub safety_settings: Option<Vec<SafetySetting>>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentResponse {
pub candidates: Option<Vec<GenerateContentCandidate>>,
pub prompt_feedback: Option<PromptFeedback>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentCandidate {
pub index: usize,
pub content: Content,
pub finish_reason: Option<String>,
pub finish_message: Option<String>,
pub safety_ratings: Option<Vec<SafetyRating>>,
pub citation_metadata: Option<CitationMetadata>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Content {
pub parts: Vec<Part>,
pub role: Role,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub enum Role {
User,
Model,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Part {
TextPart(TextPart),
InlineDataPart(InlineDataPart),
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TextPart {
pub text: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InlineDataPart {
pub inline_data: GenerativeContentBlob,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerativeContentBlob {
pub mime_type: String,
pub data: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CitationSource {
pub start_index: Option<usize>,
pub end_index: Option<usize>,
pub uri: Option<String>,
pub license: Option<String>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CitationMetadata {
pub citation_sources: Vec<CitationSource>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptFeedback {
pub block_reason: Option<String>,
pub safety_ratings: Vec<SafetyRating>,
pub block_reason_message: Option<String>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
pub candidate_count: Option<usize>,
pub stop_sequences: Option<Vec<String>>,
pub max_output_tokens: Option<usize>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub top_k: Option<usize>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetySetting {
pub category: HarmCategory,
pub threshold: HarmBlockThreshold,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum HarmCategory {
#[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
Unspecified,
#[serde(rename = "HARM_CATEGORY_DEROGATORY")]
Derogatory,
#[serde(rename = "HARM_CATEGORY_TOXICITY")]
Toxicity,
#[serde(rename = "HARM_CATEGORY_VIOLENCE")]
Violence,
#[serde(rename = "HARM_CATEGORY_SEXUAL")]
Sexual,
#[serde(rename = "HARM_CATEGORY_MEDICAL")]
Medical,
#[serde(rename = "HARM_CATEGORY_DANGEROUS")]
Dangerous,
#[serde(rename = "HARM_CATEGORY_HARASSMENT")]
Harassment,
#[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
HateSpeech,
#[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
SexuallyExplicit,
#[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
DangerousContent,
}
#[derive(Debug, Serialize)]
pub enum HarmBlockThreshold {
#[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
Unspecified,
#[serde(rename = "BLOCK_LOW_AND_ABOVE")]
BlockLowAndAbove,
#[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
BlockMediumAndAbove,
#[serde(rename = "BLOCK_ONLY_HIGH")]
BlockOnlyHigh,
#[serde(rename = "BLOCK_NONE")]
BlockNone,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum HarmProbability {
#[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
Unspecified,
Negligible,
Low,
Medium,
High,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetyRating {
pub category: HarmCategory,
pub probability: HarmProbability,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensRequest {
pub contents: Vec<Content>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensResponse {
pub total_tokens: usize,
}

19
crates/open_ai/Cargo.toml Normal file
View file

@ -0,0 +1,19 @@
[package]
name = "open_ai"
version = "0.1.0"
edition = "2021"
[lib]
path = "src/open_ai.rs"
[features]
default = []
schemars = ["dep:schemars"]
[dependencies]
anyhow.workspace = true
futures.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
util.workspace = true

View file

@ -0,0 +1,182 @@
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl TryFrom<String> for Role {
type Error = anyhow::Error;
fn try_from(value: String) -> Result<Self> {
match value.as_str() {
"user" => Ok(Self::User),
"assistant" => Ok(Self::Assistant),
"system" => Ok(Self::System),
_ => Err(anyhow!("invalid role '{value}'")),
}
}
}
impl From<Role> for String {
fn from(val: Role) -> Self {
match val {
Role::User => "user".to_owned(),
Role::Assistant => "assistant".to_owned(),
Role::System => "system".to_owned(),
}
}
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub enum Model {
#[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
ThreePointFiveTurbo,
#[serde(rename = "gpt-4", alias = "gpt-4-0613")]
Four,
#[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
#[default]
FourTurbo,
}
impl Model {
pub fn from_id(id: &str) -> Result<Self> {
match id {
"gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
"gpt-4" => Ok(Self::Four),
"gpt-4-turbo-preview" => Ok(Self::FourTurbo),
_ => Err(anyhow!("invalid model id")),
}
}
pub fn id(&self) -> &'static str {
match self {
Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
Self::Four => "gpt-4",
Self::FourTurbo => "gpt-4-turbo-preview",
}
}
pub fn display_name(&self) -> &'static str {
match self {
Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
Self::Four => "gpt-4",
Self::FourTurbo => "gpt-4-turbo",
}
}
}
#[derive(Debug, Serialize)]
pub struct Request {
pub model: Model,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct ResponseStreamEvent {
pub created: u32,
pub model: String,
pub choices: Vec<ChoiceDelta>,
pub usage: Option<Usage>,
}
pub async fn stream_completion(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: Request,
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
let uri = format!("{api_url}/chat/completions");
let request = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
Ok(reader
.lines()
.filter_map(|line| async move {
match line {
Ok(line) => {
let line = line.strip_prefix("data: ")?;
if line == "[DONE]" {
None
} else {
match serde_json::from_str(line) {
Ok(response) => Some(Ok(response)),
Err(error) => Some(Err(anyhow!(error))),
}
}
}
Err(error) => Some(Err(anyhow!(error))),
}
})
.boxed())
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAiResponse {
error: OpenAiError,
}
#[derive(Deserialize)]
struct OpenAiError {
message: String,
}
match serde_json::from_str::<OpenAiResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}

View file

@ -1,7 +1,7 @@
syntax = "proto3";
package zed.messages;
// Looking for a number? Search "// Current max"
// Looking for a number? Search "// current max"
message PeerId {
uint32 owner_id = 1;
@ -26,6 +26,7 @@ message Envelope {
Error error = 6;
Ping ping = 7;
Test test = 8;
EndStream end_stream = 165;
CreateRoom create_room = 9;
CreateRoomResponse create_room_response = 10;
@ -198,6 +199,11 @@ message Envelope {
GetImplementationResponse get_implementation_response = 163;
JoinHostedProject join_hosted_project = 164;
CompleteWithLanguageModel complete_with_language_model = 166;
LanguageModelResponse language_model_response = 167;
CountTokensWithLanguageModel count_tokens_with_language_model = 168;
CountTokensResponse count_tokens_response = 169; // current max
}
reserved 158 to 161;
@ -236,6 +242,8 @@ enum ErrorCode {
reserved 6;
}
message EndStream {}
message Test {
uint64 id = 1;
}
@ -1718,3 +1726,45 @@ message SetRoomParticipantRole {
uint64 user_id = 2;
ChannelRole role = 3;
}
message CompleteWithLanguageModel {
string model = 1;
repeated LanguageModelRequestMessage messages = 2;
repeated string stop = 3;
float temperature = 4;
}
message LanguageModelRequestMessage {
LanguageModelRole role = 1;
string content = 2;
}
enum LanguageModelRole {
LanguageModelUser = 0;
LanguageModelAssistant = 1;
LanguageModelSystem = 2;
}
message LanguageModelResponseMessage {
optional LanguageModelRole role = 1;
optional string content = 2;
}
message LanguageModelResponse {
repeated LanguageModelChoiceDelta choices = 1;
}
message LanguageModelChoiceDelta {
uint32 index = 1;
LanguageModelResponseMessage delta = 2;
optional string finish_reason = 3;
}
message CountTokensWithLanguageModel {
string model = 1;
repeated LanguageModelRequestMessage messages = 2;
}
message CountTokensResponse {
uint32 token_count = 1;
}

View file

@ -80,7 +80,7 @@ pub trait ErrorExt {
fn error_tag(&self, k: &str) -> Option<&str>;
/// to_proto() converts the error into a proto::Error
fn to_proto(&self) -> proto::Error;
///
/// Clones the error and turns into an [anyhow::Error].
fn cloned(&self) -> anyhow::Error;
}

View file

@ -9,19 +9,21 @@ use collections::HashMap;
use futures::{
channel::{mpsc, oneshot},
stream::BoxStream,
FutureExt, SinkExt, StreamExt, TryFutureExt,
FutureExt, SinkExt, Stream, StreamExt, TryFutureExt,
};
use parking_lot::{Mutex, RwLock};
use serde::{ser::SerializeStruct, Serialize};
use std::{fmt, sync::atomic::Ordering::SeqCst, time::Instant};
use std::{
fmt, future,
future::Future,
marker::PhantomData,
sync::atomic::Ordering::SeqCst,
sync::{
atomic::{self, AtomicU32},
Arc,
},
time::Duration,
time::Instant,
};
use tracing::instrument;
@ -118,6 +120,15 @@ pub struct ConnectionState {
>,
>,
>,
#[allow(clippy::type_complexity)]
#[serde(skip)]
stream_response_channels: Arc<
Mutex<
Option<
HashMap<u32, mpsc::UnboundedSender<(Result<proto::Envelope>, oneshot::Sender<()>)>>,
>,
>,
>,
}
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
@ -171,17 +182,28 @@ impl Peer {
outgoing_tx,
next_message_id: Default::default(),
response_channels: Arc::new(Mutex::new(Some(Default::default()))),
stream_response_channels: Arc::new(Mutex::new(Some(Default::default()))),
};
let mut writer = MessageStream::new(connection.tx);
let mut reader = MessageStream::new(connection.rx);
let this = self.clone();
let response_channels = connection_state.response_channels.clone();
let stream_response_channels = connection_state.stream_response_channels.clone();
let handle_io = async move {
tracing::trace!(%connection_id, "handle io future: start");
let _end_connection = util::defer(|| {
response_channels.lock().take();
if let Some(channels) = stream_response_channels.lock().take() {
for channel in channels.values() {
let _ = channel.unbounded_send((
Err(anyhow!("connection closed")),
oneshot::channel().0,
));
}
}
this.connections.write().remove(&connection_id);
tracing::trace!(%connection_id, "handle io future: end");
});
@ -273,12 +295,14 @@ impl Peer {
};
let response_channels = connection_state.response_channels.clone();
let stream_response_channels = connection_state.stream_response_channels.clone();
self.connections
.write()
.insert(connection_id, connection_state);
let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
let response_channels = response_channels.clone();
let stream_response_channels = stream_response_channels.clone();
async move {
let message_id = incoming.id;
tracing::trace!(?incoming, "incoming message future: start");
@ -293,8 +317,15 @@ impl Peer {
responding_to,
"incoming response: received"
);
let channel = response_channels.lock().as_mut()?.remove(&responding_to);
if let Some(tx) = channel {
let response_channel =
response_channels.lock().as_mut()?.remove(&responding_to);
let stream_response_channel = stream_response_channels
.lock()
.as_ref()?
.get(&responding_to)
.cloned();
if let Some(tx) = response_channel {
let requester_resumed = oneshot::channel();
if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
tracing::trace!(
@ -319,6 +350,31 @@ impl Peer {
responding_to,
"incoming response: requester resumed"
);
} else if let Some(tx) = stream_response_channel {
let requester_resumed = oneshot::channel();
if let Err(error) = tx.unbounded_send((Ok(incoming), requester_resumed.0)) {
tracing::debug!(
%connection_id,
message_id,
responding_to = responding_to,
?error,
"incoming stream response: request future dropped",
);
}
tracing::debug!(
%connection_id,
message_id,
responding_to,
"incoming stream response: waiting to resume requester"
);
let _ = requester_resumed.1.await;
tracing::debug!(
%connection_id,
message_id,
responding_to,
"incoming stream response: requester resumed"
);
} else {
let message_type =
proto::build_typed_envelope(connection_id, received_at, incoming)
@ -451,6 +507,66 @@ impl Peer {
}
}
pub fn request_stream<T: RequestMessage>(
&self,
receiver_id: ConnectionId,
request: T,
) -> impl Future<Output = Result<impl Unpin + Stream<Item = Result<T::Response>>>> {
let (tx, rx) = mpsc::unbounded();
let send = self.connection_state(receiver_id).and_then(|connection| {
let message_id = connection.next_message_id.fetch_add(1, SeqCst);
let stream_response_channels = connection.stream_response_channels.clone();
stream_response_channels
.lock()
.as_mut()
.ok_or_else(|| anyhow!("connection was closed"))?
.insert(message_id, tx);
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(
request.into_envelope(message_id, None, None),
))
.map_err(|_| anyhow!("connection was closed"))?;
Ok((message_id, stream_response_channels))
});
async move {
let (message_id, stream_response_channels) = send?;
let stream_response_channels = Arc::downgrade(&stream_response_channels);
Ok(rx.filter_map(move |(response, _barrier)| {
let stream_response_channels = stream_response_channels.clone();
future::ready(match response {
Ok(response) => {
if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
Some(Err(anyhow!(
"RPC request {} failed - {}",
T::NAME,
error.message
)))
} else if let Some(proto::envelope::Payload::EndStream(_)) =
&response.payload
{
// Remove the transmitting end of the response channel to end the stream.
if let Some(channels) = stream_response_channels.upgrade() {
if let Some(channels) = channels.lock().as_mut() {
channels.remove(&message_id);
}
}
None
} else {
Some(
T::Response::from_envelope(response)
.ok_or_else(|| anyhow!("received response of the wrong type")),
)
}
}
Err(error) => Some(Err(error)),
})
}))
}
}
pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
let connection = self.connection_state(receiver_id)?;
let message_id = connection
@ -503,6 +619,24 @@ impl Peer {
Ok(())
}
pub fn end_stream<T: RequestMessage>(&self, receipt: Receipt<T>) -> Result<()> {
let connection = self.connection_state(receipt.sender_id)?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
let message = proto::EndStream {};
connection
.outgoing_tx
.unbounded_send(proto::Message::Envelope(message.into_envelope(
message_id,
Some(receipt.message_id),
None,
)))?;
Ok(())
}
pub fn respond_with_error<T: RequestMessage>(
&self,
receipt: Receipt<T>,

View file

@ -149,7 +149,10 @@ messages!(
(CallCanceled, Foreground),
(CancelCall, Foreground),
(ChannelMessageSent, Foreground),
(CompleteWithLanguageModel, Background),
(CopyProjectEntry, Foreground),
(CountTokensWithLanguageModel, Background),
(CountTokensResponse, Background),
(CreateBufferForPeer, Foreground),
(CreateChannel, Foreground),
(CreateChannelResponse, Foreground),
@ -160,6 +163,7 @@ messages!(
(DeleteChannel, Foreground),
(DeleteNotification, Foreground),
(DeleteProjectEntry, Foreground),
(EndStream, Foreground),
(Error, Foreground),
(ExpandProjectEntry, Foreground),
(ExpandProjectEntryResponse, Foreground),
@ -211,6 +215,7 @@ messages!(
(JoinProjectResponse, Foreground),
(JoinRoom, Foreground),
(JoinRoomResponse, Foreground),
(LanguageModelResponse, Background),
(LeaveChannelBuffer, Background),
(LeaveChannelChat, Foreground),
(LeaveProject, Foreground),
@ -300,6 +305,8 @@ request_messages!(
(Call, Ack),
(CancelCall, Ack),
(CopyProjectEntry, ProjectEntryResponse),
(CompleteWithLanguageModel, LanguageModelResponse),
(CountTokensWithLanguageModel, CountTokensResponse),
(CreateChannel, CreateChannelResponse),
(CreateProjectEntry, ProjectEntryResponse),
(CreateRoom, CreateRoomResponse),

View file

@ -22,7 +22,6 @@ gpui.workspace = true
language.workspace = true
menu.workspace = true
project.workspace = true
semantic_index.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true

View file

@ -705,11 +705,6 @@ impl BufferSearchBar {
option.as_button(is_active, action)
}
pub fn activate_search_mode(&mut self, mode: SearchMode, cx: &mut ViewContext<Self>) {
assert_ne!(
mode,
SearchMode::Semantic,
"Semantic search is not supported in buffer search"
);
if mode == self.current_mode {
return;
}
@ -1022,7 +1017,7 @@ impl BufferSearchBar {
}
}
fn cycle_mode(&mut self, _: &CycleMode, cx: &mut ViewContext<Self>) {
self.activate_search_mode(next_mode(&self.current_mode, false), cx);
self.activate_search_mode(next_mode(&self.current_mode), cx);
}
fn toggle_replace(&mut self, _: &ToggleReplace, cx: &mut ViewContext<Self>) {
if let Some(_) = &self.active_searchable_item {

View file

@ -1,13 +1,12 @@
use gpui::{Action, SharedString};
use crate::{ActivateRegexMode, ActivateSemanticMode, ActivateTextMode};
use crate::{ActivateRegexMode, ActivateTextMode};
// TODO: Update the default search mode to get from config
#[derive(Copy, Clone, Debug, Default, PartialEq)]
pub enum SearchMode {
#[default]
Text,
Semantic,
Regex,
}
@ -15,7 +14,6 @@ impl SearchMode {
pub(crate) fn label(&self) -> &'static str {
match self {
SearchMode::Text => "Text",
SearchMode::Semantic => "Semantic",
SearchMode::Regex => "Regex",
}
}
@ -25,22 +23,14 @@ impl SearchMode {
pub(crate) fn action(&self) -> Box<dyn Action> {
match self {
SearchMode::Text => ActivateTextMode.boxed_clone(),
SearchMode::Semantic => ActivateSemanticMode.boxed_clone(),
SearchMode::Regex => ActivateRegexMode.boxed_clone(),
}
}
}
pub(crate) fn next_mode(mode: &SearchMode, semantic_enabled: bool) -> SearchMode {
pub(crate) fn next_mode(mode: &SearchMode) -> SearchMode {
match mode {
SearchMode::Text => SearchMode::Regex,
SearchMode::Regex => {
if semantic_enabled {
SearchMode::Semantic
} else {
SearchMode::Text
}
}
SearchMode::Semantic => SearchMode::Text,
SearchMode::Regex => SearchMode::Text,
}
}

View file

@ -1,33 +1,26 @@
use crate::{
history::SearchHistory, mode::SearchMode, ActivateRegexMode, ActivateSemanticMode,
ActivateTextMode, CycleMode, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext,
SearchOptions, SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleIncludeIgnored,
ToggleReplace, ToggleWholeWord,
history::SearchHistory, mode::SearchMode, ActivateRegexMode, ActivateTextMode, CycleMode,
NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext, SearchOptions,
SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleIncludeIgnored, ToggleReplace,
ToggleWholeWord,
};
use anyhow::{Context as _, Result};
use collections::HashMap;
use anyhow::Context as _;
use collections::{HashMap, HashSet};
use editor::{
actions::SelectAll,
items::active_match_index,
scroll::{Autoscroll, Axis},
Anchor, Editor, EditorEvent, MultiBuffer, MAX_TAB_TITLE_LEN,
Anchor, Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer, MAX_TAB_TITLE_LEN,
};
use editor::{EditorElement, EditorStyle};
use gpui::{
actions, div, Action, AnyElement, AnyView, AppContext, Context as _, Element, EntityId,
EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight, Global, Hsla,
InteractiveElement, IntoElement, KeyContext, Model, ModelContext, ParentElement, Point,
PromptLevel, Render, SharedString, Styled, Subscription, Task, TextStyle, View, ViewContext,
VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext,
InteractiveElement, IntoElement, KeyContext, Model, ModelContext, ParentElement, Point, Render,
SharedString, Styled, Subscription, Task, TextStyle, View, ViewContext, VisualContext,
WeakModel, WeakView, WhiteSpace, WindowContext,
};
use menu::Confirm;
use project::{
search::{SearchInputs, SearchQuery},
Project,
};
use semantic_index::{SemanticIndex, SemanticIndexStatus};
use collections::HashSet;
use project::{search::SearchQuery, Project};
use settings::Settings;
use smol::stream::StreamExt;
use std::{
@ -35,22 +28,20 @@ use std::{
mem,
ops::{Not, Range},
path::{Path, PathBuf},
time::{Duration, Instant},
};
use theme::ThemeSettings;
use workspace::{DeploySearch, NewSearch};
use ui::{
h_flex, prelude::*, v_flex, Icon, IconButton, IconName, Label, LabelCommon, LabelSize,
Selectable, ToggleButton, Tooltip,
};
use util::{paths::PathMatcher, ResultExt as _};
use util::paths::PathMatcher;
use workspace::{
item::{BreadcrumbText, Item, ItemEvent, ItemHandle},
searchable::{Direction, SearchableItem, SearchableItemHandle},
ItemNavHistory, Pane, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
WorkspaceId,
};
use workspace::{DeploySearch, NewSearch};
const MIN_INPUT_WIDTH_REMS: f32 = 15.;
const MAX_INPUT_WIDTH_REMS: f32 = 30.;
@ -86,12 +77,6 @@ pub fn init(cx: &mut AppContext) {
register_workspace_action(workspace, move |search_bar, _: &ActivateTextMode, cx| {
search_bar.activate_search_mode(SearchMode::Text, cx)
});
register_workspace_action(
workspace,
move |search_bar, _: &ActivateSemanticMode, cx| {
search_bar.activate_search_mode(SearchMode::Semantic, cx)
},
);
register_workspace_action(workspace, move |search_bar, action: &CycleMode, cx| {
search_bar.cycle_mode(action, cx)
});
@ -159,8 +144,6 @@ pub struct ProjectSearchView {
query_editor: View<Editor>,
replacement_editor: View<Editor>,
results_editor: View<Editor>,
semantic_state: Option<SemanticState>,
semantic_permissioned: Option<bool>,
search_options: SearchOptions,
panels_with_errors: HashSet<InputPanel>,
active_match_index: Option<usize>,
@ -174,12 +157,6 @@ pub struct ProjectSearchView {
_subscriptions: Vec<Subscription>,
}
struct SemanticState {
index_status: SemanticIndexStatus,
maintain_rate_limit: Option<Task<()>>,
_subscription: Subscription,
}
#[derive(Debug, Clone)]
struct ProjectSearchSettings {
search_options: SearchOptions,
@ -282,68 +259,6 @@ impl ProjectSearch {
}));
cx.notify();
}
fn semantic_search(&mut self, inputs: &SearchInputs, cx: &mut ModelContext<Self>) {
let search = SemanticIndex::global(cx).map(|index| {
index.update(cx, |semantic_index, cx| {
semantic_index.search_project(
self.project.clone(),
inputs.as_str().to_owned(),
10,
inputs.files_to_include().to_vec(),
inputs.files_to_exclude().to_vec(),
cx,
)
})
});
self.search_id += 1;
self.match_ranges.clear();
self.search_history.add(inputs.as_str().to_string());
self.no_results = None;
self.pending_search = Some(cx.spawn(|this, mut cx| async move {
let results = search?.await.log_err()?;
let matches = results
.into_iter()
.map(|result| (result.buffer, vec![result.range.start..result.range.start]));
this.update(&mut cx, |this, cx| {
this.no_results = Some(true);
this.excerpts.update(cx, |excerpts, cx| {
excerpts.clear(cx);
});
})
.ok()?;
for (buffer, ranges) in matches {
let mut match_ranges = this
.update(&mut cx, |this, cx| {
this.no_results = Some(false);
this.excerpts.update(cx, |excerpts, cx| {
excerpts.stream_excerpts_with_context_lines(buffer, ranges, 3, cx)
})
})
.ok()?;
while let Some(match_range) = match_ranges.next().await {
this.update(&mut cx, |this, cx| {
this.match_ranges.push(match_range);
while let Ok(Some(match_range)) = match_ranges.try_next() {
this.match_ranges.push(match_range);
}
cx.notify();
})
.ok()?;
}
}
this.update(&mut cx, |this, cx| {
this.pending_search.take();
cx.notify();
})
.ok()?;
None
}));
cx.notify();
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
@ -358,8 +273,6 @@ impl EventEmitter<ViewEvent> for ProjectSearchView {}
impl Render for ProjectSearchView {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
const PLEASE_AUTHENTICATE: &str = "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables. If you authenticated using the Assistant Panel, please restart Zed to Authenticate.";
if self.has_matches() {
div()
.flex_1()
@ -370,7 +283,7 @@ impl Render for ProjectSearchView {
let model = self.model.read(cx);
let has_no_results = model.no_results.unwrap_or(false);
let is_search_underway = model.pending_search.is_some();
let mut major_text = if is_search_underway {
let major_text = if is_search_underway {
Label::new("Searching...")
} else if has_no_results {
Label::new("No results")
@ -378,43 +291,6 @@ impl Render for ProjectSearchView {
Label::new(format!("{} search all files", self.current_mode.label()))
};
let mut show_minor_text = true;
let semantic_status = self.semantic_state.as_ref().and_then(|semantic| {
let status = semantic.index_status;
match status {
SemanticIndexStatus::NotAuthenticated => {
major_text = Label::new("Not Authenticated");
show_minor_text = false;
Some(PLEASE_AUTHENTICATE.to_string())
}
SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()),
SemanticIndexStatus::Indexing {
remaining_files,
rate_limit_expiry,
} => {
if remaining_files == 0 {
Some("Indexing...".to_string())
} else {
if let Some(rate_limit_expiry) = rate_limit_expiry {
let remaining_seconds =
rate_limit_expiry.duration_since(Instant::now());
if remaining_seconds > Duration::from_secs(0) {
Some(format!(
"Remaining files to index (rate limit resets in {}s): {}",
remaining_seconds.as_secs(),
remaining_files
))
} else {
Some(format!("Remaining files to index: {}", remaining_files))
}
} else {
Some(format!("Remaining files to index: {}", remaining_files))
}
}
}
SemanticIndexStatus::NotIndexed => None,
}
});
let major_text = div().justify_center().max_w_96().child(major_text);
let minor_text: Option<SharedString> = if let Some(no_results) = model.no_results {
@ -424,12 +300,7 @@ impl Render for ProjectSearchView {
None
}
} else {
if let Some(mut semantic_status) = semantic_status {
semantic_status.extend(self.landing_text_minor().chars());
Some(semantic_status.into())
} else {
Some(self.landing_text_minor())
}
Some(self.landing_text_minor())
};
let minor_text = minor_text.map(|text| {
div()
@ -676,58 +547,6 @@ impl ProjectSearchView {
});
}
fn index_project(&mut self, cx: &mut ViewContext<Self>) {
if let Some(semantic_index) = SemanticIndex::global(cx) {
// Semantic search uses no options
self.search_options = SearchOptions::none();
let project = self.model.read(cx).project.clone();
semantic_index.update(cx, |semantic_index, cx| {
semantic_index
.index_project(project.clone(), cx)
.detach_and_log_err(cx);
});
self.semantic_state = Some(SemanticState {
index_status: semantic_index.read(cx).status(&project),
maintain_rate_limit: None,
_subscription: cx.observe(&semantic_index, Self::semantic_index_changed),
});
self.semantic_index_changed(semantic_index, cx);
}
}
fn semantic_index_changed(
&mut self,
semantic_index: Model<SemanticIndex>,
cx: &mut ViewContext<Self>,
) {
let project = self.model.read(cx).project.clone();
if let Some(semantic_state) = self.semantic_state.as_mut() {
cx.notify();
semantic_state.index_status = semantic_index.read(cx).status(&project);
if let SemanticIndexStatus::Indexing {
rate_limit_expiry: Some(_),
..
} = &semantic_state.index_status
{
if semantic_state.maintain_rate_limit.is_none() {
semantic_state.maintain_rate_limit =
Some(cx.spawn(|this, mut cx| async move {
loop {
cx.background_executor().timer(Duration::from_secs(1)).await;
this.update(&mut cx, |_, cx| cx.notify()).log_err();
}
}));
return;
}
} else {
semantic_state.maintain_rate_limit = None;
}
}
}
fn clear_search(&mut self, cx: &mut ViewContext<Self>) {
self.model.update(cx, |model, cx| {
model.pending_search = None;
@ -750,63 +569,7 @@ impl ProjectSearchView {
self.clear_search(cx);
self.current_mode = mode;
self.active_match_index = None;
match mode {
SearchMode::Semantic => {
let has_permission = self.semantic_permissioned(cx);
self.active_match_index = None;
cx.spawn(|this, mut cx| async move {
let has_permission = has_permission.await?;
if !has_permission {
let answer = this.update(&mut cx, |this, cx| {
let project = this.model.read(cx).project.clone();
let project_name = project
.read(cx)
.worktree_root_names(cx)
.collect::<Vec<&str>>()
.join("/");
let is_plural =
project_name.chars().filter(|letter| *letter == '/').count() > 0;
let prompt_text = format!("Would you like to index the '{}' project{} for semantic search? This requires sending code to the OpenAI API", project_name,
if is_plural {
"s"
} else {""});
cx.prompt(
PromptLevel::Info,
prompt_text.as_str(),
None,
&["Continue", "Cancel"],
)
})?;
if answer.await? == 0 {
this.update(&mut cx, |this, _| {
this.semantic_permissioned = Some(true);
})?;
} else {
this.update(&mut cx, |this, cx| {
this.semantic_permissioned = Some(false);
debug_assert_ne!(previous_mode, SearchMode::Semantic, "Tried to re-enable semantic search mode after user modal was rejected");
this.activate_search_mode(previous_mode, cx);
})?;
return anyhow::Ok(());
}
}
this.update(&mut cx, |this, cx| {
this.index_project(cx);
})?;
anyhow::Ok(())
}).detach_and_log_err(cx);
}
SearchMode::Regex | SearchMode::Text => {
self.semantic_state = None;
self.active_match_index = None;
self.search(cx);
}
}
self.search(cx);
cx.update_global(|state: &mut ActiveSettings, cx| {
state.0.insert(
@ -973,8 +736,6 @@ impl ProjectSearchView {
model,
query_editor,
results_editor,
semantic_state: None,
semantic_permissioned: None,
search_options: options,
panels_with_errors: HashSet::default(),
active_match_index: None,
@ -990,19 +751,6 @@ impl ProjectSearchView {
this
}
fn semantic_permissioned(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<bool>> {
if let Some(value) = self.semantic_permissioned {
return Task::ready(Ok(value));
}
SemanticIndex::global(cx)
.map(|semantic| {
let project = self.model.read(cx).project.clone();
semantic.update(cx, |this, cx| this.project_previously_indexed(&project, cx))
})
.unwrap_or(Task::ready(Ok(false)))
}
pub fn new_search_in_directory(
workspace: &mut Workspace,
dir_path: &Path,
@ -1126,22 +874,8 @@ impl ProjectSearchView {
}
fn search(&mut self, cx: &mut ViewContext<Self>) {
let mode = self.current_mode;
match mode {
SearchMode::Semantic => {
if self.semantic_state.is_some() {
if let Some(query) = self.build_search_query(cx) {
self.model
.update(cx, |model, cx| model.semantic_search(query.as_inner(), cx));
}
}
}
_ => {
if let Some(query) = self.build_search_query(cx) {
self.model.update(cx, |model, cx| model.search(query, cx));
}
}
if let Some(query) = self.build_search_query(cx) {
self.model.update(cx, |model, cx| model.search(query, cx));
}
}
@ -1356,7 +1090,6 @@ impl ProjectSearchView {
fn landing_text_minor(&self) -> SharedString {
match self.current_mode {
SearchMode::Text | SearchMode::Regex => "Include/exclude specific paths with the filter option. Matching exact word and/or casing is available too.".into(),
SearchMode::Semantic => "\nSimply explain the code you are looking to find. ex. 'prompt user for permissions to index their project'".into()
}
}
fn border_color_for(&self, panel: InputPanel, cx: &WindowContext) -> Hsla {
@ -1387,8 +1120,7 @@ impl ProjectSearchBar {
fn cycle_mode(&self, _: &CycleMode, cx: &mut ViewContext<Self>) {
if let Some(view) = self.active_project_search.as_ref() {
view.update(cx, |this, cx| {
let new_mode =
crate::mode::next_mode(&this.current_mode, SemanticIndex::enabled(cx));
let new_mode = crate::mode::next_mode(&this.current_mode);
this.activate_search_mode(new_mode, cx);
let editor_handle = this.query_editor.focus_handle(cx);
cx.focus(&editor_handle);
@ -1681,7 +1413,6 @@ impl Render for ProjectSearchBar {
});
}
let search = search.read(cx);
let semantic_is_available = SemanticIndex::enabled(cx);
let query_column = h_flex()
.flex_1()
@ -1711,12 +1442,8 @@ impl Render for ProjectSearchBar {
.unwrap_or_default(),
),
)
.when(search.current_mode != SearchMode::Semantic, |this| {
this.child(
IconButton::new(
"project-search-case-sensitive",
IconName::CaseSensitive,
)
.child(
IconButton::new("project-search-case-sensitive", IconName::CaseSensitive)
.tooltip(|cx| {
Tooltip::for_action(
"Toggle case sensitive",
@ -1728,18 +1455,17 @@ impl Render for ProjectSearchBar {
.on_click(cx.listener(|this, _, cx| {
this.toggle_search_option(SearchOptions::CASE_SENSITIVE, cx);
})),
)
.child(
IconButton::new("project-search-whole-word", IconName::WholeWord)
.tooltip(|cx| {
Tooltip::for_action("Toggle whole word", &ToggleWholeWord, cx)
})
.selected(self.is_option_enabled(SearchOptions::WHOLE_WORD, cx))
.on_click(cx.listener(|this, _, cx| {
this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
})),
)
}),
)
.child(
IconButton::new("project-search-whole-word", IconName::WholeWord)
.tooltip(|cx| {
Tooltip::for_action("Toggle whole word", &ToggleWholeWord, cx)
})
.selected(self.is_option_enabled(SearchOptions::WHOLE_WORD, cx))
.on_click(cx.listener(|this, _, cx| {
this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
})),
),
);
let mode_column = v_flex().items_start().justify_start().child(
@ -1775,33 +1501,8 @@ impl Render for ProjectSearchBar {
cx,
)
})
.map(|this| {
if semantic_is_available {
this.middle()
} else {
this.last()
}
}),
)
.when(semantic_is_available, |this| {
this.child(
ToggleButton::new("project-search-semantic-button", "Semantic")
.style(ButtonStyle::Filled)
.size(ButtonSize::Large)
.selected(search.current_mode == SearchMode::Semantic)
.on_click(cx.listener(|this, _, cx| {
this.activate_search_mode(SearchMode::Semantic, cx)
}))
.tooltip(|cx| {
Tooltip::for_action(
"Toggle semantic search",
&ActivateSemanticMode,
cx,
)
})
.last(),
)
}),
.last(),
),
)
.child(
IconButton::new("project-search-toggle-replace", IconName::Replace)
@ -1929,21 +1630,16 @@ impl Render for ProjectSearchBar {
.border_color(search.border_color_for(InputPanel::Include, cx))
.rounded_lg()
.child(self.render_text_input(&search.included_files_editor, cx))
.when(search.current_mode != SearchMode::Semantic, |this| {
this.child(
SearchOptions::INCLUDE_IGNORED.as_button(
search
.search_options
.contains(SearchOptions::INCLUDE_IGNORED),
cx.listener(|this, _, cx| {
this.toggle_search_option(
SearchOptions::INCLUDE_IGNORED,
cx,
);
}),
),
)
}),
.child(
SearchOptions::INCLUDE_IGNORED.as_button(
search
.search_options
.contains(SearchOptions::INCLUDE_IGNORED),
cx.listener(|this, _, cx| {
this.toggle_search_option(SearchOptions::INCLUDE_IGNORED, cx);
}),
),
),
)
.child(
h_flex()
@ -1972,9 +1668,6 @@ impl Render for ProjectSearchBar {
.on_action(cx.listener(|this, _: &ActivateRegexMode, cx| {
this.activate_search_mode(SearchMode::Regex, cx)
}))
.on_action(cx.listener(|this, _: &ActivateSemanticMode, cx| {
this.activate_search_mode(SearchMode::Semantic, cx)
}))
.capture_action(cx.listener(|this, action, cx| {
this.tab(action, cx);
cx.stop_propagation();
@ -1987,35 +1680,33 @@ impl Render for ProjectSearchBar {
.on_action(cx.listener(|this, action, cx| {
this.cycle_mode(action, cx);
}))
.when(search.current_mode != SearchMode::Semantic, |this| {
this.on_action(cx.listener(|this, action, cx| {
this.toggle_replace(action, cx);
.on_action(cx.listener(|this, action, cx| {
this.toggle_replace(action, cx);
}))
.on_action(cx.listener(|this, _: &ToggleWholeWord, cx| {
this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
}))
.on_action(cx.listener(|this, _: &ToggleCaseSensitive, cx| {
this.toggle_search_option(SearchOptions::CASE_SENSITIVE, cx);
}))
.on_action(cx.listener(|this, action, cx| {
if let Some(search) = this.active_project_search.as_ref() {
search.update(cx, |this, cx| {
this.replace_next(action, cx);
})
}
}))
.on_action(cx.listener(|this, action, cx| {
if let Some(search) = this.active_project_search.as_ref() {
search.update(cx, |this, cx| {
this.replace_all(action, cx);
})
}
}))
.when(search.filters_enabled, |this| {
this.on_action(cx.listener(|this, _: &ToggleIncludeIgnored, cx| {
this.toggle_search_option(SearchOptions::INCLUDE_IGNORED, cx);
}))
.on_action(cx.listener(|this, _: &ToggleWholeWord, cx| {
this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
}))
.on_action(cx.listener(|this, _: &ToggleCaseSensitive, cx| {
this.toggle_search_option(SearchOptions::CASE_SENSITIVE, cx);
}))
.on_action(cx.listener(|this, action, cx| {
if let Some(search) = this.active_project_search.as_ref() {
search.update(cx, |this, cx| {
this.replace_next(action, cx);
})
}
}))
.on_action(cx.listener(|this, action, cx| {
if let Some(search) = this.active_project_search.as_ref() {
search.update(cx, |this, cx| {
this.replace_all(action, cx);
})
}
}))
.when(search.filters_enabled, |this| {
this.on_action(cx.listener(|this, _: &ToggleIncludeIgnored, cx| {
this.toggle_search_option(SearchOptions::INCLUDE_IGNORED, cx);
}))
})
})
.on_action(cx.listener(Self::select_next_match))
.on_action(cx.listener(Self::select_prev_match))
@ -2039,12 +1730,6 @@ impl ToolbarItemView for ProjectSearchBar {
self.subscription = None;
self.active_project_search = None;
if let Some(search) = active_pane_item.and_then(|i| i.downcast::<ProjectSearchView>()) {
search.update(cx, |search, cx| {
if search.current_mode == SearchMode::Semantic {
search.index_project(cx);
}
});
self.subscription = Some(cx.observe(&search, |_, _, cx| cx.notify()));
self.active_project_search = Some(search);
ToolbarItemLocation::PrimaryLeft {}
@ -2123,9 +1808,8 @@ pub mod tests {
use editor::DisplayPoint;
use gpui::{Action, TestAppContext, WindowHandle};
use project::FakeFs;
use semantic_index::semantic_index_settings::SemanticIndexSettings;
use serde_json::json;
use settings::{Settings, SettingsStore};
use settings::SettingsStore;
use std::sync::Arc;
use workspace::DeploySearch;
@ -3446,8 +3130,6 @@ pub mod tests {
let settings = SettingsStore::test(cx);
cx.set_global(settings);
SemanticIndexSettings::register(cx);
theme::init(theme::LoadThemes::JustBase, cx);
language::init(cx);

View file

@ -33,7 +33,6 @@ actions!(
NextHistoryQuery,
PreviousHistoryQuery,
ActivateTextMode,
ActivateSemanticMode,
ActivateRegexMode,
ReplaceAll,
ReplaceNext,

View file

@ -1,66 +0,0 @@
[package]
name = "semantic_index"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/semantic_index.rs"
doctest = false
[dependencies]
ai.workspace = true
anyhow.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
lazy_static.workspace = true
log.workspace = true
ndarray = { version = "0.15.0" }
ordered-float.workspace = true
parking_lot.workspace = true
postage.workspace = true
project.workspace = true
rand.workspace = true
release_channel.workspace = true
rpc.workspace = true
rusqlite.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
sha1 = "0.10.5"
smol.workspace = true
tree-sitter.workspace = true
util.workspace = true
workspace.workspace = true
[dev-dependencies]
ai = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
ctor.workspace = true
env_logger.workspace = true
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
rpc = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"]}
tempfile.workspace = true
tree-sitter-cpp.workspace = true
tree-sitter-elixir.workspace = true
tree-sitter-json.workspace = true
tree-sitter-lua.workspace = true
tree-sitter-php.workspace = true
tree-sitter-ruby.workspace = true
tree-sitter-rust.workspace = true
tree-sitter-toml.workspace = true
tree-sitter-typescript.workspace = true
unindent.workspace = true
workspace = { workspace = true, features = ["test-support"] }

View file

@ -1 +0,0 @@
../../LICENSE-GPL

View file

@ -1,20 +0,0 @@
# Semantic Index
## Evaluation
### Metrics
nDCG@k:
- "The value of NDCG is determined by comparing the relevance of the items returned by the search engine to the relevance of the item that a hypothetical "ideal" search engine would return.
- "The relevance of result is represented by a score (also known as a 'grade') that is assigned to the search query. The scores of these results are then discounted based on their position in the search results -- did they get recommended first or last?"
MRR@k:
- "Mean reciprocal rank quantifies the rank of the first relevant item found in the recommendation list."
MAP@k:
- "Mean average precision averages the precision@k metric at each relevant item position in the recommendation list.
Resources:
- [Evaluating recommendation metrics](https://www.shaped.ai/blog/evaluating-recommendation-systems-map-mmr-ndcg)
- [Math Walkthrough](https://towardsdatascience.com/demystifying-ndcg-bee3be58cfe0)

View file

@ -1,114 +0,0 @@
{
"repo": "https://github.com/AntonOsika/gpt-engineer.git",
"commit": "7735a6445bae3611c62f521e6464c67c957f87c2",
"assertions": [
{
"query": "How do I contribute to this project?",
"matches": [
".github/CONTRIBUTING.md:1",
"ROADMAP.md:48"
]
},
{
"query": "What version of the openai package is active?",
"matches": [
"pyproject.toml:14"
]
},
{
"query": "Ask user for clarification",
"matches": [
"gpt_engineer/steps.py:69"
]
},
{
"query": "generate tests for python code",
"matches": [
"gpt_engineer/steps.py:153"
]
},
{
"query": "get item from database based on key",
"matches": [
"gpt_engineer/db.py:42",
"gpt_engineer/db.py:68"
]
},
{
"query": "prompt user to select files",
"matches": [
"gpt_engineer/file_selector.py:171",
"gpt_engineer/file_selector.py:306",
"gpt_engineer/file_selector.py:289",
"gpt_engineer/file_selector.py:234"
]
},
{
"query": "send to rudderstack",
"matches": [
"gpt_engineer/collect.py:11",
"gpt_engineer/collect.py:38"
]
},
{
"query": "parse code blocks from chat messages",
"matches": [
"gpt_engineer/chat_to_files.py:10",
"docs/intro/chat_parsing.md:1"
]
},
{
"query": "how do I use the docker cli?",
"matches": [
"docker/README.md:1"
]
},
{
"query": "ask the user if the code ran successfully?",
"matches": [
"gpt_engineer/learning.py:54"
]
},
{
"query": "how is consent granted by the user?",
"matches": [
"gpt_engineer/learning.py:107",
"gpt_engineer/learning.py:130",
"gpt_engineer/learning.py:152"
]
},
{
"query": "what are all the different steps the agent can take?",
"matches": [
"docs/intro/steps_module.md:1",
"gpt_engineer/steps.py:391"
]
},
{
"query": "ask the user for clarification?",
"matches": [
"gpt_engineer/steps.py:69"
]
},
{
"query": "what models are available?",
"matches": [
"gpt_engineer/ai.py:315",
"gpt_engineer/ai.py:341",
"docs/open-models.md:1"
]
},
{
"query": "what is the current focus of the project?",
"matches": [
"ROADMAP.md:11"
]
},
{
"query": "does the agent know how to fix code?",
"matches": [
"gpt_engineer/steps.py:367"
]
}
]
}

View file

@ -1,104 +0,0 @@
{
"repo": "https://github.com/tree-sitter/tree-sitter.git",
"commit": "46af27796a76c72d8466627d499f2bca4af958ee",
"assertions": [
{
"query": "What attributes are available for the tags configuration struct?",
"matches": [
"tags/src/lib.rs:24"
]
},
{
"query": "create a new tag configuration",
"matches": [
"tags/src/lib.rs:119"
]
},
{
"query": "generate tags based on config",
"matches": [
"tags/src/lib.rs:261"
]
},
{
"query": "match on ts quantifier in rust",
"matches": [
"lib/binding_rust/lib.rs:139"
]
},
{
"query": "cli command to generate tags",
"matches": [
"cli/src/tags.rs:10"
]
},
{
"query": "what version of the tree-sitter-tags package is active?",
"matches": [
"tags/Cargo.toml:4"
]
},
{
"query": "Insert a new parse state",
"matches": [
"cli/src/generate/build_tables/build_parse_table.rs:153"
]
},
{
"query": "Handle conflict when numerous actions occur on the same symbol",
"matches": [
"cli/src/generate/build_tables/build_parse_table.rs:363",
"cli/src/generate/build_tables/build_parse_table.rs:442"
]
},
{
"query": "Match based on associativity of actions",
"matches": [
"cri/src/generate/build_tables/build_parse_table.rs:542"
]
},
{
"query": "Format token set display",
"matches": [
"cli/src/generate/build_tables/item.rs:246"
]
},
{
"query": "extract choices from rule",
"matches": [
"cli/src/generate/prepare_grammar/flatten_grammar.rs:124"
]
},
{
"query": "How do we identify if a symbol is being used?",
"matches": [
"cli/src/generate/prepare_grammar/flatten_grammar.rs:175"
]
},
{
"query": "How do we launch the playground?",
"matches": [
"cli/src/playground.rs:46"
]
},
{
"query": "How do we test treesitter query matches in rust?",
"matches": [
"cli/src/query_testing.rs:152",
"cli/src/tests/query_test.rs:781",
"cli/src/tests/query_test.rs:2163",
"cli/src/tests/query_test.rs:3781",
"cli/src/tests/query_test.rs:887"
]
},
{
"query": "What does the CLI do?",
"matches": [
"cli/README.md:10",
"cli/loader/README.md:3",
"docs/section-5-implementation.md:14",
"docs/section-5-implementation.md:18"
]
}
]
}

View file

@ -1,594 +0,0 @@
use crate::{
parsing::{Span, SpanDigest},
SEMANTIC_INDEX_VERSION,
};
use ai::embedding::Embedding;
use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use futures::channel::oneshot;
use gpui::BackgroundExecutor;
use ndarray::{Array1, Array2};
use ordered_float::OrderedFloat;
use project::Fs;
use rpc::proto::Timestamp;
use rusqlite::params;
use rusqlite::types::Value;
use std::{
future::Future,
ops::Range,
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
time::SystemTime,
};
use util::{paths::PathMatcher, TryFutureExt};
pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
let mut indices = (0..data.len()).collect::<Vec<_>>();
indices.sort_by_key(|&i| &data[i]);
indices.reverse();
indices
}
#[derive(Debug)]
pub struct FileRecord {
pub id: usize,
pub relative_path: String,
pub mtime: Timestamp,
}
#[derive(Clone)]
pub struct VectorDatabase {
path: Arc<Path>,
transactions:
smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
}
impl VectorDatabase {
pub async fn new(
fs: Arc<dyn Fs>,
path: Arc<Path>,
executor: BackgroundExecutor,
) -> Result<Self> {
if let Some(db_directory) = path.parent() {
fs.create_dir(db_directory).await?;
}
let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
>();
executor
.spawn({
let path = path.clone();
async move {
let mut connection = rusqlite::Connection::open(&path)?;
connection.pragma_update(None, "journal_mode", "wal")?;
connection.pragma_update(None, "synchronous", "normal")?;
connection.pragma_update(None, "cache_size", 1000000)?;
connection.pragma_update(None, "temp_store", "MEMORY")?;
while let Ok(transaction) = transactions_rx.recv().await {
transaction(&mut connection);
}
anyhow::Ok(())
}
.log_err()
})
.detach();
let this = Self {
transactions: transactions_tx,
path,
};
this.initialize_database().await?;
Ok(this)
}
pub fn path(&self) -> &Arc<Path> {
&self.path
}
fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
where
F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
T: 'static + Send,
{
let (tx, rx) = oneshot::channel();
let transactions = self.transactions.clone();
async move {
if transactions
.send(Box::new(|connection| {
let result = connection
.transaction()
.map_err(|err| anyhow!(err))
.and_then(|transaction| {
let result = f(&transaction)?;
transaction.commit()?;
Ok(result)
});
let _ = tx.send(result);
}))
.await
.is_err()
{
return Err(anyhow!("connection was dropped"))?;
}
rx.await?
}
}
fn initialize_database(&self) -> impl Future<Output = Result<()>> {
self.transact(|db| {
rusqlite::vtab::array::load_module(&db)?;
// Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
let version_query = db.prepare("SELECT version from semantic_index_config");
let version = version_query
.and_then(|mut query| query.query_row([], |row| row.get::<_, i64>(0)));
if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
log::trace!("vector database schema up to date");
return Ok(());
}
log::trace!("vector database schema out of date. updating...");
// We renamed the `documents` table to `spans`, so we want to drop
// `documents` without recreating it if it exists.
db.execute("DROP TABLE IF EXISTS documents", [])
.context("failed to drop 'documents' table")?;
db.execute("DROP TABLE IF EXISTS spans", [])
.context("failed to drop 'spans' table")?;
db.execute("DROP TABLE IF EXISTS files", [])
.context("failed to drop 'files' table")?;
db.execute("DROP TABLE IF EXISTS worktrees", [])
.context("failed to drop 'worktrees' table")?;
db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
.context("failed to drop 'semantic_index_config' table")?;
// Initialize Vector Databasing Tables
db.execute(
"CREATE TABLE semantic_index_config (
version INTEGER NOT NULL
)",
[],
)?;
db.execute(
"INSERT INTO semantic_index_config (version) VALUES (?1)",
params![SEMANTIC_INDEX_VERSION],
)?;
db.execute(
"CREATE TABLE worktrees (
id INTEGER PRIMARY KEY AUTOINCREMENT,
absolute_path VARCHAR NOT NULL
);
CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
",
[],
)?;
db.execute(
"CREATE TABLE files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
worktree_id INTEGER NOT NULL,
relative_path VARCHAR NOT NULL,
mtime_seconds INTEGER NOT NULL,
mtime_nanos INTEGER NOT NULL,
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
)",
[],
)?;
db.execute(
"CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
[],
)?;
db.execute(
"CREATE TABLE spans (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL,
start_byte INTEGER NOT NULL,
end_byte INTEGER NOT NULL,
name VARCHAR NOT NULL,
embedding BLOB NOT NULL,
digest BLOB NOT NULL,
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
)",
[],
)?;
db.execute(
"CREATE INDEX spans_digest ON spans (digest)",
[],
)?;
log::trace!("vector database initialized with updated schema.");
Ok(())
})
}
pub fn delete_file(
&self,
worktree_id: i64,
delete_path: Arc<Path>,
) -> impl Future<Output = Result<()>> {
self.transact(move |db| {
db.execute(
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
params![worktree_id, delete_path.to_str()],
)?;
Ok(())
})
}
pub fn insert_file(
&self,
worktree_id: i64,
path: Arc<Path>,
mtime: SystemTime,
spans: Vec<Span>,
) -> impl Future<Output = Result<()>> {
self.transact(move |db| {
// Return the existing ID, if both the file and mtime match
let mtime = Timestamp::from(mtime);
db.execute(
"
REPLACE INTO files
(worktree_id, relative_path, mtime_seconds, mtime_nanos)
VALUES (?1, ?2, ?3, ?4)
",
params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
)?;
let file_id = db.last_insert_rowid();
let mut query = db.prepare(
"
INSERT INTO spans
(file_id, start_byte, end_byte, name, embedding, digest)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
",
)?;
for span in spans {
query.execute(params![
file_id,
span.range.start.to_string(),
span.range.end.to_string(),
span.name,
span.embedding,
span.digest
])?;
}
Ok(())
})
}
pub fn worktree_previously_indexed(
&self,
worktree_root_path: &Path,
) -> impl Future<Output = Result<bool>> {
let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
self.transact(move |db| {
let mut worktree_query =
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id =
worktree_query.query_row(params![worktree_root_path], |row| row.get::<_, i64>(0));
Ok(worktree_id.is_ok())
})
}
pub fn embeddings_for_digests(
&self,
digests: Vec<SpanDigest>,
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
self.transact(move |db| {
let mut query = db.prepare(
"
SELECT digest, embedding
FROM spans
WHERE digest IN rarray(?)
",
)?;
let mut embeddings_by_digest = HashMap::default();
let digests = Rc::new(
digests
.into_iter()
.map(|digest| Value::Blob(digest.0.to_vec()))
.collect::<Vec<_>>(),
);
let rows = query.query_map(params![digests], |row| {
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
})?;
for (digest, embedding) in rows.flatten() {
embeddings_by_digest.insert(digest, embedding);
}
Ok(embeddings_by_digest)
})
}
pub fn embeddings_for_files(
&self,
worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
self.transact(move |db| {
let mut query = db.prepare(
"
SELECT digest, embedding
FROM spans
LEFT JOIN files ON files.id = spans.file_id
WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
",
)?;
let mut embeddings_by_digest = HashMap::default();
for (worktree_id, file_paths) in worktree_id_file_paths {
let file_paths = Rc::new(
file_paths
.into_iter()
.map(|p| Value::Text(p.to_string_lossy().into_owned()))
.collect::<Vec<_>>(),
);
let rows = query.query_map(params![worktree_id, file_paths], |row| {
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
})?;
for (digest, embedding) in rows.flatten() {
embeddings_by_digest.insert(digest, embedding);
}
}
Ok(embeddings_by_digest)
})
}
pub fn find_or_create_worktree(
&self,
worktree_root_path: Arc<Path>,
) -> impl Future<Output = Result<i64>> {
self.transact(move |db| {
let mut worktree_query =
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id = worktree_query
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
row.get::<_, i64>(0)
});
if worktree_id.is_ok() {
return Ok(worktree_id?);
}
// If worktree_id is Err, insert new worktree
db.execute(
"INSERT into worktrees (absolute_path) VALUES (?1)",
params![worktree_root_path.to_string_lossy()],
)?;
Ok(db.last_insert_rowid())
})
}
pub fn get_file_mtimes(
&self,
worktree_id: i64,
) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
self.transact(move |db| {
let mut statement = db.prepare(
"
SELECT relative_path, mtime_seconds, mtime_nanos
FROM files
WHERE worktree_id = ?1
ORDER BY relative_path",
)?;
let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
for row in statement.query_map(params![worktree_id], |row| {
Ok((
row.get::<_, String>(0)?.into(),
Timestamp {
seconds: row.get(1)?,
nanos: row.get(2)?,
}
.into(),
))
})? {
let row = row?;
result.insert(row.0, row.1);
}
Ok(result)
})
}
pub fn top_k_search(
&self,
query_embedding: &Embedding,
limit: usize,
file_ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
let file_ids = file_ids.to_vec();
let query = query_embedding.clone().0;
let query = Array1::from_vec(query);
self.transact(move |db| {
let mut query_statement = db.prepare(
"
SELECT
id, embedding
FROM
spans
WHERE
file_id IN rarray(?)
",
)?;
let deserialized_rows = query_statement
.query_map(params![ids_to_sql(&file_ids)], |row| {
Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
})?
.filter_map(|row| row.ok())
.collect::<Vec<(usize, Embedding)>>();
if deserialized_rows.len() == 0 {
return Ok(Vec::new());
}
// Get Length of Embeddings Returned
let embedding_len = deserialized_rows[0].1 .0.len();
let batch_n = 1000;
let mut batches = Vec::new();
let mut batch_ids = Vec::new();
let mut batch_embeddings: Vec<f32> = Vec::new();
deserialized_rows.iter().for_each(|(id, embedding)| {
batch_ids.push(id);
batch_embeddings.extend(&embedding.0);
if batch_ids.len() == batch_n {
let embeddings = std::mem::take(&mut batch_embeddings);
let ids = std::mem::take(&mut batch_ids);
let array = Array2::from_shape_vec((ids.len(), embedding_len), embeddings);
match array {
Ok(array) => {
batches.push((ids, array));
}
Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
}
}
});
if batch_ids.len() > 0 {
let array = Array2::from_shape_vec(
(batch_ids.len(), embedding_len),
batch_embeddings.clone(),
);
match array {
Ok(array) => {
batches.push((batch_ids.clone(), array));
}
Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
}
}
let mut ids: Vec<usize> = Vec::new();
let mut results = Vec::new();
for (batch_ids, array) in batches {
let scores = array
.dot(&query.t())
.to_vec()
.iter()
.map(|score| OrderedFloat(*score))
.collect::<Vec<OrderedFloat<f32>>>();
results.extend(scores);
ids.extend(batch_ids);
}
let sorted_idx = argsort(&results);
let mut sorted_results = Vec::new();
let last_idx = limit.min(sorted_idx.len());
for idx in &sorted_idx[0..last_idx] {
sorted_results.push((ids[*idx] as i64, results[*idx]))
}
Ok(sorted_results)
})
}
pub fn retrieve_included_file_ids(
&self,
worktree_ids: &[i64],
includes: &[PathMatcher],
excludes: &[PathMatcher],
) -> impl Future<Output = Result<Vec<i64>>> {
let worktree_ids = worktree_ids.to_vec();
let includes = includes.to_vec();
let excludes = excludes.to_vec();
self.transact(move |db| {
let mut file_query = db.prepare(
"
SELECT
id, relative_path
FROM
files
WHERE
worktree_id IN rarray(?)
",
)?;
let mut file_ids = Vec::<i64>::new();
let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
while let Some(row) = rows.next()? {
let file_id = row.get(0)?;
let relative_path = row.get_ref(1)?.as_str()?;
let included =
includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
if included && !excluded {
file_ids.push(file_id);
}
}
anyhow::Ok(file_ids)
})
}
pub fn spans_for_ids(
&self,
ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
let ids = ids.to_vec();
self.transact(move |db| {
let mut statement = db.prepare(
"
SELECT
spans.id,
files.worktree_id,
files.relative_path,
spans.start_byte,
spans.end_byte
FROM
spans, files
WHERE
spans.file_id = files.id AND
spans.id in rarray(?)
",
)?;
let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
Ok((
row.get::<_, i64>(0)?,
row.get::<_, i64>(1)?,
row.get::<_, String>(2)?.into(),
row.get(3)?..row.get(4)?,
))
})?;
let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
for row in result_iter {
let (id, worktree_id, path, range) = row?;
values_by_id.insert(id, (worktree_id, path, range));
}
let mut results = Vec::with_capacity(ids.len());
for id in &ids {
let value = values_by_id
.remove(id)
.ok_or(anyhow!("missing span id {}", id))?;
results.push(value);
}
Ok(results)
})
}
}
fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
Rc::new(
ids.iter()
.copied()
.map(|v| rusqlite::types::Value::from(v))
.collect::<Vec<_>>(),
)
}

View file

@ -1,169 +0,0 @@
use crate::{parsing::Span, JobHandle};
use ai::embedding::EmbeddingProvider;
use gpui::BackgroundExecutor;
use parking_lot::Mutex;
use smol::channel;
use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
#[derive(Clone)]
pub struct FileToEmbed {
pub worktree_id: i64,
pub path: Arc<Path>,
pub mtime: SystemTime,
pub spans: Vec<Span>,
pub job_handle: JobHandle,
}
impl std::fmt::Debug for FileToEmbed {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FileToEmbed")
.field("worktree_id", &self.worktree_id)
.field("path", &self.path)
.field("mtime", &self.mtime)
.field("spans", &self.spans)
.finish_non_exhaustive()
}
}
impl PartialEq for FileToEmbed {
fn eq(&self, other: &Self) -> bool {
self.worktree_id == other.worktree_id
&& self.path == other.path
&& self.mtime == other.mtime
&& self.spans == other.spans
}
}
pub struct EmbeddingQueue {
embedding_provider: Arc<dyn EmbeddingProvider>,
pending_batch: Vec<FileFragmentToEmbed>,
executor: BackgroundExecutor,
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
}
#[derive(Clone)]
pub struct FileFragmentToEmbed {
file: Arc<Mutex<FileToEmbed>>,
span_range: Range<usize>,
}
impl EmbeddingQueue {
pub fn new(
embedding_provider: Arc<dyn EmbeddingProvider>,
executor: BackgroundExecutor,
) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
embedding_provider,
executor,
pending_batch: Vec::new(),
pending_batch_token_count: 0,
finished_files_tx,
finished_files_rx,
}
}
pub fn push(&mut self, file: FileToEmbed) {
if file.spans.is_empty() {
self.finished_files_tx.try_send(file).unwrap();
return;
}
let file = Arc::new(Mutex::new(file));
self.pending_batch.push(FileFragmentToEmbed {
file: file.clone(),
span_range: 0..0,
});
let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
for (ix, span) in file.lock().spans.iter().enumerate() {
let span_token_count = if span.embedding.is_none() {
span.token_count
} else {
0
};
let next_token_count = self.pending_batch_token_count + span_token_count;
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
let range_end = fragment_range.end;
self.flush();
self.pending_batch.push(FileFragmentToEmbed {
file: file.clone(),
span_range: range_end..range_end,
});
fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
}
fragment_range.end = ix + 1;
self.pending_batch_token_count += span_token_count;
}
}
pub fn flush(&mut self) {
let batch = mem::take(&mut self.pending_batch);
self.pending_batch_token_count = 0;
if batch.is_empty() {
return;
}
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
self.executor
.spawn(async move {
let mut spans = Vec::new();
for fragment in &batch {
let file = fragment.file.lock();
spans.extend(
file.spans[fragment.span_range.clone()]
.iter()
.filter(|d| d.embedding.is_none())
.map(|d| d.content.clone()),
);
}
// If spans is 0, just send the fragment to the finished files if its the last one.
if spans.is_empty() {
for fragment in batch.clone() {
if let Some(file) = Arc::into_inner(fragment.file) {
finished_files_tx.try_send(file.into_inner()).unwrap();
}
}
return;
};
match embedding_provider.embed_batch(spans).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {
for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
.iter_mut()
.filter(|d| d.embedding.is_none())
{
if let Some(embedding) = embeddings.next() {
span.embedding = Some(embedding);
} else {
log::error!("number of embeddings != number of documents");
}
}
if let Some(file) = Arc::into_inner(fragment.file) {
finished_files_tx.try_send(file.into_inner()).unwrap();
}
}
}
Err(error) => {
log::error!("{:?}", error);
}
}
})
.detach();
}
pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
self.finished_files_rx.clone()
}
}

View file

@ -1,414 +0,0 @@
use ai::{
embedding::{Embedding, EmbeddingProvider},
models::TruncationDirection,
};
use anyhow::{anyhow, Result};
use collections::HashSet;
use language::{Grammar, Language};
use rusqlite::{
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
ToSql,
};
use sha1::{Digest, Sha1};
use std::{
borrow::Cow,
cmp::{self, Reverse},
ops::Range,
path::Path,
sync::Arc,
};
use tree_sitter::{Parser, QueryCursor};
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct SpanDigest(pub [u8; 20]);
impl FromSql for SpanDigest {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let blob = value.as_blob()?;
let bytes =
blob.try_into()
.map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
expected_size: 20,
blob_size: blob.len(),
})?;
return Ok(SpanDigest(bytes));
}
}
impl ToSql for SpanDigest {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
self.0.to_sql()
}
}
impl From<&'_ str> for SpanDigest {
fn from(value: &'_ str) -> Self {
let mut sha1 = Sha1::new();
sha1.update(value);
Self(sha1.finalize().into())
}
}
#[derive(Debug, PartialEq, Clone)]
pub struct Span {
pub name: String,
pub range: Range<usize>,
pub content: String,
pub embedding: Option<Embedding>,
pub digest: SpanDigest,
pub token_count: usize,
}
const CODE_CONTEXT_TEMPLATE: &str =
"The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
const ENTIRE_FILE_TEMPLATE: &str =
"The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &[
"TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML", "Scheme",
];
pub struct CodeContextRetriever {
pub parser: Parser,
pub cursor: QueryCursor,
pub embedding_provider: Arc<dyn EmbeddingProvider>,
}
// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
// If there are preceding comments, we track this with a context capture
// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
#[derive(Debug, Clone)]
pub struct CodeContextMatch {
pub start_col: usize,
pub item_range: Option<Range<usize>>,
pub name_range: Option<Range<usize>>,
pub context_ranges: Vec<Range<usize>>,
pub collapse_ranges: Vec<Range<usize>>,
}
impl CodeContextRetriever {
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
Self {
parser: Parser::new(),
cursor: QueryCursor::new(),
embedding_provider,
}
}
fn parse_entire_file(
&self,
relative_path: Option<&Path>,
language_name: Arc<str>,
content: &str,
) -> Result<Vec<Span>> {
let document_span = ENTIRE_FILE_TEMPLATE
.replace(
"<path>",
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
)
.replace("<language>", language_name.as_ref())
.replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str());
let model = self.embedding_provider.base_model();
let document_span = model.truncate(
&document_span,
model.capacity()?,
ai::models::TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_span)?;
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
embedding: Default::default(),
name: language_name.to_string(),
digest,
token_count,
}])
}
fn parse_markdown_file(
&self,
relative_path: Option<&Path>,
content: &str,
) -> Result<Vec<Span>> {
let document_span = MARKDOWN_CONTEXT_TEMPLATE
.replace(
"<path>",
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
)
.replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str());
let model = self.embedding_provider.base_model();
let document_span = model.truncate(
&document_span,
model.capacity()?,
ai::models::TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_span)?;
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
embedding: None,
name: "Markdown".to_string(),
digest,
token_count,
}])
}
fn get_matches_in_file(
&mut self,
content: &str,
grammar: &Arc<Grammar>,
) -> Result<Vec<CodeContextMatch>> {
let embedding_config = grammar
.embedding_config
.as_ref()
.ok_or_else(|| anyhow!("no embedding queries"))?;
self.parser.set_language(&grammar.ts_language).unwrap();
let tree = self
.parser
.parse(&content, None)
.ok_or_else(|| anyhow!("parsing failed"))?;
let mut captures: Vec<CodeContextMatch> = Vec::new();
let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
let mut keep_ranges: Vec<Range<usize>> = Vec::new();
for mat in self.cursor.matches(
&embedding_config.query,
tree.root_node(),
content.as_bytes(),
) {
let mut start_col = 0;
let mut item_range: Option<Range<usize>> = None;
let mut name_range: Option<Range<usize>> = None;
let mut context_ranges: Vec<Range<usize>> = Vec::new();
collapse_ranges.clear();
keep_ranges.clear();
for capture in mat.captures {
if capture.index == embedding_config.item_capture_ix {
item_range = Some(capture.node.byte_range());
start_col = capture.node.start_position().column;
} else if Some(capture.index) == embedding_config.name_capture_ix {
name_range = Some(capture.node.byte_range());
} else if Some(capture.index) == embedding_config.context_capture_ix {
context_ranges.push(capture.node.byte_range());
} else if Some(capture.index) == embedding_config.collapse_capture_ix {
collapse_ranges.push(capture.node.byte_range());
} else if Some(capture.index) == embedding_config.keep_capture_ix {
keep_ranges.push(capture.node.byte_range());
}
}
captures.push(CodeContextMatch {
start_col,
item_range,
name_range,
context_ranges,
collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
});
}
Ok(captures)
}
pub fn parse_file_with_template(
&mut self,
relative_path: Option<&Path>,
content: &str,
language: Arc<Language>,
) -> Result<Vec<Span>> {
let language_name = language.name();
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
return self.parse_entire_file(relative_path, language_name, &content);
} else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
return self.parse_markdown_file(relative_path, &content);
}
let mut spans = self.parse_file(content, language)?;
for span in &mut spans {
let document_content = CODE_CONTEXT_TEMPLATE
.replace(
"<path>",
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
)
.replace("<language>", language_name.as_ref())
.replace("item", &span.content);
let model = self.embedding_provider.base_model();
let document_content = model.truncate(
&document_content,
model.capacity()?,
TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_content)?;
span.content = document_content;
span.token_count = token_count;
}
Ok(spans)
}
pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
let grammar = language
.grammar()
.ok_or_else(|| anyhow!("no grammar for language"))?;
// Iterate through query matches
let matches = self.get_matches_in_file(content, grammar)?;
let language_scope = language.default_scope();
let placeholder = language_scope.collapsed_placeholder();
let mut spans = Vec::new();
let mut collapsed_ranges_within = Vec::new();
let mut parsed_name_ranges = HashSet::default();
for (i, context_match) in matches.iter().enumerate() {
// Items which are collapsible but not embeddable have no item range
let item_range = if let Some(item_range) = context_match.item_range.clone() {
item_range
} else {
continue;
};
// Checks for deduplication
let name;
if let Some(name_range) = context_match.name_range.clone() {
name = content
.get(name_range.clone())
.map_or(String::new(), |s| s.to_string());
if parsed_name_ranges.contains(&name_range) {
continue;
}
parsed_name_ranges.insert(name_range);
} else {
name = String::new();
}
collapsed_ranges_within.clear();
'outer: for remaining_match in &matches[(i + 1)..] {
for collapsed_range in &remaining_match.collapse_ranges {
if item_range.start <= collapsed_range.start
&& item_range.end >= collapsed_range.end
{
collapsed_ranges_within.push(collapsed_range.clone());
} else {
break 'outer;
}
}
}
collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
let mut span_content = String::new();
for context_range in &context_match.context_ranges {
add_content_from_range(
&mut span_content,
content,
context_range.clone(),
context_match.start_col,
);
span_content.push_str("\n");
}
let mut offset = item_range.start;
for collapsed_range in &collapsed_ranges_within {
if collapsed_range.start > offset {
add_content_from_range(
&mut span_content,
content,
offset..collapsed_range.start,
context_match.start_col,
);
offset = collapsed_range.start;
}
if collapsed_range.end > offset {
span_content.push_str(placeholder);
offset = collapsed_range.end;
}
}
if offset < item_range.end {
add_content_from_range(
&mut span_content,
content,
offset..item_range.end,
context_match.start_col,
);
}
let sha1 = SpanDigest::from(span_content.as_str());
spans.push(Span {
name,
content: span_content,
range: item_range.clone(),
embedding: None,
digest: sha1,
token_count: 0,
})
}
return Ok(spans);
}
}
pub(crate) fn subtract_ranges(
ranges: &[Range<usize>],
ranges_to_subtract: &[Range<usize>],
) -> Vec<Range<usize>> {
let mut result = Vec::new();
let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
for range in ranges {
let mut offset = range.start;
while offset < range.end {
if let Some(range_to_subtract) = ranges_to_subtract.peek() {
if offset < range_to_subtract.start {
let next_offset = cmp::min(range_to_subtract.start, range.end);
result.push(offset..next_offset);
offset = next_offset;
} else {
let next_offset = cmp::min(range_to_subtract.end, range.end);
offset = next_offset;
}
if offset >= range_to_subtract.end {
ranges_to_subtract.next();
}
} else {
result.push(offset..range.end);
offset = range.end;
}
}
}
result
}
fn add_content_from_range(
output: &mut String,
content: &str,
range: Range<usize>,
start_col: usize,
) {
for mut line in content.get(range.clone()).unwrap_or("").lines() {
for _ in 0..start_col {
if line.starts_with(' ') {
line = &line[1..];
} else {
break;
}
}
output.push_str(line);
output.push('\n');
}
output.pop();
}

File diff suppressed because it is too large Load diff

View file

@ -1,33 +0,0 @@
use anyhow;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
#[derive(Deserialize, Debug)]
pub struct SemanticIndexSettings {
pub enabled: bool,
}
/// Configuration of semantic index, an alternate search engine available in
/// project search.
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
pub struct SemanticIndexSettingsContent {
/// Whether or not to display the Semantic mode in project search.
///
/// Default: true
pub enabled: Option<bool>,
}
impl Settings for SemanticIndexSettings {
const KEY: Option<&'static str> = Some("semantic_index");
type FileContent = SemanticIndexSettingsContent;
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_: &mut gpui::AppContext,
) -> anyhow::Result<Self> {
Self::load_via_json_merge(default_value, user_values)
}
}

File diff suppressed because it is too large Load diff

View file

@ -479,7 +479,28 @@ impl SettingsStore {
merge_schema(target_schema, setting_schema.schema);
}
fn merge_schema(target: &mut SchemaObject, source: SchemaObject) {
fn merge_schema(target: &mut SchemaObject, mut source: SchemaObject) {
let source_subschemas = source.subschemas();
let target_subschemas = target.subschemas();
if let Some(all_of) = source_subschemas.all_of.take() {
target_subschemas
.all_of
.get_or_insert(Vec::new())
.extend(all_of);
}
if let Some(any_of) = source_subschemas.any_of.take() {
target_subschemas
.any_of
.get_or_insert(Vec::new())
.extend(any_of);
}
if let Some(one_of) = source_subschemas.one_of.take() {
target_subschemas
.one_of
.get_or_insert(Vec::new())
.extend(one_of);
}
if let Some(source) = source.object {
let target_properties = &mut target.object().properties;
for (key, value) in source.properties {

View file

@ -5,9 +5,8 @@ use futures_lite::FutureExt;
use isahc::config::{Configurable, RedirectPolicy};
pub use isahc::{
http::{Method, StatusCode, Uri},
Error,
AsyncBody, Error, HttpClient as IsahcHttpClient, Request, Response,
};
pub use isahc::{AsyncBody, Request, Response};
#[cfg(feature = "test-support")]
use std::fmt;
use std::{

View file

@ -71,7 +71,6 @@ recent_projects.workspace = true
release_channel.workspace = true
rope.workspace = true
search.workspace = true
semantic_index.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true

View file

@ -174,7 +174,7 @@ fn main() {
node_runtime.clone(),
cx,
);
assistant::init(cx);
assistant::init(client.clone(), cx);
extension::init(
fs.clone(),
@ -247,7 +247,6 @@ fn main() {
tasks_ui::init(cx);
channel::init(&client, user_store.clone(), cx);
search::init(cx);
semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx);
vim::init(cx);
terminal_view::init(cx);

View file

@ -3060,7 +3060,7 @@ mod tests {
collab_ui::init(&app_state, cx);
project_panel::init((), cx);
terminal_view::init(cx);
assistant::init(cx);
assistant::init(app_state.client.clone(), cx);
initialize_workspace(app_state.clone(), cx);
app_state
})

View file

@ -606,28 +606,6 @@ These values take in the same options as the root-level settings with the same n
`boolean` values
## Semantic Index
- Description: Settings related to semantic index.
- Setting: `semantic_index`
- Default:
```json
"semantic_index": {
"enabled": false
},
```
### Enabled
- Description: Whether or not to display the `Semantic` mode in project search.
- Setting: `enabled`
- Default: `true`
**Options**
`boolean` values
## Show Call Status Icon
- Description: Whether or not to show the call status icon in the status bar.

View file

@ -11,3 +11,8 @@ cargo run -p collab -- migrate
echo "seeding database..."
script/seed-db
if [[ "$OSTYPE" == "linux-gnu"* ]]; then
echo "Linux dependencies..."
script/linux
fi

View file

@ -1,3 +0,0 @@
#!/bin/bash
RUST_LOG=semantic_index=trace cargo run --example semantic_index_eval --release

91
script/gemini.py Normal file
View file

@ -0,0 +1,91 @@
import subprocess
import json
import http.client
import mimetypes
import os
def get_text_files():
text_files = []
# List all files tracked by Git
git_files_proc = subprocess.run(['git', 'ls-files'], stdout=subprocess.PIPE, text=True)
for file in git_files_proc.stdout.strip().split('\n'):
# Check MIME type for each file
mime_check_proc = subprocess.run(['file', '--mime', file], stdout=subprocess.PIPE, text=True)
if 'text' in mime_check_proc.stdout:
text_files.append(file)
print(f"File count: {len(text_files)}")
return text_files
def get_file_contents(file):
# Read file content
with open(file, 'r') as f:
return f.read()
def main():
GEMINI_API_KEY = os.environ.get('GEMINI_API_KEY')
# Your prompt
prompt = "Document the data types and dataflow in this codebase in preparation to port a streaming implementation to rust:\n\n"
# Fetch all text files
text_files = get_text_files()
code_blocks = []
for file in text_files:
file_contents = get_file_contents(file)
# Create a code block for each text file
code_blocks.append(f"\n`{file}`\n\n```{file_contents}```\n")
# Construct the JSON payload
payload = json.dumps({
"contents": [{
"parts": [{
"text": prompt + "".join(code_blocks)
}]
}]
})
# Prepare the HTTP connection
conn = http.client.HTTPSConnection("generativelanguage.googleapis.com")
# Define headers
headers = {
'Content-Type': 'application/json',
'Content-Length': str(len(payload))
}
# Output the content length in bytes
print(f"Content Length in kilobytes: {len(payload.encode('utf-8')) / 1024:.2f} KB")
# Send a request to count the tokens
conn.request("POST", f"/v1beta/models/gemini-1.5-pro-latest:countTokens?key={GEMINI_API_KEY}", body=payload, headers=headers)
# Get the response
response = conn.getresponse()
if response.status == 200:
token_count = json.loads(response.read().decode('utf-8')).get('totalTokens')
print(f"Token count: {token_count}")
else:
print(f"Failed to get token count. Status code: {response.status}, Response body: {response.read().decode('utf-8')}")
# Prepare the HTTP connection
conn = http.client.HTTPSConnection("generativelanguage.googleapis.com")
conn.request("GET", f"/v1beta/models/gemini-1.5-pro-latest:streamGenerateContent?key={GEMINI_API_KEY}", body=payload, headers=headers)
# Get the response in a streaming manner
response = conn.getresponse()
if response.status == 200:
print("Successfully sent the data to the API.")
# Read the response in chunks
while chunk := response.read(4096):
print(chunk.decode('utf-8'))
else:
print(f"Failed to send the data to the API. Status code: {response.status}, Response body: {response.read().decode('utf-8')}")
# Close the connection
conn.close()
if __name__ == "__main__":
main()

View file

@ -1,4 +1,6 @@
#!/usr/bin/bash -e
#!/usr/bin/bash
set -e
# if sudo is not installed, define an empty alias
maysudo=$(command -v sudo || command -v doas || true)

1
script/script.py Normal file
View file

@ -0,0 +1 @@

View file

@ -3,12 +3,15 @@
set -e
# Install sqlx-cli if needed
[[ "$(sqlx --version)" == "sqlx-cli 0.5.7" ]] || cargo install sqlx-cli --version 0.5.7
if [[ "$(sqlx --version)" != "sqlx-cli 0.5.7" ]]; then
echo "sqlx-cli not found or not the required version, installing version 0.5.7..."
cargo install sqlx-cli --version 0.5.7
fi
cd crates/collab
# Export contents of .env.toml
eval "$(cargo run --quiet --bin dotenv)"
eval "$(cargo run --bin dotenv)"
# Run sqlx command
sqlx $@