Semantic index progress (#11071)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Kyle <kylek@zed.dev>
Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
Max Brunsfeld 2024-04-26 17:06:05 -07:00 committed by GitHub
parent 1aa9c868d4
commit b7d9aeb29d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 298 additions and 407 deletions

1
Cargo.lock generated
View file

@ -8689,6 +8689,7 @@ dependencies = [
"languages", "languages",
"log", "log",
"open_ai", "open_ai",
"parking_lot",
"project", "project",
"serde", "serde",
"serde_json", "serde_json",

View file

@ -87,16 +87,14 @@ fn main() {
let project_index = semantic_index.project_index(project.clone(), cx); let project_index = semantic_index.project_index(project.clone(), cx);
cx.open_window(WindowOptions::default(), |cx| {
let mut tool_registry = ToolRegistry::new(); let mut tool_registry = ToolRegistry::new();
tool_registry tool_registry
.register(ProjectIndexTool::new(project_index.clone(), fs.clone())) .register(ProjectIndexTool::new(project_index.clone(), fs.clone()), cx)
.context("failed to register ProjectIndexTool") .context("failed to register ProjectIndexTool")
.log_err(); .log_err();
let tool_registry = Arc::new(tool_registry); cx.new_view(|cx| Example::new(language_registry, Arc::new(tool_registry), cx))
cx.open_window(WindowOptions::default(), |cx| {
cx.new_view(|cx| Example::new(language_registry, tool_registry, cx))
}); });
cx.activate(true); cx.activate(true);
}) })

View file

