Streaming tools (#11629)

Stream characters in for tool calls to allow rendering partial input.


https://github.com/zed-industries/zed/assets/836375/0f023a4b-9c46-4449-ae69-8b6bcab41673

Release Notes:

- N/A

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Kyle Kelley 2024-05-09 15:57:14 -07:00 committed by GitHub
parent 27ed0f4273
commit 50c45c7897
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 786 additions and 653 deletions

View file

@ -1,4 +1,4 @@
use anyhow::{anyhow, Result};
use anyhow::Result;
use assistant_tooling::{LanguageModelTool, ToolOutput};
use collections::BTreeMap;
use gpui::{prelude::*, Model, Task};
@ -6,9 +6,8 @@ use project::ProjectPath;
use schemars::JsonSchema;
use semantic_index::{ProjectIndex, Status};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{fmt::Write as _, ops::Range, path::Path, sync::Arc};
use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
use ui::{prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
const DEFAULT_SEARCH_LIMIT: usize = 20;
@ -16,10 +15,26 @@ pub struct ProjectIndexTool {
project_index: Model<ProjectIndex>,
}
// Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model.
// Any changes or deletions to the `CodebaseQuery` comments will change model behavior.
#[derive(Default)]
enum ProjectIndexToolState {
#[default]
CollectingQuery,
Searching,
Error(anyhow::Error),
Finished {
excerpts: BTreeMap<ProjectPath, Vec<Range<usize>>>,
index_status: Status,
},
}
#[derive(Deserialize, JsonSchema)]
pub struct ProjectIndexView {
project_index: Model<ProjectIndex>,
input: CodebaseQuery,
expanded_header: bool,
state: ProjectIndexToolState,
}
#[derive(Default, Deserialize, JsonSchema)]
pub struct CodebaseQuery {
/// Semantic search query
query: String,
@ -27,21 +42,14 @@ pub struct CodebaseQuery {
limit: Option<usize>,
}
pub struct ProjectIndexView {
input: CodebaseQuery,
status: Status,
excerpts: Result<BTreeMap<ProjectPath, Vec<Range<usize>>>>,
element_id: ElementId,
expanded_header: bool,
}
#[derive(Serialize, Deserialize)]
pub struct ProjectIndexOutput {
status: Status,
pub struct SerializedState {
index_status: Status,
error_message: Option<String>,
worktrees: BTreeMap<Arc<Path>, WorktreeIndexOutput>,
}
#[derive(Serialize, Deserialize)]
#[derive(Default, Serialize, Deserialize)]
struct WorktreeIndexOutput {
excerpts: BTreeMap<Arc<Path>, Vec<Range<usize>>>,
}
@ -56,58 +64,80 @@ impl ProjectIndexView {
impl Render for ProjectIndexView {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let query = self.input.query.clone();
let excerpts = match &self.excerpts {
Err(err) => {
return div().child(Label::new(format!("Error: {}", err)).color(Color::Error));
let (header_text, content) = match &self.state {
ProjectIndexToolState::Error(error) => {
return format!("failed to search: {error:?}").into_any_element()
}
ProjectIndexToolState::CollectingQuery | ProjectIndexToolState::Searching => {
("Searching...".to_string(), div())
}
ProjectIndexToolState::Finished { excerpts, .. } => {
let file_count = excerpts.len();
let header_text = format!(
"Read {} {}",
file_count,
if file_count == 1 { "file" } else { "files" }
);
let el = v_flex().gap_2().children(excerpts.keys().map(|path| {
h_flex().gap_2().child(Icon::new(IconName::File)).child(
Label::new(path.path.to_string_lossy().to_string()).color(Color::Muted),
)
}));
(header_text, el)
}
Ok(excerpts) => excerpts,
};
let file_count = excerpts.len();
let header = h_flex()
.gap_2()
.child(Icon::new(IconName::File))
.child(format!(
"Read {} {}",
file_count,
if file_count == 1 { "file" } else { "files" }
));
.child(header_text);
v_flex().gap_3().child(
CollapsibleContainer::new(self.element_id.clone(), self.expanded_header)
.start_slot(header)
.on_click(cx.listener(move |this, _, cx| {
this.toggle_header(cx);
}))
.child(
v_flex()
.gap_3()
.p_3()
.child(
h_flex()
.gap_2()
.child(Icon::new(IconName::MagnifyingGlass))
.child(Label::new(format!("`{}`", query)).color(Color::Muted)),
)
.child(v_flex().gap_2().children(excerpts.keys().map(|path| {
h_flex().gap_2().child(Icon::new(IconName::File)).child(
Label::new(path.path.to_string_lossy().to_string())
.color(Color::Muted),
v_flex()
.gap_3()
.child(
CollapsibleContainer::new("collapsible-container", self.expanded_header)
.start_slot(header)
.on_click(cx.listener(move |this, _, cx| {
this.toggle_header(cx);
}))
.child(
v_flex()
.gap_3()
.p_3()
.child(
h_flex()
.gap_2()
.child(Icon::new(IconName::MagnifyingGlass))
.child(Label::new(format!("`{}`", query)).color(Color::Muted)),
)
}))),
),
)
.child(content),
),
)
.into_any_element()
}
}
impl ToolOutput for ProjectIndexView {
type Input = CodebaseQuery;
type SerializedState = SerializedState;
fn generate(
&self,
context: &mut assistant_tooling::ProjectContext,
_: &mut WindowContext,
_: &mut ViewContext<Self>,
) -> String {
match &self.excerpts {
Ok(excerpts) => {
match &self.state {
ProjectIndexToolState::CollectingQuery => String::new(),
ProjectIndexToolState::Searching => String::new(),
ProjectIndexToolState::Error(error) => format!("failed to search: {error:?}"),
ProjectIndexToolState::Finished {
excerpts,
index_status,
} => {
let mut body = "found results in the following paths:\n".to_string();
for (project_path, ranges) in excerpts {
@ -115,15 +145,126 @@ impl ToolOutput for ProjectIndexView {
writeln!(&mut body, "* {}", &project_path.path.display()).unwrap();
}
if self.status != Status::Idle {
if *index_status != Status::Idle {
body.push_str("Still indexing. Results may be incomplete.\n");
}
body
}
Err(err) => format!("Error: {}", err),
}
}
fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
self.input = input;
cx.notify();
}
fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
self.state = ProjectIndexToolState::Searching;
cx.notify();
let project_index = self.project_index.read(cx);
let index_status = project_index.status();
let search = project_index.search(
self.input.query.clone(),
self.input.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
cx,
);
cx.spawn(|this, mut cx| async move {
let search_result = search.await;
this.update(&mut cx, |this, cx| {
match search_result {
Ok(search_results) => {
let mut excerpts = BTreeMap::<ProjectPath, Vec<Range<usize>>>::new();
for search_result in search_results {
let project_path = ProjectPath {
worktree_id: search_result.worktree.read(cx).id(),
path: search_result.path,
};
excerpts
.entry(project_path)
.or_default()
.push(search_result.range);
}
this.state = ProjectIndexToolState::Finished {
excerpts,
index_status,
};
}
Err(error) => {
this.state = ProjectIndexToolState::Error(error);
}
}
cx.notify();
})
})
}
fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState {
let mut serialized = SerializedState {
error_message: None,
index_status: Status::Idle,
worktrees: Default::default(),
};
match &self.state {
ProjectIndexToolState::Error(err) => serialized.error_message = Some(err.to_string()),
ProjectIndexToolState::Finished {
excerpts,
index_status,
} => {
serialized.index_status = *index_status;
if let Some(project) = self.project_index.read(cx).project().upgrade() {
let project = project.read(cx);
for (project_path, excerpts) in excerpts {
if let Some(worktree) =
project.worktree_for_id(project_path.worktree_id, cx)
{
let worktree_path = worktree.read(cx).abs_path();
serialized
.worktrees
.entry(worktree_path)
.or_default()
.excerpts
.insert(project_path.path.clone(), excerpts.clone());
}
}
}
}
_ => {}
}
serialized
}
fn deserialize(
&mut self,
serialized: Self::SerializedState,
cx: &mut ViewContext<Self>,
) -> Result<()> {
if !serialized.worktrees.is_empty() {
let mut excerpts = BTreeMap::<ProjectPath, Vec<Range<usize>>>::new();
if let Some(project) = self.project_index.read(cx).project().upgrade() {
let project = project.read(cx);
for (worktree_path, worktree_state) in serialized.worktrees {
if let Some(worktree) = project
.worktrees()
.find(|worktree| worktree.read(cx).abs_path() == worktree_path)
{
let worktree_id = worktree.read(cx).id();
for (path, serialized_excerpts) in worktree_state.excerpts {
excerpts.insert(ProjectPath { worktree_id, path }, serialized_excerpts);
}
}
}
}
self.state = ProjectIndexToolState::Finished {
excerpts,
index_status: serialized.index_status,
};
}
cx.notify();
Ok(())
}
}
impl ProjectIndexTool {
@ -133,8 +274,6 @@ impl ProjectIndexTool {
}
impl LanguageModelTool for ProjectIndexTool {
type Input = CodebaseQuery;
type Output = ProjectIndexOutput;
type View = ProjectIndexView;
fn name(&self) -> String {
@ -145,109 +284,12 @@ impl LanguageModelTool for ProjectIndexTool {
"Semantic search against the user's current codebase, returning excerpts related to the query by computing a dot product against embeddings of code chunks in the code base and an embedding of the query.".to_string()
}
fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
let project_index = self.project_index.read(cx);
let status = project_index.status();
let search = project_index.search(
query.query.clone(),
query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
cx,
);
cx.spawn(|mut cx| async move {
let search_results = search.await?;
cx.update(|cx| {
let mut output = ProjectIndexOutput {
status,
worktrees: Default::default(),
};
for search_result in search_results {
let worktree_path = search_result.worktree.read(cx).abs_path();
let excerpts = &mut output
.worktrees
.entry(worktree_path)
.or_insert(WorktreeIndexOutput {
excerpts: Default::default(),
})
.excerpts;
let excerpts_for_path = excerpts.entry(search_result.path).or_default();
let ix = match excerpts_for_path
.binary_search_by_key(&search_result.range.start, |r| r.start)
{
Ok(ix) | Err(ix) => ix,
};
excerpts_for_path.insert(ix, search_result.range);
}
output
})
fn view(&self, cx: &mut WindowContext) -> gpui::View<Self::View> {
cx.new_view(|_| ProjectIndexView {
state: ProjectIndexToolState::CollectingQuery,
input: Default::default(),
expanded_header: false,
project_index: self.project_index.clone(),
})
}
fn view(
&self,
input: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,
) -> gpui::View<Self::View> {
cx.new_view(|cx| {
let status;
let excerpts;
match output {
Ok(output) => {
status = output.status;
let project_index = self.project_index.read(cx);
if let Some(project) = project_index.project().upgrade() {
let project = project.read(cx);
excerpts = Ok(output
.worktrees
.into_iter()
.filter_map(|(abs_path, output)| {
for worktree in project.worktrees() {
let worktree = worktree.read(cx);
if worktree.abs_path() == abs_path {
return Some((worktree.id(), output.excerpts));
}
}
None
})
.flat_map(|(worktree_id, excerpts)| {
excerpts.into_iter().map(move |(path, ranges)| {
(ProjectPath { worktree_id, path }, ranges)
})
})
.collect::<BTreeMap<_, _>>());
} else {
excerpts = Err(anyhow!("project was dropped"));
}
}
Err(err) => {
status = Status::Idle;
excerpts = Err(err);
}
};
ProjectIndexView {
input,
status,
excerpts,
element_id: ElementId::Name(nanoid::nanoid!().into()),
expanded_header: false,
}
})
}
fn render_running(arguments: &Option<Value>, _: &mut WindowContext) -> impl IntoElement {
let text: String = arguments
.as_ref()
.and_then(|arguments| arguments.get("query"))
.and_then(|query| query.as_str())
.map(|query| format!("Searching for: {}", query))
.unwrap_or_else(|| "Preparing search...".to_string());
CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false).start_slot(text)
}
}