/auto (#16696)
Add `/auto` behind a feature flag that's disabled for now, even for staff. We've decided on a different design for context inference, but there are parts of /auto that will be useful for that, so we want them in the code base even if they're unused for now. Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
parent
93a3e8bc94
commit
91ffa02e2c
42 changed files with 2776 additions and 1054 deletions
27
Cargo.lock
generated
27
Cargo.lock
generated
|
@ -304,6 +304,9 @@ name = "arrayvec"
|
||||||
version = "0.7.6"
|
version = "0.7.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "as-raw-xcb-connection"
|
name = "as-raw-xcb-connection"
|
||||||
|
@ -1709,6 +1712,19 @@ dependencies = [
|
||||||
"profiling",
|
"profiling",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "blake3"
|
||||||
|
version = "1.5.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7"
|
||||||
|
dependencies = [
|
||||||
|
"arrayref",
|
||||||
|
"arrayvec",
|
||||||
|
"cc",
|
||||||
|
"cfg-if",
|
||||||
|
"constant_time_eq",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "block"
|
name = "block"
|
||||||
version = "0.1.6"
|
version = "0.1.6"
|
||||||
|
@ -2752,6 +2768,12 @@ dependencies = [
|
||||||
"tiny-keccak",
|
"tiny-keccak",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "constant_time_eq"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "context_servers"
|
name = "context_servers"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
@ -4187,6 +4209,7 @@ dependencies = [
|
||||||
name = "feature_flags"
|
name = "feature_flags"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"futures 0.3.30",
|
||||||
"gpui",
|
"gpui",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -9814,10 +9837,13 @@ name = "semantic_index"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"arrayvec",
|
||||||
|
"blake3",
|
||||||
"client",
|
"client",
|
||||||
"clock",
|
"clock",
|
||||||
"collections",
|
"collections",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
|
"feature_flags",
|
||||||
"fs",
|
"fs",
|
||||||
"futures 0.3.30",
|
"futures 0.3.30",
|
||||||
"futures-batch",
|
"futures-batch",
|
||||||
|
@ -9825,6 +9851,7 @@ dependencies = [
|
||||||
"heed",
|
"heed",
|
||||||
"http_client",
|
"http_client",
|
||||||
"language",
|
"language",
|
||||||
|
"language_model",
|
||||||
"languages",
|
"languages",
|
||||||
"log",
|
"log",
|
||||||
"open_ai",
|
"open_ai",
|
||||||
|
|
|
@ -309,6 +309,7 @@ aho-corasick = "1.1"
|
||||||
alacritty_terminal = { git = "https://github.com/alacritty/alacritty", rev = "91d034ff8b53867143c005acfaa14609147c9a2c" }
|
alacritty_terminal = { git = "https://github.com/alacritty/alacritty", rev = "91d034ff8b53867143c005acfaa14609147c9a2c" }
|
||||||
any_vec = "0.14"
|
any_vec = "0.14"
|
||||||
anyhow = "1.0.86"
|
anyhow = "1.0.86"
|
||||||
|
arrayvec = { version = "0.7.4", features = ["serde"] }
|
||||||
ashpd = "0.9.1"
|
ashpd = "0.9.1"
|
||||||
async-compression = { version = "0.4", features = ["gzip", "futures-io"] }
|
async-compression = { version = "0.4", features = ["gzip", "futures-io"] }
|
||||||
async-dispatcher = "0.1"
|
async-dispatcher = "0.1"
|
||||||
|
@ -325,6 +326,7 @@ bitflags = "2.6.0"
|
||||||
blade-graphics = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
|
blade-graphics = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
|
||||||
blade-macros = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
|
blade-macros = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
|
||||||
blade-util = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
|
blade-util = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
|
||||||
|
blake3 = "1.5.3"
|
||||||
cargo_metadata = "0.18"
|
cargo_metadata = "0.18"
|
||||||
cargo_toml = "0.20"
|
cargo_toml = "0.20"
|
||||||
chrono = { version = "0.4", features = ["serde"] }
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
|
|
|
@ -37,13 +37,13 @@ use language_model::{
|
||||||
pub(crate) use model_selector::*;
|
pub(crate) use model_selector::*;
|
||||||
pub use prompts::PromptBuilder;
|
pub use prompts::PromptBuilder;
|
||||||
use prompts::PromptLoadingParams;
|
use prompts::PromptLoadingParams;
|
||||||
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
|
use semantic_index::{CloudEmbeddingProvider, SemanticDb};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::{update_settings_file, Settings, SettingsStore};
|
use settings::{update_settings_file, Settings, SettingsStore};
|
||||||
use slash_command::{
|
use slash_command::{
|
||||||
context_server_command, default_command, diagnostics_command, docs_command, fetch_command,
|
auto_command, context_server_command, default_command, diagnostics_command, docs_command,
|
||||||
file_command, now_command, project_command, prompt_command, search_command, symbols_command,
|
fetch_command, file_command, now_command, project_command, prompt_command, search_command,
|
||||||
tab_command, terminal_command, workflow_command,
|
symbols_command, tab_command, terminal_command, workflow_command,
|
||||||
};
|
};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -210,12 +210,13 @@ pub fn init(
|
||||||
let client = client.clone();
|
let client = client.clone();
|
||||||
async move {
|
async move {
|
||||||
let embedding_provider = CloudEmbeddingProvider::new(client.clone());
|
let embedding_provider = CloudEmbeddingProvider::new(client.clone());
|
||||||
let semantic_index = SemanticIndex::new(
|
let semantic_index = SemanticDb::new(
|
||||||
paths::embeddings_dir().join("semantic-index-db.0.mdb"),
|
paths::embeddings_dir().join("semantic-index-db.0.mdb"),
|
||||||
Arc::new(embedding_provider),
|
Arc::new(embedding_provider),
|
||||||
&mut cx,
|
&mut cx,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
cx.update(|cx| cx.set_global(semantic_index))
|
cx.update(|cx| cx.set_global(semantic_index))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -364,6 +365,7 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) {
|
||||||
|
|
||||||
fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
|
fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
|
||||||
let slash_command_registry = SlashCommandRegistry::global(cx);
|
let slash_command_registry = SlashCommandRegistry::global(cx);
|
||||||
|
|
||||||
slash_command_registry.register_command(file_command::FileSlashCommand, true);
|
slash_command_registry.register_command(file_command::FileSlashCommand, true);
|
||||||
slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
|
slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
|
||||||
slash_command_registry.register_command(tab_command::TabSlashCommand, true);
|
slash_command_registry.register_command(tab_command::TabSlashCommand, true);
|
||||||
|
@ -382,6 +384,17 @@ fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut
|
||||||
}
|
}
|
||||||
slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
|
slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
|
||||||
|
|
||||||
|
cx.observe_flag::<auto_command::AutoSlashCommandFeatureFlag, _>({
|
||||||
|
let slash_command_registry = slash_command_registry.clone();
|
||||||
|
move |is_enabled, _cx| {
|
||||||
|
if is_enabled {
|
||||||
|
// [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
|
||||||
|
slash_command_registry.register_command(auto_command::AutoCommand, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
|
||||||
update_slash_commands_from_settings(cx);
|
update_slash_commands_from_settings(cx);
|
||||||
cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
|
cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
|
||||||
.detach();
|
.detach();
|
||||||
|
|
|
@ -4723,6 +4723,20 @@ impl Render for ContextEditorToolbarItem {
|
||||||
let weak_self = cx.view().downgrade();
|
let weak_self = cx.view().downgrade();
|
||||||
let right_side = h_flex()
|
let right_side = h_flex()
|
||||||
.gap_2()
|
.gap_2()
|
||||||
|
// TODO display this in a nicer way, once we have a design for it.
|
||||||
|
// .children({
|
||||||
|
// let project = self
|
||||||
|
// .workspace
|
||||||
|
// .upgrade()
|
||||||
|
// .map(|workspace| workspace.read(cx).project().downgrade());
|
||||||
|
//
|
||||||
|
// let scan_items_remaining = cx.update_global(|db: &mut SemanticDb, cx| {
|
||||||
|
// project.and_then(|project| db.remaining_summaries(&project, cx))
|
||||||
|
// });
|
||||||
|
|
||||||
|
// scan_items_remaining
|
||||||
|
// .map(|remaining_items| format!("Files to scan: {}", remaining_items))
|
||||||
|
// })
|
||||||
.child(
|
.child(
|
||||||
ModelSelector::new(
|
ModelSelector::new(
|
||||||
self.fs.clone(),
|
self.fs.clone(),
|
||||||
|
|
|
@ -519,6 +519,7 @@ impl Settings for AssistantSettings {
|
||||||
&mut settings.default_model,
|
&mut settings.default_model,
|
||||||
value.default_model.map(Into::into),
|
value.default_model.map(Into::into),
|
||||||
);
|
);
|
||||||
|
// merge(&mut settings.infer_context, value.infer_context); TODO re-enable this once we ship context inference
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(settings)
|
Ok(settings)
|
||||||
|
|
|
@ -19,6 +19,7 @@ use std::{
|
||||||
use ui::ActiveTheme;
|
use ui::ActiveTheme;
|
||||||
use workspace::Workspace;
|
use workspace::Workspace;
|
||||||
|
|
||||||
|
pub mod auto_command;
|
||||||
pub mod context_server_command;
|
pub mod context_server_command;
|
||||||
pub mod default_command;
|
pub mod default_command;
|
||||||
pub mod diagnostics_command;
|
pub mod diagnostics_command;
|
||||||
|
|
360
crates/assistant/src/slash_command/auto_command.rs
Normal file
360
crates/assistant/src/slash_command/auto_command.rs
Normal file
|
@ -0,0 +1,360 @@
|
||||||
|
use super::create_label_for_command;
|
||||||
|
use super::{SlashCommand, SlashCommandOutput};
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use assistant_slash_command::ArgumentCompletion;
|
||||||
|
use feature_flags::FeatureFlag;
|
||||||
|
use futures::StreamExt;
|
||||||
|
use gpui::{AppContext, AsyncAppContext, Task, WeakView};
|
||||||
|
use language::{CodeLabel, LspAdapterDelegate};
|
||||||
|
use language_model::{
|
||||||
|
LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
|
||||||
|
LanguageModelRequestMessage, Role,
|
||||||
|
};
|
||||||
|
use semantic_index::{FileSummary, SemanticDb};
|
||||||
|
use smol::channel;
|
||||||
|
use std::sync::{atomic::AtomicBool, Arc};
|
||||||
|
use ui::{BorrowAppContext, WindowContext};
|
||||||
|
use util::ResultExt;
|
||||||
|
use workspace::Workspace;
|
||||||
|
|
||||||
|
pub struct AutoSlashCommandFeatureFlag;
|
||||||
|
|
||||||
|
impl FeatureFlag for AutoSlashCommandFeatureFlag {
|
||||||
|
const NAME: &'static str = "auto-slash-command";
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct AutoCommand;
|
||||||
|
|
||||||
|
impl SlashCommand for AutoCommand {
|
||||||
|
fn name(&self) -> String {
|
||||||
|
"auto".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> String {
|
||||||
|
"Automatically infer what context to add, based on your prompt".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn menu_text(&self) -> String {
|
||||||
|
"Automatically Infer Context".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn label(&self, cx: &AppContext) -> CodeLabel {
|
||||||
|
create_label_for_command("auto", &["--prompt"], cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn complete_argument(
|
||||||
|
self: Arc<Self>,
|
||||||
|
_arguments: &[String],
|
||||||
|
_cancel: Arc<AtomicBool>,
|
||||||
|
workspace: Option<WeakView<Workspace>>,
|
||||||
|
cx: &mut WindowContext,
|
||||||
|
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||||
|
// There's no autocomplete for a prompt, since it's arbitrary text.
|
||||||
|
// However, we can use this opportunity to kick off a drain of the backlog.
|
||||||
|
// That way, it can hopefully be done resummarizing by the time we've actually
|
||||||
|
// typed out our prompt. This re-runs on every keystroke during autocomplete,
|
||||||
|
// but in the future, we could instead do it only once, when /auto is first entered.
|
||||||
|
let Some(workspace) = workspace.and_then(|ws| ws.upgrade()) else {
|
||||||
|
log::warn!("workspace was dropped or unavailable during /auto autocomplete");
|
||||||
|
|
||||||
|
return Task::ready(Ok(Vec::new()));
|
||||||
|
};
|
||||||
|
|
||||||
|
let project = workspace.read(cx).project().clone();
|
||||||
|
let Some(project_index) =
|
||||||
|
cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
|
||||||
|
else {
|
||||||
|
return Task::ready(Err(anyhow!("No project indexer, cannot use /auto")));
|
||||||
|
};
|
||||||
|
|
||||||
|
let cx: &mut AppContext = cx;
|
||||||
|
|
||||||
|
cx.spawn(|cx: gpui::AsyncAppContext| async move {
|
||||||
|
let task = project_index.read_with(&cx, |project_index, cx| {
|
||||||
|
project_index.flush_summary_backlogs(cx)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
cx.background_executor().spawn(task).await;
|
||||||
|
|
||||||
|
anyhow::Ok(Vec::new())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn requires_argument(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
arguments: &[String],
|
||||||
|
workspace: WeakView<Workspace>,
|
||||||
|
_delegate: Option<Arc<dyn LspAdapterDelegate>>,
|
||||||
|
cx: &mut WindowContext,
|
||||||
|
) -> Task<Result<SlashCommandOutput>> {
|
||||||
|
let Some(workspace) = workspace.upgrade() else {
|
||||||
|
return Task::ready(Err(anyhow::anyhow!("workspace was dropped")));
|
||||||
|
};
|
||||||
|
if arguments.is_empty() {
|
||||||
|
return Task::ready(Err(anyhow!("missing prompt")));
|
||||||
|
};
|
||||||
|
let argument = arguments.join(" ");
|
||||||
|
let original_prompt = argument.to_string();
|
||||||
|
let project = workspace.read(cx).project().clone();
|
||||||
|
let Some(project_index) =
|
||||||
|
cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
|
||||||
|
else {
|
||||||
|
return Task::ready(Err(anyhow!("no project indexer")));
|
||||||
|
};
|
||||||
|
|
||||||
|
let task = cx.spawn(|cx: gpui::AsyncWindowContext| async move {
|
||||||
|
let summaries = project_index
|
||||||
|
.read_with(&cx, |project_index, cx| project_index.all_summaries(cx))?
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
commands_for_summaries(&summaries, &original_prompt, &cx).await
|
||||||
|
});
|
||||||
|
|
||||||
|
// As a convenience, append /auto's argument to the end of the prompt
|
||||||
|
// so you don't have to write it again.
|
||||||
|
let original_prompt = argument.to_string();
|
||||||
|
|
||||||
|
cx.background_executor().spawn(async move {
|
||||||
|
let commands = task.await?;
|
||||||
|
let mut prompt = String::new();
|
||||||
|
|
||||||
|
log::info!(
|
||||||
|
"Translating this response into slash-commands: {:?}",
|
||||||
|
commands
|
||||||
|
);
|
||||||
|
|
||||||
|
for command in commands {
|
||||||
|
prompt.push('/');
|
||||||
|
prompt.push_str(&command.name);
|
||||||
|
prompt.push(' ');
|
||||||
|
prompt.push_str(&command.arg);
|
||||||
|
prompt.push('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt.push('\n');
|
||||||
|
prompt.push_str(&original_prompt);
|
||||||
|
|
||||||
|
Ok(SlashCommandOutput {
|
||||||
|
text: prompt,
|
||||||
|
sections: Vec::new(),
|
||||||
|
run_commands_in_text: true,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const PROMPT_INSTRUCTIONS_BEFORE_SUMMARY: &str = include_str!("prompt_before_summary.txt");
|
||||||
|
const PROMPT_INSTRUCTIONS_AFTER_SUMMARY: &str = include_str!("prompt_after_summary.txt");
|
||||||
|
|
||||||
|
fn summaries_prompt(summaries: &[FileSummary], original_prompt: &str) -> String {
|
||||||
|
let json_summaries = serde_json::to_string(summaries).unwrap();
|
||||||
|
|
||||||
|
format!("{PROMPT_INSTRUCTIONS_BEFORE_SUMMARY}\n{json_summaries}\n{PROMPT_INSTRUCTIONS_AFTER_SUMMARY}\n{original_prompt}")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The slash commands that the model is told about, and which we look for in the inference response.
|
||||||
|
const SUPPORTED_SLASH_COMMANDS: &[&str] = &["search", "file"];
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct CommandToRun {
|
||||||
|
name: String,
|
||||||
|
arg: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Given the pre-indexed file summaries for this project, as well as the original prompt
|
||||||
|
/// string passed to `/auto`, get a list of slash commands to run, along with their arguments.
|
||||||
|
///
|
||||||
|
/// The prompt's output does not include the slashes (to reduce the chance that it makes a mistake),
|
||||||
|
/// so taking one of these returned Strings and turning it into a real slash-command-with-argument
|
||||||
|
/// involves prepending a slash to it.
|
||||||
|
///
|
||||||
|
/// This function will validate that each of the returned lines begins with one of SUPPORTED_SLASH_COMMANDS.
|
||||||
|
/// Any other lines it encounters will be discarded, with a warning logged.
|
||||||
|
async fn commands_for_summaries(
|
||||||
|
summaries: &[FileSummary],
|
||||||
|
original_prompt: &str,
|
||||||
|
cx: &AsyncAppContext,
|
||||||
|
) -> Result<Vec<CommandToRun>> {
|
||||||
|
if summaries.is_empty() {
|
||||||
|
log::warn!("Inferring no context because there were no summaries available.");
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the globally configured model to translate the summaries into slash-commands,
|
||||||
|
// because Qwen2-7B-Instruct has not done a good job at that task.
|
||||||
|
let Some(model) = cx.update(|cx| LanguageModelRegistry::read_global(cx).active_model())? else {
|
||||||
|
log::warn!("Can't infer context because there's no active model.");
|
||||||
|
return Ok(Vec::new());
|
||||||
|
};
|
||||||
|
// Only go up to 90% of the actual max token count, to reduce chances of
|
||||||
|
// exceeding the token count due to inaccuracies in the token counting heuristic.
|
||||||
|
let max_token_count = (model.max_token_count() * 9) / 10;
|
||||||
|
|
||||||
|
// Rather than recursing (which would require this async function use a pinned box),
|
||||||
|
// we use an explicit stack of arguments and answers for when we need to "recurse."
|
||||||
|
let mut stack = vec![summaries];
|
||||||
|
let mut final_response = Vec::new();
|
||||||
|
let mut prompts = Vec::new();
|
||||||
|
|
||||||
|
// TODO We only need to create multiple Requests because we currently
|
||||||
|
// don't have the ability to tell if a CompletionProvider::complete response
|
||||||
|
// was a "too many tokens in this request" error. If we had that, then
|
||||||
|
// we could try the request once, instead of having to make separate requests
|
||||||
|
// to check the token count and then afterwards to run the actual prompt.
|
||||||
|
let make_request = |prompt: String| LanguageModelRequest {
|
||||||
|
messages: vec![LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec![prompt.into()],
|
||||||
|
// Nothing in here will benefit from caching
|
||||||
|
cache: false,
|
||||||
|
}],
|
||||||
|
tools: Vec::new(),
|
||||||
|
stop: Vec::new(),
|
||||||
|
temperature: 1.0,
|
||||||
|
};
|
||||||
|
|
||||||
|
while let Some(current_summaries) = stack.pop() {
|
||||||
|
// The split can result in one slice being empty and the other having one element.
|
||||||
|
// Whenever that happens, skip the empty one.
|
||||||
|
if current_summaries.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!(
|
||||||
|
"Inferring prompt context using {} file summaries",
|
||||||
|
current_summaries.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
let prompt = summaries_prompt(¤t_summaries, original_prompt);
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
// Per OpenAI, 1 token ~= 4 chars in English (we go with 4.5 to overestimate a bit, because failed API requests cost a lot of perf)
|
||||||
|
// Verifying this against an actual model.count_tokens() confirms that it's usually within ~5% of the correct answer, whereas
|
||||||
|
// getting the correct answer from tiktoken takes hundreds of milliseconds (compared to this arithmetic being ~free).
|
||||||
|
// source: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
|
||||||
|
let token_estimate = prompt.len() * 2 / 9;
|
||||||
|
let duration = start.elapsed();
|
||||||
|
log::info!(
|
||||||
|
"Time taken to count tokens for prompt of length {:?}B: {:?}",
|
||||||
|
prompt.len(),
|
||||||
|
duration
|
||||||
|
);
|
||||||
|
|
||||||
|
if token_estimate < max_token_count {
|
||||||
|
prompts.push(prompt);
|
||||||
|
} else if current_summaries.len() == 1 {
|
||||||
|
log::warn!("Inferring context for a single file's summary failed because the prompt's token length exceeded the model's token limit.");
|
||||||
|
} else {
|
||||||
|
log::info!(
|
||||||
|
"Context inference using file summaries resulted in a prompt containing {token_estimate} tokens, which exceeded the model's max of {max_token_count}. Retrying as two separate prompts, each including half the number of summaries.",
|
||||||
|
);
|
||||||
|
let (left, right) = current_summaries.split_at(current_summaries.len() / 2);
|
||||||
|
stack.push(right);
|
||||||
|
stack.push(left);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let all_start = std::time::Instant::now();
|
||||||
|
|
||||||
|
let (tx, rx) = channel::bounded(1024);
|
||||||
|
|
||||||
|
let completion_streams = prompts
|
||||||
|
.into_iter()
|
||||||
|
.map(|prompt| {
|
||||||
|
let request = make_request(prompt.clone());
|
||||||
|
let model = model.clone();
|
||||||
|
let tx = tx.clone();
|
||||||
|
let stream = model.stream_completion(request, &cx);
|
||||||
|
|
||||||
|
(stream, tx)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
cx.background_executor()
|
||||||
|
.spawn(async move {
|
||||||
|
let futures = completion_streams
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(ix, (stream, tx))| async move {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let events = stream.await?;
|
||||||
|
log::info!("Time taken for awaiting /await chunk stream #{ix}: {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let completion: String = events
|
||||||
|
.filter_map(|event| async {
|
||||||
|
if let Ok(LanguageModelCompletionEvent::Text(text)) = event {
|
||||||
|
Some(text)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
log::info!("Time taken for all /auto chunks to come back for #{ix}: {:?}", start.elapsed());
|
||||||
|
|
||||||
|
for line in completion.split('\n') {
|
||||||
|
if let Some(first_space) = line.find(' ') {
|
||||||
|
let command = &line[..first_space].trim();
|
||||||
|
let arg = &line[first_space..].trim();
|
||||||
|
|
||||||
|
tx.send(CommandToRun {
|
||||||
|
name: command.to_string(),
|
||||||
|
arg: arg.to_string(),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
} else if !line.trim().is_empty() {
|
||||||
|
// All slash-commands currently supported in context inference need a space for the argument.
|
||||||
|
log::warn!(
|
||||||
|
"Context inference returned a non-blank line that contained no spaces (meaning no argument for the slash command): {:?}",
|
||||||
|
line
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let _ = futures::future::try_join_all(futures).await.log_err();
|
||||||
|
|
||||||
|
let duration = all_start.elapsed();
|
||||||
|
eprintln!("All futures completed in {:?}", duration);
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
drop(tx); // Close the channel so that rx.collect() won't hang. This is safe because all futures have completed.
|
||||||
|
let results = rx.collect::<Vec<_>>().await;
|
||||||
|
eprintln!(
|
||||||
|
"Finished collecting from the channel with {} results",
|
||||||
|
results.len()
|
||||||
|
);
|
||||||
|
for command in results {
|
||||||
|
// Don't return empty or duplicate commands
|
||||||
|
if !command.name.is_empty()
|
||||||
|
&& !final_response
|
||||||
|
.iter()
|
||||||
|
.any(|cmd: &CommandToRun| cmd.name == command.name && cmd.arg == command.arg)
|
||||||
|
{
|
||||||
|
if SUPPORTED_SLASH_COMMANDS
|
||||||
|
.iter()
|
||||||
|
.any(|supported| &command.name == supported)
|
||||||
|
{
|
||||||
|
final_response.push(command);
|
||||||
|
} else {
|
||||||
|
log::warn!(
|
||||||
|
"Context inference returned an unrecognized slash command: {:?}",
|
||||||
|
command
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort the commands by name (reversed just so that /search appears before /file)
|
||||||
|
final_response.sort_by(|cmd1, cmd2| cmd1.name.cmp(&cmd2.name).reverse());
|
||||||
|
|
||||||
|
Ok(final_response)
|
||||||
|
}
|
24
crates/assistant/src/slash_command/prompt_after_summary.txt
Normal file
24
crates/assistant/src/slash_command/prompt_after_summary.txt
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
Actions have a cost, so only include actions that you think
|
||||||
|
will be helpful to you in doing a great job answering the
|
||||||
|
prompt in the future.
|
||||||
|
|
||||||
|
You must respond ONLY with a list of actions you would like to
|
||||||
|
perform. Each action should be on its own line, and followed by a space and then its parameter.
|
||||||
|
|
||||||
|
Actions can be performed more than once with different parameters.
|
||||||
|
Here is an example valid response:
|
||||||
|
|
||||||
|
```
|
||||||
|
file path/to/my/file.txt
|
||||||
|
file path/to/another/file.txt
|
||||||
|
search something to search for
|
||||||
|
search something else to search for
|
||||||
|
```
|
||||||
|
|
||||||
|
Once again, do not forget: you must respond ONLY in the format of
|
||||||
|
one action per line, and the action name should be followed by
|
||||||
|
its parameter. Your response must not include anything other
|
||||||
|
than a list of actions, with one action per line, in this format.
|
||||||
|
It is extremely important that you do not deviate from this format even slightly!
|
||||||
|
|
||||||
|
This is the end of my instructions for how to respond. The rest is the prompt:
|
31
crates/assistant/src/slash_command/prompt_before_summary.txt
Normal file
31
crates/assistant/src/slash_command/prompt_before_summary.txt
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
I'm going to give you a prompt. I don't want you to respond
|
||||||
|
to the prompt itself. I want you to figure out which of the following
|
||||||
|
actions on my project, if any, would help you answer the prompt.
|
||||||
|
|
||||||
|
Here are the actions:
|
||||||
|
|
||||||
|
## file
|
||||||
|
|
||||||
|
This action's parameter is a file path to one of the files
|
||||||
|
in the project. If you ask for this action, I will tell you
|
||||||
|
the full contents of the file, so you can learn all the
|
||||||
|
details of the file.
|
||||||
|
|
||||||
|
## search
|
||||||
|
|
||||||
|
This action's parameter is a string to do a semantic search for
|
||||||
|
across the files in the project. (You will have a JSON summary
|
||||||
|
of all the files in the project.) It will tell you which files this string
|
||||||
|
(or similar strings; it is a semantic search) appear in,
|
||||||
|
as well as some context of the lines surrounding each result.
|
||||||
|
It's very important that you only use this action when you think
|
||||||
|
that searching across the specific files in this project for the query
|
||||||
|
in question will be useful. For example, don't use this command to search
|
||||||
|
for queries you might put into a general Web search engine, because those
|
||||||
|
will be too general to give useful results in this project-specific search.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
That was the end of the list of actions.
|
||||||
|
|
||||||
|
Here is a JSON summary of each of the files in my project:
|
|
@ -8,7 +8,7 @@ use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
|
||||||
use feature_flags::FeatureFlag;
|
use feature_flags::FeatureFlag;
|
||||||
use gpui::{AppContext, Task, WeakView};
|
use gpui::{AppContext, Task, WeakView};
|
||||||
use language::{CodeLabel, LineEnding, LspAdapterDelegate};
|
use language::{CodeLabel, LineEnding, LspAdapterDelegate};
|
||||||
use semantic_index::SemanticIndex;
|
use semantic_index::SemanticDb;
|
||||||
use std::{
|
use std::{
|
||||||
fmt::Write,
|
fmt::Write,
|
||||||
path::PathBuf,
|
path::PathBuf,
|
||||||
|
@ -92,8 +92,11 @@ impl SlashCommand for SearchSlashCommand {
|
||||||
|
|
||||||
let project = workspace.read(cx).project().clone();
|
let project = workspace.read(cx).project().clone();
|
||||||
let fs = project.read(cx).fs().clone();
|
let fs = project.read(cx).fs().clone();
|
||||||
let project_index =
|
let Some(project_index) =
|
||||||
cx.update_global(|index: &mut SemanticIndex, cx| index.project_index(project, cx));
|
cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
|
||||||
|
else {
|
||||||
|
return Task::ready(Err(anyhow::anyhow!("no project indexer")));
|
||||||
|
};
|
||||||
|
|
||||||
cx.spawn(|cx| async move {
|
cx.spawn(|cx| async move {
|
||||||
let results = project_index
|
let results = project_index
|
||||||
|
|
|
@ -149,16 +149,16 @@ spec:
|
||||||
secretKeyRef:
|
secretKeyRef:
|
||||||
name: google-ai
|
name: google-ai
|
||||||
key: api_key
|
key: api_key
|
||||||
- name: QWEN2_7B_API_KEY
|
- name: RUNPOD_API_KEY
|
||||||
valueFrom:
|
valueFrom:
|
||||||
secretKeyRef:
|
secretKeyRef:
|
||||||
name: hugging-face
|
name: runpod
|
||||||
key: api_key
|
key: api_key
|
||||||
- name: QWEN2_7B_API_URL
|
- name: RUNPOD_API_SUMMARY_URL
|
||||||
valueFrom:
|
valueFrom:
|
||||||
secretKeyRef:
|
secretKeyRef:
|
||||||
name: hugging-face
|
name: runpod
|
||||||
key: qwen2_api_url
|
key: summary
|
||||||
- name: BLOB_STORE_ACCESS_KEY
|
- name: BLOB_STORE_ACCESS_KEY
|
||||||
valueFrom:
|
valueFrom:
|
||||||
secretKeyRef:
|
secretKeyRef:
|
||||||
|
|
|
@ -728,6 +728,11 @@ impl Database {
|
||||||
is_ignored: db_entry.is_ignored,
|
is_ignored: db_entry.is_ignored,
|
||||||
is_external: db_entry.is_external,
|
is_external: db_entry.is_external,
|
||||||
git_status: db_entry.git_status.map(|status| status as i32),
|
git_status: db_entry.git_status.map(|status| status as i32),
|
||||||
|
// This is only used in the summarization backlog, so if it's None,
|
||||||
|
// that just means we won't be able to detect when to resummarize
|
||||||
|
// based on total number of backlogged bytes - instead, we'd go
|
||||||
|
// on number of files only. That shouldn't be a huge deal in practice.
|
||||||
|
size: None,
|
||||||
is_fifo: db_entry.is_fifo,
|
is_fifo: db_entry.is_fifo,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -663,6 +663,11 @@ impl Database {
|
||||||
is_ignored: db_entry.is_ignored,
|
is_ignored: db_entry.is_ignored,
|
||||||
is_external: db_entry.is_external,
|
is_external: db_entry.is_external,
|
||||||
git_status: db_entry.git_status.map(|status| status as i32),
|
git_status: db_entry.git_status.map(|status| status as i32),
|
||||||
|
// This is only used in the summarization backlog, so if it's None,
|
||||||
|
// that just means we won't be able to detect when to resummarize
|
||||||
|
// based on total number of backlogged bytes - instead, we'd go
|
||||||
|
// on number of files only. That shouldn't be a huge deal in practice.
|
||||||
|
size: None,
|
||||||
is_fifo: db_entry.is_fifo,
|
is_fifo: db_entry.is_fifo,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -170,8 +170,8 @@ pub struct Config {
|
||||||
pub anthropic_api_key: Option<Arc<str>>,
|
pub anthropic_api_key: Option<Arc<str>>,
|
||||||
pub anthropic_staff_api_key: Option<Arc<str>>,
|
pub anthropic_staff_api_key: Option<Arc<str>>,
|
||||||
pub llm_closed_beta_model_name: Option<Arc<str>>,
|
pub llm_closed_beta_model_name: Option<Arc<str>>,
|
||||||
pub qwen2_7b_api_key: Option<Arc<str>>,
|
pub runpod_api_key: Option<Arc<str>>,
|
||||||
pub qwen2_7b_api_url: Option<Arc<str>>,
|
pub runpod_api_summary_url: Option<Arc<str>>,
|
||||||
pub zed_client_checksum_seed: Option<String>,
|
pub zed_client_checksum_seed: Option<String>,
|
||||||
pub slack_panics_webhook: Option<String>,
|
pub slack_panics_webhook: Option<String>,
|
||||||
pub auto_join_channel_id: Option<ChannelId>,
|
pub auto_join_channel_id: Option<ChannelId>,
|
||||||
|
@ -235,8 +235,8 @@ impl Config {
|
||||||
stripe_api_key: None,
|
stripe_api_key: None,
|
||||||
stripe_price_id: None,
|
stripe_price_id: None,
|
||||||
supermaven_admin_api_key: None,
|
supermaven_admin_api_key: None,
|
||||||
qwen2_7b_api_key: None,
|
runpod_api_key: None,
|
||||||
qwen2_7b_api_url: None,
|
runpod_api_summary_url: None,
|
||||||
user_backfiller_github_access_token: None,
|
user_backfiller_github_access_token: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -402,12 +402,12 @@ async fn perform_completion(
|
||||||
LanguageModelProvider::Zed => {
|
LanguageModelProvider::Zed => {
|
||||||
let api_key = state
|
let api_key = state
|
||||||
.config
|
.config
|
||||||
.qwen2_7b_api_key
|
.runpod_api_key
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.context("no Qwen2-7B API key configured on the server")?;
|
.context("no Qwen2-7B API key configured on the server")?;
|
||||||
let api_url = state
|
let api_url = state
|
||||||
.config
|
.config
|
||||||
.qwen2_7b_api_url
|
.runpod_api_summary_url
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.context("no Qwen2-7B URL configured on the server")?;
|
.context("no Qwen2-7B URL configured on the server")?;
|
||||||
let chunks = open_ai::stream_completion(
|
let chunks = open_ai::stream_completion(
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
use sea_orm::QueryOrder;
|
use sea_orm::{sea_query::OnConflict, QueryOrder};
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use strum::IntoEnumIterator as _;
|
use strum::IntoEnumIterator as _;
|
||||||
|
|
||||||
|
@ -99,6 +99,17 @@ impl LlmDatabase {
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
.on_conflict(
|
||||||
|
OnConflict::columns([model::Column::ProviderId, model::Column::Name])
|
||||||
|
.update_columns([
|
||||||
|
model::Column::MaxRequestsPerMinute,
|
||||||
|
model::Column::MaxTokensPerMinute,
|
||||||
|
model::Column::MaxTokensPerDay,
|
||||||
|
model::Column::PricePerMillionInputTokens,
|
||||||
|
model::Column::PricePerMillionOutputTokens,
|
||||||
|
])
|
||||||
|
.to_owned(),
|
||||||
|
)
|
||||||
.exec_without_returning(&*tx)
|
.exec_without_returning(&*tx)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -40,6 +40,15 @@ pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool)
|
||||||
price_per_million_input_tokens: 25, // $0.25/MTok
|
price_per_million_input_tokens: 25, // $0.25/MTok
|
||||||
price_per_million_output_tokens: 125, // $1.25/MTok
|
price_per_million_output_tokens: 125, // $1.25/MTok
|
||||||
},
|
},
|
||||||
|
ModelParams {
|
||||||
|
provider: LanguageModelProvider::Zed,
|
||||||
|
name: "Qwen/Qwen2-7B-Instruct".into(),
|
||||||
|
max_requests_per_minute: 5,
|
||||||
|
max_tokens_per_minute: 25_000, // These are arbitrary limits we've set to cap costs; we control this number
|
||||||
|
max_tokens_per_day: 300_000,
|
||||||
|
price_per_million_input_tokens: 25,
|
||||||
|
price_per_million_output_tokens: 125,
|
||||||
|
},
|
||||||
])
|
])
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
|
@ -679,8 +679,8 @@ impl TestServer {
|
||||||
stripe_api_key: None,
|
stripe_api_key: None,
|
||||||
stripe_price_id: None,
|
stripe_price_id: None,
|
||||||
supermaven_admin_api_key: None,
|
supermaven_admin_api_key: None,
|
||||||
qwen2_7b_api_key: None,
|
runpod_api_key: None,
|
||||||
qwen2_7b_api_url: None,
|
runpod_api_summary_url: None,
|
||||||
user_backfiller_github_access_token: None,
|
user_backfiller_github_access_token: None,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
|
@ -13,3 +13,4 @@ path = "src/feature_flags.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
|
|
|
@ -1,4 +1,10 @@
|
||||||
|
use futures::{channel::oneshot, FutureExt as _};
|
||||||
use gpui::{AppContext, Global, Subscription, ViewContext};
|
use gpui::{AppContext, Global, Subscription, ViewContext};
|
||||||
|
use std::{
|
||||||
|
future::Future,
|
||||||
|
pin::Pin,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
struct FeatureFlags {
|
struct FeatureFlags {
|
||||||
|
@ -53,6 +59,15 @@ impl FeatureFlag for ZedPro {
|
||||||
const NAME: &'static str = "zed-pro";
|
const NAME: &'static str = "zed-pro";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct AutoCommand {}
|
||||||
|
impl FeatureFlag for AutoCommand {
|
||||||
|
const NAME: &'static str = "auto-command";
|
||||||
|
|
||||||
|
fn enabled_for_staff() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait FeatureFlagViewExt<V: 'static> {
|
pub trait FeatureFlagViewExt<V: 'static> {
|
||||||
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
|
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
|
||||||
where
|
where
|
||||||
|
@ -75,6 +90,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait FeatureFlagAppExt {
|
pub trait FeatureFlagAppExt {
|
||||||
|
fn wait_for_flag<T: FeatureFlag>(&mut self) -> WaitForFlag;
|
||||||
fn update_flags(&mut self, staff: bool, flags: Vec<String>);
|
fn update_flags(&mut self, staff: bool, flags: Vec<String>);
|
||||||
fn set_staff(&mut self, staff: bool);
|
fn set_staff(&mut self, staff: bool);
|
||||||
fn has_flag<T: FeatureFlag>(&self) -> bool;
|
fn has_flag<T: FeatureFlag>(&self) -> bool;
|
||||||
|
@ -82,7 +98,7 @@ pub trait FeatureFlagAppExt {
|
||||||
|
|
||||||
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
|
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
|
||||||
where
|
where
|
||||||
F: Fn(bool, &mut AppContext) + 'static;
|
F: FnMut(bool, &mut AppContext) + 'static;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FeatureFlagAppExt for AppContext {
|
impl FeatureFlagAppExt for AppContext {
|
||||||
|
@ -109,13 +125,49 @@ impl FeatureFlagAppExt for AppContext {
|
||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
|
fn observe_flag<T: FeatureFlag, F>(&mut self, mut callback: F) -> Subscription
|
||||||
where
|
where
|
||||||
F: Fn(bool, &mut AppContext) + 'static,
|
F: FnMut(bool, &mut AppContext) + 'static,
|
||||||
{
|
{
|
||||||
self.observe_global::<FeatureFlags>(move |cx| {
|
self.observe_global::<FeatureFlags>(move |cx| {
|
||||||
let feature_flags = cx.global::<FeatureFlags>();
|
let feature_flags = cx.global::<FeatureFlags>();
|
||||||
callback(feature_flags.has_flag::<T>(), cx);
|
callback(feature_flags.has_flag::<T>(), cx);
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn wait_for_flag<T: FeatureFlag>(&mut self) -> WaitForFlag {
|
||||||
|
let (tx, rx) = oneshot::channel::<bool>();
|
||||||
|
let mut tx = Some(tx);
|
||||||
|
let subscription: Option<Subscription>;
|
||||||
|
|
||||||
|
match self.try_global::<FeatureFlags>() {
|
||||||
|
Some(feature_flags) => {
|
||||||
|
subscription = None;
|
||||||
|
tx.take().unwrap().send(feature_flags.has_flag::<T>()).ok();
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
subscription = Some(self.observe_global::<FeatureFlags>(move |cx| {
|
||||||
|
let feature_flags = cx.global::<FeatureFlags>();
|
||||||
|
if let Some(tx) = tx.take() {
|
||||||
|
tx.send(feature_flags.has_flag::<T>()).ok();
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
WaitForFlag(rx, subscription)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct WaitForFlag(oneshot::Receiver<bool>, Option<Subscription>);
|
||||||
|
|
||||||
|
impl Future for WaitForFlag {
|
||||||
|
type Output = bool;
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
|
self.0.poll_unpin(cx).map(|result| {
|
||||||
|
self.1.take();
|
||||||
|
result.unwrap_or(false)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -171,6 +171,7 @@ pub struct Metadata {
|
||||||
pub mtime: SystemTime,
|
pub mtime: SystemTime,
|
||||||
pub is_symlink: bool,
|
pub is_symlink: bool,
|
||||||
pub is_dir: bool,
|
pub is_dir: bool,
|
||||||
|
pub len: u64,
|
||||||
pub is_fifo: bool,
|
pub is_fifo: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -497,6 +498,7 @@ impl Fs for RealFs {
|
||||||
Ok(Some(Metadata {
|
Ok(Some(Metadata {
|
||||||
inode,
|
inode,
|
||||||
mtime: metadata.modified().unwrap(),
|
mtime: metadata.modified().unwrap(),
|
||||||
|
len: metadata.len(),
|
||||||
is_symlink,
|
is_symlink,
|
||||||
is_dir: metadata.file_type().is_dir(),
|
is_dir: metadata.file_type().is_dir(),
|
||||||
is_fifo,
|
is_fifo,
|
||||||
|
@ -800,11 +802,13 @@ enum FakeFsEntry {
|
||||||
File {
|
File {
|
||||||
inode: u64,
|
inode: u64,
|
||||||
mtime: SystemTime,
|
mtime: SystemTime,
|
||||||
|
len: u64,
|
||||||
content: Vec<u8>,
|
content: Vec<u8>,
|
||||||
},
|
},
|
||||||
Dir {
|
Dir {
|
||||||
inode: u64,
|
inode: u64,
|
||||||
mtime: SystemTime,
|
mtime: SystemTime,
|
||||||
|
len: u64,
|
||||||
entries: BTreeMap<String, Arc<Mutex<FakeFsEntry>>>,
|
entries: BTreeMap<String, Arc<Mutex<FakeFsEntry>>>,
|
||||||
git_repo_state: Option<Arc<Mutex<git::repository::FakeGitRepositoryState>>>,
|
git_repo_state: Option<Arc<Mutex<git::repository::FakeGitRepositoryState>>>,
|
||||||
},
|
},
|
||||||
|
@ -935,6 +939,7 @@ impl FakeFs {
|
||||||
root: Arc::new(Mutex::new(FakeFsEntry::Dir {
|
root: Arc::new(Mutex::new(FakeFsEntry::Dir {
|
||||||
inode: 0,
|
inode: 0,
|
||||||
mtime: SystemTime::UNIX_EPOCH,
|
mtime: SystemTime::UNIX_EPOCH,
|
||||||
|
len: 0,
|
||||||
entries: Default::default(),
|
entries: Default::default(),
|
||||||
git_repo_state: None,
|
git_repo_state: None,
|
||||||
})),
|
})),
|
||||||
|
@ -969,6 +974,7 @@ impl FakeFs {
|
||||||
inode: new_inode,
|
inode: new_inode,
|
||||||
mtime: new_mtime,
|
mtime: new_mtime,
|
||||||
content: Vec::new(),
|
content: Vec::new(),
|
||||||
|
len: 0,
|
||||||
})));
|
})));
|
||||||
}
|
}
|
||||||
btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() {
|
btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() {
|
||||||
|
@ -1016,6 +1022,7 @@ impl FakeFs {
|
||||||
let file = Arc::new(Mutex::new(FakeFsEntry::File {
|
let file = Arc::new(Mutex::new(FakeFsEntry::File {
|
||||||
inode,
|
inode,
|
||||||
mtime,
|
mtime,
|
||||||
|
len: content.len() as u64,
|
||||||
content,
|
content,
|
||||||
}));
|
}));
|
||||||
let mut kind = None;
|
let mut kind = None;
|
||||||
|
@ -1369,6 +1376,7 @@ impl Fs for FakeFs {
|
||||||
Arc::new(Mutex::new(FakeFsEntry::Dir {
|
Arc::new(Mutex::new(FakeFsEntry::Dir {
|
||||||
inode,
|
inode,
|
||||||
mtime,
|
mtime,
|
||||||
|
len: 0,
|
||||||
entries: Default::default(),
|
entries: Default::default(),
|
||||||
git_repo_state: None,
|
git_repo_state: None,
|
||||||
}))
|
}))
|
||||||
|
@ -1391,6 +1399,7 @@ impl Fs for FakeFs {
|
||||||
let file = Arc::new(Mutex::new(FakeFsEntry::File {
|
let file = Arc::new(Mutex::new(FakeFsEntry::File {
|
||||||
inode,
|
inode,
|
||||||
mtime,
|
mtime,
|
||||||
|
len: 0,
|
||||||
content: Vec::new(),
|
content: Vec::new(),
|
||||||
}));
|
}));
|
||||||
let mut kind = Some(PathEventKind::Created);
|
let mut kind = Some(PathEventKind::Created);
|
||||||
|
@ -1539,6 +1548,7 @@ impl Fs for FakeFs {
|
||||||
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
|
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
|
||||||
inode,
|
inode,
|
||||||
mtime,
|
mtime,
|
||||||
|
len: content.len() as u64,
|
||||||
content: Vec::new(),
|
content: Vec::new(),
|
||||||
})))
|
})))
|
||||||
.clone(),
|
.clone(),
|
||||||
|
@ -1694,16 +1704,22 @@ impl Fs for FakeFs {
|
||||||
|
|
||||||
let entry = entry.lock();
|
let entry = entry.lock();
|
||||||
Ok(Some(match &*entry {
|
Ok(Some(match &*entry {
|
||||||
FakeFsEntry::File { inode, mtime, .. } => Metadata {
|
FakeFsEntry::File {
|
||||||
|
inode, mtime, len, ..
|
||||||
|
} => Metadata {
|
||||||
inode: *inode,
|
inode: *inode,
|
||||||
mtime: *mtime,
|
mtime: *mtime,
|
||||||
|
len: *len,
|
||||||
is_dir: false,
|
is_dir: false,
|
||||||
is_symlink,
|
is_symlink,
|
||||||
is_fifo: false,
|
is_fifo: false,
|
||||||
},
|
},
|
||||||
FakeFsEntry::Dir { inode, mtime, .. } => Metadata {
|
FakeFsEntry::Dir {
|
||||||
|
inode, mtime, len, ..
|
||||||
|
} => Metadata {
|
||||||
inode: *inode,
|
inode: *inode,
|
||||||
mtime: *mtime,
|
mtime: *mtime,
|
||||||
|
len: *len,
|
||||||
is_dir: true,
|
is_dir: true,
|
||||||
is_symlink,
|
is_symlink,
|
||||||
is_fifo: false,
|
is_fifo: false,
|
||||||
|
|
|
@ -57,7 +57,6 @@ impl GitStatus {
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
return Err(anyhow!("git status process failed: {}", stderr));
|
return Err(anyhow!("git status process failed: {}", stderr));
|
||||||
}
|
}
|
||||||
|
|
||||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||||
let mut entries = stdout
|
let mut entries = stdout
|
||||||
.split('\0')
|
.split('\0')
|
||||||
|
|
|
@ -221,6 +221,10 @@ impl HttpClient for HttpClientWithUrl {
|
||||||
|
|
||||||
pub fn client(user_agent: Option<String>, proxy: Option<Uri>) -> Arc<dyn HttpClient> {
|
pub fn client(user_agent: Option<String>, proxy: Option<Uri>) -> Arc<dyn HttpClient> {
|
||||||
let mut builder = isahc::HttpClient::builder()
|
let mut builder = isahc::HttpClient::builder()
|
||||||
|
// Some requests to Qwen2 models on Runpod can take 32+ seconds,
|
||||||
|
// especially if there's a cold boot involved. We may need to have
|
||||||
|
// those requests use a different http client, because global timeouts
|
||||||
|
// of 50 and 60 seconds, respectively, would be very high!
|
||||||
.connect_timeout(Duration::from_secs(5))
|
.connect_timeout(Duration::from_secs(5))
|
||||||
.low_speed_timeout(100, Duration::from_secs(5))
|
.low_speed_timeout(100, Duration::from_secs(5))
|
||||||
.proxy(proxy.clone());
|
.proxy(proxy.clone());
|
||||||
|
|
|
@ -17,14 +17,14 @@ pub enum CloudModel {
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
|
||||||
pub enum ZedModel {
|
pub enum ZedModel {
|
||||||
#[serde(rename = "qwen2-7b-instruct")]
|
#[serde(rename = "Qwen/Qwen2-7B-Instruct")]
|
||||||
Qwen2_7bInstruct,
|
Qwen2_7bInstruct,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ZedModel {
|
impl ZedModel {
|
||||||
pub fn id(&self) -> &str {
|
pub fn id(&self) -> &str {
|
||||||
match self {
|
match self {
|
||||||
ZedModel::Qwen2_7bInstruct => "qwen2-7b-instruct",
|
ZedModel::Qwen2_7bInstruct => "Qwen/Qwen2-7B-Instruct",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -319,7 +319,7 @@ impl AnthropicModel {
|
||||||
};
|
};
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
let api_key = api_key.ok_or_else(|| anyhow!("Missing Anthropic API Key"))?;
|
||||||
let request = anthropic::stream_completion(
|
let request = anthropic::stream_completion(
|
||||||
http_client.as_ref(),
|
http_client.as_ref(),
|
||||||
&api_url,
|
&api_url,
|
||||||
|
|
|
@ -265,7 +265,7 @@ impl LanguageModel for GoogleLanguageModel {
|
||||||
let low_speed_timeout = settings.low_speed_timeout;
|
let low_speed_timeout = settings.low_speed_timeout;
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API key"))?;
|
||||||
let response = google_ai::count_tokens(
|
let response = google_ai::count_tokens(
|
||||||
http_client.as_ref(),
|
http_client.as_ref(),
|
||||||
&api_url,
|
&api_url,
|
||||||
|
@ -304,7 +304,7 @@ impl LanguageModel for GoogleLanguageModel {
|
||||||
};
|
};
|
||||||
|
|
||||||
let future = self.rate_limiter.stream(async move {
|
let future = self.rate_limiter.stream(async move {
|
||||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API Key"))?;
|
||||||
let response = stream_generate_content(
|
let response = stream_generate_content(
|
||||||
http_client.as_ref(),
|
http_client.as_ref(),
|
||||||
&api_url,
|
&api_url,
|
||||||
|
|
|
@ -239,7 +239,7 @@ impl OpenAiLanguageModel {
|
||||||
};
|
};
|
||||||
|
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenAI API Key"))?;
|
||||||
let request = stream_completion(
|
let request = stream_completion(
|
||||||
http_client.as_ref(),
|
http_client.as_ref(),
|
||||||
&api_url,
|
&api_url,
|
||||||
|
|
|
@ -159,11 +159,13 @@ impl LanguageModelRegistry {
|
||||||
providers
|
providers
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
pub fn available_models<'a>(
|
||||||
|
&'a self,
|
||||||
|
cx: &'a AppContext,
|
||||||
|
) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
|
||||||
self.providers
|
self.providers
|
||||||
.values()
|
.values()
|
||||||
.flat_map(|provider| provider.provided_models(cx))
|
.flat_map(|provider| provider.provided_models(cx))
|
||||||
.collect()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
|
pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
|
||||||
|
|
|
@ -1823,6 +1823,7 @@ impl ProjectPanel {
|
||||||
path: entry.path.join("\0").into(),
|
path: entry.path.join("\0").into(),
|
||||||
inode: 0,
|
inode: 0,
|
||||||
mtime: entry.mtime,
|
mtime: entry.mtime,
|
||||||
|
size: entry.size,
|
||||||
is_ignored: entry.is_ignored,
|
is_ignored: entry.is_ignored,
|
||||||
is_external: false,
|
is_external: false,
|
||||||
is_private: false,
|
is_private: false,
|
||||||
|
|
|
@ -1855,6 +1855,7 @@ message Entry {
|
||||||
bool is_external = 8;
|
bool is_external = 8;
|
||||||
optional GitStatus git_status = 9;
|
optional GitStatus git_status = 9;
|
||||||
bool is_fifo = 10;
|
bool is_fifo = 10;
|
||||||
|
optional uint64 size = 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
message RepositoryEntry {
|
message RepositoryEntry {
|
||||||
|
|
|
@ -19,14 +19,18 @@ crate-type = ["bin"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
|
arrayvec.workspace = true
|
||||||
|
blake3.workspace = true
|
||||||
client.workspace = true
|
client.workspace = true
|
||||||
clock.workspace = true
|
clock.workspace = true
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
|
feature_flags.workspace = true
|
||||||
fs.workspace = true
|
fs.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
futures-batch.workspace = true
|
futures-batch.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
language.workspace = true
|
language.workspace = true
|
||||||
|
language_model.workspace = true
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
heed.workspace = true
|
heed.workspace = true
|
||||||
http_client.workspace = true
|
http_client.workspace = true
|
||||||
|
|
|
@ -4,7 +4,7 @@ use gpui::App;
|
||||||
use http_client::HttpClientWithUrl;
|
use http_client::HttpClientWithUrl;
|
||||||
use language::language_settings::AllLanguageSettings;
|
use language::language_settings::AllLanguageSettings;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex};
|
use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticDb};
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use std::{
|
use std::{
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
|
@ -50,7 +50,7 @@ fn main() {
|
||||||
));
|
));
|
||||||
|
|
||||||
cx.spawn(|mut cx| async move {
|
cx.spawn(|mut cx| async move {
|
||||||
let semantic_index = SemanticIndex::new(
|
let semantic_index = SemanticDb::new(
|
||||||
PathBuf::from("/tmp/semantic-index-db.mdb"),
|
PathBuf::from("/tmp/semantic-index-db.mdb"),
|
||||||
embedding_provider,
|
embedding_provider,
|
||||||
&mut cx,
|
&mut cx,
|
||||||
|
@ -71,6 +71,7 @@ fn main() {
|
||||||
|
|
||||||
let project_index = cx
|
let project_index = cx
|
||||||
.update(|cx| semantic_index.project_index(project.clone(), cx))
|
.update(|cx| semantic_index.project_index(project.clone(), cx))
|
||||||
|
.unwrap()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let (tx, rx) = oneshot::channel();
|
let (tx, rx) = oneshot::channel();
|
||||||
|
|
|
@ -12,6 +12,12 @@ use futures::{future::BoxFuture, FutureExt};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{fmt, future};
|
use std::{fmt, future};
|
||||||
|
|
||||||
|
/// Trait for embedding providers. Texts in, vectors out.
|
||||||
|
pub trait EmbeddingProvider: Sync + Send {
|
||||||
|
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
|
||||||
|
fn batch_size(&self) -> usize;
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct Embedding(Vec<f32>);
|
pub struct Embedding(Vec<f32>);
|
||||||
|
|
||||||
|
@ -68,12 +74,6 @@ impl fmt::Display for Embedding {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for embedding providers. Texts in, vectors out.
|
|
||||||
pub trait EmbeddingProvider: Sync + Send {
|
|
||||||
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
|
|
||||||
fn batch_size(&self) -> usize;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct TextToEmbed<'a> {
|
pub struct TextToEmbed<'a> {
|
||||||
pub text: &'a str,
|
pub text: &'a str,
|
||||||
|
|
469
crates/semantic_index/src/embedding_index.rs
Normal file
469
crates/semantic_index/src/embedding_index.rs
Normal file
|
@ -0,0 +1,469 @@
|
||||||
|
use crate::{
|
||||||
|
chunking::{self, Chunk},
|
||||||
|
embedding::{Embedding, EmbeddingProvider, TextToEmbed},
|
||||||
|
indexing::{IndexingEntryHandle, IndexingEntrySet},
|
||||||
|
};
|
||||||
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
|
use collections::Bound;
|
||||||
|
use fs::Fs;
|
||||||
|
use futures::stream::StreamExt;
|
||||||
|
use futures_batch::ChunksTimeoutStreamExt;
|
||||||
|
use gpui::{AppContext, Model, Task};
|
||||||
|
use heed::types::{SerdeBincode, Str};
|
||||||
|
use language::LanguageRegistry;
|
||||||
|
use log;
|
||||||
|
use project::{Entry, UpdatedEntriesSet, Worktree};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use smol::channel;
|
||||||
|
use std::{
|
||||||
|
cmp::Ordering,
|
||||||
|
future::Future,
|
||||||
|
iter,
|
||||||
|
path::Path,
|
||||||
|
sync::Arc,
|
||||||
|
time::{Duration, SystemTime},
|
||||||
|
};
|
||||||
|
use util::ResultExt;
|
||||||
|
use worktree::Snapshot;
|
||||||
|
|
||||||
|
pub struct EmbeddingIndex {
|
||||||
|
worktree: Model<Worktree>,
|
||||||
|
db_connection: heed::Env,
|
||||||
|
db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
|
||||||
|
fs: Arc<dyn Fs>,
|
||||||
|
language_registry: Arc<LanguageRegistry>,
|
||||||
|
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||||
|
entry_ids_being_indexed: Arc<IndexingEntrySet>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingIndex {
|
||||||
|
pub fn new(
|
||||||
|
worktree: Model<Worktree>,
|
||||||
|
fs: Arc<dyn Fs>,
|
||||||
|
db_connection: heed::Env,
|
||||||
|
embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
|
||||||
|
language_registry: Arc<LanguageRegistry>,
|
||||||
|
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||||
|
entry_ids_being_indexed: Arc<IndexingEntrySet>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
worktree,
|
||||||
|
fs,
|
||||||
|
db_connection,
|
||||||
|
db: embedding_db,
|
||||||
|
language_registry,
|
||||||
|
embedding_provider,
|
||||||
|
entry_ids_being_indexed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
|
||||||
|
&self.db
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn index_entries_changed_on_disk(
|
||||||
|
&self,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> impl Future<Output = Result<()>> {
|
||||||
|
let worktree = self.worktree.read(cx).snapshot();
|
||||||
|
let worktree_abs_path = worktree.abs_path().clone();
|
||||||
|
let scan = self.scan_entries(worktree, cx);
|
||||||
|
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
|
||||||
|
let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
|
||||||
|
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
|
||||||
|
async move {
|
||||||
|
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn index_updated_entries(
|
||||||
|
&self,
|
||||||
|
updated_entries: UpdatedEntriesSet,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> impl Future<Output = Result<()>> {
|
||||||
|
let worktree = self.worktree.read(cx).snapshot();
|
||||||
|
let worktree_abs_path = worktree.abs_path().clone();
|
||||||
|
let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
|
||||||
|
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
|
||||||
|
let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
|
||||||
|
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
|
||||||
|
async move {
|
||||||
|
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
|
||||||
|
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
|
||||||
|
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
|
||||||
|
let db_connection = self.db_connection.clone();
|
||||||
|
let db = self.db;
|
||||||
|
let entries_being_indexed = self.entry_ids_being_indexed.clone();
|
||||||
|
let task = cx.background_executor().spawn(async move {
|
||||||
|
let txn = db_connection
|
||||||
|
.read_txn()
|
||||||
|
.context("failed to create read transaction")?;
|
||||||
|
let mut db_entries = db
|
||||||
|
.iter(&txn)
|
||||||
|
.context("failed to create iterator")?
|
||||||
|
.move_between_keys()
|
||||||
|
.peekable();
|
||||||
|
|
||||||
|
let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
|
||||||
|
for entry in worktree.files(false, 0) {
|
||||||
|
log::trace!("scanning for embedding index: {:?}", &entry.path);
|
||||||
|
|
||||||
|
let entry_db_key = db_key_for_path(&entry.path);
|
||||||
|
|
||||||
|
let mut saved_mtime = None;
|
||||||
|
while let Some(db_entry) = db_entries.peek() {
|
||||||
|
match db_entry {
|
||||||
|
Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
|
||||||
|
Ordering::Less => {
|
||||||
|
if let Some(deletion_range) = deletion_range.as_mut() {
|
||||||
|
deletion_range.1 = Bound::Included(db_path);
|
||||||
|
} else {
|
||||||
|
deletion_range =
|
||||||
|
Some((Bound::Included(db_path), Bound::Included(db_path)));
|
||||||
|
}
|
||||||
|
|
||||||
|
db_entries.next();
|
||||||
|
}
|
||||||
|
Ordering::Equal => {
|
||||||
|
if let Some(deletion_range) = deletion_range.take() {
|
||||||
|
deleted_entry_ranges_tx
|
||||||
|
.send((
|
||||||
|
deletion_range.0.map(ToString::to_string),
|
||||||
|
deletion_range.1.map(ToString::to_string),
|
||||||
|
))
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
saved_mtime = db_embedded_file.mtime;
|
||||||
|
db_entries.next();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Ordering::Greater => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if entry.mtime != saved_mtime {
|
||||||
|
let handle = entries_being_indexed.insert(entry.id);
|
||||||
|
updated_entries_tx.send((entry.clone(), handle)).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(db_entry) = db_entries.next() {
|
||||||
|
let (db_path, _) = db_entry?;
|
||||||
|
deleted_entry_ranges_tx
|
||||||
|
.send((Bound::Included(db_path.to_string()), Bound::Unbounded))
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
|
||||||
|
ScanEntries {
|
||||||
|
updated_entries: updated_entries_rx,
|
||||||
|
deleted_entry_ranges: deleted_entry_ranges_rx,
|
||||||
|
task,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scan_updated_entries(
|
||||||
|
&self,
|
||||||
|
worktree: Snapshot,
|
||||||
|
updated_entries: UpdatedEntriesSet,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> ScanEntries {
|
||||||
|
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
|
||||||
|
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
|
||||||
|
let entries_being_indexed = self.entry_ids_being_indexed.clone();
|
||||||
|
let task = cx.background_executor().spawn(async move {
|
||||||
|
for (path, entry_id, status) in updated_entries.iter() {
|
||||||
|
match status {
|
||||||
|
project::PathChange::Added
|
||||||
|
| project::PathChange::Updated
|
||||||
|
| project::PathChange::AddedOrUpdated => {
|
||||||
|
if let Some(entry) = worktree.entry_for_id(*entry_id) {
|
||||||
|
if entry.is_file() {
|
||||||
|
let handle = entries_being_indexed.insert(entry.id);
|
||||||
|
updated_entries_tx.send((entry.clone(), handle)).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
project::PathChange::Removed => {
|
||||||
|
let db_path = db_key_for_path(path);
|
||||||
|
deleted_entry_ranges_tx
|
||||||
|
.send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
project::PathChange::Loaded => {
|
||||||
|
// Do nothing.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
|
||||||
|
ScanEntries {
|
||||||
|
updated_entries: updated_entries_rx,
|
||||||
|
deleted_entry_ranges: deleted_entry_ranges_rx,
|
||||||
|
task,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn chunk_files(
|
||||||
|
&self,
|
||||||
|
worktree_abs_path: Arc<Path>,
|
||||||
|
entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> ChunkFiles {
|
||||||
|
let language_registry = self.language_registry.clone();
|
||||||
|
let fs = self.fs.clone();
|
||||||
|
let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
|
||||||
|
let task = cx.spawn(|cx| async move {
|
||||||
|
cx.background_executor()
|
||||||
|
.scoped(|cx| {
|
||||||
|
for _ in 0..cx.num_cpus() {
|
||||||
|
cx.spawn(async {
|
||||||
|
while let Ok((entry, handle)) = entries.recv().await {
|
||||||
|
let entry_abs_path = worktree_abs_path.join(&entry.path);
|
||||||
|
match fs.load(&entry_abs_path).await {
|
||||||
|
Ok(text) => {
|
||||||
|
let language = language_registry
|
||||||
|
.language_for_file_path(&entry.path)
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
let chunked_file = ChunkedFile {
|
||||||
|
chunks: chunking::chunk_text(
|
||||||
|
&text,
|
||||||
|
language.as_ref(),
|
||||||
|
&entry.path,
|
||||||
|
),
|
||||||
|
handle,
|
||||||
|
path: entry.path,
|
||||||
|
mtime: entry.mtime,
|
||||||
|
text,
|
||||||
|
};
|
||||||
|
|
||||||
|
if chunked_files_tx.send(chunked_file).await.is_err() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_)=> {
|
||||||
|
log::error!("Failed to read contents into a UTF-8 string: {entry_abs_path:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
|
||||||
|
ChunkFiles {
|
||||||
|
files: chunked_files_rx,
|
||||||
|
task,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_files(
|
||||||
|
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||||
|
chunked_files: channel::Receiver<ChunkedFile>,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> EmbedFiles {
|
||||||
|
let embedding_provider = embedding_provider.clone();
|
||||||
|
let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
|
||||||
|
let task = cx.background_executor().spawn(async move {
|
||||||
|
let mut chunked_file_batches =
|
||||||
|
chunked_files.chunks_timeout(512, Duration::from_secs(2));
|
||||||
|
while let Some(chunked_files) = chunked_file_batches.next().await {
|
||||||
|
// View the batch of files as a vec of chunks
|
||||||
|
// Flatten out to a vec of chunks that we can subdivide into batch sized pieces
|
||||||
|
// Once those are done, reassemble them back into the files in which they belong
|
||||||
|
// If any embeddings fail for a file, the entire file is discarded
|
||||||
|
|
||||||
|
let chunks: Vec<TextToEmbed> = chunked_files
|
||||||
|
.iter()
|
||||||
|
.flat_map(|file| {
|
||||||
|
file.chunks.iter().map(|chunk| TextToEmbed {
|
||||||
|
text: &file.text[chunk.range.clone()],
|
||||||
|
digest: chunk.digest,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let mut embeddings: Vec<Option<Embedding>> = Vec::new();
|
||||||
|
for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
|
||||||
|
if let Some(batch_embeddings) =
|
||||||
|
embedding_provider.embed(embedding_batch).await.log_err()
|
||||||
|
{
|
||||||
|
if batch_embeddings.len() == embedding_batch.len() {
|
||||||
|
embeddings.extend(batch_embeddings.into_iter().map(Some));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
log::error!(
|
||||||
|
"embedding provider returned unexpected embedding count {}, expected {}",
|
||||||
|
batch_embeddings.len(), embedding_batch.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut embeddings = embeddings.into_iter();
|
||||||
|
for chunked_file in chunked_files {
|
||||||
|
let mut embedded_file = EmbeddedFile {
|
||||||
|
path: chunked_file.path,
|
||||||
|
mtime: chunked_file.mtime,
|
||||||
|
chunks: Vec::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut embedded_all_chunks = true;
|
||||||
|
for (chunk, embedding) in
|
||||||
|
chunked_file.chunks.into_iter().zip(embeddings.by_ref())
|
||||||
|
{
|
||||||
|
if let Some(embedding) = embedding {
|
||||||
|
embedded_file
|
||||||
|
.chunks
|
||||||
|
.push(EmbeddedChunk { chunk, embedding });
|
||||||
|
} else {
|
||||||
|
embedded_all_chunks = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedded_all_chunks {
|
||||||
|
embedded_files_tx
|
||||||
|
.send((embedded_file, chunked_file.handle))
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
|
||||||
|
EmbedFiles {
|
||||||
|
files: embedded_files_rx,
|
||||||
|
task,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn persist_embeddings(
|
||||||
|
&self,
|
||||||
|
mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
|
||||||
|
embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> Task<Result<()>> {
|
||||||
|
let db_connection = self.db_connection.clone();
|
||||||
|
let db = self.db;
|
||||||
|
cx.background_executor().spawn(async move {
|
||||||
|
while let Some(deletion_range) = deleted_entry_ranges.next().await {
|
||||||
|
let mut txn = db_connection.write_txn()?;
|
||||||
|
let start = deletion_range.0.as_ref().map(|start| start.as_str());
|
||||||
|
let end = deletion_range.1.as_ref().map(|end| end.as_str());
|
||||||
|
log::debug!("deleting embeddings in range {:?}", &(start, end));
|
||||||
|
db.delete_range(&mut txn, &(start, end))?;
|
||||||
|
txn.commit()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
|
||||||
|
while let Some(embedded_files) = embedded_files.next().await {
|
||||||
|
let mut txn = db_connection.write_txn()?;
|
||||||
|
for (file, _) in &embedded_files {
|
||||||
|
log::debug!("saving embedding for file {:?}", file.path);
|
||||||
|
let key = db_key_for_path(&file.path);
|
||||||
|
db.put(&mut txn, &key, file)?;
|
||||||
|
}
|
||||||
|
txn.commit()?;
|
||||||
|
|
||||||
|
drop(embedded_files);
|
||||||
|
log::debug!("committed");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
|
||||||
|
let connection = self.db_connection.clone();
|
||||||
|
let db = self.db;
|
||||||
|
cx.background_executor().spawn(async move {
|
||||||
|
let tx = connection
|
||||||
|
.read_txn()
|
||||||
|
.context("failed to create read transaction")?;
|
||||||
|
let result = db
|
||||||
|
.iter(&tx)?
|
||||||
|
.map(|entry| Ok(entry?.1.path.clone()))
|
||||||
|
.collect::<Result<Vec<Arc<Path>>>>();
|
||||||
|
drop(tx);
|
||||||
|
result
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn chunks_for_path(
|
||||||
|
&self,
|
||||||
|
path: Arc<Path>,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> Task<Result<Vec<EmbeddedChunk>>> {
|
||||||
|
let connection = self.db_connection.clone();
|
||||||
|
let db = self.db;
|
||||||
|
cx.background_executor().spawn(async move {
|
||||||
|
let tx = connection
|
||||||
|
.read_txn()
|
||||||
|
.context("failed to create read transaction")?;
|
||||||
|
Ok(db
|
||||||
|
.get(&tx, &db_key_for_path(&path))?
|
||||||
|
.ok_or_else(|| anyhow!("no such path"))?
|
||||||
|
.chunks
|
||||||
|
.clone())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ScanEntries {
|
||||||
|
updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
|
||||||
|
deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
|
||||||
|
task: Task<Result<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ChunkFiles {
|
||||||
|
files: channel::Receiver<ChunkedFile>,
|
||||||
|
task: Task<Result<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ChunkedFile {
|
||||||
|
pub path: Arc<Path>,
|
||||||
|
pub mtime: Option<SystemTime>,
|
||||||
|
pub handle: IndexingEntryHandle,
|
||||||
|
pub text: String,
|
||||||
|
pub chunks: Vec<Chunk>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EmbedFiles {
|
||||||
|
pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
|
||||||
|
pub task: Task<Result<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct EmbeddedFile {
|
||||||
|
pub path: Arc<Path>,
|
||||||
|
pub mtime: Option<SystemTime>,
|
||||||
|
pub chunks: Vec<EmbeddedChunk>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct EmbeddedChunk {
|
||||||
|
pub chunk: Chunk,
|
||||||
|
pub embedding: Embedding,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn db_key_for_path(path: &Arc<Path>) -> String {
|
||||||
|
path.to_string_lossy().replace('/', "\0")
|
||||||
|
}
|
49
crates/semantic_index/src/indexing.rs
Normal file
49
crates/semantic_index/src/indexing.rs
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
use collections::HashSet;
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
use project::ProjectEntryId;
|
||||||
|
use smol::channel;
|
||||||
|
use std::sync::{Arc, Weak};
|
||||||
|
|
||||||
|
/// The set of entries that are currently being indexed.
|
||||||
|
pub struct IndexingEntrySet {
|
||||||
|
entry_ids: Mutex<HashSet<ProjectEntryId>>,
|
||||||
|
tx: channel::Sender<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// When dropped, removes the entry from the set of entries that are being indexed.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub(crate) struct IndexingEntryHandle {
|
||||||
|
entry_id: ProjectEntryId,
|
||||||
|
set: Weak<IndexingEntrySet>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IndexingEntrySet {
|
||||||
|
pub fn new(tx: channel::Sender<()>) -> Self {
|
||||||
|
Self {
|
||||||
|
entry_ids: Default::default(),
|
||||||
|
tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
|
||||||
|
self.entry_ids.lock().insert(entry_id);
|
||||||
|
self.tx.send_blocking(()).ok();
|
||||||
|
IndexingEntryHandle {
|
||||||
|
entry_id,
|
||||||
|
set: Arc::downgrade(self),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.entry_ids.lock().len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for IndexingEntryHandle {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Some(set) = self.set.upgrade() {
|
||||||
|
set.tx.send_blocking(()).ok();
|
||||||
|
set.entry_ids.lock().remove(&self.entry_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
523
crates/semantic_index/src/project_index.rs
Normal file
523
crates/semantic_index/src/project_index.rs
Normal file
|
@ -0,0 +1,523 @@
|
||||||
|
use crate::{
|
||||||
|
embedding::{EmbeddingProvider, TextToEmbed},
|
||||||
|
summary_index::FileSummary,
|
||||||
|
worktree_index::{WorktreeIndex, WorktreeIndexHandle},
|
||||||
|
};
|
||||||
|
use anyhow::{anyhow, Context, Result};
|
||||||
|
use collections::HashMap;
|
||||||
|
use fs::Fs;
|
||||||
|
use futures::{stream::StreamExt, FutureExt};
|
||||||
|
use gpui::{
|
||||||
|
AppContext, Entity, EntityId, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel,
|
||||||
|
};
|
||||||
|
use language::LanguageRegistry;
|
||||||
|
use log;
|
||||||
|
use project::{Project, Worktree, WorktreeId};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use smol::channel;
|
||||||
|
use std::{cmp::Ordering, future::Future, num::NonZeroUsize, ops::Range, path::Path, sync::Arc};
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct SearchResult {
|
||||||
|
pub worktree: Model<Worktree>,
|
||||||
|
pub path: Arc<Path>,
|
||||||
|
pub range: Range<usize>,
|
||||||
|
pub score: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct WorktreeSearchResult {
|
||||||
|
pub worktree_id: WorktreeId,
|
||||||
|
pub path: Arc<Path>,
|
||||||
|
pub range: Range<usize>,
|
||||||
|
pub score: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||||
|
pub enum Status {
|
||||||
|
Idle,
|
||||||
|
Loading,
|
||||||
|
Scanning { remaining_count: NonZeroUsize },
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ProjectIndex {
|
||||||
|
db_connection: heed::Env,
|
||||||
|
project: WeakModel<Project>,
|
||||||
|
worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
|
||||||
|
language_registry: Arc<LanguageRegistry>,
|
||||||
|
fs: Arc<dyn Fs>,
|
||||||
|
last_status: Status,
|
||||||
|
status_tx: channel::Sender<()>,
|
||||||
|
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||||
|
_maintain_status: Task<()>,
|
||||||
|
_subscription: Subscription,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProjectIndex {
|
||||||
|
pub fn new(
|
||||||
|
project: Model<Project>,
|
||||||
|
db_connection: heed::Env,
|
||||||
|
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||||
|
cx: &mut ModelContext<Self>,
|
||||||
|
) -> Self {
|
||||||
|
let language_registry = project.read(cx).languages().clone();
|
||||||
|
let fs = project.read(cx).fs().clone();
|
||||||
|
let (status_tx, mut status_rx) = channel::unbounded();
|
||||||
|
let mut this = ProjectIndex {
|
||||||
|
db_connection,
|
||||||
|
project: project.downgrade(),
|
||||||
|
worktree_indices: HashMap::default(),
|
||||||
|
language_registry,
|
||||||
|
fs,
|
||||||
|
status_tx,
|
||||||
|
last_status: Status::Idle,
|
||||||
|
embedding_provider,
|
||||||
|
_subscription: cx.subscribe(&project, Self::handle_project_event),
|
||||||
|
_maintain_status: cx.spawn(|this, mut cx| async move {
|
||||||
|
while status_rx.next().await.is_some() {
|
||||||
|
if this
|
||||||
|
.update(&mut cx, |this, cx| this.update_status(cx))
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
this.update_worktree_indices(cx);
|
||||||
|
this
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn status(&self) -> Status {
|
||||||
|
self.last_status
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn project(&self) -> WeakModel<Project> {
|
||||||
|
self.project.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fs(&self) -> Arc<dyn Fs> {
|
||||||
|
self.fs.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_project_event(
|
||||||
|
&mut self,
|
||||||
|
_: Model<Project>,
|
||||||
|
event: &project::Event,
|
||||||
|
cx: &mut ModelContext<Self>,
|
||||||
|
) {
|
||||||
|
match event {
|
||||||
|
project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
|
||||||
|
self.update_worktree_indices(cx);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
|
||||||
|
let Some(project) = self.project.upgrade() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let worktrees = project
|
||||||
|
.read(cx)
|
||||||
|
.visible_worktrees(cx)
|
||||||
|
.filter_map(|worktree| {
|
||||||
|
if worktree.read(cx).is_local() {
|
||||||
|
Some((worktree.entity_id(), worktree))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<HashMap<_, _>>();
|
||||||
|
|
||||||
|
self.worktree_indices
|
||||||
|
.retain(|worktree_id, _| worktrees.contains_key(worktree_id));
|
||||||
|
for (worktree_id, worktree) in worktrees {
|
||||||
|
self.worktree_indices.entry(worktree_id).or_insert_with(|| {
|
||||||
|
let worktree_index = WorktreeIndex::load(
|
||||||
|
worktree.clone(),
|
||||||
|
self.db_connection.clone(),
|
||||||
|
self.language_registry.clone(),
|
||||||
|
self.fs.clone(),
|
||||||
|
self.status_tx.clone(),
|
||||||
|
self.embedding_provider.clone(),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
|
||||||
|
let load_worktree = cx.spawn(|this, mut cx| async move {
|
||||||
|
let result = match worktree_index.await {
|
||||||
|
Ok(worktree_index) => {
|
||||||
|
this.update(&mut cx, |this, _| {
|
||||||
|
this.worktree_indices.insert(
|
||||||
|
worktree_id,
|
||||||
|
WorktreeIndexHandle::Loaded {
|
||||||
|
index: worktree_index.clone(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
})?;
|
||||||
|
Ok(worktree_index)
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
this.update(&mut cx, |this, _cx| {
|
||||||
|
this.worktree_indices.remove(&worktree_id)
|
||||||
|
})?;
|
||||||
|
Err(Arc::new(error))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
this.update(&mut cx, |this, cx| this.update_status(cx))?;
|
||||||
|
|
||||||
|
result
|
||||||
|
});
|
||||||
|
|
||||||
|
WorktreeIndexHandle::Loading {
|
||||||
|
index: load_worktree.shared(),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
self.update_status(cx);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_status(&mut self, cx: &mut ModelContext<Self>) {
|
||||||
|
let mut indexing_count = 0;
|
||||||
|
let mut any_loading = false;
|
||||||
|
|
||||||
|
for index in self.worktree_indices.values_mut() {
|
||||||
|
match index {
|
||||||
|
WorktreeIndexHandle::Loading { .. } => {
|
||||||
|
any_loading = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
WorktreeIndexHandle::Loaded { index, .. } => {
|
||||||
|
indexing_count += index.read(cx).entry_ids_being_indexed().len();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let status = if any_loading {
|
||||||
|
Status::Loading
|
||||||
|
} else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
|
||||||
|
Status::Scanning { remaining_count }
|
||||||
|
} else {
|
||||||
|
Status::Idle
|
||||||
|
};
|
||||||
|
|
||||||
|
if status != self.last_status {
|
||||||
|
self.last_status = status;
|
||||||
|
cx.emit(status);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn search(
|
||||||
|
&self,
|
||||||
|
query: String,
|
||||||
|
limit: usize,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> Task<Result<Vec<SearchResult>>> {
|
||||||
|
let (chunks_tx, chunks_rx) = channel::bounded(1024);
|
||||||
|
let mut worktree_scan_tasks = Vec::new();
|
||||||
|
for worktree_index in self.worktree_indices.values() {
|
||||||
|
let worktree_index = worktree_index.clone();
|
||||||
|
let chunks_tx = chunks_tx.clone();
|
||||||
|
worktree_scan_tasks.push(cx.spawn(|cx| async move {
|
||||||
|
let index = match worktree_index {
|
||||||
|
WorktreeIndexHandle::Loading { index } => {
|
||||||
|
index.clone().await.map_err(|error| anyhow!(error))?
|
||||||
|
}
|
||||||
|
WorktreeIndexHandle::Loaded { index } => index.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
index
|
||||||
|
.read_with(&cx, |index, cx| {
|
||||||
|
let worktree_id = index.worktree().read(cx).id();
|
||||||
|
let db_connection = index.db_connection().clone();
|
||||||
|
let db = *index.embedding_index().db();
|
||||||
|
cx.background_executor().spawn(async move {
|
||||||
|
let txn = db_connection
|
||||||
|
.read_txn()
|
||||||
|
.context("failed to create read transaction")?;
|
||||||
|
let db_entries = db.iter(&txn).context("failed to iterate database")?;
|
||||||
|
for db_entry in db_entries {
|
||||||
|
let (_key, db_embedded_file) = db_entry?;
|
||||||
|
for chunk in db_embedded_file.chunks {
|
||||||
|
chunks_tx
|
||||||
|
.send((worktree_id, db_embedded_file.path.clone(), chunk))
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anyhow::Ok(())
|
||||||
|
})
|
||||||
|
})?
|
||||||
|
.await
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
drop(chunks_tx);
|
||||||
|
|
||||||
|
let project = self.project.clone();
|
||||||
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
|
cx.spawn(|cx| async move {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
let embedding_query_start = std::time::Instant::now();
|
||||||
|
log::info!("Searching for {query}");
|
||||||
|
|
||||||
|
let query_embeddings = embedding_provider
|
||||||
|
.embed(&[TextToEmbed::new(&query)])
|
||||||
|
.await?;
|
||||||
|
let query_embedding = query_embeddings
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| anyhow!("no embedding for query"))?;
|
||||||
|
|
||||||
|
let mut results_by_worker = Vec::new();
|
||||||
|
for _ in 0..cx.background_executor().num_cpus() {
|
||||||
|
results_by_worker.push(Vec::<WorktreeSearchResult>::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
let search_start = std::time::Instant::now();
|
||||||
|
|
||||||
|
cx.background_executor()
|
||||||
|
.scoped(|cx| {
|
||||||
|
for results in results_by_worker.iter_mut() {
|
||||||
|
cx.spawn(async {
|
||||||
|
while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
|
||||||
|
let score = chunk.embedding.similarity(&query_embedding);
|
||||||
|
let ix = match results.binary_search_by(|probe| {
|
||||||
|
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
|
||||||
|
}) {
|
||||||
|
Ok(ix) | Err(ix) => ix,
|
||||||
|
};
|
||||||
|
results.insert(
|
||||||
|
ix,
|
||||||
|
WorktreeSearchResult {
|
||||||
|
worktree_id,
|
||||||
|
path: path.clone(),
|
||||||
|
range: chunk.chunk.range.clone(),
|
||||||
|
score,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
results.truncate(limit);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
for scan_task in futures::future::join_all(worktree_scan_tasks).await {
|
||||||
|
scan_task.log_err();
|
||||||
|
}
|
||||||
|
|
||||||
|
project.read_with(&cx, |project, cx| {
|
||||||
|
let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
|
||||||
|
for worker_results in results_by_worker {
|
||||||
|
search_results.extend(worker_results.into_iter().filter_map(|result| {
|
||||||
|
Some(SearchResult {
|
||||||
|
worktree: project.worktree_for_id(result.worktree_id, cx)?,
|
||||||
|
path: result.path,
|
||||||
|
range: result.range,
|
||||||
|
score: result.score,
|
||||||
|
})
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
search_results.sort_unstable_by(|a, b| {
|
||||||
|
b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
|
||||||
|
});
|
||||||
|
search_results.truncate(limit);
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
let search_elapsed = search_start.elapsed();
|
||||||
|
log::debug!(
|
||||||
|
"searched {} entries in {:?}",
|
||||||
|
search_results.len(),
|
||||||
|
search_elapsed
|
||||||
|
);
|
||||||
|
let embedding_query_elapsed = embedding_query_start.elapsed();
|
||||||
|
log::debug!("embedding query took {:?}", embedding_query_elapsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
search_results
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
|
||||||
|
let mut result = 0;
|
||||||
|
for worktree_index in self.worktree_indices.values() {
|
||||||
|
if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
|
||||||
|
result += index.read(cx).path_count()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn worktree_index(
|
||||||
|
&self,
|
||||||
|
worktree_id: WorktreeId,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> Option<Model<WorktreeIndex>> {
|
||||||
|
for index in self.worktree_indices.values() {
|
||||||
|
if let WorktreeIndexHandle::Loaded { index, .. } = index {
|
||||||
|
if index.read(cx).worktree().read(cx).id() == worktree_id {
|
||||||
|
return Some(index.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn worktree_indices(&self, cx: &AppContext) -> Vec<Model<WorktreeIndex>> {
|
||||||
|
let mut result = self
|
||||||
|
.worktree_indices
|
||||||
|
.values()
|
||||||
|
.filter_map(|index| {
|
||||||
|
if let WorktreeIndexHandle::Loaded { index, .. } = index {
|
||||||
|
Some(index.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
result.sort_by_key(|index| index.read(cx).worktree().read(cx).id());
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn all_summaries(&self, cx: &AppContext) -> Task<Result<Vec<FileSummary>>> {
|
||||||
|
let (summaries_tx, summaries_rx) = channel::bounded(1024);
|
||||||
|
let mut worktree_scan_tasks = Vec::new();
|
||||||
|
for worktree_index in self.worktree_indices.values() {
|
||||||
|
let worktree_index = worktree_index.clone();
|
||||||
|
let summaries_tx: channel::Sender<(String, String)> = summaries_tx.clone();
|
||||||
|
worktree_scan_tasks.push(cx.spawn(|cx| async move {
|
||||||
|
let index = match worktree_index {
|
||||||
|
WorktreeIndexHandle::Loading { index } => {
|
||||||
|
index.clone().await.map_err(|error| anyhow!(error))?
|
||||||
|
}
|
||||||
|
WorktreeIndexHandle::Loaded { index } => index.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
index
|
||||||
|
.read_with(&cx, |index, cx| {
|
||||||
|
let db_connection = index.db_connection().clone();
|
||||||
|
let summary_index = index.summary_index();
|
||||||
|
let file_digest_db = summary_index.file_digest_db();
|
||||||
|
let summary_db = summary_index.summary_db();
|
||||||
|
|
||||||
|
cx.background_executor().spawn(async move {
|
||||||
|
let txn = db_connection
|
||||||
|
.read_txn()
|
||||||
|
.context("failed to create db read transaction")?;
|
||||||
|
let db_entries = file_digest_db
|
||||||
|
.iter(&txn)
|
||||||
|
.context("failed to iterate database")?;
|
||||||
|
for db_entry in db_entries {
|
||||||
|
let (file_path, db_file) = db_entry?;
|
||||||
|
|
||||||
|
match summary_db.get(&txn, &db_file.digest) {
|
||||||
|
Ok(opt_summary) => {
|
||||||
|
// Currently, we only use summaries we already have. If the file hasn't been
|
||||||
|
// summarized yet, then we skip it and don't include it in the inferred context.
|
||||||
|
// If we want to do just-in-time summarization, this would be the place to do it!
|
||||||
|
if let Some(summary) = opt_summary {
|
||||||
|
summaries_tx
|
||||||
|
.send((file_path.to_string(), summary.to_string()))
|
||||||
|
.await?;
|
||||||
|
} else {
|
||||||
|
log::warn!("No summary found for {:?}", &db_file);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
log::error!(
|
||||||
|
"Error reading from summary database: {:?}",
|
||||||
|
err
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anyhow::Ok(())
|
||||||
|
})
|
||||||
|
})?
|
||||||
|
.await
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
drop(summaries_tx);
|
||||||
|
|
||||||
|
let project = self.project.clone();
|
||||||
|
cx.spawn(|cx| async move {
|
||||||
|
let mut results_by_worker = Vec::new();
|
||||||
|
for _ in 0..cx.background_executor().num_cpus() {
|
||||||
|
results_by_worker.push(Vec::<FileSummary>::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
cx.background_executor()
|
||||||
|
.scoped(|cx| {
|
||||||
|
for results in results_by_worker.iter_mut() {
|
||||||
|
cx.spawn(async {
|
||||||
|
while let Ok((filename, summary)) = summaries_rx.recv().await {
|
||||||
|
results.push(FileSummary { filename, summary });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
for scan_task in futures::future::join_all(worktree_scan_tasks).await {
|
||||||
|
scan_task.log_err();
|
||||||
|
}
|
||||||
|
|
||||||
|
project.read_with(&cx, |_project, _cx| {
|
||||||
|
results_by_worker.into_iter().flatten().collect()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Empty out the backlogs of all the worktrees in the project
|
||||||
|
pub fn flush_summary_backlogs(&self, cx: &AppContext) -> impl Future<Output = ()> {
|
||||||
|
let flush_start = std::time::Instant::now();
|
||||||
|
|
||||||
|
futures::future::join_all(self.worktree_indices.values().map(|worktree_index| {
|
||||||
|
let worktree_index = worktree_index.clone();
|
||||||
|
|
||||||
|
cx.spawn(|cx| async move {
|
||||||
|
let index = match worktree_index {
|
||||||
|
WorktreeIndexHandle::Loading { index } => {
|
||||||
|
index.clone().await.map_err(|error| anyhow!(error))?
|
||||||
|
}
|
||||||
|
WorktreeIndexHandle::Loaded { index } => index.clone(),
|
||||||
|
};
|
||||||
|
let worktree_abs_path =
|
||||||
|
cx.update(|cx| index.read(cx).worktree().read(cx).abs_path())?;
|
||||||
|
|
||||||
|
index
|
||||||
|
.read_with(&cx, |index, cx| {
|
||||||
|
cx.background_executor()
|
||||||
|
.spawn(index.summary_index().flush_backlog(worktree_abs_path, cx))
|
||||||
|
})?
|
||||||
|
.await
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
.map(move |results| {
|
||||||
|
// Log any errors, but don't block the user. These summaries are supposed to
|
||||||
|
// improve quality by providing extra context, but they aren't hard requirements!
|
||||||
|
for result in results {
|
||||||
|
if let Err(err) = result {
|
||||||
|
log::error!("Error flushing summary backlog: {:?}", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!("Summary backlog flushed in {:?}", flush_start.elapsed());
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remaining_summaries(&self, cx: &mut ModelContext<Self>) -> usize {
|
||||||
|
self.worktree_indices(cx)
|
||||||
|
.iter()
|
||||||
|
.map(|index| index.read(cx).summary_index().backlog_len())
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EventEmitter<Status> for ProjectIndex {}
|
|
@ -55,8 +55,12 @@ impl ProjectIndexDebugView {
|
||||||
for index in worktree_indices {
|
for index in worktree_indices {
|
||||||
let (root_path, worktree_id, worktree_paths) =
|
let (root_path, worktree_id, worktree_paths) =
|
||||||
index.read_with(&cx, |index, cx| {
|
index.read_with(&cx, |index, cx| {
|
||||||
let worktree = index.worktree.read(cx);
|
let worktree = index.worktree().read(cx);
|
||||||
(worktree.abs_path(), worktree.id(), index.paths(cx))
|
(
|
||||||
|
worktree.abs_path(),
|
||||||
|
worktree.id(),
|
||||||
|
index.embedding_index().paths(cx),
|
||||||
|
)
|
||||||
})?;
|
})?;
|
||||||
rows.push(Row::Worktree(root_path));
|
rows.push(Row::Worktree(root_path));
|
||||||
rows.extend(
|
rows.extend(
|
||||||
|
@ -82,10 +86,12 @@ impl ProjectIndexDebugView {
|
||||||
cx: &mut ViewContext<Self>,
|
cx: &mut ViewContext<Self>,
|
||||||
) -> Option<()> {
|
) -> Option<()> {
|
||||||
let project_index = self.index.read(cx);
|
let project_index = self.index.read(cx);
|
||||||
let fs = project_index.fs.clone();
|
let fs = project_index.fs().clone();
|
||||||
let worktree_index = project_index.worktree_index(worktree_id, cx)?.read(cx);
|
let worktree_index = project_index.worktree_index(worktree_id, cx)?.read(cx);
|
||||||
let root_path = worktree_index.worktree.read(cx).abs_path();
|
let root_path = worktree_index.worktree().read(cx).abs_path();
|
||||||
let chunks = worktree_index.chunks_for_path(file_path.clone(), cx);
|
let chunks = worktree_index
|
||||||
|
.embedding_index()
|
||||||
|
.chunks_for_path(file_path.clone(), cx);
|
||||||
|
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
let chunks = chunks.await?;
|
let chunks = chunks.await?;
|
||||||
|
|
File diff suppressed because it is too large
Load diff
48
crates/semantic_index/src/summary_backlog.rs
Normal file
48
crates/semantic_index/src/summary_backlog.rs
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
use collections::HashMap;
|
||||||
|
use std::{path::Path, sync::Arc, time::SystemTime};
|
||||||
|
|
||||||
|
const MAX_FILES_BEFORE_RESUMMARIZE: usize = 4;
|
||||||
|
const MAX_BYTES_BEFORE_RESUMMARIZE: u64 = 1_000_000; // 1 MB
|
||||||
|
|
||||||
|
#[derive(Default, Debug)]
|
||||||
|
pub struct SummaryBacklog {
|
||||||
|
/// Key: path to a file that needs summarization, but that we haven't summarized yet. Value: that file's size on disk, in bytes, and its mtime.
|
||||||
|
files: HashMap<Arc<Path>, (u64, Option<SystemTime>)>,
|
||||||
|
/// Cache of the sum of all values in `files`, so we don't have to traverse the whole map to check if we're over the byte limit.
|
||||||
|
total_bytes: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SummaryBacklog {
|
||||||
|
/// Store the given path in the backlog, along with how many bytes are in it.
|
||||||
|
pub fn insert(&mut self, path: Arc<Path>, bytes_on_disk: u64, mtime: Option<SystemTime>) {
|
||||||
|
let (prev_bytes, _) = self
|
||||||
|
.files
|
||||||
|
.insert(path, (bytes_on_disk, mtime))
|
||||||
|
.unwrap_or_default(); // Default to 0 prev_bytes
|
||||||
|
|
||||||
|
// Update the cached total by subtracting out the old amount and adding the new one.
|
||||||
|
self.total_bytes = self.total_bytes - prev_bytes + bytes_on_disk;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if the total number of bytes in the backlog exceeds a predefined threshold.
|
||||||
|
pub fn needs_drain(&self) -> bool {
|
||||||
|
self.files.len() > MAX_FILES_BEFORE_RESUMMARIZE ||
|
||||||
|
// The whole purpose of the cached total_bytes is to make this comparison cheap.
|
||||||
|
// Otherwise we'd have to traverse the entire dictionary every time we wanted this answer.
|
||||||
|
self.total_bytes > MAX_BYTES_BEFORE_RESUMMARIZE
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove all the entries in the backlog and return the file paths as an iterator.
|
||||||
|
#[allow(clippy::needless_lifetimes)] // Clippy thinks this 'a can be elided, but eliding it gives a compile error
|
||||||
|
pub fn drain<'a>(&'a mut self) -> impl Iterator<Item = (Arc<Path>, Option<SystemTime>)> + 'a {
|
||||||
|
self.total_bytes = 0;
|
||||||
|
|
||||||
|
self.files
|
||||||
|
.drain()
|
||||||
|
.map(|(path, (_size, mtime))| (path, mtime))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.files.len()
|
||||||
|
}
|
||||||
|
}
|
693
crates/semantic_index/src/summary_index.rs
Normal file
693
crates/semantic_index/src/summary_index.rs
Normal file
|
@ -0,0 +1,693 @@
|
||||||
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
|
use arrayvec::ArrayString;
|
||||||
|
use fs::Fs;
|
||||||
|
use futures::{stream::StreamExt, TryFutureExt};
|
||||||
|
use futures_batch::ChunksTimeoutStreamExt;
|
||||||
|
use gpui::{AppContext, Model, Task};
|
||||||
|
use heed::{
|
||||||
|
types::{SerdeBincode, Str},
|
||||||
|
RoTxn,
|
||||||
|
};
|
||||||
|
use language_model::{
|
||||||
|
LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
|
||||||
|
LanguageModelRequestMessage, Role,
|
||||||
|
};
|
||||||
|
use log;
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
use project::{Entry, UpdatedEntriesSet, Worktree};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use smol::channel;
|
||||||
|
use std::{
|
||||||
|
future::Future,
|
||||||
|
path::Path,
|
||||||
|
sync::Arc,
|
||||||
|
time::{Duration, Instant, SystemTime},
|
||||||
|
};
|
||||||
|
use util::ResultExt;
|
||||||
|
use worktree::Snapshot;
|
||||||
|
|
||||||
|
use crate::{indexing::IndexingEntrySet, summary_backlog::SummaryBacklog};
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct FileSummary {
|
||||||
|
pub filename: String,
|
||||||
|
pub summary: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct UnsummarizedFile {
|
||||||
|
// Path to the file on disk
|
||||||
|
path: Arc<Path>,
|
||||||
|
// The mtime of the file on disk
|
||||||
|
mtime: Option<SystemTime>,
|
||||||
|
// BLAKE3 hash of the source file's contents
|
||||||
|
digest: Blake3Digest,
|
||||||
|
// The source file's contents
|
||||||
|
contents: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct SummarizedFile {
|
||||||
|
// Path to the file on disk
|
||||||
|
path: String,
|
||||||
|
// The mtime of the file on disk
|
||||||
|
mtime: Option<SystemTime>,
|
||||||
|
// BLAKE3 hash of the source file's contents
|
||||||
|
digest: Blake3Digest,
|
||||||
|
// The LLM's summary of the file's contents
|
||||||
|
summary: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This is what blake3's to_hex() method returns - see https://docs.rs/blake3/1.5.3/src/blake3/lib.rs.html#246
|
||||||
|
pub type Blake3Digest = ArrayString<{ blake3::OUT_LEN * 2 }>;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct FileDigest {
|
||||||
|
pub mtime: Option<SystemTime>,
|
||||||
|
pub digest: Blake3Digest,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct NeedsSummary {
|
||||||
|
files: channel::Receiver<UnsummarizedFile>,
|
||||||
|
task: Task<Result<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SummarizeFiles {
|
||||||
|
files: channel::Receiver<SummarizedFile>,
|
||||||
|
task: Task<Result<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct SummaryIndex {
|
||||||
|
worktree: Model<Worktree>,
|
||||||
|
fs: Arc<dyn Fs>,
|
||||||
|
db_connection: heed::Env,
|
||||||
|
file_digest_db: heed::Database<Str, SerdeBincode<FileDigest>>, // Key: file path. Val: BLAKE3 digest of its contents.
|
||||||
|
summary_db: heed::Database<SerdeBincode<Blake3Digest>, Str>, // Key: BLAKE3 digest of a file's contents. Val: LLM summary of those contents.
|
||||||
|
backlog: Arc<Mutex<SummaryBacklog>>,
|
||||||
|
_entry_ids_being_indexed: Arc<IndexingEntrySet>, // TODO can this be removed?
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Backlogged {
|
||||||
|
paths_to_digest: channel::Receiver<Vec<(Arc<Path>, Option<SystemTime>)>>,
|
||||||
|
task: Task<Result<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MightNeedSummaryFiles {
|
||||||
|
files: channel::Receiver<UnsummarizedFile>,
|
||||||
|
task: Task<Result<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SummaryIndex {
|
||||||
|
pub fn new(
|
||||||
|
worktree: Model<Worktree>,
|
||||||
|
fs: Arc<dyn Fs>,
|
||||||
|
db_connection: heed::Env,
|
||||||
|
file_digest_db: heed::Database<Str, SerdeBincode<FileDigest>>,
|
||||||
|
summary_db: heed::Database<SerdeBincode<Blake3Digest>, Str>,
|
||||||
|
_entry_ids_being_indexed: Arc<IndexingEntrySet>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
worktree,
|
||||||
|
fs,
|
||||||
|
db_connection,
|
||||||
|
file_digest_db,
|
||||||
|
summary_db,
|
||||||
|
_entry_ids_being_indexed,
|
||||||
|
backlog: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn file_digest_db(&self) -> heed::Database<Str, SerdeBincode<FileDigest>> {
|
||||||
|
self.file_digest_db
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn summary_db(&self) -> heed::Database<SerdeBincode<Blake3Digest>, Str> {
|
||||||
|
self.summary_db
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn index_entries_changed_on_disk(
|
||||||
|
&self,
|
||||||
|
is_auto_available: bool,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> impl Future<Output = Result<()>> {
|
||||||
|
let start = Instant::now();
|
||||||
|
let backlogged;
|
||||||
|
let digest;
|
||||||
|
let needs_summary;
|
||||||
|
let summaries;
|
||||||
|
let persist;
|
||||||
|
|
||||||
|
if is_auto_available {
|
||||||
|
let worktree = self.worktree.read(cx).snapshot();
|
||||||
|
let worktree_abs_path = worktree.abs_path().clone();
|
||||||
|
|
||||||
|
backlogged = self.scan_entries(worktree, cx);
|
||||||
|
digest = self.digest_files(backlogged.paths_to_digest, worktree_abs_path, cx);
|
||||||
|
needs_summary = self.check_summary_cache(digest.files, cx);
|
||||||
|
summaries = self.summarize_files(needs_summary.files, cx);
|
||||||
|
persist = self.persist_summaries(summaries.files, cx);
|
||||||
|
} else {
|
||||||
|
// This feature is only staff-shipped, so make the rest of these no-ops.
|
||||||
|
backlogged = Backlogged {
|
||||||
|
paths_to_digest: channel::unbounded().1,
|
||||||
|
task: Task::ready(Ok(())),
|
||||||
|
};
|
||||||
|
digest = MightNeedSummaryFiles {
|
||||||
|
files: channel::unbounded().1,
|
||||||
|
task: Task::ready(Ok(())),
|
||||||
|
};
|
||||||
|
needs_summary = NeedsSummary {
|
||||||
|
files: channel::unbounded().1,
|
||||||
|
task: Task::ready(Ok(())),
|
||||||
|
};
|
||||||
|
summaries = SummarizeFiles {
|
||||||
|
files: channel::unbounded().1,
|
||||||
|
task: Task::ready(Ok(())),
|
||||||
|
};
|
||||||
|
persist = Task::ready(Ok(()));
|
||||||
|
}
|
||||||
|
|
||||||
|
async move {
|
||||||
|
futures::try_join!(
|
||||||
|
backlogged.task,
|
||||||
|
digest.task,
|
||||||
|
needs_summary.task,
|
||||||
|
summaries.task,
|
||||||
|
persist
|
||||||
|
)?;
|
||||||
|
|
||||||
|
if is_auto_available {
|
||||||
|
log::info!(
|
||||||
|
"Summarizing everything that changed on disk took {:?}",
|
||||||
|
start.elapsed()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn index_updated_entries(
|
||||||
|
&mut self,
|
||||||
|
updated_entries: UpdatedEntriesSet,
|
||||||
|
is_auto_available: bool,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> impl Future<Output = Result<()>> {
|
||||||
|
let start = Instant::now();
|
||||||
|
let backlogged;
|
||||||
|
let digest;
|
||||||
|
let needs_summary;
|
||||||
|
let summaries;
|
||||||
|
let persist;
|
||||||
|
|
||||||
|
if is_auto_available {
|
||||||
|
let worktree = self.worktree.read(cx).snapshot();
|
||||||
|
let worktree_abs_path = worktree.abs_path().clone();
|
||||||
|
|
||||||
|
backlogged = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
|
||||||
|
digest = self.digest_files(backlogged.paths_to_digest, worktree_abs_path, cx);
|
||||||
|
needs_summary = self.check_summary_cache(digest.files, cx);
|
||||||
|
summaries = self.summarize_files(needs_summary.files, cx);
|
||||||
|
persist = self.persist_summaries(summaries.files, cx);
|
||||||
|
} else {
|
||||||
|
// This feature is only staff-shipped, so make the rest of these no-ops.
|
||||||
|
backlogged = Backlogged {
|
||||||
|
paths_to_digest: channel::unbounded().1,
|
||||||
|
task: Task::ready(Ok(())),
|
||||||
|
};
|
||||||
|
digest = MightNeedSummaryFiles {
|
||||||
|
files: channel::unbounded().1,
|
||||||
|
task: Task::ready(Ok(())),
|
||||||
|
};
|
||||||
|
needs_summary = NeedsSummary {
|
||||||
|
files: channel::unbounded().1,
|
||||||
|
task: Task::ready(Ok(())),
|
||||||
|
};
|
||||||
|
summaries = SummarizeFiles {
|
||||||
|
files: channel::unbounded().1,
|
||||||
|
task: Task::ready(Ok(())),
|
||||||
|
};
|
||||||
|
persist = Task::ready(Ok(()));
|
||||||
|
}
|
||||||
|
|
||||||
|
async move {
|
||||||
|
futures::try_join!(
|
||||||
|
backlogged.task,
|
||||||
|
digest.task,
|
||||||
|
needs_summary.task,
|
||||||
|
summaries.task,
|
||||||
|
persist
|
||||||
|
)?;
|
||||||
|
|
||||||
|
log::info!("Summarizing updated entries took {:?}", start.elapsed());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_summary_cache(
|
||||||
|
&self,
|
||||||
|
mut might_need_summary: channel::Receiver<UnsummarizedFile>,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> NeedsSummary {
|
||||||
|
let db_connection = self.db_connection.clone();
|
||||||
|
let db = self.summary_db;
|
||||||
|
let (needs_summary_tx, needs_summary_rx) = channel::bounded(512);
|
||||||
|
let task = cx.background_executor().spawn(async move {
|
||||||
|
while let Some(file) = might_need_summary.next().await {
|
||||||
|
let tx = db_connection
|
||||||
|
.read_txn()
|
||||||
|
.context("Failed to create read transaction for checking which hashes are in summary cache")?;
|
||||||
|
|
||||||
|
match db.get(&tx, &file.digest) {
|
||||||
|
Ok(opt_answer) => {
|
||||||
|
if opt_answer.is_none() {
|
||||||
|
// It's not in the summary cache db, so we need to summarize it.
|
||||||
|
log::debug!("File {:?} (digest {:?}) was NOT in the db cache and needs to be resummarized.", file.path.display(), &file.digest);
|
||||||
|
needs_summary_tx.send(file).await?;
|
||||||
|
} else {
|
||||||
|
log::debug!("File {:?} (digest {:?}) was in the db cache and does not need to be resummarized.", file.path.display(), &file.digest);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
log::error!("Reading from the summaries database failed: {:?}", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
|
||||||
|
NeedsSummary {
|
||||||
|
files: needs_summary_rx,
|
||||||
|
task,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> Backlogged {
|
||||||
|
let (tx, rx) = channel::bounded(512);
|
||||||
|
let db_connection = self.db_connection.clone();
|
||||||
|
let digest_db = self.file_digest_db;
|
||||||
|
let backlog = Arc::clone(&self.backlog);
|
||||||
|
let task = cx.background_executor().spawn(async move {
|
||||||
|
let txn = db_connection
|
||||||
|
.read_txn()
|
||||||
|
.context("failed to create read transaction")?;
|
||||||
|
|
||||||
|
for entry in worktree.files(false, 0) {
|
||||||
|
let needs_summary =
|
||||||
|
Self::add_to_backlog(Arc::clone(&backlog), digest_db, &txn, entry);
|
||||||
|
|
||||||
|
if !needs_summary.is_empty() {
|
||||||
|
tx.send(needs_summary).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO delete db entries for deleted files
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
|
||||||
|
Backlogged {
|
||||||
|
paths_to_digest: rx,
|
||||||
|
task,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_to_backlog(
|
||||||
|
backlog: Arc<Mutex<SummaryBacklog>>,
|
||||||
|
digest_db: heed::Database<Str, SerdeBincode<FileDigest>>,
|
||||||
|
txn: &RoTxn<'_>,
|
||||||
|
entry: &Entry,
|
||||||
|
) -> Vec<(Arc<Path>, Option<SystemTime>)> {
|
||||||
|
let entry_db_key = db_key_for_path(&entry.path);
|
||||||
|
|
||||||
|
match digest_db.get(&txn, &entry_db_key) {
|
||||||
|
Ok(opt_saved_digest) => {
|
||||||
|
// The file path is the same, but the mtime is different. (Or there was no mtime.)
|
||||||
|
// It needs updating, so add it to the backlog! Then, if the backlog is full, drain it and summarize its contents.
|
||||||
|
if entry.mtime != opt_saved_digest.and_then(|digest| digest.mtime) {
|
||||||
|
let mut backlog = backlog.lock();
|
||||||
|
|
||||||
|
log::info!(
|
||||||
|
"Inserting {:?} ({:?} bytes) into backlog",
|
||||||
|
&entry.path,
|
||||||
|
entry.size,
|
||||||
|
);
|
||||||
|
backlog.insert(Arc::clone(&entry.path), entry.size, entry.mtime);
|
||||||
|
|
||||||
|
if backlog.needs_drain() {
|
||||||
|
log::info!("Draining summary backlog...");
|
||||||
|
return backlog.drain().collect();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
log::error!(
|
||||||
|
"Error trying to get file digest db entry {:?}: {:?}",
|
||||||
|
&entry_db_key,
|
||||||
|
err
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scan_updated_entries(
|
||||||
|
&self,
|
||||||
|
worktree: Snapshot,
|
||||||
|
updated_entries: UpdatedEntriesSet,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> Backlogged {
|
||||||
|
log::info!("Scanning for updated entries that might need summarization...");
|
||||||
|
let (tx, rx) = channel::bounded(512);
|
||||||
|
// let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
|
||||||
|
let db_connection = self.db_connection.clone();
|
||||||
|
let digest_db = self.file_digest_db;
|
||||||
|
let backlog = Arc::clone(&self.backlog);
|
||||||
|
let task = cx.background_executor().spawn(async move {
|
||||||
|
let txn = db_connection
|
||||||
|
.read_txn()
|
||||||
|
.context("failed to create read transaction")?;
|
||||||
|
|
||||||
|
for (path, entry_id, status) in updated_entries.iter() {
|
||||||
|
match status {
|
||||||
|
project::PathChange::Loaded
|
||||||
|
| project::PathChange::Added
|
||||||
|
| project::PathChange::Updated
|
||||||
|
| project::PathChange::AddedOrUpdated => {
|
||||||
|
if let Some(entry) = worktree.entry_for_id(*entry_id) {
|
||||||
|
if entry.is_file() {
|
||||||
|
let needs_summary = Self::add_to_backlog(
|
||||||
|
Arc::clone(&backlog),
|
||||||
|
digest_db,
|
||||||
|
&txn,
|
||||||
|
entry,
|
||||||
|
);
|
||||||
|
|
||||||
|
if !needs_summary.is_empty() {
|
||||||
|
tx.send(needs_summary).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
project::PathChange::Removed => {
|
||||||
|
let _db_path = db_key_for_path(path);
|
||||||
|
// TODO delete db entries for deleted files
|
||||||
|
// deleted_entry_ranges_tx
|
||||||
|
// .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
|
||||||
|
// .await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
|
||||||
|
Backlogged {
|
||||||
|
paths_to_digest: rx,
|
||||||
|
// deleted_entry_ranges: deleted_entry_ranges_rx,
|
||||||
|
task,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn digest_files(
|
||||||
|
&self,
|
||||||
|
paths: channel::Receiver<Vec<(Arc<Path>, Option<SystemTime>)>>,
|
||||||
|
worktree_abs_path: Arc<Path>,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> MightNeedSummaryFiles {
|
||||||
|
let fs = self.fs.clone();
|
||||||
|
let (rx, tx) = channel::bounded(2048);
|
||||||
|
let task = cx.spawn(|cx| async move {
|
||||||
|
cx.background_executor()
|
||||||
|
.scoped(|cx| {
|
||||||
|
for _ in 0..cx.num_cpus() {
|
||||||
|
cx.spawn(async {
|
||||||
|
while let Ok(pairs) = paths.recv().await {
|
||||||
|
// Note: we could process all these files concurrently if desired. Might or might not speed things up.
|
||||||
|
for (path, mtime) in pairs {
|
||||||
|
let entry_abs_path = worktree_abs_path.join(&path);
|
||||||
|
|
||||||
|
// Load the file's contents and compute its hash digest.
|
||||||
|
let unsummarized_file = {
|
||||||
|
let Some(contents) = fs
|
||||||
|
.load(&entry_abs_path)
|
||||||
|
.await
|
||||||
|
.with_context(|| {
|
||||||
|
format!("failed to read path {entry_abs_path:?}")
|
||||||
|
})
|
||||||
|
.log_err()
|
||||||
|
else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let digest = {
|
||||||
|
let mut hasher = blake3::Hasher::new();
|
||||||
|
// Incorporate both the (relative) file path as well as the contents of the file into the hash.
|
||||||
|
// This is because in some languages and frameworks, identical files can do different things
|
||||||
|
// depending on their paths (e.g. Rails controllers). It's also why we send the path to the model.
|
||||||
|
hasher.update(path.display().to_string().as_bytes());
|
||||||
|
hasher.update(contents.as_bytes());
|
||||||
|
hasher.finalize().to_hex()
|
||||||
|
};
|
||||||
|
|
||||||
|
UnsummarizedFile {
|
||||||
|
digest,
|
||||||
|
contents,
|
||||||
|
path,
|
||||||
|
mtime,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(err) = rx
|
||||||
|
.send(unsummarized_file)
|
||||||
|
.map_err(|error| anyhow!(error))
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
log::error!("Error: {:?}", err);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
|
||||||
|
MightNeedSummaryFiles { files: tx, task }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn summarize_files(
|
||||||
|
&self,
|
||||||
|
mut unsummarized_files: channel::Receiver<UnsummarizedFile>,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> SummarizeFiles {
|
||||||
|
let (summarized_tx, summarized_rx) = channel::bounded(512);
|
||||||
|
let task = cx.spawn(|cx| async move {
|
||||||
|
while let Some(file) = unsummarized_files.next().await {
|
||||||
|
log::debug!("Summarizing {:?}", file);
|
||||||
|
let summary = cx
|
||||||
|
.update(|cx| Self::summarize_code(&file.contents, &file.path, cx))?
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|err| {
|
||||||
|
// Log a warning because we'll continue anyway.
|
||||||
|
// In the future, we may want to try splitting it up into multiple requests and concatenating the summaries,
|
||||||
|
// but this might give bad summaries due to cutting off source code files in the middle.
|
||||||
|
log::warn!("Failed to summarize {} - {:?}", file.path.display(), err);
|
||||||
|
|
||||||
|
String::new()
|
||||||
|
});
|
||||||
|
|
||||||
|
// Note that the summary could be empty because of an error talking to a cloud provider,
|
||||||
|
// e.g. because the context limit was exceeded. In that case, we return Ok(String::new()).
|
||||||
|
if !summary.is_empty() {
|
||||||
|
summarized_tx
|
||||||
|
.send(SummarizedFile {
|
||||||
|
path: file.path.display().to_string(),
|
||||||
|
digest: file.digest,
|
||||||
|
summary,
|
||||||
|
mtime: file.mtime,
|
||||||
|
})
|
||||||
|
.await?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
|
||||||
|
SummarizeFiles {
|
||||||
|
files: summarized_rx,
|
||||||
|
task,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn summarize_code(
|
||||||
|
code: &str,
|
||||||
|
path: &Path,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> impl Future<Output = Result<String>> {
|
||||||
|
let start = Instant::now();
|
||||||
|
let (summary_model_id, use_cache): (LanguageModelId, bool) = (
|
||||||
|
"Qwen/Qwen2-7B-Instruct".to_string().into(), // TODO read this from the user's settings.
|
||||||
|
false, // qwen2 doesn't have a cache, but we should probably infer this from the model
|
||||||
|
);
|
||||||
|
let Some(model) = LanguageModelRegistry::read_global(cx)
|
||||||
|
.available_models(cx)
|
||||||
|
.find(|model| &model.id() == &summary_model_id)
|
||||||
|
else {
|
||||||
|
return cx.background_executor().spawn(async move {
|
||||||
|
Err(anyhow!("Couldn't find the preferred summarization model ({:?}) in the language registry's available models", summary_model_id))
|
||||||
|
});
|
||||||
|
};
|
||||||
|
let utf8_path = path.to_string_lossy();
|
||||||
|
const PROMPT_BEFORE_CODE: &str = "Summarize what the code in this file does in 3 sentences, using no newlines or bullet points in the summary:";
|
||||||
|
let prompt = format!("{PROMPT_BEFORE_CODE}\n{utf8_path}:\n{code}");
|
||||||
|
|
||||||
|
log::debug!(
|
||||||
|
"Summarizing code by sending this prompt to {:?}: {:?}",
|
||||||
|
model.name(),
|
||||||
|
&prompt
|
||||||
|
);
|
||||||
|
|
||||||
|
let request = LanguageModelRequest {
|
||||||
|
messages: vec![LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: vec![prompt.into()],
|
||||||
|
cache: use_cache,
|
||||||
|
}],
|
||||||
|
tools: Vec::new(),
|
||||||
|
stop: Vec::new(),
|
||||||
|
temperature: 1.0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let code_len = code.len();
|
||||||
|
cx.spawn(|cx| async move {
|
||||||
|
let stream = model.stream_completion(request, &cx);
|
||||||
|
cx.background_executor()
|
||||||
|
.spawn(async move {
|
||||||
|
let answer: String = stream
|
||||||
|
.await?
|
||||||
|
.filter_map(|event| async {
|
||||||
|
if let Ok(LanguageModelCompletionEvent::Text(text)) = event {
|
||||||
|
Some(text)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
log::info!(
|
||||||
|
"It took {:?} to summarize {:?} bytes of code.",
|
||||||
|
start.elapsed(),
|
||||||
|
code_len
|
||||||
|
);
|
||||||
|
|
||||||
|
log::debug!("Summary was: {:?}", &answer);
|
||||||
|
|
||||||
|
Ok(answer)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
|
||||||
|
// TODO if summarization failed, put it back in the backlog!
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn persist_summaries(
|
||||||
|
&self,
|
||||||
|
summaries: channel::Receiver<SummarizedFile>,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> Task<Result<()>> {
|
||||||
|
let db_connection = self.db_connection.clone();
|
||||||
|
let digest_db = self.file_digest_db;
|
||||||
|
let summary_db = self.summary_db;
|
||||||
|
cx.background_executor().spawn(async move {
|
||||||
|
let mut summaries = summaries.chunks_timeout(4096, Duration::from_secs(2));
|
||||||
|
while let Some(summaries) = summaries.next().await {
|
||||||
|
let mut txn = db_connection.write_txn()?;
|
||||||
|
for file in &summaries {
|
||||||
|
log::debug!(
|
||||||
|
"Saving summary of {:?} - which is {} bytes of summary for content digest {:?}",
|
||||||
|
&file.path,
|
||||||
|
file.summary.len(),
|
||||||
|
file.digest
|
||||||
|
);
|
||||||
|
digest_db.put(
|
||||||
|
&mut txn,
|
||||||
|
&file.path,
|
||||||
|
&FileDigest {
|
||||||
|
mtime: file.mtime,
|
||||||
|
digest: file.digest,
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
summary_db.put(&mut txn, &file.digest, &file.summary)?;
|
||||||
|
}
|
||||||
|
txn.commit()?;
|
||||||
|
|
||||||
|
drop(summaries);
|
||||||
|
log::debug!("committed summaries");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Empty out the backlog of files that haven't been resummarized, and resummarize them immediately.
|
||||||
|
pub(crate) fn flush_backlog(
|
||||||
|
&self,
|
||||||
|
worktree_abs_path: Arc<Path>,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> impl Future<Output = Result<()>> {
|
||||||
|
let start = Instant::now();
|
||||||
|
let backlogged = {
|
||||||
|
let (tx, rx) = channel::bounded(512);
|
||||||
|
let needs_summary: Vec<(Arc<Path>, Option<SystemTime>)> = {
|
||||||
|
let mut backlog = self.backlog.lock();
|
||||||
|
|
||||||
|
backlog.drain().collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
let task = cx.background_executor().spawn(async move {
|
||||||
|
tx.send(needs_summary).await?;
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
|
||||||
|
Backlogged {
|
||||||
|
paths_to_digest: rx,
|
||||||
|
task,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let digest = self.digest_files(backlogged.paths_to_digest, worktree_abs_path, cx);
|
||||||
|
let needs_summary = self.check_summary_cache(digest.files, cx);
|
||||||
|
let summaries = self.summarize_files(needs_summary.files, cx);
|
||||||
|
let persist = self.persist_summaries(summaries.files, cx);
|
||||||
|
|
||||||
|
async move {
|
||||||
|
futures::try_join!(
|
||||||
|
backlogged.task,
|
||||||
|
digest.task,
|
||||||
|
needs_summary.task,
|
||||||
|
summaries.task,
|
||||||
|
persist
|
||||||
|
)?;
|
||||||
|
|
||||||
|
log::info!("Summarizing backlogged entries took {:?}", start.elapsed());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn backlog_len(&self) -> usize {
|
||||||
|
self.backlog.lock().len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn db_key_for_path(path: &Arc<Path>) -> String {
|
||||||
|
path.to_string_lossy().replace('/', "\0")
|
||||||
|
}
|
217
crates/semantic_index/src/worktree_index.rs
Normal file
217
crates/semantic_index/src/worktree_index.rs
Normal file
|
@ -0,0 +1,217 @@
|
||||||
|
use crate::embedding::EmbeddingProvider;
|
||||||
|
use crate::embedding_index::EmbeddingIndex;
|
||||||
|
use crate::indexing::IndexingEntrySet;
|
||||||
|
use crate::summary_index::SummaryIndex;
|
||||||
|
use anyhow::Result;
|
||||||
|
use feature_flags::{AutoCommand, FeatureFlagAppExt};
|
||||||
|
use fs::Fs;
|
||||||
|
use futures::future::Shared;
|
||||||
|
use gpui::{
|
||||||
|
AppContext, AsyncAppContext, Context, Model, ModelContext, Subscription, Task, WeakModel,
|
||||||
|
};
|
||||||
|
use language::LanguageRegistry;
|
||||||
|
use log;
|
||||||
|
use project::{UpdatedEntriesSet, Worktree};
|
||||||
|
use smol::channel;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum WorktreeIndexHandle {
|
||||||
|
Loading {
|
||||||
|
index: Shared<Task<Result<Model<WorktreeIndex>, Arc<anyhow::Error>>>>,
|
||||||
|
},
|
||||||
|
Loaded {
|
||||||
|
index: Model<WorktreeIndex>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct WorktreeIndex {
|
||||||
|
worktree: Model<Worktree>,
|
||||||
|
db_connection: heed::Env,
|
||||||
|
embedding_index: EmbeddingIndex,
|
||||||
|
summary_index: SummaryIndex,
|
||||||
|
entry_ids_being_indexed: Arc<IndexingEntrySet>,
|
||||||
|
_index_entries: Task<Result<()>>,
|
||||||
|
_subscription: Subscription,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WorktreeIndex {
|
||||||
|
pub fn load(
|
||||||
|
worktree: Model<Worktree>,
|
||||||
|
db_connection: heed::Env,
|
||||||
|
language_registry: Arc<LanguageRegistry>,
|
||||||
|
fs: Arc<dyn Fs>,
|
||||||
|
status_tx: channel::Sender<()>,
|
||||||
|
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||||
|
cx: &mut AppContext,
|
||||||
|
) -> Task<Result<Model<Self>>> {
|
||||||
|
let worktree_for_index = worktree.clone();
|
||||||
|
let worktree_for_summary = worktree.clone();
|
||||||
|
let worktree_abs_path = worktree.read(cx).abs_path();
|
||||||
|
let embedding_fs = Arc::clone(&fs);
|
||||||
|
let summary_fs = fs;
|
||||||
|
cx.spawn(|mut cx| async move {
|
||||||
|
let entries_being_indexed = Arc::new(IndexingEntrySet::new(status_tx));
|
||||||
|
let (embedding_index, summary_index) = cx
|
||||||
|
.background_executor()
|
||||||
|
.spawn({
|
||||||
|
let entries_being_indexed = Arc::clone(&entries_being_indexed);
|
||||||
|
let db_connection = db_connection.clone();
|
||||||
|
async move {
|
||||||
|
let mut txn = db_connection.write_txn()?;
|
||||||
|
let embedding_index = {
|
||||||
|
let db_name = worktree_abs_path.to_string_lossy();
|
||||||
|
let db = db_connection.create_database(&mut txn, Some(&db_name))?;
|
||||||
|
|
||||||
|
EmbeddingIndex::new(
|
||||||
|
worktree_for_index,
|
||||||
|
embedding_fs,
|
||||||
|
db_connection.clone(),
|
||||||
|
db,
|
||||||
|
language_registry,
|
||||||
|
embedding_provider,
|
||||||
|
Arc::clone(&entries_being_indexed),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let summary_index = {
|
||||||
|
let file_digest_db = {
|
||||||
|
let db_name =
|
||||||
|
// Prepend something that wouldn't be found at the beginning of an
|
||||||
|
// absolute path, so we don't get db key namespace conflicts with
|
||||||
|
// embeddings, which use the abs path as a key.
|
||||||
|
format!("digests-{}", worktree_abs_path.to_string_lossy());
|
||||||
|
db_connection.create_database(&mut txn, Some(&db_name))?
|
||||||
|
};
|
||||||
|
let summary_db = {
|
||||||
|
let db_name =
|
||||||
|
// Prepend something that wouldn't be found at the beginning of an
|
||||||
|
// absolute path, so we don't get db key namespace conflicts with
|
||||||
|
// embeddings, which use the abs path as a key.
|
||||||
|
format!("summaries-{}", worktree_abs_path.to_string_lossy());
|
||||||
|
db_connection.create_database(&mut txn, Some(&db_name))?
|
||||||
|
};
|
||||||
|
SummaryIndex::new(
|
||||||
|
worktree_for_summary,
|
||||||
|
summary_fs,
|
||||||
|
db_connection.clone(),
|
||||||
|
file_digest_db,
|
||||||
|
summary_db,
|
||||||
|
Arc::clone(&entries_being_indexed),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
txn.commit()?;
|
||||||
|
anyhow::Ok((embedding_index, summary_index))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
cx.new_model(|cx| {
|
||||||
|
Self::new(
|
||||||
|
worktree,
|
||||||
|
db_connection,
|
||||||
|
embedding_index,
|
||||||
|
summary_index,
|
||||||
|
entries_being_indexed,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn new(
|
||||||
|
worktree: Model<Worktree>,
|
||||||
|
db_connection: heed::Env,
|
||||||
|
embedding_index: EmbeddingIndex,
|
||||||
|
summary_index: SummaryIndex,
|
||||||
|
entry_ids_being_indexed: Arc<IndexingEntrySet>,
|
||||||
|
cx: &mut ModelContext<Self>,
|
||||||
|
) -> Self {
|
||||||
|
let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
|
||||||
|
let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
|
||||||
|
if let worktree::Event::UpdatedEntries(update) = event {
|
||||||
|
log::debug!("Updating entries...");
|
||||||
|
_ = updated_entries_tx.try_send(update.clone());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Self {
|
||||||
|
db_connection,
|
||||||
|
embedding_index,
|
||||||
|
summary_index,
|
||||||
|
worktree,
|
||||||
|
entry_ids_being_indexed,
|
||||||
|
_index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
|
||||||
|
_subscription,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn entry_ids_being_indexed(&self) -> &IndexingEntrySet {
|
||||||
|
self.entry_ids_being_indexed.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn worktree(&self) -> &Model<Worktree> {
|
||||||
|
&self.worktree
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn db_connection(&self) -> &heed::Env {
|
||||||
|
&self.db_connection
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embedding_index(&self) -> &EmbeddingIndex {
|
||||||
|
&self.embedding_index
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn summary_index(&self) -> &SummaryIndex {
|
||||||
|
&self.summary_index
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn index_entries(
|
||||||
|
this: WeakModel<Self>,
|
||||||
|
updated_entries: channel::Receiver<UpdatedEntriesSet>,
|
||||||
|
mut cx: AsyncAppContext,
|
||||||
|
) -> Result<()> {
|
||||||
|
let is_auto_available = cx.update(|cx| cx.wait_for_flag::<AutoCommand>())?.await;
|
||||||
|
let index = this.update(&mut cx, |this, cx| {
|
||||||
|
futures::future::try_join(
|
||||||
|
this.embedding_index.index_entries_changed_on_disk(cx),
|
||||||
|
this.summary_index
|
||||||
|
.index_entries_changed_on_disk(is_auto_available, cx),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
index.await.log_err();
|
||||||
|
|
||||||
|
while let Ok(updated_entries) = updated_entries.recv().await {
|
||||||
|
let is_auto_available = cx
|
||||||
|
.update(|cx| cx.has_flag::<AutoCommand>())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
let index = this.update(&mut cx, |this, cx| {
|
||||||
|
futures::future::try_join(
|
||||||
|
this.embedding_index
|
||||||
|
.index_updated_entries(updated_entries.clone(), cx),
|
||||||
|
this.summary_index.index_updated_entries(
|
||||||
|
updated_entries,
|
||||||
|
is_auto_available,
|
||||||
|
cx,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
index.await.log_err();
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub fn path_count(&self) -> Result<u64> {
|
||||||
|
use anyhow::Context;
|
||||||
|
|
||||||
|
let txn = self
|
||||||
|
.db_connection
|
||||||
|
.read_txn()
|
||||||
|
.context("failed to create read transaction")?;
|
||||||
|
Ok(self.embedding_index().db().len(&txn)?)
|
||||||
|
}
|
||||||
|
}
|
|
@ -3227,6 +3227,8 @@ pub struct Entry {
|
||||||
pub git_status: Option<GitFileStatus>,
|
pub git_status: Option<GitFileStatus>,
|
||||||
/// Whether this entry is considered to be a `.env` file.
|
/// Whether this entry is considered to be a `.env` file.
|
||||||
pub is_private: bool,
|
pub is_private: bool,
|
||||||
|
/// The entry's size on disk, in bytes.
|
||||||
|
pub size: u64,
|
||||||
pub char_bag: CharBag,
|
pub char_bag: CharBag,
|
||||||
pub is_fifo: bool,
|
pub is_fifo: bool,
|
||||||
}
|
}
|
||||||
|
@ -3282,6 +3284,7 @@ impl Entry {
|
||||||
path,
|
path,
|
||||||
inode: metadata.inode,
|
inode: metadata.inode,
|
||||||
mtime: Some(metadata.mtime),
|
mtime: Some(metadata.mtime),
|
||||||
|
size: metadata.len,
|
||||||
canonical_path,
|
canonical_path,
|
||||||
is_symlink: metadata.is_symlink,
|
is_symlink: metadata.is_symlink,
|
||||||
is_ignored: false,
|
is_ignored: false,
|
||||||
|
@ -5210,6 +5213,7 @@ impl<'a> From<&'a Entry> for proto::Entry {
|
||||||
is_external: entry.is_external,
|
is_external: entry.is_external,
|
||||||
git_status: entry.git_status.map(git_status_to_proto),
|
git_status: entry.git_status.map(git_status_to_proto),
|
||||||
is_fifo: entry.is_fifo,
|
is_fifo: entry.is_fifo,
|
||||||
|
size: Some(entry.size),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5231,6 +5235,7 @@ impl<'a> TryFrom<(&'a CharBag, proto::Entry)> for Entry {
|
||||||
path,
|
path,
|
||||||
inode: entry.inode,
|
inode: entry.inode,
|
||||||
mtime: entry.mtime.map(|time| time.into()),
|
mtime: entry.mtime.map(|time| time.into()),
|
||||||
|
size: entry.size.unwrap_or(0),
|
||||||
canonical_path: None,
|
canonical_path: None,
|
||||||
is_ignored: entry.is_ignored,
|
is_ignored: entry.is_ignored,
|
||||||
is_external: entry.is_external,
|
is_external: entry.is_external,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue