Centralize project context provided to the assistant (#11471)
This PR restructures the way that tools and attachments add information about the current project to a conversation with the assistant. Rather than each tool call or attachment generating a new tool or system message containing information about the project, they can all collectively mutate a new type called a `ProjectContext`, which stores all of the project data that should be sent to the assistant. That data is then formatted in a single place, and passed to the assistant in one system message. This prevents multiple tools/attachments from including redundant context. Release Notes: - N/A --------- Co-authored-by: Kyle <kylek@zed.dev>
This commit is contained in:
parent
f2a415135b
commit
a64e20ed96
15 changed files with 841 additions and 518 deletions
7
Cargo.lock
generated
7
Cargo.lock
generated
|
@ -411,10 +411,17 @@ name = "assistant_tooling"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"collections",
|
||||||
|
"futures 0.3.28",
|
||||||
"gpui",
|
"gpui",
|
||||||
|
"project",
|
||||||
"schemars",
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"settings",
|
||||||
|
"sum_tree",
|
||||||
|
"unindent",
|
||||||
|
"util",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
@ -4,10 +4,16 @@ mod completion_provider;
|
||||||
mod tools;
|
mod tools;
|
||||||
pub mod ui;
|
pub mod ui;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
attachments::ActiveEditorAttachmentTool,
|
||||||
|
tools::{CreateBufferTool, ProjectIndexTool},
|
||||||
|
ui::UserOrAssistant,
|
||||||
|
};
|
||||||
use ::ui::{div, prelude::*, Color, ViewContext};
|
use ::ui::{div, prelude::*, Color, ViewContext};
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use assistant_tooling::{ToolFunctionCall, ToolRegistry};
|
use assistant_tooling::{
|
||||||
use attachments::{ActiveEditorAttachmentTool, UserAttachment, UserAttachmentStore};
|
AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment,
|
||||||
|
};
|
||||||
use client::{proto, Client, UserStore};
|
use client::{proto, Client, UserStore};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use completion_provider::*;
|
use completion_provider::*;
|
||||||
|
@ -34,9 +40,6 @@ use workspace::{
|
||||||
|
|
||||||
pub use assistant_settings::AssistantSettings;
|
pub use assistant_settings::AssistantSettings;
|
||||||
|
|
||||||
use crate::tools::{CreateBufferTool, ProjectIndexTool};
|
|
||||||
use crate::ui::UserOrAssistant;
|
|
||||||
|
|
||||||
const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5;
|
const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5;
|
||||||
|
|
||||||
#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
|
#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
|
||||||
|
@ -85,10 +88,9 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||||
});
|
});
|
||||||
workspace.register_action(|workspace, _: &DebugProjectIndex, cx| {
|
workspace.register_action(|workspace, _: &DebugProjectIndex, cx| {
|
||||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||||
if let Some(index) = panel.read(cx).chat.read(cx).project_index.clone() {
|
let index = panel.read(cx).chat.read(cx).project_index.clone();
|
||||||
let view = cx.new_view(|cx| ProjectIndexDebugView::new(index, cx));
|
let view = cx.new_view(|cx| ProjectIndexDebugView::new(index, cx));
|
||||||
workspace.add_item_to_center(Box::new(view), cx);
|
workspace.add_item_to_center(Box::new(view), cx);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
@ -122,10 +124,7 @@ impl AssistantPanel {
|
||||||
|
|
||||||
let mut tool_registry = ToolRegistry::new();
|
let mut tool_registry = ToolRegistry::new();
|
||||||
tool_registry
|
tool_registry
|
||||||
.register(
|
.register(ProjectIndexTool::new(project_index.clone()), cx)
|
||||||
ProjectIndexTool::new(project_index.clone(), project.read(cx).fs().clone()),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
.context("failed to register ProjectIndexTool")
|
.context("failed to register ProjectIndexTool")
|
||||||
.log_err();
|
.log_err();
|
||||||
tool_registry
|
tool_registry
|
||||||
|
@ -136,7 +135,7 @@ impl AssistantPanel {
|
||||||
.context("failed to register CreateBufferTool")
|
.context("failed to register CreateBufferTool")
|
||||||
.log_err();
|
.log_err();
|
||||||
|
|
||||||
let mut attachment_store = UserAttachmentStore::new();
|
let mut attachment_store = AttachmentRegistry::new();
|
||||||
attachment_store.register(ActiveEditorAttachmentTool::new(workspace.clone(), cx));
|
attachment_store.register(ActiveEditorAttachmentTool::new(workspace.clone(), cx));
|
||||||
|
|
||||||
Self::new(
|
Self::new(
|
||||||
|
@ -144,7 +143,7 @@ impl AssistantPanel {
|
||||||
Arc::new(tool_registry),
|
Arc::new(tool_registry),
|
||||||
Arc::new(attachment_store),
|
Arc::new(attachment_store),
|
||||||
app_state.user_store.clone(),
|
app_state.user_store.clone(),
|
||||||
Some(project_index),
|
project_index,
|
||||||
workspace,
|
workspace,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
@ -155,9 +154,9 @@ impl AssistantPanel {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
tool_registry: Arc<ToolRegistry>,
|
tool_registry: Arc<ToolRegistry>,
|
||||||
attachment_store: Arc<UserAttachmentStore>,
|
attachment_store: Arc<AttachmentRegistry>,
|
||||||
user_store: Model<UserStore>,
|
user_store: Model<UserStore>,
|
||||||
project_index: Option<Model<ProjectIndex>>,
|
project_index: Model<ProjectIndex>,
|
||||||
workspace: WeakView<Workspace>,
|
workspace: WeakView<Workspace>,
|
||||||
cx: &mut ViewContext<Self>,
|
cx: &mut ViewContext<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
@ -241,16 +240,16 @@ pub struct AssistantChat {
|
||||||
list_state: ListState,
|
list_state: ListState,
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
composer_editor: View<Editor>,
|
composer_editor: View<Editor>,
|
||||||
project_index_button: Option<View<ProjectIndexButton>>,
|
project_index_button: View<ProjectIndexButton>,
|
||||||
active_file_button: Option<View<ActiveFileButton>>,
|
active_file_button: Option<View<ActiveFileButton>>,
|
||||||
user_store: Model<UserStore>,
|
user_store: Model<UserStore>,
|
||||||
next_message_id: MessageId,
|
next_message_id: MessageId,
|
||||||
collapsed_messages: HashMap<MessageId, bool>,
|
collapsed_messages: HashMap<MessageId, bool>,
|
||||||
editing_message: Option<EditingMessage>,
|
editing_message: Option<EditingMessage>,
|
||||||
pending_completion: Option<Task<()>>,
|
pending_completion: Option<Task<()>>,
|
||||||
attachment_store: Arc<UserAttachmentStore>,
|
|
||||||
tool_registry: Arc<ToolRegistry>,
|
tool_registry: Arc<ToolRegistry>,
|
||||||
project_index: Option<Model<ProjectIndex>>,
|
attachment_registry: Arc<AttachmentRegistry>,
|
||||||
|
project_index: Model<ProjectIndex>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct EditingMessage {
|
struct EditingMessage {
|
||||||
|
@ -263,9 +262,9 @@ impl AssistantChat {
|
||||||
fn new(
|
fn new(
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
tool_registry: Arc<ToolRegistry>,
|
tool_registry: Arc<ToolRegistry>,
|
||||||
attachment_store: Arc<UserAttachmentStore>,
|
attachment_registry: Arc<AttachmentRegistry>,
|
||||||
user_store: Model<UserStore>,
|
user_store: Model<UserStore>,
|
||||||
project_index: Option<Model<ProjectIndex>>,
|
project_index: Model<ProjectIndex>,
|
||||||
workspace: WeakView<Workspace>,
|
workspace: WeakView<Workspace>,
|
||||||
cx: &mut ViewContext<Self>,
|
cx: &mut ViewContext<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
@ -281,14 +280,14 @@ impl AssistantChat {
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
let project_index_button = project_index.clone().map(|project_index| {
|
let project_index_button = cx.new_view(|cx| {
|
||||||
cx.new_view(|cx| ProjectIndexButton::new(project_index, tool_registry.clone(), cx))
|
ProjectIndexButton::new(project_index.clone(), tool_registry.clone(), cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
let active_file_button = match workspace.upgrade() {
|
let active_file_button = match workspace.upgrade() {
|
||||||
Some(workspace) => {
|
Some(workspace) => {
|
||||||
Some(cx.new_view(
|
Some(cx.new_view(
|
||||||
|cx| ActiveFileButton::new(attachment_store.clone(), workspace, cx), //
|
|cx| ActiveFileButton::new(attachment_registry.clone(), workspace, cx), //
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
_ => None,
|
_ => None,
|
||||||
|
@ -313,7 +312,7 @@ impl AssistantChat {
|
||||||
editing_message: None,
|
editing_message: None,
|
||||||
collapsed_messages: HashMap::default(),
|
collapsed_messages: HashMap::default(),
|
||||||
pending_completion: None,
|
pending_completion: None,
|
||||||
attachment_store,
|
attachment_registry,
|
||||||
tool_registry,
|
tool_registry,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -395,7 +394,7 @@ impl AssistantChat {
|
||||||
let mode = *mode;
|
let mode = *mode;
|
||||||
self.pending_completion = Some(cx.spawn(move |this, mut cx| async move {
|
self.pending_completion = Some(cx.spawn(move |this, mut cx| async move {
|
||||||
let attachments_task = this.update(&mut cx, |this, cx| {
|
let attachments_task = this.update(&mut cx, |this, cx| {
|
||||||
let attachment_store = this.attachment_store.clone();
|
let attachment_store = this.attachment_registry.clone();
|
||||||
attachment_store.call_all_attachment_tools(cx)
|
attachment_store.call_all_attachment_tools(cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -443,7 +442,7 @@ impl AssistantChat {
|
||||||
let mut call_count = 0;
|
let mut call_count = 0;
|
||||||
loop {
|
loop {
|
||||||
let complete = async {
|
let complete = async {
|
||||||
let completion = this.update(cx, |this, cx| {
|
let (tool_definitions, model_name, messages) = this.update(cx, |this, cx| {
|
||||||
this.push_new_assistant_message(cx);
|
this.push_new_assistant_message(cx);
|
||||||
|
|
||||||
let definitions = if call_count < limit
|
let definitions = if call_count < limit
|
||||||
|
@ -455,14 +454,22 @@ impl AssistantChat {
|
||||||
};
|
};
|
||||||
call_count += 1;
|
call_count += 1;
|
||||||
|
|
||||||
let messages = this.completion_messages(cx);
|
(
|
||||||
|
definitions,
|
||||||
CompletionProvider::get(cx).complete(
|
|
||||||
this.model.clone(),
|
this.model.clone(),
|
||||||
|
this.completion_messages(cx),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let messages = messages.await?;
|
||||||
|
|
||||||
|
let completion = cx.update(|cx| {
|
||||||
|
CompletionProvider::get(cx).complete(
|
||||||
|
model_name,
|
||||||
messages,
|
messages,
|
||||||
Vec::new(),
|
Vec::new(),
|
||||||
1.0,
|
1.0,
|
||||||
definitions,
|
tool_definitions,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -765,7 +772,12 @@ impl AssistantChat {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn completion_messages(&self, cx: &mut WindowContext) -> Vec<CompletionMessage> {
|
fn completion_messages(&self, cx: &mut WindowContext) -> Task<Result<Vec<CompletionMessage>>> {
|
||||||
|
let project_index = self.project_index.read(cx);
|
||||||
|
let project = project_index.project();
|
||||||
|
let fs = project_index.fs();
|
||||||
|
|
||||||
|
let mut project_context = ProjectContext::new(project, fs);
|
||||||
let mut completion_messages = Vec::new();
|
let mut completion_messages = Vec::new();
|
||||||
|
|
||||||
for message in &self.messages {
|
for message in &self.messages {
|
||||||
|
@ -773,12 +785,11 @@ impl AssistantChat {
|
||||||
ChatMessage::User(UserMessage {
|
ChatMessage::User(UserMessage {
|
||||||
body, attachments, ..
|
body, attachments, ..
|
||||||
}) => {
|
}) => {
|
||||||
completion_messages.extend(
|
for attachment in attachments {
|
||||||
attachments
|
if let Some(content) = attachment.generate(&mut project_context, cx) {
|
||||||
.into_iter()
|
completion_messages.push(CompletionMessage::System { content });
|
||||||
.filter_map(|attachment| attachment.message.clone())
|
}
|
||||||
.map(|content| CompletionMessage::System { content }),
|
}
|
||||||
);
|
|
||||||
|
|
||||||
// Show user's message last so that the assistant is grounded in the user's request
|
// Show user's message last so that the assistant is grounded in the user's request
|
||||||
completion_messages.push(CompletionMessage::User {
|
completion_messages.push(CompletionMessage::User {
|
||||||
|
@ -815,7 +826,9 @@ impl AssistantChat {
|
||||||
for tool_call in tool_calls {
|
for tool_call in tool_calls {
|
||||||
// Every tool call _must_ have a result by ID, otherwise OpenAI will error.
|
// Every tool call _must_ have a result by ID, otherwise OpenAI will error.
|
||||||
let content = match &tool_call.result {
|
let content = match &tool_call.result {
|
||||||
Some(result) => result.format(&tool_call.name),
|
Some(result) => {
|
||||||
|
result.generate(&tool_call.name, &mut project_context, cx)
|
||||||
|
}
|
||||||
None => "".to_string(),
|
None => "".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -828,7 +841,13 @@ impl AssistantChat {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
completion_messages
|
let system_message = project_context.generate_system_message(cx);
|
||||||
|
|
||||||
|
cx.background_executor().spawn(async move {
|
||||||
|
let content = system_message.await?;
|
||||||
|
completion_messages.insert(0, CompletionMessage::System { content });
|
||||||
|
Ok(completion_messages)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,137 +1,18 @@
|
||||||
use std::{
|
pub mod active_file;
|
||||||
any::TypeId,
|
|
||||||
sync::{
|
|
||||||
atomic::{AtomicBool, Ordering::SeqCst},
|
|
||||||
Arc,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use collections::HashMap;
|
use assistant_tooling::{LanguageModelAttachment, ProjectContext, ToolOutput};
|
||||||
use editor::Editor;
|
use editor::Editor;
|
||||||
use futures::future::join_all;
|
use gpui::{Render, Task, View, WeakModel, WeakView};
|
||||||
use gpui::{AnyView, Render, Task, View, WeakView};
|
use language::Buffer;
|
||||||
|
use project::ProjectPath;
|
||||||
use ui::{prelude::*, ButtonLike, Tooltip, WindowContext};
|
use ui::{prelude::*, ButtonLike, Tooltip, WindowContext};
|
||||||
use util::{maybe, ResultExt};
|
use util::maybe;
|
||||||
use workspace::Workspace;
|
use workspace::Workspace;
|
||||||
|
|
||||||
/// A collected attachment from running an attachment tool
|
|
||||||
pub struct UserAttachment {
|
|
||||||
pub message: Option<String>,
|
|
||||||
pub view: AnyView,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct UserAttachmentStore {
|
|
||||||
attachment_tools: HashMap<TypeId, DynamicAttachment>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Internal representation of an attachment tool to allow us to treat them dynamically
|
|
||||||
struct DynamicAttachment {
|
|
||||||
enabled: AtomicBool,
|
|
||||||
call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UserAttachmentStore {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
attachment_tools: HashMap::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn register<A: AttachmentTool + 'static>(&mut self, attachment: A) {
|
|
||||||
let call = Box::new(move |cx: &mut WindowContext| {
|
|
||||||
let result = attachment.run(cx);
|
|
||||||
|
|
||||||
cx.spawn(move |mut cx| async move {
|
|
||||||
let result: Result<A::Output> = result.await;
|
|
||||||
let message = A::format(&result);
|
|
||||||
let view = cx.update(|cx| A::view(result, cx))?;
|
|
||||||
|
|
||||||
Ok(UserAttachment {
|
|
||||||
message,
|
|
||||||
view: view.into(),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
self.attachment_tools.insert(
|
|
||||||
TypeId::of::<A>(),
|
|
||||||
DynamicAttachment {
|
|
||||||
call,
|
|
||||||
enabled: AtomicBool::new(true),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_attachment_tool_enabled<A: AttachmentTool + 'static>(&self, is_enabled: bool) {
|
|
||||||
if let Some(attachment) = self.attachment_tools.get(&TypeId::of::<A>()) {
|
|
||||||
attachment.enabled.store(is_enabled, SeqCst);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_attachment_tool_enabled<A: AttachmentTool + 'static>(&self) -> bool {
|
|
||||||
if let Some(attachment) = self.attachment_tools.get(&TypeId::of::<A>()) {
|
|
||||||
attachment.enabled.load(SeqCst)
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call<A: AttachmentTool + 'static>(
|
|
||||||
&self,
|
|
||||||
cx: &mut WindowContext,
|
|
||||||
) -> Task<Result<UserAttachment>> {
|
|
||||||
let Some(attachment) = self.attachment_tools.get(&TypeId::of::<A>()) else {
|
|
||||||
return Task::ready(Err(anyhow!("no attachment tool")));
|
|
||||||
};
|
|
||||||
|
|
||||||
(attachment.call)(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call_all_attachment_tools(
|
|
||||||
self: Arc<Self>,
|
|
||||||
cx: &mut WindowContext<'_>,
|
|
||||||
) -> Task<Result<Vec<UserAttachment>>> {
|
|
||||||
let this = self.clone();
|
|
||||||
cx.spawn(|mut cx| async move {
|
|
||||||
let attachment_tasks = cx.update(|cx| {
|
|
||||||
let mut tasks = Vec::new();
|
|
||||||
for attachment in this
|
|
||||||
.attachment_tools
|
|
||||||
.values()
|
|
||||||
.filter(|attachment| attachment.enabled.load(SeqCst))
|
|
||||||
{
|
|
||||||
tasks.push((attachment.call)(cx))
|
|
||||||
}
|
|
||||||
|
|
||||||
tasks
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let attachments = join_all(attachment_tasks.into_iter()).await;
|
|
||||||
|
|
||||||
Ok(attachments
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|attachment| attachment.log_err())
|
|
||||||
.collect())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait AttachmentTool {
|
|
||||||
type Output: 'static;
|
|
||||||
type View: Render;
|
|
||||||
|
|
||||||
fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
|
|
||||||
|
|
||||||
fn format(output: &Result<Self::Output>) -> Option<String>;
|
|
||||||
|
|
||||||
fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ActiveEditorAttachment {
|
pub struct ActiveEditorAttachment {
|
||||||
filename: Arc<str>,
|
buffer: WeakModel<Buffer>,
|
||||||
language: Arc<str>,
|
path: Option<ProjectPath>,
|
||||||
text: Arc<str>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct FileAttachmentView {
|
pub struct FileAttachmentView {
|
||||||
|
@ -142,7 +23,13 @@ impl Render for FileAttachmentView {
|
||||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||||
match &self.output {
|
match &self.output {
|
||||||
Ok(attachment) => {
|
Ok(attachment) => {
|
||||||
let filename = attachment.filename.clone();
|
let filename: SharedString = attachment
|
||||||
|
.path
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|p| p.path.file_name()?.to_str())
|
||||||
|
.unwrap_or("Untitled")
|
||||||
|
.to_string()
|
||||||
|
.into();
|
||||||
|
|
||||||
// todo!(): make the button link to the actual file to open
|
// todo!(): make the button link to the actual file to open
|
||||||
ButtonLike::new("file-attachment")
|
ButtonLike::new("file-attachment")
|
||||||
|
@ -152,7 +39,7 @@ impl Render for FileAttachmentView {
|
||||||
.bg(cx.theme().colors().editor_background)
|
.bg(cx.theme().colors().editor_background)
|
||||||
.rounded_md()
|
.rounded_md()
|
||||||
.child(ui::Icon::new(IconName::File))
|
.child(ui::Icon::new(IconName::File))
|
||||||
.child(filename.to_string()),
|
.child(filename.clone()),
|
||||||
)
|
)
|
||||||
.tooltip({
|
.tooltip({
|
||||||
move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx)
|
move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx)
|
||||||
|
@ -164,6 +51,20 @@ impl Render for FileAttachmentView {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ToolOutput for FileAttachmentView {
|
||||||
|
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
|
||||||
|
if let Ok(result) = &self.output {
|
||||||
|
if let Some(path) = &result.path {
|
||||||
|
project.add_file(path.clone());
|
||||||
|
return format!("current file: {}", path.path.display());
|
||||||
|
} else if let Some(buffer) = result.buffer.upgrade() {
|
||||||
|
return format!("current untitled buffer text:\n{}", buffer.read(cx).text());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ActiveEditorAttachmentTool {
|
pub struct ActiveEditorAttachmentTool {
|
||||||
workspace: WeakView<Workspace>,
|
workspace: WeakView<Workspace>,
|
||||||
}
|
}
|
||||||
|
@ -174,7 +75,7 @@ impl ActiveEditorAttachmentTool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AttachmentTool for ActiveEditorAttachmentTool {
|
impl LanguageModelAttachment for ActiveEditorAttachmentTool {
|
||||||
type Output = ActiveEditorAttachment;
|
type Output = ActiveEditorAttachment;
|
||||||
type View = FileAttachmentView;
|
type View = FileAttachmentView;
|
||||||
|
|
||||||
|
@ -191,47 +92,22 @@ impl AttachmentTool for ActiveEditorAttachmentTool {
|
||||||
|
|
||||||
let buffer = active_buffer.read(cx);
|
let buffer = active_buffer.read(cx);
|
||||||
|
|
||||||
if let Some(singleton) = buffer.as_singleton() {
|
if let Some(buffer) = buffer.as_singleton() {
|
||||||
let singleton = singleton.read(cx);
|
let path =
|
||||||
|
project::File::from_dyn(buffer.read(cx).file()).map(|file| ProjectPath {
|
||||||
let filename = singleton
|
worktree_id: file.worktree_id(cx),
|
||||||
.file()
|
path: file.path.clone(),
|
||||||
.map(|file| file.path().to_string_lossy())
|
});
|
||||||
.unwrap_or("Untitled".into());
|
|
||||||
|
|
||||||
let text = singleton.text();
|
|
||||||
|
|
||||||
let language = singleton
|
|
||||||
.language()
|
|
||||||
.map(|l| {
|
|
||||||
let name = l.code_fence_block_name();
|
|
||||||
name.to_string()
|
|
||||||
})
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
return Ok(ActiveEditorAttachment {
|
return Ok(ActiveEditorAttachment {
|
||||||
filename: filename.into(),
|
buffer: buffer.downgrade(),
|
||||||
language: language.into(),
|
path,
|
||||||
text: text.into(),
|
|
||||||
});
|
});
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("no active buffer"))
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(anyhow!("no active buffer"))
|
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format(output: &Result<Self::Output>) -> Option<String> {
|
|
||||||
let output = output.as_ref().ok()?;
|
|
||||||
|
|
||||||
let filename = &output.filename;
|
|
||||||
let language = &output.language;
|
|
||||||
let text = &output.text;
|
|
||||||
|
|
||||||
Some(format!(
|
|
||||||
"User's active file `{filename}`:\n\n```{language}\n{text}```\n\n"
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View> {
|
fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View> {
|
||||||
cx.new_view(|_cx| FileAttachmentView { output })
|
cx.new_view(|_cx| FileAttachmentView { output })
|
||||||
}
|
}
|
||||||
|
|
1
crates/assistant2/src/attachments/active_file.rs
Normal file
1
crates/assistant2/src/attachments/active_file.rs
Normal file
|
@ -0,0 +1 @@
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use assistant_tooling::LanguageModelTool;
|
use assistant_tooling::{LanguageModelTool, ProjectContext, ToolOutput};
|
||||||
use editor::Editor;
|
use editor::Editor;
|
||||||
use gpui::{prelude::*, Model, Task, View, WeakView};
|
use gpui::{prelude::*, Model, Task, View, WeakView};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
@ -31,11 +31,9 @@ pub struct CreateBufferInput {
|
||||||
language: String,
|
language: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CreateBufferOutput {}
|
|
||||||
|
|
||||||
impl LanguageModelTool for CreateBufferTool {
|
impl LanguageModelTool for CreateBufferTool {
|
||||||
type Input = CreateBufferInput;
|
type Input = CreateBufferInput;
|
||||||
type Output = CreateBufferOutput;
|
type Output = ();
|
||||||
type View = CreateBufferView;
|
type View = CreateBufferView;
|
||||||
|
|
||||||
fn name(&self) -> String {
|
fn name(&self) -> String {
|
||||||
|
@ -83,32 +81,39 @@ impl LanguageModelTool for CreateBufferTool {
|
||||||
})
|
})
|
||||||
.log_err();
|
.log_err();
|
||||||
|
|
||||||
Ok(CreateBufferOutput {})
|
Ok(())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format(input: &Self::Input, output: &Result<Self::Output>) -> String {
|
|
||||||
match output {
|
|
||||||
Ok(_) => format!("Created a new {} buffer", input.language),
|
|
||||||
Err(err) => format!("Failed to create buffer: {err:?}"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn output_view(
|
fn output_view(
|
||||||
_tool_call_id: String,
|
input: Self::Input,
|
||||||
_input: Self::Input,
|
output: Result<Self::Output>,
|
||||||
_output: Result<Self::Output>,
|
|
||||||
cx: &mut WindowContext,
|
cx: &mut WindowContext,
|
||||||
) -> View<Self::View> {
|
) -> View<Self::View> {
|
||||||
cx.new_view(|_cx| CreateBufferView {})
|
cx.new_view(|_cx| CreateBufferView {
|
||||||
|
language: input.language,
|
||||||
|
output,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CreateBufferView {}
|
pub struct CreateBufferView {
|
||||||
|
language: String,
|
||||||
|
output: Result<()>,
|
||||||
|
}
|
||||||
|
|
||||||
impl Render for CreateBufferView {
|
impl Render for CreateBufferView {
|
||||||
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||||
div().child("Opening a buffer")
|
div().child("Opening a buffer")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ToolOutput for CreateBufferView {
|
||||||
|
fn generate(&self, _: &mut ProjectContext, _: &mut WindowContext) -> String {
|
||||||
|
match &self.output {
|
||||||
|
Ok(_) => format!("Created a new {} buffer", self.language),
|
||||||
|
Err(err) => format!("Failed to create buffer: {err:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,25 +1,18 @@
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use assistant_tooling::LanguageModelTool;
|
use assistant_tooling::{LanguageModelTool, ToolOutput};
|
||||||
|
use collections::BTreeMap;
|
||||||
use gpui::{prelude::*, Model, Task};
|
use gpui::{prelude::*, Model, Task};
|
||||||
use project::Fs;
|
use project::ProjectPath;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use semantic_index::{ProjectIndex, Status};
|
use semantic_index::{ProjectIndex, Status};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::{collections::HashSet, sync::Arc};
|
use std::{fmt::Write as _, ops::Range};
|
||||||
|
use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
|
||||||
use ui::{
|
|
||||||
div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString,
|
|
||||||
WindowContext,
|
|
||||||
};
|
|
||||||
use util::ResultExt as _;
|
|
||||||
|
|
||||||
const DEFAULT_SEARCH_LIMIT: usize = 20;
|
const DEFAULT_SEARCH_LIMIT: usize = 20;
|
||||||
|
|
||||||
#[derive(Clone)]
|
pub struct ProjectIndexTool {
|
||||||
pub struct CodebaseExcerpt {
|
project_index: Model<ProjectIndex>,
|
||||||
path: SharedString,
|
|
||||||
text: SharedString,
|
|
||||||
score: f32,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model.
|
// Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model.
|
||||||
|
@ -40,6 +33,11 @@ pub struct ProjectIndexView {
|
||||||
expanded_header: bool,
|
expanded_header: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct ProjectIndexOutput {
|
||||||
|
status: Status,
|
||||||
|
excerpts: BTreeMap<ProjectPath, Vec<Range<usize>>>,
|
||||||
|
}
|
||||||
|
|
||||||
impl ProjectIndexView {
|
impl ProjectIndexView {
|
||||||
fn new(input: CodebaseQuery, output: Result<ProjectIndexOutput>) -> Self {
|
fn new(input: CodebaseQuery, output: Result<ProjectIndexOutput>) -> Self {
|
||||||
let element_id = ElementId::Name(nanoid::nanoid!().into());
|
let element_id = ElementId::Name(nanoid::nanoid!().into());
|
||||||
|
@ -71,19 +69,15 @@ impl Render for ProjectIndexView {
|
||||||
Ok(output) => output,
|
Ok(output) => output,
|
||||||
};
|
};
|
||||||
|
|
||||||
let num_files_searched = output.files_searched.len();
|
let file_count = output.excerpts.len();
|
||||||
|
|
||||||
let header = h_flex()
|
let header = h_flex()
|
||||||
.gap_2()
|
.gap_2()
|
||||||
.child(Icon::new(IconName::File))
|
.child(Icon::new(IconName::File))
|
||||||
.child(format!(
|
.child(format!(
|
||||||
"Read {} {}",
|
"Read {} {}",
|
||||||
num_files_searched,
|
file_count,
|
||||||
if num_files_searched == 1 {
|
if file_count == 1 { "file" } else { "files" }
|
||||||
"file"
|
|
||||||
} else {
|
|
||||||
"files"
|
|
||||||
}
|
|
||||||
));
|
));
|
||||||
|
|
||||||
v_flex().gap_3().child(
|
v_flex().gap_3().child(
|
||||||
|
@ -102,36 +96,50 @@ impl Render for ProjectIndexView {
|
||||||
.child(Icon::new(IconName::MagnifyingGlass))
|
.child(Icon::new(IconName::MagnifyingGlass))
|
||||||
.child(Label::new(format!("`{}`", query)).color(Color::Muted)),
|
.child(Label::new(format!("`{}`", query)).color(Color::Muted)),
|
||||||
)
|
)
|
||||||
.child(v_flex().gap_2().children(output.files_searched.iter().map(
|
.child(
|
||||||
|path| {
|
v_flex()
|
||||||
h_flex()
|
.gap_2()
|
||||||
.gap_2()
|
.children(output.excerpts.keys().map(|path| {
|
||||||
.child(Icon::new(IconName::File))
|
h_flex().gap_2().child(Icon::new(IconName::File)).child(
|
||||||
.child(Label::new(path.clone()).color(Color::Muted))
|
Label::new(path.path.to_string_lossy().to_string())
|
||||||
},
|
.color(Color::Muted),
|
||||||
))),
|
)
|
||||||
|
})),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ProjectIndexTool {
|
impl ToolOutput for ProjectIndexView {
|
||||||
project_index: Model<ProjectIndex>,
|
fn generate(
|
||||||
fs: Arc<dyn Fs>,
|
&self,
|
||||||
}
|
context: &mut assistant_tooling::ProjectContext,
|
||||||
|
_: &mut WindowContext,
|
||||||
|
) -> String {
|
||||||
|
match &self.output {
|
||||||
|
Ok(output) => {
|
||||||
|
let mut body = "found results in the following paths:\n".to_string();
|
||||||
|
|
||||||
pub struct ProjectIndexOutput {
|
for (project_path, ranges) in &output.excerpts {
|
||||||
excerpts: Vec<CodebaseExcerpt>,
|
context.add_excerpts(project_path.clone(), ranges);
|
||||||
status: Status,
|
writeln!(&mut body, "* {}", &project_path.path.display()).unwrap();
|
||||||
files_searched: HashSet<SharedString>,
|
}
|
||||||
|
|
||||||
|
if output.status != Status::Idle {
|
||||||
|
body.push_str("Still indexing. Results may be incomplete.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
body
|
||||||
|
}
|
||||||
|
Err(err) => format!("Error: {}", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ProjectIndexTool {
|
impl ProjectIndexTool {
|
||||||
pub fn new(project_index: Model<ProjectIndex>, fs: Arc<dyn Fs>) -> Self {
|
pub fn new(project_index: Model<ProjectIndex>) -> Self {
|
||||||
// Listen for project index status and update the ProjectIndexTool directly
|
Self { project_index }
|
||||||
|
|
||||||
// TODO: setup a better description based on the user's current codebase.
|
|
||||||
Self { project_index, fs }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -151,64 +159,42 @@ impl LanguageModelTool for ProjectIndexTool {
|
||||||
fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
|
fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
|
||||||
let project_index = self.project_index.read(cx);
|
let project_index = self.project_index.read(cx);
|
||||||
let status = project_index.status();
|
let status = project_index.status();
|
||||||
let results = project_index.search(
|
let search = project_index.search(
|
||||||
query.query.clone(),
|
query.query.clone(),
|
||||||
query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
|
query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
|
||||||
let fs = self.fs.clone();
|
cx.spawn(|mut cx| async move {
|
||||||
|
let search_results = search.await?;
|
||||||
|
|
||||||
cx.spawn(|cx| async move {
|
cx.update(|cx| {
|
||||||
let results = results.await?;
|
let mut output = ProjectIndexOutput {
|
||||||
|
status,
|
||||||
|
excerpts: Default::default(),
|
||||||
|
};
|
||||||
|
|
||||||
let excerpts = results.into_iter().map(|result| {
|
for search_result in search_results {
|
||||||
let abs_path = result
|
let path = ProjectPath {
|
||||||
.worktree
|
worktree_id: search_result.worktree.read(cx).id(),
|
||||||
.read_with(&cx, |worktree, _| worktree.abs_path().join(&result.path));
|
path: search_result.path.clone(),
|
||||||
let fs = fs.clone();
|
};
|
||||||
|
|
||||||
async move {
|
let excerpts_for_path = output.excerpts.entry(path).or_default();
|
||||||
let path = result.path.clone();
|
let ix = match excerpts_for_path
|
||||||
let text = fs.load(&abs_path?).await?;
|
.binary_search_by_key(&search_result.range.start, |r| r.start)
|
||||||
|
{
|
||||||
let mut start = result.range.start;
|
Ok(ix) | Err(ix) => ix,
|
||||||
let mut end = result.range.end.min(text.len());
|
};
|
||||||
while !text.is_char_boundary(start) {
|
excerpts_for_path.insert(ix, search_result.range);
|
||||||
start += 1;
|
|
||||||
}
|
|
||||||
while !text.is_char_boundary(end) {
|
|
||||||
end -= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
anyhow::Ok(CodebaseExcerpt {
|
|
||||||
path: path.to_string_lossy().to_string().into(),
|
|
||||||
text: SharedString::from(text[start..end].to_string()),
|
|
||||||
score: result.score,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
});
|
|
||||||
|
|
||||||
let mut files_searched = HashSet::new();
|
output
|
||||||
let excerpts = futures::future::join_all(excerpts)
|
|
||||||
.await
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|result| result.log_err())
|
|
||||||
.inspect(|excerpt| {
|
|
||||||
files_searched.insert(excerpt.path.clone());
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
anyhow::Ok(ProjectIndexOutput {
|
|
||||||
excerpts,
|
|
||||||
status,
|
|
||||||
files_searched,
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn output_view(
|
fn output_view(
|
||||||
_tool_call_id: String,
|
|
||||||
input: Self::Input,
|
input: Self::Input,
|
||||||
output: Result<Self::Output>,
|
output: Result<Self::Output>,
|
||||||
cx: &mut WindowContext,
|
cx: &mut WindowContext,
|
||||||
|
@ -220,34 +206,4 @@ impl LanguageModelTool for ProjectIndexTool {
|
||||||
CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false)
|
CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false)
|
||||||
.start_slot("Searching code base")
|
.start_slot("Searching code base")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
|
|
||||||
match &output {
|
|
||||||
Ok(output) => {
|
|
||||||
let mut body = "Semantic search results:\n".to_string();
|
|
||||||
|
|
||||||
if output.status != Status::Idle {
|
|
||||||
body.push_str("Still indexing. Results may be incomplete.\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
if output.excerpts.is_empty() {
|
|
||||||
body.push_str("No results found");
|
|
||||||
return body;
|
|
||||||
}
|
|
||||||
|
|
||||||
for excerpt in &output.excerpts {
|
|
||||||
body.push_str("Excerpt from ");
|
|
||||||
body.push_str(excerpt.path.as_ref());
|
|
||||||
body.push_str(", score ");
|
|
||||||
body.push_str(&excerpt.score.to_string());
|
|
||||||
body.push_str(":\n");
|
|
||||||
body.push_str("~~~\n");
|
|
||||||
body.push_str(excerpt.text.as_ref());
|
|
||||||
body.push_str("~~~\n");
|
|
||||||
}
|
|
||||||
body
|
|
||||||
}
|
|
||||||
Err(err) => format!("Error: {}", err),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use crate::attachments::{ActiveEditorAttachmentTool, UserAttachmentStore};
|
use crate::attachments::ActiveEditorAttachmentTool;
|
||||||
|
use assistant_tooling::AttachmentRegistry;
|
||||||
use editor::Editor;
|
use editor::Editor;
|
||||||
use gpui::{prelude::*, Subscription, View};
|
use gpui::{prelude::*, Subscription, View};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -13,7 +14,7 @@ enum Status {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ActiveFileButton {
|
pub struct ActiveFileButton {
|
||||||
attachment_store: Arc<UserAttachmentStore>,
|
attachment_registry: Arc<AttachmentRegistry>,
|
||||||
status: Status,
|
status: Status,
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
workspace_subscription: Subscription,
|
workspace_subscription: Subscription,
|
||||||
|
@ -21,7 +22,7 @@ pub struct ActiveFileButton {
|
||||||
|
|
||||||
impl ActiveFileButton {
|
impl ActiveFileButton {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
attachment_store: Arc<UserAttachmentStore>,
|
attachment_store: Arc<AttachmentRegistry>,
|
||||||
workspace: View<Workspace>,
|
workspace: View<Workspace>,
|
||||||
cx: &mut ViewContext<Self>,
|
cx: &mut ViewContext<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
@ -30,14 +31,14 @@ impl ActiveFileButton {
|
||||||
cx.defer(move |this, cx| this.update_active_buffer(workspace.clone(), cx));
|
cx.defer(move |this, cx| this.update_active_buffer(workspace.clone(), cx));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
attachment_store,
|
attachment_registry: attachment_store,
|
||||||
status: Status::NoFile,
|
status: Status::NoFile,
|
||||||
workspace_subscription,
|
workspace_subscription,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_enabled(&mut self, enabled: bool) {
|
pub fn set_enabled(&mut self, enabled: bool) {
|
||||||
self.attachment_store
|
self.attachment_registry
|
||||||
.set_attachment_tool_enabled::<ActiveEditorAttachmentTool>(enabled);
|
.set_attachment_tool_enabled::<ActiveEditorAttachmentTool>(enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,7 +80,7 @@ impl ActiveFileButton {
|
||||||
impl Render for ActiveFileButton {
|
impl Render for ActiveFileButton {
|
||||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||||
let is_enabled = self
|
let is_enabled = self
|
||||||
.attachment_store
|
.attachment_registry
|
||||||
.is_attachment_tool_enabled::<ActiveEditorAttachmentTool>();
|
.is_attachment_tool_enabled::<ActiveEditorAttachmentTool>();
|
||||||
|
|
||||||
let icon = if is_enabled {
|
let icon = if is_enabled {
|
||||||
|
|
|
@ -11,7 +11,7 @@ use ui::{popover_menu, prelude::*, ButtonLike, ContextMenu, Divider, TextSize, T
|
||||||
#[derive(IntoElement)]
|
#[derive(IntoElement)]
|
||||||
pub struct Composer {
|
pub struct Composer {
|
||||||
editor: View<Editor>,
|
editor: View<Editor>,
|
||||||
project_index_button: Option<View<ProjectIndexButton>>,
|
project_index_button: View<ProjectIndexButton>,
|
||||||
active_file_button: Option<View<ActiveFileButton>>,
|
active_file_button: Option<View<ActiveFileButton>>,
|
||||||
model_selector: AnyElement,
|
model_selector: AnyElement,
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,7 @@ pub struct Composer {
|
||||||
impl Composer {
|
impl Composer {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
editor: View<Editor>,
|
editor: View<Editor>,
|
||||||
project_index_button: Option<View<ProjectIndexButton>>,
|
project_index_button: View<ProjectIndexButton>,
|
||||||
active_file_button: Option<View<ActiveFileButton>>,
|
active_file_button: Option<View<ActiveFileButton>>,
|
||||||
model_selector: AnyElement,
|
model_selector: AnyElement,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
@ -32,11 +32,7 @@ impl Composer {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement {
|
fn render_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement {
|
||||||
h_flex().children(
|
h_flex().child(self.project_index_button.clone())
|
||||||
self.project_index_button
|
|
||||||
.clone()
|
|
||||||
.map(|view| view.into_any_element()),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_attachment_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement {
|
fn render_attachment_tools(&mut self, _cx: &mut WindowContext) -> impl IntoElement {
|
||||||
|
|
|
@ -13,10 +13,18 @@ path = "src/assistant_tooling.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
|
collections.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
|
project.workspace = true
|
||||||
schemars.workspace = true
|
schemars.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
sum_tree.workspace = true
|
||||||
|
util.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
gpui = { workspace = true, features = ["test-support"] }
|
gpui = { workspace = true, features = ["test-support"] }
|
||||||
|
project = { workspace = true, features = ["test-support"] }
|
||||||
|
settings = { workspace = true, features = ["test-support"] }
|
||||||
|
unindent.workspace = true
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
pub mod registry;
|
mod attachment_registry;
|
||||||
pub mod tool;
|
mod project_context;
|
||||||
|
mod tool_registry;
|
||||||
|
|
||||||
pub use crate::registry::ToolRegistry;
|
pub use attachment_registry::{AttachmentRegistry, LanguageModelAttachment, UserAttachment};
|
||||||
pub use crate::tool::{LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition};
|
pub use project_context::ProjectContext;
|
||||||
|
pub use tool_registry::{
|
||||||
|
LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition, ToolOutput, ToolRegistry,
|
||||||
|
};
|
||||||
|
|
148
crates/assistant_tooling/src/attachment_registry.rs
Normal file
148
crates/assistant_tooling/src/attachment_registry.rs
Normal file
|
@ -0,0 +1,148 @@
|
||||||
|
use crate::{ProjectContext, ToolOutput};
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use collections::HashMap;
|
||||||
|
use futures::future::join_all;
|
||||||
|
use gpui::{AnyView, Render, Task, View, WindowContext};
|
||||||
|
use std::{
|
||||||
|
any::TypeId,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicBool, Ordering::SeqCst},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use util::ResultExt as _;
|
||||||
|
|
||||||
|
pub struct AttachmentRegistry {
|
||||||
|
registered_attachments: HashMap<TypeId, RegisteredAttachment>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait LanguageModelAttachment {
|
||||||
|
type Output: 'static;
|
||||||
|
type View: Render + ToolOutput;
|
||||||
|
|
||||||
|
fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
|
||||||
|
|
||||||
|
fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A collected attachment from running an attachment tool
|
||||||
|
pub struct UserAttachment {
|
||||||
|
pub view: AnyView,
|
||||||
|
generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Internal representation of an attachment tool to allow us to treat them dynamically
|
||||||
|
struct RegisteredAttachment {
|
||||||
|
enabled: AtomicBool,
|
||||||
|
call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AttachmentRegistry {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
registered_attachments: HashMap::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
|
||||||
|
let call = Box::new(move |cx: &mut WindowContext| {
|
||||||
|
let result = attachment.run(cx);
|
||||||
|
|
||||||
|
cx.spawn(move |mut cx| async move {
|
||||||
|
let result: Result<A::Output> = result.await;
|
||||||
|
let view = cx.update(|cx| A::view(result, cx))?;
|
||||||
|
|
||||||
|
Ok(UserAttachment {
|
||||||
|
view: view.into(),
|
||||||
|
generate_fn: generate::<A>,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
self.registered_attachments.insert(
|
||||||
|
TypeId::of::<A>(),
|
||||||
|
RegisteredAttachment {
|
||||||
|
call,
|
||||||
|
enabled: AtomicBool::new(true),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
|
||||||
|
fn generate<T: LanguageModelAttachment>(
|
||||||
|
view: AnyView,
|
||||||
|
project: &mut ProjectContext,
|
||||||
|
cx: &mut WindowContext,
|
||||||
|
) -> String {
|
||||||
|
view.downcast::<T::View>()
|
||||||
|
.unwrap()
|
||||||
|
.update(cx, |view, cx| T::View::generate(view, project, cx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(
|
||||||
|
&self,
|
||||||
|
is_enabled: bool,
|
||||||
|
) {
|
||||||
|
if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
|
||||||
|
attachment.enabled.store(is_enabled, SeqCst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(&self) -> bool {
|
||||||
|
if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
|
||||||
|
attachment.enabled.load(SeqCst)
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call<A: LanguageModelAttachment + 'static>(
|
||||||
|
&self,
|
||||||
|
cx: &mut WindowContext,
|
||||||
|
) -> Task<Result<UserAttachment>> {
|
||||||
|
let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) else {
|
||||||
|
return Task::ready(Err(anyhow!("no attachment tool")));
|
||||||
|
};
|
||||||
|
|
||||||
|
(attachment.call)(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call_all_attachment_tools(
|
||||||
|
self: Arc<Self>,
|
||||||
|
cx: &mut WindowContext<'_>,
|
||||||
|
) -> Task<Result<Vec<UserAttachment>>> {
|
||||||
|
let this = self.clone();
|
||||||
|
cx.spawn(|mut cx| async move {
|
||||||
|
let attachment_tasks = cx.update(|cx| {
|
||||||
|
let mut tasks = Vec::new();
|
||||||
|
for attachment in this
|
||||||
|
.registered_attachments
|
||||||
|
.values()
|
||||||
|
.filter(|attachment| attachment.enabled.load(SeqCst))
|
||||||
|
{
|
||||||
|
tasks.push((attachment.call)(cx))
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let attachments = join_all(attachment_tasks.into_iter()).await;
|
||||||
|
|
||||||
|
Ok(attachments
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|attachment| attachment.log_err())
|
||||||
|
.collect())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserAttachment {
|
||||||
|
pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option<String> {
|
||||||
|
let result = (self.generate_fn)(self.view.clone(), output, cx);
|
||||||
|
if result.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
296
crates/assistant_tooling/src/project_context.rs
Normal file
296
crates/assistant_tooling/src/project_context.rs
Normal file
|
@ -0,0 +1,296 @@
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use gpui::{AppContext, Model, Task, WeakModel};
|
||||||
|
use project::{Fs, Project, ProjectPath, Worktree};
|
||||||
|
use std::{cmp::Ordering, fmt::Write as _, ops::Range, sync::Arc};
|
||||||
|
use sum_tree::TreeMap;
|
||||||
|
|
||||||
|
pub struct ProjectContext {
|
||||||
|
files: TreeMap<ProjectPath, PathState>,
|
||||||
|
project: WeakModel<Project>,
|
||||||
|
fs: Arc<dyn Fs>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum PathState {
|
||||||
|
PathOnly,
|
||||||
|
EntireFile,
|
||||||
|
Excerpts { ranges: Vec<Range<usize>> },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProjectContext {
|
||||||
|
pub fn new(project: WeakModel<Project>, fs: Arc<dyn Fs>) -> Self {
|
||||||
|
Self {
|
||||||
|
files: TreeMap::default(),
|
||||||
|
fs,
|
||||||
|
project,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_path(&mut self, project_path: ProjectPath) {
|
||||||
|
if self.files.get(&project_path).is_none() {
|
||||||
|
self.files.insert(project_path, PathState::PathOnly);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_excerpts(&mut self, project_path: ProjectPath, new_ranges: &[Range<usize>]) {
|
||||||
|
let previous_state = self
|
||||||
|
.files
|
||||||
|
.get(&project_path)
|
||||||
|
.unwrap_or(&PathState::PathOnly);
|
||||||
|
|
||||||
|
let mut ranges = match previous_state {
|
||||||
|
PathState::EntireFile => return,
|
||||||
|
PathState::PathOnly => Vec::new(),
|
||||||
|
PathState::Excerpts { ranges } => ranges.to_vec(),
|
||||||
|
};
|
||||||
|
|
||||||
|
for new_range in new_ranges {
|
||||||
|
let ix = ranges.binary_search_by(|probe| {
|
||||||
|
if probe.end < new_range.start {
|
||||||
|
Ordering::Less
|
||||||
|
} else if probe.start > new_range.end {
|
||||||
|
Ordering::Greater
|
||||||
|
} else {
|
||||||
|
Ordering::Equal
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
match ix {
|
||||||
|
Ok(mut ix) => {
|
||||||
|
let existing = &mut ranges[ix];
|
||||||
|
existing.start = existing.start.min(new_range.start);
|
||||||
|
existing.end = existing.end.max(new_range.end);
|
||||||
|
while ix + 1 < ranges.len() && ranges[ix + 1].start <= ranges[ix].end {
|
||||||
|
ranges[ix].end = ranges[ix].end.max(ranges[ix + 1].end);
|
||||||
|
ranges.remove(ix + 1);
|
||||||
|
}
|
||||||
|
while ix > 0 && ranges[ix - 1].end >= ranges[ix].start {
|
||||||
|
ranges[ix].start = ranges[ix].start.min(ranges[ix - 1].start);
|
||||||
|
ranges.remove(ix - 1);
|
||||||
|
ix -= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(ix) => {
|
||||||
|
ranges.insert(ix, new_range.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.files
|
||||||
|
.insert(project_path, PathState::Excerpts { ranges });
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_file(&mut self, project_path: ProjectPath) {
|
||||||
|
self.files.insert(project_path, PathState::EntireFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_system_message(&self, cx: &mut AppContext) -> Task<Result<String>> {
|
||||||
|
let project = self
|
||||||
|
.project
|
||||||
|
.upgrade()
|
||||||
|
.ok_or_else(|| anyhow!("project dropped"));
|
||||||
|
let files = self.files.clone();
|
||||||
|
let fs = self.fs.clone();
|
||||||
|
cx.spawn(|cx| async move {
|
||||||
|
let project = project?;
|
||||||
|
let mut result = "project structure:\n".to_string();
|
||||||
|
|
||||||
|
let mut last_worktree: Option<Model<Worktree>> = None;
|
||||||
|
for (project_path, path_state) in files.iter() {
|
||||||
|
if let Some(worktree) = &last_worktree {
|
||||||
|
if worktree.read_with(&cx, |tree, _| tree.id())? != project_path.worktree_id {
|
||||||
|
last_worktree = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let worktree;
|
||||||
|
if let Some(last_worktree) = &last_worktree {
|
||||||
|
worktree = last_worktree.clone();
|
||||||
|
} else if let Some(tree) = project.read_with(&cx, |project, cx| {
|
||||||
|
project.worktree_for_id(project_path.worktree_id, cx)
|
||||||
|
})? {
|
||||||
|
worktree = tree;
|
||||||
|
last_worktree = Some(worktree.clone());
|
||||||
|
let worktree_name =
|
||||||
|
worktree.read_with(&cx, |tree, _cx| tree.root_name().to_string())?;
|
||||||
|
writeln!(&mut result, "# {}", worktree_name).unwrap();
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let worktree_abs_path = worktree.read_with(&cx, |tree, _cx| tree.abs_path())?;
|
||||||
|
let path = &project_path.path;
|
||||||
|
writeln!(&mut result, "## {}", path.display()).unwrap();
|
||||||
|
|
||||||
|
match path_state {
|
||||||
|
PathState::PathOnly => {}
|
||||||
|
PathState::EntireFile => {
|
||||||
|
let text = fs.load(&worktree_abs_path.join(&path)).await?;
|
||||||
|
writeln!(&mut result, "~~~\n{text}\n~~~").unwrap();
|
||||||
|
}
|
||||||
|
PathState::Excerpts { ranges } => {
|
||||||
|
let text = fs.load(&worktree_abs_path.join(&path)).await?;
|
||||||
|
|
||||||
|
writeln!(&mut result, "~~~").unwrap();
|
||||||
|
|
||||||
|
// Assumption: ranges are in order, not overlapping
|
||||||
|
let mut prev_range_end = 0;
|
||||||
|
for range in ranges {
|
||||||
|
if range.start > prev_range_end {
|
||||||
|
writeln!(&mut result, "...").unwrap();
|
||||||
|
prev_range_end = range.end;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut start = range.start;
|
||||||
|
let mut end = range.end.min(text.len());
|
||||||
|
while !text.is_char_boundary(start) {
|
||||||
|
start += 1;
|
||||||
|
}
|
||||||
|
while !text.is_char_boundary(end) {
|
||||||
|
end -= 1;
|
||||||
|
}
|
||||||
|
result.push_str(&text[start..end]);
|
||||||
|
if !result.ends_with('\n') {
|
||||||
|
result.push('\n');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if prev_range_end < text.len() {
|
||||||
|
writeln!(&mut result, "...").unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
writeln!(&mut result, "~~~").unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use gpui::TestAppContext;
|
||||||
|
use project::FakeFs;
|
||||||
|
use serde_json::json;
|
||||||
|
use settings::SettingsStore;
|
||||||
|
|
||||||
|
use unindent::Unindent as _;
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_system_message_generation(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
|
||||||
|
let file_3_contents = r#"
|
||||||
|
fn test1() {}
|
||||||
|
fn test2() {}
|
||||||
|
fn test3() {}
|
||||||
|
"#
|
||||||
|
.unindent();
|
||||||
|
|
||||||
|
let fs = FakeFs::new(cx.executor());
|
||||||
|
fs.insert_tree(
|
||||||
|
"/code",
|
||||||
|
json!({
|
||||||
|
"root1": {
|
||||||
|
"lib": {
|
||||||
|
"file1.rs": "mod example;",
|
||||||
|
"file2.rs": "",
|
||||||
|
},
|
||||||
|
"test": {
|
||||||
|
"file3.rs": file_3_contents,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root2": {
|
||||||
|
"src": {
|
||||||
|
"main.rs": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let project = Project::test(
|
||||||
|
fs.clone(),
|
||||||
|
["/code/root1".as_ref(), "/code/root2".as_ref()],
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let worktree_ids = project.read_with(cx, |project, cx| {
|
||||||
|
project
|
||||||
|
.worktrees()
|
||||||
|
.map(|worktree| worktree.read(cx).id())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut ax = ProjectContext::new(project.downgrade(), fs);
|
||||||
|
|
||||||
|
ax.add_file(ProjectPath {
|
||||||
|
worktree_id: worktree_ids[0],
|
||||||
|
path: Path::new("lib/file1.rs").into(),
|
||||||
|
});
|
||||||
|
|
||||||
|
let message = cx
|
||||||
|
.update(|cx| ax.generate_system_message(cx))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
r#"
|
||||||
|
project structure:
|
||||||
|
# root1
|
||||||
|
## lib/file1.rs
|
||||||
|
~~~
|
||||||
|
mod example;
|
||||||
|
~~~
|
||||||
|
"#
|
||||||
|
.unindent(),
|
||||||
|
message
|
||||||
|
);
|
||||||
|
|
||||||
|
ax.add_excerpts(
|
||||||
|
ProjectPath {
|
||||||
|
worktree_id: worktree_ids[0],
|
||||||
|
path: Path::new("test/file3.rs").into(),
|
||||||
|
},
|
||||||
|
&[
|
||||||
|
file_3_contents.find("fn test2").unwrap()
|
||||||
|
..file_3_contents.find("fn test3").unwrap(),
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
|
let message = cx
|
||||||
|
.update(|cx| ax.generate_system_message(cx))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
r#"
|
||||||
|
project structure:
|
||||||
|
# root1
|
||||||
|
## lib/file1.rs
|
||||||
|
~~~
|
||||||
|
mod example;
|
||||||
|
~~~
|
||||||
|
## test/file3.rs
|
||||||
|
~~~
|
||||||
|
...
|
||||||
|
fn test2() {}
|
||||||
|
...
|
||||||
|
~~~
|
||||||
|
"#
|
||||||
|
.unindent(),
|
||||||
|
message
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn init_test(cx: &mut TestAppContext) {
|
||||||
|
cx.update(|cx| {
|
||||||
|
let settings_store = SettingsStore::test(cx);
|
||||||
|
cx.set_global(settings_store);
|
||||||
|
Project::init_settings(cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,111 +0,0 @@
|
||||||
use anyhow::Result;
|
|
||||||
use gpui::{div, AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
|
|
||||||
use schemars::{schema::RootSchema, schema_for, JsonSchema};
|
|
||||||
use serde::Deserialize;
|
|
||||||
use std::fmt::Display;
|
|
||||||
|
|
||||||
#[derive(Default, Deserialize)]
|
|
||||||
pub struct ToolFunctionCall {
|
|
||||||
pub id: String,
|
|
||||||
pub name: String,
|
|
||||||
pub arguments: String,
|
|
||||||
#[serde(skip)]
|
|
||||||
pub result: Option<ToolFunctionCallResult>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub enum ToolFunctionCallResult {
|
|
||||||
NoSuchTool,
|
|
||||||
ParsingFailed,
|
|
||||||
Finished { for_model: String, view: AnyView },
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ToolFunctionCallResult {
|
|
||||||
pub fn format(&self, name: &String) -> String {
|
|
||||||
match self {
|
|
||||||
ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
|
|
||||||
ToolFunctionCallResult::ParsingFailed => {
|
|
||||||
format!("Unable to parse arguments for {name}")
|
|
||||||
}
|
|
||||||
ToolFunctionCallResult::Finished { for_model, .. } => for_model.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_any_element(&self, name: &String) -> AnyElement {
|
|
||||||
match self {
|
|
||||||
ToolFunctionCallResult::NoSuchTool => {
|
|
||||||
format!("Language Model attempted to call {name}").into_any_element()
|
|
||||||
}
|
|
||||||
ToolFunctionCallResult::ParsingFailed => {
|
|
||||||
format!("Language Model called {name} with bad arguments").into_any_element()
|
|
||||||
}
|
|
||||||
ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct ToolFunctionDefinition {
|
|
||||||
pub name: String,
|
|
||||||
pub description: String,
|
|
||||||
pub parameters: RootSchema,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for ToolFunctionDefinition {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
let schema = serde_json::to_string(&self.parameters).ok();
|
|
||||||
let schema = schema.unwrap_or("None".to_string());
|
|
||||||
write!(f, "Name: {}:\n", self.name)?;
|
|
||||||
write!(f, "Description: {}\n", self.description)?;
|
|
||||||
write!(f, "Parameters: {}", schema)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait LanguageModelTool {
|
|
||||||
/// The input type that will be passed in to `execute` when the tool is called
|
|
||||||
/// by the language model.
|
|
||||||
type Input: for<'de> Deserialize<'de> + JsonSchema;
|
|
||||||
|
|
||||||
/// The output returned by executing the tool.
|
|
||||||
type Output: 'static;
|
|
||||||
|
|
||||||
type View: Render;
|
|
||||||
|
|
||||||
/// Returns the name of the tool.
|
|
||||||
///
|
|
||||||
/// This name is exposed to the language model to allow the model to pick
|
|
||||||
/// which tools to use. As this name is used to identify the tool within a
|
|
||||||
/// tool registry, it should be unique.
|
|
||||||
fn name(&self) -> String;
|
|
||||||
|
|
||||||
/// Returns the description of the tool.
|
|
||||||
///
|
|
||||||
/// This can be used to _prompt_ the model as to what the tool does.
|
|
||||||
fn description(&self) -> String;
|
|
||||||
|
|
||||||
/// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
|
|
||||||
fn definition(&self) -> ToolFunctionDefinition {
|
|
||||||
let root_schema = schema_for!(Self::Input);
|
|
||||||
|
|
||||||
ToolFunctionDefinition {
|
|
||||||
name: self.name(),
|
|
||||||
description: self.description(),
|
|
||||||
parameters: root_schema,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Executes the tool with the given input.
|
|
||||||
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
|
|
||||||
|
|
||||||
fn format(input: &Self::Input, output: &Result<Self::Output>) -> String;
|
|
||||||
|
|
||||||
fn output_view(
|
|
||||||
tool_call_id: String,
|
|
||||||
input: Self::Input,
|
|
||||||
output: Result<Self::Output>,
|
|
||||||
cx: &mut WindowContext,
|
|
||||||
) -> View<Self::View>;
|
|
||||||
|
|
||||||
fn render_running(_cx: &mut WindowContext) -> impl IntoElement {
|
|
||||||
div()
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,54 +1,115 @@
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use gpui::{div, AnyElement, IntoElement as _, ParentElement, Styled, Task, WindowContext};
|
use gpui::{
|
||||||
|
div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
|
||||||
|
};
|
||||||
|
use schemars::{schema::RootSchema, schema_for, JsonSchema};
|
||||||
|
use serde::Deserialize;
|
||||||
use std::{
|
use std::{
|
||||||
any::TypeId,
|
any::TypeId,
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
|
fmt::Display,
|
||||||
sync::atomic::{AtomicBool, Ordering::SeqCst},
|
sync::atomic::{AtomicBool, Ordering::SeqCst},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::tool::{
|
use crate::ProjectContext;
|
||||||
LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Internal Tool representation for the registry
|
pub struct ToolRegistry {
|
||||||
pub struct Tool {
|
registered_tools: HashMap<String, RegisteredTool>,
|
||||||
enabled: AtomicBool,
|
|
||||||
type_id: TypeId,
|
|
||||||
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
|
||||||
render_running: Box<dyn Fn(&mut WindowContext) -> gpui::AnyElement>,
|
|
||||||
definition: ToolFunctionDefinition,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Tool {
|
#[derive(Default, Deserialize)]
|
||||||
fn new(
|
pub struct ToolFunctionCall {
|
||||||
type_id: TypeId,
|
pub id: String,
|
||||||
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
pub name: String,
|
||||||
render_running: Box<dyn Fn(&mut WindowContext) -> gpui::AnyElement>,
|
pub arguments: String,
|
||||||
definition: ToolFunctionDefinition,
|
#[serde(skip)]
|
||||||
) -> Self {
|
pub result: Option<ToolFunctionCallResult>,
|
||||||
Self {
|
}
|
||||||
enabled: AtomicBool::new(true),
|
|
||||||
type_id,
|
pub enum ToolFunctionCallResult {
|
||||||
call,
|
NoSuchTool,
|
||||||
render_running,
|
ParsingFailed,
|
||||||
definition,
|
Finished {
|
||||||
|
view: AnyView,
|
||||||
|
generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ToolFunctionDefinition {
|
||||||
|
pub name: String,
|
||||||
|
pub description: String,
|
||||||
|
pub parameters: RootSchema,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait LanguageModelTool {
|
||||||
|
/// The input type that will be passed in to `execute` when the tool is called
|
||||||
|
/// by the language model.
|
||||||
|
type Input: for<'de> Deserialize<'de> + JsonSchema;
|
||||||
|
|
||||||
|
/// The output returned by executing the tool.
|
||||||
|
type Output: 'static;
|
||||||
|
|
||||||
|
type View: Render + ToolOutput;
|
||||||
|
|
||||||
|
/// Returns the name of the tool.
|
||||||
|
///
|
||||||
|
/// This name is exposed to the language model to allow the model to pick
|
||||||
|
/// which tools to use. As this name is used to identify the tool within a
|
||||||
|
/// tool registry, it should be unique.
|
||||||
|
fn name(&self) -> String;
|
||||||
|
|
||||||
|
/// Returns the description of the tool.
|
||||||
|
///
|
||||||
|
/// This can be used to _prompt_ the model as to what the tool does.
|
||||||
|
fn description(&self) -> String;
|
||||||
|
|
||||||
|
/// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
|
||||||
|
fn definition(&self) -> ToolFunctionDefinition {
|
||||||
|
let root_schema = schema_for!(Self::Input);
|
||||||
|
|
||||||
|
ToolFunctionDefinition {
|
||||||
|
name: self.name(),
|
||||||
|
description: self.description(),
|
||||||
|
parameters: root_schema,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Executes the tool with the given input.
|
||||||
|
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
|
||||||
|
|
||||||
|
fn output_view(
|
||||||
|
input: Self::Input,
|
||||||
|
output: Result<Self::Output>,
|
||||||
|
cx: &mut WindowContext,
|
||||||
|
) -> View<Self::View>;
|
||||||
|
|
||||||
|
fn render_running(_cx: &mut WindowContext) -> impl IntoElement {
|
||||||
|
div()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ToolRegistry {
|
pub trait ToolOutput: Sized {
|
||||||
tools: HashMap<String, Tool>,
|
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RegisteredTool {
|
||||||
|
enabled: AtomicBool,
|
||||||
|
type_id: TypeId,
|
||||||
|
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
||||||
|
render_running: fn(&mut WindowContext) -> gpui::AnyElement,
|
||||||
|
definition: ToolFunctionDefinition,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolRegistry {
|
impl ToolRegistry {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
tools: HashMap::new(),
|
registered_tools: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
|
pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
|
||||||
for tool in self.tools.values() {
|
for tool in self.registered_tools.values() {
|
||||||
if tool.type_id == TypeId::of::<T>() {
|
if tool.type_id == TypeId::of::<T>() {
|
||||||
tool.enabled.store(is_enabled, SeqCst);
|
tool.enabled.store(is_enabled, SeqCst);
|
||||||
return;
|
return;
|
||||||
|
@ -57,7 +118,7 @@ impl ToolRegistry {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
|
pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
|
||||||
for tool in self.tools.values() {
|
for tool in self.registered_tools.values() {
|
||||||
if tool.type_id == TypeId::of::<T>() {
|
if tool.type_id == TypeId::of::<T>() {
|
||||||
return tool.enabled.load(SeqCst);
|
return tool.enabled.load(SeqCst);
|
||||||
}
|
}
|
||||||
|
@ -66,7 +127,7 @@ impl ToolRegistry {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
|
pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
|
||||||
self.tools
|
self.registered_tools
|
||||||
.values()
|
.values()
|
||||||
.filter(|tool| tool.enabled.load(SeqCst))
|
.filter(|tool| tool.enabled.load(SeqCst))
|
||||||
.map(|tool| tool.definition.clone())
|
.map(|tool| tool.definition.clone())
|
||||||
|
@ -84,7 +145,7 @@ impl ToolRegistry {
|
||||||
.child(result.into_any_element(&tool_call.name))
|
.child(result.into_any_element(&tool_call.name))
|
||||||
.into_any_element(),
|
.into_any_element(),
|
||||||
None => self
|
None => self
|
||||||
.tools
|
.registered_tools
|
||||||
.get(&tool_call.name)
|
.get(&tool_call.name)
|
||||||
.map(|tool| (tool.render_running)(cx))
|
.map(|tool| (tool.render_running)(cx))
|
||||||
.unwrap_or_else(|| div().into_any_element()),
|
.unwrap_or_else(|| div().into_any_element()),
|
||||||
|
@ -96,13 +157,12 @@ impl ToolRegistry {
|
||||||
tool: T,
|
tool: T,
|
||||||
_cx: &mut WindowContext,
|
_cx: &mut WindowContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let definition = tool.definition();
|
|
||||||
|
|
||||||
let name = tool.name();
|
let name = tool.name();
|
||||||
|
let registered_tool = RegisteredTool {
|
||||||
let registered_tool = Tool::new(
|
type_id: TypeId::of::<T>(),
|
||||||
TypeId::of::<T>(),
|
definition: tool.definition(),
|
||||||
Box::new(
|
enabled: AtomicBool::new(true),
|
||||||
|
call: Box::new(
|
||||||
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
|
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
|
||||||
let name = tool_call.name.clone();
|
let name = tool_call.name.clone();
|
||||||
let arguments = tool_call.arguments.clone();
|
let arguments = tool_call.arguments.clone();
|
||||||
|
@ -121,8 +181,7 @@ impl ToolRegistry {
|
||||||
|
|
||||||
cx.spawn(move |mut cx| async move {
|
cx.spawn(move |mut cx| async move {
|
||||||
let result: Result<T::Output> = result.await;
|
let result: Result<T::Output> = result.await;
|
||||||
let for_model = T::format(&input, &result);
|
let view = cx.update(|cx| T::output_view(input, result, cx))?;
|
||||||
let view = cx.update(|cx| T::output_view(id.clone(), input, result, cx))?;
|
|
||||||
|
|
||||||
Ok(ToolFunctionCall {
|
Ok(ToolFunctionCall {
|
||||||
id,
|
id,
|
||||||
|
@ -130,23 +189,35 @@ impl ToolRegistry {
|
||||||
arguments,
|
arguments,
|
||||||
result: Some(ToolFunctionCallResult::Finished {
|
result: Some(ToolFunctionCallResult::Finished {
|
||||||
view: view.into(),
|
view: view.into(),
|
||||||
for_model,
|
generate_fn: generate::<T>,
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
Box::new(|cx| T::render_running(cx).into_any_element()),
|
render_running: render_running::<T>,
|
||||||
definition,
|
};
|
||||||
);
|
|
||||||
|
|
||||||
let previous = self.tools.insert(name.clone(), registered_tool);
|
|
||||||
|
|
||||||
|
let previous = self.registered_tools.insert(name.clone(), registered_tool);
|
||||||
if previous.is_some() {
|
if previous.is_some() {
|
||||||
return Err(anyhow!("already registered a tool with name {}", name));
|
return Err(anyhow!("already registered a tool with name {}", name));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
return Ok(());
|
||||||
|
|
||||||
|
fn render_running<T: LanguageModelTool>(cx: &mut WindowContext) -> AnyElement {
|
||||||
|
T::render_running(cx).into_any_element()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate<T: LanguageModelTool>(
|
||||||
|
view: AnyView,
|
||||||
|
project: &mut ProjectContext,
|
||||||
|
cx: &mut WindowContext,
|
||||||
|
) -> String {
|
||||||
|
view.downcast::<T::View>()
|
||||||
|
.unwrap()
|
||||||
|
.update(cx, |view, cx| T::View::generate(view, project, cx))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Task yields an error if the window for the given WindowContext is closed before the task completes.
|
/// Task yields an error if the window for the given WindowContext is closed before the task completes.
|
||||||
|
@ -159,7 +230,7 @@ impl ToolRegistry {
|
||||||
let arguments = tool_call.arguments.clone();
|
let arguments = tool_call.arguments.clone();
|
||||||
let id = tool_call.id.clone();
|
let id = tool_call.id.clone();
|
||||||
|
|
||||||
let tool = match self.tools.get(&name) {
|
let tool = match self.registered_tools.get(&name) {
|
||||||
Some(tool) => tool,
|
Some(tool) => tool,
|
||||||
None => {
|
None => {
|
||||||
let name = name.clone();
|
let name = name.clone();
|
||||||
|
@ -176,6 +247,47 @@ impl ToolRegistry {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ToolFunctionCallResult {
|
||||||
|
pub fn generate(
|
||||||
|
&self,
|
||||||
|
name: &String,
|
||||||
|
project: &mut ProjectContext,
|
||||||
|
cx: &mut WindowContext,
|
||||||
|
) -> String {
|
||||||
|
match self {
|
||||||
|
ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
|
||||||
|
ToolFunctionCallResult::ParsingFailed => {
|
||||||
|
format!("Unable to parse arguments for {name}")
|
||||||
|
}
|
||||||
|
ToolFunctionCallResult::Finished { generate_fn, view } => {
|
||||||
|
(generate_fn)(view.clone(), project, cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn into_any_element(&self, name: &String) -> AnyElement {
|
||||||
|
match self {
|
||||||
|
ToolFunctionCallResult::NoSuchTool => {
|
||||||
|
format!("Language Model attempted to call {name}").into_any_element()
|
||||||
|
}
|
||||||
|
ToolFunctionCallResult::ParsingFailed => {
|
||||||
|
format!("Language Model called {name} with bad arguments").into_any_element()
|
||||||
|
}
|
||||||
|
ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for ToolFunctionDefinition {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
let schema = serde_json::to_string(&self.parameters).ok();
|
||||||
|
let schema = schema.unwrap_or("None".to_string());
|
||||||
|
write!(f, "Name: {}:\n", self.name)?;
|
||||||
|
write!(f, "Description: {}\n", self.description)?;
|
||||||
|
write!(f, "Parameters: {}", schema)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -213,6 +325,12 @@ mod test {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ToolOutput for WeatherView {
|
||||||
|
fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
|
||||||
|
serde_json::to_string(&self.result).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl LanguageModelTool for WeatherTool {
|
impl LanguageModelTool for WeatherTool {
|
||||||
type Input = WeatherQuery;
|
type Input = WeatherQuery;
|
||||||
type Output = WeatherResult;
|
type Output = WeatherResult;
|
||||||
|
@ -240,7 +358,6 @@ mod test {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn output_view(
|
fn output_view(
|
||||||
_tool_call_id: String,
|
|
||||||
_input: Self::Input,
|
_input: Self::Input,
|
||||||
result: Result<Self::Output>,
|
result: Result<Self::Output>,
|
||||||
cx: &mut WindowContext,
|
cx: &mut WindowContext,
|
||||||
|
@ -250,10 +367,6 @@ mod test {
|
||||||
WeatherView { result }
|
WeatherView { result }
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
|
|
||||||
serde_json::to_string(&output.as_ref().unwrap()).unwrap()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
|
@ -163,6 +163,10 @@ impl ProjectIndex {
|
||||||
self.project.clone()
|
self.project.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn fs(&self) -> Arc<dyn Fs> {
|
||||||
|
self.fs.clone()
|
||||||
|
}
|
||||||
|
|
||||||
fn handle_project_event(
|
fn handle_project_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
_: Model<Project>,
|
_: Model<Project>,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue