Centralize project context provided to the assistant (#11471)

This PR restructures the way that tools and attachments add information
about the current project to a conversation with the assistant. Rather
than each tool call or attachment generating a new tool or system
message containing information about the project, they can all
collectively mutate a new type called a `ProjectContext`, which stores
all of the project data that should be sent to the assistant. That data
is then formatted in a single place, and passed to the assistant in one
system message.

This prevents multiple tools/attachments from including redundant
context.

Release Notes:

- N/A

---------

Co-authored-by: Kyle <kylek@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-05-06 17:01:50 -07:00 committed by GitHub
parent f2a415135b
commit a64e20ed96
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 841 additions and 518 deletions

7
Cargo.lock generated
View file

@ -411,10 +411,17 @@ name = "assistant_tooling"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"collections",
"futures 0.3.28",
"gpui", "gpui",
"project",
"schemars", "schemars",
"serde", "serde",
"serde_json", "serde_json",
"settings",
"sum_tree",
"unindent",
"util",
] ]
[[package]] [[package]]

View file

@ -4,10 +4,16 @@ mod completion_provider;
mod tools; mod tools;
pub mod ui; pub mod ui;
use crate::{
attachments::ActiveEditorAttachmentTool,
tools::{CreateBufferTool, ProjectIndexTool},
ui::UserOrAssistant,
};
use ::ui::{div, prelude::*, Color, ViewContext}; use ::ui::{div, prelude::*, Color, ViewContext};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use assistant_tooling::{ToolFunctionCall, ToolRegistry}; use assistant_tooling::{
use attachments::{ActiveEditorAttachmentTool, UserAttachment, UserAttachmentStore}; AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment,
};
use client::{proto, Client, UserStore}; use client::{proto, Client, UserStore};
use collections::HashMap; use collections::HashMap;
use completion_provider::*; use completion_provider::*;
@ -34,9 +40,6 @@ use workspace::{
pub use assistant_settings::AssistantSettings; pub use assistant_settings::AssistantSettings;
use crate::tools::{CreateBufferTool, ProjectIndexTool};
use crate::ui::UserOrAssistant;
const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5; const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5;
#[derive(Eq, PartialEq, Copy, Clone, Deserialize)] #[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
@ -85,11 +88,10 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
}); });
workspace.register_action(|workspace, _: &DebugProjectIndex, cx| { workspace.register_action(|workspace, _: &DebugProjectIndex, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) { if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
if let Some(index) = panel.read(cx).chat.read(cx).project_index.clone() { let index = panel.read(cx).chat.read(cx).project_index.clone();
let view = cx.new_view(|cx| ProjectIndexDebugView::new(index, cx)); let view = cx.new_view(|cx| ProjectIndexDebugView::new(index, cx));
workspace.add_item_to_center(Box::new(view), cx); workspace.add_item_to_center(Box::new(view), cx);
} }
}
}); });
}, },
) )
@ -122,10 +124,7 @@ impl AssistantPanel {
let mut tool_registry = ToolRegistry::new(); let mut tool_registry = ToolRegistry::new();
tool_registry tool_registry
.register( .register(ProjectIndexTool::new(project_index.clone()), cx)
ProjectIndexTool::new(project_index.clone(), project.read(cx).fs().clone()),
cx,
)
.context("failed to register ProjectIndexTool") .context("failed to register ProjectIndexTool")
.log_err(); .log_err();
tool_registry tool_registry
@ -136,7 +135,7 @@ impl AssistantPanel {
.context("failed to register CreateBufferTool") .context("failed to register CreateBufferTool")
.log_err(); .log_err();
let mut attachment_store = UserAttachmentStore::new(); let mut attachment_store = AttachmentRegistry::new();
attachment_store.register(ActiveEditorAttachmentTool::new(workspace.clone(), cx)); attachment_store.register(ActiveEditorAttachmentTool::new(workspace.clone(), cx));
Self::new( Self::new(
@ -144,7 +143,7 @@ impl AssistantPanel {
Arc::new(tool_registry), Arc::new(tool_registry),
Arc::new(attachment_store), Arc::new(attachment_store),
app_state.user_store.clone(), app_state.user_store.clone(),
Some(project_index), project_index,
workspace, workspace,
cx, cx,
) )
@ -155,9 +154,9 @@ impl AssistantPanel {
pub fn new( pub fn new(
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
tool_registry: Arc<ToolRegistry>, tool_registry: Arc<ToolRegistry>,
attachment_store: Arc<UserAttachmentStore>, attachment_store: Arc<AttachmentRegistry>,
user_store: Model<UserStore>, user_store: Model<UserStore>,
project_index: Option<Model<ProjectIndex>>, project_index: Model<ProjectIndex>,
workspace: WeakView<Workspace>, workspace: WeakView<Workspace>,
cx: &mut ViewContext<Self>, cx: &mut ViewContext<Self>,
) -> Self { ) -> Self {
@ -241,16 +240,16 @@ pub struct AssistantChat {
list_state: ListState, list_state: ListState,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
composer_editor: View<Editor>, composer_editor: View<Editor>,
project_index_button: Option<View<ProjectIndexButton>>, project_index_button: View<ProjectIndexButton>,
active_file_button: Option<View<ActiveFileButton>>, active_file_button: Option<View<ActiveFileButton>>,
user_store: Model<UserStore>, user_store: Model<UserStore>,
next_message_id: MessageId, next_message_id: MessageId,
collapsed_messages: HashMap<MessageId, bool>, collapsed_messages: HashMap<MessageId, bool>,
editing_message: Option<EditingMessage>, editing_message: Option<EditingMessage>,
pending_completion: Option<Task<()>>, pending_completion: Option<Task<()>>,
attachment_store: Arc<UserAttachmentStore>,
tool_registry: Arc<ToolRegistry>, tool_registry: Arc<ToolRegistry>,
project_index: Option<Model<ProjectIndex>>, attachment_registry: Arc<AttachmentRegistry>,
project_index: Model<ProjectIndex>,
} }
struct EditingMessage { struct EditingMessage {
@ -263,9 +262,9 @@ impl AssistantChat {
fn new( fn new(
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
tool_registry: Arc<ToolRegistry>, tool_registry: Arc<ToolRegistry>,
attachment_store: Arc<UserAttachmentStore>, attachment_registry: Arc<AttachmentRegistry>,
user_store: Model<UserStore>, user_store: Model<UserStore>,
project_index: Option<Model<ProjectIndex>>, project_index: Model<ProjectIndex>,
workspace: WeakView<Workspace>, workspace: WeakView<Workspace>,
cx: &mut ViewContext<Self>, cx: &mut ViewContext<Self>,
) -> Self { ) -> Self {
@ -281,14 +280,14 @@ impl AssistantChat {
}, },
); );
let project_index_button = project_index.clone().map(|project_index| { let project_index_button = cx.new_view(|cx| {
cx.new_view(|cx| ProjectIndexButton::new(project_index, tool_registry.clone(), cx)) ProjectIndexButton::new(project_index.clone(), tool_registry.clone(), cx)
}); });
let active_file_button = match workspace.upgrade() { let active_file_button = match workspace.upgrade() {
Some(workspace) => { Some(workspace) => {
Some(cx.new_view( Some(cx.new_view(
|cx| ActiveFileButton::new(attachment_store.clone(), workspace, cx), // |cx| ActiveFileButton::new(attachment_registry.clone(), workspace, cx), //
)) ))
} }
_ => None, _ => None,
@ -313,7 +312,7 @@ impl AssistantChat {
editing_message: None, editing_message: None,
collapsed_messages: HashMap::default(), collapsed_messages: HashMap::default(),
pending_completion: None, pending_completion: None,
attachment_store, attachment_registry,
tool_registry, tool_registry,
} }
} }
@ -395,7 +394,7 @@ impl AssistantChat {
let mode = *mode; let mode = *mode;
self.pending_completion = Some(cx.spawn(move |this, mut cx| async move { self.pending_completion = Some(cx.spawn(move |this, mut cx| async move {
let attachments_task = this.update(&mut cx, |this, cx| { let attachments_task = this.update(&mut cx, |this, cx| {
let attachment_store = this.attachment_store.clone(); let attachment_store = this.attachment_registry.clone();
attachment_store.call_all_attachment_tools(cx) attachment_store.call_all_attachment_tools(cx)
}); });
@ -443,7 +442,7 @@ impl AssistantChat {
let mut call_count = 0; let mut call_count = 0;
loop { loop {
let complete = async { let complete = async {
let completion = this.update(cx, |this, cx| { let (tool_definitions, model_name, messages) = this.update(cx, |this, cx| {
this.push_new_assistant_message(cx); this.push_new_assistant_message(cx);
let definitions = if call_count < limit let definitions = if call_count < limit
@ -455,14 +454,22 @@ impl AssistantChat {
}; };
call_count += 1; call_count += 1;
let messages = this.completion_messages(cx); (
definitions,
CompletionProvider::get(cx).complete(
this.model.clone(), this.model.clone(),
this.completion_messages(cx),
)
})?;
let messages = messages.await?;
let completion = cx.update(|cx| {
CompletionProvider::get(cx).complete(
model_name,
messages, messages,
Vec::new(), Vec::new(),
1.0, 1.0,
definitions, tool_definitions,
) )
}); });
@ -765,7 +772,12 @@ impl AssistantChat {
} }
} }
fn completion_messages(&self, cx: &mut WindowContext) -> Vec<CompletionMessage> { fn completion_messages(&self, cx: &mut WindowContext) -> Task<Result<Vec<CompletionMessage>>> {
let project_index = self.project_index.read(cx);
let project = project_index.project();
let fs = project_index.fs();
let mut project_context = ProjectContext::new(project, fs);
let mut completion_messages = Vec::new(); let mut completion_messages = Vec::new();
for message in &self.messages { for message in &self.messages {
@ -773,12 +785,11 @@ impl AssistantChat {
ChatMessage::User(UserMessage { ChatMessage::User(UserMessage {
body, attachments, .. body, attachments, ..
}) => { }) => {
completion_messages.extend( for attachment in attachments {
attachments if let Some(content) = attachment.generate(&mut project_context, cx) {
.into_iter() completion_messages.push(CompletionMessage::System { content });
.filter_map(|attachment| attachment.message.clone()) }
.map(|content| CompletionMessage::System { content }), }
);
// Show user's message last so that the assistant is grounded in the user's request // Show user's message last so that the assistant is grounded in the user's request
completion_messages.push(CompletionMessage::User { completion_messages.push(CompletionMessage::User {
@ -815,7 +826,9 @@ impl AssistantChat {
for tool_call in tool_calls { for tool_call in tool_calls {
// Every tool call _must_ have a result by ID, otherwise OpenAI will error. // Every tool call _must_ have a result by ID, otherwise OpenAI will error.
let content = match &tool_call.result { let content = match &tool_call.result {
Some(result) => result.format(&tool_call.name), Some(result) => {
result.generate(&tool_call.name, &mut project_context, cx)
}
None => "".to_string(), None => "".to_string(),
}; };
@ -828,7 +841,13 @@ impl AssistantChat {
} }
} }
completion_messages let system_message = project_context.generate_system_message(cx);
cx.background_executor().spawn(async move {
let content = system_message.await?;
completion_messages.insert(0, CompletionMessage::System { content });
Ok(completion_messages)
})
} }
} }

View file

@ -1,137 +1,18 @@
use std::{ pub mod active_file;
any::TypeId,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use collections::HashMap; use assistant_tooling::{LanguageModelAttachment, ProjectContext, ToolOutput};
use editor::Editor; use editor::Editor;
use futures::future::join_all; use gpui::{Render, Task, View, WeakModel, WeakView};
use gpui::{AnyView, Render, Task, View, WeakView}; use language::Buffer;
use project::ProjectPath;
use ui::{prelude::*, ButtonLike, Tooltip, WindowContext}; use ui::{prelude::*, ButtonLike, Tooltip, WindowContext};
use util::{maybe, ResultExt}; use util::maybe;
use workspace::Workspace; use workspace::Workspace;
/// A collected attachment from running an attachment tool
pub struct UserAttachment {
pub message: Option<String>,
pub view: AnyView,
}
pub struct UserAttachmentStore {
attachment_tools: HashMap<TypeId, DynamicAttachment>,
}
/// Internal representation of an attachment tool to allow us to treat them dynamically
struct DynamicAttachment {
enabled: AtomicBool,
call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
}
impl UserAttachmentStore {
pub fn new() -> Self {
Self {
attachment_tools: HashMap::default(),
}
}
pub fn register<A: AttachmentTool + 'static>(&mut self, attachment: A) {
let call = Box::new(move |cx: &mut WindowContext| {
let result = attachment.run(cx);
cx.spawn(move |mut cx| async move {
let result: Result<A::Output> = result.await;
let message = A::format(&result);
let view = cx.update(|cx| A::view(result, cx))?;
Ok(UserAttachment {
message,
view: view.into(),
})
})
});
self.attachment_tools.insert(
TypeId::of::<A>(),
DynamicAttachment {
call,
enabled: AtomicBool::new(true),
},
);
}
pub fn set_attachment_tool_enabled<A: AttachmentTool + 'static>(&self, is_enabled: bool) {
if let Some(attachment) = self.attachment_tools.get(&TypeId::of::<A>()) {
attachment.enabled.store(is_enabled, SeqCst);
}
}
pub fn is_attachment_tool_enabled<A: AttachmentTool + 'static>(&self) -> bool {
if let Some(attachment) = self.attachment_tools.get(&TypeId::of::<A>()) {
attachment.enabled.load(SeqCst)
} else {
false
}
}
pub fn call<A: AttachmentTool + 'static>(
&self,
cx: &mut WindowContext,
) -> Task<Result<UserAttachment>> {
let Some(attachment) = self.attachment_tools.get(&TypeId::of::<A>()) else {
return Task::ready(Err(anyhow!("no attachment tool")));
};
(attachment.call)(cx)
}
pub fn call_all_attachment_tools(
self: Arc<Self>,
cx: &mut WindowContext<'_>,
) -> Task<Result<Vec<UserAttachment>>> {
let this = self.clone();
cx.spawn(|mut cx| async move {
let attachment_tasks = cx.update(|cx| {
let mut tasks = Vec::new();
for attachment in this
.attachment_tools
.values()
.filter(|attachment| attachment.enabled.load(SeqCst))
{
tasks.push((attachment.call)(cx))
}
tasks
})?;
let attachments = join_all(attachment_tasks.into_iter()).await;
Ok(attachments
.into_iter()
.filter_map(|attachment| attachment.log_err())
.collect())
})
}
}
pub trait AttachmentTool {
type Output: 'static;
type View: Render;
fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
fn format(output: &Result<Self::Output>) -> Option<String>;
fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
}
pub struct ActiveEditorAttachment { pub struct ActiveEditorAttachment {
filename: Arc<str>, buffer: WeakModel<Buffer>,
language: Arc<str>, path: Option<ProjectPath>,
text: Arc<str>,
} }
pub struct FileAttachmentView { pub struct FileAttachmentView {
@ -142,7 +23,13 @@ impl Render for FileAttachmentView {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement { fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
match &self.output { match &self.output {
Ok(attachment) => { Ok(attachment) => {
let filename = attachment.filename.clone(); let filename: SharedString = attachment
.path
.as_ref()
.and_then(|p| p.path.file_name()?.to_str())
.unwrap_or("Untitled")
.to_string()
.into();
// todo!(): make the button link to the actual file to open // todo!(): make the button link to the actual file to open
ButtonLike::new("file-attachment") ButtonLike::new("file-attachment")
@ -152,7 +39,7 @@ impl Render for FileAttachmentView {
.bg(cx.theme().colors().editor_background) .bg(cx.theme().colors().editor_background)
.rounded_md() .rounded_md()
.child(ui::Icon::new(IconName::File)) .child(ui::Icon::new(IconName::File))
.child(filename.to_string()), .child(filename.clone()),
) )
.tooltip({ .tooltip({
move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx) move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx)
@ -164,6 +51,20 @@ impl Render for FileAttachmentView {
} }
} }
impl ToolOutput for FileAttachmentView {
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
if let Ok(result) = &self.output {
if let Some(path) = &result.path {
project.add_file(path.clone());
return format!("current file: {}", path.path.display());
} else if let Some(buffer) = result.buffer.upgrade() {
return format!("current untitled buffer text:\n{}", buffer.read(cx).text());
}
}
String::new()
}
}
pub struct ActiveEditorAttachmentTool { pub struct ActiveEditorAttachmentTool {
workspace: WeakView<Workspace>, workspace: WeakView<Workspace>,
} }
@ -174,7 +75,7 @@ impl ActiveEditorAttachmentTool {
} }
} }
impl AttachmentTool for ActiveEditorAttachmentTool { impl LanguageModelAttachment for ActiveEditorAttachmentTool {
type Output = ActiveEditorAttachment; type Output = ActiveEditorAttachment;
type View = FileAttachmentView; type View = FileAttachmentView;
@ -191,45 +92,20 @@ impl AttachmentTool for ActiveEditorAttachmentTool {
let buffer = active_buffer.read(cx); let buffer = active_buffer.read(cx);
if let Some(singleton) = buffer.as_singleton() { if let Some(buffer) = buffer.as_singleton() {
let singleton = singleton.read(cx); let path =
project::File::from_dyn(buffer.read(cx).file()).map(|file| ProjectPath {
let filename = singleton worktree_id: file.worktree_id(cx),
.file() path: file.path.clone(),
.map(|file| file.path().to_string_lossy())
.unwrap_or("Untitled".into());
let text = singleton.text();
let language = singleton
.language()
.map(|l| {
let name = l.code_fence_block_name();
name.to_string()
})
.unwrap_or_default();
return Ok(ActiveEditorAttachment {
filename: filename.into(),
language: language.into(),
text: text.into(),
}); });
} return Ok(ActiveEditorAttachment {
buffer: buffer.downgrade(),
path,
});
} else {
Err(anyhow!("no active buffer")) Err(anyhow!("no active buffer"))
}))
} }
}))
fn format(output: &Result<Self::Output>) -> Option<String> {
let output = output.as_ref().ok()?;
let filename = &output.filename;
let language = &output.language;
let text = &output.text;
Some(format!(
"User's active file `{filename}`:\n\n```{language}\n{text}```\n\n"
))
} }
fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View> { fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View> {

View file

@ -0,0 +1 @@

View file

@ -1,5 +1,5 @@
use anyhow::Result; use anyhow::Result;
use assistant_tooling::LanguageModelTool; use assistant_tooling::{LanguageModelTool, ProjectContext, ToolOutput};
use editor::Editor; use editor::Editor;
use gpui::{prelude::*, Model, Task, View, WeakView}; use gpui::{prelude::*, Model, Task, View, WeakView};
use project::Project; use project::Project;
@ -31,11 +31,9 @@ pub struct CreateBufferInput {
language: String, language: String,
} }
pub struct CreateBufferOutput {}
impl LanguageModelTool for CreateBufferTool { impl LanguageModelTool for CreateBufferTool {
type Input = CreateBufferInput; type Input = CreateBufferInput;
type Output = CreateBufferOutput; type Output = ();
type View = CreateBufferView; type View = CreateBufferView;
fn name(&self) -> String { fn name(&self) -> String {
@ -83,32 +81,39 @@ impl LanguageModelTool for CreateBufferTool {
}) })
.log_err(); .log_err();
Ok(CreateBufferOutput {}) Ok(())
} }
}) })
} }
fn format(input: &Self::Input, output: &Result<Self::Output>) -> String {
match output {
Ok(_) => format!("Created a new {} buffer", input.language),
Err(err) => format!("Failed to create buffer: {err:?}"),
}
}
fn output_view( fn output_view(
_tool_call_id: String, input: Self::Input,
_input: Self::Input, output: Result<Self::Output>,
_output: Result<Self::Output>,
cx: &mut WindowContext, cx: &mut WindowContext,
) -> View<Self::View> { ) -> View<Self::View> {
cx.new_view(|_cx| CreateBufferView {}) cx.new_view(|_cx| CreateBufferView {
language: input.language,
output,
})
} }
} }
pub struct CreateBufferView {} pub struct CreateBufferView {
language: String,
output: Result<()>,
}
impl Render for CreateBufferView { impl Render for CreateBufferView {
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement { fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
div().child("Opening a buffer") div().child("Opening a buffer")
} }
} }
impl ToolOutput for CreateBufferView {
fn generate(&self, _: &mut ProjectContext, _: &mut WindowContext) -> String {
match &self.output {
Ok(_) => format!("Created a new {} buffer", self.language),
Err(err) => format!("Failed to create buffer: {err:?}"),
}
}
}

View file

@ -1,25 +1,18 @@
use anyhow::Result; use anyhow::Result;
use assistant_tooling::LanguageModelTool; use assistant_tooling::{LanguageModelTool, ToolOutput};
use collections::BTreeMap;
use gpui::{prelude::*, Model, Task}; use gpui::{prelude::*, Model, Task};
use project::Fs; use project::ProjectPath;
use schemars::JsonSchema; use schemars::JsonSchema;
use semantic_index::{ProjectIndex, Status}; use semantic_index::{ProjectIndex, Status};
use serde::Deserialize; use serde::Deserialize;
use std::{collections::HashSet, sync::Arc}; use std::{fmt::Write as _, ops::Range};
use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
use ui::{
div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString,
WindowContext,
};
use util::ResultExt as _;
const DEFAULT_SEARCH_LIMIT: usize = 20; const DEFAULT_SEARCH_LIMIT: usize = 20;
#[derive(Clone)] pub struct ProjectIndexTool {
pub struct CodebaseExcerpt { project_index: Model<ProjectIndex>,
path: SharedString,
text: SharedString,
score: f32,
} }
// Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model. // Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model.
@ -40,6 +33,11 @@ pub struct ProjectIndexView {
expanded_header: bool, expanded_header: bool,
} }
pub struct ProjectIndexOutput {
status: Status,
excerpts: BTreeMap<ProjectPath, Vec<Range<usize>>>,
}
impl ProjectIndexView { impl ProjectIndexView {
fn new(input: CodebaseQuery, output: Result<ProjectIndexOutput>) -> Self { fn new(input: CodebaseQuery, output: Result<ProjectIndexOutput>) -> Self {
let element_id = ElementId::Name(nanoid::nanoid!().into()); let element_id = ElementId::Name(nanoid::nanoid!().into());
@ -71,19 +69,15 @@ impl Render for ProjectIndexView {
Ok(output) => output, Ok(output) => output,
}; };
let num_files_searched = output.files_searched.len(); let file_count = output.excerpts.len();
let header = h_flex() let header = h_flex()
.gap_2() .gap_2()
.child(Icon::new(IconName::File)) .child(Icon::new(IconName::File))
.child(format!( .child(format!(
"Read {} {}", "Read {} {}",
num_files_searched, file_count,
if num_files_searched == 1 { if file_count == 1 { "file" } else { "files" }
"file"
} else {
"files"
}
)); ));
v_flex().gap_3().child( v_flex().gap_3().child(
@ -102,36 +96,50 @@ impl Render for ProjectIndexView {
.child(Icon::new(IconName::MagnifyingGlass)) .child(Icon::new(IconName::MagnifyingGlass))
.child(Label::new(format!("`{}`", query)).color(Color::Muted)), .child(Label::new(format!("`{}`", query)).color(Color::Muted)),
) )
.child(v_flex().gap_2().children(output.files_searched.iter().map( .child(
|path| { v_flex()
h_flex()
.gap_2() .gap_2()
.child(Icon::new(IconName::File)) .children(output.excerpts.keys().map(|path| {
.child(Label::new(path.clone()).color(Color::Muted)) h_flex().gap_2().child(Icon::new(IconName::File)).child(
}, Label::new(path.path.to_string_lossy().to_string())
))), .color(Color::Muted),
)
})),
),
), ),
) )
} }
} }
pub struct ProjectIndexTool { impl ToolOutput for ProjectIndexView {
project_index: Model<ProjectIndex>, fn generate(
fs: Arc<dyn Fs>, &self,
} context: &mut assistant_tooling::ProjectContext,
_: &mut WindowContext,
) -> String {
match &self.output {
Ok(output) => {
let mut body = "found results in the following paths:\n".to_string();
pub struct ProjectIndexOutput { for (project_path, ranges) in &output.excerpts {
excerpts: Vec<CodebaseExcerpt>, context.add_excerpts(project_path.clone(), ranges);
status: Status, writeln!(&mut body, "* {}", &project_path.path.display()).unwrap();
files_searched: HashSet<SharedString>, }
if output.status != Status::Idle {
body.push_str("Still indexing. Results may be incomplete.\n");
}
body
}
Err(err) => format!("Error: {}", err),
}
}
} }
impl ProjectIndexTool { impl ProjectIndexTool {
pub fn new(project_index: Model<ProjectIndex>, fs: Arc<dyn Fs>) -> Self { pub fn new(project_index: Model<ProjectIndex>) -> Self {
// Listen for project index status and update the ProjectIndexTool directly Self { project_index }
// TODO: setup a better description based on the user's current codebase.
Self { project_index, fs }
} }
} }
@ -151,64 +159,42 @@ impl LanguageModelTool for ProjectIndexTool {
fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> { fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
let project_index = self.project_index.read(cx); let project_index = self.project_index.read(cx);
let status = project_index.status(); let status = project_index.status();
let results = project_index.search( let search = project_index.search(
query.query.clone(), query.query.clone(),
query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT), query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
cx, cx,
); );
let fs = self.fs.clone(); cx.spawn(|mut cx| async move {
let search_results = search.await?;
cx.spawn(|cx| async move { cx.update(|cx| {
let results = results.await?; let mut output = ProjectIndexOutput {
let excerpts = results.into_iter().map(|result| {
let abs_path = result
.worktree
.read_with(&cx, |worktree, _| worktree.abs_path().join(&result.path));
let fs = fs.clone();
async move {
let path = result.path.clone();
let text = fs.load(&abs_path?).await?;
let mut start = result.range.start;
let mut end = result.range.end.min(text.len());
while !text.is_char_boundary(start) {
start += 1;
}
while !text.is_char_boundary(end) {
end -= 1;
}
anyhow::Ok(CodebaseExcerpt {
path: path.to_string_lossy().to_string().into(),
text: SharedString::from(text[start..end].to_string()),
score: result.score,
})
}
});
let mut files_searched = HashSet::new();
let excerpts = futures::future::join_all(excerpts)
.await
.into_iter()
.filter_map(|result| result.log_err())
.inspect(|excerpt| {
files_searched.insert(excerpt.path.clone());
})
.collect::<Vec<_>>();
anyhow::Ok(ProjectIndexOutput {
excerpts,
status, status,
files_searched, excerpts: Default::default(),
};
for search_result in search_results {
let path = ProjectPath {
worktree_id: search_result.worktree.read(cx).id(),
path: search_result.path.clone(),
};
let excerpts_for_path = output.excerpts.entry(path).or_default();
let ix = match excerpts_for_path
.binary_search_by_key(&search_result.range.start, |r| r.start)
{
Ok(ix) | Err(ix) => ix,
};
excerpts_for_path.insert(ix, search_result.range);
}
output
}) })
}) })
} }
fn output_view( fn output_view(
_tool_call_id: String,
input: Self::Input, input: Self::Input,
output: Result<Self::Output>, output: Result<Self::Output>,
cx: &mut WindowContext, cx: &mut WindowContext,
@ -220,34 +206,4 @@ impl LanguageModelTool for ProjectIndexTool {
CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false) CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false)
.start_slot("Searching code base") .start_slot("Searching code base")
} }
fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
match &output {
Ok(output) => {
let mut body = "Semantic search results:\n".to_string();
if output.status != Status::Idle {
body.push_str("Still indexing. Results may be incomplete.\n");
}
if output.excerpts.is_empty() {
body.push_str("No results found");
return body;
}
for excerpt in &output.excerpts {
body.push_str("Excerpt from ");
body.push_str(excerpt.path.as_ref());
body.push_str(", score ");
body.push_str(&excerpt.score.to_string());
body.push_str(":\n");
body.push_str("~~~\n");
body.push_str(excerpt.text.as_ref());
body.push_str("~~~\n");
}
body
}
Err(err) => format!("Error: {}", err),
}
}
} }

