Allow AI interactions to be proxied through Zed's server so you don't need an API key (#7367)
Co-authored-by: Antonio <antonio@zed.dev> Resurrected this from some assistant work I did in Spring of 2023. - [x] Resurrect streaming responses - [x] Use streaming responses to enable AI via Zed's servers by default (but preserve API key option for now) - [x] Simplify protobuf - [x] Proxy to OpenAI on zed.dev - [x] Proxy to Gemini on zed.dev - [x] Improve UX for switching between openAI and google models - We current disallow cycling when setting a custom model, but we need a better solution to keep OpenAI models available while testing the google ones - [x] Show remaining tokens correctly for Google models - [x] Remove semantic index - [x] Delete `ai` crate - [x] Cloud front so we can ban abuse - [x] Rate-limiting - [x] Fix panic when using inline assistant - [x] Double check the upgraded `AssistantSettings` are backwards-compatible - [x] Add hosted LLM interaction behind a `language-models` feature flag. Release Notes: - We are temporarily removing the semantic index in order to redesign it from scratch. --------- Co-authored-by: Antonio <antonio@zed.dev> Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Thorsten <thorsten@zed.dev> Co-authored-by: Max <max@zed.dev>
This commit is contained in:
parent
905a24079a
commit
8ae5a3b61a
87 changed files with 3647 additions and 8937 deletions
234
Cargo.lock
generated
234
Cargo.lock
generated
|
@ -85,32 +85,6 @@ dependencies = [
|
|||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ai"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"futures 0.3.28",
|
||||
"gpui",
|
||||
"isahc",
|
||||
"language",
|
||||
"log",
|
||||
"matrixmultiply",
|
||||
"ordered-float 2.10.0",
|
||||
"parking_lot",
|
||||
"parse_duration",
|
||||
"postage",
|
||||
"rand 0.8.5",
|
||||
"rusqlite",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tiktoken-rs",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "alacritty_terminal"
|
||||
version = "0.22.1-dev"
|
||||
|
@ -339,9 +313,9 @@ dependencies = [
|
|||
name = "assistant"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"ai",
|
||||
"anyhow",
|
||||
"chrono",
|
||||
"client",
|
||||
"collections",
|
||||
"ctor",
|
||||
"editor",
|
||||
|
@ -354,13 +328,14 @@ dependencies = [
|
|||
"log",
|
||||
"menu",
|
||||
"multi_buffer",
|
||||
"open_ai",
|
||||
"ordered-float 2.10.0",
|
||||
"parking_lot",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"schemars",
|
||||
"search",
|
||||
"semantic_index",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
|
@ -1339,7 +1314,7 @@ version = "0.3.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a6773ddc0eafc0e509fb60e48dff7f450f8e674a0686ae8605e8d9901bd5eefa"
|
||||
dependencies = [
|
||||
"num-bigint 0.4.4",
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
@ -2209,11 +2184,11 @@ dependencies = [
|
|||
"fs",
|
||||
"futures 0.3.28",
|
||||
"git",
|
||||
"google_ai",
|
||||
"gpui",
|
||||
"hex",
|
||||
"indoc",
|
||||
"language",
|
||||
"lazy_static",
|
||||
"live_kit_client",
|
||||
"live_kit_server",
|
||||
"log",
|
||||
|
@ -2222,6 +2197,7 @@ dependencies = [
|
|||
"nanoid",
|
||||
"node_runtime",
|
||||
"notifications",
|
||||
"open_ai",
|
||||
"parking_lot",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
|
@ -3554,24 +3530,12 @@ dependencies = [
|
|||
"workspace",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fallible-iterator"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
|
||||
|
||||
[[package]]
|
||||
name = "fallible-iterator"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649"
|
||||
|
||||
[[package]]
|
||||
name = "fallible-streaming-iterator"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
|
||||
|
||||
[[package]]
|
||||
name = "fancy-regex"
|
||||
version = "0.11.0"
|
||||
|
@ -4183,7 +4147,7 @@ version = "0.28.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0"
|
||||
dependencies = [
|
||||
"fallible-iterator 0.3.0",
|
||||
"fallible-iterator",
|
||||
"indexmap 2.0.0",
|
||||
"stable_deref_trait",
|
||||
]
|
||||
|
@ -4279,6 +4243,17 @@ dependencies = [
|
|||
"workspace",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "google_ai"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.28",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gpu-alloc"
|
||||
version = "0.6.0"
|
||||
|
@ -5667,16 +5642,6 @@ version = "0.7.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
version = "0.3.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"rawpointer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "maybe-owned"
|
||||
version = "0.3.4"
|
||||
|
@ -5946,19 +5911,6 @@ dependencies = [
|
|||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndarray"
|
||||
version = "0.15.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
|
||||
dependencies = [
|
||||
"matrixmultiply",
|
||||
"num-complex 0.4.4",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"rawpointer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndk"
|
||||
version = "0.7.0"
|
||||
|
@ -6111,45 +6063,20 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
|
||||
dependencies = [
|
||||
"num-bigint 0.2.6",
|
||||
"num-complex 0.2.4",
|
||||
"num-integer",
|
||||
"num-iter",
|
||||
"num-rational 0.2.4",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
|
||||
dependencies = [
|
||||
"num-bigint 0.4.4",
|
||||
"num-complex 0.4.4",
|
||||
"num-bigint",
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-iter",
|
||||
"num-rational 0.4.1",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint"
|
||||
version = "0.4.4"
|
||||
|
@ -6196,16 +6123,6 @@ dependencies = [
|
|||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.4"
|
||||
|
@ -6247,18 +6164,6 @@ dependencies = [
|
|||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-bigint 0.2.6",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.3.2"
|
||||
|
@ -6277,7 +6182,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-bigint 0.4.4",
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
@ -6436,7 +6341,7 @@ dependencies = [
|
|||
"futures-util",
|
||||
"hkdf",
|
||||
"hmac 0.12.1",
|
||||
"num 0.4.1",
|
||||
"num",
|
||||
"num-bigint-dig 0.8.4",
|
||||
"pbkdf2 0.12.2",
|
||||
"rand 0.8.5",
|
||||
|
@ -6464,6 +6369,18 @@ dependencies = [
|
|||
"pathdiff",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "open_ai"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.28",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.57"
|
||||
|
@ -6679,17 +6596,6 @@ dependencies = [
|
|||
"windows-targets 0.48.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parse_duration"
|
||||
version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7037e5e93e0172a5a96874380bf73bc6ecef022e26fa25f2be26864d6b3ba95d"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
"num 0.2.1",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "password-hash"
|
||||
version = "0.2.1"
|
||||
|
@ -7471,12 +7377,6 @@ dependencies = [
|
|||
"raw-window-handle 0.5.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rawpointer"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.8.0"
|
||||
|
@ -7935,20 +7835,6 @@ dependencies = [
|
|||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rusqlite"
|
||||
version = "0.29.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2"
|
||||
dependencies = [
|
||||
"bitflags 2.4.2",
|
||||
"fallible-iterator 0.2.0",
|
||||
"fallible-streaming-iterator",
|
||||
"hashlink",
|
||||
"libsqlite3-sys",
|
||||
"smallvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-embed"
|
||||
version = "8.2.0"
|
||||
|
@ -8378,7 +8264,6 @@ dependencies = [
|
|||
"language",
|
||||
"menu",
|
||||
"project",
|
||||
"semantic_index",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
|
@ -8434,52 +8319,6 @@ version = "1.0.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "58bf37232d3bb9a2c4e641ca2a11d83b5062066f88df7fed36c28772046d65ba"
|
||||
|
||||
[[package]]
|
||||
name = "semantic_index"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"ai",
|
||||
"anyhow",
|
||||
"collections",
|
||||
"ctor",
|
||||
"env_logger",
|
||||
"futures 0.3.28",
|
||||
"gpui",
|
||||
"language",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"ndarray",
|
||||
"ordered-float 2.10.0",
|
||||
"parking_lot",
|
||||
"postage",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"release_channel",
|
||||
"rpc",
|
||||
"rusqlite",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"sha1",
|
||||
"smol",
|
||||
"tempfile",
|
||||
"tree-sitter",
|
||||
"tree-sitter-cpp",
|
||||
"tree-sitter-elixir",
|
||||
"tree-sitter-json 0.20.0",
|
||||
"tree-sitter-lua",
|
||||
"tree-sitter-php",
|
||||
"tree-sitter-ruby",
|
||||
"tree-sitter-rust",
|
||||
"tree-sitter-toml",
|
||||
"tree-sitter-typescript",
|
||||
"unindent",
|
||||
"util",
|
||||
"workspace",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.18"
|
||||
|
@ -8766,7 +8605,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "8eb4ea60fb301dc81dfc113df680571045d375ab7345d171c5dc7d7e13107a80"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"num-bigint 0.4.4",
|
||||
"num-bigint",
|
||||
"num-traits",
|
||||
"thiserror",
|
||||
]
|
||||
|
@ -9197,7 +9036,7 @@ dependencies = [
|
|||
"log",
|
||||
"md-5",
|
||||
"memchr",
|
||||
"num-bigint 0.4.4",
|
||||
"num-bigint",
|
||||
"once_cell",
|
||||
"rand 0.8.5",
|
||||
"rust_decimal",
|
||||
|
@ -12729,7 +12568,6 @@ dependencies = [
|
|||
"release_channel",
|
||||
"rope",
|
||||
"search",
|
||||
"semantic_index",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
[workspace]
|
||||
members = [
|
||||
"crates/activity_indicator",
|
||||
"crates/ai",
|
||||
"crates/assets",
|
||||
"crates/assistant",
|
||||
"crates/audio",
|
||||
|
@ -34,6 +33,7 @@ members = [
|
|||
"crates/fuzzy",
|
||||
"crates/git",
|
||||
"crates/go_to_line",
|
||||
"crates/google_ai",
|
||||
"crates/gpui",
|
||||
"crates/gpui_macros",
|
||||
"crates/image_viewer",
|
||||
|
@ -52,6 +52,7 @@ members = [
|
|||
"crates/multi_buffer",
|
||||
"crates/node_runtime",
|
||||
"crates/notifications",
|
||||
"crates/open_ai",
|
||||
"crates/outline",
|
||||
"crates/picker",
|
||||
"crates/prettier",
|
||||
|
@ -69,7 +70,6 @@ members = [
|
|||
"crates/task",
|
||||
"crates/tasks_ui",
|
||||
"crates/search",
|
||||
"crates/semantic_index",
|
||||
"crates/settings",
|
||||
"crates/snippet",
|
||||
"crates/sqlez",
|
||||
|
@ -138,6 +138,7 @@ fsevent = { path = "crates/fsevent" }
|
|||
fuzzy = { path = "crates/fuzzy" }
|
||||
git = { path = "crates/git" }
|
||||
go_to_line = { path = "crates/go_to_line" }
|
||||
google_ai = { path = "crates/google_ai" }
|
||||
gpui = { path = "crates/gpui" }
|
||||
gpui_macros = { path = "crates/gpui_macros" }
|
||||
install_cli = { path = "crates/install_cli" }
|
||||
|
@ -156,6 +157,7 @@ menu = { path = "crates/menu" }
|
|||
multi_buffer = { path = "crates/multi_buffer" }
|
||||
node_runtime = { path = "crates/node_runtime" }
|
||||
notifications = { path = "crates/notifications" }
|
||||
open_ai = { path = "crates/open_ai" }
|
||||
outline = { path = "crates/outline" }
|
||||
picker = { path = "crates/picker" }
|
||||
plugin = { path = "crates/plugin" }
|
||||
|
@ -174,7 +176,6 @@ rpc = { path = "crates/rpc" }
|
|||
task = { path = "crates/task" }
|
||||
tasks_ui = { path = "crates/tasks_ui" }
|
||||
search = { path = "crates/search" }
|
||||
semantic_index = { path = "crates/semantic_index" }
|
||||
settings = { path = "crates/settings" }
|
||||
snippet = { path = "crates/snippet" }
|
||||
sqlez = { path = "crates/sqlez" }
|
||||
|
|
|
@ -251,7 +251,6 @@
|
|||
"alt-tab": "search::CycleMode",
|
||||
"cmd-shift-h": "search::ToggleReplace",
|
||||
"alt-cmd-g": "search::ActivateRegexMode",
|
||||
"alt-cmd-s": "search::ActivateSemanticMode",
|
||||
"alt-cmd-x": "search::ActivateTextMode"
|
||||
}
|
||||
},
|
||||
|
@ -276,7 +275,6 @@
|
|||
"alt-tab": "search::CycleMode",
|
||||
"cmd-shift-h": "search::ToggleReplace",
|
||||
"alt-cmd-g": "search::ActivateRegexMode",
|
||||
"alt-cmd-s": "search::ActivateSemanticMode",
|
||||
"alt-cmd-x": "search::ActivateTextMode"
|
||||
}
|
||||
},
|
||||
|
@ -302,7 +300,6 @@
|
|||
"alt-tab": "search::CycleMode",
|
||||
"alt-cmd-f": "project_search::ToggleFilters",
|
||||
"alt-cmd-g": "search::ActivateRegexMode",
|
||||
"alt-cmd-s": "search::ActivateSemanticMode",
|
||||
"alt-cmd-x": "search::ActivateTextMode"
|
||||
}
|
||||
},
|
||||
|
|
|
@ -237,6 +237,8 @@
|
|||
"default_width": 380
|
||||
},
|
||||
"assistant": {
|
||||
// Version of this setting.
|
||||
"version": "1",
|
||||
// Whether to show the assistant panel button in the status bar.
|
||||
"button": true,
|
||||
// Where to dock the assistant panel. Can be 'left', 'right' or 'bottom'.
|
||||
|
@ -245,28 +247,16 @@
|
|||
"default_width": 640,
|
||||
// Default height when the assistant is docked to the bottom.
|
||||
"default_height": 320,
|
||||
// Deprecated: Please use `provider.api_url` instead.
|
||||
// The default OpenAI API endpoint to use when starting new conversations.
|
||||
"openai_api_url": "https://api.openai.com/v1",
|
||||
// Deprecated: Please use `provider.default_model` instead.
|
||||
// The default OpenAI model to use when starting new conversations. This
|
||||
// setting can take three values:
|
||||
//
|
||||
// 1. "gpt-3.5-turbo-0613""
|
||||
// 2. "gpt-4-0613""
|
||||
// 3. "gpt-4-1106-preview"
|
||||
"default_open_ai_model": "gpt-4-1106-preview",
|
||||
// AI provider.
|
||||
"provider": {
|
||||
"type": "openai",
|
||||
// The default OpenAI API endpoint to use when starting new conversations.
|
||||
"api_url": "https://api.openai.com/v1",
|
||||
// The default OpenAI model to use when starting new conversations. This
|
||||
"name": "openai",
|
||||
// The default model to use when starting new conversations. This
|
||||
// setting can take three values:
|
||||
//
|
||||
// 1. "gpt-3.5-turbo-0613""
|
||||
// 2. "gpt-4-0613""
|
||||
// 3. "gpt-4-1106-preview"
|
||||
"default_model": "gpt-4-1106-preview"
|
||||
// 1. "gpt-3.5-turbo"
|
||||
// 2. "gpt-4"
|
||||
// 3. "gpt-4-turbo-preview"
|
||||
"default_model": "gpt-4-turbo-preview"
|
||||
}
|
||||
},
|
||||
// Whether the screen sharing icon is shown in the os status bar.
|
||||
|
@ -505,10 +495,6 @@
|
|||
// Existing terminals will not pick up this change until they are recreated.
|
||||
// "max_scroll_history_lines": 10000,
|
||||
},
|
||||
// Difference settings for semantic_index
|
||||
"semantic_index": {
|
||||
"enabled": true
|
||||
},
|
||||
// Settings specific to our elixir integration
|
||||
"elixir": {
|
||||
// Change the LSP zed uses for elixir.
|
||||
|
|
|
@ -1,41 +0,0 @@
|
|||
[package]
|
||||
name = "ai"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/ai.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = []
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
async-trait.workspace = true
|
||||
bincode = "1.3.3"
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
isahc.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
matrixmultiply = "0.3.7"
|
||||
ordered-float.workspace = true
|
||||
parking_lot.workspace = true
|
||||
parse_duration = "2.1.1"
|
||||
postage.workspace = true
|
||||
rand.workspace = true
|
||||
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
tiktoken-rs.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
|
@ -1 +0,0 @@
|
|||
../../LICENSE-GPL
|
|
@ -1,8 +0,0 @@
|
|||
pub mod auth;
|
||||
pub mod completion;
|
||||
pub mod embedding;
|
||||
pub mod models;
|
||||
pub mod prompts;
|
||||
pub mod providers;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod test;
|
|
@ -1,23 +0,0 @@
|
|||
use futures::future::BoxFuture;
|
||||
use gpui::AppContext;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ProviderCredential {
|
||||
Credentials { api_key: String },
|
||||
NoCredentials,
|
||||
NotNeeded,
|
||||
}
|
||||
|
||||
pub trait CredentialProvider: Send + Sync {
|
||||
fn has_credentials(&self) -> bool;
|
||||
#[must_use]
|
||||
fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential>;
|
||||
#[must_use]
|
||||
fn save_credentials(
|
||||
&self,
|
||||
cx: &mut AppContext,
|
||||
credential: ProviderCredential,
|
||||
) -> BoxFuture<()>;
|
||||
#[must_use]
|
||||
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()>;
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
use anyhow::Result;
|
||||
use futures::{future::BoxFuture, stream::BoxStream};
|
||||
|
||||
use crate::{auth::CredentialProvider, models::LanguageModel};
|
||||
|
||||
pub trait CompletionRequest: Send + Sync {
|
||||
fn data(&self) -> serde_json::Result<String>;
|
||||
}
|
||||
|
||||
pub trait CompletionProvider: CredentialProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider>;
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn CompletionProvider> {
|
||||
fn clone(&self) -> Box<dyn CompletionProvider> {
|
||||
self.box_clone()
|
||||
}
|
||||
}
|
|
@ -1,121 +0,0 @@
|
|||
use std::time::Instant;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use ordered_float::OrderedFloat;
|
||||
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
||||
use rusqlite::ToSql;
|
||||
|
||||
use crate::auth::CredentialProvider;
|
||||
use crate::models::LanguageModel;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Embedding(pub Vec<f32>);
|
||||
|
||||
// This is needed for semantic index functionality
|
||||
// Unfortunately it has to live wherever the "Embedding" struct is created.
|
||||
// Keeping this in here though, introduces a 'rusqlite' dependency into AI
|
||||
// which is less than ideal
|
||||
impl FromSql for Embedding {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let embedding =
|
||||
bincode::deserialize(bytes).map_err(|err| rusqlite::types::FromSqlError::Other(err))?;
|
||||
Ok(Embedding(embedding))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for Embedding {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||
let bytes = bincode::serialize(&self.0)
|
||||
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
|
||||
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
|
||||
}
|
||||
}
|
||||
impl From<Vec<f32>> for Embedding {
|
||||
fn from(value: Vec<f32>) -> Self {
|
||||
Embedding(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
|
||||
let len = self.0.len();
|
||||
assert_eq!(len, other.0.len());
|
||||
|
||||
let mut result = 0.0;
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(
|
||||
1,
|
||||
len,
|
||||
1,
|
||||
1.0,
|
||||
self.0.as_ptr(),
|
||||
len as isize,
|
||||
1,
|
||||
other.0.as_ptr(),
|
||||
1,
|
||||
len as isize,
|
||||
0.0,
|
||||
&mut result as *mut f32,
|
||||
1,
|
||||
1,
|
||||
);
|
||||
}
|
||||
OrderedFloat(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait EmbeddingProvider: CredentialProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
||||
fn max_tokens_per_batch(&self) -> usize;
|
||||
fn rate_limit_expiration(&self) -> Option<Instant>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::prelude::*;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_similarity(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
Embedding::from(vec![1., 0., 0., 0., 0.])
|
||||
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
|
||||
0.
|
||||
);
|
||||
assert_eq!(
|
||||
Embedding::from(vec![2., 0., 0., 0., 0.])
|
||||
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
|
||||
6.
|
||||
);
|
||||
|
||||
for _ in 0..100 {
|
||||
let size = 1536;
|
||||
let mut a = vec![0.; size];
|
||||
let mut b = vec![0.; size];
|
||||
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
|
||||
*a = rng.gen();
|
||||
*b = rng.gen();
|
||||
}
|
||||
let a = Embedding::from(a);
|
||||
let b = Embedding::from(b);
|
||||
|
||||
assert_eq!(
|
||||
round_to_decimals(a.similarity(&b), 1),
|
||||
round_to_decimals(reference_dot(&a.0, &b.0), 1)
|
||||
);
|
||||
}
|
||||
|
||||
fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
|
||||
let factor = 10.0_f32.powi(decimal_places);
|
||||
(n * factor).round() / factor
|
||||
}
|
||||
|
||||
fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
|
||||
OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,16 +0,0 @@
|
|||
pub enum TruncationDirection {
|
||||
Start,
|
||||
End,
|
||||
}
|
||||
|
||||
pub trait LanguageModel {
|
||||
fn name(&self) -> String;
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String>;
|
||||
fn capacity(&self) -> anyhow::Result<usize>;
|
||||
}
|
|
@ -1,337 +0,0 @@
|
|||
use std::cmp::Reverse;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
use language::BufferSnapshot;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::models::LanguageModel;
|
||||
use crate::prompts::repository_context::PromptCodeSnippet;
|
||||
|
||||
pub(crate) enum PromptFileType {
|
||||
Text,
|
||||
Code,
|
||||
}
|
||||
|
||||
// TODO: Set this up to manage for defaults well
|
||||
pub struct PromptArguments {
|
||||
pub model: Arc<dyn LanguageModel>,
|
||||
pub user_prompt: Option<String>,
|
||||
pub language_name: Option<String>,
|
||||
pub project_name: Option<String>,
|
||||
pub snippets: Vec<PromptCodeSnippet>,
|
||||
pub reserved_tokens: usize,
|
||||
pub buffer: Option<BufferSnapshot>,
|
||||
pub selected_range: Option<Range<usize>>,
|
||||
}
|
||||
|
||||
impl PromptArguments {
|
||||
pub(crate) fn get_file_type(&self) -> PromptFileType {
|
||||
if self
|
||||
.language_name
|
||||
.as_ref()
|
||||
.map(|name| !["Markdown", "Plain Text"].contains(&name.as_str()))
|
||||
.unwrap_or(true)
|
||||
{
|
||||
PromptFileType::Code
|
||||
} else {
|
||||
PromptFileType::Text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait PromptTemplate {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)>;
|
||||
}
|
||||
|
||||
#[repr(i8)]
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum PromptPriority {
|
||||
/// Ignores truncation.
|
||||
Mandatory,
|
||||
/// Truncates based on priority.
|
||||
Ordered { order: usize },
|
||||
}
|
||||
|
||||
impl PartialOrd for PromptPriority {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for PromptPriority {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
match (self, other) {
|
||||
(Self::Mandatory, Self::Mandatory) => std::cmp::Ordering::Equal,
|
||||
(Self::Mandatory, Self::Ordered { .. }) => std::cmp::Ordering::Greater,
|
||||
(Self::Ordered { .. }, Self::Mandatory) => std::cmp::Ordering::Less,
|
||||
(Self::Ordered { order: a }, Self::Ordered { order: b }) => b.cmp(a),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PromptChain {
|
||||
args: PromptArguments,
|
||||
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
|
||||
}
|
||||
|
||||
impl PromptChain {
|
||||
pub fn new(
|
||||
args: PromptArguments,
|
||||
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
|
||||
) -> Self {
|
||||
PromptChain { args, templates }
|
||||
}
|
||||
|
||||
pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
|
||||
// Argsort based on Prompt Priority
|
||||
let separator = "\n";
|
||||
let separator_tokens = self.args.model.count_tokens(separator)?;
|
||||
let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
|
||||
sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
|
||||
|
||||
let mut tokens_outstanding = if truncate {
|
||||
Some(self.args.model.capacity()? - self.args.reserved_tokens)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut prompts = vec!["".to_string(); sorted_indices.len()];
|
||||
for idx in sorted_indices {
|
||||
let (_, template) = &self.templates[idx];
|
||||
|
||||
if let Some((template_prompt, prompt_token_count)) =
|
||||
template.generate(&self.args, tokens_outstanding).log_err()
|
||||
{
|
||||
if template_prompt != "" {
|
||||
prompts[idx] = template_prompt;
|
||||
|
||||
if let Some(remaining_tokens) = tokens_outstanding {
|
||||
let new_tokens = prompt_token_count + separator_tokens;
|
||||
tokens_outstanding = if remaining_tokens > new_tokens {
|
||||
Some(remaining_tokens - new_tokens)
|
||||
} else {
|
||||
Some(0)
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prompts.retain(|x| x != "");
|
||||
|
||||
let full_prompt = prompts.join(separator);
|
||||
let total_token_count = self.args.model.count_tokens(&full_prompt)?;
|
||||
anyhow::Ok((prompts.join(separator), total_token_count))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use crate::models::TruncationDirection;
|
||||
use crate::test::FakeLanguageModel;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
pub fn test_prompt_chain() {
|
||||
struct TestPromptTemplate {}
|
||||
impl PromptTemplate for TestPromptTemplate {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let mut content = "This is a test prompt template".to_string();
|
||||
|
||||
let mut token_count = args.model.count_tokens(&content)?;
|
||||
if let Some(max_token_length) = max_token_length {
|
||||
if token_count > max_token_length {
|
||||
content = args.model.truncate(
|
||||
&content,
|
||||
max_token_length,
|
||||
TruncationDirection::End,
|
||||
)?;
|
||||
token_count = max_token_length;
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok((content, token_count))
|
||||
}
|
||||
}
|
||||
|
||||
struct TestLowPriorityTemplate {}
|
||||
impl PromptTemplate for TestLowPriorityTemplate {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let mut content = "This is a low priority test prompt template".to_string();
|
||||
|
||||
let mut token_count = args.model.count_tokens(&content)?;
|
||||
if let Some(max_token_length) = max_token_length {
|
||||
if token_count > max_token_length {
|
||||
content = args.model.truncate(
|
||||
&content,
|
||||
max_token_length,
|
||||
TruncationDirection::End,
|
||||
)?;
|
||||
token_count = max_token_length;
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok((content, token_count))
|
||||
}
|
||||
}
|
||||
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens: 0,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(false).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
prompt,
|
||||
"This is a test prompt template\nThis is a low priority test prompt template"
|
||||
.to_string()
|
||||
);
|
||||
|
||||
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
|
||||
|
||||
// Testing with Truncation Off
|
||||
// Should ignore capacity and return all prompts
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens: 0,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(false).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
prompt,
|
||||
"This is a test prompt template\nThis is a low priority test prompt template"
|
||||
.to_string()
|
||||
);
|
||||
|
||||
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
|
||||
|
||||
// Testing with Truncation Off
|
||||
// Should ignore capacity and return all prompts
|
||||
let capacity = 20;
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens: 0,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 2 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(true).unwrap();
|
||||
|
||||
assert_eq!(prompt, "This is a test promp".to_string());
|
||||
assert_eq!(token_count, capacity);
|
||||
|
||||
// Change Ordering of Prompts Based on Priority
|
||||
let capacity = 120;
|
||||
let reserved_tokens = 10;
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Mandatory,
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(true).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
prompt,
|
||||
"This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
|
||||
.to_string()
|
||||
);
|
||||
assert_eq!(token_count, capacity - reserved_tokens);
|
||||
}
|
||||
}
|
|
@ -1,164 +0,0 @@
|
|||
use anyhow::anyhow;
|
||||
use language::BufferSnapshot;
|
||||
use language::ToOffset;
|
||||
|
||||
use crate::models::LanguageModel;
|
||||
use crate::models::TruncationDirection;
|
||||
use crate::prompts::base::PromptArguments;
|
||||
use crate::prompts::base::PromptTemplate;
|
||||
use std::fmt::Write;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn retrieve_context(
|
||||
buffer: &BufferSnapshot,
|
||||
selected_range: &Option<Range<usize>>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
max_token_count: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize, bool)> {
|
||||
let mut prompt = String::new();
|
||||
let mut truncated = false;
|
||||
if let Some(selected_range) = selected_range {
|
||||
let start = selected_range.start.to_offset(buffer);
|
||||
let end = selected_range.end.to_offset(buffer);
|
||||
|
||||
let start_window = buffer.text_for_range(0..start).collect::<String>();
|
||||
|
||||
let mut selected_window = String::new();
|
||||
if start == end {
|
||||
write!(selected_window, "<|START|>").unwrap();
|
||||
} else {
|
||||
write!(selected_window, "<|START|").unwrap();
|
||||
}
|
||||
|
||||
write!(
|
||||
selected_window,
|
||||
"{}",
|
||||
buffer.text_for_range(start..end).collect::<String>()
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
if start != end {
|
||||
write!(selected_window, "|END|>").unwrap();
|
||||
}
|
||||
|
||||
let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
|
||||
|
||||
if let Some(max_token_count) = max_token_count {
|
||||
let selected_tokens = model.count_tokens(&selected_window)?;
|
||||
if selected_tokens > max_token_count {
|
||||
return Err(anyhow!(
|
||||
"selected range is greater than model context window, truncation not possible"
|
||||
));
|
||||
};
|
||||
|
||||
let mut remaining_tokens = max_token_count - selected_tokens;
|
||||
let start_window_tokens = model.count_tokens(&start_window)?;
|
||||
let end_window_tokens = model.count_tokens(&end_window)?;
|
||||
let outside_tokens = start_window_tokens + end_window_tokens;
|
||||
if outside_tokens > remaining_tokens {
|
||||
let (start_goal_tokens, end_goal_tokens) =
|
||||
if start_window_tokens < end_window_tokens {
|
||||
let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
|
||||
remaining_tokens -= start_goal_tokens;
|
||||
let end_goal_tokens = remaining_tokens.min(end_window_tokens);
|
||||
(start_goal_tokens, end_goal_tokens)
|
||||
} else {
|
||||
let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
|
||||
remaining_tokens -= end_goal_tokens;
|
||||
let start_goal_tokens = remaining_tokens.min(start_window_tokens);
|
||||
(start_goal_tokens, end_goal_tokens)
|
||||
};
|
||||
|
||||
let truncated_start_window =
|
||||
model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
|
||||
let truncated_end_window =
|
||||
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
|
||||
writeln!(
|
||||
prompt,
|
||||
"{truncated_start_window}{selected_window}{truncated_end_window}"
|
||||
)
|
||||
.unwrap();
|
||||
truncated = true;
|
||||
} else {
|
||||
writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
|
||||
}
|
||||
} else {
|
||||
// If we dont have a selected range, include entire file.
|
||||
writeln!(prompt, "{}", &buffer.text()).unwrap();
|
||||
|
||||
// Dumb truncation strategy
|
||||
if let Some(max_token_count) = max_token_count {
|
||||
if model.count_tokens(&prompt)? > max_token_count {
|
||||
truncated = true;
|
||||
prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let token_count = model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, token_count, truncated))
|
||||
}
|
||||
|
||||
pub struct FileContext {}
|
||||
|
||||
impl PromptTemplate for FileContext {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
if let Some(buffer) = &args.buffer {
|
||||
let mut prompt = String::new();
|
||||
// Add Initial Preamble
|
||||
// TODO: Do we want to add the path in here?
|
||||
writeln!(
|
||||
prompt,
|
||||
"The file you are currently working on has the following content:"
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let language_name = args
|
||||
.language_name
|
||||
.clone()
|
||||
.unwrap_or("".to_string())
|
||||
.to_lowercase();
|
||||
|
||||
let (context, _, truncated) = retrieve_context(
|
||||
buffer,
|
||||
&args.selected_range,
|
||||
args.model.clone(),
|
||||
max_token_length,
|
||||
)?;
|
||||
writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
|
||||
|
||||
if truncated {
|
||||
writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
|
||||
}
|
||||
|
||||
if let Some(selected_range) = &args.selected_range {
|
||||
let start = selected_range.start.to_offset(buffer);
|
||||
let end = selected_range.end.to_offset(buffer);
|
||||
|
||||
if start == end {
|
||||
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// Really dumb truncation strategy
|
||||
if let Some(max_tokens) = max_token_length {
|
||||
prompt = args
|
||||
.model
|
||||
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
|
||||
}
|
||||
|
||||
let token_count = args.model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, token_count))
|
||||
} else {
|
||||
Err(anyhow!("no buffer provided to retrieve file context from"))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,99 +0,0 @@
|
|||
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||
use anyhow::anyhow;
|
||||
use std::fmt::Write;
|
||||
|
||||
pub fn capitalize(s: &str) -> String {
|
||||
let mut c = s.chars();
|
||||
match c.next() {
|
||||
None => String::new(),
|
||||
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GenerateInlineContent {}
|
||||
|
||||
impl PromptTemplate for GenerateInlineContent {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let Some(user_prompt) = &args.user_prompt else {
|
||||
return Err(anyhow!("user prompt not provided"));
|
||||
};
|
||||
|
||||
let file_type = args.get_file_type();
|
||||
let content_type = match &file_type {
|
||||
PromptFileType::Code => "code",
|
||||
PromptFileType::Text => "text",
|
||||
};
|
||||
|
||||
let mut prompt = String::new();
|
||||
|
||||
if let Some(selected_range) = &args.selected_range {
|
||||
if selected_range.start == selected_range.end {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Assume the cursor is located where the `<|START|>` span is."
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"{} can't be replaced, so assume your answer will be inserted at the cursor.",
|
||||
capitalize(content_type)
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Generate {content_type} based on the users prompt: {user_prompt}",
|
||||
)
|
||||
.unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
|
||||
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
|
||||
writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
|
||||
}
|
||||
} else {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Generate {content_type} based on the users prompt: {user_prompt}"
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
if let Some(language_name) = &args.language_name {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Your answer MUST always and only be valid {}.",
|
||||
language_name
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
writeln!(prompt, "Never make remarks about the output.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Do not return anything else, except the generated {content_type}."
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
match file_type {
|
||||
PromptFileType::Code => {
|
||||
// writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Really dumb truncation strategy
|
||||
if let Some(max_tokens) = max_token_length {
|
||||
prompt = args.model.truncate(
|
||||
&prompt,
|
||||
max_tokens,
|
||||
crate::models::TruncationDirection::End,
|
||||
)?;
|
||||
}
|
||||
|
||||
let token_count = args.model.count_tokens(&prompt)?;
|
||||
|
||||
anyhow::Ok((prompt, token_count))
|
||||
}
|
||||
}
|
|
@ -1,5 +0,0 @@
|
|||
pub mod base;
|
||||
pub mod file_context;
|
||||
pub mod generate;
|
||||
pub mod preamble;
|
||||
pub mod repository_context;
|
|
@ -1,52 +0,0 @@
|
|||
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||
use std::fmt::Write;
|
||||
|
||||
pub struct EngineerPreamble {}
|
||||
|
||||
impl PromptTemplate for EngineerPreamble {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let mut prompts = Vec::new();
|
||||
|
||||
match args.get_file_type() {
|
||||
PromptFileType::Code => {
|
||||
prompts.push(format!(
|
||||
"You are an expert {}engineer.",
|
||||
args.language_name.clone().unwrap_or("".to_string()) + " "
|
||||
));
|
||||
}
|
||||
PromptFileType::Text => {
|
||||
prompts.push("You are an expert engineer.".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(project_name) = args.project_name.clone() {
|
||||
prompts.push(format!(
|
||||
"You are currently working inside the '{project_name}' project in code editor Zed."
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(mut remaining_tokens) = max_token_length {
|
||||
let mut prompt = String::new();
|
||||
let mut total_count = 0;
|
||||
for prompt_piece in prompts {
|
||||
let prompt_token_count =
|
||||
args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
|
||||
if remaining_tokens > prompt_token_count {
|
||||
writeln!(prompt, "{prompt_piece}").unwrap();
|
||||
remaining_tokens -= prompt_token_count;
|
||||
total_count += prompt_token_count;
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok((prompt, total_count))
|
||||
} else {
|
||||
let prompt = prompts.join("\n");
|
||||
let token_count = args.model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, token_count))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,96 +0,0 @@
|
|||
use crate::prompts::base::{PromptArguments, PromptTemplate};
|
||||
use std::fmt::Write;
|
||||
use std::{ops::Range, path::PathBuf};
|
||||
|
||||
use gpui::{AsyncAppContext, Model};
|
||||
use language::{Anchor, Buffer};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PromptCodeSnippet {
|
||||
path: Option<PathBuf>,
|
||||
language_name: Option<String>,
|
||||
content: String,
|
||||
}
|
||||
|
||||
impl PromptCodeSnippet {
|
||||
pub fn new(
|
||||
buffer: Model<Buffer>,
|
||||
range: Range<Anchor>,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
let content = snapshot.text_for_range(range.clone()).collect::<String>();
|
||||
|
||||
let language_name = buffer
|
||||
.language()
|
||||
.map(|language| language.name().to_string().to_lowercase());
|
||||
|
||||
let file_path = buffer.file().map(|file| file.path().to_path_buf());
|
||||
|
||||
(content, language_name, file_path)
|
||||
})?;
|
||||
|
||||
anyhow::Ok(PromptCodeSnippet {
|
||||
path: file_path,
|
||||
language_name,
|
||||
content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ToString for PromptCodeSnippet {
|
||||
fn to_string(&self) -> String {
|
||||
let path = self
|
||||
.path
|
||||
.as_ref()
|
||||
.map(|path| path.to_string_lossy().to_string())
|
||||
.unwrap_or("".to_string());
|
||||
let language_name = self.language_name.clone().unwrap_or("".to_string());
|
||||
let content = self.content.clone();
|
||||
|
||||
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RepositoryContext {}
|
||||
|
||||
impl PromptTemplate for RepositoryContext {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
|
||||
let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
|
||||
let mut prompt = String::new();
|
||||
|
||||
let mut remaining_tokens = max_token_length;
|
||||
let separator_token_length = args.model.count_tokens("\n")?;
|
||||
for snippet in &args.snippets {
|
||||
let mut snippet_prompt = template.to_string();
|
||||
let content = snippet.to_string();
|
||||
writeln!(snippet_prompt, "{content}").unwrap();
|
||||
|
||||
let token_count = args.model.count_tokens(&snippet_prompt)?;
|
||||
if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
|
||||
if let Some(tokens_left) = remaining_tokens {
|
||||
if tokens_left >= token_count {
|
||||
writeln!(prompt, "{snippet_prompt}").unwrap();
|
||||
remaining_tokens = if tokens_left >= (token_count + separator_token_length)
|
||||
{
|
||||
Some(tokens_left - token_count - separator_token_length)
|
||||
} else {
|
||||
Some(0)
|
||||
};
|
||||
}
|
||||
} else {
|
||||
writeln!(prompt, "{snippet_prompt}").unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let total_token_count = args.model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, total_token_count))
|
||||
}
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
pub mod open_ai;
|
|
@ -1,9 +0,0 @@
|
|||
pub mod completion;
|
||||
pub mod embedding;
|
||||
pub mod model;
|
||||
|
||||
pub use completion::*;
|
||||
pub use embedding::*;
|
||||
pub use model::OpenAiLanguageModel;
|
||||
|
||||
pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
|
|
@ -1,421 +0,0 @@
|
|||
use std::{
|
||||
env,
|
||||
fmt::{self, Display},
|
||||
io,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{
|
||||
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
||||
Stream, StreamExt,
|
||||
};
|
||||
use gpui::{AppContext, BackgroundExecutor};
|
||||
use isahc::{http::StatusCode, Request, RequestExt};
|
||||
use parking_lot::RwLock;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
|
||||
use crate::{
|
||||
auth::{CredentialProvider, ProviderCredential},
|
||||
completion::{CompletionProvider, CompletionRequest},
|
||||
models::LanguageModel,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn cycle(&mut self) {
|
||||
*self = match self {
|
||||
Role::User => Role::Assistant,
|
||||
Role::Assistant => Role::System,
|
||||
Role::System => Role::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Role::User => write!(f, "User"),
|
||||
Role::Assistant => write!(f, "Assistant"),
|
||||
Role::System => write!(f, "System"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
pub struct OpenAiRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
impl CompletionRequest for OpenAiRequest {
|
||||
fn data(&self) -> serde_json::Result<String> {
|
||||
serde_json::to_string(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAiUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ChatChoiceDelta {
|
||||
pub index: u32,
|
||||
pub delta: ResponseMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAiResponseStreamEvent {
|
||||
pub id: Option<String>,
|
||||
pub object: String,
|
||||
pub created: u32,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoiceDelta>,
|
||||
pub usage: Option<OpenAiUsage>,
|
||||
}
|
||||
|
||||
async fn stream_completion(
|
||||
api_url: String,
|
||||
kind: OpenAiCompletionProviderKind,
|
||||
credential: ProviderCredential,
|
||||
executor: BackgroundExecutor,
|
||||
request: Box<dyn CompletionRequest>,
|
||||
) -> Result<impl Stream<Item = Result<OpenAiResponseStreamEvent>>> {
|
||||
let api_key = match credential {
|
||||
ProviderCredential::Credentials { api_key } => api_key,
|
||||
_ => {
|
||||
return Err(anyhow!("no credentials provider for completion"));
|
||||
}
|
||||
};
|
||||
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
|
||||
|
||||
let (auth_header_name, auth_header_value) = kind.auth_header(api_key);
|
||||
let json_data = request.data()?;
|
||||
let mut response = Request::post(kind.completions_endpoint_url(&api_url))
|
||||
.header("Content-Type", "application/json")
|
||||
.header(auth_header_name, auth_header_value)
|
||||
.body(json_data)?
|
||||
.send_async()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if status == StatusCode::OK {
|
||||
executor
|
||||
.spawn(async move {
|
||||
let mut lines = BufReader::new(response.body_mut()).lines();
|
||||
|
||||
fn parse_line(
|
||||
line: Result<String, io::Error>,
|
||||
) -> Result<Option<OpenAiResponseStreamEvent>> {
|
||||
if let Some(data) = line?.strip_prefix("data: ") {
|
||||
let event = serde_json::from_str(data)?;
|
||||
Ok(Some(event))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(line) = lines.next().await {
|
||||
if let Some(event) = parse_line(line).transpose() {
|
||||
let done = event.as_ref().map_or(false, |event| {
|
||||
event
|
||||
.choices
|
||||
.last()
|
||||
.map_or(false, |choice| choice.finish_reason.is_some())
|
||||
});
|
||||
if tx.unbounded_send(event).is_err() {
|
||||
break;
|
||||
}
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
Ok(rx)
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiResponse {
|
||||
error: OpenAiError,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiError {
|
||||
message: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<OpenAiResponse>(&body) {
|
||||
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {}",
|
||||
response.error.message,
|
||||
)),
|
||||
|
||||
_ => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
|
||||
pub enum AzureOpenAiApiVersion {
|
||||
/// Retiring April 2, 2024.
|
||||
#[serde(rename = "2023-03-15-preview")]
|
||||
V2023_03_15Preview,
|
||||
#[serde(rename = "2023-05-15")]
|
||||
V2023_05_15,
|
||||
/// Retiring April 2, 2024.
|
||||
#[serde(rename = "2023-06-01-preview")]
|
||||
V2023_06_01Preview,
|
||||
/// Retiring April 2, 2024.
|
||||
#[serde(rename = "2023-07-01-preview")]
|
||||
V2023_07_01Preview,
|
||||
/// Retiring April 2, 2024.
|
||||
#[serde(rename = "2023-08-01-preview")]
|
||||
V2023_08_01Preview,
|
||||
/// Retiring April 2, 2024.
|
||||
#[serde(rename = "2023-09-01-preview")]
|
||||
V2023_09_01Preview,
|
||||
#[serde(rename = "2023-12-01-preview")]
|
||||
V2023_12_01Preview,
|
||||
#[serde(rename = "2024-02-15-preview")]
|
||||
V2024_02_15Preview,
|
||||
}
|
||||
|
||||
impl fmt::Display for AzureOpenAiApiVersion {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
Self::V2023_03_15Preview => "2023-03-15-preview",
|
||||
Self::V2023_05_15 => "2023-05-15",
|
||||
Self::V2023_06_01Preview => "2023-06-01-preview",
|
||||
Self::V2023_07_01Preview => "2023-07-01-preview",
|
||||
Self::V2023_08_01Preview => "2023-08-01-preview",
|
||||
Self::V2023_09_01Preview => "2023-09-01-preview",
|
||||
Self::V2023_12_01Preview => "2023-12-01-preview",
|
||||
Self::V2024_02_15Preview => "2024-02-15-preview",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum OpenAiCompletionProviderKind {
|
||||
OpenAi,
|
||||
AzureOpenAi {
|
||||
deployment_id: String,
|
||||
api_version: AzureOpenAiApiVersion,
|
||||
},
|
||||
}
|
||||
|
||||
impl OpenAiCompletionProviderKind {
|
||||
/// Returns the chat completion endpoint URL for this [`OpenAiCompletionProviderKind`].
|
||||
fn completions_endpoint_url(&self, api_url: &str) -> String {
|
||||
match self {
|
||||
Self::OpenAi => {
|
||||
// https://platform.openai.com/docs/api-reference/chat/create
|
||||
format!("{api_url}/chat/completions")
|
||||
}
|
||||
Self::AzureOpenAi {
|
||||
deployment_id,
|
||||
api_version,
|
||||
} => {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
|
||||
format!("{api_url}/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the authentication header for this [`OpenAiCompletionProviderKind`].
|
||||
fn auth_header(&self, api_key: String) -> (&'static str, String) {
|
||||
match self {
|
||||
Self::OpenAi => ("Authorization", format!("Bearer {api_key}")),
|
||||
Self::AzureOpenAi { .. } => ("Api-Key", api_key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAiCompletionProvider {
|
||||
api_url: String,
|
||||
kind: OpenAiCompletionProviderKind,
|
||||
model: OpenAiLanguageModel,
|
||||
credential: Arc<RwLock<ProviderCredential>>,
|
||||
executor: BackgroundExecutor,
|
||||
}
|
||||
|
||||
impl OpenAiCompletionProvider {
|
||||
pub async fn new(
|
||||
api_url: String,
|
||||
kind: OpenAiCompletionProviderKind,
|
||||
model_name: String,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Self {
|
||||
let model = executor
|
||||
.spawn(async move { OpenAiLanguageModel::load(&model_name) })
|
||||
.await;
|
||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||
Self {
|
||||
api_url,
|
||||
kind,
|
||||
model,
|
||||
credential,
|
||||
executor,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for OpenAiCompletionProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
match *self.credential.read() {
|
||||
ProviderCredential::Credentials { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
|
||||
let existing_credential = self.credential.read().clone();
|
||||
let retrieved_credential = match existing_credential {
|
||||
ProviderCredential::Credentials { .. } => {
|
||||
return async move { existing_credential }.boxed()
|
||||
}
|
||||
_ => {
|
||||
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
|
||||
async move { ProviderCredential::Credentials { api_key } }.boxed()
|
||||
} else {
|
||||
let credentials = cx.read_credentials(OPEN_AI_API_URL);
|
||||
async move {
|
||||
if let Some(Some((_, api_key))) = credentials.await.log_err() {
|
||||
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||
ProviderCredential::Credentials { api_key }
|
||||
} else {
|
||||
ProviderCredential::NoCredentials
|
||||
}
|
||||
} else {
|
||||
ProviderCredential::NoCredentials
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
async move {
|
||||
let retrieved_credential = retrieved_credential.await;
|
||||
*self.credential.write() = retrieved_credential.clone();
|
||||
retrieved_credential
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn save_credentials(
|
||||
&self,
|
||||
cx: &mut AppContext,
|
||||
credential: ProviderCredential,
|
||||
) -> BoxFuture<()> {
|
||||
*self.credential.write() = credential.clone();
|
||||
let credential = credential.clone();
|
||||
let write_credentials = match credential {
|
||||
ProviderCredential::Credentials { api_key } => {
|
||||
Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
async move {
|
||||
if let Some(write_credentials) = write_credentials {
|
||||
write_credentials.await.log_err();
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
|
||||
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||
let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
|
||||
async move {
|
||||
delete_credentials.await.log_err();
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionProvider for OpenAiCompletionProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||
model
|
||||
}
|
||||
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
|
||||
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
|
||||
// which is currently model based, due to the language model.
|
||||
// At some point in the future we should rectify this.
|
||||
let credential = self.credential.read().clone();
|
||||
let api_url = self.api_url.clone();
|
||||
let kind = self.kind.clone();
|
||||
let request = stream_completion(api_url, kind, credential, self.executor.clone(), prompt);
|
||||
async move {
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||
Box::new((*self).clone())
|
||||
}
|
||||
}
|
|
@ -1,345 +0,0 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::AsyncReadExt;
|
||||
use futures::FutureExt;
|
||||
use gpui::AppContext;
|
||||
use gpui::BackgroundExecutor;
|
||||
use isahc::http::StatusCode;
|
||||
use isahc::prelude::Configurable;
|
||||
use isahc::{AsyncBody, Response};
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use parse_duration::parse;
|
||||
use postage::watch;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
use std::env;
|
||||
use std::ops::Add;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::time::{Duration, Instant};
|
||||
use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||
use util::http::{HttpClient, Request};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::auth::{CredentialProvider, ProviderCredential};
|
||||
use crate::embedding::{Embedding, EmbeddingProvider};
|
||||
use crate::models::LanguageModel;
|
||||
use crate::providers::open_ai::OpenAiLanguageModel;
|
||||
|
||||
use crate::providers::open_ai::OPEN_AI_API_URL;
|
||||
|
||||
pub(crate) fn open_ai_bpe_tokenizer() -> &'static CoreBPE {
|
||||
static OPEN_AI_BPE_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
|
||||
OPEN_AI_BPE_TOKENIZER.get_or_init(|| cl100k_base().unwrap())
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAiEmbeddingProvider {
|
||||
api_url: String,
|
||||
model: OpenAiLanguageModel,
|
||||
credential: Arc<RwLock<ProviderCredential>>,
|
||||
pub client: Arc<dyn HttpClient>,
|
||||
pub executor: BackgroundExecutor,
|
||||
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
||||
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OpenAiEmbeddingRequest<'a> {
|
||||
model: &'static str,
|
||||
input: Vec<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiEmbeddingResponse {
|
||||
data: Vec<OpenAiEmbedding>,
|
||||
usage: OpenAiEmbeddingUsage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiEmbedding {
|
||||
embedding: Vec<f32>,
|
||||
index: usize,
|
||||
object: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiEmbeddingUsage {
|
||||
prompt_tokens: usize,
|
||||
total_tokens: usize,
|
||||
}
|
||||
|
||||
impl OpenAiEmbeddingProvider {
|
||||
pub async fn new(
|
||||
api_url: String,
|
||||
client: Arc<dyn HttpClient>,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Self {
|
||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||
|
||||
// Loading the model is expensive, so ensure this runs off the main thread.
|
||||
let model = executor
|
||||
.spawn(async move { OpenAiLanguageModel::load("text-embedding-ada-002") })
|
||||
.await;
|
||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||
|
||||
OpenAiEmbeddingProvider {
|
||||
api_url,
|
||||
model,
|
||||
credential,
|
||||
client,
|
||||
executor,
|
||||
rate_limit_count_rx,
|
||||
rate_limit_count_tx,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_api_key(&self) -> Result<String> {
|
||||
match self.credential.read().clone() {
|
||||
ProviderCredential::Credentials { api_key } => Ok(api_key),
|
||||
_ => Err(anyhow!("api credentials not provided")),
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_rate_limit(&self) {
|
||||
let reset_time = *self.rate_limit_count_tx.lock().borrow();
|
||||
|
||||
if let Some(reset_time) = reset_time {
|
||||
if Instant::now() >= reset_time {
|
||||
*self.rate_limit_count_tx.lock().borrow_mut() = None
|
||||
}
|
||||
}
|
||||
|
||||
log::trace!(
|
||||
"resolving reset time: {:?}",
|
||||
*self.rate_limit_count_tx.lock().borrow()
|
||||
);
|
||||
}
|
||||
|
||||
fn update_reset_time(&self, reset_time: Instant) {
|
||||
let original_time = *self.rate_limit_count_tx.lock().borrow();
|
||||
|
||||
let updated_time = if let Some(original_time) = original_time {
|
||||
if reset_time < original_time {
|
||||
Some(reset_time)
|
||||
} else {
|
||||
Some(original_time)
|
||||
}
|
||||
} else {
|
||||
Some(reset_time)
|
||||
};
|
||||
|
||||
log::trace!("updating rate limit time: {:?}", updated_time);
|
||||
|
||||
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
|
||||
}
|
||||
async fn send_request(
|
||||
&self,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
spans: Vec<&str>,
|
||||
request_timeout: u64,
|
||||
) -> Result<Response<AsyncBody>> {
|
||||
let request = Request::post(format!("{api_url}/embeddings"))
|
||||
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
||||
.timeout(Duration::from_secs(request_timeout))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(
|
||||
serde_json::to_string(&OpenAiEmbeddingRequest {
|
||||
input: spans.clone(),
|
||||
model: "text-embedding-ada-002",
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
)?;
|
||||
|
||||
Ok(self.client.send(request).await?)
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for OpenAiEmbeddingProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
match *self.credential.read() {
|
||||
ProviderCredential::Credentials { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
|
||||
let existing_credential = self.credential.read().clone();
|
||||
let retrieved_credential = match existing_credential {
|
||||
ProviderCredential::Credentials { .. } => {
|
||||
return async move { existing_credential }.boxed()
|
||||
}
|
||||
_ => {
|
||||
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
|
||||
async move { ProviderCredential::Credentials { api_key } }.boxed()
|
||||
} else {
|
||||
let credentials = cx.read_credentials(OPEN_AI_API_URL);
|
||||
async move {
|
||||
if let Some(Some((_, api_key))) = credentials.await.log_err() {
|
||||
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||
ProviderCredential::Credentials { api_key }
|
||||
} else {
|
||||
ProviderCredential::NoCredentials
|
||||
}
|
||||
} else {
|
||||
ProviderCredential::NoCredentials
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
async move {
|
||||
let retrieved_credential = retrieved_credential.await;
|
||||
*self.credential.write() = retrieved_credential.clone();
|
||||
retrieved_credential
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn save_credentials(
|
||||
&self,
|
||||
cx: &mut AppContext,
|
||||
credential: ProviderCredential,
|
||||
) -> BoxFuture<()> {
|
||||
*self.credential.write() = credential.clone();
|
||||
let credential = credential.clone();
|
||||
let write_credentials = match credential {
|
||||
ProviderCredential::Credentials { api_key } => {
|
||||
Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
async move {
|
||||
if let Some(write_credentials) = write_credentials {
|
||||
write_credentials.await.log_err();
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
|
||||
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||
let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
|
||||
async move {
|
||||
delete_credentials.await.log_err();
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for OpenAiEmbeddingProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||
model
|
||||
}
|
||||
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
50000
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
*self.rate_limit_count_rx.borrow()
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||
const MAX_RETRIES: usize = 4;
|
||||
|
||||
let api_url = self.api_url.as_str();
|
||||
let api_key = self.get_api_key()?;
|
||||
|
||||
let mut request_number = 0;
|
||||
let mut rate_limiting = false;
|
||||
let mut request_timeout: u64 = 15;
|
||||
let mut response: Response<AsyncBody>;
|
||||
while request_number < MAX_RETRIES {
|
||||
response = self
|
||||
.send_request(
|
||||
&api_url,
|
||||
&api_key,
|
||||
spans.iter().map(|x| &**x).collect(),
|
||||
request_timeout,
|
||||
)
|
||||
.await?;
|
||||
|
||||
request_number += 1;
|
||||
|
||||
match response.status() {
|
||||
StatusCode::REQUEST_TIMEOUT => {
|
||||
request_timeout += 5;
|
||||
}
|
||||
StatusCode::OK => {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
let response: OpenAiEmbeddingResponse = serde_json::from_str(&body)?;
|
||||
|
||||
log::trace!(
|
||||
"openai embedding completed. tokens: {:?}",
|
||||
response.usage.total_tokens
|
||||
);
|
||||
|
||||
// If we complete a request successfully that was previously rate_limited
|
||||
// resolve the rate limit
|
||||
if rate_limiting {
|
||||
self.resolve_rate_limit()
|
||||
}
|
||||
|
||||
return Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|embedding| Embedding::from(embedding.embedding))
|
||||
.collect());
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
rate_limiting = true;
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
let delay_duration = {
|
||||
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
||||
if let Some(time_to_reset) =
|
||||
response.headers().get("x-ratelimit-reset-tokens")
|
||||
{
|
||||
if let Ok(time_str) = time_to_reset.to_str() {
|
||||
parse(time_str).unwrap_or(delay)
|
||||
} else {
|
||||
delay
|
||||
}
|
||||
} else {
|
||||
delay
|
||||
}
|
||||
};
|
||||
|
||||
// If we've previously rate limited, increment the duration but not the count
|
||||
let reset_time = Instant::now().add(delay_duration);
|
||||
self.update_reset_time(reset_time);
|
||||
|
||||
log::trace!(
|
||||
"openai rate limiting: waiting {:?} until lifted",
|
||||
&delay_duration
|
||||
);
|
||||
|
||||
self.executor.timer(delay_duration).await;
|
||||
}
|
||||
_ => {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(
|
||||
"open ai bad request: {:?} {:?}",
|
||||
&response.status(),
|
||||
body
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(anyhow!("openai max retries"))
|
||||
}
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
use anyhow::anyhow;
|
||||
use tiktoken_rs::CoreBPE;
|
||||
|
||||
use crate::models::{LanguageModel, TruncationDirection};
|
||||
|
||||
use super::open_ai_bpe_tokenizer;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAiLanguageModel {
|
||||
name: String,
|
||||
bpe: Option<CoreBPE>,
|
||||
}
|
||||
|
||||
impl OpenAiLanguageModel {
|
||||
pub fn load(model_name: &str) -> Self {
|
||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name)
|
||||
.unwrap_or(open_ai_bpe_tokenizer().to_owned());
|
||||
OpenAiLanguageModel {
|
||||
name: model_name.to_string(),
|
||||
bpe: Some(bpe),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for OpenAiLanguageModel {
|
||||
fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||
if let Some(bpe) = &self.bpe {
|
||||
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
|
||||
} else {
|
||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||
}
|
||||
}
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String> {
|
||||
if let Some(bpe) = &self.bpe {
|
||||
let tokens = bpe.encode_with_special_tokens(content);
|
||||
if tokens.len() > length {
|
||||
match direction {
|
||||
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
|
||||
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
|
||||
}
|
||||
} else {
|
||||
bpe.decode(tokens)
|
||||
}
|
||||
} else {
|
||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||
}
|
||||
}
|
||||
fn capacity(&self) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
|
||||
}
|
||||
}
|
|
@ -1,206 +0,0 @@
|
|||
use std::{
|
||||
sync::atomic::{self, AtomicUsize, Ordering},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::AppContext;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::{
|
||||
auth::{CredentialProvider, ProviderCredential},
|
||||
completion::{CompletionProvider, CompletionRequest},
|
||||
embedding::{Embedding, EmbeddingProvider},
|
||||
models::{LanguageModel, TruncationDirection},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FakeLanguageModel {
|
||||
pub capacity: usize,
|
||||
}
|
||||
|
||||
impl LanguageModel for FakeLanguageModel {
|
||||
fn name(&self) -> String {
|
||||
"dummy".to_string()
|
||||
}
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
||||
}
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String> {
|
||||
println!("TRYING TO TRUNCATE: {:?}", length.clone());
|
||||
|
||||
if length > self.count_tokens(content)? {
|
||||
println!("NOT TRUNCATING");
|
||||
return anyhow::Ok(content.to_string());
|
||||
}
|
||||
|
||||
anyhow::Ok(match direction {
|
||||
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
|
||||
.into_iter()
|
||||
.collect::<String>(),
|
||||
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
|
||||
.into_iter()
|
||||
.collect::<String>(),
|
||||
})
|
||||
}
|
||||
fn capacity(&self) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(self.capacity)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct FakeEmbeddingProvider {
|
||||
pub embedding_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl Clone for FakeEmbeddingProvider {
|
||||
fn clone(&self) -> Self {
|
||||
FakeEmbeddingProvider {
|
||||
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeEmbeddingProvider {
|
||||
pub fn embedding_count(&self) -> usize {
|
||||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub fn embed_sync(&self, span: &str) -> Embedding {
|
||||
let mut result = vec![1.0; 26];
|
||||
for letter in span.chars() {
|
||||
let letter = letter.to_ascii_lowercase();
|
||||
if letter as u32 >= 'a' as u32 {
|
||||
let ix = (letter as u32) - ('a' as u32);
|
||||
if ix < 26 {
|
||||
result[ix as usize] += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
for x in &mut result {
|
||||
*x /= norm;
|
||||
}
|
||||
|
||||
result.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for FakeEmbeddingProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
|
||||
async { ProviderCredential::NotNeeded }.boxed()
|
||||
}
|
||||
|
||||
fn save_credentials(
|
||||
&self,
|
||||
_cx: &mut AppContext,
|
||||
_credential: ProviderCredential,
|
||||
) -> BoxFuture<()> {
|
||||
async {}.boxed()
|
||||
}
|
||||
|
||||
fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
|
||||
async {}.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
Box::new(FakeLanguageModel { capacity: 1000 })
|
||||
}
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
1000
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
None
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
|
||||
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FakeCompletionProvider {
|
||||
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||
}
|
||||
|
||||
impl Clone for FakeCompletionProvider {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
last_completion_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeCompletionProvider {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
last_completion_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_completion(&self, completion: impl Into<String>) {
|
||||
let mut tx = self.last_completion_tx.lock();
|
||||
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||
}
|
||||
|
||||
pub fn finish_completion(&self) {
|
||||
self.last_completion_tx.lock().take().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for FakeCompletionProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
|
||||
async { ProviderCredential::NotNeeded }.boxed()
|
||||
}
|
||||
|
||||
fn save_credentials(
|
||||
&self,
|
||||
_cx: &mut AppContext,
|
||||
_credential: ProviderCredential,
|
||||
) -> BoxFuture<()> {
|
||||
async {}.boxed()
|
||||
}
|
||||
|
||||
fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
|
||||
async {}.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionProvider for FakeCompletionProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
|
||||
model
|
||||
}
|
||||
fn complete(
|
||||
&self,
|
||||
_prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
*self.last_completion_tx.lock() = Some(tx);
|
||||
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||
}
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||
Box::new((*self).clone())
|
||||
}
|
||||
}
|
|
@ -5,17 +5,14 @@ edition = "2021"
|
|||
publish = false
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/assistant.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
ai.workspace = true
|
||||
anyhow.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
editor.workspace = true
|
||||
fs.workspace = true
|
||||
|
@ -26,12 +23,13 @@ language.workspace = true
|
|||
log.workspace = true
|
||||
menu.workspace = true
|
||||
multi_buffer.workspace = true
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
ordered-float.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
schemars.workspace = true
|
||||
search.workspace = true
|
||||
semantic_index.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
|
@ -45,7 +43,6 @@ uuid.workspace = true
|
|||
workspace.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
ai = { workspace = true, features = ["test-support"] }
|
||||
ctor.workspace = true
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
|
|
|
@ -1,22 +1,24 @@
|
|||
pub mod assistant_panel;
|
||||
pub mod assistant_settings;
|
||||
mod codegen;
|
||||
mod completion_provider;
|
||||
mod prompts;
|
||||
mod saved_conversation;
|
||||
mod streaming_diff;
|
||||
|
||||
use ai::providers::open_ai::Role;
|
||||
use anyhow::Result;
|
||||
pub use assistant_panel::AssistantPanel;
|
||||
use assistant_settings::OpenAiModel;
|
||||
use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
|
||||
use chrono::{DateTime, Local};
|
||||
use collections::HashMap;
|
||||
use fs::Fs;
|
||||
use futures::StreamExt;
|
||||
use client::{proto, Client};
|
||||
pub(crate) use completion_provider::*;
|
||||
use gpui::{actions, AppContext, SharedString};
|
||||
use regex::Regex;
|
||||
pub(crate) use saved_conversation::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
|
||||
use util::paths::CONVERSATIONS_DIR;
|
||||
use settings::Settings;
|
||||
use std::{
|
||||
fmt::{self, Display},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
actions!(
|
||||
assistant,
|
||||
|
@ -30,7 +32,6 @@ actions!(
|
|||
ResetKey,
|
||||
InlineAssist,
|
||||
ToggleIncludeConversation,
|
||||
ToggleRetrieveContext,
|
||||
]
|
||||
);
|
||||
|
||||
|
@ -39,6 +40,139 @@ actions!(
|
|||
)]
|
||||
struct MessageId(usize);
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn cycle(&mut self) {
|
||||
*self = match self {
|
||||
Role::User => Role::Assistant,
|
||||
Role::Assistant => Role::System,
|
||||
Role::System => Role::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Role::User => write!(f, "user"),
|
||||
Role::Assistant => write!(f, "assistant"),
|
||||
Role::System => write!(f, "system"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub enum LanguageModel {
|
||||
ZedDotDev(ZedDotDevModel),
|
||||
OpenAi(OpenAiModel),
|
||||
}
|
||||
|
||||
impl Default for LanguageModel {
|
||||
fn default() -> Self {
|
||||
LanguageModel::ZedDotDev(ZedDotDevModel::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel {
|
||||
pub fn telemetry_id(&self) -> String {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
|
||||
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> String {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
|
||||
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => tiktoken_rs::model::get_context_size(model.id()),
|
||||
LanguageModel::ZedDotDev(model) => match model {
|
||||
ZedDotDevModel::GptThreePointFiveTurbo
|
||||
| ZedDotDevModel::GptFour
|
||||
| ZedDotDevModel::GptFourTurbo => tiktoken_rs::model::get_context_size(model.id()),
|
||||
ZedDotDevModel::Custom(_) => 30720, // TODO: Base this on the selected model.
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => model.id(),
|
||||
LanguageModel::ZedDotDev(model) => model.id(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct LanguageModelRequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl LanguageModelRequestMessage {
|
||||
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
|
||||
proto::LanguageModelRequestMessage {
|
||||
role: match self.role {
|
||||
Role::User => proto::LanguageModelRole::LanguageModelUser,
|
||||
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
|
||||
Role::System => proto::LanguageModelRole::LanguageModelSystem,
|
||||
} as i32,
|
||||
content: self.content.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
pub struct LanguageModelRequest {
|
||||
pub model: LanguageModel,
|
||||
pub messages: Vec<LanguageModelRequestMessage>,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
impl LanguageModelRequest {
|
||||
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
|
||||
proto::CompleteWithLanguageModel {
|
||||
model: self.model.id().to_string(),
|
||||
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
|
||||
stop: self.stop.clone(),
|
||||
temperature: self.temperature,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct LanguageModelResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct LanguageModelUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct LanguageModelChoiceDelta {
|
||||
pub index: u32,
|
||||
pub delta: LanguageModelResponseMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
struct MessageMetadata {
|
||||
role: Role,
|
||||
|
@ -53,71 +187,9 @@ enum MessageStatus {
|
|||
Error(SharedString),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SavedMessage {
|
||||
id: MessageId,
|
||||
start: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SavedConversation {
|
||||
id: Option<String>,
|
||||
zed: String,
|
||||
version: String,
|
||||
text: String,
|
||||
messages: Vec<SavedMessage>,
|
||||
message_metadata: HashMap<MessageId, MessageMetadata>,
|
||||
summary: String,
|
||||
api_url: Option<String>,
|
||||
model: OpenAiModel,
|
||||
}
|
||||
|
||||
impl SavedConversation {
|
||||
const VERSION: &'static str = "0.1.0";
|
||||
}
|
||||
|
||||
struct SavedConversationMetadata {
|
||||
title: String,
|
||||
path: PathBuf,
|
||||
mtime: chrono::DateTime<chrono::Local>,
|
||||
}
|
||||
|
||||
impl SavedConversationMetadata {
|
||||
pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
|
||||
fs.create_dir(&CONVERSATIONS_DIR).await?;
|
||||
|
||||
let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
|
||||
let mut conversations = Vec::<SavedConversationMetadata>::new();
|
||||
while let Some(path) = paths.next().await {
|
||||
let path = path?;
|
||||
if path.extension() != Some(OsStr::new("json")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let pattern = r" - \d+.zed.json$";
|
||||
let re = Regex::new(pattern).unwrap();
|
||||
|
||||
let metadata = fs.metadata(&path).await?;
|
||||
if let Some((file_name, metadata)) = path
|
||||
.file_name()
|
||||
.and_then(|name| name.to_str())
|
||||
.zip(metadata)
|
||||
{
|
||||
let title = re.replace(file_name, "");
|
||||
conversations.push(Self {
|
||||
title: title.into_owned(),
|
||||
path,
|
||||
mtime: metadata.mtime.into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
|
||||
|
||||
Ok(conversations)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(cx: &mut AppContext) {
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
AssistantSettings::register(cx);
|
||||
completion_provider::init(client, cx);
|
||||
assistant_panel::init(cx);
|
||||
}
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,169 +1,296 @@
|
|||
use ai::providers::open_ai::{
|
||||
AzureOpenAiApiVersion, OpenAiCompletionProviderKind, OPEN_AI_API_URL,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
use std::fmt;
|
||||
|
||||
use gpui::Pixels;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub use open_ai::Model as OpenAiModel;
|
||||
use schemars::{
|
||||
schema::{InstanceType, Metadata, Schema, SchemaObject},
|
||||
JsonSchema,
|
||||
};
|
||||
use serde::{
|
||||
de::{self, Visitor},
|
||||
Deserialize, Deserializer, Serialize, Serializer,
|
||||
};
|
||||
use settings::Settings;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum OpenAiModel {
|
||||
#[serde(rename = "gpt-3.5-turbo-0613")]
|
||||
ThreePointFiveTurbo,
|
||||
#[serde(rename = "gpt-4-0613")]
|
||||
Four,
|
||||
#[serde(rename = "gpt-4-1106-preview")]
|
||||
FourTurbo,
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub enum ZedDotDevModel {
|
||||
GptThreePointFiveTurbo,
|
||||
GptFour,
|
||||
#[default]
|
||||
GptFourTurbo,
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl OpenAiModel {
|
||||
pub fn full_name(&self) -> &'static str {
|
||||
impl Serialize for ZedDotDevModel {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(self.id())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for ZedDotDevModel {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct ZedDotDevModelVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
|
||||
type Value = ZedDotDevModel;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
match value {
|
||||
"gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo),
|
||||
"gpt-4" => Ok(ZedDotDevModel::GptFour),
|
||||
"gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo),
|
||||
_ => Ok(ZedDotDevModel::Custom(value.to_owned())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_str(ZedDotDevModelVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonSchema for ZedDotDevModel {
|
||||
fn schema_name() -> String {
|
||||
"ZedDotDevModel".to_owned()
|
||||
}
|
||||
|
||||
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
|
||||
let variants = vec![
|
||||
"gpt-3.5-turbo".to_owned(),
|
||||
"gpt-4".to_owned(),
|
||||
"gpt-4-turbo-preview".to_owned(),
|
||||
];
|
||||
Schema::Object(SchemaObject {
|
||||
instance_type: Some(InstanceType::String.into()),
|
||||
enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
|
||||
metadata: Some(Box::new(Metadata {
|
||||
title: Some("ZedDotDevModel".to_owned()),
|
||||
default: Some(serde_json::json!("gpt-4-turbo-preview")),
|
||||
examples: vec![
|
||||
serde_json::json!("gpt-3.5-turbo"),
|
||||
serde_json::json!("gpt-4"),
|
||||
serde_json::json!("gpt-4-turbo-preview"),
|
||||
serde_json::json!("custom-model-name"),
|
||||
],
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ZedDotDevModel {
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
Self::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
|
||||
Self::Four => "gpt-4-0613",
|
||||
Self::FourTurbo => "gpt-4-1106-preview",
|
||||
Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
|
||||
Self::GptFour => "gpt-4",
|
||||
Self::GptFourTurbo => "gpt-4-turbo-preview",
|
||||
Self::Custom(id) => id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn short_name(&self) -> &'static str {
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
|
||||
Self::Four => "gpt-4",
|
||||
Self::FourTurbo => "gpt-4-turbo",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cycle(&self) -> Self {
|
||||
match self {
|
||||
Self::ThreePointFiveTurbo => Self::Four,
|
||||
Self::Four => Self::FourTurbo,
|
||||
Self::FourTurbo => Self::ThreePointFiveTurbo,
|
||||
Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
|
||||
Self::GptFour => "gpt-4",
|
||||
Self::GptFourTurbo => "gpt-4-turbo",
|
||||
Self::Custom(id) => id.as_str(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AssistantDockPosition {
|
||||
Left,
|
||||
#[default]
|
||||
Right,
|
||||
Bottom,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AssistantSettings {
|
||||
/// Whether to show the assistant panel button in the status bar.
|
||||
pub button: bool,
|
||||
/// Where to dock the assistant.
|
||||
pub dock: AssistantDockPosition,
|
||||
/// Default width in pixels when the assistant is docked to the left or right.
|
||||
pub default_width: Pixels,
|
||||
/// Default height in pixels when the assistant is docked to the bottom.
|
||||
pub default_height: Pixels,
|
||||
/// The default OpenAI model to use when starting new conversations.
|
||||
#[deprecated = "Please use `provider.default_model` instead."]
|
||||
pub default_open_ai_model: OpenAiModel,
|
||||
/// OpenAI API base URL to use when starting new conversations.
|
||||
#[deprecated = "Please use `provider.api_url` instead."]
|
||||
pub openai_api_url: String,
|
||||
/// The settings for the AI provider.
|
||||
pub provider: AiProviderSettings,
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
#[serde(tag = "name", rename_all = "snake_case")]
|
||||
pub enum AssistantProvider {
|
||||
#[serde(rename = "zed.dev")]
|
||||
ZedDotDev {
|
||||
#[serde(default)]
|
||||
default_model: ZedDotDevModel,
|
||||
},
|
||||
#[serde(rename = "openai")]
|
||||
OpenAi {
|
||||
#[serde(default)]
|
||||
default_model: OpenAiModel,
|
||||
#[serde(default = "open_ai_url")]
|
||||
api_url: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl AssistantSettings {
|
||||
pub fn provider_kind(&self) -> anyhow::Result<OpenAiCompletionProviderKind> {
|
||||
match &self.provider {
|
||||
AiProviderSettings::OpenAi(_) => Ok(OpenAiCompletionProviderKind::OpenAi),
|
||||
AiProviderSettings::AzureOpenAi(settings) => {
|
||||
let deployment_id = settings
|
||||
.deployment_id
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?;
|
||||
let api_version = settings
|
||||
.api_version
|
||||
.ok_or_else(|| anyhow!("no Azure OpenAI API version"))?;
|
||||
|
||||
Ok(OpenAiCompletionProviderKind::AzureOpenAi {
|
||||
deployment_id,
|
||||
api_version,
|
||||
})
|
||||
}
|
||||
impl Default for AssistantProvider {
|
||||
fn default() -> Self {
|
||||
Self::ZedDotDev {
|
||||
default_model: ZedDotDevModel::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn provider_api_url(&self) -> anyhow::Result<String> {
|
||||
match &self.provider {
|
||||
AiProviderSettings::OpenAi(settings) => Ok(settings
|
||||
.api_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| OPEN_AI_API_URL.to_string())),
|
||||
AiProviderSettings::AzureOpenAi(settings) => settings
|
||||
.api_url
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("no Azure OpenAI API URL")),
|
||||
}
|
||||
fn open_ai_url() -> String {
|
||||
"https://api.openai.com/v1".into()
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Deserialize, Serialize)]
|
||||
pub struct AssistantSettings {
|
||||
pub button: bool,
|
||||
pub dock: AssistantDockPosition,
|
||||
pub default_width: Pixels,
|
||||
pub default_height: Pixels,
|
||||
pub provider: AssistantProvider,
|
||||
}
|
||||
|
||||
/// Assistant panel settings
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
pub enum AssistantSettingsContent {
|
||||
Versioned(VersionedAssistantSettingsContent),
|
||||
Legacy(LegacyAssistantSettingsContent),
|
||||
}
|
||||
|
||||
impl JsonSchema for AssistantSettingsContent {
|
||||
fn schema_name() -> String {
|
||||
VersionedAssistantSettingsContent::schema_name()
|
||||
}
|
||||
|
||||
pub fn provider_model(&self) -> anyhow::Result<OpenAiModel> {
|
||||
match &self.provider {
|
||||
AiProviderSettings::OpenAi(settings) => {
|
||||
Ok(settings.default_model.unwrap_or(OpenAiModel::FourTurbo))
|
||||
}
|
||||
AiProviderSettings::AzureOpenAi(settings) => {
|
||||
let deployment_id = settings
|
||||
.deployment_id
|
||||
.as_deref()
|
||||
.ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?;
|
||||
fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
|
||||
VersionedAssistantSettingsContent::json_schema(gen)
|
||||
}
|
||||
|
||||
match deployment_id {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-preview
|
||||
"gpt-4" | "gpt-4-32k" => Ok(OpenAiModel::Four),
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-35
|
||||
"gpt-35-turbo" | "gpt-35-turbo-16k" | "gpt-35-turbo-instruct" => {
|
||||
Ok(OpenAiModel::ThreePointFiveTurbo)
|
||||
fn is_referenceable() -> bool {
|
||||
VersionedAssistantSettingsContent::is_referenceable()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AssistantSettingsContent {
|
||||
fn default() -> Self {
|
||||
Self::Versioned(VersionedAssistantSettingsContent::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl AssistantSettingsContent {
|
||||
fn upgrade(&self) -> AssistantSettingsContentV1 {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||
VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
|
||||
},
|
||||
AssistantSettingsContent::Legacy(settings) => {
|
||||
if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
|
||||
AssistantSettingsContentV1 {
|
||||
button: settings.button,
|
||||
dock: settings.dock,
|
||||
default_width: settings.default_width,
|
||||
default_height: settings.default_height,
|
||||
provider: Some(AssistantProvider::OpenAi {
|
||||
default_model: settings
|
||||
.default_open_ai_model
|
||||
.clone()
|
||||
.unwrap_or_default(),
|
||||
api_url: open_ai_api_url.clone(),
|
||||
}),
|
||||
}
|
||||
} else if let Some(open_ai_model) = settings.default_open_ai_model.clone() {
|
||||
AssistantSettingsContentV1 {
|
||||
button: settings.button,
|
||||
dock: settings.dock,
|
||||
default_width: settings.default_width,
|
||||
default_height: settings.default_height,
|
||||
provider: Some(AssistantProvider::OpenAi {
|
||||
default_model: open_ai_model,
|
||||
api_url: open_ai_url(),
|
||||
}),
|
||||
}
|
||||
} else {
|
||||
AssistantSettingsContentV1 {
|
||||
button: settings.button,
|
||||
dock: settings.dock,
|
||||
default_width: settings.default_width,
|
||||
default_height: settings.default_height,
|
||||
provider: None,
|
||||
}
|
||||
_ => Err(anyhow!(
|
||||
"no matching OpenAI model found for deployment ID: '{deployment_id}'"
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn provider_model_name(&self) -> anyhow::Result<String> {
|
||||
match &self.provider {
|
||||
AiProviderSettings::OpenAi(settings) => Ok(settings
|
||||
.default_model
|
||||
.unwrap_or(OpenAiModel::FourTurbo)
|
||||
.full_name()
|
||||
.to_string()),
|
||||
AiProviderSettings::AzureOpenAi(settings) => settings
|
||||
.deployment_id
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("no Azure OpenAI deployment ID")),
|
||||
pub fn set_dock(&mut self, dock: AssistantDockPosition) {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||
VersionedAssistantSettingsContent::V1(settings) => {
|
||||
settings.dock = Some(dock);
|
||||
}
|
||||
},
|
||||
AssistantSettingsContent::Legacy(settings) => {
|
||||
settings.dock = Some(dock);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Settings for AssistantSettings {
|
||||
const KEY: Option<&'static str> = Some("assistant");
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
#[serde(tag = "version")]
|
||||
pub enum VersionedAssistantSettingsContent {
|
||||
#[serde(rename = "1")]
|
||||
V1(AssistantSettingsContentV1),
|
||||
}
|
||||
|
||||
type FileContent = AssistantSettingsContent;
|
||||
|
||||
fn load(
|
||||
default_value: &Self::FileContent,
|
||||
user_values: &[&Self::FileContent],
|
||||
_: &mut gpui::AppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
Self::load_via_json_merge(default_value, user_values)
|
||||
impl Default for VersionedAssistantSettingsContent {
|
||||
fn default() -> Self {
|
||||
Self::V1(AssistantSettingsContentV1 {
|
||||
button: None,
|
||||
dock: None,
|
||||
default_width: None,
|
||||
default_height: None,
|
||||
provider: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Assistant panel settings
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
pub struct AssistantSettingsContent {
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
pub struct AssistantSettingsContentV1 {
|
||||
/// Whether to show the assistant panel button in the status bar.
|
||||
///
|
||||
/// Default: true
|
||||
button: Option<bool>,
|
||||
/// Where to dock the assistant.
|
||||
///
|
||||
/// Default: right
|
||||
dock: Option<AssistantDockPosition>,
|
||||
/// Default width in pixels when the assistant is docked to the left or right.
|
||||
///
|
||||
/// Default: 640
|
||||
default_width: Option<f32>,
|
||||
/// Default height in pixels when the assistant is docked to the bottom.
|
||||
///
|
||||
/// Default: 320
|
||||
default_height: Option<f32>,
|
||||
/// The provider of the assistant service.
|
||||
///
|
||||
/// This can either be the internal `zed.dev` service or an external `openai` service,
|
||||
/// each with their respective default models and configurations.
|
||||
provider: Option<AssistantProvider>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
pub struct LegacyAssistantSettingsContent {
|
||||
/// Whether to show the assistant panel button in the status bar.
|
||||
///
|
||||
/// Default: true
|
||||
|
@ -180,88 +307,164 @@ pub struct AssistantSettingsContent {
|
|||
///
|
||||
/// Default: 320
|
||||
pub default_height: Option<f32>,
|
||||
/// Deprecated: Please use `provider.default_model` instead.
|
||||
/// The default OpenAI model to use when starting new conversations.
|
||||
///
|
||||
/// Default: gpt-4-1106-preview
|
||||
#[deprecated = "Please use `provider.default_model` instead."]
|
||||
pub default_open_ai_model: Option<OpenAiModel>,
|
||||
/// Deprecated: Please use `provider.api_url` instead.
|
||||
/// OpenAI API base URL to use when starting new conversations.
|
||||
///
|
||||
/// Default: https://api.openai.com/v1
|
||||
#[deprecated = "Please use `provider.api_url` instead."]
|
||||
pub openai_api_url: Option<String>,
|
||||
/// The settings for the AI provider.
|
||||
#[serde(default)]
|
||||
pub provider: AiProviderSettingsContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum AiProviderSettings {
|
||||
/// The settings for the OpenAI provider.
|
||||
#[serde(rename = "openai")]
|
||||
OpenAi(OpenAiProviderSettings),
|
||||
/// The settings for the Azure OpenAI provider.
|
||||
#[serde(rename = "azure_openai")]
|
||||
AzureOpenAi(AzureOpenAiProviderSettings),
|
||||
}
|
||||
impl Settings for AssistantSettings {
|
||||
const KEY: Option<&'static str> = Some("assistant");
|
||||
|
||||
/// The settings for the AI provider used by the Zed Assistant.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum AiProviderSettingsContent {
|
||||
/// The settings for the OpenAI provider.
|
||||
#[serde(rename = "openai")]
|
||||
OpenAi(OpenAiProviderSettingsContent),
|
||||
/// The settings for the Azure OpenAI provider.
|
||||
#[serde(rename = "azure_openai")]
|
||||
AzureOpenAi(AzureOpenAiProviderSettingsContent),
|
||||
}
|
||||
type FileContent = AssistantSettingsContent;
|
||||
|
||||
impl Default for AiProviderSettingsContent {
|
||||
fn default() -> Self {
|
||||
Self::OpenAi(OpenAiProviderSettingsContent::default())
|
||||
fn load(
|
||||
default_value: &Self::FileContent,
|
||||
user_values: &[&Self::FileContent],
|
||||
_: &mut gpui::AppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
let mut settings = AssistantSettings::default();
|
||||
|
||||
for value in [default_value].iter().chain(user_values) {
|
||||
let value = value.upgrade();
|
||||
merge(&mut settings.button, value.button);
|
||||
merge(&mut settings.dock, value.dock);
|
||||
merge(
|
||||
&mut settings.default_width,
|
||||
value.default_width.map(Into::into),
|
||||
);
|
||||
merge(
|
||||
&mut settings.default_height,
|
||||
value.default_height.map(Into::into),
|
||||
);
|
||||
if let Some(provider) = value.provider.clone() {
|
||||
match (&mut settings.provider, provider) {
|
||||
(
|
||||
AssistantProvider::ZedDotDev { default_model },
|
||||
AssistantProvider::ZedDotDev {
|
||||
default_model: default_model_override,
|
||||
},
|
||||
) => {
|
||||
*default_model = default_model_override;
|
||||
}
|
||||
(
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
},
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: default_model_override,
|
||||
api_url: api_url_override,
|
||||
},
|
||||
) => {
|
||||
*default_model = default_model_override;
|
||||
*api_url = api_url_override;
|
||||
}
|
||||
(merged, provider_override) => {
|
||||
*merged = provider_override;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OpenAiProviderSettings {
|
||||
/// The OpenAI API base URL to use when starting new conversations.
|
||||
pub api_url: Option<String>,
|
||||
/// The default OpenAI model to use when starting new conversations.
|
||||
pub default_model: Option<OpenAiModel>,
|
||||
fn merge<T: Copy>(target: &mut T, value: Option<T>) {
|
||||
if let Some(value) = value {
|
||||
*target = value;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct OpenAiProviderSettingsContent {
|
||||
/// The OpenAI API base URL to use when starting new conversations.
|
||||
///
|
||||
/// Default: https://api.openai.com/v1
|
||||
pub api_url: Option<String>,
|
||||
/// The default OpenAI model to use when starting new conversations.
|
||||
///
|
||||
/// Default: gpt-4-1106-preview
|
||||
pub default_model: Option<OpenAiModel>,
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use gpui::AppContext;
|
||||
use settings::SettingsStore;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AzureOpenAiProviderSettings {
|
||||
/// The Azure OpenAI API base URL to use when starting new conversations.
|
||||
pub api_url: Option<String>,
|
||||
/// The Azure OpenAI API version.
|
||||
pub api_version: Option<AzureOpenAiApiVersion>,
|
||||
/// The Azure OpenAI API deployment ID.
|
||||
pub deployment_id: Option<String>,
|
||||
}
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AzureOpenAiProviderSettingsContent {
|
||||
/// The Azure OpenAI API base URL to use when starting new conversations.
|
||||
pub api_url: Option<String>,
|
||||
/// The Azure OpenAI API version.
|
||||
pub api_version: Option<AzureOpenAiApiVersion>,
|
||||
/// The Azure OpenAI deployment ID.
|
||||
pub deployment_id: Option<String>,
|
||||
#[gpui::test]
|
||||
fn test_deserialize_assistant_settings(cx: &mut AppContext) {
|
||||
let store = settings::SettingsStore::test(cx);
|
||||
cx.set_global(store);
|
||||
|
||||
// Settings default to gpt-4-turbo.
|
||||
AssistantSettings::register(cx);
|
||||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: OpenAiModel::FourTurbo,
|
||||
api_url: open_ai_url()
|
||||
}
|
||||
);
|
||||
|
||||
// Ensure backward-compatibility.
|
||||
cx.update_global::<SettingsStore, _>(|store, cx| {
|
||||
store
|
||||
.set_user_settings(
|
||||
r#"{
|
||||
"assistant": {
|
||||
"openai_api_url": "test-url",
|
||||
}
|
||||
}"#,
|
||||
cx,
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: OpenAiModel::FourTurbo,
|
||||
api_url: "test-url".into()
|
||||
}
|
||||
);
|
||||
cx.update_global::<SettingsStore, _>(|store, cx| {
|
||||
store
|
||||
.set_user_settings(
|
||||
r#"{
|
||||
"assistant": {
|
||||
"default_open_ai_model": "gpt-4-0613"
|
||||
}
|
||||
}"#,
|
||||
cx,
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: OpenAiModel::Four,
|
||||
api_url: open_ai_url()
|
||||
}
|
||||
);
|
||||
|
||||
// The new version supports setting a custom model when using zed.dev.
|
||||
cx.update_global::<SettingsStore, _>(|store, cx| {
|
||||
store
|
||||
.set_user_settings(
|
||||
r#"{
|
||||
"assistant": {
|
||||
"version": "1",
|
||||
"provider": {
|
||||
"name": "zed.dev",
|
||||
"default_model": "custom"
|
||||
}
|
||||
}
|
||||
}"#,
|
||||
cx,
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::ZedDotDev {
|
||||
default_model: ZedDotDevModel::Custom("custom".into())
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
use crate::streaming_diff::{Hunk, StreamingDiff};
|
||||
use ai::completion::{CompletionProvider, CompletionRequest};
|
||||
use crate::{
|
||||
streaming_diff::{Hunk, StreamingDiff},
|
||||
CompletionProvider, LanguageModelRequest,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
|
||||
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
|
||||
use gpui::{EventEmitter, Model, ModelContext, Task};
|
||||
use language::{Rope, TransactionId};
|
||||
use multi_buffer;
|
||||
use std::{cmp, future, ops::Range, sync::Arc};
|
||||
use std::{cmp, future, ops::Range};
|
||||
|
||||
pub enum Event {
|
||||
Finished,
|
||||
|
@ -20,7 +21,6 @@ pub enum CodegenKind {
|
|||
}
|
||||
|
||||
pub struct Codegen {
|
||||
provider: Arc<dyn CompletionProvider>,
|
||||
buffer: Model<MultiBuffer>,
|
||||
snapshot: MultiBufferSnapshot,
|
||||
kind: CodegenKind,
|
||||
|
@ -35,15 +35,9 @@ pub struct Codegen {
|
|||
impl EventEmitter<Event> for Codegen {}
|
||||
|
||||
impl Codegen {
|
||||
pub fn new(
|
||||
buffer: Model<MultiBuffer>,
|
||||
kind: CodegenKind,
|
||||
provider: Arc<dyn CompletionProvider>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
pub fn new(buffer: Model<MultiBuffer>, kind: CodegenKind, cx: &mut ModelContext<Self>) -> Self {
|
||||
let snapshot = buffer.read(cx).snapshot(cx);
|
||||
Self {
|
||||
provider,
|
||||
buffer: buffer.clone(),
|
||||
snapshot,
|
||||
kind,
|
||||
|
@ -94,7 +88,7 @@ impl Codegen {
|
|||
self.error.as_ref()
|
||||
}
|
||||
|
||||
pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
|
||||
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
|
||||
let range = self.range();
|
||||
let snapshot = self.snapshot.clone();
|
||||
let selected_text = snapshot
|
||||
|
@ -108,7 +102,7 @@ impl Codegen {
|
|||
.next()
|
||||
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
|
||||
|
||||
let response = self.provider.complete(prompt);
|
||||
let response = CompletionProvider::global(cx).complete(prompt);
|
||||
self.generation = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let generate = async {
|
||||
|
@ -305,7 +299,7 @@ fn strip_invalid_spans_from_codeblock(
|
|||
}
|
||||
|
||||
if first_line {
|
||||
if buffer == "" || buffer == "`" || buffer == "``" {
|
||||
if buffer.is_empty() || buffer == "`" || buffer == "``" {
|
||||
return future::ready(None);
|
||||
} else if buffer.starts_with("```") {
|
||||
starts_with_markdown_codeblock = true;
|
||||
|
@ -360,8 +354,9 @@ fn strip_invalid_spans_from_codeblock(
|
|||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::FakeCompletionProvider;
|
||||
|
||||
use super::*;
|
||||
use ai::test::FakeCompletionProvider;
|
||||
use futures::stream::{self};
|
||||
use gpui::{Context, TestAppContext};
|
||||
use indoc::indoc;
|
||||
|
@ -378,15 +373,11 @@ mod tests {
|
|||
pub name: String,
|
||||
}
|
||||
|
||||
impl CompletionRequest for DummyCompletionRequest {
|
||||
fn data(&self) -> serde_json::Result<String> {
|
||||
serde_json::to_string(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||
let provider = FakeCompletionProvider::default();
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.set_global(CompletionProvider::Fake(provider.clone()));
|
||||
cx.update(language_settings::init);
|
||||
|
||||
let text = indoc! {"
|
||||
|
@ -405,19 +396,10 @@ mod tests {
|
|||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||
});
|
||||
let provider = Arc::new(FakeCompletionProvider::new());
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
CodegenKind::Transform { range },
|
||||
provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Transform { range }, cx));
|
||||
|
||||
let request = Box::new(DummyCompletionRequest {
|
||||
name: "test".to_string(),
|
||||
});
|
||||
let request = LanguageModelRequest::default();
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
|
@ -430,8 +412,7 @@ mod tests {
|
|||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
println!("CHUNK: {:?}", &chunk);
|
||||
provider.send_completion(chunk);
|
||||
provider.send_completion(chunk.into());
|
||||
new_text = suffix;
|
||||
cx.background_executor.run_until_parked();
|
||||
}
|
||||
|
@ -456,6 +437,8 @@ mod tests {
|
|||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
) {
|
||||
let provider = FakeCompletionProvider::default();
|
||||
cx.set_global(CompletionProvider::Fake(provider.clone()));
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
|
||||
|
@ -472,19 +455,10 @@ mod tests {
|
|||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 6))
|
||||
});
|
||||
let provider = Arc::new(FakeCompletionProvider::new());
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
CodegenKind::Generate { position },
|
||||
provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
|
||||
|
||||
let request = Box::new(DummyCompletionRequest {
|
||||
name: "test".to_string(),
|
||||
});
|
||||
let request = LanguageModelRequest::default();
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
|
@ -497,7 +471,7 @@ mod tests {
|
|||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
provider.send_completion(chunk);
|
||||
provider.send_completion(chunk.into());
|
||||
new_text = suffix;
|
||||
cx.background_executor.run_until_parked();
|
||||
}
|
||||
|
@ -522,6 +496,8 @@ mod tests {
|
|||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
) {
|
||||
let provider = FakeCompletionProvider::default();
|
||||
cx.set_global(CompletionProvider::Fake(provider.clone()));
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
|
||||
|
@ -538,19 +514,10 @@ mod tests {
|
|||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 2))
|
||||
});
|
||||
let provider = Arc::new(FakeCompletionProvider::new());
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
CodegenKind::Generate { position },
|
||||
provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
|
||||
|
||||
let request = Box::new(DummyCompletionRequest {
|
||||
name: "test".to_string(),
|
||||
});
|
||||
let request = LanguageModelRequest::default();
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
|
@ -563,8 +530,7 @@ mod tests {
|
|||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
println!("{:?}", &chunk);
|
||||
provider.send_completion(chunk);
|
||||
provider.send_completion(chunk.into());
|
||||
new_text = suffix;
|
||||
cx.background_executor.run_until_parked();
|
||||
}
|
||||
|
|
188
crates/assistant/src/completion_provider.rs
Normal file
188
crates/assistant/src/completion_provider.rs
Normal file
|
@ -0,0 +1,188 @@
|
|||
#[cfg(test)]
|
||||
mod fake;
|
||||
mod open_ai;
|
||||
mod zed;
|
||||
|
||||
#[cfg(test)]
|
||||
pub use fake::*;
|
||||
pub use open_ai::*;
|
||||
pub use zed::*;
|
||||
|
||||
use crate::{
|
||||
assistant_settings::{AssistantProvider, AssistantSettings},
|
||||
LanguageModel, LanguageModelRequest,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use client::Client;
|
||||
use futures::{future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, AppContext, Task, WindowContext};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
let mut settings_version = 0;
|
||||
let provider = match &AssistantSettings::get_global(cx).provider {
|
||||
AssistantProvider::ZedDotDev { default_model } => {
|
||||
CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
client.clone(),
|
||||
settings_version,
|
||||
cx,
|
||||
))
|
||||
}
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
settings_version,
|
||||
)),
|
||||
};
|
||||
cx.set_global(provider);
|
||||
|
||||
cx.observe_global::<SettingsStore>(move |cx| {
|
||||
settings_version += 1;
|
||||
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
||||
match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
|
||||
(
|
||||
CompletionProvider::OpenAi(provider),
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
},
|
||||
) => {
|
||||
provider.update(default_model.clone(), api_url.clone(), settings_version);
|
||||
}
|
||||
(
|
||||
CompletionProvider::ZedDotDev(provider),
|
||||
AssistantProvider::ZedDotDev { default_model },
|
||||
) => {
|
||||
provider.update(default_model.clone(), settings_version);
|
||||
}
|
||||
(CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => {
|
||||
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
client.clone(),
|
||||
settings_version,
|
||||
cx,
|
||||
));
|
||||
}
|
||||
(
|
||||
CompletionProvider::ZedDotDev(_),
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
},
|
||||
) => {
|
||||
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
settings_version,
|
||||
));
|
||||
}
|
||||
#[cfg(test)]
|
||||
(CompletionProvider::Fake(_), _) => unimplemented!(),
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
pub enum CompletionProvider {
|
||||
OpenAi(OpenAiCompletionProvider),
|
||||
ZedDotDev(ZedDotDevCompletionProvider),
|
||||
#[cfg(test)]
|
||||
Fake(FakeCompletionProvider),
|
||||
}
|
||||
|
||||
impl gpui::Global for CompletionProvider {}
|
||||
|
||||
impl CompletionProvider {
|
||||
pub fn global(cx: &AppContext) -> &Self {
|
||||
cx.global::<Self>()
|
||||
}
|
||||
|
||||
pub fn settings_version(&self) -> usize {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.settings_version(),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => Task::ready(Ok(())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
|
||||
CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => Task::ready(Ok(())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_model(&self) -> LanguageModel {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
|
||||
CompletionProvider::ZedDotDev(provider) => {
|
||||
LanguageModel::ZedDotDev(provider.default_model())
|
||||
}
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn complete(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.complete(request),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.complete(request),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(provider) => provider.complete(),
|
||||
}
|
||||
}
|
||||
}
|
29
crates/assistant/src/completion_provider/fake.rs
Normal file
29
crates/assistant/src/completion_provider/fake.rs
Normal file
|
@ -0,0 +1,29 @@
|
|||
use anyhow::Result;
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct FakeCompletionProvider {
|
||||
current_completion_tx: Arc<parking_lot::Mutex<Option<mpsc::UnboundedSender<String>>>>,
|
||||
}
|
||||
|
||||
impl FakeCompletionProvider {
|
||||
pub fn complete(&self) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
*self.current_completion_tx.lock() = Some(tx);
|
||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||
}
|
||||
|
||||
pub fn send_completion(&self, chunk: String) {
|
||||
self.current_completion_tx
|
||||
.lock()
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.unbounded_send(chunk)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn finish_completion(&self) {
|
||||
self.current_completion_tx.lock().take();
|
||||
}
|
||||
}
|
301
crates/assistant/src/completion_provider/open_ai.rs
Normal file
301
crates/assistant/src/completion_provider/open_ai.rs
Normal file
|
@ -0,0 +1,301 @@
|
|||
use crate::{
|
||||
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
|
||||
use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
|
||||
use settings::Settings;
|
||||
use std::{env, sync::Arc};
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use util::{http::HttpClient, ResultExt};
|
||||
|
||||
pub struct OpenAiCompletionProvider {
|
||||
api_key: Option<String>,
|
||||
api_url: String,
|
||||
default_model: OpenAiModel,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
settings_version: usize,
|
||||
}
|
||||
|
||||
impl OpenAiCompletionProvider {
|
||||
pub fn new(
|
||||
default_model: OpenAiModel,
|
||||
api_url: String,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
settings_version: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_key: None,
|
||||
api_url,
|
||||
default_model,
|
||||
http_client,
|
||||
settings_version,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) {
|
||||
self.default_model = default_model;
|
||||
self.api_url = api_url;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
pub fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated() {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
let api_url = self.api_url.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||
api_key
|
||||
} else {
|
||||
let (_, api_key) = cx
|
||||
.update(|cx| cx.read_credentials(&api_url))?
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||
String::from_utf8(api_key)?
|
||||
};
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
if let CompletionProvider::OpenAi(provider) = provider {
|
||||
provider.api_key = Some(api_key);
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let delete_credentials = cx.delete_credentials(&self.api_url);
|
||||
cx.spawn(|mut cx| async move {
|
||||
delete_credentials.await.log_err();
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
if let CompletionProvider::OpenAi(provider) = provider {
|
||||
provider.api_key = None;
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
pub fn default_model(&self) -> OpenAiModel {
|
||||
self.default_model.clone()
|
||||
}
|
||||
|
||||
pub fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
|
||||
pub fn complete(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_open_ai_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
||||
let model = match request.model {
|
||||
LanguageModel::ZedDotDev(_) => self.default_model(),
|
||||
LanguageModel::OpenAi(model) => model,
|
||||
};
|
||||
|
||||
Request {
|
||||
model,
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| RequestMessage {
|
||||
role: msg.role.into(),
|
||||
content: msg.content,
|
||||
})
|
||||
.collect(),
|
||||
stream: true,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_open_ai_tokens(
|
||||
request: LanguageModelRequest,
|
||||
background_executor: &gpui::BackgroundExecutor,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
background_executor
|
||||
.spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.content),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
impl From<Role> for open_ai::Role {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => OpenAiRole::User,
|
||||
Role::Assistant => OpenAiRole::Assistant,
|
||||
Role::System => OpenAiRole::System,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
api_key: View<Editor>,
|
||||
api_url: String,
|
||||
}
|
||||
|
||||
impl AuthenticationPrompt {
|
||||
fn new(api_url: String, cx: &mut WindowContext) -> Self {
|
||||
Self {
|
||||
api_key: cx.new_view(|cx| {
|
||||
let mut editor = Editor::single_line(cx);
|
||||
editor.set_placeholder_text(
|
||||
"sk-000000000000000000000000000000000000000000000000",
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
}),
|
||||
api_url,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||
let api_key = self.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
write_credentials.await?;
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
if let CompletionProvider::OpenAi(provider) = provider {
|
||||
provider.api_key = Some(api_key);
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features,
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: FontWeight::NORMAL,
|
||||
font_style: FontStyle::Normal,
|
||||
line_height: relative(1.3),
|
||||
background_color: None,
|
||||
underline: None,
|
||||
strikethrough: None,
|
||||
white_space: WhiteSpace::Normal,
|
||||
};
|
||||
EditorElement::new(
|
||||
&self.api_key,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const INSTRUCTIONS: [&str; 6] = [
|
||||
"To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
|
||||
" - You can create an API key at: platform.openai.com/api-keys",
|
||||
" - Make sure your OpenAI account has credits",
|
||||
" - Having a subscription for another service like GitHub Copilot won't work.",
|
||||
"",
|
||||
"Paste your OpenAI API key below and hit enter to use the assistant:",
|
||||
];
|
||||
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.children(
|
||||
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.my_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(self.render_api_key_editor(cx)),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
"You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
|
||||
)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Label::new("Click on").size(LabelSize::Small))
|
||||
.child(Icon::new(IconName::Ai).size(IconSize::XSmall))
|
||||
.child(
|
||||
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}
|
167
crates/assistant/src/completion_provider/zed.rs
Normal file
167
crates/assistant/src/completion_provider/zed.rs
Normal file
|
@ -0,0 +1,167 @@
|
|||
use crate::{
|
||||
assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider,
|
||||
LanguageModelRequest,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use client::{proto, Client};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
||||
use gpui::{AnyView, AppContext, Task};
|
||||
use std::{future, sync::Arc};
|
||||
use ui::prelude::*;
|
||||
|
||||
pub struct ZedDotDevCompletionProvider {
|
||||
client: Arc<Client>,
|
||||
default_model: ZedDotDevModel,
|
||||
settings_version: usize,
|
||||
status: client::Status,
|
||||
_maintain_client_status: Task<()>,
|
||||
}
|
||||
|
||||
impl ZedDotDevCompletionProvider {
|
||||
pub fn new(
|
||||
default_model: ZedDotDevModel,
|
||||
client: Arc<Client>,
|
||||
settings_version: usize,
|
||||
cx: &mut AppContext,
|
||||
) -> Self {
|
||||
let mut status_rx = client.status();
|
||||
let status = *status_rx.borrow();
|
||||
let maintain_client_status = cx.spawn(|mut cx| async move {
|
||||
while let Some(status) = status_rx.next().await {
|
||||
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
if let CompletionProvider::ZedDotDev(provider) = provider {
|
||||
provider.status = status;
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
Self {
|
||||
client,
|
||||
default_model,
|
||||
settings_version,
|
||||
status,
|
||||
_maintain_client_status: maintain_client_status,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(&mut self, default_model: ZedDotDevModel, settings_version: usize) {
|
||||
self.default_model = default_model;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
pub fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
pub fn default_model(&self) -> ZedDotDevModel {
|
||||
self.default_model.clone()
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
self.status.is_connected()
|
||||
}
|
||||
|
||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let client = self.client.clone();
|
||||
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
|
||||
}
|
||||
|
||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|_cx| AuthenticationPrompt).into()
|
||||
}
|
||||
|
||||
pub fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match request.model {
|
||||
crate::LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
||||
crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFour)
|
||||
| crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFourTurbo)
|
||||
| crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptThreePointFiveTurbo) => {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
crate::LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
|
||||
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
||||
model,
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
});
|
||||
async move {
|
||||
let response = request.await?;
|
||||
Ok(response.token_count as usize)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn complete(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = proto::CompleteWithLanguageModel {
|
||||
model: request.model.id().to_string(),
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
};
|
||||
|
||||
self.client
|
||||
.request_stream(request)
|
||||
.map_ok(|stream| {
|
||||
stream
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt;
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
|
||||
|
||||
v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Button::new("sign_in", "Sign in")
|
||||
.icon_color(Color::Muted)
|
||||
.icon(IconName::Github)
|
||||
.icon_position(IconPosition::Start)
|
||||
.style(ButtonStyle::Filled)
|
||||
.full_width()
|
||||
.on_click(|_, cx| {
|
||||
CompletionProvider::global(cx)
|
||||
.authenticate(cx)
|
||||
.detach_and_log_err(cx);
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
div().flex().w_full().items_center().child(
|
||||
Label::new("Sign in to enable collaboration.")
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
|
@ -1,394 +1,95 @@
|
|||
use ai::models::LanguageModel;
|
||||
use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
|
||||
use ai::prompts::file_context::FileContext;
|
||||
use ai::prompts::generate::GenerateInlineContent;
|
||||
use ai::prompts::preamble::EngineerPreamble;
|
||||
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
|
||||
use ai::providers::open_ai::OpenAiLanguageModel;
|
||||
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
||||
use std::cmp::{self, Reverse};
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
|
||||
#[derive(Debug)]
|
||||
struct Match {
|
||||
collapse: Range<usize>,
|
||||
keep: Vec<Range<usize>>,
|
||||
}
|
||||
|
||||
let selected_range = selected_range.to_offset(buffer);
|
||||
let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
|
||||
Some(&grammar.embedding_config.as_ref()?.query)
|
||||
});
|
||||
let configs = ts_matches
|
||||
.grammars()
|
||||
.iter()
|
||||
.map(|g| g.embedding_config.as_ref().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
let mut matches = Vec::new();
|
||||
while let Some(mat) = ts_matches.peek() {
|
||||
let config = &configs[mat.grammar_index];
|
||||
if let Some(collapse) = mat.captures.iter().find_map(|cap| {
|
||||
if Some(cap.index) == config.collapse_capture_ix {
|
||||
Some(cap.node.byte_range())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}) {
|
||||
let mut keep = Vec::new();
|
||||
for capture in mat.captures.iter() {
|
||||
if Some(capture.index) == config.keep_capture_ix {
|
||||
keep.push(capture.node.byte_range());
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
ts_matches.advance();
|
||||
matches.push(Match { collapse, keep });
|
||||
} else {
|
||||
ts_matches.advance();
|
||||
}
|
||||
}
|
||||
matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
|
||||
let mut matches = matches.into_iter().peekable();
|
||||
|
||||
let mut summary = String::new();
|
||||
let mut offset = 0;
|
||||
let mut flushed_selection = false;
|
||||
while let Some(mat) = matches.next() {
|
||||
// Keep extending the collapsed range if the next match surrounds
|
||||
// the current one.
|
||||
while let Some(next_mat) = matches.peek() {
|
||||
if mat.collapse.start <= next_mat.collapse.start
|
||||
&& mat.collapse.end >= next_mat.collapse.end
|
||||
{
|
||||
matches.next().unwrap();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if offset > mat.collapse.start {
|
||||
// Skip collapsed nodes that have already been summarized.
|
||||
offset = cmp::max(offset, mat.collapse.end);
|
||||
continue;
|
||||
}
|
||||
|
||||
if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
|
||||
if !flushed_selection {
|
||||
// The collapsed node ends after the selection starts, so we'll flush the selection first.
|
||||
summary.extend(buffer.text_for_range(offset..selected_range.start));
|
||||
summary.push_str("<|S|");
|
||||
if selected_range.end == selected_range.start {
|
||||
summary.push_str(">");
|
||||
} else {
|
||||
summary.extend(buffer.text_for_range(selected_range.clone()));
|
||||
summary.push_str("|E|>");
|
||||
}
|
||||
offset = selected_range.end;
|
||||
flushed_selection = true;
|
||||
}
|
||||
|
||||
// If the selection intersects the collapsed node, we won't collapse it.
|
||||
if selected_range.end >= mat.collapse.start {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
summary.extend(buffer.text_for_range(offset..mat.collapse.start));
|
||||
for keep in mat.keep {
|
||||
summary.extend(buffer.text_for_range(keep));
|
||||
}
|
||||
offset = mat.collapse.end;
|
||||
}
|
||||
|
||||
// Flush selection if we haven't already done so.
|
||||
if !flushed_selection && offset <= selected_range.start {
|
||||
summary.extend(buffer.text_for_range(offset..selected_range.start));
|
||||
summary.push_str("<|S|");
|
||||
if selected_range.end == selected_range.start {
|
||||
summary.push_str(">");
|
||||
} else {
|
||||
summary.extend(buffer.text_for_range(selected_range.clone()));
|
||||
summary.push_str("|E|>");
|
||||
}
|
||||
offset = selected_range.end;
|
||||
}
|
||||
|
||||
summary.extend(buffer.text_for_range(offset..buffer.len()));
|
||||
summary
|
||||
}
|
||||
use language::BufferSnapshot;
|
||||
use std::{fmt::Write, ops::Range};
|
||||
|
||||
pub fn generate_content_prompt(
|
||||
user_prompt: String,
|
||||
language_name: Option<&str>,
|
||||
buffer: BufferSnapshot,
|
||||
range: Range<usize>,
|
||||
search_results: Vec<PromptCodeSnippet>,
|
||||
model: &str,
|
||||
project_name: Option<String>,
|
||||
) -> anyhow::Result<String> {
|
||||
// Using new Prompt Templates
|
||||
let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAiLanguageModel::load(model));
|
||||
let lang_name = if let Some(language_name) = language_name {
|
||||
Some(language_name.to_string())
|
||||
let mut prompt = String::new();
|
||||
|
||||
let content_type = match language_name {
|
||||
None | Some("Markdown" | "Plain Text") => {
|
||||
writeln!(prompt, "You are an expert engineer.")?;
|
||||
"Text"
|
||||
}
|
||||
Some(language_name) => {
|
||||
writeln!(prompt, "You are an expert {language_name} engineer.")?;
|
||||
writeln!(
|
||||
prompt,
|
||||
"Your answer MUST always and only be valid {}.",
|
||||
language_name
|
||||
)?;
|
||||
"Code"
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(project_name) = project_name {
|
||||
writeln!(
|
||||
prompt,
|
||||
"You are currently working inside the '{project_name}' project in code editor Zed."
|
||||
)?;
|
||||
}
|
||||
|
||||
// Include file content.
|
||||
for chunk in buffer.text_for_range(0..range.start) {
|
||||
prompt.push_str(chunk);
|
||||
}
|
||||
|
||||
if range.is_empty() {
|
||||
prompt.push_str("<|START|>");
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let args = PromptArguments {
|
||||
model: openai_model,
|
||||
language_name: lang_name.clone(),
|
||||
project_name,
|
||||
snippets: search_results.clone(),
|
||||
reserved_tokens: 1000,
|
||||
buffer: Some(buffer),
|
||||
selected_range: Some(range),
|
||||
user_prompt: Some(user_prompt.clone()),
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(RepositoryContext {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(FileContext {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Mandatory,
|
||||
Box::new(GenerateInlineContent {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
let (prompt, _) = chain.generate(true)?;
|
||||
|
||||
anyhow::Ok(prompt)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
use gpui::{AppContext, Context};
|
||||
use indoc::indoc;
|
||||
use language::{
|
||||
language_settings, tree_sitter_rust, Buffer, BufferId, Language, LanguageConfig,
|
||||
LanguageMatcher, Point,
|
||||
};
|
||||
use settings::SettingsStore;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub(crate) fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::language()),
|
||||
)
|
||||
.with_embedding_query(
|
||||
r#"
|
||||
(
|
||||
[(line_comment) (attribute_item)]* @context
|
||||
.
|
||||
[
|
||||
(struct_item
|
||||
name: (_) @name)
|
||||
|
||||
(enum_item
|
||||
name: (_) @name)
|
||||
|
||||
(impl_item
|
||||
trait: (_)? @name
|
||||
"for"? @name
|
||||
type: (_) @name)
|
||||
|
||||
(trait_item
|
||||
name: (_) @name)
|
||||
|
||||
(function_item
|
||||
name: (_) @name
|
||||
body: (block
|
||||
"{" @keep
|
||||
"}" @keep) @collapse)
|
||||
|
||||
(macro_definition
|
||||
name: (_) @name)
|
||||
] @item
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.unwrap()
|
||||
prompt.push_str("<|START|");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_outline_for_prompt(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language_settings::init(cx);
|
||||
let text = indoc! {"
|
||||
struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {
|
||||
let a = 1;
|
||||
let b = 2;
|
||||
Self { a, b }
|
||||
}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {
|
||||
self.a
|
||||
}
|
||||
|
||||
pub fn b(&self) -> usize {
|
||||
self.b
|
||||
}
|
||||
}
|
||||
"};
|
||||
let buffer = cx.new_model(|cx| {
|
||||
Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
|
||||
});
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
|
||||
assert_eq!(
|
||||
summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
|
||||
indoc! {"
|
||||
struct X {
|
||||
<|S|>a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {}
|
||||
|
||||
pub fn b(&self) -> usize {}
|
||||
}
|
||||
"}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
|
||||
indoc! {"
|
||||
struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {
|
||||
let <|S|a |E|>= 1;
|
||||
let b = 2;
|
||||
Self { a, b }
|
||||
}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {}
|
||||
|
||||
pub fn b(&self) -> usize {}
|
||||
}
|
||||
"}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
|
||||
indoc! {"
|
||||
struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
<|S|>
|
||||
fn new() -> Self {}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {}
|
||||
|
||||
pub fn b(&self) -> usize {}
|
||||
}
|
||||
"}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
|
||||
indoc! {"
|
||||
struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {}
|
||||
|
||||
pub fn b(&self) -> usize {}
|
||||
}
|
||||
<|S|>"}
|
||||
);
|
||||
|
||||
// Ensure nested functions get collapsed properly.
|
||||
let text = indoc! {"
|
||||
struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {
|
||||
let a = 1;
|
||||
let b = 2;
|
||||
Self { a, b }
|
||||
}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {
|
||||
let a = 30;
|
||||
fn nested() -> usize {
|
||||
3
|
||||
}
|
||||
self.a + nested()
|
||||
}
|
||||
|
||||
pub fn b(&self) -> usize {
|
||||
self.b
|
||||
}
|
||||
}
|
||||
"};
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
assert_eq!(
|
||||
summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
|
||||
indoc! {"
|
||||
<|S|>struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {}
|
||||
|
||||
pub fn b(&self) -> usize {}
|
||||
}
|
||||
"}
|
||||
);
|
||||
for chunk in buffer.text_for_range(range.clone()) {
|
||||
prompt.push_str(chunk);
|
||||
}
|
||||
|
||||
if !range.is_empty() {
|
||||
prompt.push_str("|END|>");
|
||||
}
|
||||
|
||||
for chunk in buffer.text_for_range(range.end..buffer.len()) {
|
||||
prompt.push_str(chunk);
|
||||
}
|
||||
|
||||
prompt.push('\n');
|
||||
|
||||
if range.is_empty() {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Assume the cursor is located where the `<|START|>` span is."
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"{content_type} can't be replaced, so assume your answer will be inserted at the cursor.",
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Generate {content_type} based on the users prompt: {user_prompt}",
|
||||
)
|
||||
.unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
|
||||
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Double check that you only return code and not the '<|START|' and '|END|'> spans"
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
writeln!(prompt, "Never make remarks about the output.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Do not return anything else, except the generated {content_type}."
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
|
|
121
crates/assistant/src/saved_conversation.rs
Normal file
121
crates/assistant/src/saved_conversation.rs
Normal file
|
@ -0,0 +1,121 @@
|
|||
use crate::{assistant_settings::OpenAiModel, MessageId, MessageMetadata};
|
||||
use anyhow::{anyhow, Result};
|
||||
use collections::HashMap;
|
||||
use fs::Fs;
|
||||
use futures::StreamExt;
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
cmp::Reverse,
|
||||
ffi::OsStr,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::paths::CONVERSATIONS_DIR;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SavedMessage {
|
||||
pub id: MessageId,
|
||||
pub start: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SavedConversation {
|
||||
pub id: Option<String>,
|
||||
pub zed: String,
|
||||
pub version: String,
|
||||
pub text: String,
|
||||
pub messages: Vec<SavedMessage>,
|
||||
pub message_metadata: HashMap<MessageId, MessageMetadata>,
|
||||
pub summary: String,
|
||||
}
|
||||
|
||||
impl SavedConversation {
|
||||
pub const VERSION: &'static str = "0.2.0";
|
||||
|
||||
pub async fn load(path: &Path, fs: &dyn Fs) -> Result<Self> {
|
||||
let saved_conversation = fs.load(path).await?;
|
||||
let saved_conversation_json =
|
||||
serde_json::from_str::<serde_json::Value>(&saved_conversation)?;
|
||||
match saved_conversation_json
|
||||
.get("version")
|
||||
.ok_or_else(|| anyhow!("version not found"))?
|
||||
{
|
||||
serde_json::Value::String(version) => match version.as_str() {
|
||||
Self::VERSION => Ok(serde_json::from_value::<Self>(saved_conversation_json)?),
|
||||
"0.1.0" => {
|
||||
let saved_conversation =
|
||||
serde_json::from_value::<SavedConversationV0_1_0>(saved_conversation_json)?;
|
||||
Ok(Self {
|
||||
id: saved_conversation.id,
|
||||
zed: saved_conversation.zed,
|
||||
version: saved_conversation.version,
|
||||
text: saved_conversation.text,
|
||||
messages: saved_conversation.messages,
|
||||
message_metadata: saved_conversation.message_metadata,
|
||||
summary: saved_conversation.summary,
|
||||
})
|
||||
}
|
||||
_ => Err(anyhow!(
|
||||
"unrecognized saved conversation version: {}",
|
||||
version
|
||||
)),
|
||||
},
|
||||
_ => Err(anyhow!("version not found on saved conversation")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SavedConversationV0_1_0 {
|
||||
id: Option<String>,
|
||||
zed: String,
|
||||
version: String,
|
||||
text: String,
|
||||
messages: Vec<SavedMessage>,
|
||||
message_metadata: HashMap<MessageId, MessageMetadata>,
|
||||
summary: String,
|
||||
api_url: Option<String>,
|
||||
model: OpenAiModel,
|
||||
}
|
||||
|
||||
pub struct SavedConversationMetadata {
|
||||
pub title: String,
|
||||
pub path: PathBuf,
|
||||
pub mtime: chrono::DateTime<chrono::Local>,
|
||||
}
|
||||
|
||||
impl SavedConversationMetadata {
|
||||
pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
|
||||
fs.create_dir(&CONVERSATIONS_DIR).await?;
|
||||
|
||||
let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
|
||||
let mut conversations = Vec::<SavedConversationMetadata>::new();
|
||||
while let Some(path) = paths.next().await {
|
||||
let path = path?;
|
||||
if path.extension() != Some(OsStr::new("json")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let pattern = r" - \d+.zed.json$";
|
||||
let re = Regex::new(pattern).unwrap();
|
||||
|
||||
let metadata = fs.metadata(&path).await?;
|
||||
if let Some((file_name, metadata)) = path
|
||||
.file_name()
|
||||
.and_then(|name| name.to_str())
|
||||
.zip(metadata)
|
||||
{
|
||||
let title = re.replace(file_name, "");
|
||||
conversations.push(Self {
|
||||
title: title.into_owned(),
|
||||
path,
|
||||
mtime: metadata.mtime.into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
|
||||
|
||||
Ok(conversations)
|
||||
}
|
||||
}
|
|
@ -197,12 +197,10 @@ impl StreamingDiff {
|
|||
} else {
|
||||
hunks.push(Hunk::Remove { len: char_len })
|
||||
}
|
||||
} else if let Some(Hunk::Keep { len }) = hunks.last_mut() {
|
||||
*len += char_len;
|
||||
} else {
|
||||
if let Some(Hunk::Keep { len }) = hunks.last_mut() {
|
||||
*len += char_len;
|
||||
} else {
|
||||
hunks.push(Hunk::Keep { len: char_len })
|
||||
}
|
||||
hunks.push(Hunk::Keep { len: char_len })
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ use async_tungstenite::tungstenite::{
|
|||
use clock::SystemClock;
|
||||
use collections::HashMap;
|
||||
use futures::{
|
||||
channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt,
|
||||
channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
|
||||
TryFutureExt as _, TryStreamExt,
|
||||
};
|
||||
use gpui::{
|
||||
|
@ -36,7 +36,10 @@ use std::{
|
|||
future::Future,
|
||||
marker::PhantomData,
|
||||
path::PathBuf,
|
||||
sync::{atomic::AtomicU64, Arc, Weak},
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc, Weak,
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use telemetry::Telemetry;
|
||||
|
@ -442,7 +445,7 @@ impl Client {
|
|||
}
|
||||
|
||||
pub fn id(&self) -> u64 {
|
||||
self.id.load(std::sync::atomic::Ordering::SeqCst)
|
||||
self.id.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
|
||||
|
@ -450,7 +453,7 @@ impl Client {
|
|||
}
|
||||
|
||||
pub fn set_id(&self, id: u64) -> &Self {
|
||||
self.id.store(id, std::sync::atomic::Ordering::SeqCst);
|
||||
self.id.store(id, Ordering::SeqCst);
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -1260,6 +1263,30 @@ impl Client {
|
|||
.map_ok(|envelope| envelope.payload)
|
||||
}
|
||||
|
||||
pub fn request_stream<T: RequestMessage>(
|
||||
&self,
|
||||
request: T,
|
||||
) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
|
||||
let client_id = self.id.load(Ordering::SeqCst);
|
||||
log::debug!(
|
||||
"rpc request start. client_id:{}. name:{}",
|
||||
client_id,
|
||||
T::NAME
|
||||
);
|
||||
let response = self
|
||||
.connection_id()
|
||||
.map(|conn_id| self.peer.request_stream(conn_id, request));
|
||||
async move {
|
||||
let response = response?.await;
|
||||
log::debug!(
|
||||
"rpc request finish. client_id:{}. name:{}",
|
||||
client_id,
|
||||
T::NAME
|
||||
);
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
pub fn request_envelope<T: RequestMessage>(
|
||||
&self,
|
||||
request: T,
|
||||
|
|
|
@ -261,7 +261,7 @@ impl Telemetry {
|
|||
self: &Arc<Self>,
|
||||
conversation_id: Option<String>,
|
||||
kind: AssistantKind,
|
||||
model: &str,
|
||||
model: String,
|
||||
) {
|
||||
let event = Event::Assistant(AssistantEvent {
|
||||
conversation_id,
|
||||
|
|
|
@ -31,10 +31,12 @@ collections.workspace = true
|
|||
dashmap = "5.4"
|
||||
envy = "0.4.2"
|
||||
futures.workspace = true
|
||||
google_ai.workspace = true
|
||||
hex.workspace = true
|
||||
live_kit_server.workspace = true
|
||||
log.workspace = true
|
||||
nanoid = "0.4"
|
||||
open_ai.workspace = true
|
||||
parking_lot.workspace = true
|
||||
prometheus = "0.13"
|
||||
prost.workspace = true
|
||||
|
@ -80,7 +82,6 @@ git = { workspace = true, features = ["test-support"] }
|
|||
gpui = { workspace = true, features = ["test-support"] }
|
||||
indoc.workspace = true
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
lazy_static.workspace = true
|
||||
live_kit_client = { workspace = true, features = ["test-support"] }
|
||||
lsp = { workspace = true, features = ["test-support"] }
|
||||
menu.workspace = true
|
||||
|
|
|
@ -379,6 +379,16 @@ CREATE TABLE extension_versions (
|
|||
CREATE UNIQUE INDEX "index_extensions_external_id" ON "extensions" ("external_id");
|
||||
CREATE INDEX "index_extensions_total_download_count" ON "extensions" ("total_download_count");
|
||||
|
||||
CREATE TABLE rate_buckets (
|
||||
user_id INT NOT NULL,
|
||||
rate_limit_name VARCHAR(255) NOT NULL,
|
||||
token_count INT NOT NULL,
|
||||
last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL,
|
||||
PRIMARY KEY (user_id, rate_limit_name),
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name);
|
||||
|
||||
CREATE TABLE hosted_projects (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
channel_id INTEGER NOT NULL REFERENCES channels(id),
|
||||
|
|
11
crates/collab/migrations/20240220234826_add_rate_buckets.sql
Normal file
11
crates/collab/migrations/20240220234826_add_rate_buckets.sql
Normal file
|
@ -0,0 +1,11 @@
|
|||
CREATE TABLE IF NOT EXISTS rate_buckets (
|
||||
user_id INT NOT NULL,
|
||||
rate_limit_name VARCHAR(255) NOT NULL,
|
||||
token_count INT NOT NULL,
|
||||
last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL,
|
||||
PRIMARY KEY (user_id, rate_limit_name),
|
||||
CONSTRAINT fk_user
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name);
|
75
crates/collab/src/ai.rs
Normal file
75
crates/collab/src/ai.rs
Normal file
|
@ -0,0 +1,75 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use rpc::proto;
|
||||
|
||||
pub fn language_model_request_to_open_ai(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
) -> Result<open_ai::Request> {
|
||||
Ok(open_ai::Request {
|
||||
model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo),
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| {
|
||||
let role = proto::LanguageModelRole::from_i32(message.role)
|
||||
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
|
||||
Ok(open_ai::RequestMessage {
|
||||
role: match role {
|
||||
proto::LanguageModelRole::LanguageModelUser => open_ai::Role::User,
|
||||
proto::LanguageModelRole::LanguageModelAssistant => {
|
||||
open_ai::Role::Assistant
|
||||
}
|
||||
proto::LanguageModelRole::LanguageModelSystem => open_ai::Role::System,
|
||||
},
|
||||
content: message.content,
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<open_ai::RequestMessage>>>()?,
|
||||
stream: true,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn language_model_request_to_google_ai(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
) -> Result<google_ai::GenerateContentRequest> {
|
||||
Ok(google_ai::GenerateContentRequest {
|
||||
contents: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(language_model_request_message_to_google_ai)
|
||||
.collect::<Result<Vec<_>>>()?,
|
||||
generation_config: None,
|
||||
safety_settings: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn language_model_request_message_to_google_ai(
|
||||
message: proto::LanguageModelRequestMessage,
|
||||
) -> Result<google_ai::Content> {
|
||||
let role = proto::LanguageModelRole::from_i32(message.role)
|
||||
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
|
||||
|
||||
Ok(google_ai::Content {
|
||||
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
|
||||
text: message.content,
|
||||
})],
|
||||
role: match role {
|
||||
proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
|
||||
proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
|
||||
proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn count_tokens_request_to_google_ai(
|
||||
request: proto::CountTokensWithLanguageModel,
|
||||
) -> Result<google_ai::CountTokensRequest> {
|
||||
Ok(google_ai::CountTokensRequest {
|
||||
contents: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(language_model_request_message_to_google_ai)
|
||||
.collect::<Result<Vec<_>>>()?,
|
||||
})
|
||||
}
|
|
@ -1,6 +1,5 @@
|
|||
use crate::{
|
||||
db::{ExtensionMetadata, NewExtensionVersion},
|
||||
executor::Executor,
|
||||
AppState, Error, Result,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _};
|
||||
|
@ -136,7 +135,7 @@ async fn download_extension(
|
|||
const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
|
||||
const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
|
||||
|
||||
pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>, executor: Executor) {
|
||||
pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
|
||||
let Some(blob_store_client) = app_state.blob_store_client.clone() else {
|
||||
log::info!("no blob store client");
|
||||
return;
|
||||
|
@ -146,6 +145,7 @@ pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>, e
|
|||
return;
|
||||
};
|
||||
|
||||
let executor = app_state.executor.clone();
|
||||
executor.spawn_detached({
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
|
|
|
@ -10,6 +10,7 @@ pub mod hosted_projects;
|
|||
pub mod messages;
|
||||
pub mod notifications;
|
||||
pub mod projects;
|
||||
pub mod rate_buckets;
|
||||
pub mod rooms;
|
||||
pub mod servers;
|
||||
pub mod users;
|
||||
|
|
58
crates/collab/src/db/queries/rate_buckets.rs
Normal file
58
crates/collab/src/db/queries/rate_buckets.rs
Normal file
|
@ -0,0 +1,58 @@
|
|||
use super::*;
|
||||
use crate::db::tables::rate_buckets;
|
||||
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
|
||||
|
||||
impl Database {
|
||||
/// Saves the rate limit for the given user and rate limit name if the last_refill is later
|
||||
/// than the currently saved timestamp.
|
||||
pub async fn save_rate_buckets(&self, buckets: &[rate_buckets::Model]) -> Result<()> {
|
||||
if buckets.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.transaction(|tx| async move {
|
||||
rate_buckets::Entity::insert_many(buckets.iter().map(|bucket| {
|
||||
rate_buckets::ActiveModel {
|
||||
user_id: ActiveValue::Set(bucket.user_id),
|
||||
rate_limit_name: ActiveValue::Set(bucket.rate_limit_name.clone()),
|
||||
token_count: ActiveValue::Set(bucket.token_count),
|
||||
last_refill: ActiveValue::Set(bucket.last_refill),
|
||||
}
|
||||
}))
|
||||
.on_conflict(
|
||||
OnConflict::columns([
|
||||
rate_buckets::Column::UserId,
|
||||
rate_buckets::Column::RateLimitName,
|
||||
])
|
||||
.update_columns([
|
||||
rate_buckets::Column::TokenCount,
|
||||
rate_buckets::Column::LastRefill,
|
||||
])
|
||||
.to_owned(),
|
||||
)
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Retrieves the rate limit for the given user and rate limit name.
|
||||
pub async fn get_rate_bucket(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
rate_limit_name: &str,
|
||||
) -> Result<Option<rate_buckets::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
let rate_limit = rate_buckets::Entity::find()
|
||||
.filter(rate_buckets::Column::UserId.eq(user_id))
|
||||
.filter(rate_buckets::Column::RateLimitName.eq(rate_limit_name))
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(rate_limit)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
|
@ -22,6 +22,7 @@ pub mod observed_buffer_edits;
|
|||
pub mod observed_channel_messages;
|
||||
pub mod project;
|
||||
pub mod project_collaborator;
|
||||
pub mod rate_buckets;
|
||||
pub mod room;
|
||||
pub mod room_participant;
|
||||
pub mod server;
|
||||
|
|
31
crates/collab/src/db/tables/rate_buckets.rs
Normal file
31
crates/collab/src/db/tables/rate_buckets.rs
Normal file
|
@ -0,0 +1,31 @@
|
|||
use crate::db::UserId;
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "rate_buckets")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
pub user_id: UserId,
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
pub rate_limit_name: String,
|
||||
pub token_count: i32,
|
||||
pub last_refill: DateTime,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::user::Entity",
|
||||
from = "Column::UserId",
|
||||
to = "super::user::Column::Id"
|
||||
)]
|
||||
User,
|
||||
}
|
||||
|
||||
impl Related<super::user::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::User.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,8 +1,10 @@
|
|||
pub mod ai;
|
||||
pub mod api;
|
||||
pub mod auth;
|
||||
pub mod db;
|
||||
pub mod env;
|
||||
pub mod executor;
|
||||
mod rate_limiter;
|
||||
pub mod rpc;
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -13,6 +15,7 @@ use aws_config::{BehaviorVersion, Region};
|
|||
use axum::{http::StatusCode, response::IntoResponse};
|
||||
use db::{ChannelId, Database};
|
||||
use executor::Executor;
|
||||
pub use rate_limiter::*;
|
||||
use serde::Deserialize;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
@ -126,6 +129,8 @@ pub struct Config {
|
|||
pub blob_store_secret_key: Option<String>,
|
||||
pub blob_store_bucket: Option<String>,
|
||||
pub zed_environment: Arc<str>,
|
||||
pub openai_api_key: Option<Arc<str>>,
|
||||
pub google_ai_api_key: Option<Arc<str>>,
|
||||
pub zed_client_checksum_seed: Option<String>,
|
||||
pub slack_panics_webhook: Option<String>,
|
||||
pub auto_join_channel_id: Option<ChannelId>,
|
||||
|
@ -147,12 +152,14 @@ pub struct AppState {
|
|||
pub db: Arc<Database>,
|
||||
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
|
||||
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
||||
pub rate_limiter: Arc<RateLimiter>,
|
||||
pub executor: Executor,
|
||||
pub clickhouse_client: Option<clickhouse::Client>,
|
||||
pub config: Config,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub async fn new(config: Config) -> Result<Arc<Self>> {
|
||||
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
|
||||
let mut db_options = db::ConnectOptions::new(config.database_url.clone());
|
||||
db_options.max_connections(config.database_max_connections);
|
||||
let mut db = Database::new(db_options, Executor::Production).await?;
|
||||
|
@ -173,10 +180,13 @@ impl AppState {
|
|||
None
|
||||
};
|
||||
|
||||
let db = Arc::new(db);
|
||||
let this = Self {
|
||||
db: Arc::new(db),
|
||||
db: db.clone(),
|
||||
live_kit_client,
|
||||
blob_store_client: build_blob_store_client(&config).await.log_err(),
|
||||
rate_limiter: Arc::new(RateLimiter::new(db)),
|
||||
executor,
|
||||
clickhouse_client: config
|
||||
.clickhouse_url
|
||||
.as_ref()
|
||||
|
|
|
@ -7,7 +7,7 @@ use axum::{
|
|||
};
|
||||
use collab::{
|
||||
api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
|
||||
Config, MigrateConfig, Result,
|
||||
Config, MigrateConfig, RateLimiter, Result,
|
||||
};
|
||||
use db::Database;
|
||||
use std::{
|
||||
|
@ -62,18 +62,27 @@ async fn main() -> Result<()> {
|
|||
|
||||
run_migrations().await?;
|
||||
|
||||
let state = AppState::new(config).await?;
|
||||
let state = AppState::new(config, Executor::Production).await?;
|
||||
|
||||
let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
|
||||
.expect("failed to bind TCP listener");
|
||||
|
||||
let epoch = state
|
||||
.db
|
||||
.create_server(&state.config.zed_environment)
|
||||
.await?;
|
||||
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
|
||||
rpc_server.start().await?;
|
||||
|
||||
fetch_extensions_from_blob_store_periodically(state.clone());
|
||||
RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone());
|
||||
|
||||
let rpc_server = if is_collab {
|
||||
let epoch = state
|
||||
.db
|
||||
.create_server(&state.config.zed_environment)
|
||||
.await?;
|
||||
let rpc_server =
|
||||
collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
|
||||
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
|
||||
rpc_server.start().await?;
|
||||
|
||||
Some(rpc_server)
|
||||
|
@ -82,7 +91,7 @@ async fn main() -> Result<()> {
|
|||
};
|
||||
|
||||
if is_api {
|
||||
fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production);
|
||||
fetch_extensions_from_blob_store_periodically(state.clone());
|
||||
}
|
||||
|
||||
let mut app = collab::api::routes(rpc_server.clone(), state.clone());
|
||||
|
|
274
crates/collab/src/rate_limiter.rs
Normal file
274
crates/collab/src/rate_limiter.rs
Normal file
|
@ -0,0 +1,274 @@
|
|||
use crate::{db::UserId, executor::Executor, Database, Error, Result};
|
||||
use anyhow::anyhow;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use dashmap::{DashMap, DashSet};
|
||||
use sea_orm::prelude::DateTimeUtc;
|
||||
use std::sync::Arc;
|
||||
use util::ResultExt;
|
||||
|
||||
pub trait RateLimit: 'static {
|
||||
fn capacity() -> usize;
|
||||
fn refill_duration() -> Duration;
|
||||
fn db_name() -> &'static str;
|
||||
}
|
||||
|
||||
/// Used to enforce per-user rate limits
|
||||
pub struct RateLimiter {
|
||||
buckets: DashMap<(UserId, String), RateBucket>,
|
||||
dirty_buckets: DashSet<(UserId, String)>,
|
||||
db: Arc<Database>,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
pub fn new(db: Arc<Database>) -> Self {
|
||||
RateLimiter {
|
||||
buckets: DashMap::new(),
|
||||
dirty_buckets: DashSet::new(),
|
||||
db,
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawns a new task that periodically saves rate limit data to the database.
|
||||
pub fn save_periodically(rate_limiter: Arc<Self>, executor: Executor) {
|
||||
const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
executor.clone().spawn_detached(async move {
|
||||
loop {
|
||||
executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await;
|
||||
rate_limiter.save().await.log_err();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Returns an error if the user has exceeded the specified `RateLimit`.
|
||||
/// Attempts to read the from the database if no cached RateBucket currently exists.
|
||||
pub async fn check<T: RateLimit>(&self, user_id: UserId) -> Result<()> {
|
||||
self.check_internal::<T>(user_id, Utc::now()).await
|
||||
}
|
||||
|
||||
async fn check_internal<T: RateLimit>(&self, user_id: UserId, now: DateTimeUtc) -> Result<()> {
|
||||
let bucket_key = (user_id, T::db_name().to_string());
|
||||
|
||||
// Attempt to fetch the bucket from the database if it hasn't been cached.
|
||||
// For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
|
||||
// but this enforces limits across restarts so long as the database is reachable.
|
||||
if !self.buckets.contains_key(&bucket_key) {
|
||||
if let Some(bucket) = self.load_bucket::<T>(user_id).await.log_err().flatten() {
|
||||
self.buckets.insert(bucket_key.clone(), bucket);
|
||||
self.dirty_buckets.insert(bucket_key.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let mut bucket = self
|
||||
.buckets
|
||||
.entry(bucket_key.clone())
|
||||
.or_insert_with(|| RateBucket::new(T::capacity(), T::refill_duration(), now));
|
||||
|
||||
if bucket.value_mut().allow(now) {
|
||||
self.dirty_buckets.insert(bucket_key);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("rate limit exceeded"))?
|
||||
}
|
||||
}
|
||||
|
||||
async fn load_bucket<K: RateLimit>(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<RateBucket>, Error> {
|
||||
Ok(self
|
||||
.db
|
||||
.get_rate_bucket(user_id, K::db_name())
|
||||
.await?
|
||||
.map(|saved_bucket| RateBucket {
|
||||
capacity: K::capacity(),
|
||||
refill_time_per_token: K::refill_duration(),
|
||||
token_count: saved_bucket.token_count as usize,
|
||||
last_refill: DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn save(&self) -> Result<()> {
|
||||
let mut buckets = Vec::new();
|
||||
self.dirty_buckets.retain(|key| {
|
||||
if let Some(bucket) = self.buckets.get(&key) {
|
||||
buckets.push(crate::db::rate_buckets::Model {
|
||||
user_id: key.0,
|
||||
rate_limit_name: key.1.clone(),
|
||||
token_count: bucket.token_count as i32,
|
||||
last_refill: bucket.last_refill.naive_utc(),
|
||||
});
|
||||
}
|
||||
false
|
||||
});
|
||||
|
||||
match self.db.save_rate_buckets(&buckets).await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(err) => {
|
||||
for bucket in buckets {
|
||||
self.dirty_buckets
|
||||
.insert((bucket.user_id, bucket.rate_limit_name));
|
||||
}
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct RateBucket {
|
||||
capacity: usize,
|
||||
token_count: usize,
|
||||
refill_time_per_token: Duration,
|
||||
last_refill: DateTimeUtc,
|
||||
}
|
||||
|
||||
impl RateBucket {
|
||||
fn new(capacity: usize, refill_duration: Duration, now: DateTimeUtc) -> Self {
|
||||
RateBucket {
|
||||
capacity,
|
||||
token_count: capacity,
|
||||
refill_time_per_token: refill_duration / capacity as i32,
|
||||
last_refill: now,
|
||||
}
|
||||
}
|
||||
|
||||
fn allow(&mut self, now: DateTimeUtc) -> bool {
|
||||
self.refill(now);
|
||||
if self.token_count > 0 {
|
||||
self.token_count -= 1;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn refill(&mut self, now: DateTimeUtc) {
|
||||
let elapsed = now - self.last_refill;
|
||||
if elapsed >= self.refill_time_per_token {
|
||||
let new_tokens =
|
||||
elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
|
||||
|
||||
self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
|
||||
self.last_refill = now;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::db::{NewUserParams, TestDb};
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_rate_limiter(cx: &mut TestAppContext) {
|
||||
let test_db = TestDb::sqlite(cx.executor().clone());
|
||||
let db = test_db.db().clone();
|
||||
let user_1 = db
|
||||
.create_user(
|
||||
"user-1@zed.dev",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user-1".into(),
|
||||
github_user_id: 1,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
let user_2 = db
|
||||
.create_user(
|
||||
"user-2@zed.dev",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user-2".into(),
|
||||
github_user_id: 2,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let mut now = Utc::now();
|
||||
|
||||
let rate_limiter = RateLimiter::new(db.clone());
|
||||
|
||||
// User 1 can access resource A two times before being rate-limited.
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
// User 2 can access resource A and user 1 can access resource B.
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitB>(user_2, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitB>(user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// After one second, user 1 can make another request before being rate-limited again.
|
||||
now += Duration::seconds(1);
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
rate_limiter.save().await.unwrap();
|
||||
|
||||
// Rate limits are reloaded from the database, so user A is still rate-limited
|
||||
// for resource A.
|
||||
let rate_limiter = RateLimiter::new(db.clone());
|
||||
rate_limiter
|
||||
.check_internal::<RateLimitA>(user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
|
||||
struct RateLimitA;
|
||||
|
||||
impl RateLimit for RateLimitA {
|
||||
fn capacity() -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn refill_duration() -> Duration {
|
||||
Duration::seconds(2)
|
||||
}
|
||||
|
||||
fn db_name() -> &'static str {
|
||||
"rate-limit-a"
|
||||
}
|
||||
}
|
||||
|
||||
struct RateLimitB;
|
||||
|
||||
impl RateLimit for RateLimitB {
|
||||
fn capacity() -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
fn refill_duration() -> Duration {
|
||||
Duration::seconds(3)
|
||||
}
|
||||
|
||||
fn db_name() -> &'static str {
|
||||
"rate-limit-b"
|
||||
}
|
||||
}
|
||||
}
|
|
@ -9,9 +9,9 @@ use crate::{
|
|||
User, UserId,
|
||||
},
|
||||
executor::Executor,
|
||||
AppState, Error, Result,
|
||||
AppState, Error, RateLimit, RateLimiter, Result,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
use anyhow::{anyhow, Context as _};
|
||||
use async_tungstenite::tungstenite::{
|
||||
protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
|
||||
};
|
||||
|
@ -30,6 +30,8 @@ use axum::{
|
|||
};
|
||||
use collections::{HashMap, HashSet};
|
||||
pub use connection_pool::{ConnectionPool, ZedVersion};
|
||||
use core::fmt::{self, Debug, Formatter};
|
||||
|
||||
use futures::{
|
||||
channel::oneshot,
|
||||
future::{self, BoxFuture},
|
||||
|
@ -39,15 +41,14 @@ use futures::{
|
|||
use prometheus::{register_int_gauge, IntGauge};
|
||||
use rpc::{
|
||||
proto::{
|
||||
self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
|
||||
RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
|
||||
self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
|
||||
LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
|
||||
},
|
||||
Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
|
||||
};
|
||||
use serde::{Serialize, Serializer};
|
||||
use std::{
|
||||
any::TypeId,
|
||||
fmt,
|
||||
future::Future,
|
||||
marker::PhantomData,
|
||||
mem,
|
||||
|
@ -64,7 +65,7 @@ use time::OffsetDateTime;
|
|||
use tokio::sync::{watch, Semaphore};
|
||||
use tower::ServiceBuilder;
|
||||
use tracing::{field, info_span, instrument, Instrument};
|
||||
use util::SemanticVersion;
|
||||
use util::{http::IsahcHttpClient, SemanticVersion};
|
||||
|
||||
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
|
@ -92,6 +93,18 @@ impl<R: RequestMessage> Response<R> {
|
|||
}
|
||||
}
|
||||
|
||||
struct StreamingResponse<R: RequestMessage> {
|
||||
peer: Arc<Peer>,
|
||||
receipt: Receipt<R>,
|
||||
}
|
||||
|
||||
impl<R: RequestMessage> StreamingResponse<R> {
|
||||
fn send(&self, payload: R::Response) -> Result<()> {
|
||||
self.peer.respond(self.receipt, payload)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Session {
|
||||
user_id: UserId,
|
||||
|
@ -100,6 +113,8 @@ struct Session {
|
|||
peer: Arc<Peer>,
|
||||
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
|
||||
live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
|
||||
http_client: IsahcHttpClient,
|
||||
rate_limiter: Arc<RateLimiter>,
|
||||
_executor: Executor,
|
||||
}
|
||||
|
||||
|
@ -124,8 +139,8 @@ impl Session {
|
|||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Session {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
impl Debug for Session {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Session")
|
||||
.field("user_id", &self.user_id)
|
||||
.field("connection_id", &self.connection_id)
|
||||
|
@ -148,7 +163,6 @@ pub struct Server {
|
|||
peer: Arc<Peer>,
|
||||
pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
|
||||
app_state: Arc<AppState>,
|
||||
executor: Executor,
|
||||
handlers: HashMap<TypeId, MessageHandler>,
|
||||
teardown: watch::Sender<bool>,
|
||||
}
|
||||
|
@ -175,12 +189,11 @@ where
|
|||
}
|
||||
|
||||
impl Server {
|
||||
pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
|
||||
pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
|
||||
let mut server = Self {
|
||||
id: parking_lot::Mutex::new(id),
|
||||
peer: Peer::new(id.0 as u32),
|
||||
app_state,
|
||||
executor,
|
||||
app_state: app_state.clone(),
|
||||
connection_pool: Default::default(),
|
||||
handlers: Default::default(),
|
||||
teardown: watch::channel(false).0,
|
||||
|
@ -280,7 +293,30 @@ impl Server {
|
|||
.add_message_handler(update_followers)
|
||||
.add_request_handler(get_private_user_info)
|
||||
.add_message_handler(acknowledge_channel_message)
|
||||
.add_message_handler(acknowledge_buffer_version);
|
||||
.add_message_handler(acknowledge_buffer_version)
|
||||
.add_streaming_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
move |request, response, session| {
|
||||
complete_with_language_model(
|
||||
request,
|
||||
response,
|
||||
session,
|
||||
app_state.config.openai_api_key.clone(),
|
||||
app_state.config.google_ai_api_key.clone(),
|
||||
)
|
||||
}
|
||||
})
|
||||
.add_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
move |request, response, session| {
|
||||
count_tokens_with_language_model(
|
||||
request,
|
||||
response,
|
||||
session,
|
||||
app_state.config.google_ai_api_key.clone(),
|
||||
)
|
||||
}
|
||||
});
|
||||
|
||||
Arc::new(server)
|
||||
}
|
||||
|
@ -289,12 +325,12 @@ impl Server {
|
|||
let server_id = *self.id.lock();
|
||||
let app_state = self.app_state.clone();
|
||||
let peer = self.peer.clone();
|
||||
let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
|
||||
let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
|
||||
let pool = self.connection_pool.clone();
|
||||
let live_kit_client = self.app_state.live_kit_client.clone();
|
||||
|
||||
let span = info_span!("start server");
|
||||
self.executor.spawn_detached(
|
||||
self.app_state.executor.spawn_detached(
|
||||
async move {
|
||||
tracing::info!("waiting for cleanup timeout");
|
||||
timeout.await;
|
||||
|
@ -536,6 +572,40 @@ impl Server {
|
|||
})
|
||||
}
|
||||
|
||||
fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
|
||||
Fut: Send + Future<Output = Result<()>>,
|
||||
M: RequestMessage,
|
||||
{
|
||||
let handler = Arc::new(handler);
|
||||
self.add_handler(move |envelope, session| {
|
||||
let receipt = envelope.receipt();
|
||||
let handler = handler.clone();
|
||||
async move {
|
||||
let peer = session.peer.clone();
|
||||
let response = StreamingResponse {
|
||||
peer: peer.clone(),
|
||||
receipt,
|
||||
};
|
||||
match (handler)(envelope.payload, response, session).await {
|
||||
Ok(()) => {
|
||||
peer.end_stream(receipt)?;
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
let proto_err = match &error {
|
||||
Error::Internal(err) => err.to_proto(),
|
||||
_ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
|
||||
};
|
||||
peer.respond_with_error(receipt, proto_err)?;
|
||||
Err(error)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn handle_connection(
|
||||
self: &Arc<Self>,
|
||||
|
@ -569,6 +639,14 @@ impl Server {
|
|||
tracing::Span::current().record("connection_id", format!("{}", connection_id));
|
||||
tracing::info!("connection opened");
|
||||
|
||||
let http_client = match IsahcHttpClient::new() {
|
||||
Ok(http_client) => http_client,
|
||||
Err(error) => {
|
||||
tracing::error!(?error, "failed to create HTTP client");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let session = Session {
|
||||
user_id,
|
||||
connection_id,
|
||||
|
@ -576,7 +654,9 @@ impl Server {
|
|||
peer: this.peer.clone(),
|
||||
connection_pool: this.connection_pool.clone(),
|
||||
live_kit_client: this.app_state.live_kit_client.clone(),
|
||||
_executor: executor.clone()
|
||||
http_client,
|
||||
rate_limiter: this.app_state.rate_limiter.clone(),
|
||||
_executor: executor.clone(),
|
||||
};
|
||||
|
||||
if let Err(error) = this.send_initial_client_update(connection_id, user, zed_version, send_connection_id, &session).await {
|
||||
|
@ -3220,6 +3300,207 @@ async fn acknowledge_buffer_version(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
struct CompleteWithLanguageModelRateLimit;
|
||||
|
||||
impl RateLimit for CompleteWithLanguageModelRateLimit {
|
||||
fn capacity() -> usize {
|
||||
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(120) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration() -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name() -> &'static str {
|
||||
"complete-with-language-model"
|
||||
}
|
||||
}
|
||||
|
||||
async fn complete_with_language_model(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
response: StreamingResponse<proto::CompleteWithLanguageModel>,
|
||||
session: Session,
|
||||
open_ai_api_key: Option<Arc<str>>,
|
||||
google_ai_api_key: Option<Arc<str>>,
|
||||
) -> Result<()> {
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CompleteWithLanguageModelRateLimit>(session.user_id)
|
||||
.await?;
|
||||
|
||||
if request.model.starts_with("gpt") {
|
||||
let api_key =
|
||||
open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
|
||||
complete_with_open_ai(request, response, session, api_key).await?;
|
||||
} else if request.model.starts_with("gemini") {
|
||||
let api_key = google_ai_api_key
|
||||
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
|
||||
complete_with_google_ai(request, response, session, api_key).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn complete_with_open_ai(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
response: StreamingResponse<proto::CompleteWithLanguageModel>,
|
||||
session: Session,
|
||||
api_key: Arc<str>,
|
||||
) -> Result<()> {
|
||||
const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
|
||||
|
||||
let mut completion_stream = open_ai::stream_completion(
|
||||
&session.http_client,
|
||||
OPEN_AI_API_URL,
|
||||
&api_key,
|
||||
crate::ai::language_model_request_to_open_ai(request)?,
|
||||
)
|
||||
.await
|
||||
.context("open_ai::stream_completion request failed")?;
|
||||
|
||||
while let Some(event) = completion_stream.next().await {
|
||||
let event = event?;
|
||||
response.send(proto::LanguageModelResponse {
|
||||
choices: event
|
||||
.choices
|
||||
.into_iter()
|
||||
.map(|choice| proto::LanguageModelChoiceDelta {
|
||||
index: choice.index,
|
||||
delta: Some(proto::LanguageModelResponseMessage {
|
||||
role: choice.delta.role.map(|role| match role {
|
||||
open_ai::Role::User => LanguageModelRole::LanguageModelUser,
|
||||
open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
|
||||
open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
|
||||
} as i32),
|
||||
content: choice.delta.content,
|
||||
}),
|
||||
finish_reason: choice.finish_reason,
|
||||
})
|
||||
.collect(),
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn complete_with_google_ai(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
response: StreamingResponse<proto::CompleteWithLanguageModel>,
|
||||
session: Session,
|
||||
api_key: Arc<str>,
|
||||
) -> Result<()> {
|
||||
let mut stream = google_ai::stream_generate_content(
|
||||
&session.http_client,
|
||||
google_ai::API_URL,
|
||||
api_key.as_ref(),
|
||||
crate::ai::language_model_request_to_google_ai(request)?,
|
||||
)
|
||||
.await
|
||||
.context("google_ai::stream_generate_content request failed")?;
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
let event = event?;
|
||||
response.send(proto::LanguageModelResponse {
|
||||
choices: event
|
||||
.candidates
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|candidate| proto::LanguageModelChoiceDelta {
|
||||
index: candidate.index as u32,
|
||||
delta: Some(proto::LanguageModelResponseMessage {
|
||||
role: Some(match candidate.content.role {
|
||||
google_ai::Role::User => LanguageModelRole::LanguageModelUser,
|
||||
google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
|
||||
} as i32),
|
||||
content: Some(
|
||||
candidate
|
||||
.content
|
||||
.parts
|
||||
.into_iter()
|
||||
.filter_map(|part| match part {
|
||||
google_ai::Part::TextPart(part) => Some(part.text),
|
||||
google_ai::Part::InlineDataPart(_) => None,
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
}),
|
||||
finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
|
||||
})
|
||||
.collect(),
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct CountTokensWithLanguageModelRateLimit;
|
||||
|
||||
impl RateLimit for CountTokensWithLanguageModelRateLimit {
|
||||
fn capacity() -> usize {
|
||||
std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(600) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration() -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name() -> &'static str {
|
||||
"count-tokens-with-language-model"
|
||||
}
|
||||
}
|
||||
|
||||
async fn count_tokens_with_language_model(
|
||||
request: proto::CountTokensWithLanguageModel,
|
||||
response: Response<proto::CountTokensWithLanguageModel>,
|
||||
session: Session,
|
||||
google_ai_api_key: Option<Arc<str>>,
|
||||
) -> Result<()> {
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
if !request.model.starts_with("gemini") {
|
||||
return Err(anyhow!(
|
||||
"counting tokens for model: {:?} is not supported",
|
||||
request.model
|
||||
))?;
|
||||
}
|
||||
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CountTokensWithLanguageModelRateLimit>(session.user_id)
|
||||
.await?;
|
||||
|
||||
let api_key = google_ai_api_key
|
||||
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
|
||||
let tokens_response = google_ai::count_tokens(
|
||||
&session.http_client,
|
||||
google_ai::API_URL,
|
||||
&api_key,
|
||||
crate::ai::count_tokens_request_to_google_ai(request)?,
|
||||
)
|
||||
.await?;
|
||||
response.send(proto::CountTokensResponse {
|
||||
token_count: tokens_response.total_tokens as u32,
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn authorize_access_to_language_models(session: &Session) -> Result<(), Error> {
|
||||
let db = session.db().await;
|
||||
let flags = db.get_user_flags(session.user_id).await?;
|
||||
if flags.iter().any(|flag| flag == "language-models") {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("permission denied"))?
|
||||
}
|
||||
}
|
||||
|
||||
/// Start receiving chat updates for a channel
|
||||
async fn join_channel_chat(
|
||||
request: proto::JoinChannelChat,
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::{
|
|||
db::{tests::TestDb, NewUserParams, UserId},
|
||||
executor::Executor,
|
||||
rpc::{Server, ZedVersion, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
|
||||
AppState, Config,
|
||||
AppState, Config, RateLimiter,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
use call::ActiveCall;
|
||||
|
@ -93,17 +93,14 @@ impl TestServer {
|
|||
deterministic.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
|
||||
let executor = Executor::Deterministic(deterministic.clone());
|
||||
let app_state = Self::build_app_state(&test_db, &live_kit_server, executor.clone()).await;
|
||||
let epoch = app_state
|
||||
.db
|
||||
.create_server(&app_state.config.zed_environment)
|
||||
.await
|
||||
.unwrap();
|
||||
let server = Server::new(
|
||||
epoch,
|
||||
app_state.clone(),
|
||||
Executor::Deterministic(deterministic.clone()),
|
||||
);
|
||||
let server = Server::new(epoch, app_state.clone());
|
||||
server.start().await.unwrap();
|
||||
// Advance clock to ensure the server's cleanup task is finished.
|
||||
deterministic.advance_clock(CLEANUP_TIMEOUT);
|
||||
|
@ -482,12 +479,15 @@ impl TestServer {
|
|||
|
||||
pub async fn build_app_state(
|
||||
test_db: &TestDb,
|
||||
fake_server: &live_kit_client::TestServer,
|
||||
live_kit_test_server: &live_kit_client::TestServer,
|
||||
executor: Executor,
|
||||
) -> Arc<AppState> {
|
||||
Arc::new(AppState {
|
||||
db: test_db.db().clone(),
|
||||
live_kit_client: Some(Arc::new(fake_server.create_api_client())),
|
||||
live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
|
||||
blob_store_client: None,
|
||||
rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
|
||||
executor,
|
||||
clickhouse_client: None,
|
||||
config: Config {
|
||||
http_port: 0,
|
||||
|
@ -506,6 +506,8 @@ impl TestServer {
|
|||
blob_store_access_key: None,
|
||||
blob_store_secret_key: None,
|
||||
blob_store_bucket: None,
|
||||
openai_api_key: None,
|
||||
google_ai_api_key: None,
|
||||
clickhouse_url: None,
|
||||
clickhouse_user: None,
|
||||
clickhouse_password: None,
|
||||
|
|
14
crates/google_ai/Cargo.toml
Normal file
14
crates/google_ai/Cargo.toml
Normal file
|
@ -0,0 +1,14 @@
|
|||
[package]
|
||||
name = "google_ai"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
path = "src/google_ai.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
util.workspace = true
|
266
crates/google_ai/src/google_ai.rs
Normal file
266
crates/google_ai/src/google_ai.rs
Normal file
|
@ -0,0 +1,266 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::http::HttpClient;
|
||||
|
||||
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
|
||||
|
||||
pub async fn stream_generate_content<T: HttpClient>(
|
||||
client: &T,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: GenerateContentRequest,
|
||||
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
|
||||
let uri = format!(
|
||||
"{}/v1beta/models/gemini-pro:streamGenerateContent?alt=sse&key={}",
|
||||
api_url, api_key
|
||||
);
|
||||
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let mut response = client.post_json(&uri, request.into()).await?;
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
Ok(reader
|
||||
.lines()
|
||||
.filter_map(|line| async move {
|
||||
match line {
|
||||
Ok(line) => {
|
||||
if let Some(line) = line.strip_prefix("data: ") {
|
||||
match serde_json::from_str(line) {
|
||||
Ok(response) => Some(Ok(response)),
|
||||
Err(error) => Some(Err(anyhow!(error))),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Err(error) => Some(Err(anyhow!(error))),
|
||||
}
|
||||
})
|
||||
.boxed())
|
||||
} else {
|
||||
let mut text = String::new();
|
||||
response.body_mut().read_to_string(&mut text).await?;
|
||||
Err(anyhow!(
|
||||
"error during streamGenerateContent, status code: {:?}, body: {}",
|
||||
response.status(),
|
||||
text
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn count_tokens<T: HttpClient>(
|
||||
client: &T,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: CountTokensRequest,
|
||||
) -> Result<CountTokensResponse> {
|
||||
let uri = format!(
|
||||
"{}/v1beta/models/gemini-pro:countTokens?key={}",
|
||||
api_url, api_key
|
||||
);
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let mut response = client.post_json(&uri, request.into()).await?;
|
||||
let mut text = String::new();
|
||||
response.body_mut().read_to_string(&mut text).await?;
|
||||
if response.status().is_success() {
|
||||
Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"error during countTokens, status code: {:?}, body: {}",
|
||||
response.status(),
|
||||
text
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub enum Task {
|
||||
#[serde(rename = "generateContent")]
|
||||
GenerateContent,
|
||||
#[serde(rename = "streamGenerateContent")]
|
||||
StreamGenerateContent,
|
||||
#[serde(rename = "countTokens")]
|
||||
CountTokens,
|
||||
#[serde(rename = "embedContent")]
|
||||
EmbedContent,
|
||||
#[serde(rename = "batchEmbedContents")]
|
||||
BatchEmbedContents,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentRequest {
|
||||
pub contents: Vec<Content>,
|
||||
pub generation_config: Option<GenerationConfig>,
|
||||
pub safety_settings: Option<Vec<SafetySetting>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentResponse {
|
||||
pub candidates: Option<Vec<GenerateContentCandidate>>,
|
||||
pub prompt_feedback: Option<PromptFeedback>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentCandidate {
|
||||
pub index: usize,
|
||||
pub content: Content,
|
||||
pub finish_reason: Option<String>,
|
||||
pub finish_message: Option<String>,
|
||||
pub safety_ratings: Option<Vec<SafetyRating>>,
|
||||
pub citation_metadata: Option<CitationMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Content {
|
||||
pub parts: Vec<Part>,
|
||||
pub role: Role,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Model,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum Part {
|
||||
TextPart(TextPart),
|
||||
InlineDataPart(InlineDataPart),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TextPart {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InlineDataPart {
|
||||
pub inline_data: GenerativeContentBlob,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerativeContentBlob {
|
||||
pub mime_type: String,
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CitationSource {
|
||||
pub start_index: Option<usize>,
|
||||
pub end_index: Option<usize>,
|
||||
pub uri: Option<String>,
|
||||
pub license: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CitationMetadata {
|
||||
pub citation_sources: Vec<CitationSource>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptFeedback {
|
||||
pub block_reason: Option<String>,
|
||||
pub safety_ratings: Vec<SafetyRating>,
|
||||
pub block_reason_message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerationConfig {
|
||||
pub candidate_count: Option<usize>,
|
||||
pub stop_sequences: Option<Vec<String>>,
|
||||
pub max_output_tokens: Option<usize>,
|
||||
pub temperature: Option<f64>,
|
||||
pub top_p: Option<f64>,
|
||||
pub top_k: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SafetySetting {
|
||||
pub category: HarmCategory,
|
||||
pub threshold: HarmBlockThreshold,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub enum HarmCategory {
|
||||
#[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
|
||||
Unspecified,
|
||||
#[serde(rename = "HARM_CATEGORY_DEROGATORY")]
|
||||
Derogatory,
|
||||
#[serde(rename = "HARM_CATEGORY_TOXICITY")]
|
||||
Toxicity,
|
||||
#[serde(rename = "HARM_CATEGORY_VIOLENCE")]
|
||||
Violence,
|
||||
#[serde(rename = "HARM_CATEGORY_SEXUAL")]
|
||||
Sexual,
|
||||
#[serde(rename = "HARM_CATEGORY_MEDICAL")]
|
||||
Medical,
|
||||
#[serde(rename = "HARM_CATEGORY_DANGEROUS")]
|
||||
Dangerous,
|
||||
#[serde(rename = "HARM_CATEGORY_HARASSMENT")]
|
||||
Harassment,
|
||||
#[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
|
||||
HateSpeech,
|
||||
#[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
|
||||
SexuallyExplicit,
|
||||
#[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
|
||||
DangerousContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub enum HarmBlockThreshold {
|
||||
#[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
|
||||
Unspecified,
|
||||
#[serde(rename = "BLOCK_LOW_AND_ABOVE")]
|
||||
BlockLowAndAbove,
|
||||
#[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
|
||||
BlockMediumAndAbove,
|
||||
#[serde(rename = "BLOCK_ONLY_HIGH")]
|
||||
BlockOnlyHigh,
|
||||
#[serde(rename = "BLOCK_NONE")]
|
||||
BlockNone,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum HarmProbability {
|
||||
#[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
|
||||
Unspecified,
|
||||
Negligible,
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SafetyRating {
|
||||
pub category: HarmCategory,
|
||||
pub probability: HarmProbability,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CountTokensRequest {
|
||||
pub contents: Vec<Content>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CountTokensResponse {
|
||||
pub total_tokens: usize,
|
||||
}
|
19
crates/open_ai/Cargo.toml
Normal file
19
crates/open_ai/Cargo.toml
Normal file
|
@ -0,0 +1,19 @@
|
|||
[package]
|
||||
name = "open_ai"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
path = "src/open_ai.rs"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
schemars = ["dep:schemars"]
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
util.workspace = true
|
182
crates/open_ai/src/open_ai.rs
Normal file
182
crates/open_ai/src/open_ai.rs
Normal file
|
@ -0,0 +1,182 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::convert::TryFrom;
|
||||
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
}
|
||||
|
||||
impl TryFrom<String> for Role {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: String) -> Result<Self> {
|
||||
match value.as_str() {
|
||||
"user" => Ok(Self::User),
|
||||
"assistant" => Ok(Self::Assistant),
|
||||
"system" => Ok(Self::System),
|
||||
_ => Err(anyhow!("invalid role '{value}'")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Role> for String {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => "user".to_owned(),
|
||||
Role::Assistant => "assistant".to_owned(),
|
||||
Role::System => "system".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub enum Model {
|
||||
#[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
|
||||
ThreePointFiveTurbo,
|
||||
#[serde(rename = "gpt-4", alias = "gpt-4-0613")]
|
||||
Four,
|
||||
#[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
|
||||
#[default]
|
||||
FourTurbo,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn from_id(id: &str) -> Result<Self> {
|
||||
match id {
|
||||
"gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
|
||||
"gpt-4" => Ok(Self::Four),
|
||||
"gpt-4-turbo-preview" => Ok(Self::FourTurbo),
|
||||
_ => Err(anyhow!("invalid model id")),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &'static str {
|
||||
match self {
|
||||
Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
|
||||
Self::Four => "gpt-4",
|
||||
Self::FourTurbo => "gpt-4-turbo-preview",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
|
||||
Self::Four => "gpt-4",
|
||||
Self::FourTurbo => "gpt-4-turbo",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Request {
|
||||
pub model: Model,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ChoiceDelta {
|
||||
pub index: u32,
|
||||
pub delta: ResponseMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ResponseStreamEvent {
|
||||
pub created: u32,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChoiceDelta>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
|
||||
let uri = format!("{api_url}/chat/completions");
|
||||
let request = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
Ok(reader
|
||||
.lines()
|
||||
.filter_map(|line| async move {
|
||||
match line {
|
||||
Ok(line) => {
|
||||
let line = line.strip_prefix("data: ")?;
|
||||
if line == "[DONE]" {
|
||||
None
|
||||
} else {
|
||||
match serde_json::from_str(line) {
|
||||
Ok(response) => Some(Ok(response)),
|
||||
Err(error) => Some(Err(anyhow!(error))),
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(error) => Some(Err(anyhow!(error))),
|
||||
}
|
||||
})
|
||||
.boxed())
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiResponse {
|
||||
error: OpenAiError,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiError {
|
||||
message: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<OpenAiResponse>(&body) {
|
||||
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {}",
|
||||
response.error.message,
|
||||
)),
|
||||
|
||||
_ => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
syntax = "proto3";
|
||||
package zed.messages;
|
||||
|
||||
// Looking for a number? Search "// Current max"
|
||||
// Looking for a number? Search "// current max"
|
||||
|
||||
message PeerId {
|
||||
uint32 owner_id = 1;
|
||||
|
@ -26,6 +26,7 @@ message Envelope {
|
|||
Error error = 6;
|
||||
Ping ping = 7;
|
||||
Test test = 8;
|
||||
EndStream end_stream = 165;
|
||||
|
||||
CreateRoom create_room = 9;
|
||||
CreateRoomResponse create_room_response = 10;
|
||||
|
@ -198,6 +199,11 @@ message Envelope {
|
|||
GetImplementationResponse get_implementation_response = 163;
|
||||
|
||||
JoinHostedProject join_hosted_project = 164;
|
||||
|
||||
CompleteWithLanguageModel complete_with_language_model = 166;
|
||||
LanguageModelResponse language_model_response = 167;
|
||||
CountTokensWithLanguageModel count_tokens_with_language_model = 168;
|
||||
CountTokensResponse count_tokens_response = 169; // current max
|
||||
}
|
||||
|
||||
reserved 158 to 161;
|
||||
|
@ -236,6 +242,8 @@ enum ErrorCode {
|
|||
reserved 6;
|
||||
}
|
||||
|
||||
message EndStream {}
|
||||
|
||||
message Test {
|
||||
uint64 id = 1;
|
||||
}
|
||||
|
@ -1718,3 +1726,45 @@ message SetRoomParticipantRole {
|
|||
uint64 user_id = 2;
|
||||
ChannelRole role = 3;
|
||||
}
|
||||
|
||||
message CompleteWithLanguageModel {
|
||||
string model = 1;
|
||||
repeated LanguageModelRequestMessage messages = 2;
|
||||
repeated string stop = 3;
|
||||
float temperature = 4;
|
||||
}
|
||||
|
||||
message LanguageModelRequestMessage {
|
||||
LanguageModelRole role = 1;
|
||||
string content = 2;
|
||||
}
|
||||
|
||||
enum LanguageModelRole {
|
||||
LanguageModelUser = 0;
|
||||
LanguageModelAssistant = 1;
|
||||
LanguageModelSystem = 2;
|
||||
}
|
||||
|
||||
message LanguageModelResponseMessage {
|
||||
optional LanguageModelRole role = 1;
|
||||
optional string content = 2;
|
||||
}
|
||||
|
||||
message LanguageModelResponse {
|
||||
repeated LanguageModelChoiceDelta choices = 1;
|
||||
}
|
||||
|
||||
message LanguageModelChoiceDelta {
|
||||
uint32 index = 1;
|
||||
LanguageModelResponseMessage delta = 2;
|
||||
optional string finish_reason = 3;
|
||||
}
|
||||
|
||||
message CountTokensWithLanguageModel {
|
||||
string model = 1;
|
||||
repeated LanguageModelRequestMessage messages = 2;
|
||||
}
|
||||
|
||||
message CountTokensResponse {
|
||||
uint32 token_count = 1;
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ pub trait ErrorExt {
|
|||
fn error_tag(&self, k: &str) -> Option<&str>;
|
||||
/// to_proto() converts the error into a proto::Error
|
||||
fn to_proto(&self) -> proto::Error;
|
||||
///
|
||||
/// Clones the error and turns into an [anyhow::Error].
|
||||
fn cloned(&self) -> anyhow::Error;
|
||||
}
|
||||
|
||||
|
|
|
@ -9,19 +9,21 @@ use collections::HashMap;
|
|||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
stream::BoxStream,
|
||||
FutureExt, SinkExt, StreamExt, TryFutureExt,
|
||||
FutureExt, SinkExt, Stream, StreamExt, TryFutureExt,
|
||||
};
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use serde::{ser::SerializeStruct, Serialize};
|
||||
use std::{fmt, sync::atomic::Ordering::SeqCst, time::Instant};
|
||||
use std::{
|
||||
fmt, future,
|
||||
future::Future,
|
||||
marker::PhantomData,
|
||||
sync::atomic::Ordering::SeqCst,
|
||||
sync::{
|
||||
atomic::{self, AtomicU32},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
time::Instant,
|
||||
};
|
||||
use tracing::instrument;
|
||||
|
||||
|
@ -118,6 +120,15 @@ pub struct ConnectionState {
|
|||
>,
|
||||
>,
|
||||
>,
|
||||
#[allow(clippy::type_complexity)]
|
||||
#[serde(skip)]
|
||||
stream_response_channels: Arc<
|
||||
Mutex<
|
||||
Option<
|
||||
HashMap<u32, mpsc::UnboundedSender<(Result<proto::Envelope>, oneshot::Sender<()>)>>,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
}
|
||||
|
||||
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
|
||||
|
@ -171,17 +182,28 @@ impl Peer {
|
|||
outgoing_tx,
|
||||
next_message_id: Default::default(),
|
||||
response_channels: Arc::new(Mutex::new(Some(Default::default()))),
|
||||
stream_response_channels: Arc::new(Mutex::new(Some(Default::default()))),
|
||||
};
|
||||
let mut writer = MessageStream::new(connection.tx);
|
||||
let mut reader = MessageStream::new(connection.rx);
|
||||
|
||||
let this = self.clone();
|
||||
let response_channels = connection_state.response_channels.clone();
|
||||
let stream_response_channels = connection_state.stream_response_channels.clone();
|
||||
|
||||
let handle_io = async move {
|
||||
tracing::trace!(%connection_id, "handle io future: start");
|
||||
|
||||
let _end_connection = util::defer(|| {
|
||||
response_channels.lock().take();
|
||||
if let Some(channels) = stream_response_channels.lock().take() {
|
||||
for channel in channels.values() {
|
||||
let _ = channel.unbounded_send((
|
||||
Err(anyhow!("connection closed")),
|
||||
oneshot::channel().0,
|
||||
));
|
||||
}
|
||||
}
|
||||
this.connections.write().remove(&connection_id);
|
||||
tracing::trace!(%connection_id, "handle io future: end");
|
||||
});
|
||||
|
@ -273,12 +295,14 @@ impl Peer {
|
|||
};
|
||||
|
||||
let response_channels = connection_state.response_channels.clone();
|
||||
let stream_response_channels = connection_state.stream_response_channels.clone();
|
||||
self.connections
|
||||
.write()
|
||||
.insert(connection_id, connection_state);
|
||||
|
||||
let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
|
||||
let response_channels = response_channels.clone();
|
||||
let stream_response_channels = stream_response_channels.clone();
|
||||
async move {
|
||||
let message_id = incoming.id;
|
||||
tracing::trace!(?incoming, "incoming message future: start");
|
||||
|
@ -293,8 +317,15 @@ impl Peer {
|
|||
responding_to,
|
||||
"incoming response: received"
|
||||
);
|
||||
let channel = response_channels.lock().as_mut()?.remove(&responding_to);
|
||||
if let Some(tx) = channel {
|
||||
let response_channel =
|
||||
response_channels.lock().as_mut()?.remove(&responding_to);
|
||||
let stream_response_channel = stream_response_channels
|
||||
.lock()
|
||||
.as_ref()?
|
||||
.get(&responding_to)
|
||||
.cloned();
|
||||
|
||||
if let Some(tx) = response_channel {
|
||||
let requester_resumed = oneshot::channel();
|
||||
if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
|
||||
tracing::trace!(
|
||||
|
@ -319,6 +350,31 @@ impl Peer {
|
|||
responding_to,
|
||||
"incoming response: requester resumed"
|
||||
);
|
||||
} else if let Some(tx) = stream_response_channel {
|
||||
let requester_resumed = oneshot::channel();
|
||||
if let Err(error) = tx.unbounded_send((Ok(incoming), requester_resumed.0)) {
|
||||
tracing::debug!(
|
||||
%connection_id,
|
||||
message_id,
|
||||
responding_to = responding_to,
|
||||
?error,
|
||||
"incoming stream response: request future dropped",
|
||||
);
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
%connection_id,
|
||||
message_id,
|
||||
responding_to,
|
||||
"incoming stream response: waiting to resume requester"
|
||||
);
|
||||
let _ = requester_resumed.1.await;
|
||||
tracing::debug!(
|
||||
%connection_id,
|
||||
message_id,
|
||||
responding_to,
|
||||
"incoming stream response: requester resumed"
|
||||
);
|
||||
} else {
|
||||
let message_type =
|
||||
proto::build_typed_envelope(connection_id, received_at, incoming)
|
||||
|
@ -451,6 +507,66 @@ impl Peer {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn request_stream<T: RequestMessage>(
|
||||
&self,
|
||||
receiver_id: ConnectionId,
|
||||
request: T,
|
||||
) -> impl Future<Output = Result<impl Unpin + Stream<Item = Result<T::Response>>>> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
let send = self.connection_state(receiver_id).and_then(|connection| {
|
||||
let message_id = connection.next_message_id.fetch_add(1, SeqCst);
|
||||
let stream_response_channels = connection.stream_response_channels.clone();
|
||||
stream_response_channels
|
||||
.lock()
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow!("connection was closed"))?
|
||||
.insert(message_id, tx);
|
||||
connection
|
||||
.outgoing_tx
|
||||
.unbounded_send(proto::Message::Envelope(
|
||||
request.into_envelope(message_id, None, None),
|
||||
))
|
||||
.map_err(|_| anyhow!("connection was closed"))?;
|
||||
Ok((message_id, stream_response_channels))
|
||||
});
|
||||
|
||||
async move {
|
||||
let (message_id, stream_response_channels) = send?;
|
||||
let stream_response_channels = Arc::downgrade(&stream_response_channels);
|
||||
|
||||
Ok(rx.filter_map(move |(response, _barrier)| {
|
||||
let stream_response_channels = stream_response_channels.clone();
|
||||
future::ready(match response {
|
||||
Ok(response) => {
|
||||
if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
|
||||
Some(Err(anyhow!(
|
||||
"RPC request {} failed - {}",
|
||||
T::NAME,
|
||||
error.message
|
||||
)))
|
||||
} else if let Some(proto::envelope::Payload::EndStream(_)) =
|
||||
&response.payload
|
||||
{
|
||||
// Remove the transmitting end of the response channel to end the stream.
|
||||
if let Some(channels) = stream_response_channels.upgrade() {
|
||||
if let Some(channels) = channels.lock().as_mut() {
|
||||
channels.remove(&message_id);
|
||||
}
|
||||
}
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
T::Response::from_envelope(response)
|
||||
.ok_or_else(|| anyhow!("received response of the wrong type")),
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(error) => Some(Err(error)),
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
|
||||
let connection = self.connection_state(receiver_id)?;
|
||||
let message_id = connection
|
||||
|
@ -503,6 +619,24 @@ impl Peer {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn end_stream<T: RequestMessage>(&self, receipt: Receipt<T>) -> Result<()> {
|
||||
let connection = self.connection_state(receipt.sender_id)?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
|
||||
let message = proto::EndStream {};
|
||||
|
||||
connection
|
||||
.outgoing_tx
|
||||
.unbounded_send(proto::Message::Envelope(message.into_envelope(
|
||||
message_id,
|
||||
Some(receipt.message_id),
|
||||
None,
|
||||
)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn respond_with_error<T: RequestMessage>(
|
||||
&self,
|
||||
receipt: Receipt<T>,
|
||||
|
|
|
@ -149,7 +149,10 @@ messages!(
|
|||
(CallCanceled, Foreground),
|
||||
(CancelCall, Foreground),
|
||||
(ChannelMessageSent, Foreground),
|
||||
(CompleteWithLanguageModel, Background),
|
||||
(CopyProjectEntry, Foreground),
|
||||
(CountTokensWithLanguageModel, Background),
|
||||
(CountTokensResponse, Background),
|
||||
(CreateBufferForPeer, Foreground),
|
||||
(CreateChannel, Foreground),
|
||||
(CreateChannelResponse, Foreground),
|
||||
|
@ -160,6 +163,7 @@ messages!(
|
|||
(DeleteChannel, Foreground),
|
||||
(DeleteNotification, Foreground),
|
||||
(DeleteProjectEntry, Foreground),
|
||||
(EndStream, Foreground),
|
||||
(Error, Foreground),
|
||||
(ExpandProjectEntry, Foreground),
|
||||
(ExpandProjectEntryResponse, Foreground),
|
||||
|
@ -211,6 +215,7 @@ messages!(
|
|||
(JoinProjectResponse, Foreground),
|
||||
(JoinRoom, Foreground),
|
||||
(JoinRoomResponse, Foreground),
|
||||
(LanguageModelResponse, Background),
|
||||
(LeaveChannelBuffer, Background),
|
||||
(LeaveChannelChat, Foreground),
|
||||
(LeaveProject, Foreground),
|
||||
|
@ -300,6 +305,8 @@ request_messages!(
|
|||
(Call, Ack),
|
||||
(CancelCall, Ack),
|
||||
(CopyProjectEntry, ProjectEntryResponse),
|
||||
(CompleteWithLanguageModel, LanguageModelResponse),
|
||||
(CountTokensWithLanguageModel, CountTokensResponse),
|
||||
(CreateChannel, CreateChannelResponse),
|
||||
(CreateProjectEntry, ProjectEntryResponse),
|
||||
(CreateRoom, CreateRoomResponse),
|
||||
|
|
|
@ -22,7 +22,6 @@ gpui.workspace = true
|
|||
language.workspace = true
|
||||
menu.workspace = true
|
||||
project.workspace = true
|
||||
semantic_index.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
|
|
|
@ -705,11 +705,6 @@ impl BufferSearchBar {
|
|||
option.as_button(is_active, action)
|
||||
}
|
||||
pub fn activate_search_mode(&mut self, mode: SearchMode, cx: &mut ViewContext<Self>) {
|
||||
assert_ne!(
|
||||
mode,
|
||||
SearchMode::Semantic,
|
||||
"Semantic search is not supported in buffer search"
|
||||
);
|
||||
if mode == self.current_mode {
|
||||
return;
|
||||
}
|
||||
|
@ -1022,7 +1017,7 @@ impl BufferSearchBar {
|
|||
}
|
||||
}
|
||||
fn cycle_mode(&mut self, _: &CycleMode, cx: &mut ViewContext<Self>) {
|
||||
self.activate_search_mode(next_mode(&self.current_mode, false), cx);
|
||||
self.activate_search_mode(next_mode(&self.current_mode), cx);
|
||||
}
|
||||
fn toggle_replace(&mut self, _: &ToggleReplace, cx: &mut ViewContext<Self>) {
|
||||
if let Some(_) = &self.active_searchable_item {
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
use gpui::{Action, SharedString};
|
||||
|
||||
use crate::{ActivateRegexMode, ActivateSemanticMode, ActivateTextMode};
|
||||
use crate::{ActivateRegexMode, ActivateTextMode};
|
||||
|
||||
// TODO: Update the default search mode to get from config
|
||||
#[derive(Copy, Clone, Debug, Default, PartialEq)]
|
||||
pub enum SearchMode {
|
||||
#[default]
|
||||
Text,
|
||||
Semantic,
|
||||
Regex,
|
||||
}
|
||||
|
||||
|
@ -15,7 +14,6 @@ impl SearchMode {
|
|||
pub(crate) fn label(&self) -> &'static str {
|
||||
match self {
|
||||
SearchMode::Text => "Text",
|
||||
SearchMode::Semantic => "Semantic",
|
||||
SearchMode::Regex => "Regex",
|
||||
}
|
||||
}
|
||||
|
@ -25,22 +23,14 @@ impl SearchMode {
|
|||
pub(crate) fn action(&self) -> Box<dyn Action> {
|
||||
match self {
|
||||
SearchMode::Text => ActivateTextMode.boxed_clone(),
|
||||
SearchMode::Semantic => ActivateSemanticMode.boxed_clone(),
|
||||
SearchMode::Regex => ActivateRegexMode.boxed_clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn next_mode(mode: &SearchMode, semantic_enabled: bool) -> SearchMode {
|
||||
pub(crate) fn next_mode(mode: &SearchMode) -> SearchMode {
|
||||
match mode {
|
||||
SearchMode::Text => SearchMode::Regex,
|
||||
SearchMode::Regex => {
|
||||
if semantic_enabled {
|
||||
SearchMode::Semantic
|
||||
} else {
|
||||
SearchMode::Text
|
||||
}
|
||||
}
|
||||
SearchMode::Semantic => SearchMode::Text,
|
||||
SearchMode::Regex => SearchMode::Text,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,33 +1,26 @@
|
|||
use crate::{
|
||||
history::SearchHistory, mode::SearchMode, ActivateRegexMode, ActivateSemanticMode,
|
||||
ActivateTextMode, CycleMode, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext,
|
||||
SearchOptions, SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleIncludeIgnored,
|
||||
ToggleReplace, ToggleWholeWord,
|
||||
history::SearchHistory, mode::SearchMode, ActivateRegexMode, ActivateTextMode, CycleMode,
|
||||
NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext, SearchOptions,
|
||||
SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleIncludeIgnored, ToggleReplace,
|
||||
ToggleWholeWord,
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use anyhow::Context as _;
|
||||
use collections::{HashMap, HashSet};
|
||||
use editor::{
|
||||
actions::SelectAll,
|
||||
items::active_match_index,
|
||||
scroll::{Autoscroll, Axis},
|
||||
Anchor, Editor, EditorEvent, MultiBuffer, MAX_TAB_TITLE_LEN,
|
||||
Anchor, Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer, MAX_TAB_TITLE_LEN,
|
||||
};
|
||||
use editor::{EditorElement, EditorStyle};
|
||||
use gpui::{
|
||||
actions, div, Action, AnyElement, AnyView, AppContext, Context as _, Element, EntityId,
|
||||
EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight, Global, Hsla,
|
||||
InteractiveElement, IntoElement, KeyContext, Model, ModelContext, ParentElement, Point,
|
||||
PromptLevel, Render, SharedString, Styled, Subscription, Task, TextStyle, View, ViewContext,
|
||||
VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext,
|
||||
InteractiveElement, IntoElement, KeyContext, Model, ModelContext, ParentElement, Point, Render,
|
||||
SharedString, Styled, Subscription, Task, TextStyle, View, ViewContext, VisualContext,
|
||||
WeakModel, WeakView, WhiteSpace, WindowContext,
|
||||
};
|
||||
use menu::Confirm;
|
||||
use project::{
|
||||
search::{SearchInputs, SearchQuery},
|
||||
Project,
|
||||
};
|
||||
use semantic_index::{SemanticIndex, SemanticIndexStatus};
|
||||
|
||||
use collections::HashSet;
|
||||
use project::{search::SearchQuery, Project};
|
||||
use settings::Settings;
|
||||
use smol::stream::StreamExt;
|
||||
use std::{
|
||||
|
@ -35,22 +28,20 @@ use std::{
|
|||
mem,
|
||||
ops::{Not, Range},
|
||||
path::{Path, PathBuf},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use theme::ThemeSettings;
|
||||
use workspace::{DeploySearch, NewSearch};
|
||||
|
||||
use ui::{
|
||||
h_flex, prelude::*, v_flex, Icon, IconButton, IconName, Label, LabelCommon, LabelSize,
|
||||
Selectable, ToggleButton, Tooltip,
|
||||
};
|
||||
use util::{paths::PathMatcher, ResultExt as _};
|
||||
use util::paths::PathMatcher;
|
||||
use workspace::{
|
||||
item::{BreadcrumbText, Item, ItemEvent, ItemHandle},
|
||||
searchable::{Direction, SearchableItem, SearchableItemHandle},
|
||||
ItemNavHistory, Pane, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
|
||||
WorkspaceId,
|
||||
};
|
||||
use workspace::{DeploySearch, NewSearch};
|
||||
|
||||
const MIN_INPUT_WIDTH_REMS: f32 = 15.;
|
||||
const MAX_INPUT_WIDTH_REMS: f32 = 30.;
|
||||
|
@ -86,12 +77,6 @@ pub fn init(cx: &mut AppContext) {
|
|||
register_workspace_action(workspace, move |search_bar, _: &ActivateTextMode, cx| {
|
||||
search_bar.activate_search_mode(SearchMode::Text, cx)
|
||||
});
|
||||
register_workspace_action(
|
||||
workspace,
|
||||
move |search_bar, _: &ActivateSemanticMode, cx| {
|
||||
search_bar.activate_search_mode(SearchMode::Semantic, cx)
|
||||
},
|
||||
);
|
||||
register_workspace_action(workspace, move |search_bar, action: &CycleMode, cx| {
|
||||
search_bar.cycle_mode(action, cx)
|
||||
});
|
||||
|
@ -159,8 +144,6 @@ pub struct ProjectSearchView {
|
|||
query_editor: View<Editor>,
|
||||
replacement_editor: View<Editor>,
|
||||
results_editor: View<Editor>,
|
||||
semantic_state: Option<SemanticState>,
|
||||
semantic_permissioned: Option<bool>,
|
||||
search_options: SearchOptions,
|
||||
panels_with_errors: HashSet<InputPanel>,
|
||||
active_match_index: Option<usize>,
|
||||
|
@ -174,12 +157,6 @@ pub struct ProjectSearchView {
|
|||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
struct SemanticState {
|
||||
index_status: SemanticIndexStatus,
|
||||
maintain_rate_limit: Option<Task<()>>,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ProjectSearchSettings {
|
||||
search_options: SearchOptions,
|
||||
|
@ -282,68 +259,6 @@ impl ProjectSearch {
|
|||
}));
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn semantic_search(&mut self, inputs: &SearchInputs, cx: &mut ModelContext<Self>) {
|
||||
let search = SemanticIndex::global(cx).map(|index| {
|
||||
index.update(cx, |semantic_index, cx| {
|
||||
semantic_index.search_project(
|
||||
self.project.clone(),
|
||||
inputs.as_str().to_owned(),
|
||||
10,
|
||||
inputs.files_to_include().to_vec(),
|
||||
inputs.files_to_exclude().to_vec(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
self.search_id += 1;
|
||||
self.match_ranges.clear();
|
||||
self.search_history.add(inputs.as_str().to_string());
|
||||
self.no_results = None;
|
||||
self.pending_search = Some(cx.spawn(|this, mut cx| async move {
|
||||
let results = search?.await.log_err()?;
|
||||
let matches = results
|
||||
.into_iter()
|
||||
.map(|result| (result.buffer, vec![result.range.start..result.range.start]));
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.no_results = Some(true);
|
||||
this.excerpts.update(cx, |excerpts, cx| {
|
||||
excerpts.clear(cx);
|
||||
});
|
||||
})
|
||||
.ok()?;
|
||||
for (buffer, ranges) in matches {
|
||||
let mut match_ranges = this
|
||||
.update(&mut cx, |this, cx| {
|
||||
this.no_results = Some(false);
|
||||
this.excerpts.update(cx, |excerpts, cx| {
|
||||
excerpts.stream_excerpts_with_context_lines(buffer, ranges, 3, cx)
|
||||
})
|
||||
})
|
||||
.ok()?;
|
||||
while let Some(match_range) = match_ranges.next().await {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.match_ranges.push(match_range);
|
||||
while let Ok(Some(match_range)) = match_ranges.try_next() {
|
||||
this.match_ranges.push(match_range);
|
||||
}
|
||||
cx.notify();
|
||||
})
|
||||
.ok()?;
|
||||
}
|
||||
}
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.pending_search.take();
|
||||
cx.notify();
|
||||
})
|
||||
.ok()?;
|
||||
|
||||
None
|
||||
}));
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
|
@ -358,8 +273,6 @@ impl EventEmitter<ViewEvent> for ProjectSearchView {}
|
|||
|
||||
impl Render for ProjectSearchView {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const PLEASE_AUTHENTICATE: &str = "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables. If you authenticated using the Assistant Panel, please restart Zed to Authenticate.";
|
||||
|
||||
if self.has_matches() {
|
||||
div()
|
||||
.flex_1()
|
||||
|
@ -370,7 +283,7 @@ impl Render for ProjectSearchView {
|
|||
let model = self.model.read(cx);
|
||||
let has_no_results = model.no_results.unwrap_or(false);
|
||||
let is_search_underway = model.pending_search.is_some();
|
||||
let mut major_text = if is_search_underway {
|
||||
let major_text = if is_search_underway {
|
||||
Label::new("Searching...")
|
||||
} else if has_no_results {
|
||||
Label::new("No results")
|
||||
|
@ -378,43 +291,6 @@ impl Render for ProjectSearchView {
|
|||
Label::new(format!("{} search all files", self.current_mode.label()))
|
||||
};
|
||||
|
||||
let mut show_minor_text = true;
|
||||
let semantic_status = self.semantic_state.as_ref().and_then(|semantic| {
|
||||
let status = semantic.index_status;
|
||||
match status {
|
||||
SemanticIndexStatus::NotAuthenticated => {
|
||||
major_text = Label::new("Not Authenticated");
|
||||
show_minor_text = false;
|
||||
Some(PLEASE_AUTHENTICATE.to_string())
|
||||
}
|
||||
SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()),
|
||||
SemanticIndexStatus::Indexing {
|
||||
remaining_files,
|
||||
rate_limit_expiry,
|
||||
} => {
|
||||
if remaining_files == 0 {
|
||||
Some("Indexing...".to_string())
|
||||
} else {
|
||||
if let Some(rate_limit_expiry) = rate_limit_expiry {
|
||||
let remaining_seconds =
|
||||
rate_limit_expiry.duration_since(Instant::now());
|
||||
if remaining_seconds > Duration::from_secs(0) {
|
||||
Some(format!(
|
||||
"Remaining files to index (rate limit resets in {}s): {}",
|
||||
remaining_seconds.as_secs(),
|
||||
remaining_files
|
||||
))
|
||||
} else {
|
||||
Some(format!("Remaining files to index: {}", remaining_files))
|
||||
}
|
||||
} else {
|
||||
Some(format!("Remaining files to index: {}", remaining_files))
|
||||
}
|
||||
}
|
||||
}
|
||||
SemanticIndexStatus::NotIndexed => None,
|
||||
}
|
||||
});
|
||||
let major_text = div().justify_center().max_w_96().child(major_text);
|
||||
|
||||
let minor_text: Option<SharedString> = if let Some(no_results) = model.no_results {
|
||||
|
@ -424,12 +300,7 @@ impl Render for ProjectSearchView {
|
|||
None
|
||||
}
|
||||
} else {
|
||||
if let Some(mut semantic_status) = semantic_status {
|
||||
semantic_status.extend(self.landing_text_minor().chars());
|
||||
Some(semantic_status.into())
|
||||
} else {
|
||||
Some(self.landing_text_minor())
|
||||
}
|
||||
Some(self.landing_text_minor())
|
||||
};
|
||||
let minor_text = minor_text.map(|text| {
|
||||
div()
|
||||
|
@ -676,58 +547,6 @@ impl ProjectSearchView {
|
|||
});
|
||||
}
|
||||
|
||||
fn index_project(&mut self, cx: &mut ViewContext<Self>) {
|
||||
if let Some(semantic_index) = SemanticIndex::global(cx) {
|
||||
// Semantic search uses no options
|
||||
self.search_options = SearchOptions::none();
|
||||
|
||||
let project = self.model.read(cx).project.clone();
|
||||
|
||||
semantic_index.update(cx, |semantic_index, cx| {
|
||||
semantic_index
|
||||
.index_project(project.clone(), cx)
|
||||
.detach_and_log_err(cx);
|
||||
});
|
||||
|
||||
self.semantic_state = Some(SemanticState {
|
||||
index_status: semantic_index.read(cx).status(&project),
|
||||
maintain_rate_limit: None,
|
||||
_subscription: cx.observe(&semantic_index, Self::semantic_index_changed),
|
||||
});
|
||||
self.semantic_index_changed(semantic_index, cx);
|
||||
}
|
||||
}
|
||||
|
||||
fn semantic_index_changed(
|
||||
&mut self,
|
||||
semantic_index: Model<SemanticIndex>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) {
|
||||
let project = self.model.read(cx).project.clone();
|
||||
if let Some(semantic_state) = self.semantic_state.as_mut() {
|
||||
cx.notify();
|
||||
semantic_state.index_status = semantic_index.read(cx).status(&project);
|
||||
if let SemanticIndexStatus::Indexing {
|
||||
rate_limit_expiry: Some(_),
|
||||
..
|
||||
} = &semantic_state.index_status
|
||||
{
|
||||
if semantic_state.maintain_rate_limit.is_none() {
|
||||
semantic_state.maintain_rate_limit =
|
||||
Some(cx.spawn(|this, mut cx| async move {
|
||||
loop {
|
||||
cx.background_executor().timer(Duration::from_secs(1)).await;
|
||||
this.update(&mut cx, |_, cx| cx.notify()).log_err();
|
||||
}
|
||||
}));
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
semantic_state.maintain_rate_limit = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_search(&mut self, cx: &mut ViewContext<Self>) {
|
||||
self.model.update(cx, |model, cx| {
|
||||
model.pending_search = None;
|
||||
|
@ -750,63 +569,7 @@ impl ProjectSearchView {
|
|||
self.clear_search(cx);
|
||||
self.current_mode = mode;
|
||||
self.active_match_index = None;
|
||||
|
||||
match mode {
|
||||
SearchMode::Semantic => {
|
||||
let has_permission = self.semantic_permissioned(cx);
|
||||
self.active_match_index = None;
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let has_permission = has_permission.await?;
|
||||
|
||||
if !has_permission {
|
||||
let answer = this.update(&mut cx, |this, cx| {
|
||||
let project = this.model.read(cx).project.clone();
|
||||
let project_name = project
|
||||
.read(cx)
|
||||
.worktree_root_names(cx)
|
||||
.collect::<Vec<&str>>()
|
||||
.join("/");
|
||||
let is_plural =
|
||||
project_name.chars().filter(|letter| *letter == '/').count() > 0;
|
||||
let prompt_text = format!("Would you like to index the '{}' project{} for semantic search? This requires sending code to the OpenAI API", project_name,
|
||||
if is_plural {
|
||||
"s"
|
||||
} else {""});
|
||||
cx.prompt(
|
||||
PromptLevel::Info,
|
||||
prompt_text.as_str(),
|
||||
None,
|
||||
&["Continue", "Cancel"],
|
||||
)
|
||||
})?;
|
||||
|
||||
if answer.await? == 0 {
|
||||
this.update(&mut cx, |this, _| {
|
||||
this.semantic_permissioned = Some(true);
|
||||
})?;
|
||||
} else {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.semantic_permissioned = Some(false);
|
||||
debug_assert_ne!(previous_mode, SearchMode::Semantic, "Tried to re-enable semantic search mode after user modal was rejected");
|
||||
this.activate_search_mode(previous_mode, cx);
|
||||
})?;
|
||||
return anyhow::Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.index_project(cx);
|
||||
})?;
|
||||
|
||||
anyhow::Ok(())
|
||||
}).detach_and_log_err(cx);
|
||||
}
|
||||
SearchMode::Regex | SearchMode::Text => {
|
||||
self.semantic_state = None;
|
||||
self.active_match_index = None;
|
||||
self.search(cx);
|
||||
}
|
||||
}
|
||||
self.search(cx);
|
||||
|
||||
cx.update_global(|state: &mut ActiveSettings, cx| {
|
||||
state.0.insert(
|
||||
|
@ -973,8 +736,6 @@ impl ProjectSearchView {
|
|||
model,
|
||||
query_editor,
|
||||
results_editor,
|
||||
semantic_state: None,
|
||||
semantic_permissioned: None,
|
||||
search_options: options,
|
||||
panels_with_errors: HashSet::default(),
|
||||
active_match_index: None,
|
||||
|
@ -990,19 +751,6 @@ impl ProjectSearchView {
|
|||
this
|
||||
}
|
||||
|
||||
fn semantic_permissioned(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<bool>> {
|
||||
if let Some(value) = self.semantic_permissioned {
|
||||
return Task::ready(Ok(value));
|
||||
}
|
||||
|
||||
SemanticIndex::global(cx)
|
||||
.map(|semantic| {
|
||||
let project = self.model.read(cx).project.clone();
|
||||
semantic.update(cx, |this, cx| this.project_previously_indexed(&project, cx))
|
||||
})
|
||||
.unwrap_or(Task::ready(Ok(false)))
|
||||
}
|
||||
|
||||
pub fn new_search_in_directory(
|
||||
workspace: &mut Workspace,
|
||||
dir_path: &Path,
|
||||
|
@ -1126,22 +874,8 @@ impl ProjectSearchView {
|
|||
}
|
||||
|
||||
fn search(&mut self, cx: &mut ViewContext<Self>) {
|
||||
let mode = self.current_mode;
|
||||
match mode {
|
||||
SearchMode::Semantic => {
|
||||
if self.semantic_state.is_some() {
|
||||
if let Some(query) = self.build_search_query(cx) {
|
||||
self.model
|
||||
.update(cx, |model, cx| model.semantic_search(query.as_inner(), cx));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ => {
|
||||
if let Some(query) = self.build_search_query(cx) {
|
||||
self.model.update(cx, |model, cx| model.search(query, cx));
|
||||
}
|
||||
}
|
||||
if let Some(query) = self.build_search_query(cx) {
|
||||
self.model.update(cx, |model, cx| model.search(query, cx));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1356,7 +1090,6 @@ impl ProjectSearchView {
|
|||
fn landing_text_minor(&self) -> SharedString {
|
||||
match self.current_mode {
|
||||
SearchMode::Text | SearchMode::Regex => "Include/exclude specific paths with the filter option. Matching exact word and/or casing is available too.".into(),
|
||||
SearchMode::Semantic => "\nSimply explain the code you are looking to find. ex. 'prompt user for permissions to index their project'".into()
|
||||
}
|
||||
}
|
||||
fn border_color_for(&self, panel: InputPanel, cx: &WindowContext) -> Hsla {
|
||||
|
@ -1387,8 +1120,7 @@ impl ProjectSearchBar {
|
|||
fn cycle_mode(&self, _: &CycleMode, cx: &mut ViewContext<Self>) {
|
||||
if let Some(view) = self.active_project_search.as_ref() {
|
||||
view.update(cx, |this, cx| {
|
||||
let new_mode =
|
||||
crate::mode::next_mode(&this.current_mode, SemanticIndex::enabled(cx));
|
||||
let new_mode = crate::mode::next_mode(&this.current_mode);
|
||||
this.activate_search_mode(new_mode, cx);
|
||||
let editor_handle = this.query_editor.focus_handle(cx);
|
||||
cx.focus(&editor_handle);
|
||||
|
@ -1681,7 +1413,6 @@ impl Render for ProjectSearchBar {
|
|||
});
|
||||
}
|
||||
let search = search.read(cx);
|
||||
let semantic_is_available = SemanticIndex::enabled(cx);
|
||||
|
||||
let query_column = h_flex()
|
||||
.flex_1()
|
||||
|
@ -1711,12 +1442,8 @@ impl Render for ProjectSearchBar {
|
|||
.unwrap_or_default(),
|
||||
),
|
||||
)
|
||||
.when(search.current_mode != SearchMode::Semantic, |this| {
|
||||
this.child(
|
||||
IconButton::new(
|
||||
"project-search-case-sensitive",
|
||||
IconName::CaseSensitive,
|
||||
)
|
||||
.child(
|
||||
IconButton::new("project-search-case-sensitive", IconName::CaseSensitive)
|
||||
.tooltip(|cx| {
|
||||
Tooltip::for_action(
|
||||
"Toggle case sensitive",
|
||||
|
@ -1728,18 +1455,17 @@ impl Render for ProjectSearchBar {
|
|||
.on_click(cx.listener(|this, _, cx| {
|
||||
this.toggle_search_option(SearchOptions::CASE_SENSITIVE, cx);
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
IconButton::new("project-search-whole-word", IconName::WholeWord)
|
||||
.tooltip(|cx| {
|
||||
Tooltip::for_action("Toggle whole word", &ToggleWholeWord, cx)
|
||||
})
|
||||
.selected(self.is_option_enabled(SearchOptions::WHOLE_WORD, cx))
|
||||
.on_click(cx.listener(|this, _, cx| {
|
||||
this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
|
||||
})),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
IconButton::new("project-search-whole-word", IconName::WholeWord)
|
||||
.tooltip(|cx| {
|
||||
Tooltip::for_action("Toggle whole word", &ToggleWholeWord, cx)
|
||||
})
|
||||
.selected(self.is_option_enabled(SearchOptions::WHOLE_WORD, cx))
|
||||
.on_click(cx.listener(|this, _, cx| {
|
||||
this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
|
||||
})),
|
||||
),
|
||||
);
|
||||
|
||||
let mode_column = v_flex().items_start().justify_start().child(
|
||||
|
@ -1775,33 +1501,8 @@ impl Render for ProjectSearchBar {
|
|||
cx,
|
||||
)
|
||||
})
|
||||
.map(|this| {
|
||||
if semantic_is_available {
|
||||
this.middle()
|
||||
} else {
|
||||
this.last()
|
||||
}
|
||||
}),
|
||||
)
|
||||
.when(semantic_is_available, |this| {
|
||||
this.child(
|
||||
ToggleButton::new("project-search-semantic-button", "Semantic")
|
||||
.style(ButtonStyle::Filled)
|
||||
.size(ButtonSize::Large)
|
||||
.selected(search.current_mode == SearchMode::Semantic)
|
||||
.on_click(cx.listener(|this, _, cx| {
|
||||
this.activate_search_mode(SearchMode::Semantic, cx)
|
||||
}))
|
||||
.tooltip(|cx| {
|
||||
Tooltip::for_action(
|
||||
"Toggle semantic search",
|
||||
&ActivateSemanticMode,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.last(),
|
||||
)
|
||||
}),
|
||||
.last(),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
IconButton::new("project-search-toggle-replace", IconName::Replace)
|
||||
|
@ -1929,21 +1630,16 @@ impl Render for ProjectSearchBar {
|
|||
.border_color(search.border_color_for(InputPanel::Include, cx))
|
||||
.rounded_lg()
|
||||
.child(self.render_text_input(&search.included_files_editor, cx))
|
||||
.when(search.current_mode != SearchMode::Semantic, |this| {
|
||||
this.child(
|
||||
SearchOptions::INCLUDE_IGNORED.as_button(
|
||||
search
|
||||
.search_options
|
||||
.contains(SearchOptions::INCLUDE_IGNORED),
|
||||
cx.listener(|this, _, cx| {
|
||||
this.toggle_search_option(
|
||||
SearchOptions::INCLUDE_IGNORED,
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
),
|
||||
)
|
||||
}),
|
||||
.child(
|
||||
SearchOptions::INCLUDE_IGNORED.as_button(
|
||||
search
|
||||
.search_options
|
||||
.contains(SearchOptions::INCLUDE_IGNORED),
|
||||
cx.listener(|this, _, cx| {
|
||||
this.toggle_search_option(SearchOptions::INCLUDE_IGNORED, cx);
|
||||
}),
|
||||
),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
|
@ -1972,9 +1668,6 @@ impl Render for ProjectSearchBar {
|
|||
.on_action(cx.listener(|this, _: &ActivateRegexMode, cx| {
|
||||
this.activate_search_mode(SearchMode::Regex, cx)
|
||||
}))
|
||||
.on_action(cx.listener(|this, _: &ActivateSemanticMode, cx| {
|
||||
this.activate_search_mode(SearchMode::Semantic, cx)
|
||||
}))
|
||||
.capture_action(cx.listener(|this, action, cx| {
|
||||
this.tab(action, cx);
|
||||
cx.stop_propagation();
|
||||
|
@ -1987,35 +1680,33 @@ impl Render for ProjectSearchBar {
|
|||
.on_action(cx.listener(|this, action, cx| {
|
||||
this.cycle_mode(action, cx);
|
||||
}))
|
||||
.when(search.current_mode != SearchMode::Semantic, |this| {
|
||||
this.on_action(cx.listener(|this, action, cx| {
|
||||
this.toggle_replace(action, cx);
|
||||
.on_action(cx.listener(|this, action, cx| {
|
||||
this.toggle_replace(action, cx);
|
||||
}))
|
||||
.on_action(cx.listener(|this, _: &ToggleWholeWord, cx| {
|
||||
this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
|
||||
}))
|
||||
.on_action(cx.listener(|this, _: &ToggleCaseSensitive, cx| {
|
||||
this.toggle_search_option(SearchOptions::CASE_SENSITIVE, cx);
|
||||
}))
|
||||
.on_action(cx.listener(|this, action, cx| {
|
||||
if let Some(search) = this.active_project_search.as_ref() {
|
||||
search.update(cx, |this, cx| {
|
||||
this.replace_next(action, cx);
|
||||
})
|
||||
}
|
||||
}))
|
||||
.on_action(cx.listener(|this, action, cx| {
|
||||
if let Some(search) = this.active_project_search.as_ref() {
|
||||
search.update(cx, |this, cx| {
|
||||
this.replace_all(action, cx);
|
||||
})
|
||||
}
|
||||
}))
|
||||
.when(search.filters_enabled, |this| {
|
||||
this.on_action(cx.listener(|this, _: &ToggleIncludeIgnored, cx| {
|
||||
this.toggle_search_option(SearchOptions::INCLUDE_IGNORED, cx);
|
||||
}))
|
||||
.on_action(cx.listener(|this, _: &ToggleWholeWord, cx| {
|
||||
this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
|
||||
}))
|
||||
.on_action(cx.listener(|this, _: &ToggleCaseSensitive, cx| {
|
||||
this.toggle_search_option(SearchOptions::CASE_SENSITIVE, cx);
|
||||
}))
|
||||
.on_action(cx.listener(|this, action, cx| {
|
||||
if let Some(search) = this.active_project_search.as_ref() {
|
||||
search.update(cx, |this, cx| {
|
||||
this.replace_next(action, cx);
|
||||
})
|
||||
}
|
||||
}))
|
||||
.on_action(cx.listener(|this, action, cx| {
|
||||
if let Some(search) = this.active_project_search.as_ref() {
|
||||
search.update(cx, |this, cx| {
|
||||
this.replace_all(action, cx);
|
||||
})
|
||||
}
|
||||
}))
|
||||
.when(search.filters_enabled, |this| {
|
||||
this.on_action(cx.listener(|this, _: &ToggleIncludeIgnored, cx| {
|
||||
this.toggle_search_option(SearchOptions::INCLUDE_IGNORED, cx);
|
||||
}))
|
||||
})
|
||||
})
|
||||
.on_action(cx.listener(Self::select_next_match))
|
||||
.on_action(cx.listener(Self::select_prev_match))
|
||||
|
@ -2039,12 +1730,6 @@ impl ToolbarItemView for ProjectSearchBar {
|
|||
self.subscription = None;
|
||||
self.active_project_search = None;
|
||||
if let Some(search) = active_pane_item.and_then(|i| i.downcast::<ProjectSearchView>()) {
|
||||
search.update(cx, |search, cx| {
|
||||
if search.current_mode == SearchMode::Semantic {
|
||||
search.index_project(cx);
|
||||
}
|
||||
});
|
||||
|
||||
self.subscription = Some(cx.observe(&search, |_, _, cx| cx.notify()));
|
||||
self.active_project_search = Some(search);
|
||||
ToolbarItemLocation::PrimaryLeft {}
|
||||
|
@ -2123,9 +1808,8 @@ pub mod tests {
|
|||
use editor::DisplayPoint;
|
||||
use gpui::{Action, TestAppContext, WindowHandle};
|
||||
use project::FakeFs;
|
||||
use semantic_index::semantic_index_settings::SemanticIndexSettings;
|
||||
use serde_json::json;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use settings::SettingsStore;
|
||||
use std::sync::Arc;
|
||||
use workspace::DeploySearch;
|
||||
|
||||
|
@ -3446,8 +3130,6 @@ pub mod tests {
|
|||
let settings = SettingsStore::test(cx);
|
||||
cx.set_global(settings);
|
||||
|
||||
SemanticIndexSettings::register(cx);
|
||||
|
||||
theme::init(theme::LoadThemes::JustBase, cx);
|
||||
|
||||
language::init(cx);
|
||||
|
|
|
@ -33,7 +33,6 @@ actions!(
|
|||
NextHistoryQuery,
|
||||
PreviousHistoryQuery,
|
||||
ActivateTextMode,
|
||||
ActivateSemanticMode,
|
||||
ActivateRegexMode,
|
||||
ReplaceAll,
|
||||
ReplaceNext,
|
||||
|
|
|
@ -1,66 +0,0 @@
|
|||
[package]
|
||||
name = "semantic_index"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/semantic_index.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
ai.workspace = true
|
||||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
language.workspace = true
|
||||
lazy_static.workspace = true
|
||||
log.workspace = true
|
||||
ndarray = { version = "0.15.0" }
|
||||
ordered-float.workspace = true
|
||||
parking_lot.workspace = true
|
||||
postage.workspace = true
|
||||
project.workspace = true
|
||||
rand.workspace = true
|
||||
release_channel.workspace = true
|
||||
rpc.workspace = true
|
||||
rusqlite.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
sha1 = "0.10.5"
|
||||
smol.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
util.workspace = true
|
||||
workspace.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
ai = { workspace = true, features = ["test-support"] }
|
||||
collections = { workspace = true, features = ["test-support"] }
|
||||
ctor.workspace = true
|
||||
env_logger.workspace = true
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
pretty_assertions.workspace = true
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
rand.workspace = true
|
||||
rpc = { workspace = true, features = ["test-support"] }
|
||||
settings = { workspace = true, features = ["test-support"]}
|
||||
tempfile.workspace = true
|
||||
tree-sitter-cpp.workspace = true
|
||||
tree-sitter-elixir.workspace = true
|
||||
tree-sitter-json.workspace = true
|
||||
tree-sitter-lua.workspace = true
|
||||
tree-sitter-php.workspace = true
|
||||
tree-sitter-ruby.workspace = true
|
||||
tree-sitter-rust.workspace = true
|
||||
tree-sitter-toml.workspace = true
|
||||
tree-sitter-typescript.workspace = true
|
||||
unindent.workspace = true
|
||||
workspace = { workspace = true, features = ["test-support"] }
|
|
@ -1 +0,0 @@
|
|||
../../LICENSE-GPL
|
|
@ -1,20 +0,0 @@
|
|||
|
||||
# Semantic Index
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Metrics
|
||||
|
||||
nDCG@k:
|
||||
- "The value of NDCG is determined by comparing the relevance of the items returned by the search engine to the relevance of the item that a hypothetical "ideal" search engine would return.
|
||||
- "The relevance of result is represented by a score (also known as a 'grade') that is assigned to the search query. The scores of these results are then discounted based on their position in the search results -- did they get recommended first or last?"
|
||||
|
||||
MRR@k:
|
||||
- "Mean reciprocal rank quantifies the rank of the first relevant item found in the recommendation list."
|
||||
|
||||
MAP@k:
|
||||
- "Mean average precision averages the precision@k metric at each relevant item position in the recommendation list.
|
||||
|
||||
Resources:
|
||||
- [Evaluating recommendation metrics](https://www.shaped.ai/blog/evaluating-recommendation-systems-map-mmr-ndcg)
|
||||
- [Math Walkthrough](https://towardsdatascience.com/demystifying-ndcg-bee3be58cfe0)
|
|
@ -1,114 +0,0 @@
|
|||
{
|
||||
"repo": "https://github.com/AntonOsika/gpt-engineer.git",
|
||||
"commit": "7735a6445bae3611c62f521e6464c67c957f87c2",
|
||||
"assertions": [
|
||||
{
|
||||
"query": "How do I contribute to this project?",
|
||||
"matches": [
|
||||
".github/CONTRIBUTING.md:1",
|
||||
"ROADMAP.md:48"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "What version of the openai package is active?",
|
||||
"matches": [
|
||||
"pyproject.toml:14"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Ask user for clarification",
|
||||
"matches": [
|
||||
"gpt_engineer/steps.py:69"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "generate tests for python code",
|
||||
"matches": [
|
||||
"gpt_engineer/steps.py:153"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "get item from database based on key",
|
||||
"matches": [
|
||||
"gpt_engineer/db.py:42",
|
||||
"gpt_engineer/db.py:68"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "prompt user to select files",
|
||||
"matches": [
|
||||
"gpt_engineer/file_selector.py:171",
|
||||
"gpt_engineer/file_selector.py:306",
|
||||
"gpt_engineer/file_selector.py:289",
|
||||
"gpt_engineer/file_selector.py:234"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "send to rudderstack",
|
||||
"matches": [
|
||||
"gpt_engineer/collect.py:11",
|
||||
"gpt_engineer/collect.py:38"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "parse code blocks from chat messages",
|
||||
"matches": [
|
||||
"gpt_engineer/chat_to_files.py:10",
|
||||
"docs/intro/chat_parsing.md:1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "how do I use the docker cli?",
|
||||
"matches": [
|
||||
"docker/README.md:1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "ask the user if the code ran successfully?",
|
||||
"matches": [
|
||||
"gpt_engineer/learning.py:54"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "how is consent granted by the user?",
|
||||
"matches": [
|
||||
"gpt_engineer/learning.py:107",
|
||||
"gpt_engineer/learning.py:130",
|
||||
"gpt_engineer/learning.py:152"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "what are all the different steps the agent can take?",
|
||||
"matches": [
|
||||
"docs/intro/steps_module.md:1",
|
||||
"gpt_engineer/steps.py:391"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "ask the user for clarification?",
|
||||
"matches": [
|
||||
"gpt_engineer/steps.py:69"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "what models are available?",
|
||||
"matches": [
|
||||
"gpt_engineer/ai.py:315",
|
||||
"gpt_engineer/ai.py:341",
|
||||
"docs/open-models.md:1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "what is the current focus of the project?",
|
||||
"matches": [
|
||||
"ROADMAP.md:11"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "does the agent know how to fix code?",
|
||||
"matches": [
|
||||
"gpt_engineer/steps.py:367"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
|
@ -1,104 +0,0 @@
|
|||
{
|
||||
"repo": "https://github.com/tree-sitter/tree-sitter.git",
|
||||
"commit": "46af27796a76c72d8466627d499f2bca4af958ee",
|
||||
"assertions": [
|
||||
{
|
||||
"query": "What attributes are available for the tags configuration struct?",
|
||||
"matches": [
|
||||
"tags/src/lib.rs:24"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "create a new tag configuration",
|
||||
"matches": [
|
||||
"tags/src/lib.rs:119"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "generate tags based on config",
|
||||
"matches": [
|
||||
"tags/src/lib.rs:261"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "match on ts quantifier in rust",
|
||||
"matches": [
|
||||
"lib/binding_rust/lib.rs:139"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "cli command to generate tags",
|
||||
"matches": [
|
||||
"cli/src/tags.rs:10"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "what version of the tree-sitter-tags package is active?",
|
||||
"matches": [
|
||||
"tags/Cargo.toml:4"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Insert a new parse state",
|
||||
"matches": [
|
||||
"cli/src/generate/build_tables/build_parse_table.rs:153"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Handle conflict when numerous actions occur on the same symbol",
|
||||
"matches": [
|
||||
"cli/src/generate/build_tables/build_parse_table.rs:363",
|
||||
"cli/src/generate/build_tables/build_parse_table.rs:442"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Match based on associativity of actions",
|
||||
"matches": [
|
||||
"cri/src/generate/build_tables/build_parse_table.rs:542"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Format token set display",
|
||||
"matches": [
|
||||
"cli/src/generate/build_tables/item.rs:246"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "extract choices from rule",
|
||||
"matches": [
|
||||
"cli/src/generate/prepare_grammar/flatten_grammar.rs:124"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "How do we identify if a symbol is being used?",
|
||||
"matches": [
|
||||
"cli/src/generate/prepare_grammar/flatten_grammar.rs:175"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "How do we launch the playground?",
|
||||
"matches": [
|
||||
"cli/src/playground.rs:46"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "How do we test treesitter query matches in rust?",
|
||||
"matches": [
|
||||
"cli/src/query_testing.rs:152",
|
||||
"cli/src/tests/query_test.rs:781",
|
||||
"cli/src/tests/query_test.rs:2163",
|
||||
"cli/src/tests/query_test.rs:3781",
|
||||
"cli/src/tests/query_test.rs:887"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "What does the CLI do?",
|
||||
"matches": [
|
||||
"cli/README.md:10",
|
||||
"cli/loader/README.md:3",
|
||||
"docs/section-5-implementation.md:14",
|
||||
"docs/section-5-implementation.md:18"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
|
@ -1,594 +0,0 @@
|
|||
use crate::{
|
||||
parsing::{Span, SpanDigest},
|
||||
SEMANTIC_INDEX_VERSION,
|
||||
};
|
||||
use ai::embedding::Embedding;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use collections::HashMap;
|
||||
use futures::channel::oneshot;
|
||||
use gpui::BackgroundExecutor;
|
||||
use ndarray::{Array1, Array2};
|
||||
use ordered_float::OrderedFloat;
|
||||
use project::Fs;
|
||||
use rpc::proto::Timestamp;
|
||||
use rusqlite::params;
|
||||
use rusqlite::types::Value;
|
||||
use std::{
|
||||
future::Future,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
time::SystemTime,
|
||||
};
|
||||
use util::{paths::PathMatcher, TryFutureExt};
|
||||
|
||||
pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
|
||||
let mut indices = (0..data.len()).collect::<Vec<_>>();
|
||||
indices.sort_by_key(|&i| &data[i]);
|
||||
indices.reverse();
|
||||
indices
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FileRecord {
|
||||
pub id: usize,
|
||||
pub relative_path: String,
|
||||
pub mtime: Timestamp,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct VectorDatabase {
|
||||
path: Arc<Path>,
|
||||
transactions:
|
||||
smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
|
||||
}
|
||||
|
||||
impl VectorDatabase {
|
||||
pub async fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
path: Arc<Path>,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Result<Self> {
|
||||
if let Some(db_directory) = path.parent() {
|
||||
fs.create_dir(db_directory).await?;
|
||||
}
|
||||
|
||||
let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
|
||||
Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
|
||||
>();
|
||||
executor
|
||||
.spawn({
|
||||
let path = path.clone();
|
||||
async move {
|
||||
let mut connection = rusqlite::Connection::open(&path)?;
|
||||
|
||||
connection.pragma_update(None, "journal_mode", "wal")?;
|
||||
connection.pragma_update(None, "synchronous", "normal")?;
|
||||
connection.pragma_update(None, "cache_size", 1000000)?;
|
||||
connection.pragma_update(None, "temp_store", "MEMORY")?;
|
||||
|
||||
while let Ok(transaction) = transactions_rx.recv().await {
|
||||
transaction(&mut connection);
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.log_err()
|
||||
})
|
||||
.detach();
|
||||
let this = Self {
|
||||
transactions: transactions_tx,
|
||||
path,
|
||||
};
|
||||
this.initialize_database().await?;
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
pub fn path(&self) -> &Arc<Path> {
|
||||
&self.path
|
||||
}
|
||||
|
||||
fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
|
||||
where
|
||||
F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
|
||||
T: 'static + Send,
|
||||
{
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let transactions = self.transactions.clone();
|
||||
async move {
|
||||
if transactions
|
||||
.send(Box::new(|connection| {
|
||||
let result = connection
|
||||
.transaction()
|
||||
.map_err(|err| anyhow!(err))
|
||||
.and_then(|transaction| {
|
||||
let result = f(&transaction)?;
|
||||
transaction.commit()?;
|
||||
Ok(result)
|
||||
});
|
||||
let _ = tx.send(result);
|
||||
}))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err(anyhow!("connection was dropped"))?;
|
||||
}
|
||||
rx.await?
|
||||
}
|
||||
}
|
||||
|
||||
fn initialize_database(&self) -> impl Future<Output = Result<()>> {
|
||||
self.transact(|db| {
|
||||
rusqlite::vtab::array::load_module(&db)?;
|
||||
|
||||
// Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
|
||||
let version_query = db.prepare("SELECT version from semantic_index_config");
|
||||
let version = version_query
|
||||
.and_then(|mut query| query.query_row([], |row| row.get::<_, i64>(0)));
|
||||
if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
|
||||
log::trace!("vector database schema up to date");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
log::trace!("vector database schema out of date. updating...");
|
||||
// We renamed the `documents` table to `spans`, so we want to drop
|
||||
// `documents` without recreating it if it exists.
|
||||
db.execute("DROP TABLE IF EXISTS documents", [])
|
||||
.context("failed to drop 'documents' table")?;
|
||||
db.execute("DROP TABLE IF EXISTS spans", [])
|
||||
.context("failed to drop 'spans' table")?;
|
||||
db.execute("DROP TABLE IF EXISTS files", [])
|
||||
.context("failed to drop 'files' table")?;
|
||||
db.execute("DROP TABLE IF EXISTS worktrees", [])
|
||||
.context("failed to drop 'worktrees' table")?;
|
||||
db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
|
||||
.context("failed to drop 'semantic_index_config' table")?;
|
||||
|
||||
// Initialize Vector Databasing Tables
|
||||
db.execute(
|
||||
"CREATE TABLE semantic_index_config (
|
||||
version INTEGER NOT NULL
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
db.execute(
|
||||
"INSERT INTO semantic_index_config (version) VALUES (?1)",
|
||||
params![SEMANTIC_INDEX_VERSION],
|
||||
)?;
|
||||
|
||||
db.execute(
|
||||
"CREATE TABLE worktrees (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
absolute_path VARCHAR NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
|
||||
",
|
||||
[],
|
||||
)?;
|
||||
|
||||
db.execute(
|
||||
"CREATE TABLE files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
worktree_id INTEGER NOT NULL,
|
||||
relative_path VARCHAR NOT NULL,
|
||||
mtime_seconds INTEGER NOT NULL,
|
||||
mtime_nanos INTEGER NOT NULL,
|
||||
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
db.execute(
|
||||
"CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
db.execute(
|
||||
"CREATE TABLE spans (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_id INTEGER NOT NULL,
|
||||
start_byte INTEGER NOT NULL,
|
||||
end_byte INTEGER NOT NULL,
|
||||
name VARCHAR NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
digest BLOB NOT NULL,
|
||||
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
db.execute(
|
||||
"CREATE INDEX spans_digest ON spans (digest)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
log::trace!("vector database initialized with updated schema.");
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn delete_file(
|
||||
&self,
|
||||
worktree_id: i64,
|
||||
delete_path: Arc<Path>,
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
self.transact(move |db| {
|
||||
db.execute(
|
||||
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
|
||||
params![worktree_id, delete_path.to_str()],
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn insert_file(
|
||||
&self,
|
||||
worktree_id: i64,
|
||||
path: Arc<Path>,
|
||||
mtime: SystemTime,
|
||||
spans: Vec<Span>,
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
self.transact(move |db| {
|
||||
// Return the existing ID, if both the file and mtime match
|
||||
let mtime = Timestamp::from(mtime);
|
||||
|
||||
db.execute(
|
||||
"
|
||||
REPLACE INTO files
|
||||
(worktree_id, relative_path, mtime_seconds, mtime_nanos)
|
||||
VALUES (?1, ?2, ?3, ?4)
|
||||
",
|
||||
params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
|
||||
)?;
|
||||
|
||||
let file_id = db.last_insert_rowid();
|
||||
|
||||
let mut query = db.prepare(
|
||||
"
|
||||
INSERT INTO spans
|
||||
(file_id, start_byte, end_byte, name, embedding, digest)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
||||
",
|
||||
)?;
|
||||
|
||||
for span in spans {
|
||||
query.execute(params![
|
||||
file_id,
|
||||
span.range.start.to_string(),
|
||||
span.range.end.to_string(),
|
||||
span.name,
|
||||
span.embedding,
|
||||
span.digest
|
||||
])?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn worktree_previously_indexed(
|
||||
&self,
|
||||
worktree_root_path: &Path,
|
||||
) -> impl Future<Output = Result<bool>> {
|
||||
let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
|
||||
self.transact(move |db| {
|
||||
let mut worktree_query =
|
||||
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
|
||||
let worktree_id =
|
||||
worktree_query.query_row(params![worktree_root_path], |row| row.get::<_, i64>(0));
|
||||
|
||||
Ok(worktree_id.is_ok())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn embeddings_for_digests(
|
||||
&self,
|
||||
digests: Vec<SpanDigest>,
|
||||
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
|
||||
self.transact(move |db| {
|
||||
let mut query = db.prepare(
|
||||
"
|
||||
SELECT digest, embedding
|
||||
FROM spans
|
||||
WHERE digest IN rarray(?)
|
||||
",
|
||||
)?;
|
||||
let mut embeddings_by_digest = HashMap::default();
|
||||
let digests = Rc::new(
|
||||
digests
|
||||
.into_iter()
|
||||
.map(|digest| Value::Blob(digest.0.to_vec()))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let rows = query.query_map(params![digests], |row| {
|
||||
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
|
||||
})?;
|
||||
|
||||
for (digest, embedding) in rows.flatten() {
|
||||
embeddings_by_digest.insert(digest, embedding);
|
||||
}
|
||||
|
||||
Ok(embeddings_by_digest)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn embeddings_for_files(
|
||||
&self,
|
||||
worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
|
||||
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
|
||||
self.transact(move |db| {
|
||||
let mut query = db.prepare(
|
||||
"
|
||||
SELECT digest, embedding
|
||||
FROM spans
|
||||
LEFT JOIN files ON files.id = spans.file_id
|
||||
WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
|
||||
",
|
||||
)?;
|
||||
let mut embeddings_by_digest = HashMap::default();
|
||||
for (worktree_id, file_paths) in worktree_id_file_paths {
|
||||
let file_paths = Rc::new(
|
||||
file_paths
|
||||
.into_iter()
|
||||
.map(|p| Value::Text(p.to_string_lossy().into_owned()))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let rows = query.query_map(params![worktree_id, file_paths], |row| {
|
||||
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
|
||||
})?;
|
||||
|
||||
for (digest, embedding) in rows.flatten() {
|
||||
embeddings_by_digest.insert(digest, embedding);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embeddings_by_digest)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn find_or_create_worktree(
|
||||
&self,
|
||||
worktree_root_path: Arc<Path>,
|
||||
) -> impl Future<Output = Result<i64>> {
|
||||
self.transact(move |db| {
|
||||
let mut worktree_query =
|
||||
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
|
||||
let worktree_id = worktree_query
|
||||
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
|
||||
row.get::<_, i64>(0)
|
||||
});
|
||||
|
||||
if worktree_id.is_ok() {
|
||||
return Ok(worktree_id?);
|
||||
}
|
||||
|
||||
// If worktree_id is Err, insert new worktree
|
||||
db.execute(
|
||||
"INSERT into worktrees (absolute_path) VALUES (?1)",
|
||||
params![worktree_root_path.to_string_lossy()],
|
||||
)?;
|
||||
Ok(db.last_insert_rowid())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_file_mtimes(
|
||||
&self,
|
||||
worktree_id: i64,
|
||||
) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
|
||||
self.transact(move |db| {
|
||||
let mut statement = db.prepare(
|
||||
"
|
||||
SELECT relative_path, mtime_seconds, mtime_nanos
|
||||
FROM files
|
||||
WHERE worktree_id = ?1
|
||||
ORDER BY relative_path",
|
||||
)?;
|
||||
let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
|
||||
for row in statement.query_map(params![worktree_id], |row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?.into(),
|
||||
Timestamp {
|
||||
seconds: row.get(1)?,
|
||||
nanos: row.get(2)?,
|
||||
}
|
||||
.into(),
|
||||
))
|
||||
})? {
|
||||
let row = row?;
|
||||
result.insert(row.0, row.1);
|
||||
}
|
||||
Ok(result)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn top_k_search(
|
||||
&self,
|
||||
query_embedding: &Embedding,
|
||||
limit: usize,
|
||||
file_ids: &[i64],
|
||||
) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
|
||||
let file_ids = file_ids.to_vec();
|
||||
let query = query_embedding.clone().0;
|
||||
let query = Array1::from_vec(query);
|
||||
self.transact(move |db| {
|
||||
let mut query_statement = db.prepare(
|
||||
"
|
||||
SELECT
|
||||
id, embedding
|
||||
FROM
|
||||
spans
|
||||
WHERE
|
||||
file_id IN rarray(?)
|
||||
",
|
||||
)?;
|
||||
|
||||
let deserialized_rows = query_statement
|
||||
.query_map(params![ids_to_sql(&file_ids)], |row| {
|
||||
Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
|
||||
})?
|
||||
.filter_map(|row| row.ok())
|
||||
.collect::<Vec<(usize, Embedding)>>();
|
||||
|
||||
if deserialized_rows.len() == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Get Length of Embeddings Returned
|
||||
let embedding_len = deserialized_rows[0].1 .0.len();
|
||||
|
||||
let batch_n = 1000;
|
||||
let mut batches = Vec::new();
|
||||
let mut batch_ids = Vec::new();
|
||||
let mut batch_embeddings: Vec<f32> = Vec::new();
|
||||
deserialized_rows.iter().for_each(|(id, embedding)| {
|
||||
batch_ids.push(id);
|
||||
batch_embeddings.extend(&embedding.0);
|
||||
|
||||
if batch_ids.len() == batch_n {
|
||||
let embeddings = std::mem::take(&mut batch_embeddings);
|
||||
let ids = std::mem::take(&mut batch_ids);
|
||||
let array = Array2::from_shape_vec((ids.len(), embedding_len), embeddings);
|
||||
match array {
|
||||
Ok(array) => {
|
||||
batches.push((ids, array));
|
||||
}
|
||||
Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if batch_ids.len() > 0 {
|
||||
let array = Array2::from_shape_vec(
|
||||
(batch_ids.len(), embedding_len),
|
||||
batch_embeddings.clone(),
|
||||
);
|
||||
match array {
|
||||
Ok(array) => {
|
||||
batches.push((batch_ids.clone(), array));
|
||||
}
|
||||
Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
|
||||
}
|
||||
}
|
||||
|
||||
let mut ids: Vec<usize> = Vec::new();
|
||||
let mut results = Vec::new();
|
||||
for (batch_ids, array) in batches {
|
||||
let scores = array
|
||||
.dot(&query.t())
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|score| OrderedFloat(*score))
|
||||
.collect::<Vec<OrderedFloat<f32>>>();
|
||||
results.extend(scores);
|
||||
ids.extend(batch_ids);
|
||||
}
|
||||
|
||||
let sorted_idx = argsort(&results);
|
||||
let mut sorted_results = Vec::new();
|
||||
let last_idx = limit.min(sorted_idx.len());
|
||||
for idx in &sorted_idx[0..last_idx] {
|
||||
sorted_results.push((ids[*idx] as i64, results[*idx]))
|
||||
}
|
||||
|
||||
Ok(sorted_results)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn retrieve_included_file_ids(
|
||||
&self,
|
||||
worktree_ids: &[i64],
|
||||
includes: &[PathMatcher],
|
||||
excludes: &[PathMatcher],
|
||||
) -> impl Future<Output = Result<Vec<i64>>> {
|
||||
let worktree_ids = worktree_ids.to_vec();
|
||||
let includes = includes.to_vec();
|
||||
let excludes = excludes.to_vec();
|
||||
self.transact(move |db| {
|
||||
let mut file_query = db.prepare(
|
||||
"
|
||||
SELECT
|
||||
id, relative_path
|
||||
FROM
|
||||
files
|
||||
WHERE
|
||||
worktree_id IN rarray(?)
|
||||
",
|
||||
)?;
|
||||
|
||||
let mut file_ids = Vec::<i64>::new();
|
||||
let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
|
||||
|
||||
while let Some(row) = rows.next()? {
|
||||
let file_id = row.get(0)?;
|
||||
let relative_path = row.get_ref(1)?.as_str()?;
|
||||
let included =
|
||||
includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
|
||||
let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
|
||||
if included && !excluded {
|
||||
file_ids.push(file_id);
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(file_ids)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn spans_for_ids(
|
||||
&self,
|
||||
ids: &[i64],
|
||||
) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
|
||||
let ids = ids.to_vec();
|
||||
self.transact(move |db| {
|
||||
let mut statement = db.prepare(
|
||||
"
|
||||
SELECT
|
||||
spans.id,
|
||||
files.worktree_id,
|
||||
files.relative_path,
|
||||
spans.start_byte,
|
||||
spans.end_byte
|
||||
FROM
|
||||
spans, files
|
||||
WHERE
|
||||
spans.file_id = files.id AND
|
||||
spans.id in rarray(?)
|
||||
",
|
||||
)?;
|
||||
|
||||
let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
|
||||
Ok((
|
||||
row.get::<_, i64>(0)?,
|
||||
row.get::<_, i64>(1)?,
|
||||
row.get::<_, String>(2)?.into(),
|
||||
row.get(3)?..row.get(4)?,
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
|
||||
for row in result_iter {
|
||||
let (id, worktree_id, path, range) = row?;
|
||||
values_by_id.insert(id, (worktree_id, path, range));
|
||||
}
|
||||
|
||||
let mut results = Vec::with_capacity(ids.len());
|
||||
for id in &ids {
|
||||
let value = values_by_id
|
||||
.remove(id)
|
||||
.ok_or(anyhow!("missing span id {}", id))?;
|
||||
results.push(value);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
|
||||
Rc::new(
|
||||
ids.iter()
|
||||
.copied()
|
||||
.map(|v| rusqlite::types::Value::from(v))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
|
@ -1,169 +0,0 @@
|
|||
use crate::{parsing::Span, JobHandle};
|
||||
use ai::embedding::EmbeddingProvider;
|
||||
use gpui::BackgroundExecutor;
|
||||
use parking_lot::Mutex;
|
||||
use smol::channel;
|
||||
use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FileToEmbed {
|
||||
pub worktree_id: i64,
|
||||
pub path: Arc<Path>,
|
||||
pub mtime: SystemTime,
|
||||
pub spans: Vec<Span>,
|
||||
pub job_handle: JobHandle,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for FileToEmbed {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("FileToEmbed")
|
||||
.field("worktree_id", &self.worktree_id)
|
||||
.field("path", &self.path)
|
||||
.field("mtime", &self.mtime)
|
||||
.field("spans", &self.spans)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for FileToEmbed {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.worktree_id == other.worktree_id
|
||||
&& self.path == other.path
|
||||
&& self.mtime == other.mtime
|
||||
&& self.spans == other.spans
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EmbeddingQueue {
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
pending_batch: Vec<FileFragmentToEmbed>,
|
||||
executor: BackgroundExecutor,
|
||||
pending_batch_token_count: usize,
|
||||
finished_files_tx: channel::Sender<FileToEmbed>,
|
||||
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FileFragmentToEmbed {
|
||||
file: Arc<Mutex<FileToEmbed>>,
|
||||
span_range: Range<usize>,
|
||||
}
|
||||
|
||||
impl EmbeddingQueue {
|
||||
pub fn new(
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Self {
|
||||
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
||||
Self {
|
||||
embedding_provider,
|
||||
executor,
|
||||
pending_batch: Vec::new(),
|
||||
pending_batch_token_count: 0,
|
||||
finished_files_tx,
|
||||
finished_files_rx,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(&mut self, file: FileToEmbed) {
|
||||
if file.spans.is_empty() {
|
||||
self.finished_files_tx.try_send(file).unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
let file = Arc::new(Mutex::new(file));
|
||||
|
||||
self.pending_batch.push(FileFragmentToEmbed {
|
||||
file: file.clone(),
|
||||
span_range: 0..0,
|
||||
});
|
||||
|
||||
let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
|
||||
for (ix, span) in file.lock().spans.iter().enumerate() {
|
||||
let span_token_count = if span.embedding.is_none() {
|
||||
span.token_count
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let next_token_count = self.pending_batch_token_count + span_token_count;
|
||||
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
|
||||
let range_end = fragment_range.end;
|
||||
self.flush();
|
||||
self.pending_batch.push(FileFragmentToEmbed {
|
||||
file: file.clone(),
|
||||
span_range: range_end..range_end,
|
||||
});
|
||||
fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
|
||||
}
|
||||
|
||||
fragment_range.end = ix + 1;
|
||||
self.pending_batch_token_count += span_token_count;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flush(&mut self) {
|
||||
let batch = mem::take(&mut self.pending_batch);
|
||||
self.pending_batch_token_count = 0;
|
||||
if batch.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let finished_files_tx = self.finished_files_tx.clone();
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
|
||||
self.executor
|
||||
.spawn(async move {
|
||||
let mut spans = Vec::new();
|
||||
for fragment in &batch {
|
||||
let file = fragment.file.lock();
|
||||
spans.extend(
|
||||
file.spans[fragment.span_range.clone()]
|
||||
.iter()
|
||||
.filter(|d| d.embedding.is_none())
|
||||
.map(|d| d.content.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
// If spans is 0, just send the fragment to the finished files if its the last one.
|
||||
if spans.is_empty() {
|
||||
for fragment in batch.clone() {
|
||||
if let Some(file) = Arc::into_inner(fragment.file) {
|
||||
finished_files_tx.try_send(file.into_inner()).unwrap();
|
||||
}
|
||||
}
|
||||
return;
|
||||
};
|
||||
|
||||
match embedding_provider.embed_batch(spans).await {
|
||||
Ok(embeddings) => {
|
||||
let mut embeddings = embeddings.into_iter();
|
||||
for fragment in batch {
|
||||
for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
|
||||
.iter_mut()
|
||||
.filter(|d| d.embedding.is_none())
|
||||
{
|
||||
if let Some(embedding) = embeddings.next() {
|
||||
span.embedding = Some(embedding);
|
||||
} else {
|
||||
log::error!("number of embeddings != number of documents");
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(file) = Arc::into_inner(fragment.file) {
|
||||
finished_files_tx.try_send(file.into_inner()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!("{:?}", error);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
|
||||
self.finished_files_rx.clone()
|
||||
}
|
||||
}
|
|
@ -1,414 +0,0 @@
|
|||
use ai::{
|
||||
embedding::{Embedding, EmbeddingProvider},
|
||||
models::TruncationDirection,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use collections::HashSet;
|
||||
use language::{Grammar, Language};
|
||||
use rusqlite::{
|
||||
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
|
||||
ToSql,
|
||||
};
|
||||
use sha1::{Digest, Sha1};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
cmp::{self, Reverse},
|
||||
ops::Range,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
};
|
||||
use tree_sitter::{Parser, QueryCursor};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
|
||||
pub struct SpanDigest(pub [u8; 20]);
|
||||
|
||||
impl FromSql for SpanDigest {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let blob = value.as_blob()?;
|
||||
let bytes =
|
||||
blob.try_into()
|
||||
.map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
|
||||
expected_size: 20,
|
||||
blob_size: blob.len(),
|
||||
})?;
|
||||
return Ok(SpanDigest(bytes));
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for SpanDigest {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||
self.0.to_sql()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&'_ str> for SpanDigest {
|
||||
fn from(value: &'_ str) -> Self {
|
||||
let mut sha1 = Sha1::new();
|
||||
sha1.update(value);
|
||||
Self(sha1.finalize().into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Span {
|
||||
pub name: String,
|
||||
pub range: Range<usize>,
|
||||
pub content: String,
|
||||
pub embedding: Option<Embedding>,
|
||||
pub digest: SpanDigest,
|
||||
pub token_count: usize,
|
||||
}
|
||||
|
||||
const CODE_CONTEXT_TEMPLATE: &str =
|
||||
"The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
|
||||
const ENTIRE_FILE_TEMPLATE: &str =
|
||||
"The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
|
||||
const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
|
||||
pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &[
|
||||
"TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML", "Scheme",
|
||||
];
|
||||
|
||||
pub struct CodeContextRetriever {
|
||||
pub parser: Parser,
|
||||
pub cursor: QueryCursor,
|
||||
pub embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
}
|
||||
|
||||
// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
|
||||
// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
|
||||
// If there are preceding comments, we track this with a context capture
|
||||
// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
|
||||
// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CodeContextMatch {
|
||||
pub start_col: usize,
|
||||
pub item_range: Option<Range<usize>>,
|
||||
pub name_range: Option<Range<usize>>,
|
||||
pub context_ranges: Vec<Range<usize>>,
|
||||
pub collapse_ranges: Vec<Range<usize>>,
|
||||
}
|
||||
|
||||
impl CodeContextRetriever {
|
||||
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
|
||||
Self {
|
||||
parser: Parser::new(),
|
||||
cursor: QueryCursor::new(),
|
||||
embedding_provider,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_entire_file(
|
||||
&self,
|
||||
relative_path: Option<&Path>,
|
||||
language_name: Arc<str>,
|
||||
content: &str,
|
||||
) -> Result<Vec<Span>> {
|
||||
let document_span = ENTIRE_FILE_TEMPLATE
|
||||
.replace(
|
||||
"<path>",
|
||||
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
|
||||
)
|
||||
.replace("<language>", language_name.as_ref())
|
||||
.replace("<item>", &content);
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
let model = self.embedding_provider.base_model();
|
||||
let document_span = model.truncate(
|
||||
&document_span,
|
||||
model.capacity()?,
|
||||
ai::models::TruncationDirection::End,
|
||||
)?;
|
||||
let token_count = model.count_tokens(&document_span)?;
|
||||
|
||||
Ok(vec![Span {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
embedding: Default::default(),
|
||||
name: language_name.to_string(),
|
||||
digest,
|
||||
token_count,
|
||||
}])
|
||||
}
|
||||
|
||||
fn parse_markdown_file(
|
||||
&self,
|
||||
relative_path: Option<&Path>,
|
||||
content: &str,
|
||||
) -> Result<Vec<Span>> {
|
||||
let document_span = MARKDOWN_CONTEXT_TEMPLATE
|
||||
.replace(
|
||||
"<path>",
|
||||
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
|
||||
)
|
||||
.replace("<item>", &content);
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
|
||||
let model = self.embedding_provider.base_model();
|
||||
let document_span = model.truncate(
|
||||
&document_span,
|
||||
model.capacity()?,
|
||||
ai::models::TruncationDirection::End,
|
||||
)?;
|
||||
let token_count = model.count_tokens(&document_span)?;
|
||||
|
||||
Ok(vec![Span {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
embedding: None,
|
||||
name: "Markdown".to_string(),
|
||||
digest,
|
||||
token_count,
|
||||
}])
|
||||
}
|
||||
|
||||
fn get_matches_in_file(
|
||||
&mut self,
|
||||
content: &str,
|
||||
grammar: &Arc<Grammar>,
|
||||
) -> Result<Vec<CodeContextMatch>> {
|
||||
let embedding_config = grammar
|
||||
.embedding_config
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("no embedding queries"))?;
|
||||
self.parser.set_language(&grammar.ts_language).unwrap();
|
||||
|
||||
let tree = self
|
||||
.parser
|
||||
.parse(&content, None)
|
||||
.ok_or_else(|| anyhow!("parsing failed"))?;
|
||||
|
||||
let mut captures: Vec<CodeContextMatch> = Vec::new();
|
||||
let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
|
||||
let mut keep_ranges: Vec<Range<usize>> = Vec::new();
|
||||
for mat in self.cursor.matches(
|
||||
&embedding_config.query,
|
||||
tree.root_node(),
|
||||
content.as_bytes(),
|
||||
) {
|
||||
let mut start_col = 0;
|
||||
let mut item_range: Option<Range<usize>> = None;
|
||||
let mut name_range: Option<Range<usize>> = None;
|
||||
let mut context_ranges: Vec<Range<usize>> = Vec::new();
|
||||
collapse_ranges.clear();
|
||||
keep_ranges.clear();
|
||||
for capture in mat.captures {
|
||||
if capture.index == embedding_config.item_capture_ix {
|
||||
item_range = Some(capture.node.byte_range());
|
||||
start_col = capture.node.start_position().column;
|
||||
} else if Some(capture.index) == embedding_config.name_capture_ix {
|
||||
name_range = Some(capture.node.byte_range());
|
||||
} else if Some(capture.index) == embedding_config.context_capture_ix {
|
||||
context_ranges.push(capture.node.byte_range());
|
||||
} else if Some(capture.index) == embedding_config.collapse_capture_ix {
|
||||
collapse_ranges.push(capture.node.byte_range());
|
||||
} else if Some(capture.index) == embedding_config.keep_capture_ix {
|
||||
keep_ranges.push(capture.node.byte_range());
|
||||
}
|
||||
}
|
||||
|
||||
captures.push(CodeContextMatch {
|
||||
start_col,
|
||||
item_range,
|
||||
name_range,
|
||||
context_ranges,
|
||||
collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
|
||||
});
|
||||
}
|
||||
Ok(captures)
|
||||
}
|
||||
|
||||
pub fn parse_file_with_template(
|
||||
&mut self,
|
||||
relative_path: Option<&Path>,
|
||||
content: &str,
|
||||
language: Arc<Language>,
|
||||
) -> Result<Vec<Span>> {
|
||||
let language_name = language.name();
|
||||
|
||||
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
|
||||
return self.parse_entire_file(relative_path, language_name, &content);
|
||||
} else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
|
||||
return self.parse_markdown_file(relative_path, &content);
|
||||
}
|
||||
|
||||
let mut spans = self.parse_file(content, language)?;
|
||||
for span in &mut spans {
|
||||
let document_content = CODE_CONTEXT_TEMPLATE
|
||||
.replace(
|
||||
"<path>",
|
||||
&relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
|
||||
)
|
||||
.replace("<language>", language_name.as_ref())
|
||||
.replace("item", &span.content);
|
||||
|
||||
let model = self.embedding_provider.base_model();
|
||||
let document_content = model.truncate(
|
||||
&document_content,
|
||||
model.capacity()?,
|
||||
TruncationDirection::End,
|
||||
)?;
|
||||
let token_count = model.count_tokens(&document_content)?;
|
||||
|
||||
span.content = document_content;
|
||||
span.token_count = token_count;
|
||||
}
|
||||
Ok(spans)
|
||||
}
|
||||
|
||||
pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
|
||||
let grammar = language
|
||||
.grammar()
|
||||
.ok_or_else(|| anyhow!("no grammar for language"))?;
|
||||
|
||||
// Iterate through query matches
|
||||
let matches = self.get_matches_in_file(content, grammar)?;
|
||||
|
||||
let language_scope = language.default_scope();
|
||||
let placeholder = language_scope.collapsed_placeholder();
|
||||
|
||||
let mut spans = Vec::new();
|
||||
let mut collapsed_ranges_within = Vec::new();
|
||||
let mut parsed_name_ranges = HashSet::default();
|
||||
for (i, context_match) in matches.iter().enumerate() {
|
||||
// Items which are collapsible but not embeddable have no item range
|
||||
let item_range = if let Some(item_range) = context_match.item_range.clone() {
|
||||
item_range
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Checks for deduplication
|
||||
let name;
|
||||
if let Some(name_range) = context_match.name_range.clone() {
|
||||
name = content
|
||||
.get(name_range.clone())
|
||||
.map_or(String::new(), |s| s.to_string());
|
||||
if parsed_name_ranges.contains(&name_range) {
|
||||
continue;
|
||||
}
|
||||
parsed_name_ranges.insert(name_range);
|
||||
} else {
|
||||
name = String::new();
|
||||
}
|
||||
|
||||
collapsed_ranges_within.clear();
|
||||
'outer: for remaining_match in &matches[(i + 1)..] {
|
||||
for collapsed_range in &remaining_match.collapse_ranges {
|
||||
if item_range.start <= collapsed_range.start
|
||||
&& item_range.end >= collapsed_range.end
|
||||
{
|
||||
collapsed_ranges_within.push(collapsed_range.clone());
|
||||
} else {
|
||||
break 'outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
|
||||
|
||||
let mut span_content = String::new();
|
||||
for context_range in &context_match.context_ranges {
|
||||
add_content_from_range(
|
||||
&mut span_content,
|
||||
content,
|
||||
context_range.clone(),
|
||||
context_match.start_col,
|
||||
);
|
||||
span_content.push_str("\n");
|
||||
}
|
||||
|
||||
let mut offset = item_range.start;
|
||||
for collapsed_range in &collapsed_ranges_within {
|
||||
if collapsed_range.start > offset {
|
||||
add_content_from_range(
|
||||
&mut span_content,
|
||||
content,
|
||||
offset..collapsed_range.start,
|
||||
context_match.start_col,
|
||||
);
|
||||
offset = collapsed_range.start;
|
||||
}
|
||||
|
||||
if collapsed_range.end > offset {
|
||||
span_content.push_str(placeholder);
|
||||
offset = collapsed_range.end;
|
||||
}
|
||||
}
|
||||
|
||||
if offset < item_range.end {
|
||||
add_content_from_range(
|
||||
&mut span_content,
|
||||
content,
|
||||
offset..item_range.end,
|
||||
context_match.start_col,
|
||||
);
|
||||
}
|
||||
|
||||
let sha1 = SpanDigest::from(span_content.as_str());
|
||||
spans.push(Span {
|
||||
name,
|
||||
content: span_content,
|
||||
range: item_range.clone(),
|
||||
embedding: None,
|
||||
digest: sha1,
|
||||
token_count: 0,
|
||||
})
|
||||
}
|
||||
|
||||
return Ok(spans);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn subtract_ranges(
|
||||
ranges: &[Range<usize>],
|
||||
ranges_to_subtract: &[Range<usize>],
|
||||
) -> Vec<Range<usize>> {
|
||||
let mut result = Vec::new();
|
||||
|
||||
let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
|
||||
|
||||
for range in ranges {
|
||||
let mut offset = range.start;
|
||||
|
||||
while offset < range.end {
|
||||
if let Some(range_to_subtract) = ranges_to_subtract.peek() {
|
||||
if offset < range_to_subtract.start {
|
||||
let next_offset = cmp::min(range_to_subtract.start, range.end);
|
||||
result.push(offset..next_offset);
|
||||
offset = next_offset;
|
||||
} else {
|
||||
let next_offset = cmp::min(range_to_subtract.end, range.end);
|
||||
offset = next_offset;
|
||||
}
|
||||
|
||||
if offset >= range_to_subtract.end {
|
||||
ranges_to_subtract.next();
|
||||
}
|
||||
} else {
|
||||
result.push(offset..range.end);
|
||||
offset = range.end;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn add_content_from_range(
|
||||
output: &mut String,
|
||||
content: &str,
|
||||
range: Range<usize>,
|
||||
start_col: usize,
|
||||
) {
|
||||
for mut line in content.get(range.clone()).unwrap_or("").lines() {
|
||||
for _ in 0..start_col {
|
||||
if line.starts_with(' ') {
|
||||
line = &line[1..];
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
output.push_str(line);
|
||||
output.push('\n');
|
||||
}
|
||||
output.pop();
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,33 +0,0 @@
|
|||
use anyhow;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct SemanticIndexSettings {
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
/// Configuration of semantic index, an alternate search engine available in
|
||||
/// project search.
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
pub struct SemanticIndexSettingsContent {
|
||||
/// Whether or not to display the Semantic mode in project search.
|
||||
///
|
||||
/// Default: true
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
impl Settings for SemanticIndexSettings {
|
||||
const KEY: Option<&'static str> = Some("semantic_index");
|
||||
|
||||
type FileContent = SemanticIndexSettingsContent;
|
||||
|
||||
fn load(
|
||||
default_value: &Self::FileContent,
|
||||
user_values: &[&Self::FileContent],
|
||||
_: &mut gpui::AppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
Self::load_via_json_merge(default_value, user_values)
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -479,7 +479,28 @@ impl SettingsStore {
|
|||
merge_schema(target_schema, setting_schema.schema);
|
||||
}
|
||||
|
||||
fn merge_schema(target: &mut SchemaObject, source: SchemaObject) {
|
||||
fn merge_schema(target: &mut SchemaObject, mut source: SchemaObject) {
|
||||
let source_subschemas = source.subschemas();
|
||||
let target_subschemas = target.subschemas();
|
||||
if let Some(all_of) = source_subschemas.all_of.take() {
|
||||
target_subschemas
|
||||
.all_of
|
||||
.get_or_insert(Vec::new())
|
||||
.extend(all_of);
|
||||
}
|
||||
if let Some(any_of) = source_subschemas.any_of.take() {
|
||||
target_subschemas
|
||||
.any_of
|
||||
.get_or_insert(Vec::new())
|
||||
.extend(any_of);
|
||||
}
|
||||
if let Some(one_of) = source_subschemas.one_of.take() {
|
||||
target_subschemas
|
||||
.one_of
|
||||
.get_or_insert(Vec::new())
|
||||
.extend(one_of);
|
||||
}
|
||||
|
||||
if let Some(source) = source.object {
|
||||
let target_properties = &mut target.object().properties;
|
||||
for (key, value) in source.properties {
|
||||
|
|
|
@ -5,9 +5,8 @@ use futures_lite::FutureExt;
|
|||
use isahc::config::{Configurable, RedirectPolicy};
|
||||
pub use isahc::{
|
||||
http::{Method, StatusCode, Uri},
|
||||
Error,
|
||||
AsyncBody, Error, HttpClient as IsahcHttpClient, Request, Response,
|
||||
};
|
||||
pub use isahc::{AsyncBody, Request, Response};
|
||||
#[cfg(feature = "test-support")]
|
||||
use std::fmt;
|
||||
use std::{
|
||||
|
|
|
@ -71,7 +71,6 @@ recent_projects.workspace = true
|
|||
release_channel.workspace = true
|
||||
rope.workspace = true
|
||||
search.workspace = true
|
||||
semantic_index.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
|
|
|
@ -174,7 +174,7 @@ fn main() {
|
|||
node_runtime.clone(),
|
||||
cx,
|
||||
);
|
||||
assistant::init(cx);
|
||||
assistant::init(client.clone(), cx);
|
||||
|
||||
extension::init(
|
||||
fs.clone(),
|
||||
|
@ -247,7 +247,6 @@ fn main() {
|
|||
tasks_ui::init(cx);
|
||||
channel::init(&client, user_store.clone(), cx);
|
||||
search::init(cx);
|
||||
semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx);
|
||||
vim::init(cx);
|
||||
terminal_view::init(cx);
|
||||
|
||||
|
|
|
@ -3060,7 +3060,7 @@ mod tests {
|
|||
collab_ui::init(&app_state, cx);
|
||||
project_panel::init((), cx);
|
||||
terminal_view::init(cx);
|
||||
assistant::init(cx);
|
||||
assistant::init(app_state.client.clone(), cx);
|
||||
initialize_workspace(app_state.clone(), cx);
|
||||
app_state
|
||||
})
|
||||
|
|
|
@ -606,28 +606,6 @@ These values take in the same options as the root-level settings with the same n
|
|||
|
||||
`boolean` values
|
||||
|
||||
## Semantic Index
|
||||
|
||||
- Description: Settings related to semantic index.
|
||||
- Setting: `semantic_index`
|
||||
- Default:
|
||||
|
||||
```json
|
||||
"semantic_index": {
|
||||
"enabled": false
|
||||
},
|
||||
```
|
||||
|
||||
### Enabled
|
||||
|
||||
- Description: Whether or not to display the `Semantic` mode in project search.
|
||||
- Setting: `enabled`
|
||||
- Default: `true`
|
||||
|
||||
**Options**
|
||||
|
||||
`boolean` values
|
||||
|
||||
## Show Call Status Icon
|
||||
|
||||
- Description: Whether or not to show the call status icon in the status bar.
|
||||
|
|
|
@ -11,3 +11,8 @@ cargo run -p collab -- migrate
|
|||
|
||||
echo "seeding database..."
|
||||
script/seed-db
|
||||
|
||||
if [[ "$OSTYPE" == "linux-gnu"* ]]; then
|
||||
echo "Linux dependencies..."
|
||||
script/linux
|
||||
fi
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
RUST_LOG=semantic_index=trace cargo run --example semantic_index_eval --release
|
91
script/gemini.py
Normal file
91
script/gemini.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
import subprocess
|
||||
import json
|
||||
import http.client
|
||||
import mimetypes
|
||||
import os
|
||||
|
||||
def get_text_files():
|
||||
text_files = []
|
||||
# List all files tracked by Git
|
||||
git_files_proc = subprocess.run(['git', 'ls-files'], stdout=subprocess.PIPE, text=True)
|
||||
for file in git_files_proc.stdout.strip().split('\n'):
|
||||
# Check MIME type for each file
|
||||
mime_check_proc = subprocess.run(['file', '--mime', file], stdout=subprocess.PIPE, text=True)
|
||||
if 'text' in mime_check_proc.stdout:
|
||||
text_files.append(file)
|
||||
|
||||
print(f"File count: {len(text_files)}")
|
||||
|
||||
return text_files
|
||||
|
||||
def get_file_contents(file):
|
||||
# Read file content
|
||||
with open(file, 'r') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def main():
|
||||
GEMINI_API_KEY = os.environ.get('GEMINI_API_KEY')
|
||||
|
||||
# Your prompt
|
||||
prompt = "Document the data types and dataflow in this codebase in preparation to port a streaming implementation to rust:\n\n"
|
||||
# Fetch all text files
|
||||
text_files = get_text_files()
|
||||
code_blocks = []
|
||||
for file in text_files:
|
||||
file_contents = get_file_contents(file)
|
||||
# Create a code block for each text file
|
||||
code_blocks.append(f"\n`{file}`\n\n```{file_contents}```\n")
|
||||
|
||||
# Construct the JSON payload
|
||||
payload = json.dumps({
|
||||
"contents": [{
|
||||
"parts": [{
|
||||
"text": prompt + "".join(code_blocks)
|
||||
}]
|
||||
}]
|
||||
})
|
||||
|
||||
# Prepare the HTTP connection
|
||||
conn = http.client.HTTPSConnection("generativelanguage.googleapis.com")
|
||||
|
||||
# Define headers
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Content-Length': str(len(payload))
|
||||
}
|
||||
|
||||
# Output the content length in bytes
|
||||
print(f"Content Length in kilobytes: {len(payload.encode('utf-8')) / 1024:.2f} KB")
|
||||
|
||||
|
||||
# Send a request to count the tokens
|
||||
conn.request("POST", f"/v1beta/models/gemini-1.5-pro-latest:countTokens?key={GEMINI_API_KEY}", body=payload, headers=headers)
|
||||
# Get the response
|
||||
response = conn.getresponse()
|
||||
if response.status == 200:
|
||||
token_count = json.loads(response.read().decode('utf-8')).get('totalTokens')
|
||||
print(f"Token count: {token_count}")
|
||||
else:
|
||||
print(f"Failed to get token count. Status code: {response.status}, Response body: {response.read().decode('utf-8')}")
|
||||
|
||||
|
||||
# Prepare the HTTP connection
|
||||
conn = http.client.HTTPSConnection("generativelanguage.googleapis.com")
|
||||
conn.request("GET", f"/v1beta/models/gemini-1.5-pro-latest:streamGenerateContent?key={GEMINI_API_KEY}", body=payload, headers=headers)
|
||||
|
||||
# Get the response in a streaming manner
|
||||
response = conn.getresponse()
|
||||
if response.status == 200:
|
||||
print("Successfully sent the data to the API.")
|
||||
# Read the response in chunks
|
||||
while chunk := response.read(4096):
|
||||
print(chunk.decode('utf-8'))
|
||||
else:
|
||||
print(f"Failed to send the data to the API. Status code: {response.status}, Response body: {response.read().decode('utf-8')}")
|
||||
|
||||
# Close the connection
|
||||
conn.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,4 +1,6 @@
|
|||
#!/usr/bin/bash -e
|
||||
#!/usr/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# if sudo is not installed, define an empty alias
|
||||
maysudo=$(command -v sudo || command -v doas || true)
|
||||
|
|
1
script/script.py
Normal file
1
script/script.py
Normal file
|
@ -0,0 +1 @@
|
|||
|
|
@ -3,12 +3,15 @@
|
|||
set -e
|
||||
|
||||
# Install sqlx-cli if needed
|
||||
[[ "$(sqlx --version)" == "sqlx-cli 0.5.7" ]] || cargo install sqlx-cli --version 0.5.7
|
||||
if [[ "$(sqlx --version)" != "sqlx-cli 0.5.7" ]]; then
|
||||
echo "sqlx-cli not found or not the required version, installing version 0.5.7..."
|
||||
cargo install sqlx-cli --version 0.5.7
|
||||
fi
|
||||
|
||||
cd crates/collab
|
||||
|
||||
# Export contents of .env.toml
|
||||
eval "$(cargo run --quiet --bin dotenv)"
|
||||
eval "$(cargo run --bin dotenv)"
|
||||
|
||||
# Run sqlx command
|
||||
sqlx $@
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue