Add `/auto` behind a feature flag that's disabled for now, even for
staff.

We've decided on a different design for context inference, but there are
parts of /auto that will be useful for that, so we want them in the code
base even if they're unused for now.

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
Richard Feldman 2024-09-13 13:17:49 -04:00 committed by GitHub
parent 93a3e8bc94
commit 91ffa02e2c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
42 changed files with 2776 additions and 1054 deletions

27
Cargo.lock generated
View file

@ -304,6 +304,9 @@ name = "arrayvec"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
dependencies = [
"serde",
]
[[package]]
name = "as-raw-xcb-connection"
@ -1709,6 +1712,19 @@ dependencies = [
"profiling",
]
[[package]]
name = "blake3"
version = "1.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7"
dependencies = [
"arrayref",
"arrayvec",
"cc",
"cfg-if",
"constant_time_eq",
]
[[package]]
name = "block"
version = "0.1.6"
@ -2752,6 +2768,12 @@ dependencies = [
"tiny-keccak",
]
[[package]]
name = "constant_time_eq"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6"
[[package]]
name = "context_servers"
version = "0.1.0"
@ -4187,6 +4209,7 @@ dependencies = [
name = "feature_flags"
version = "0.1.0"
dependencies = [
"futures 0.3.30",
"gpui",
]
@ -9814,10 +9837,13 @@ name = "semantic_index"
version = "0.1.0"
dependencies = [
"anyhow",
"arrayvec",
"blake3",
"client",
"clock",
"collections",
"env_logger",
"feature_flags",
"fs",
"futures 0.3.30",
"futures-batch",
@ -9825,6 +9851,7 @@ dependencies = [
"heed",
"http_client",
"language",
"language_model",
"languages",
"log",
"open_ai",

View file

@ -309,6 +309,7 @@ aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/alacritty/alacritty", rev = "91d034ff8b53867143c005acfaa14609147c9a2c" }
any_vec = "0.14"
anyhow = "1.0.86"
arrayvec = { version = "0.7.4", features = ["serde"] }
ashpd = "0.9.1"
async-compression = { version = "0.4", features = ["gzip", "futures-io"] }
async-dispatcher = "0.1"
@ -325,6 +326,7 @@ bitflags = "2.6.0"
blade-graphics = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
blade-macros = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
blade-util = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
blake3 = "1.5.3"
cargo_metadata = "0.18"
cargo_toml = "0.20"
chrono = { version = "0.4", features = ["serde"] }

View file

@ -37,13 +37,13 @@ use language_model::{
pub(crate) use model_selector::*;
pub use prompts::PromptBuilder;
use prompts::PromptLoadingParams;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use semantic_index::{CloudEmbeddingProvider, SemanticDb};
use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore};
use slash_command::{
context_server_command, default_command, diagnostics_command, docs_command, fetch_command,
file_command, now_command, project_command, prompt_command, search_command, symbols_command,
tab_command, terminal_command, workflow_command,
auto_command, context_server_command, default_command, diagnostics_command, docs_command,
fetch_command, file_command, now_command, project_command, prompt_command, search_command,
symbols_command, tab_command, terminal_command, workflow_command,
};
use std::path::PathBuf;
use std::sync::Arc;
@ -210,12 +210,13 @@ pub fn init(
let client = client.clone();
async move {
let embedding_provider = CloudEmbeddingProvider::new(client.clone());
let semantic_index = SemanticIndex::new(
let semantic_index = SemanticDb::new(
paths::embeddings_dir().join("semantic-index-db.0.mdb"),
Arc::new(embedding_provider),
&mut cx,
)
.await?;
cx.update(|cx| cx.set_global(semantic_index))
}
})
@ -364,6 +365,7 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) {
fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
let slash_command_registry = SlashCommandRegistry::global(cx);
slash_command_registry.register_command(file_command::FileSlashCommand, true);
slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
slash_command_registry.register_command(tab_command::TabSlashCommand, true);
@ -382,6 +384,17 @@ fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut
}
slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
cx.observe_flag::<auto_command::AutoSlashCommandFeatureFlag, _>({
let slash_command_registry = slash_command_registry.clone();
move |is_enabled, _cx| {
if is_enabled {
// [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
slash_command_registry.register_command(auto_command::AutoCommand, true);
}
}
})
.detach();
update_slash_commands_from_settings(cx);
cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
.detach();

View file

@ -4723,6 +4723,20 @@ impl Render for ContextEditorToolbarItem {
let weak_self = cx.view().downgrade();
let right_side = h_flex()
.gap_2()
// TODO display this in a nicer way, once we have a design for it.
// .children({
// let project = self
// .workspace
// .upgrade()
// .map(|workspace| workspace.read(cx).project().downgrade());
//
// let scan_items_remaining = cx.update_global(|db: &mut SemanticDb, cx| {
// project.and_then(|project| db.remaining_summaries(&project, cx))
// });
// scan_items_remaining
// .map(|remaining_items| format!("Files to scan: {}", remaining_items))
// })
.child(
ModelSelector::new(
self.fs.clone(),

View file

@ -519,6 +519,7 @@ impl Settings for AssistantSettings {
&mut settings.default_model,
value.default_model.map(Into::into),
);
// merge(&mut settings.infer_context, value.infer_context); TODO re-enable this once we ship context inference
}
Ok(settings)

View file

@ -19,6 +19,7 @@ use std::{
use ui::ActiveTheme;
use workspace::Workspace;
pub mod auto_command;
pub mod context_server_command;
pub mod default_command;
pub mod diagnostics_command;

View file

@ -0,0 +1,360 @@
use super::create_label_for_command;
use super::{SlashCommand, SlashCommandOutput};
use anyhow::{anyhow, Result};
use assistant_slash_command::ArgumentCompletion;
use feature_flags::FeatureFlag;
use futures::StreamExt;
use gpui::{AppContext, AsyncAppContext, Task, WeakView};
use language::{CodeLabel, LspAdapterDelegate};
use language_model::{
LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, Role,
};
use semantic_index::{FileSummary, SemanticDb};
use smol::channel;
use std::sync::{atomic::AtomicBool, Arc};
use ui::{BorrowAppContext, WindowContext};
use util::ResultExt;
use workspace::Workspace;
pub struct AutoSlashCommandFeatureFlag;
impl FeatureFlag for AutoSlashCommandFeatureFlag {
const NAME: &'static str = "auto-slash-command";
}
pub(crate) struct AutoCommand;
impl SlashCommand for AutoCommand {
fn name(&self) -> String {
"auto".into()
}
fn description(&self) -> String {
"Automatically infer what context to add, based on your prompt".into()
}
fn menu_text(&self) -> String {
"Automatically Infer Context".into()
}
fn label(&self, cx: &AppContext) -> CodeLabel {
create_label_for_command("auto", &["--prompt"], cx)
}
fn complete_argument(
self: Arc<Self>,
_arguments: &[String],
_cancel: Arc<AtomicBool>,
workspace: Option<WeakView<Workspace>>,
cx: &mut WindowContext,
) -> Task<Result<Vec<ArgumentCompletion>>> {
// There's no autocomplete for a prompt, since it's arbitrary text.
// However, we can use this opportunity to kick off a drain of the backlog.
// That way, it can hopefully be done resummarizing by the time we've actually
// typed out our prompt. This re-runs on every keystroke during autocomplete,
// but in the future, we could instead do it only once, when /auto is first entered.
let Some(workspace) = workspace.and_then(|ws| ws.upgrade()) else {
log::warn!("workspace was dropped or unavailable during /auto autocomplete");
return Task::ready(Ok(Vec::new()));
};
let project = workspace.read(cx).project().clone();
let Some(project_index) =
cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
else {
return Task::ready(Err(anyhow!("No project indexer, cannot use /auto")));
};
let cx: &mut AppContext = cx;
cx.spawn(|cx: gpui::AsyncAppContext| async move {
let task = project_index.read_with(&cx, |project_index, cx| {
project_index.flush_summary_backlogs(cx)
})?;
cx.background_executor().spawn(task).await;
anyhow::Ok(Vec::new())
})
}
fn requires_argument(&self) -> bool {
true
}
fn run(
self: Arc<Self>,
arguments: &[String],
workspace: WeakView<Workspace>,
_delegate: Option<Arc<dyn LspAdapterDelegate>>,
cx: &mut WindowContext,
) -> Task<Result<SlashCommandOutput>> {
let Some(workspace) = workspace.upgrade() else {
return Task::ready(Err(anyhow::anyhow!("workspace was dropped")));
};
if arguments.is_empty() {
return Task::ready(Err(anyhow!("missing prompt")));
};
let argument = arguments.join(" ");
let original_prompt = argument.to_string();
let project = workspace.read(cx).project().clone();
let Some(project_index) =
cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
else {
return Task::ready(Err(anyhow!("no project indexer")));
};
let task = cx.spawn(|cx: gpui::AsyncWindowContext| async move {
let summaries = project_index
.read_with(&cx, |project_index, cx| project_index.all_summaries(cx))?
.await?;
commands_for_summaries(&summaries, &original_prompt, &cx).await
});
// As a convenience, append /auto's argument to the end of the prompt
// so you don't have to write it again.
let original_prompt = argument.to_string();
cx.background_executor().spawn(async move {
let commands = task.await?;
let mut prompt = String::new();
log::info!(
"Translating this response into slash-commands: {:?}",
commands
);
for command in commands {
prompt.push('/');
prompt.push_str(&command.name);
prompt.push(' ');
prompt.push_str(&command.arg);
prompt.push('\n');
}
prompt.push('\n');
prompt.push_str(&original_prompt);
Ok(SlashCommandOutput {
text: prompt,
sections: Vec::new(),
run_commands_in_text: true,
})
})
}
}
const PROMPT_INSTRUCTIONS_BEFORE_SUMMARY: &str = include_str!("prompt_before_summary.txt");
const PROMPT_INSTRUCTIONS_AFTER_SUMMARY: &str = include_str!("prompt_after_summary.txt");
fn summaries_prompt(summaries: &[FileSummary], original_prompt: &str) -> String {
let json_summaries = serde_json::to_string(summaries).unwrap();
format!("{PROMPT_INSTRUCTIONS_BEFORE_SUMMARY}\n{json_summaries}\n{PROMPT_INSTRUCTIONS_AFTER_SUMMARY}\n{original_prompt}")
}
/// The slash commands that the model is told about, and which we look for in the inference response.
const SUPPORTED_SLASH_COMMANDS: &[&str] = &["search", "file"];
#[derive(Debug, Clone)]
struct CommandToRun {
name: String,
arg: String,
}
/// Given the pre-indexed file summaries for this project, as well as the original prompt
/// string passed to `/auto`, get a list of slash commands to run, along with their arguments.
///
/// The prompt's output does not include the slashes (to reduce the chance that it makes a mistake),
/// so taking one of these returned Strings and turning it into a real slash-command-with-argument
/// involves prepending a slash to it.
///
/// This function will validate that each of the returned lines begins with one of SUPPORTED_SLASH_COMMANDS.
/// Any other lines it encounters will be discarded, with a warning logged.
async fn commands_for_summaries(
summaries: &[FileSummary],
original_prompt: &str,
cx: &AsyncAppContext,
) -> Result<Vec<CommandToRun>> {
if summaries.is_empty() {
log::warn!("Inferring no context because there were no summaries available.");
return Ok(Vec::new());
}
// Use the globally configured model to translate the summaries into slash-commands,
// because Qwen2-7B-Instruct has not done a good job at that task.
let Some(model) = cx.update(|cx| LanguageModelRegistry::read_global(cx).active_model())? else {
log::warn!("Can't infer context because there's no active model.");
return Ok(Vec::new());
};
// Only go up to 90% of the actual max token count, to reduce chances of
// exceeding the token count due to inaccuracies in the token counting heuristic.
let max_token_count = (model.max_token_count() * 9) / 10;
// Rather than recursing (which would require this async function use a pinned box),
// we use an explicit stack of arguments and answers for when we need to "recurse."
let mut stack = vec![summaries];
let mut final_response = Vec::new();
let mut prompts = Vec::new();
// TODO We only need to create multiple Requests because we currently
// don't have the ability to tell if a CompletionProvider::complete response
// was a "too many tokens in this request" error. If we had that, then
// we could try the request once, instead of having to make separate requests
// to check the token count and then afterwards to run the actual prompt.
let make_request = |prompt: String| LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
// Nothing in here will benefit from caching
cache: false,
}],
tools: Vec::new(),
stop: Vec::new(),
temperature: 1.0,
};
while let Some(current_summaries) = stack.pop() {
// The split can result in one slice being empty and the other having one element.
// Whenever that happens, skip the empty one.
if current_summaries.is_empty() {
continue;
}
log::info!(
"Inferring prompt context using {} file summaries",
current_summaries.len()
);
let prompt = summaries_prompt(&current_summaries, original_prompt);
let start = std::time::Instant::now();
// Per OpenAI, 1 token ~= 4 chars in English (we go with 4.5 to overestimate a bit, because failed API requests cost a lot of perf)
// Verifying this against an actual model.count_tokens() confirms that it's usually within ~5% of the correct answer, whereas
// getting the correct answer from tiktoken takes hundreds of milliseconds (compared to this arithmetic being ~free).
// source: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
let token_estimate = prompt.len() * 2 / 9;
let duration = start.elapsed();
log::info!(
"Time taken to count tokens for prompt of length {:?}B: {:?}",
prompt.len(),
duration
);
if token_estimate < max_token_count {
prompts.push(prompt);
} else if current_summaries.len() == 1 {
log::warn!("Inferring context for a single file's summary failed because the prompt's token length exceeded the model's token limit.");
} else {
log::info!(
"Context inference using file summaries resulted in a prompt containing {token_estimate} tokens, which exceeded the model's max of {max_token_count}. Retrying as two separate prompts, each including half the number of summaries.",
);
let (left, right) = current_summaries.split_at(current_summaries.len() / 2);
stack.push(right);
stack.push(left);
}
}
let all_start = std::time::Instant::now();
let (tx, rx) = channel::bounded(1024);
let completion_streams = prompts
.into_iter()
.map(|prompt| {
let request = make_request(prompt.clone());
let model = model.clone();
let tx = tx.clone();
let stream = model.stream_completion(request, &cx);
(stream, tx)
})
.collect::<Vec<_>>();
cx.background_executor()
.spawn(async move {
let futures = completion_streams
.into_iter()
.enumerate()
.map(|(ix, (stream, tx))| async move {
let start = std::time::Instant::now();
let events = stream.await?;
log::info!("Time taken for awaiting /await chunk stream #{ix}: {:?}", start.elapsed());
let completion: String = events
.filter_map(|event| async {
if let Ok(LanguageModelCompletionEvent::Text(text)) = event {
Some(text)
} else {
None
}
})
.collect()
.await;
log::info!("Time taken for all /auto chunks to come back for #{ix}: {:?}", start.elapsed());
for line in completion.split('\n') {
if let Some(first_space) = line.find(' ') {
let command = &line[..first_space].trim();
let arg = &line[first_space..].trim();
tx.send(CommandToRun {
name: command.to_string(),
arg: arg.to_string(),
})
.await?;
} else if !line.trim().is_empty() {
// All slash-commands currently supported in context inference need a space for the argument.
log::warn!(
"Context inference returned a non-blank line that contained no spaces (meaning no argument for the slash command): {:?}",
line
);
}
}
anyhow::Ok(())
})
.collect::<Vec<_>>();
let _ = futures::future::try_join_all(futures).await.log_err();
let duration = all_start.elapsed();
eprintln!("All futures completed in {:?}", duration);
})
.await;
drop(tx); // Close the channel so that rx.collect() won't hang. This is safe because all futures have completed.
let results = rx.collect::<Vec<_>>().await;
eprintln!(
"Finished collecting from the channel with {} results",
results.len()
);
for command in results {
// Don't return empty or duplicate commands
if !command.name.is_empty()
&& !final_response
.iter()
.any(|cmd: &CommandToRun| cmd.name == command.name && cmd.arg == command.arg)
{
if SUPPORTED_SLASH_COMMANDS
.iter()
.any(|supported| &command.name == supported)
{
final_response.push(command);
} else {
log::warn!(
"Context inference returned an unrecognized slash command: {:?}",
command
);
}
}
}
// Sort the commands by name (reversed just so that /search appears before /file)
final_response.sort_by(|cmd1, cmd2| cmd1.name.cmp(&cmd2.name).reverse());
Ok(final_response)
}

View file

@ -0,0 +1,24 @@
Actions have a cost, so only include actions that you think
will be helpful to you in doing a great job answering the
prompt in the future.
You must respond ONLY with a list of actions you would like to
perform. Each action should be on its own line, and followed by a space and then its parameter.
Actions can be performed more than once with different parameters.
Here is an example valid response:
```
file path/to/my/file.txt
file path/to/another/file.txt
search something to search for
search something else to search for
```
Once again, do not forget: you must respond ONLY in the format of
one action per line, and the action name should be followed by
its parameter. Your response must not include anything other
than a list of actions, with one action per line, in this format.
It is extremely important that you do not deviate from this format even slightly!
This is the end of my instructions for how to respond. The rest is the prompt:

View file

@ -0,0 +1,31 @@
I'm going to give you a prompt. I don't want you to respond
to the prompt itself. I want you to figure out which of the following
actions on my project, if any, would help you answer the prompt.
Here are the actions:
## file
This action's parameter is a file path to one of the files
in the project. If you ask for this action, I will tell you
the full contents of the file, so you can learn all the
details of the file.
## search
This action's parameter is a string to do a semantic search for
across the files in the project. (You will have a JSON summary
of all the files in the project.) It will tell you which files this string
(or similar strings; it is a semantic search) appear in,
as well as some context of the lines surrounding each result.
It's very important that you only use this action when you think
that searching across the specific files in this project for the query
in question will be useful. For example, don't use this command to search
for queries you might put into a general Web search engine, because those
will be too general to give useful results in this project-specific search.
---
That was the end of the list of actions.
Here is a JSON summary of each of the files in my project:

View file

@ -8,7 +8,7 @@ use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
use feature_flags::FeatureFlag;
use gpui::{AppContext, Task, WeakView};
use language::{CodeLabel, LineEnding, LspAdapterDelegate};
use semantic_index::SemanticIndex;
use semantic_index::SemanticDb;
use std::{
fmt::Write,
path::PathBuf,
@ -92,8 +92,11 @@ impl SlashCommand for SearchSlashCommand {
let project = workspace.read(cx).project().clone();
let fs = project.read(cx).fs().clone();
let project_index =
cx.update_global(|index: &mut SemanticIndex, cx| index.project_index(project, cx));
let Some(project_index) =
cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
else {
return Task::ready(Err(anyhow::anyhow!("no project indexer")));
};
cx.spawn(|cx| async move {
let results = project_index

View file

@ -149,16 +149,16 @@ spec:
secretKeyRef:
name: google-ai
key: api_key
- name: QWEN2_7B_API_KEY
- name: RUNPOD_API_KEY
valueFrom:
secretKeyRef:
name: hugging-face
name: runpod
key: api_key
- name: QWEN2_7B_API_URL
- name: RUNPOD_API_SUMMARY_URL
valueFrom:
secretKeyRef:
name: hugging-face
key: qwen2_api_url
name: runpod
key: summary
- name: BLOB_STORE_ACCESS_KEY
valueFrom:
secretKeyRef:

View file

@ -728,6 +728,11 @@ impl Database {
is_ignored: db_entry.is_ignored,
is_external: db_entry.is_external,
git_status: db_entry.git_status.map(|status| status as i32),
// This is only used in the summarization backlog, so if it's None,
// that just means we won't be able to detect when to resummarize
// based on total number of backlogged bytes - instead, we'd go
// on number of files only. That shouldn't be a huge deal in practice.
size: None,
is_fifo: db_entry.is_fifo,
});
}

View file

@ -663,6 +663,11 @@ impl Database {
is_ignored: db_entry.is_ignored,
is_external: db_entry.is_external,
git_status: db_entry.git_status.map(|status| status as i32),
// This is only used in the summarization backlog, so if it's None,
// that just means we won't be able to detect when to resummarize
// based on total number of backlogged bytes - instead, we'd go
// on number of files only. That shouldn't be a huge deal in practice.
size: None,
is_fifo: db_entry.is_fifo,
});
}

View file

@ -170,8 +170,8 @@ pub struct Config {
pub anthropic_api_key: Option<Arc<str>>,
pub anthropic_staff_api_key: Option<Arc<str>>,
pub llm_closed_beta_model_name: Option<Arc<str>>,
pub qwen2_7b_api_key: Option<Arc<str>>,
pub qwen2_7b_api_url: Option<Arc<str>>,
pub runpod_api_key: Option<Arc<str>>,
pub runpod_api_summary_url: Option<Arc<str>>,
pub zed_client_checksum_seed: Option<String>,
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
@ -235,8 +235,8 @@ impl Config {
stripe_api_key: None,
stripe_price_id: None,
supermaven_admin_api_key: None,
qwen2_7b_api_key: None,
qwen2_7b_api_url: None,
runpod_api_key: None,
runpod_api_summary_url: None,
user_backfiller_github_access_token: None,
}
}

View file

@ -402,12 +402,12 @@ async fn perform_completion(
LanguageModelProvider::Zed => {
let api_key = state
.config
.qwen2_7b_api_key
.runpod_api_key
.as_ref()
.context("no Qwen2-7B API key configured on the server")?;
let api_url = state
.config
.qwen2_7b_api_url
.runpod_api_summary_url
.as_ref()
.context("no Qwen2-7B URL configured on the server")?;
let chunks = open_ai::stream_completion(

View file

@ -1,5 +1,5 @@
use super::*;
use sea_orm::QueryOrder;
use sea_orm::{sea_query::OnConflict, QueryOrder};
use std::str::FromStr;
use strum::IntoEnumIterator as _;
@ -99,6 +99,17 @@ impl LlmDatabase {
..Default::default()
}
}))
.on_conflict(
OnConflict::columns([model::Column::ProviderId, model::Column::Name])
.update_columns([
model::Column::MaxRequestsPerMinute,
model::Column::MaxTokensPerMinute,
model::Column::MaxTokensPerDay,
model::Column::PricePerMillionInputTokens,
model::Column::PricePerMillionOutputTokens,
])
.to_owned(),
)
.exec_without_returning(&*tx)
.await?;
Ok(())

View file

@ -40,6 +40,15 @@ pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool)
price_per_million_input_tokens: 25, // $0.25/MTok
price_per_million_output_tokens: 125, // $1.25/MTok
},
ModelParams {
provider: LanguageModelProvider::Zed,
name: "Qwen/Qwen2-7B-Instruct".into(),
max_requests_per_minute: 5,
max_tokens_per_minute: 25_000, // These are arbitrary limits we've set to cap costs; we control this number
max_tokens_per_day: 300_000,
price_per_million_input_tokens: 25,
price_per_million_output_tokens: 125,
},
])
.await
}