View file

@ -1,4 +1,5 @@
use crate::attachments::{ActiveEditorAttachmentTool, UserAttachmentStore}; use crate::attachments::ActiveEditorAttachmentTool;
use assistant_tooling::AttachmentRegistry;
use editor::Editor; use editor::Editor;
use gpui::{prelude::*, Subscription, View}; use gpui::{prelude::*, Subscription, View};
use std::sync::Arc; use std::sync::Arc;
@ -13,7 +14,7 @@ enum Status {
} }
pub struct ActiveFileButton { pub struct ActiveFileButton {
attachment_store: Arc<UserAttachmentStore>, attachment_registry: Arc<AttachmentRegistry>,
status: Status, status: Status,
#[allow(dead_code)] #[allow(dead_code)]
workspace_subscription: Subscription, workspace_subscription: Subscription,
@ -21,7 +22,7 @@ pub struct ActiveFileButton {
impl ActiveFileButton { impl ActiveFileButton {
pub fn new( pub fn new(
attachment_store: Arc<UserAttachmentStore>, attachment_store: Arc<AttachmentRegistry>,
workspace: View<Workspace>, workspace: View<Workspace>,
cx: &mut ViewContext<Self>, cx: &mut ViewContext<Self>,
) -> Self { ) -> Self {
@ -30,14 +31,14 @@ impl ActiveFileButton {
cx.defer(move |this, cx| this.update_active_buffer(workspace.clone(), cx)); cx.defer(move |this, cx| this.update_active_buffer(workspace.clone(), cx));
Self { Self {
attachment_store, attachment_registry: attachment_store,
status: Status::NoFile, status: Status::NoFile,
workspace_subscription, workspace_subscription,
} }
} }
pub fn set_enabled(&mut self, enabled: bool) { pub fn set_enabled(&mut self, enabled: bool) {
self.attachment_store self.attachment_registry
.set_attachment_tool_enabled::<ActiveEditorAttachmentTool>(enabled); .set_attachment_tool_enabled::<ActiveEditorAttachmentTool>(enabled);
} }
@ -79,7 +80,7 @@ impl ActiveFileButton {
impl Render for ActiveFileButton { impl Render for ActiveFileButton {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement { fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let is_enabled = self let is_enabled = self
.attachment_store .attachment_registry
.is_attachment_tool_enabled::<ActiveEditorAttachmentTool>(); .is_attachment_tool_enabled::<ActiveEditorAttachmentTool>();
let icon = if is_enabled { let icon = if is_enabled {

View file

@ -11,7 +11,7 @@ use ui::{popover_menu, prelude::*, ButtonLike, ContextMenu, Divider, TextSize, T
#[derive(IntoElement)] #[derive(IntoElement)]
pub struct Composer { pub struct Composer {
editor: View<Editor>, editor: View<Editor>,
project_index_button: Option<View<ProjectIndexButton>>, project_index_button: View<ProjectIndexButton>,
active_file_button: Option<View<ActiveFileButton>>, active_file_button: Option<View<ActiveFileButton>>,
model_selector: AnyElement, model_selector: AnyElement,
} }
@ -19,7 +19,7 @@ pub struct Composer {
impl Composer { impl Composer {
pub fn new( pub fn new(
editor: View<Editor>, editor: View<Editor>,
project_index_button: Option<View<ProjectIndexButton>>, project_index_button: View<ProjectIndexButton>,
active_file_button: Option<View<ActiveFileButton>>, active_file_button: Option<View<ActiveFileButton>>,
model_selector: AnyElement, model_selector: AnyElement,
) -> Self { ) -> Self {
@ -32,11 +32,7 @@ impl Composer {
} }
fn render_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement { fn render_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement {
h_flex().children( h_flex().child(self.project_index_button.clone())
self.project_index_button
.clone()
.map(|view| view.into_any_element()),
)
} }
fn render_attachment_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement { fn render_attachment_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement {

View file

@ -13,10 +13,18 @@ path = "src/assistant_tooling.rs"
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true gpui.workspace = true
project.workspace = true
schemars.workspace = true schemars.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
sum_tree.workspace = true
util.workspace = true
[dev-dependencies] [dev-dependencies]
gpui = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
unindent.workspace = true

View file

@ -1,5 +1,9 @@
pub mod registry; mod attachment_registry;
pub mod tool; mod project_context;
mod tool_registry;
pub use crate::registry::ToolRegistry; pub use attachment_registry::{AttachmentRegistry, LanguageModelAttachment, UserAttachment};
pub use crate::tool::{LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition}; pub use project_context::ProjectContext;
pub use tool_registry::{
LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition, ToolOutput, ToolRegistry,
};

View file

@ -0,0 +1,148 @@
use crate::{ProjectContext, ToolOutput};
use anyhow::{anyhow, Result};
use collections::HashMap;
use futures::future::join_all;
use gpui::{AnyView, Render, Task, View, WindowContext};
use std::{
any::TypeId,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
};
use util::ResultExt as _;
pub struct AttachmentRegistry {
registered_attachments: HashMap<TypeId, RegisteredAttachment>,
}
pub trait LanguageModelAttachment {
type Output: 'static;
type View: Render + ToolOutput;
fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
}
/// A collected attachment from running an attachment tool
pub struct UserAttachment {
pub view: AnyView,
generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
}
/// Internal representation of an attachment tool to allow us to treat them dynamically
struct RegisteredAttachment {
enabled: AtomicBool,
call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
}
impl AttachmentRegistry {
pub fn new() -> Self {
Self {
registered_attachments: HashMap::default(),
}
}
pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
let call = Box::new(move |cx: &mut WindowContext| {
let result = attachment.run(cx);
cx.spawn(move |mut cx| async move {
let result: Result<A::Output> = result.await;
let view = cx.update(|cx| A::view(result, cx))?;
Ok(UserAttachment {
view: view.into(),
generate_fn: generate::<A>,
})
})
});
self.registered_attachments.insert(
TypeId::of::<A>(),
RegisteredAttachment {
call,
enabled: AtomicBool::new(true),
},
);
return;
fn generate<T: LanguageModelAttachment>(
view: AnyView,
project: &mut ProjectContext,
cx: &mut WindowContext,
) -> String {
view.downcast::<T::View>()
.unwrap()
.update(cx, |view, cx| T::View::generate(view, project, cx))
}
}
pub fn set_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(
&self,
is_enabled: bool,
) {
if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
attachment.enabled.store(is_enabled, SeqCst);
}
}
pub fn is_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(&self) -> bool {
if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
attachment.enabled.load(SeqCst)
} else {
false
}
}
pub fn call<A: LanguageModelAttachment + 'static>(
&self,
cx: &mut WindowContext,
) -> Task<Result<UserAttachment>> {
let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) else {
return Task::ready(Err(anyhow!("no attachment tool")));
};
(attachment.call)(cx)
}
pub fn call_all_attachment_tools(
self: Arc<Self>,
cx: &mut WindowContext<'_>,
) -> Task<Result<Vec<UserAttachment>>> {
let this = self.clone();
cx.spawn(|mut cx| async move {
let attachment_tasks = cx.update(|cx| {
let mut tasks = Vec::new();
for attachment in this
.registered_attachments
.values()
.filter(|attachment| attachment.enabled.load(SeqCst))
{
tasks.push((attachment.call)(cx))
}
tasks
})?;
let attachments = join_all(attachment_tasks.into_iter()).await;
Ok(attachments
.into_iter()
.filter_map(|attachment| attachment.log_err())
.collect())
})
}
}
impl UserAttachment {
pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option<String> {
let result = (self.generate_fn)(self.view.clone(), output, cx);
if result.is_empty() {
None
} else {
Some(result)
}
}
}

View file

@ -0,0 +1,296 @@
use anyhow::{anyhow, Result};
use gpui::{AppContext, Model, Task, WeakModel};
use project::{Fs, Project, ProjectPath, Worktree};
use std::{cmp::Ordering, fmt::Write as _, ops::Range, sync::Arc};
use sum_tree::TreeMap;
pub struct ProjectContext {
files: TreeMap<ProjectPath, PathState>,
project: WeakModel<Project>,
fs: Arc<dyn Fs>,
}
#[derive(Debug, Clone)]
enum PathState {
PathOnly,
EntireFile,
Excerpts { ranges: Vec<Range<usize>> },
}
impl ProjectContext {
pub fn new(project: WeakModel<Project>, fs: Arc<dyn Fs>) -> Self {
Self {
files: TreeMap::default(),
fs,
project,
}
}
pub fn add_path(&mut self, project_path: ProjectPath) {
if self.files.get(&project_path).is_none() {
self.files.insert(project_path, PathState::PathOnly);
}
}
pub fn add_excerpts(&mut self, project_path: ProjectPath, new_ranges: &[Range<usize>]) {
let previous_state = self
.files
.get(&project_path)
.unwrap_or(&PathState::PathOnly);
let mut ranges = match previous_state {
PathState::EntireFile => return,
PathState::PathOnly => Vec::new(),
PathState::Excerpts { ranges } => ranges.to_vec(),
};
for new_range in new_ranges {
let ix = ranges.binary_search_by(|probe| {
if probe.end < new_range.start {
Ordering::Less
} else if probe.start > new_range.end {
Ordering::Greater
} else {
Ordering::Equal
}
});
match ix {
Ok(mut ix) => {
let existing = &mut ranges[ix];
existing.start = existing.start.min(new_range.start);
existing.end = existing.end.max(new_range.end);
while ix + 1 < ranges.len() && ranges[ix + 1].start <= ranges[ix].end {
ranges[ix].end = ranges[ix].end.max(ranges[ix + 1].end);
ranges.remove(ix + 1);
}
while ix > 0 && ranges[ix - 1].end >= ranges[ix].start {
ranges[ix].start = ranges[ix].start.min(ranges[ix - 1].start);
ranges.remove(ix - 1);
ix -= 1;
}
}
Err(ix) => {
ranges.insert(ix, new_range.clone());
}
}
}
self.files
.insert(project_path, PathState::Excerpts { ranges });
}
pub fn add_file(&mut self, project_path: ProjectPath) {
self.files.insert(project_path, PathState::EntireFile);
}
pub fn generate_system_message(&self, cx: &mut AppContext) -> Task<Result<String>> {
let project = self
.project
.upgrade()
.ok_or_else(|| anyhow!("project dropped"));
let files = self.files.clone();
let fs = self.fs.clone();
cx.spawn(|cx| async move {
let project = project?;
let mut result = "project structure:\n".to_string();
let mut last_worktree: Option<Model<Worktree>> = None;
for (project_path, path_state) in files.iter() {
if let Some(worktree) = &last_worktree {
if worktree.read_with(&cx, |tree, _| tree.id())? != project_path.worktree_id {
last_worktree = None;
}
}
let worktree;
if let Some(last_worktree) = &last_worktree {
worktree = last_worktree.clone();
} else if let Some(tree) = project.read_with(&cx, |project, cx| {
project.worktree_for_id(project_path.worktree_id, cx)
})? {
worktree = tree;
last_worktree = Some(worktree.clone());
let worktree_name =
worktree.read_with(&cx, |tree, _cx| tree.root_name().to_string())?;
writeln!(&mut result, "# {}", worktree_name).unwrap();
} else {
continue;
}
let worktree_abs_path = worktree.read_with(&cx, |tree, _cx| tree.abs_path())?;
let path = &project_path.path;
writeln!(&mut result, "## {}", path.display()).unwrap();
match path_state {
PathState::PathOnly => {}
PathState::EntireFile => {
let text = fs.load(&worktree_abs_path.join(&path)).await?;
writeln!(&mut result, "~~~\n{text}\n~~~").unwrap();
}
PathState::Excerpts { ranges } => {
let text = fs.load(&worktree_abs_path.join(&path)).await?;
writeln!(&mut result, "~~~").unwrap();
// Assumption: ranges are in order, not overlapping
let mut prev_range_end = 0;
for range in ranges {
if range.start > prev_range_end {
writeln!(&mut result, "...").unwrap();
prev_range_end = range.end;
}
let mut start = range.start;
let mut end = range.end.min(text.len());
while !text.is_char_boundary(start) {
start += 1;
}
while !text.is_char_boundary(end) {
end -= 1;
}
result.push_str(&text[start..end]);
if !result.ends_with('\n') {
result.push('\n');
}
}
if prev_range_end < text.len() {
writeln!(&mut result, "...").unwrap();
}
writeln!(&mut result, "~~~").unwrap();
}
}
}
Ok(result)
})
}
}
#[cfg(test)]
mod tests {
use std::path::Path;
use super::*;
use gpui::TestAppContext;
use project::FakeFs;
use serde_json::json;
use settings::SettingsStore;
use unindent::Unindent as _;
#[gpui::test]
async fn test_system_message_generation(cx: &mut TestAppContext) {
init_test(cx);
let file_3_contents = r#"
fn test1() {}
fn test2() {}
fn test3() {}
"#
.unindent();
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/code",
json!({
"root1": {
"lib": {
"file1.rs": "mod example;",
"file2.rs": "",
},
"test": {
"file3.rs": file_3_contents,
}
},
"root2": {
"src": {
"main.rs": ""
}
}
}),
)
.await;
let project = Project::test(
fs.clone(),
["/code/root1".as_ref(), "/code/root2".as_ref()],
cx,
)
.await;
let worktree_ids = project.read_with(cx, |project, cx| {
project
.worktrees()
.map(|worktree| worktree.read(cx).id())
.collect::<Vec<_>>()
});
let mut ax = ProjectContext::new(project.downgrade(), fs);
ax.add_file(ProjectPath {
worktree_id: worktree_ids[0],
path: Path::new("lib/file1.rs").into(),
});
let message = cx
.update(|cx| ax.generate_system_message(cx))
.await
.unwrap();
assert_eq!(
r#"
project structure:
# root1
## lib/file1.rs
~~~
mod example;
~~~
"#
.unindent(),
message
);
ax.add_excerpts(
ProjectPath {
worktree_id: worktree_ids[0],
path: Path::new("test/file3.rs").into(),
},
&[
file_3_contents.find("fn test2").unwrap()
..file_3_contents.find("fn test3").unwrap(),
],
);
let message = cx
.update(|cx| ax.generate_system_message(cx))
.await
.unwrap();
assert_eq!(
r#"
project structure:
# root1
## lib/file1.rs
~~~
mod example;
~~~
## test/file3.rs
~~~
...
fn test2() {}
...
~~~
"#
.unindent(),
message
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
});
}
}

View file

@ -1,111 +0,0 @@
use anyhow::Result;
use gpui::{div, AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
use schemars::{schema::RootSchema, schema_for, JsonSchema};
use serde::Deserialize;
use std::fmt::Display;
#[derive(Default, Deserialize)]
pub struct ToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
#[serde(skip)]
pub result: Option<ToolFunctionCallResult>,
}
pub enum ToolFunctionCallResult {
NoSuchTool,
ParsingFailed,
Finished { for_model: String, view: AnyView },
}
impl ToolFunctionCallResult {
pub fn format(&self, name: &String) -> String {
match self {
ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
ToolFunctionCallResult::ParsingFailed => {
format!("Unable to parse arguments for {name}")
}
ToolFunctionCallResult::Finished { for_model, .. } => for_model.clone(),
}
}
pub fn into_any_element(&self, name: &String) -> AnyElement {
match self {
ToolFunctionCallResult::NoSuchTool => {
format!("Language Model attempted to call {name}").into_any_element()
}
ToolFunctionCallResult::ParsingFailed => {
format!("Language Model called {name} with bad arguments").into_any_element()
}
ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
}
}
}
#[derive(Clone)]
pub struct ToolFunctionDefinition {
pub name: String,
pub description: String,
pub parameters: RootSchema,
}
impl Display for ToolFunctionDefinition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let schema = serde_json::to_string(&self.parameters).ok();
let schema = schema.unwrap_or("None".to_string());
write!(f, "Name: {}:\n", self.name)?;
write!(f, "Description: {}\n", self.description)?;
write!(f, "Parameters: {}", schema)
}
}
pub trait LanguageModelTool {
/// The input type that will be passed in to `execute` when the tool is called
/// by the language model.
type Input: for<'de> Deserialize<'de> + JsonSchema;
/// The output returned by executing the tool.
type Output: 'static;
type View: Render;
/// Returns the name of the tool.
///
/// This name is exposed to the language model to allow the model to pick
/// which tools to use. As this name is used to identify the tool within a
/// tool registry, it should be unique.
fn name(&self) -> String;
/// Returns the description of the tool.
///
/// This can be used to _prompt_ the model as to what the tool does.
fn description(&self) -> String;
/// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
fn definition(&self) -> ToolFunctionDefinition {
let root_schema = schema_for!(Self::Input);
ToolFunctionDefinition {
name: self.name(),
description: self.description(),
parameters: root_schema,
}
}
/// Executes the tool with the given input.
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
fn format(input: &Self::Input, output: &Result<Self::Output>) -> String;
fn output_view(
tool_call_id: String,
input: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,
) -> View<Self::View>;
fn render_running(_cx: &mut WindowContext) -> impl IntoElement {
div()
}
}

View file

@ -1,54 +1,115 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use gpui::{div, AnyElement, IntoElement as _, ParentElement, Styled, Task, WindowContext}; use gpui::{
div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
};
use schemars::{schema::RootSchema, schema_for, JsonSchema};
use serde::Deserialize;
use std::{ use std::{
any::TypeId, any::TypeId,
collections::HashMap, collections::HashMap,
fmt::Display,
sync::atomic::{AtomicBool, Ordering::SeqCst}, sync::atomic::{AtomicBool, Ordering::SeqCst},
}; };
use crate::tool::{ use crate::ProjectContext;
LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
};
// Internal Tool representation for the registry pub struct ToolRegistry {
pub struct Tool { registered_tools: HashMap<String, RegisteredTool>,
}
#[derive(Default, Deserialize)]
pub struct ToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
#[serde(skip)]
pub result: Option<ToolFunctionCallResult>,
}
pub enum ToolFunctionCallResult {
NoSuchTool,
ParsingFailed,
Finished {
view: AnyView,
generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
},
}
#[derive(Clone)]
pub struct ToolFunctionDefinition {
pub name: String,
pub description: String,
pub parameters: RootSchema,
}
pub trait LanguageModelTool {
/// The input type that will be passed in to `execute` when the tool is called
/// by the language model.
type Input: for<'de> Deserialize<'de> + JsonSchema;
/// The output returned by executing the tool.
type Output: 'static;
type View: Render + ToolOutput;
/// Returns the name of the tool.
///
/// This name is exposed to the language model to allow the model to pick
/// which tools to use. As this name is used to identify the tool within a
/// tool registry, it should be unique.
fn name(&self) -> String;
/// Returns the description of the tool.
///
/// This can be used to _prompt_ the model as to what the tool does.
fn description(&self) -> String;
/// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
fn definition(&self) -> ToolFunctionDefinition {
let root_schema = schema_for!(Self::Input);
ToolFunctionDefinition {
name: self.name(),
description: self.description(),
parameters: root_schema,
}
}
/// Executes the tool with the given input.
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
fn output_view(
input: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,
) -> View<Self::View>;
fn render_running(_cx: &mut WindowContext) -> impl IntoElement {
div()
}
}
pub trait ToolOutput: Sized {
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
}
struct RegisteredTool {
enabled: AtomicBool, enabled: AtomicBool,
type_id: TypeId, type_id: TypeId,
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>, call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
render_running: Box<dyn Fn(&mut WindowContext) -> gpui::AnyElement>, render_running: fn(&mut WindowContext) -> gpui::AnyElement,
definition: ToolFunctionDefinition, definition: ToolFunctionDefinition,
} }
impl Tool {
fn new(
type_id: TypeId,
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
render_running: Box<dyn Fn(&mut WindowContext) -> gpui::AnyElement>,
definition: ToolFunctionDefinition,
) -> Self {
Self {
enabled: AtomicBool::new(true),
type_id,
call,
render_running,
definition,
}
}
}
pub struct ToolRegistry {
tools: HashMap<String, Tool>,
}
impl ToolRegistry { impl ToolRegistry {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
tools: HashMap::new(), registered_tools: HashMap::new(),
} }
} }
pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) { pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
for tool in self.tools.values() { for tool in self.registered_tools.values() {
if tool.type_id == TypeId::of::<T>() { if tool.type_id == TypeId::of::<T>() {
tool.enabled.store(is_enabled, SeqCst); tool.enabled.store(is_enabled, SeqCst);
return; return;
@ -57,7 +118,7 @@ impl ToolRegistry {
} }
pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool { pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
for tool in self.tools.values() { for tool in self.registered_tools.values() {
if tool.type_id == TypeId::of::<T>() { if tool.type_id == TypeId::of::<T>() {
return tool.enabled.load(SeqCst); return tool.enabled.load(SeqCst);
} }
@ -66,7 +127,7 @@ impl ToolRegistry {
} }
pub fn definitions(&self) -> Vec<ToolFunctionDefinition> { pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
self.tools self.registered_tools
.values() .values()
.filter(|tool| tool.enabled.load(SeqCst)) .filter(|tool| tool.enabled.load(SeqCst))
.map(|tool| tool.definition.clone()) .map(|tool| tool.definition.clone())
@ -84,7 +145,7 @@ impl ToolRegistry {
.child(result.into_any_element(&tool_call.name)) .child(result.into_any_element(&tool_call.name))
.into_any_element(), .into_any_element(),
None => self None => self
.tools .registered_tools
.get(&tool_call.name) .get(&tool_call.name)
.map(|tool| (tool.render_running)(cx)) .map(|tool| (tool.render_running)(cx))
.unwrap_or_else(|| div().into_any_element()), .unwrap_or_else(|| div().into_any_element()),
@ -96,13 +157,12 @@ impl ToolRegistry {
tool: T, tool: T,
_cx: &mut WindowContext, _cx: &mut WindowContext,
) -> Result<()> { ) -> Result<()> {
let definition = tool.definition();
let name = tool.name(); let name = tool.name();
let registered_tool = RegisteredTool {
let registered_tool = Tool::new( type_id: TypeId::of::<T>(),
TypeId::of::<T>(), definition: tool.definition(),
Box::new( enabled: AtomicBool::new(true),
call: Box::new(
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| { move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
let name = tool_call.name.clone(); let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone(); let arguments = tool_call.arguments.clone();
@ -121,8 +181,7 @@ impl ToolRegistry {
cx.spawn(move |mut cx| async move { cx.spawn(move |mut cx| async move {
let result: Result<T::Output> = result.await; let result: Result<T::Output> = result.await;
let for_model = T::format(&input, &result); let view = cx.update(|cx| T::output_view(input, result, cx))?;
let view = cx.update(|cx| T::output_view(id.clone(), input, result, cx))?;
Ok(ToolFunctionCall { Ok(ToolFunctionCall {
id, id,
@ -130,23 +189,35 @@ impl ToolRegistry {
arguments, arguments,
result: Some(ToolFunctionCallResult::Finished { result: Some(ToolFunctionCallResult::Finished {
view: view.into(), view: view.into(),
for_model, generate_fn: generate::<T>,
}), }),
}) })
}) })
}, },
), ),
Box::new(|cx| T::render_running(cx).into_any_element()), render_running: render_running::<T>,
definition, };
);
let previous = self.tools.insert(name.clone(), registered_tool);
let previous = self.registered_tools.insert(name.clone(), registered_tool);
if previous.is_some() { if previous.is_some() {
return Err(anyhow!("already registered a tool with name {}", name)); return Err(anyhow!("already registered a tool with name {}", name));
} }
Ok(()) return Ok(());
fn render_running<T: LanguageModelTool>(cx: &mut WindowContext) -> AnyElement {
T::render_running(cx).into_any_element()
}
fn generate<T: LanguageModelTool>(
view: AnyView,
project: &mut ProjectContext,
cx: &mut WindowContext,
) -> String {
view.downcast::<T::View>()
.unwrap()
.update(cx, |view, cx| T::View::generate(view, project, cx))
}
} }
/// Task yields an error if the window for the given WindowContext is closed before the task completes. /// Task yields an error if the window for the given WindowContext is closed before the task completes.
@ -159,7 +230,7 @@ impl ToolRegistry {
let arguments = tool_call.arguments.clone(); let arguments = tool_call.arguments.clone();
let id = tool_call.id.clone(); let id = tool_call.id.clone();
let tool = match self.tools.get(&name) { let tool = match self.registered_tools.get(&name) {
Some(tool) => tool, Some(tool) => tool,
None => { None => {
let name = name.clone(); let name = name.clone();
@ -176,6 +247,47 @@ impl ToolRegistry {
} }
} }
impl ToolFunctionCallResult {
pub fn generate(
&self,
name: &String,
project: &mut ProjectContext,
cx: &mut WindowContext,
) -> String {
match self {
ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
ToolFunctionCallResult::ParsingFailed => {
format!("Unable to parse arguments for {name}")
}
ToolFunctionCallResult::Finished { generate_fn, view } => {
(generate_fn)(view.clone(), project, cx)
}
}
}
fn into_any_element(&self, name: &String) -> AnyElement {
match self {
ToolFunctionCallResult::NoSuchTool => {
format!("Language Model attempted to call {name}").into_any_element()
}
ToolFunctionCallResult::ParsingFailed => {
format!("Language Model called {name} with bad arguments").into_any_element()
}
ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
}
}
}
impl Display for ToolFunctionDefinition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let schema = serde_json::to_string(&self.parameters).ok();
let schema = schema.unwrap_or("None".to_string());
write!(f, "Name: {}:\n", self.name)?;
write!(f, "Description: {}\n", self.description)?;
write!(f, "Parameters: {}", schema)
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;
@ -213,6 +325,12 @@ mod test {
} }
} }
impl ToolOutput for WeatherView {
fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
serde_json::to_string(&self.result).unwrap()
}
}
impl LanguageModelTool for WeatherTool { impl LanguageModelTool for WeatherTool {
type Input = WeatherQuery; type Input = WeatherQuery;
type Output = WeatherResult; type Output = WeatherResult;
@ -240,7 +358,6 @@ mod test {
} }
fn output_view( fn output_view(
_tool_call_id: String,
_input: Self::Input, _input: Self::Input,
result: Result<Self::Output>, result: Result<Self::Output>,
cx: &mut WindowContext, cx: &mut WindowContext,
@ -250,10 +367,6 @@ mod test {
WeatherView { result } WeatherView { result }
}) })
} }
fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
serde_json::to_string(&output.as_ref().unwrap()).unwrap()
}
} }
#[gpui::test] #[gpui::test]

View file

@ -163,6 +163,10 @@ impl ProjectIndex {
self.project.clone() self.project.clone()
} }
pub fn fs(&self) -> Arc<dyn Fs> {
self.fs.clone()
}
fn handle_project_event( fn handle_project_event(
&mut self, &mut self,
_: Model<Project>, _: Model<Project>,