@ -135,7 +135,7 @@ impl LanguageModelTool for RollDiceTool {
return Task::ready(Ok(DiceRoll { rolls })); return Task::ready(Ok(DiceRoll { rolls }));
} }
fn new_view( fn output_view(
_tool_call_id: String, _tool_call_id: String,
_input: Self::Input, _input: Self::Input,
result: Result<Self::Output>, result: Result<Self::Output>,
@ -194,9 +194,10 @@ fn main() {
cx.spawn(|cx| async move { cx.spawn(|cx| async move {
cx.update(|cx| { cx.update(|cx| {
cx.open_window(WindowOptions::default(), |cx| {
let mut tool_registry = ToolRegistry::new(); let mut tool_registry = ToolRegistry::new();
tool_registry tool_registry
.register(RollDiceTool::new()) .register(RollDiceTool::new(), cx)
.context("failed to register DummyTool") .context("failed to register DummyTool")
.log_err(); .log_err();
@ -207,7 +208,6 @@ fn main() {
println!("{}", definition); println!("{}", definition);
} }
cx.open_window(WindowOptions::default(), |cx| {
cx.new_view(|cx| Example::new(language_registry, tool_registry, cx)) cx.new_view(|cx| Example::new(language_registry, tool_registry, cx))
}); });
cx.activate(true); cx.activate(true);

View file

@ -115,7 +115,7 @@ impl LanguageModelTool for FileBrowserTool {
}) })
} }
fn new_view( fn output_view(
_tool_call_id: String, _tool_call_id: String,
_input: Self::Input, _input: Self::Input,
result: Result<Self::Output>, result: Result<Self::Output>,
@ -174,9 +174,10 @@ fn main() {
let fs = Arc::new(fs::RealFs::new(None)); let fs = Arc::new(fs::RealFs::new(None));
let cwd = std::env::current_dir().expect("Failed to get current working directory"); let cwd = std::env::current_dir().expect("Failed to get current working directory");
cx.open_window(WindowOptions::default(), |cx| {
let mut tool_registry = ToolRegistry::new(); let mut tool_registry = ToolRegistry::new();
tool_registry tool_registry
.register(FileBrowserTool::new(fs, cwd)) .register(FileBrowserTool::new(fs, cwd), cx)
.context("failed to register FileBrowserTool") .context("failed to register FileBrowserTool")
.log_err(); .log_err();
@ -187,7 +188,6 @@ fn main() {
println!("{}", definition); println!("{}", definition);
} }
cx.open_window(WindowOptions::default(), |cx| {
cx.new_view(|cx| Example::new(language_registry, tool_registry, cx)) cx.new_view(|cx| Example::new(language_registry, tool_registry, cx))
}); });
cx.activate(true); cx.activate(true);

View file

@ -8,22 +8,21 @@ use client::{proto, Client};
use completion_provider::*; use completion_provider::*;
use editor::Editor; use editor::Editor;
use feature_flags::FeatureFlagAppExt as _; use feature_flags::FeatureFlagAppExt as _;
use futures::{channel::oneshot, future::join_all, Future, FutureExt, StreamExt}; use futures::{future::join_all, StreamExt};
use gpui::{ use gpui::{
list, prelude::*, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle, list, prelude::*, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
FocusableView, Global, ListAlignment, ListState, Model, Render, Task, View, WeakView, FocusableView, Global, ListAlignment, ListState, Render, Task, View, WeakView,
}; };
use language::{language_settings::SoftWrap, LanguageRegistry}; use language::{language_settings::SoftWrap, LanguageRegistry};
use open_ai::{FunctionContent, ToolCall, ToolCallContent}; use open_ai::{FunctionContent, ToolCall, ToolCallContent};
use project::Fs;
use rich_text::RichText; use rich_text::RichText;
use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex}; use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::Deserialize; use serde::Deserialize;
use settings::Settings; use settings::Settings;
use std::{cmp, sync::Arc}; use std::sync::Arc;
use theme::ThemeSettings; use theme::ThemeSettings;
use tools::ProjectIndexTool; use tools::ProjectIndexTool;
use ui::{popover_menu, prelude::*, ButtonLike, CollapsibleContainer, Color, ContextMenu, Tooltip}; use ui::{popover_menu, prelude::*, ButtonLike, Color, ContextMenu, Tooltip};
use util::{paths::EMBEDDINGS_DIR, ResultExt}; use util::{paths::EMBEDDINGS_DIR, ResultExt};
use workspace::{ use workspace::{
dock::{DockPosition, Panel, PanelEvent}, dock::{DockPosition, Panel, PanelEvent},
@ -110,10 +109,10 @@ impl AssistantPanel {
let mut tool_registry = ToolRegistry::new(); let mut tool_registry = ToolRegistry::new();
tool_registry tool_registry
.register(ProjectIndexTool::new( .register(
project_index.clone(), ProjectIndexTool::new(project_index.clone(), app_state.fs.clone()),
app_state.fs.clone(), cx,
)) )
.context("failed to register ProjectIndexTool") .context("failed to register ProjectIndexTool")
.log_err(); .log_err();
@ -447,11 +446,7 @@ impl AssistantChat {
} }
editor editor
}); });
let message = ChatMessage::User(UserMessage { let message = ChatMessage::User(UserMessage { id, body });
id,
body,
contexts: Vec::new(),
});
self.push_message(message, cx); self.push_message(message, cx);
} }
@ -525,11 +520,7 @@ impl AssistantChat {
let is_last = ix == self.messages.len() - 1; let is_last = ix == self.messages.len() - 1;
match &self.messages[ix] { match &self.messages[ix] {
ChatMessage::User(UserMessage { ChatMessage::User(UserMessage { body, .. }) => div()
body,
contexts: _contexts,
..
}) => div()
.when(!is_last, |element| element.mb_2()) .when(!is_last, |element| element.mb_2())
.child(div().p_2().child(Label::new("You").color(Color::Default))) .child(div().p_2().child(Label::new("You").color(Color::Default)))
.child( .child(
@ -539,7 +530,7 @@ impl AssistantChat {
.text_color(cx.theme().colors().editor_foreground) .text_color(cx.theme().colors().editor_foreground)
.font(ThemeSettings::get_global(cx).buffer_font.clone()) .font(ThemeSettings::get_global(cx).buffer_font.clone())
.bg(cx.theme().colors().editor_background) .bg(cx.theme().colors().editor_background)
.child(body.clone()), // .children(contexts.iter().map(|context| context.render(cx))), .child(body.clone()),
) )
.into_any(), .into_any(),
ChatMessage::Assistant(AssistantMessage { ChatMessage::Assistant(AssistantMessage {
@ -588,11 +579,11 @@ impl AssistantChat {
for message in &self.messages { for message in &self.messages {
match message { match message {
ChatMessage::User(UserMessage { body, contexts, .. }) => { ChatMessage::User(UserMessage { body, .. }) => {
// setup context for model // When we re-introduce contexts like active file, we'll inject them here instead of relying on the model to request them
contexts.iter().for_each(|context| { // contexts.iter().for_each(|context| {
completion_messages.extend(context.completion_messages(cx)) // completion_messages.extend(context.completion_messages(cx))
}); // });
// Show user's message last so that the assistant is grounded in the user's request // Show user's message last so that the assistant is grounded in the user's request
completion_messages.push(CompletionMessage::User { completion_messages.push(CompletionMessage::User {
@ -712,6 +703,12 @@ impl Render for AssistantChat {
.text_color(Color::Default.color(cx)) .text_color(Color::Default.color(cx))
.child(self.render_model_dropdown(cx)) .child(self.render_model_dropdown(cx))
.child(list(self.list_state.clone()).flex_1()) .child(list(self.list_state.clone()).flex_1())
.child(
h_flex()
.mt_2()
.gap_2()
.children(self.tool_registry.status_views().iter().cloned()),
)
} }
} }
@ -743,7 +740,6 @@ impl ChatMessage {
struct UserMessage { struct UserMessage {
id: MessageId, id: MessageId,
body: View<Editor>, body: View<Editor>,
contexts: Vec<AssistantContext>,
} }
struct AssistantMessage { struct AssistantMessage {
@ -752,211 +748,3 @@ struct AssistantMessage {
tool_calls: Vec<ToolFunctionCall>, tool_calls: Vec<ToolFunctionCall>,
error: Option<SharedString>, error: Option<SharedString>,
} }
// Since we're swapping out for direct query usage, we might not need to use this injected context
// It will be useful though for when the user _definitely_ wants the model to see a specific file,
// query, error, etc.
#[allow(dead_code)]
enum AssistantContext {
Codebase(View<CodebaseContext>),
}
#[allow(dead_code)]
struct CodebaseExcerpt {
element_id: ElementId,
path: SharedString,
text: SharedString,
score: f32,
expanded: bool,
}
impl AssistantContext {
#[allow(dead_code)]
fn render(&self, _cx: &mut ViewContext<AssistantChat>) -> AnyElement {
match self {
AssistantContext::Codebase(context) => context.clone().into_any_element(),
}
}
fn completion_messages(&self, cx: &WindowContext) -> Vec<CompletionMessage> {
match self {
AssistantContext::Codebase(context) => context.read(cx).completion_messages(),
}
}
}
enum CodebaseContext {
Pending { _task: Task<()> },
Done(Result<Vec<CodebaseExcerpt>>),
}
impl CodebaseContext {
fn toggle_expanded(&mut self, element_id: ElementId, cx: &mut ViewContext<Self>) {
if let CodebaseContext::Done(Ok(excerpts)) = self {
if let Some(excerpt) = excerpts
.iter_mut()
.find(|excerpt| excerpt.element_id == element_id)
{
excerpt.expanded = !excerpt.expanded;
cx.notify();
}
}
}
}
impl Render for CodebaseContext {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
match self {
CodebaseContext::Pending { .. } => div()
.h_flex()
.items_center()
.gap_1()
.child(Icon::new(IconName::Ai).color(Color::Muted).into_element())
.child("Searching codebase..."),
CodebaseContext::Done(Ok(excerpts)) => {
div()
.v_flex()
.gap_2()
.children(excerpts.iter().map(|excerpt| {
let expanded = excerpt.expanded;
let element_id = excerpt.element_id.clone();
CollapsibleContainer::new(element_id.clone(), expanded)
.start_slot(
h_flex()
.gap_1()
.child(Icon::new(IconName::File).color(Color::Muted))
.child(Label::new(excerpt.path.clone()).color(Color::Muted)),
)
.on_click(cx.listener(move |this, _, cx| {
this.toggle_expanded(element_id.clone(), cx);
}))
.child(
div()
.p_2()
.rounded_md()
.bg(cx.theme().colors().editor_background)
.child(
excerpt.text.clone(), // todo!(): Show as an editor block
),
)
}))
}
CodebaseContext::Done(Err(error)) => div().child(error.to_string()),
}
}
}
impl CodebaseContext {
#[allow(dead_code)]
fn new(
query: impl 'static + Future<Output = Result<String>>,
populated: oneshot::Sender<bool>,
project_index: Model<ProjectIndex>,
fs: Arc<dyn Fs>,
cx: &mut ViewContext<Self>,
) -> Self {
let query = query.boxed_local();
let _task = cx.spawn(|this, mut cx| async move {
let result = async {
let query = query.await?;
let results = this
.update(&mut cx, |_this, cx| {
project_index.read(cx).search(&query, 16, cx)
})?
.await;
let excerpts = results.into_iter().map(|result| {
let abs_path = result
.worktree
.read_with(&cx, |worktree, _| worktree.abs_path().join(&result.path));
let fs = fs.clone();
async move {
let path = result.path.clone();
let text = fs.load(&abs_path?).await?;
// todo!("what should we do with stale ranges?");
let range = cmp::min(result.range.start, text.len())
..cmp::min(result.range.end, text.len());
let text = SharedString::from(text[range].to_string());
anyhow::Ok(CodebaseExcerpt {
element_id: ElementId::Name(nanoid::nanoid!().into()),
path: path.to_string_lossy().to_string().into(),
text,
score: result.score,
expanded: false,
})
}
});
anyhow::Ok(
futures::future::join_all(excerpts)
.await
.into_iter()
.filter_map(|result| result.log_err())
.collect(),
)
}
.await;
this.update(&mut cx, |this, cx| {
this.populate(result, populated, cx);
})
.ok();
});
Self::Pending { _task }
}
#[allow(dead_code)]
fn populate(
&mut self,
result: Result<Vec<CodebaseExcerpt>>,
populated: oneshot::Sender<bool>,
cx: &mut ViewContext<Self>,
) {
let success = result.is_ok();
*self = Self::Done(result);
populated.send(success).ok();
cx.notify();
}
fn completion_messages(&self) -> Vec<CompletionMessage> {
// One system message for the whole batch of excerpts:
// Semantic search results for user query:
//
// Excerpt from $path:
// ~~~
// `text`
// ~~~
//
// Excerpt from $path:
match self {
CodebaseContext::Done(Ok(excerpts)) => {
if excerpts.is_empty() {
return Vec::new();
}
let mut body = "Semantic search results for user query:\n".to_string();
for excerpt in excerpts {
body.push_str("Excerpt from ");
body.push_str(excerpt.path.as_ref());
body.push_str(", score ");
body.push_str(&excerpt.score.to_string());
body.push_str(":\n");
body.push_str("~~~\n");
body.push_str(excerpt.text.as_ref());
body.push_str("~~~\n");
}
vec![CompletionMessage::System { content: body }]
}
_ => vec![],
}
}
}

View file

@ -1,9 +1,9 @@
use anyhow::Result; use anyhow::Result;
use assistant_tooling::LanguageModelTool; use assistant_tooling::LanguageModelTool;
use gpui::{prelude::*, AppContext, Model, Task}; use gpui::{prelude::*, AnyView, AppContext, Model, Task};
use project::Fs; use project::Fs;
use schemars::JsonSchema; use schemars::JsonSchema;
use semantic_index::ProjectIndex; use semantic_index::{ProjectIndex, Status};
use serde::Deserialize; use serde::Deserialize;
use std::sync::Arc; use std::sync::Arc;
use ui::{ use ui::{
@ -36,13 +36,14 @@ pub struct CodebaseQuery {
pub struct ProjectIndexView { pub struct ProjectIndexView {
input: CodebaseQuery, input: CodebaseQuery,
output: Result<Vec<CodebaseExcerpt>>, output: Result<ProjectIndexOutput>,
} }
impl ProjectIndexView { impl ProjectIndexView {
fn toggle_expanded(&mut self, element_id: ElementId, cx: &mut ViewContext<Self>) { fn toggle_expanded(&mut self, element_id: ElementId, cx: &mut ViewContext<Self>) {
if let Ok(excerpts) = &mut self.output { if let Ok(output) = &mut self.output {
if let Some(excerpt) = excerpts if let Some(excerpt) = output
.excerpts
.iter_mut() .iter_mut()
.find(|excerpt| excerpt.element_id == element_id) .find(|excerpt| excerpt.element_id == element_id)
{ {
@ -59,11 +60,11 @@ impl Render for ProjectIndexView {
let result = &self.output; let result = &self.output;
let excerpts = match result { let output = match result {
Err(err) => { Err(err) => {
return div().child(Label::new(format!("Error: {}", err)).color(Color::Error)); return div().child(Label::new(format!("Error: {}", err)).color(Color::Error));
} }
Ok(excerpts) => excerpts, Ok(output) => output,
}; };
div() div()
@ -80,7 +81,7 @@ impl Render for ProjectIndexView {
.child(Label::new(query).color(Color::Muted)), .child(Label::new(query).color(Color::Muted)),
), ),
) )
.children(excerpts.iter().map(|excerpt| { .children(output.excerpts.iter().map(|excerpt| {
let element_id = excerpt.element_id.clone(); let element_id = excerpt.element_id.clone();
let expanded = excerpt.expanded; let expanded = excerpt.expanded;
@ -99,9 +100,7 @@ impl Render for ProjectIndexView {
.p_2() .p_2()
.rounded_md() .rounded_md()
.bg(cx.theme().colors().editor_background) .bg(cx.theme().colors().editor_background)
.child( .child(excerpt.text.clone()),
excerpt.text.clone(), // todo!(): Show as an editor block
),
) )
})) }))
} }
@ -112,8 +111,15 @@ pub struct ProjectIndexTool {
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
} }
pub struct ProjectIndexOutput {
excerpts: Vec<CodebaseExcerpt>,
status: Status,
}
impl ProjectIndexTool { impl ProjectIndexTool {
pub fn new(project_index: Model<ProjectIndex>, fs: Arc<dyn Fs>) -> Self { pub fn new(project_index: Model<ProjectIndex>, fs: Arc<dyn Fs>) -> Self {
// Listen for project index status and update the ProjectIndexTool directly
// TODO: setup a better description based on the user's current codebase. // TODO: setup a better description based on the user's current codebase.
Self { project_index, fs } Self { project_index, fs }
} }
@ -121,7 +127,7 @@ impl ProjectIndexTool {
impl LanguageModelTool for ProjectIndexTool { impl LanguageModelTool for ProjectIndexTool {
type Input = CodebaseQuery; type Input = CodebaseQuery;
type Output = Vec<CodebaseExcerpt>; type Output = ProjectIndexOutput;
type View = ProjectIndexView; type View = ProjectIndexView;
fn name(&self) -> String { fn name(&self) -> String {
@ -135,6 +141,7 @@ impl LanguageModelTool for ProjectIndexTool {
fn execute(&self, query: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>> { fn execute(&self, query: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>> {
let project_index = self.project_index.read(cx); let project_index = self.project_index.read(cx);
let status = project_index.status();
let results = project_index.search( let results = project_index.search(
query.query.as_str(), query.query.as_str(),
query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT), query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
@ -180,11 +187,11 @@ impl LanguageModelTool for ProjectIndexTool {
.into_iter() .into_iter()
.filter_map(|result| result.log_err()) .filter_map(|result| result.log_err())
.collect(); .collect();
anyhow::Ok(excerpts) anyhow::Ok(ProjectIndexOutput { excerpts, status })
}) })
} }
fn new_view( fn output_view(
_tool_call_id: String, _tool_call_id: String,
input: Self::Input, input: Self::Input,
output: Result<Self::Output>, output: Result<Self::Output>,
@ -193,16 +200,28 @@ impl LanguageModelTool for ProjectIndexTool {
cx.new_view(|_cx| ProjectIndexView { input, output }) cx.new_view(|_cx| ProjectIndexView { input, output })
} }
fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String { fn status_view(&self, cx: &mut WindowContext) -> Option<AnyView> {
match &output { Some(
Ok(excerpts) => { cx.new_view(|cx| ProjectIndexStatusView::new(self.project_index.clone(), cx))
if excerpts.len() == 0 { .into(),
return "No results found".to_string(); )
} }
fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
match &output {
Ok(output) => {
let mut body = "Semantic search results:\n".to_string(); let mut body = "Semantic search results:\n".to_string();
for excerpt in excerpts { if output.status != Status::Idle {
body.push_str("Still indexing. Results may be incomplete.\n");
}
if output.excerpts.is_empty() {
body.push_str("No results found");
return body;
}
for excerpt in &output.excerpts {
body.push_str("Excerpt from "); body.push_str("Excerpt from ");
body.push_str(excerpt.path.as_ref()); body.push_str(excerpt.path.as_ref());
body.push_str(", score "); body.push_str(", score ");
@ -218,3 +237,31 @@ impl LanguageModelTool for ProjectIndexTool {
} }
} }
} }
struct ProjectIndexStatusView {
project_index: Model<ProjectIndex>,
}
impl ProjectIndexStatusView {
pub fn new(project_index: Model<ProjectIndex>, cx: &mut ViewContext<Self>) -> Self {
cx.subscribe(&project_index, |_this, _, _status: &Status, cx| {
cx.notify();
})
.detach();
Self { project_index }
}
}
impl Render for ProjectIndexStatusView {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let status = self.project_index.read(cx).status();
h_flex().gap_2().map(|element| match status {
Status::Idle => element.child(Label::new("Project index ready")),
Status::Loading => element.child(Label::new("Project index loading...")),
Status::Scanning { remaining_count } => element.child(Label::new(format!(
"Project index scanning: {remaining_count} remaining..."
))),
})
}
}

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use gpui::{Task, WindowContext}; use gpui::{AnyView, Task, WindowContext};
use std::collections::HashMap; use std::collections::HashMap;
use crate::tool::{ use crate::tool::{
@ -12,6 +12,7 @@ pub struct ToolRegistry {
Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>, Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
>, >,
definitions: Vec<ToolFunctionDefinition>, definitions: Vec<ToolFunctionDefinition>,
status_views: Vec<AnyView>,
} }
impl ToolRegistry { impl ToolRegistry {
@ -19,6 +20,7 @@ impl ToolRegistry {
Self { Self {
tools: HashMap::new(), tools: HashMap::new(),
definitions: Vec::new(), definitions: Vec::new(),
status_views: Vec::new(),
} }
} }
@ -26,8 +28,17 @@ impl ToolRegistry {
&self.definitions &self.definitions
} }
pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> { pub fn register<T: 'static + LanguageModelTool>(
&mut self,
tool: T,
cx: &mut WindowContext,
) -> Result<()> {
self.definitions.push(tool.definition()); self.definitions.push(tool.definition());
if let Some(tool_view) = tool.status_view(cx) {
self.status_views.push(tool_view);
}
let name = tool.name(); let name = tool.name();
let previous = self.tools.insert( let previous = self.tools.insert(
name.clone(), name.clone(),
@ -52,7 +63,7 @@ impl ToolRegistry {
cx.spawn(move |mut cx| async move { cx.spawn(move |mut cx| async move {
let result: Result<T::Output> = result.await; let result: Result<T::Output> = result.await;
let for_model = T::format(&input, &result); let for_model = T::format(&input, &result);
let view = cx.update(|cx| T::new_view(id.clone(), input, result, cx))?; let view = cx.update(|cx| T::output_view(id.clone(), input, result, cx))?;
Ok(ToolFunctionCall { Ok(ToolFunctionCall {
id, id,
@ -100,6 +111,10 @@ impl ToolRegistry {
tool(tool_call, cx) tool(tool_call, cx)
} }
pub fn status_views(&self) -> &[AnyView] {
&self.status_views
}
} }
#[cfg(test)] #[cfg(test)]
@ -165,7 +180,7 @@ mod test {
Task::ready(Ok(weather)) Task::ready(Ok(weather))
} }
fn new_view( fn output_view(
_tool_call_id: String, _tool_call_id: String,
_input: Self::Input, _input: Self::Input,
result: Result<Self::Output>, result: Result<Self::Output>,
@ -182,46 +197,6 @@ mod test {
} }
} }
#[gpui::test]
async fn test_function_registry(cx: &mut TestAppContext) {
cx.background_executor.run_until_parked();
let mut registry = ToolRegistry::new();
let tool = WeatherTool {
current_weather: WeatherResult {
location: "San Francisco".to_string(),
temperature: 21.0,
unit: "Celsius".to_string(),
},
};
registry.register(tool).unwrap();
// let _result = cx
// .update(|cx| {
// registry.call(
// &ToolFunctionCall {
// name: "get_current_weather".to_string(),
// arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
// .to_string(),
// id: "test-123".to_string(),
// result: None,
// },
// cx,
// )
// })
// .await;
// assert!(result.is_ok());
// let result = result.unwrap();
// let expected = r#"{"location":"San Francisco","temperature":21.0,"unit":"Celsius"}"#;
// todo!(): Put this back in after the interface is stabilized
// assert_eq!(result, expected);
}
#[gpui::test] #[gpui::test]
async fn test_openai_weather_example(cx: &mut TestAppContext) { async fn test_openai_weather_example(cx: &mut TestAppContext) {
cx.background_executor.run_until_parked(); cx.background_executor.run_until_parked();

View file

@ -95,10 +95,14 @@ pub trait LanguageModelTool {
fn format(input: &Self::Input, output: &Result<Self::Output>) -> String; fn format(input: &Self::Input, output: &Result<Self::Output>) -> String;
fn new_view( fn output_view(
tool_call_id: String, tool_call_id: String,
input: Self::Input, input: Self::Input,
output: Result<Self::Output>, output: Result<Self::Output>,
cx: &mut WindowContext, cx: &mut WindowContext,
) -> View<Self::View>; ) -> View<Self::View>;
fn status_view(&self, _cx: &mut WindowContext) -> Option<AnyView> {
None
}
} }

View file

@ -30,6 +30,7 @@ language.workspace = true
log.workspace = true log.workspace = true
heed.workspace = true heed.workspace = true
open_ai.workspace = true open_ai.workspace = true
parking_lot.workspace = true
project.workspace = true project.workspace = true
settings.workspace = true settings.workspace = true
serde.workspace = true serde.workspace = true

View file

@ -3,7 +3,7 @@ mod embedding;
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use chunking::{chunk_text, Chunk}; use chunking::{chunk_text, Chunk};
use collections::{Bound, HashMap}; use collections::{Bound, HashMap, HashSet};
pub use embedding::*; pub use embedding::*;
use fs::Fs; use fs::Fs;
use futures::stream::StreamExt; use futures::stream::StreamExt;
@ -14,15 +14,17 @@ use gpui::{
}; };
use heed::types::{SerdeBincode, Str}; use heed::types::{SerdeBincode, Str};
use language::LanguageRegistry; use language::LanguageRegistry;
use project::{Entry, Project, UpdatedEntriesSet, Worktree}; use parking_lot::Mutex;
use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use smol::channel; use smol::channel;
use std::{ use std::{
cmp::Ordering, cmp::Ordering,
future::Future, future::Future,
num::NonZeroUsize,
ops::Range, ops::Range,
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::Arc, sync::{Arc, Weak},
time::{Duration, SystemTime}, time::{Duration, SystemTime},
}; };
use util::ResultExt; use util::ResultExt;
@ -102,19 +104,16 @@ pub struct ProjectIndex {
worktree_indices: HashMap<EntityId, WorktreeIndexHandle>, worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
pub last_status: Status, last_status: Status,
status_tx: channel::Sender<()>,
embedding_provider: Arc<dyn EmbeddingProvider>, embedding_provider: Arc<dyn EmbeddingProvider>,
_maintain_status: Task<()>,
_subscription: Subscription, _subscription: Subscription,
} }
enum WorktreeIndexHandle { enum WorktreeIndexHandle {
Loading { Loading { _task: Task<Result<()>> },
_task: Task<Result<()>>, Loaded { index: Model<WorktreeIndex> },
},
Loaded {
index: Model<WorktreeIndex>,
_subscription: Subscription,
},
} }
impl ProjectIndex { impl ProjectIndex {
@ -126,20 +125,36 @@ impl ProjectIndex {
) -> Self { ) -> Self {
let language_registry = project.read(cx).languages().clone(); let language_registry = project.read(cx).languages().clone();
let fs = project.read(cx).fs().clone(); let fs = project.read(cx).fs().clone();
let (status_tx, mut status_rx) = channel::unbounded();
let mut this = ProjectIndex { let mut this = ProjectIndex {
db_connection, db_connection,
project: project.downgrade(), project: project.downgrade(),
worktree_indices: HashMap::default(), worktree_indices: HashMap::default(),
language_registry, language_registry,
fs, fs,
status_tx,
last_status: Status::Idle, last_status: Status::Idle,
embedding_provider, embedding_provider,
_subscription: cx.subscribe(&project, Self::handle_project_event), _subscription: cx.subscribe(&project, Self::handle_project_event),
_maintain_status: cx.spawn(|this, mut cx| async move {
while status_rx.next().await.is_some() {
if this
.update(&mut cx, |this, cx| this.update_status(cx))
.is_err()
{
break;
}
}
}),
}; };
this.update_worktree_indices(cx); this.update_worktree_indices(cx);
this this
} }
pub fn status(&self) -> Status {
self.last_status
}
fn handle_project_event( fn handle_project_event(
&mut self, &mut self,
_: Model<Project>, _: Model<Project>,
@ -180,19 +195,18 @@ impl ProjectIndex {
self.db_connection.clone(), self.db_connection.clone(),
self.language_registry.clone(), self.language_registry.clone(),
self.fs.clone(), self.fs.clone(),
self.status_tx.clone(),
self.embedding_provider.clone(), self.embedding_provider.clone(),
cx, cx,
); );
let load_worktree = cx.spawn(|this, mut cx| async move { let load_worktree = cx.spawn(|this, mut cx| async move {
if let Some(index) = worktree_index.await.log_err() { if let Some(worktree_index) = worktree_index.await.log_err() {
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, _| {
this.worktree_indices.insert( this.worktree_indices.insert(
worktree_id, worktree_id,
WorktreeIndexHandle::Loaded { WorktreeIndexHandle::Loaded {
_subscription: cx index: worktree_index,
.observe(&index, |this, _, cx| this.update_status(cx)),
index,
}, },
); );
})?; })?;
@ -215,22 +229,29 @@ impl ProjectIndex {
} }
fn update_status(&mut self, cx: &mut ModelContext<Self>) { fn update_status(&mut self, cx: &mut ModelContext<Self>) {
let mut status = Status::Idle; let mut indexing_count = 0;
for index in self.worktree_indices.values() { let mut any_loading = false;
for index in self.worktree_indices.values_mut() {
match index { match index {
WorktreeIndexHandle::Loading { .. } => { WorktreeIndexHandle::Loading { .. } => {
status = Status::Scanning; any_loading = true;
break; break;
} }
WorktreeIndexHandle::Loaded { index, .. } => { WorktreeIndexHandle::Loaded { index, .. } => {
if index.read(cx).status == Status::Scanning { indexing_count += index.read(cx).entry_ids_being_indexed.len();
status = Status::Scanning;
break;
}
} }
} }
} }
let status = if any_loading {
Status::Loading
} else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
Status::Scanning { remaining_count }
} else {
Status::Idle
};
if status != self.last_status { if status != self.last_status {
self.last_status = status; self.last_status = status;
cx.emit(status); cx.emit(status);
@ -263,6 +284,17 @@ impl ProjectIndex {
results results
}) })
} }
#[cfg(test)]
pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
let mut result = 0;
for worktree_index in self.worktree_indices.values() {
if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
result += index.read(cx).path_count()?;
}
}
Ok(result)
}
} }
pub struct SearchResult { pub struct SearchResult {
@ -275,7 +307,8 @@ pub struct SearchResult {
#[derive(Copy, Clone, Debug, Eq, PartialEq)] #[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum Status { pub enum Status {
Idle, Idle,
Scanning, Loading,
Scanning { remaining_count: NonZeroUsize },
} }
impl EventEmitter<Status> for ProjectIndex {} impl EventEmitter<Status> for ProjectIndex {}
@ -287,7 +320,7 @@ struct WorktreeIndex {
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
embedding_provider: Arc<dyn EmbeddingProvider>, embedding_provider: Arc<dyn EmbeddingProvider>,
status: Status, entry_ids_being_indexed: Arc<IndexingEntrySet>,
_index_entries: Task<Result<()>>, _index_entries: Task<Result<()>>,
_subscription: Subscription, _subscription: Subscription,
} }
@ -298,6 +331,7 @@ impl WorktreeIndex {
db_connection: heed::Env, db_connection: heed::Env,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
status_tx: channel::Sender<()>,
embedding_provider: Arc<dyn EmbeddingProvider>, embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut AppContext, cx: &mut AppContext,
) -> Task<Result<Model<Self>>> { ) -> Task<Result<Model<Self>>> {
@ -321,6 +355,7 @@ impl WorktreeIndex {
worktree, worktree,
db_connection, db_connection,
db, db,
status_tx,
language_registry, language_registry,
fs, fs,
embedding_provider, embedding_provider,
@ -330,10 +365,12 @@ impl WorktreeIndex {
}) })
} }
#[allow(clippy::too_many_arguments)]
fn new( fn new(
worktree: Model<Worktree>, worktree: Model<Worktree>,
db_connection: heed::Env, db_connection: heed::Env,
db: heed::Database<Str, SerdeBincode<EmbeddedFile>>, db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
status: channel::Sender<()>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
embedding_provider: Arc<dyn EmbeddingProvider>, embedding_provider: Arc<dyn EmbeddingProvider>,
@ -353,7 +390,7 @@ impl WorktreeIndex {
language_registry, language_registry,
fs, fs,
embedding_provider, embedding_provider,
status: Status::Idle, entry_ids_being_indexed: Arc::new(IndexingEntrySet::new(status)),
_index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)), _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
_subscription, _subscription,
} }
@ -364,28 +401,14 @@ impl WorktreeIndex {
updated_entries: channel::Receiver<UpdatedEntriesSet>, updated_entries: channel::Receiver<UpdatedEntriesSet>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<()> { ) -> Result<()> {
let index = this.update(&mut cx, |this, cx| { let index = this.update(&mut cx, |this, cx| this.index_entries_changed_on_disk(cx))?;
cx.notify();
this.status = Status::Scanning;
this.index_entries_changed_on_disk(cx)
})?;
index.await.log_err(); index.await.log_err();
this.update(&mut cx, |this, cx| {
this.status = Status::Idle;
cx.notify();
})?;
while let Ok(updated_entries) = updated_entries.recv().await { while let Ok(updated_entries) = updated_entries.recv().await {
let index = this.update(&mut cx, |this, cx| { let index = this.update(&mut cx, |this, cx| {
cx.notify();
this.status = Status::Scanning;
this.index_updated_entries(updated_entries, cx) this.index_updated_entries(updated_entries, cx)
})?; })?;
index.await.log_err(); index.await.log_err();
this.update(&mut cx, |this, cx| {
this.status = Status::Idle;
cx.notify();
})?;
} }
Ok(()) Ok(())
@ -426,6 +449,7 @@ impl WorktreeIndex {
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128); let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
let db_connection = self.db_connection.clone(); let db_connection = self.db_connection.clone();
let db = self.db; let db = self.db;
let entries_being_indexed = self.entry_ids_being_indexed.clone();
let task = cx.background_executor().spawn(async move { let task = cx.background_executor().spawn(async move {
let txn = db_connection let txn = db_connection
.read_txn() .read_txn()
@ -476,7 +500,8 @@ impl WorktreeIndex {
} }
if entry.mtime != saved_mtime { if entry.mtime != saved_mtime {
updated_entries_tx.send(entry.clone()).await?; let handle = entries_being_indexed.insert(&entry);
updated_entries_tx.send((entry.clone(), handle)).await?;
} }
} }
@ -505,6 +530,7 @@ impl WorktreeIndex {
) -> ScanEntries { ) -> ScanEntries {
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512); let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128); let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
let entries_being_indexed = self.entry_ids_being_indexed.clone();
let task = cx.background_executor().spawn(async move { let task = cx.background_executor().spawn(async move {
for (path, entry_id, status) in updated_entries.iter() { for (path, entry_id, status) in updated_entries.iter() {
match status { match status {
@ -513,7 +539,8 @@ impl WorktreeIndex {
| project::PathChange::AddedOrUpdated => { | project::PathChange::AddedOrUpdated => {
if let Some(entry) = worktree.entry_for_id(*entry_id) { if let Some(entry) = worktree.entry_for_id(*entry_id) {
if entry.is_file() { if entry.is_file() {
updated_entries_tx.send(entry.clone()).await?; let handle = entries_being_indexed.insert(&entry);
updated_entries_tx.send((entry.clone(), handle)).await?;
} }
} }
} }
@ -542,7 +569,7 @@ impl WorktreeIndex {
fn chunk_files( fn chunk_files(
&self, &self,
worktree_abs_path: Arc<Path>, worktree_abs_path: Arc<Path>,
entries: channel::Receiver<Entry>, entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
cx: &AppContext, cx: &AppContext,
) -> ChunkFiles { ) -> ChunkFiles {
let language_registry = self.language_registry.clone(); let language_registry = self.language_registry.clone();
@ -553,7 +580,7 @@ impl WorktreeIndex {
.scoped(|cx| { .scoped(|cx| {
for _ in 0..cx.num_cpus() { for _ in 0..cx.num_cpus() {
cx.spawn(async { cx.spawn(async {
while let Ok(entry) = entries.recv().await { while let Ok((entry, handle)) = entries.recv().await {
let entry_abs_path = worktree_abs_path.join(&entry.path); let entry_abs_path = worktree_abs_path.join(&entry.path);
let Some(text) = fs let Some(text) = fs
.load(&entry_abs_path) .load(&entry_abs_path)
@ -572,8 +599,8 @@ impl WorktreeIndex {
let grammar = let grammar =
language.as_ref().and_then(|language| language.grammar()); language.as_ref().and_then(|language| language.grammar());
let chunked_file = ChunkedFile { let chunked_file = ChunkedFile {
worktree_root: worktree_abs_path.clone(),
chunks: chunk_text(&text, grammar), chunks: chunk_text(&text, grammar),
handle,
entry, entry,
text, text,
}; };
@ -622,7 +649,11 @@ impl WorktreeIndex {
let mut embeddings = Vec::new(); let mut embeddings = Vec::new();
for embedding_batch in chunks.chunks(embedding_provider.batch_size()) { for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
embeddings.extend(embedding_provider.embed(embedding_batch).await?); if let Some(batch_embeddings) =
embedding_provider.embed(embedding_batch).await.log_err()
{
embeddings.extend_from_slice(&batch_embeddings);
}
} }
let mut embeddings = embeddings.into_iter(); let mut embeddings = embeddings.into_iter();
@ -643,7 +674,9 @@ impl WorktreeIndex {
chunks: embedded_chunks, chunks: embedded_chunks,
}; };
embedded_files_tx.send(embedded_file).await?; embedded_files_tx
.send((embedded_file, chunked_file.handle))
.await?;
} }
} }
Ok(()) Ok(())
@ -658,7 +691,7 @@ impl WorktreeIndex {
fn persist_embeddings( fn persist_embeddings(
&self, &self,
mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>, mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
embedded_files: channel::Receiver<EmbeddedFile>, embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
cx: &AppContext, cx: &AppContext,
) -> Task<Result<()>> { ) -> Task<Result<()>> {
let db_connection = self.db_connection.clone(); let db_connection = self.db_connection.clone();
@ -676,12 +709,15 @@ impl WorktreeIndex {
let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2)); let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
while let Some(embedded_files) = embedded_files.next().await { while let Some(embedded_files) = embedded_files.next().await {
let mut txn = db_connection.write_txn()?; let mut txn = db_connection.write_txn()?;
for file in embedded_files { for (file, _) in &embedded_files {
log::debug!("saving embedding for file {:?}", file.path); log::debug!("saving embedding for file {:?}", file.path);
let key = db_key_for_path(&file.path); let key = db_key_for_path(&file.path);
db.put(&mut txn, &key, &file)?; db.put(&mut txn, &key, file)?;
} }
txn.commit()?; txn.commit()?;
eprintln!("committed {:?}", embedded_files.len());
drop(embedded_files);
log::debug!("committed"); log::debug!("committed");
} }
@ -789,10 +825,19 @@ impl WorktreeIndex {
Ok(search_results) Ok(search_results)
}) })
} }
#[cfg(test)]
fn path_count(&self) -> Result<u64> {
let txn = self
.db_connection
.read_txn()
.context("failed to create read transaction")?;
Ok(self.db.len(&txn)?)
}
} }
struct ScanEntries { struct ScanEntries {
updated_entries: channel::Receiver<Entry>, updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>, deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
task: Task<Result<()>>, task: Task<Result<()>>,
} }
@ -803,15 +848,14 @@ struct ChunkFiles {
} }
struct ChunkedFile { struct ChunkedFile {
#[allow(dead_code)]
pub worktree_root: Arc<Path>,
pub entry: Entry, pub entry: Entry,
pub handle: IndexingEntryHandle,
pub text: String, pub text: String,
pub chunks: Vec<Chunk>, pub chunks: Vec<Chunk>,
} }
struct EmbedFiles { struct EmbedFiles {
files: channel::Receiver<EmbeddedFile>, files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
task: Task<Result<()>>, task: Task<Result<()>>,
} }
@ -828,6 +872,47 @@ struct EmbeddedChunk {
embedding: Embedding, embedding: Embedding,
} }
struct IndexingEntrySet {
entry_ids: Mutex<HashSet<ProjectEntryId>>,
tx: channel::Sender<()>,
}
struct IndexingEntryHandle {
entry_id: ProjectEntryId,
set: Weak<IndexingEntrySet>,
}
impl IndexingEntrySet {
fn new(tx: channel::Sender<()>) -> Self {
Self {
entry_ids: Default::default(),
tx,
}
}
fn insert(self: &Arc<Self>, entry: &project::Entry) -> IndexingEntryHandle {
self.entry_ids.lock().insert(entry.id);
self.tx.send_blocking(()).ok();
IndexingEntryHandle {
entry_id: entry.id,
set: Arc::downgrade(self),
}
}
pub fn len(&self) -> usize {
self.entry_ids.lock().len()
}
}
impl Drop for IndexingEntryHandle {
fn drop(&mut self) {
if let Some(set) = self.set.upgrade() {
set.tx.send_blocking(()).ok();
set.entry_ids.lock().remove(&self.entry_id);
}
}
}
fn db_key_for_path(path: &Arc<Path>) -> String { fn db_key_for_path(path: &Arc<Path>) -> String {
path.to_string_lossy().replace('/', "\0") path.to_string_lossy().replace('/', "\0")
} }
@ -835,10 +920,7 @@ fn db_key_for_path(path: &Arc<Path>) -> String {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use futures::channel::oneshot;
use futures::{future::BoxFuture, FutureExt}; use futures::{future::BoxFuture, FutureExt};
use gpui::{Global, TestAppContext}; use gpui::{Global, TestAppContext};
use language::language_settings::AllLanguageSettings; use language::language_settings::AllLanguageSettings;
use project::Project; use project::Project;
@ -922,18 +1004,13 @@ mod tests {
let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx)); let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
let (tx, rx) = oneshot::channel(); while project_index
let mut tx = Some(tx); .read_with(cx, |index, cx| index.path_count(cx))
let subscription = cx.update(|cx| { .unwrap()
cx.subscribe(&project_index, move |_, event, _| { == 0
if let Some(tx) = tx.take() { {
_ = tx.send(*event); project_index.next_event(cx).await;
} }
})
});
rx.await.expect("no event emitted");
drop(subscription);
let results = cx let results = cx
.update(|cx| { .update(|cx| {