View file

@ -679,8 +679,8 @@ impl TestServer {
stripe_api_key: None,
stripe_price_id: None,
supermaven_admin_api_key: None,
qwen2_7b_api_key: None,
qwen2_7b_api_url: None,
runpod_api_key: None,
runpod_api_summary_url: None,
user_backfiller_github_access_token: None,
},
})

View file

@ -13,3 +13,4 @@ path = "src/feature_flags.rs"
[dependencies]
gpui.workspace = true
futures.workspace = true

View file

@ -1,4 +1,10 @@
use futures::{channel::oneshot, FutureExt as _};
use gpui::{AppContext, Global, Subscription, ViewContext};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
#[derive(Default)]
struct FeatureFlags {
@ -53,6 +59,15 @@ impl FeatureFlag for ZedPro {
const NAME: &'static str = "zed-pro";
}
pub struct AutoCommand {}
impl FeatureFlag for AutoCommand {
const NAME: &'static str = "auto-command";
fn enabled_for_staff() -> bool {
false
}
}
pub trait FeatureFlagViewExt<V: 'static> {
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
where
@ -75,6 +90,7 @@ where
}
pub trait FeatureFlagAppExt {
fn wait_for_flag<T: FeatureFlag>(&mut self) -> WaitForFlag;
fn update_flags(&mut self, staff: bool, flags: Vec<String>);
fn set_staff(&mut self, staff: bool);
fn has_flag<T: FeatureFlag>(&self) -> bool;
@ -82,7 +98,7 @@ pub trait FeatureFlagAppExt {
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
where
F: Fn(bool, &mut AppContext) + 'static;
F: FnMut(bool, &mut AppContext) + 'static;
}
impl FeatureFlagAppExt for AppContext {
@ -109,13 +125,49 @@ impl FeatureFlagAppExt for AppContext {
.unwrap_or(false)
}
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
fn observe_flag<T: FeatureFlag, F>(&mut self, mut callback: F) -> Subscription
where
F: Fn(bool, &mut AppContext) + 'static,
F: FnMut(bool, &mut AppContext) + 'static,
{
self.observe_global::<FeatureFlags>(move |cx| {
let feature_flags = cx.global::<FeatureFlags>();
callback(feature_flags.has_flag::<T>(), cx);
})
}
fn wait_for_flag<T: FeatureFlag>(&mut self) -> WaitForFlag {
let (tx, rx) = oneshot::channel::<bool>();
let mut tx = Some(tx);
let subscription: Option<Subscription>;
match self.try_global::<FeatureFlags>() {
Some(feature_flags) => {
subscription = None;
tx.take().unwrap().send(feature_flags.has_flag::<T>()).ok();
}
None => {
subscription = Some(self.observe_global::<FeatureFlags>(move |cx| {
let feature_flags = cx.global::<FeatureFlags>();
if let Some(tx) = tx.take() {
tx.send(feature_flags.has_flag::<T>()).ok();
}
}));
}
}
WaitForFlag(rx, subscription)
}
}
pub struct WaitForFlag(oneshot::Receiver<bool>, Option<Subscription>);
impl Future for WaitForFlag {
type Output = bool;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_unpin(cx).map(|result| {
self.1.take();
result.unwrap_or(false)
})
}
}

View file

@ -171,6 +171,7 @@ pub struct Metadata {
pub mtime: SystemTime,
pub is_symlink: bool,
pub is_dir: bool,
pub len: u64,
pub is_fifo: bool,
}
@ -497,6 +498,7 @@ impl Fs for RealFs {
Ok(Some(Metadata {
inode,
mtime: metadata.modified().unwrap(),
len: metadata.len(),
is_symlink,
is_dir: metadata.file_type().is_dir(),
is_fifo,
@ -800,11 +802,13 @@ enum FakeFsEntry {
File {
inode: u64,
mtime: SystemTime,
len: u64,
content: Vec<u8>,
},
Dir {
inode: u64,
mtime: SystemTime,
len: u64,
entries: BTreeMap<String, Arc<Mutex<FakeFsEntry>>>,
git_repo_state: Option<Arc<Mutex<git::repository::FakeGitRepositoryState>>>,
},
@ -935,6 +939,7 @@ impl FakeFs {
root: Arc::new(Mutex::new(FakeFsEntry::Dir {
inode: 0,
mtime: SystemTime::UNIX_EPOCH,
len: 0,
entries: Default::default(),
git_repo_state: None,
})),
@ -969,6 +974,7 @@ impl FakeFs {
inode: new_inode,
mtime: new_mtime,
content: Vec::new(),
len: 0,
})));
}
btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() {
@ -1016,6 +1022,7 @@ impl FakeFs {
let file = Arc::new(Mutex::new(FakeFsEntry::File {
inode,
mtime,
len: content.len() as u64,
content,
}));
let mut kind = None;
@ -1369,6 +1376,7 @@ impl Fs for FakeFs {
Arc::new(Mutex::new(FakeFsEntry::Dir {
inode,
mtime,
len: 0,
entries: Default::default(),
git_repo_state: None,
}))
@ -1391,6 +1399,7 @@ impl Fs for FakeFs {
let file = Arc::new(Mutex::new(FakeFsEntry::File {
inode,
mtime,
len: 0,
content: Vec::new(),
}));
let mut kind = Some(PathEventKind::Created);
@ -1539,6 +1548,7 @@ impl Fs for FakeFs {
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
inode,
mtime,
len: content.len() as u64,
content: Vec::new(),
})))
.clone(),
@ -1694,16 +1704,22 @@ impl Fs for FakeFs {
let entry = entry.lock();
Ok(Some(match &*entry {
FakeFsEntry::File { inode, mtime, .. } => Metadata {
FakeFsEntry::File {
inode, mtime, len, ..
} => Metadata {
inode: *inode,
mtime: *mtime,
len: *len,
is_dir: false,
is_symlink,
is_fifo: false,
},
FakeFsEntry::Dir { inode, mtime, .. } => Metadata {
FakeFsEntry::Dir {
inode, mtime, len, ..
} => Metadata {
inode: *inode,
mtime: *mtime,
len: *len,
is_dir: true,
is_symlink,
is_fifo: false,

View file

@ -57,7 +57,6 @@ impl GitStatus {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(anyhow!("git status process failed: {}", stderr));
}
let stdout = String::from_utf8_lossy(&output.stdout);
let mut entries = stdout
.split('\0')

View file

@ -221,6 +221,10 @@ impl HttpClient for HttpClientWithUrl {
pub fn client(user_agent: Option<String>, proxy: Option<Uri>) -> Arc<dyn HttpClient> {
let mut builder = isahc::HttpClient::builder()
// Some requests to Qwen2 models on Runpod can take 32+ seconds,
// especially if there's a cold boot involved. We may need to have
// those requests use a different http client, because global timeouts
// of 50 and 60 seconds, respectively, would be very high!
.connect_timeout(Duration::from_secs(5))
.low_speed_timeout(100, Duration::from_secs(5))
.proxy(proxy.clone());

View file

@ -17,14 +17,14 @@ pub enum CloudModel {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
pub enum ZedModel {
#[serde(rename = "qwen2-7b-instruct")]
#[serde(rename = "Qwen/Qwen2-7B-Instruct")]
Qwen2_7bInstruct,
}
impl ZedModel {
pub fn id(&self) -> &str {
match self {
ZedModel::Qwen2_7bInstruct => "qwen2-7b-instruct",
ZedModel::Qwen2_7bInstruct => "Qwen/Qwen2-7B-Instruct",
}
}

View file

@ -319,7 +319,7 @@ impl AnthropicModel {
};
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let api_key = api_key.ok_or_else(|| anyhow!("Missing Anthropic API Key"))?;
let request = anthropic::stream_completion(
http_client.as_ref(),
&api_url,

View file

@ -265,7 +265,7 @@ impl LanguageModel for GoogleLanguageModel {
let low_speed_timeout = settings.low_speed_timeout;
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API key"))?;
let response = google_ai::count_tokens(
http_client.as_ref(),
&api_url,
@ -304,7 +304,7 @@ impl LanguageModel for GoogleLanguageModel {
};
let future = self.rate_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API Key"))?;
let response = stream_generate_content(
http_client.as_ref(),
&api_url,

View file

@ -239,7 +239,7 @@ impl OpenAiLanguageModel {
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenAI API Key"))?;
let request = stream_completion(
http_client.as_ref(),
&api_url,

View file

@ -159,11 +159,13 @@ impl LanguageModelRegistry {
providers
}
pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
pub fn available_models<'a>(
&'a self,
cx: &'a AppContext,
) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
self.providers
.values()
.flat_map(|provider| provider.provided_models(cx))
.collect()
}
pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {

View file

@ -1823,6 +1823,7 @@ impl ProjectPanel {
path: entry.path.join("\0").into(),
inode: 0,
mtime: entry.mtime,
size: entry.size,
is_ignored: entry.is_ignored,
is_external: false,
is_private: false,

View file

@ -1855,6 +1855,7 @@ message Entry {
bool is_external = 8;
optional GitStatus git_status = 9;
bool is_fifo = 10;
optional uint64 size = 11;
}
message RepositoryEntry {

View file

@ -19,14 +19,18 @@ crate-type = ["bin"]
[dependencies]
anyhow.workspace = true
arrayvec.workspace = true
blake3.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
feature_flags.workspace = true
fs.workspace = true
futures.workspace = true
futures-batch.workspace = true
gpui.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
heed.workspace = true
http_client.workspace = true

View file

@ -4,7 +4,7 @@ use gpui::App;
use http_client::HttpClientWithUrl;
use language::language_settings::AllLanguageSettings;
use project::Project;
use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex};
use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticDb};
use settings::SettingsStore;
use std::{
path::{Path, PathBuf},
@ -50,7 +50,7 @@ fn main() {
));
cx.spawn(|mut cx| async move {
let semantic_index = SemanticIndex::new(
let semantic_index = SemanticDb::new(
PathBuf::from("/tmp/semantic-index-db.mdb"),
embedding_provider,
&mut cx,
@ -71,6 +71,7 @@ fn main() {
let project_index = cx
.update(|cx| semantic_index.project_index(project.clone(), cx))
.unwrap()
.unwrap();
let (tx, rx) = oneshot::channel();

View file

@ -12,6 +12,12 @@ use futures::{future::BoxFuture, FutureExt};
use serde::{Deserialize, Serialize};
use std::{fmt, future};
/// Trait for embedding providers. Texts in, vectors out.
pub trait EmbeddingProvider: Sync + Send {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
fn batch_size(&self) -> usize;
}
#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
pub struct Embedding(Vec<f32>);
@ -68,12 +74,6 @@ impl fmt::Display for Embedding {
}
}
/// Trait for embedding providers. Texts in, vectors out.
pub trait EmbeddingProvider: Sync + Send {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
fn batch_size(&self) -> usize;
}
#[derive(Debug)]
pub struct TextToEmbed<'a> {
pub text: &'a str,

View file

@ -0,0 +1,469 @@
use crate::{
chunking::{self, Chunk},
embedding::{Embedding, EmbeddingProvider, TextToEmbed},
indexing::{IndexingEntryHandle, IndexingEntrySet},
};
use anyhow::{anyhow, Context as _, Result};
use collections::Bound;
use fs::Fs;
use futures::stream::StreamExt;
use futures_batch::ChunksTimeoutStreamExt;
use gpui::{AppContext, Model, Task};
use heed::types::{SerdeBincode, Str};
use language::LanguageRegistry;
use log;
use project::{Entry, UpdatedEntriesSet, Worktree};
use serde::{Deserialize, Serialize};
use smol::channel;
use std::{
cmp::Ordering,
future::Future,
iter,
path::Path,
sync::Arc,
time::{Duration, SystemTime},
};
use util::ResultExt;
use worktree::Snapshot;
pub struct EmbeddingIndex {
worktree: Model<Worktree>,
db_connection: heed::Env,
db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
fs: Arc<dyn Fs>,
language_registry: Arc<LanguageRegistry>,
embedding_provider: Arc<dyn EmbeddingProvider>,
entry_ids_being_indexed: Arc<IndexingEntrySet>,
}
impl EmbeddingIndex {
pub fn new(
worktree: Model<Worktree>,
fs: Arc<dyn Fs>,
db_connection: heed::Env,
embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
language_registry: Arc<LanguageRegistry>,
embedding_provider: Arc<dyn EmbeddingProvider>,
entry_ids_being_indexed: Arc<IndexingEntrySet>,
) -> Self {
Self {
worktree,
fs,
db_connection,
db: embedding_db,
language_registry,
embedding_provider,
entry_ids_being_indexed,
}
}
pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
&self.db
}
pub fn index_entries_changed_on_disk(
&self,
cx: &AppContext,
) -> impl Future<Output = Result<()>> {
let worktree = self.worktree.read(cx).snapshot();
let worktree_abs_path = worktree.abs_path().clone();
let scan = self.scan_entries(worktree, cx);
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
async move {
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
Ok(())
}
}
pub fn index_updated_entries(
&self,
updated_entries: UpdatedEntriesSet,
cx: &AppContext,
) -> impl Future<Output = Result<()>> {
let worktree = self.worktree.read(cx).snapshot();
let worktree_abs_path = worktree.abs_path().clone();
let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
async move {
futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
Ok(())
}
}
fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
let db_connection = self.db_connection.clone();
let db = self.db;
let entries_being_indexed = self.entry_ids_being_indexed.clone();
let task = cx.background_executor().spawn(async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
let mut db_entries = db
.iter(&txn)
.context("failed to create iterator")?
.move_between_keys()
.peekable();
let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
for entry in worktree.files(false, 0) {
log::trace!("scanning for embedding index: {:?}", &entry.path);
let entry_db_key = db_key_for_path(&entry.path);
let mut saved_mtime = None;
while let Some(db_entry) = db_entries.peek() {
match db_entry {
Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
Ordering::Less => {
if let Some(deletion_range) = deletion_range.as_mut() {
deletion_range.1 = Bound::Included(db_path);
} else {
deletion_range =
Some((Bound::Included(db_path), Bound::Included(db_path)));
}
db_entries.next();
}
Ordering::Equal => {
if let Some(deletion_range) = deletion_range.take() {
deleted_entry_ranges_tx
.send((
deletion_range.0.map(ToString::to_string),
deletion_range.1.map(ToString::to_string),
))
.await?;
}
saved_mtime = db_embedded_file.mtime;
db_entries.next();
break;
}
Ordering::Greater => {
break;
}
},
Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
}
}
if entry.mtime != saved_mtime {
let handle = entries_being_indexed.insert(entry.id);
updated_entries_tx.send((entry.clone(), handle)).await?;
}
}
if let Some(db_entry) = db_entries.next() {
let (db_path, _) = db_entry?;
deleted_entry_ranges_tx
.send((Bound::Included(db_path.to_string()), Bound::Unbounded))
.await?;
}
Ok(())
});
ScanEntries {
updated_entries: updated_entries_rx,
deleted_entry_ranges: deleted_entry_ranges_rx,
task,
}
}
fn scan_updated_entries(
&self,
worktree: Snapshot,
updated_entries: UpdatedEntriesSet,
cx: &AppContext,
) -> ScanEntries {
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
let entries_being_indexed = self.entry_ids_being_indexed.clone();
let task = cx.background_executor().spawn(async move {
for (path, entry_id, status) in updated_entries.iter() {
match status {
project::PathChange::Added
| project::PathChange::Updated
| project::PathChange::AddedOrUpdated => {
if let Some(entry) = worktree.entry_for_id(*entry_id) {
if entry.is_file() {
let handle = entries_being_indexed.insert(entry.id);
updated_entries_tx.send((entry.clone(), handle)).await?;
}
}
}
project::PathChange::Removed => {
let db_path = db_key_for_path(path);
deleted_entry_ranges_tx
.send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
.await?;
}
project::PathChange::Loaded => {
// Do nothing.
}
}
}
Ok(())
});
ScanEntries {
updated_entries: updated_entries_rx,
deleted_entry_ranges: deleted_entry_ranges_rx,
task,
}
}
fn chunk_files(
&self,
worktree_abs_path: Arc<Path>,
entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
cx: &AppContext,
) -> ChunkFiles {
let language_registry = self.language_registry.clone();
let fs = self.fs.clone();
let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
let task = cx.spawn(|cx| async move {
cx.background_executor()
.scoped(|cx| {
for _ in 0..cx.num_cpus() {
cx.spawn(async {
while let Ok((entry, handle)) = entries.recv().await {
let entry_abs_path = worktree_abs_path.join(&entry.path);
match fs.load(&entry_abs_path).await {
Ok(text) => {
let language = language_registry
.language_for_file_path(&entry.path)
.await
.ok();
let chunked_file = ChunkedFile {
chunks: chunking::chunk_text(
&text,
language.as_ref(),
&entry.path,
),
handle,
path: entry.path,
mtime: entry.mtime,
text,
};
if chunked_files_tx.send(chunked_file).await.is_err() {
return;
}
}
Err(_)=> {
log::error!("Failed to read contents into a UTF-8 string: {entry_abs_path:?}");
}
}
}
});
}
})
.await;
Ok(())
});
ChunkFiles {
files: chunked_files_rx,
task,
}
}
pub fn embed_files(
embedding_provider: Arc<dyn EmbeddingProvider>,
chunked_files: channel::Receiver<ChunkedFile>,
cx: &AppContext,
) -> EmbedFiles {
let embedding_provider = embedding_provider.clone();
let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
let task = cx.background_executor().spawn(async move {
let mut chunked_file_batches =
chunked_files.chunks_timeout(512, Duration::from_secs(2));
while let Some(chunked_files) = chunked_file_batches.next().await {
// View the batch of files as a vec of chunks
// Flatten out to a vec of chunks that we can subdivide into batch sized pieces
// Once those are done, reassemble them back into the files in which they belong
// If any embeddings fail for a file, the entire file is discarded
let chunks: Vec<TextToEmbed> = chunked_files
.iter()
.flat_map(|file| {
file.chunks.iter().map(|chunk| TextToEmbed {
text: &file.text[chunk.range.clone()],
digest: chunk.digest,
})
})
.collect::<Vec<_>>();
let mut embeddings: Vec<Option<Embedding>> = Vec::new();
for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
if let Some(batch_embeddings) =
embedding_provider.embed(embedding_batch).await.log_err()
{
if batch_embeddings.len() == embedding_batch.len() {
embeddings.extend(batch_embeddings.into_iter().map(Some));
continue;
}
log::error!(
"embedding provider returned unexpected embedding count {}, expected {}",
batch_embeddings.len(), embedding_batch.len()
);
}
embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
}
let mut embeddings = embeddings.into_iter();
for chunked_file in chunked_files {
let mut embedded_file = EmbeddedFile {
path: chunked_file.path,
mtime: chunked_file.mtime,
chunks: Vec::new(),
};
let mut embedded_all_chunks = true;
for (chunk, embedding) in
chunked_file.chunks.into_iter().zip(embeddings.by_ref())
{
if let Some(embedding) = embedding {
embedded_file
.chunks
.push(EmbeddedChunk { chunk, embedding });
} else {
embedded_all_chunks = false;
}
}
if embedded_all_chunks {
embedded_files_tx
.send((embedded_file, chunked_file.handle))
.await?;
}
}
}
Ok(())
});
EmbedFiles {
files: embedded_files_rx,
task,
}
}
fn persist_embeddings(
&self,
mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
cx: &AppContext,
) -> Task<Result<()>> {
let db_connection = self.db_connection.clone();
let db = self.db;
cx.background_executor().spawn(async move {
while let Some(deletion_range) = deleted_entry_ranges.next().await {
let mut txn = db_connection.write_txn()?;
let start = deletion_range.0.as_ref().map(|start| start.as_str());
let end = deletion_range.1.as_ref().map(|end| end.as_str());
log::debug!("deleting embeddings in range {:?}", &(start, end));
db.delete_range(&mut txn, &(start, end))?;
txn.commit()?;
}
let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
while let Some(embedded_files) = embedded_files.next().await {
let mut txn = db_connection.write_txn()?;
for (file, _) in &embedded_files {
log::debug!("saving embedding for file {:?}", file.path);
let key = db_key_for_path(&file.path);
db.put(&mut txn, &key, file)?;
}
txn.commit()?;
drop(embedded_files);
log::debug!("committed");
}
Ok(())
})
}
pub fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
let connection = self.db_connection.clone();
let db = self.db;
cx.background_executor().spawn(async move {
let tx = connection
.read_txn()
.context("failed to create read transaction")?;
let result = db
.iter(&tx)?
.map(|entry| Ok(entry?.1.path.clone()))
.collect::<Result<Vec<Arc<Path>>>>();
drop(tx);
result
})
}
pub fn chunks_for_path(
&self,
path: Arc<Path>,
cx: &AppContext,
) -> Task<Result<Vec<EmbeddedChunk>>> {
let connection = self.db_connection.clone();
let db = self.db;
cx.background_executor().spawn(async move {
let tx = connection
.read_txn()
.context("failed to create read transaction")?;
Ok(db
.get(&tx, &db_key_for_path(&path))?
.ok_or_else(|| anyhow!("no such path"))?
.chunks
.clone())
})
}
}
struct ScanEntries {
updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
task: Task<Result<()>>,
}
struct ChunkFiles {
files: channel::Receiver<ChunkedFile>,
task: Task<Result<()>>,
}
pub struct ChunkedFile {
pub path: Arc<Path>,
pub mtime: Option<SystemTime>,
pub handle: IndexingEntryHandle,
pub text: String,
pub chunks: Vec<Chunk>,
}
pub struct EmbedFiles {
pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
pub task: Task<Result<()>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct EmbeddedFile {
pub path: Arc<Path>,
pub mtime: Option<SystemTime>,
pub chunks: Vec<EmbeddedChunk>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct EmbeddedChunk {
pub chunk: Chunk,
pub embedding: Embedding,
}
fn db_key_for_path(path: &Arc<Path>) -> String {
path.to_string_lossy().replace('/', "\0")
}

View file

@ -0,0 +1,49 @@
use collections::HashSet;
use parking_lot::Mutex;
use project::ProjectEntryId;
use smol::channel;
use std::sync::{Arc, Weak};
/// The set of entries that are currently being indexed.
pub struct IndexingEntrySet {
entry_ids: Mutex<HashSet<ProjectEntryId>>,
tx: channel::Sender<()>,
}
/// When dropped, removes the entry from the set of entries that are being indexed.
#[derive(Clone)]
pub(crate) struct IndexingEntryHandle {
entry_id: ProjectEntryId,
set: Weak<IndexingEntrySet>,
}
impl IndexingEntrySet {
pub fn new(tx: channel::Sender<()>) -> Self {
Self {
entry_ids: Default::default(),
tx,
}
}
pub fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
self.entry_ids.lock().insert(entry_id);
self.tx.send_blocking(()).ok();
IndexingEntryHandle {
entry_id,
set: Arc::downgrade(self),
}
}
pub fn len(&self) -> usize {
self.entry_ids.lock().len()
}
}
impl Drop for IndexingEntryHandle {
fn drop(&mut self) {
if let Some(set) = self.set.upgrade() {
set.tx.send_blocking(()).ok();
set.entry_ids.lock().remove(&self.entry_id);
}
}
}

View file

@ -0,0 +1,523 @@
use crate::{
embedding::{EmbeddingProvider, TextToEmbed},
summary_index::FileSummary,
worktree_index::{WorktreeIndex, WorktreeIndexHandle},
};
use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use fs::Fs;
use futures::{stream::StreamExt, FutureExt};
use gpui::{
AppContext, Entity, EntityId, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel,
};
use language::LanguageRegistry;
use log;
use project::{Project, Worktree, WorktreeId};
use serde::{Deserialize, Serialize};
use smol::channel;
use std::{cmp::Ordering, future::Future, num::NonZeroUsize, ops::Range, path::Path, sync::Arc};
use util::ResultExt;
#[derive(Debug)]
pub struct SearchResult {
pub worktree: Model<Worktree>,
pub path: Arc<Path>,
pub range: Range<usize>,
pub score: f32,
}
pub struct WorktreeSearchResult {
pub worktree_id: WorktreeId,
pub path: Arc<Path>,
pub range: Range<usize>,
pub score: f32,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum Status {
Idle,
Loading,
Scanning { remaining_count: NonZeroUsize },
}
pub struct ProjectIndex {
db_connection: heed::Env,
project: WeakModel<Project>,
worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
last_status: Status,
status_tx: channel::Sender<()>,
embedding_provider: Arc<dyn EmbeddingProvider>,
_maintain_status: Task<()>,
_subscription: Subscription,
}
impl ProjectIndex {
pub fn new(
project: Model<Project>,
db_connection: heed::Env,
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut ModelContext<Self>,
) -> Self {
let language_registry = project.read(cx).languages().clone();
let fs = project.read(cx).fs().clone();
let (status_tx, mut status_rx) = channel::unbounded();
let mut this = ProjectIndex {
db_connection,
project: project.downgrade(),
worktree_indices: HashMap::default(),
language_registry,
fs,
status_tx,
last_status: Status::Idle,
embedding_provider,
_subscription: cx.subscribe(&project, Self::handle_project_event),
_maintain_status: cx.spawn(|this, mut cx| async move {
while status_rx.next().await.is_some() {
if this
.update(&mut cx, |this, cx| this.update_status(cx))
.is_err()
{
break;
}
}
}),
};
this.update_worktree_indices(cx);
this
}
pub fn status(&self) -> Status {
self.last_status
}
pub fn project(&self) -> WeakModel<Project> {
self.project.clone()
}
pub fn fs(&self) -> Arc<dyn Fs> {
self.fs.clone()
}
fn handle_project_event(
&mut self,
_: Model<Project>,
event: &project::Event,
cx: &mut ModelContext<Self>,
) {
match event {
project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
self.update_worktree_indices(cx);
}
_ => {}
}
}
fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
let Some(project) = self.project.upgrade() else {
return;
};
let worktrees = project
.read(cx)
.visible_worktrees(cx)
.filter_map(|worktree| {
if worktree.read(cx).is_local() {
Some((worktree.entity_id(), worktree))
} else {
None
}
})
.collect::<HashMap<_, _>>();
self.worktree_indices
.retain(|worktree_id, _| worktrees.contains_key(worktree_id));
for (worktree_id, worktree) in worktrees {
self.worktree_indices.entry(worktree_id).or_insert_with(|| {
let worktree_index = WorktreeIndex::load(
worktree.clone(),
self.db_connection.clone(),
self.language_registry.clone(),
self.fs.clone(),
self.status_tx.clone(),
self.embedding_provider.clone(),
cx,
);
let load_worktree = cx.spawn(|this, mut cx| async move {
let result = match worktree_index.await {
Ok(worktree_index) => {
this.update(&mut cx, |this, _| {
this.worktree_indices.insert(
worktree_id,
WorktreeIndexHandle::Loaded {
index: worktree_index.clone(),
},
);
})?;
Ok(worktree_index)
}
Err(error) => {
this.update(&mut cx, |this, _cx| {
this.worktree_indices.remove(&worktree_id)
})?;
Err(Arc::new(error))
}
};
this.update(&mut cx, |this, cx| this.update_status(cx))?;
result
});
WorktreeIndexHandle::Loading {
index: load_worktree.shared(),
}
});
}
self.update_status(cx);
}
fn update_status(&mut self, cx: &mut ModelContext<Self>) {
let mut indexing_count = 0;
let mut any_loading = false;
for index in self.worktree_indices.values_mut() {
match index {
WorktreeIndexHandle::Loading { .. } => {
any_loading = true;
break;
}
WorktreeIndexHandle::Loaded { index, .. } => {
indexing_count += index.read(cx).entry_ids_being_indexed().len();
}
}
}
let status = if any_loading {
Status::Loading
} else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
Status::Scanning { remaining_count }
} else {
Status::Idle
};
if status != self.last_status {
self.last_status = status;
cx.emit(status);
}
}
pub fn search(
&self,
query: String,
limit: usize,
cx: &AppContext,
) -> Task<Result<Vec<SearchResult>>> {
let (chunks_tx, chunks_rx) = channel::bounded(1024);
let mut worktree_scan_tasks = Vec::new();
for worktree_index in self.worktree_indices.values() {
let worktree_index = worktree_index.clone();
let chunks_tx = chunks_tx.clone();
worktree_scan_tasks.push(cx.spawn(|cx| async move {
let index = match worktree_index {
WorktreeIndexHandle::Loading { index } => {
index.clone().await.map_err(|error| anyhow!(error))?
}
WorktreeIndexHandle::Loaded { index } => index.clone(),
};
index
.read_with(&cx, |index, cx| {
let worktree_id = index.worktree().read(cx).id();
let db_connection = index.db_connection().clone();
let db = *index.embedding_index().db();
cx.background_executor().spawn(async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
let db_entries = db.iter(&txn).context("failed to iterate database")?;
for db_entry in db_entries {
let (_key, db_embedded_file) = db_entry?;
for chunk in db_embedded_file.chunks {
chunks_tx
.send((worktree_id, db_embedded_file.path.clone(), chunk))
.await?;
}
}
anyhow::Ok(())
})
})?
.await
}));
}
drop(chunks_tx);
let project = self.project.clone();
let embedding_provider = self.embedding_provider.clone();
cx.spawn(|cx| async move {
#[cfg(debug_assertions)]
let embedding_query_start = std::time::Instant::now();
log::info!("Searching for {query}");
let query_embeddings = embedding_provider
.embed(&[TextToEmbed::new(&query)])
.await?;
let query_embedding = query_embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("no embedding for query"))?;
let mut results_by_worker = Vec::new();
for _ in 0..cx.background_executor().num_cpus() {
results_by_worker.push(Vec::<WorktreeSearchResult>::new());
}
#[cfg(debug_assertions)]
let search_start = std::time::Instant::now();
cx.background_executor()
.scoped(|cx| {
for results in results_by_worker.iter_mut() {
cx.spawn(async {
while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
let score = chunk.embedding.similarity(&query_embedding);
let ix = match results.binary_search_by(|probe| {
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
}) {
Ok(ix) | Err(ix) => ix,
};
results.insert(
ix,
WorktreeSearchResult {
worktree_id,
path: path.clone(),
range: chunk.chunk.range.clone(),
score,
},
);
results.truncate(limit);
}
});
}
})
.await;
for scan_task in futures::future::join_all(worktree_scan_tasks).await {
scan_task.log_err();
}
project.read_with(&cx, |project, cx| {
let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
for worker_results in results_by_worker {
search_results.extend(worker_results.into_iter().filter_map(|result| {
Some(SearchResult {
worktree: project.worktree_for_id(result.worktree_id, cx)?,
path: result.path,
range: result.range,
score: result.score,
})
}));
}
search_results.sort_unstable_by(|a, b| {
b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
});
search_results.truncate(limit);
#[cfg(debug_assertions)]
{
let search_elapsed = search_start.elapsed();
log::debug!(
"searched {} entries in {:?}",
search_results.len(),
search_elapsed
);
let embedding_query_elapsed = embedding_query_start.elapsed();
log::debug!("embedding query took {:?}", embedding_query_elapsed);
}
search_results
})
})
}
#[cfg(test)]
pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
let mut result = 0;
for worktree_index in self.worktree_indices.values() {
if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
result += index.read(cx).path_count()?;
}
}
Ok(result)
}
pub(crate) fn worktree_index(
&self,
worktree_id: WorktreeId,
cx: &AppContext,
) -> Option<Model<WorktreeIndex>> {
for index in self.worktree_indices.values() {
if let WorktreeIndexHandle::Loaded { index, .. } = index {
if index.read(cx).worktree().read(cx).id() == worktree_id {
return Some(index.clone());
}
}
}
None
}
pub(crate) fn worktree_indices(&self, cx: &AppContext) -> Vec<Model<WorktreeIndex>> {
let mut result = self
.worktree_indices
.values()
.filter_map(|index| {
if let WorktreeIndexHandle::Loaded { index, .. } = index {
Some(index.clone())
} else {
None
}
})
.collect::<Vec<_>>();
result.sort_by_key(|index| index.read(cx).worktree().read(cx).id());
result
}
pub fn all_summaries(&self, cx: &AppContext) -> Task<Result<Vec<FileSummary>>> {
let (summaries_tx, summaries_rx) = channel::bounded(1024);
let mut worktree_scan_tasks = Vec::new();
for worktree_index in self.worktree_indices.values() {
let worktree_index = worktree_index.clone();
let summaries_tx: channel::Sender<(String, String)> = summaries_tx.clone();
worktree_scan_tasks.push(cx.spawn(|cx| async move {
let index = match worktree_index {
WorktreeIndexHandle::Loading { index } => {
index.clone().await.map_err(|error| anyhow!(error))?
}
WorktreeIndexHandle::Loaded { index } => index.clone(),
};
index
.read_with(&cx, |index, cx| {
let db_connection = index.db_connection().clone();
let summary_index = index.summary_index();
let file_digest_db = summary_index.file_digest_db();
let summary_db = summary_index.summary_db();
cx.background_executor().spawn(async move {
let txn = db_connection
.read_txn()
.context("failed to create db read transaction")?;
let db_entries = file_digest_db
.iter(&txn)
.context("failed to iterate database")?;
for db_entry in db_entries {
let (file_path, db_file) = db_entry?;
match summary_db.get(&txn, &db_file.digest) {
Ok(opt_summary) => {
// Currently, we only use summaries we already have. If the file hasn't been
// summarized yet, then we skip it and don't include it in the inferred context.
// If we want to do just-in-time summarization, this would be the place to do it!
if let Some(summary) = opt_summary {
summaries_tx
.send((file_path.to_string(), summary.to_string()))
.await?;
} else {
log::warn!("No summary found for {:?}", &db_file);
}
}
Err(err) => {
log::error!(
"Error reading from summary database: {:?}",
err
);
}
}
}
anyhow::Ok(())
})
})?
.await
}));
}
drop(summaries_tx);
let project = self.project.clone();
cx.spawn(|cx| async move {
let mut results_by_worker = Vec::new();
for _ in 0..cx.background_executor().num_cpus() {
results_by_worker.push(Vec::<FileSummary>::new());
}
cx.background_executor()
.scoped(|cx| {
for results in results_by_worker.iter_mut() {
cx.spawn(async {
while let Ok((filename, summary)) = summaries_rx.recv().await {
results.push(FileSummary { filename, summary });
}
});
}
})
.await;
for scan_task in futures::future::join_all(worktree_scan_tasks).await {
scan_task.log_err();
}
project.read_with(&cx, |_project, _cx| {
results_by_worker.into_iter().flatten().collect()
})
})
}
/// Empty out the backlogs of all the worktrees in the project
pub fn flush_summary_backlogs(&self, cx: &AppContext) -> impl Future<Output = ()> {
let flush_start = std::time::Instant::now();
futures::future::join_all(self.worktree_indices.values().map(|worktree_index| {
let worktree_index = worktree_index.clone();
cx.spawn(|cx| async move {
let index = match worktree_index {
WorktreeIndexHandle::Loading { index } => {
index.clone().await.map_err(|error| anyhow!(error))?
}
WorktreeIndexHandle::Loaded { index } => index.clone(),
};
let worktree_abs_path =
cx.update(|cx| index.read(cx).worktree().read(cx).abs_path())?;
index
.read_with(&cx, |index, cx| {
cx.background_executor()
.spawn(index.summary_index().flush_backlog(worktree_abs_path, cx))
})?
.await
})
}))
.map(move |results| {
// Log any errors, but don't block the user. These summaries are supposed to
// improve quality by providing extra context, but they aren't hard requirements!
for result in results {
if let Err(err) = result {
log::error!("Error flushing summary backlog: {:?}", err);
}
}
log::info!("Summary backlog flushed in {:?}", flush_start.elapsed());
})
}
pub fn remaining_summaries(&self, cx: &mut ModelContext<Self>) -> usize {
self.worktree_indices(cx)
.iter()
.map(|index| index.read(cx).summary_index().backlog_len())
.sum()
}
}
impl EventEmitter<Status> for ProjectIndex {}

View file

@ -55,8 +55,12 @@ impl ProjectIndexDebugView {
for index in worktree_indices {
let (root_path, worktree_id, worktree_paths) =
index.read_with(&cx, |index, cx| {
let worktree = index.worktree.read(cx);
(worktree.abs_path(), worktree.id(), index.paths(cx))
let worktree = index.worktree().read(cx);
(
worktree.abs_path(),
worktree.id(),
index.embedding_index().paths(cx),
)
})?;
rows.push(Row::Worktree(root_path));
rows.extend(
@ -82,10 +86,12 @@ impl ProjectIndexDebugView {
cx: &mut ViewContext<Self>,
) -> Option<()> {
let project_index = self.index.read(cx);
let fs = project_index.fs.clone();
let fs = project_index.fs().clone();
let worktree_index = project_index.worktree_index(worktree_id, cx)?.read(cx);
let root_path = worktree_index.worktree.read(cx).abs_path();
let chunks = worktree_index.chunks_for_path(file_path.clone(), cx);
let root_path = worktree_index.worktree().read(cx).abs_path();
let chunks = worktree_index
.embedding_index()
.chunks_for_path(file_path.clone(), cx);
cx.spawn(|this, mut cx| async move {
let chunks = chunks.await?;

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,48 @@
use collections::HashMap;
use std::{path::Path, sync::Arc, time::SystemTime};
const MAX_FILES_BEFORE_RESUMMARIZE: usize = 4;
const MAX_BYTES_BEFORE_RESUMMARIZE: u64 = 1_000_000; // 1 MB
#[derive(Default, Debug)]
pub struct SummaryBacklog {
/// Key: path to a file that needs summarization, but that we haven't summarized yet. Value: that file's size on disk, in bytes, and its mtime.
files: HashMap<Arc<Path>, (u64, Option<SystemTime>)>,
/// Cache of the sum of all values in `files`, so we don't have to traverse the whole map to check if we're over the byte limit.
total_bytes: u64,
}
impl SummaryBacklog {
/// Store the given path in the backlog, along with how many bytes are in it.
pub fn insert(&mut self, path: Arc<Path>, bytes_on_disk: u64, mtime: Option<SystemTime>) {
let (prev_bytes, _) = self
.files
.insert(path, (bytes_on_disk, mtime))
.unwrap_or_default(); // Default to 0 prev_bytes
// Update the cached total by subtracting out the old amount and adding the new one.
self.total_bytes = self.total_bytes - prev_bytes + bytes_on_disk;
}
/// Returns true if the total number of bytes in the backlog exceeds a predefined threshold.
pub fn needs_drain(&self) -> bool {
self.files.len() > MAX_FILES_BEFORE_RESUMMARIZE ||
// The whole purpose of the cached total_bytes is to make this comparison cheap.
// Otherwise we'd have to traverse the entire dictionary every time we wanted this answer.
self.total_bytes > MAX_BYTES_BEFORE_RESUMMARIZE
}
/// Remove all the entries in the backlog and return the file paths as an iterator.
#[allow(clippy::needless_lifetimes)] // Clippy thinks this 'a can be elided, but eliding it gives a compile error
pub fn drain<'a>(&'a mut self) -> impl Iterator<Item = (Arc<Path>, Option<SystemTime>)> + 'a {
self.total_bytes = 0;
self.files
.drain()
.map(|(path, (_size, mtime))| (path, mtime))
}
pub fn len(&self) -> usize {
self.files.len()
}
}

View file

@ -0,0 +1,693 @@
use anyhow::{anyhow, Context as _, Result};
use arrayvec::ArrayString;
use fs::Fs;
use futures::{stream::StreamExt, TryFutureExt};
use futures_batch::ChunksTimeoutStreamExt;
use gpui::{AppContext, Model, Task};
use heed::{
types::{SerdeBincode, Str},
RoTxn,
};
use language_model::{
LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, Role,
};
use log;
use parking_lot::Mutex;
use project::{Entry, UpdatedEntriesSet, Worktree};
use serde::{Deserialize, Serialize};
use smol::channel;
use std::{
future::Future,
path::Path,
sync::Arc,
time::{Duration, Instant, SystemTime},
};
use util::ResultExt;
use worktree::Snapshot;
use crate::{indexing::IndexingEntrySet, summary_backlog::SummaryBacklog};
#[derive(Serialize, Deserialize, Debug)]
pub struct FileSummary {
pub filename: String,
pub summary: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct UnsummarizedFile {
// Path to the file on disk
path: Arc<Path>,
// The mtime of the file on disk
mtime: Option<SystemTime>,
// BLAKE3 hash of the source file's contents
digest: Blake3Digest,
// The source file's contents
contents: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct SummarizedFile {
// Path to the file on disk
path: String,
// The mtime of the file on disk
mtime: Option<SystemTime>,
// BLAKE3 hash of the source file's contents
digest: Blake3Digest,
// The LLM's summary of the file's contents
summary: String,
}
/// This is what blake3's to_hex() method returns - see https://docs.rs/blake3/1.5.3/src/blake3/lib.rs.html#246
pub type Blake3Digest = ArrayString<{ blake3::OUT_LEN * 2 }>;
#[derive(Debug, Serialize, Deserialize)]
pub struct FileDigest {
pub mtime: Option<SystemTime>,
pub digest: Blake3Digest,
}
struct NeedsSummary {
files: channel::Receiver<UnsummarizedFile>,
task: Task<Result<()>>,
}
struct SummarizeFiles {
files: channel::Receiver<SummarizedFile>,
task: Task<Result<()>>,
}
pub struct SummaryIndex {
worktree: Model<Worktree>,
fs: Arc<dyn Fs>,
db_connection: heed::Env,
file_digest_db: heed::Database<Str, SerdeBincode<FileDigest>>, // Key: file path. Val: BLAKE3 digest of its contents.
summary_db: heed::Database<SerdeBincode<Blake3Digest>, Str>, // Key: BLAKE3 digest of a file's contents. Val: LLM summary of those contents.
backlog: Arc<Mutex<SummaryBacklog>>,
_entry_ids_being_indexed: Arc<IndexingEntrySet>, // TODO can this be removed?
}
struct Backlogged {
paths_to_digest: channel::Receiver<Vec<(Arc<Path>, Option<SystemTime>)>>,
task: Task<Result<()>>,
}
struct MightNeedSummaryFiles {
files: channel::Receiver<UnsummarizedFile>,
task: Task<Result<()>>,
}
impl SummaryIndex {
pub fn new(
worktree: Model<Worktree>,
fs: Arc<dyn Fs>,
db_connection: heed::Env,
file_digest_db: heed::Database<Str, SerdeBincode<FileDigest>>,
summary_db: heed::Database<SerdeBincode<Blake3Digest>, Str>,
_entry_ids_being_indexed: Arc<IndexingEntrySet>,
) -> Self {
Self {
worktree,
fs,
db_connection,
file_digest_db,
summary_db,
_entry_ids_being_indexed,
backlog: Default::default(),
}
}
pub fn file_digest_db(&self) -> heed::Database<Str, SerdeBincode<FileDigest>> {
self.file_digest_db
}
pub fn summary_db(&self) -> heed::Database<SerdeBincode<Blake3Digest>, Str> {
self.summary_db
}
pub fn index_entries_changed_on_disk(
&self,
is_auto_available: bool,
cx: &AppContext,
) -> impl Future<Output = Result<()>> {
let start = Instant::now();
let backlogged;
let digest;
let needs_summary;
let summaries;
let persist;
if is_auto_available {
let worktree = self.worktree.read(cx).snapshot();
let worktree_abs_path = worktree.abs_path().clone();
backlogged = self.scan_entries(worktree, cx);
digest = self.digest_files(backlogged.paths_to_digest, worktree_abs_path, cx);
needs_summary = self.check_summary_cache(digest.files, cx);
summaries = self.summarize_files(needs_summary.files, cx);
persist = self.persist_summaries(summaries.files, cx);
} else {
// This feature is only staff-shipped, so make the rest of these no-ops.
backlogged = Backlogged {
paths_to_digest: channel::unbounded().1,
task: Task::ready(Ok(())),
};
digest = MightNeedSummaryFiles {
files: channel::unbounded().1,
task: Task::ready(Ok(())),
};
needs_summary = NeedsSummary {
files: channel::unbounded().1,
task: Task::ready(Ok(())),
};
summaries = SummarizeFiles {
files: channel::unbounded().1,
task: Task::ready(Ok(())),
};
persist = Task::ready(Ok(()));
}
async move {
futures::try_join!(
backlogged.task,
digest.task,
needs_summary.task,
summaries.task,
persist
)?;
if is_auto_available {
log::info!(
"Summarizing everything that changed on disk took {:?}",
start.elapsed()
);
}
Ok(())
}
}
pub fn index_updated_entries(
&mut self,
updated_entries: UpdatedEntriesSet,
is_auto_available: bool,
cx: &AppContext,
) -> impl Future<Output = Result<()>> {
let start = Instant::now();
let backlogged;
let digest;
let needs_summary;
let summaries;
let persist;
if is_auto_available {
let worktree = self.worktree.read(cx).snapshot();
let worktree_abs_path = worktree.abs_path().clone();
backlogged = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
digest = self.digest_files(backlogged.paths_to_digest, worktree_abs_path, cx);
needs_summary = self.check_summary_cache(digest.files, cx);
summaries = self.summarize_files(needs_summary.files, cx);
persist = self.persist_summaries(summaries.files, cx);
} else {
// This feature is only staff-shipped, so make the rest of these no-ops.
backlogged = Backlogged {
paths_to_digest: channel::unbounded().1,
task: Task::ready(Ok(())),
};
digest = MightNeedSummaryFiles {
files: channel::unbounded().1,
task: Task::ready(Ok(())),
};
needs_summary = NeedsSummary {
files: channel::unbounded().1,
task: Task::ready(Ok(())),
};
summaries = SummarizeFiles {
files: channel::unbounded().1,
task: Task::ready(Ok(())),
};
persist = Task::ready(Ok(()));
}
async move {
futures::try_join!(
backlogged.task,
digest.task,
needs_summary.task,
summaries.task,
persist
)?;
log::info!("Summarizing updated entries took {:?}", start.elapsed());
Ok(())
}
}
fn check_summary_cache(
&self,
mut might_need_summary: channel::Receiver<UnsummarizedFile>,
cx: &AppContext,
) -> NeedsSummary {
let db_connection = self.db_connection.clone();
let db = self.summary_db;
let (needs_summary_tx, needs_summary_rx) = channel::bounded(512);
let task = cx.background_executor().spawn(async move {
while let Some(file) = might_need_summary.next().await {
let tx = db_connection
.read_txn()
.context("Failed to create read transaction for checking which hashes are in summary cache")?;
match db.get(&tx, &file.digest) {
Ok(opt_answer) => {
if opt_answer.is_none() {
// It's not in the summary cache db, so we need to summarize it.
log::debug!("File {:?} (digest {:?}) was NOT in the db cache and needs to be resummarized.", file.path.display(), &file.digest);
needs_summary_tx.send(file).await?;
} else {
log::debug!("File {:?} (digest {:?}) was in the db cache and does not need to be resummarized.", file.path.display(), &file.digest);
}
}
Err(err) => {
log::error!("Reading from the summaries database failed: {:?}", err);
}
}
}
Ok(())
});
NeedsSummary {
files: needs_summary_rx,
task,
}
}
fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> Backlogged {
let (tx, rx) = channel::bounded(512);
let db_connection = self.db_connection.clone();
let digest_db = self.file_digest_db;
let backlog = Arc::clone(&self.backlog);
let task = cx.background_executor().spawn(async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
for entry in worktree.files(false, 0) {
let needs_summary =
Self::add_to_backlog(Arc::clone(&backlog), digest_db, &txn, entry);
if !needs_summary.is_empty() {
tx.send(needs_summary).await?;
}
}
// TODO delete db entries for deleted files
Ok(())
});
Backlogged {
paths_to_digest: rx,
task,
}
}
fn add_to_backlog(
backlog: Arc<Mutex<SummaryBacklog>>,
digest_db: heed::Database<Str, SerdeBincode<FileDigest>>,
txn: &RoTxn<'_>,
entry: &Entry,
) -> Vec<(Arc<Path>, Option<SystemTime>)> {
let entry_db_key = db_key_for_path(&entry.path);
match digest_db.get(&txn, &entry_db_key) {
Ok(opt_saved_digest) => {
// The file path is the same, but the mtime is different. (Or there was no mtime.)
// It needs updating, so add it to the backlog! Then, if the backlog is full, drain it and summarize its contents.
if entry.mtime != opt_saved_digest.and_then(|digest| digest.mtime) {
let mut backlog = backlog.lock();
log::info!(
"Inserting {:?} ({:?} bytes) into backlog",
&entry.path,
entry.size,
);
backlog.insert(Arc::clone(&entry.path), entry.size, entry.mtime);
if backlog.needs_drain() {
log::info!("Draining summary backlog...");
return backlog.drain().collect();
}
}
}
Err(err) => {
log::error!(
"Error trying to get file digest db entry {:?}: {:?}",
&entry_db_key,
err
);
}
}
Vec::new()
}
fn scan_updated_entries(
&self,
worktree: Snapshot,
updated_entries: UpdatedEntriesSet,
cx: &AppContext,
) -> Backlogged {
log::info!("Scanning for updated entries that might need summarization...");
let (tx, rx) = channel::bounded(512);
// let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
let db_connection = self.db_connection.clone();
let digest_db = self.file_digest_db;
let backlog = Arc::clone(&self.backlog);
let task = cx.background_executor().spawn(async move {
let txn = db_connection
.read_txn()
.context("failed to create read transaction")?;
for (path, entry_id, status) in updated_entries.iter() {
match status {
project::PathChange::Loaded
| project::PathChange::Added
| project::PathChange::Updated
| project::PathChange::AddedOrUpdated => {
if let Some(entry) = worktree.entry_for_id(*entry_id) {
if entry.is_file() {
let needs_summary = Self::add_to_backlog(
Arc::clone(&backlog),
digest_db,
&txn,
entry,
);
if !needs_summary.is_empty() {
tx.send(needs_summary).await?;
}
}
}
}
project::PathChange::Removed => {
let _db_path = db_key_for_path(path);
// TODO delete db entries for deleted files
// deleted_entry_ranges_tx
// .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
// .await?;
}
}
}
Ok(())
});
Backlogged {
paths_to_digest: rx,
// deleted_entry_ranges: deleted_entry_ranges_rx,
task,
}
}
fn digest_files(
&self,
paths: channel::Receiver<Vec<(Arc<Path>, Option<SystemTime>)>>,
worktree_abs_path: Arc<Path>,
cx: &AppContext,
) -> MightNeedSummaryFiles {
let fs = self.fs.clone();
let (rx, tx) = channel::bounded(2048);
let task = cx.spawn(|cx| async move {
cx.background_executor()
.scoped(|cx| {
for _ in 0..cx.num_cpus() {
cx.spawn(async {
while let Ok(pairs) = paths.recv().await {
// Note: we could process all these files concurrently if desired. Might or might not speed things up.
for (path, mtime) in pairs {
let entry_abs_path = worktree_abs_path.join(&path);
// Load the file's contents and compute its hash digest.
let unsummarized_file = {
let Some(contents) = fs
.load(&entry_abs_path)
.await
.with_context(|| {
format!("failed to read path {entry_abs_path:?}")
})
.log_err()
else {
continue;
};
let digest = {
let mut hasher = blake3::Hasher::new();
// Incorporate both the (relative) file path as well as the contents of the file into the hash.
// This is because in some languages and frameworks, identical files can do different things
// depending on their paths (e.g. Rails controllers). It's also why we send the path to the model.
hasher.update(path.display().to_string().as_bytes());
hasher.update(contents.as_bytes());
hasher.finalize().to_hex()
};
UnsummarizedFile {
digest,
contents,
path,
mtime,
}
};
if let Err(err) = rx
.send(unsummarized_file)
.map_err(|error| anyhow!(error))
.await
{
log::error!("Error: {:?}", err);
return;
}
}
}
});
}
})
.await;
Ok(())
});
MightNeedSummaryFiles { files: tx, task }
}
fn summarize_files(
&self,
mut unsummarized_files: channel::Receiver<UnsummarizedFile>,
cx: &AppContext,
) -> SummarizeFiles {
let (summarized_tx, summarized_rx) = channel::bounded(512);
let task = cx.spawn(|cx| async move {
while let Some(file) = unsummarized_files.next().await {
log::debug!("Summarizing {:?}", file);
let summary = cx
.update(|cx| Self::summarize_code(&file.contents, &file.path, cx))?
.await
.unwrap_or_else(|err| {
// Log a warning because we'll continue anyway.
// In the future, we may want to try splitting it up into multiple requests and concatenating the summaries,
// but this might give bad summaries due to cutting off source code files in the middle.
log::warn!("Failed to summarize {} - {:?}", file.path.display(), err);
String::new()
});
// Note that the summary could be empty because of an error talking to a cloud provider,
// e.g. because the context limit was exceeded. In that case, we return Ok(String::new()).
if !summary.is_empty() {
summarized_tx
.send(SummarizedFile {
path: file.path.display().to_string(),
digest: file.digest,
summary,
mtime: file.mtime,
})
.await?
}
}
Ok(())
});
SummarizeFiles {
files: summarized_rx,
task,
}
}
fn summarize_code(
code: &str,
path: &Path,
cx: &AppContext,
) -> impl Future<Output = Result<String>> {
let start = Instant::now();
let (summary_model_id, use_cache): (LanguageModelId, bool) = (
"Qwen/Qwen2-7B-Instruct".to_string().into(), // TODO read this from the user's settings.
false, // qwen2 doesn't have a cache, but we should probably infer this from the model
);
let Some(model) = LanguageModelRegistry::read_global(cx)
.available_models(cx)
.find(|model| &model.id() == &summary_model_id)
else {
return cx.background_executor().spawn(async move {
Err(anyhow!("Couldn't find the preferred summarization model ({:?}) in the language registry's available models", summary_model_id))
});
};
let utf8_path = path.to_string_lossy();
const PROMPT_BEFORE_CODE: &str = "Summarize what the code in this file does in 3 sentences, using no newlines or bullet points in the summary:";
let prompt = format!("{PROMPT_BEFORE_CODE}\n{utf8_path}:\n{code}");
log::debug!(
"Summarizing code by sending this prompt to {:?}: {:?}",
model.name(),
&prompt
);
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
cache: use_cache,
}],
tools: Vec::new(),
stop: Vec::new(),
temperature: 1.0,
};
let code_len = code.len();
cx.spawn(|cx| async move {
let stream = model.stream_completion(request, &cx);
cx.background_executor()
.spawn(async move {
let answer: String = stream
.await?
.filter_map(|event| async {
if let Ok(LanguageModelCompletionEvent::Text(text)) = event {
Some(text)
} else {
None
}
})
.collect()
.await;
log::info!(
"It took {:?} to summarize {:?} bytes of code.",
start.elapsed(),
code_len
);
log::debug!("Summary was: {:?}", &answer);
Ok(answer)
})
.await
// TODO if summarization failed, put it back in the backlog!
})
}
fn persist_summaries(
&self,
summaries: channel::Receiver<SummarizedFile>,
cx: &AppContext,
) -> Task<Result<()>> {
let db_connection = self.db_connection.clone();
let digest_db = self.file_digest_db;
let summary_db = self.summary_db;
cx.background_executor().spawn(async move {
let mut summaries = summaries.chunks_timeout(4096, Duration::from_secs(2));
while let Some(summaries) = summaries.next().await {
let mut txn = db_connection.write_txn()?;
for file in &summaries {
log::debug!(
"Saving summary of {:?} - which is {} bytes of summary for content digest {:?}",
&file.path,
file.summary.len(),
file.digest
);
digest_db.put(
&mut txn,
&file.path,
&FileDigest {
mtime: file.mtime,
digest: file.digest,
},
)?;
summary_db.put(&mut txn, &file.digest, &file.summary)?;
}
txn.commit()?;
drop(summaries);
log::debug!("committed summaries");
}
Ok(())
})
}
/// Empty out the backlog of files that haven't been resummarized, and resummarize them immediately.
pub(crate) fn flush_backlog(
&self,
worktree_abs_path: Arc<Path>,
cx: &AppContext,
) -> impl Future<Output = Result<()>> {
let start = Instant::now();
let backlogged = {
let (tx, rx) = channel::bounded(512);
let needs_summary: Vec<(Arc<Path>, Option<SystemTime>)> = {
let mut backlog = self.backlog.lock();
backlog.drain().collect()
};
let task = cx.background_executor().spawn(async move {
tx.send(needs_summary).await?;
Ok(())
});
Backlogged {
paths_to_digest: rx,
task,
}
};
let digest = self.digest_files(backlogged.paths_to_digest, worktree_abs_path, cx);
let needs_summary = self.check_summary_cache(digest.files, cx);
let summaries = self.summarize_files(needs_summary.files, cx);
let persist = self.persist_summaries(summaries.files, cx);
async move {
futures::try_join!(
backlogged.task,
digest.task,
needs_summary.task,
summaries.task,
persist
)?;
log::info!("Summarizing backlogged entries took {:?}", start.elapsed());
Ok(())
}
}
pub(crate) fn backlog_len(&self) -> usize {
self.backlog.lock().len()
}
}
fn db_key_for_path(path: &Arc<Path>) -> String {
path.to_string_lossy().replace('/', "\0")
}

View file

@ -0,0 +1,217 @@
use crate::embedding::EmbeddingProvider;
use crate::embedding_index::EmbeddingIndex;
use crate::indexing::IndexingEntrySet;
use crate::summary_index::SummaryIndex;
use anyhow::Result;
use feature_flags::{AutoCommand, FeatureFlagAppExt};
use fs::Fs;
use futures::future::Shared;
use gpui::{
AppContext, AsyncAppContext, Context, Model, ModelContext, Subscription, Task, WeakModel,
};
use language::LanguageRegistry;
use log;
use project::{UpdatedEntriesSet, Worktree};
use smol::channel;
use std::sync::Arc;
use util::ResultExt;
#[derive(Clone)]
pub enum WorktreeIndexHandle {
Loading {
index: Shared<Task<Result<Model<WorktreeIndex>, Arc<anyhow::Error>>>>,
},
Loaded {
index: Model<WorktreeIndex>,
},
}
pub struct WorktreeIndex {
worktree: Model<Worktree>,
db_connection: heed::Env,
embedding_index: EmbeddingIndex,
summary_index: SummaryIndex,
entry_ids_being_indexed: Arc<IndexingEntrySet>,
_index_entries: Task<Result<()>>,
_subscription: Subscription,
}
impl WorktreeIndex {
pub fn load(
worktree: Model<Worktree>,
db_connection: heed::Env,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
status_tx: channel::Sender<()>,
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut AppContext,
) -> Task<Result<Model<Self>>> {
let worktree_for_index = worktree.clone();
let worktree_for_summary = worktree.clone();
let worktree_abs_path = worktree.read(cx).abs_path();
let embedding_fs = Arc::clone(&fs);
let summary_fs = fs;
cx.spawn(|mut cx| async move {
let entries_being_indexed = Arc::new(IndexingEntrySet::new(status_tx));
let (embedding_index, summary_index) = cx
.background_executor()
.spawn({
let entries_being_indexed = Arc::clone(&entries_being_indexed);
let db_connection = db_connection.clone();
async move {
let mut txn = db_connection.write_txn()?;
let embedding_index = {
let db_name = worktree_abs_path.to_string_lossy();
let db = db_connection.create_database(&mut txn, Some(&db_name))?;
EmbeddingIndex::new(
worktree_for_index,
embedding_fs,
db_connection.clone(),
db,
language_registry,
embedding_provider,
Arc::clone(&entries_being_indexed),
)
};
let summary_index = {
let file_digest_db = {
let db_name =
// Prepend something that wouldn't be found at the beginning of an
// absolute path, so we don't get db key namespace conflicts with
// embeddings, which use the abs path as a key.
format!("digests-{}", worktree_abs_path.to_string_lossy());
db_connection.create_database(&mut txn, Some(&db_name))?
};
let summary_db = {
let db_name =
// Prepend something that wouldn't be found at the beginning of an
// absolute path, so we don't get db key namespace conflicts with
// embeddings, which use the abs path as a key.
format!("summaries-{}", worktree_abs_path.to_string_lossy());
db_connection.create_database(&mut txn, Some(&db_name))?
};
SummaryIndex::new(
worktree_for_summary,
summary_fs,
db_connection.clone(),
file_digest_db,
summary_db,
Arc::clone(&entries_being_indexed),
)
};
txn.commit()?;
anyhow::Ok((embedding_index, summary_index))
}
})
.await?;
cx.new_model(|cx| {
Self::new(
worktree,
db_connection,
embedding_index,
summary_index,
entries_being_indexed,
cx,
)
})
})
}
#[allow(clippy::too_many_arguments)]
pub fn new(
worktree: Model<Worktree>,
db_connection: heed::Env,
embedding_index: EmbeddingIndex,
summary_index: SummaryIndex,
entry_ids_being_indexed: Arc<IndexingEntrySet>,
cx: &mut ModelContext<Self>,
) -> Self {
let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
if let worktree::Event::UpdatedEntries(update) = event {
log::debug!("Updating entries...");
_ = updated_entries_tx.try_send(update.clone());
}
});
Self {
db_connection,
embedding_index,
summary_index,
worktree,
entry_ids_being_indexed,
_index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
_subscription,
}
}
pub fn entry_ids_being_indexed(&self) -> &IndexingEntrySet {
self.entry_ids_being_indexed.as_ref()
}
pub fn worktree(&self) -> &Model<Worktree> {
&self.worktree
}
pub fn db_connection(&self) -> &heed::Env {
&self.db_connection
}
pub fn embedding_index(&self) -> &EmbeddingIndex {
&self.embedding_index
}
pub fn summary_index(&self) -> &SummaryIndex {
&self.summary_index
}
async fn index_entries(
this: WeakModel<Self>,
updated_entries: channel::Receiver<UpdatedEntriesSet>,
mut cx: AsyncAppContext,
) -> Result<()> {
let is_auto_available = cx.update(|cx| cx.wait_for_flag::<AutoCommand>())?.await;
let index = this.update(&mut cx, |this, cx| {
futures::future::try_join(
this.embedding_index.index_entries_changed_on_disk(cx),
this.summary_index
.index_entries_changed_on_disk(is_auto_available, cx),
)
})?;
index.await.log_err();
while let Ok(updated_entries) = updated_entries.recv().await {
let is_auto_available = cx
.update(|cx| cx.has_flag::<AutoCommand>())
.unwrap_or(false);
let index = this.update(&mut cx, |this, cx| {
futures::future::try_join(
this.embedding_index
.index_updated_entries(updated_entries.clone(), cx),
this.summary_index.index_updated_entries(
updated_entries,
is_auto_available,
cx,
),
)
})?;
index.await.log_err();
}
Ok(())
}
#[cfg(test)]
pub fn path_count(&self) -> Result<u64> {
use anyhow::Context;
let txn = self
.db_connection
.read_txn()
.context("failed to create read transaction")?;
Ok(self.embedding_index().db().len(&txn)?)
}
}

View file

@ -3227,6 +3227,8 @@ pub struct Entry {
pub git_status: Option<GitFileStatus>,
/// Whether this entry is considered to be a `.env` file.
pub is_private: bool,
/// The entry's size on disk, in bytes.
pub size: u64,
pub char_bag: CharBag,
pub is_fifo: bool,
}
@ -3282,6 +3284,7 @@ impl Entry {
path,
inode: metadata.inode,
mtime: Some(metadata.mtime),
size: metadata.len,
canonical_path,
is_symlink: metadata.is_symlink,
is_ignored: false,
@ -5210,6 +5213,7 @@ impl<'a> From<&'a Entry> for proto::Entry {
is_external: entry.is_external,
git_status: entry.git_status.map(git_status_to_proto),
is_fifo: entry.is_fifo,
size: Some(entry.size),
}
}
}
@ -5231,6 +5235,7 @@ impl<'a> TryFrom<(&'a CharBag, proto::Entry)> for Entry {
path,
inode: entry.inode,
mtime: entry.mtime.map(|time| time.into()),
size: entry.size.unwrap_or(0),
canonical_path: None,
is_ignored: entry.is_ignored,
is_external: entry.is_external,