Centralize project context provided to the assistant (#11471)

This PR restructures the way that tools and attachments add information
about the current project to a conversation with the assistant. Rather
than each tool call or attachment generating a new tool or system
message containing information about the project, they can all
collectively mutate a new type called a `ProjectContext`, which stores
all of the project data that should be sent to the assistant. That data
is then formatted in a single place, and passed to the assistant in one
system message.

This prevents multiple tools/attachments from including redundant
context.

Release Notes:

- N/A

---------

Co-authored-by: Kyle <kylek@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-05-06 17:01:50 -07:00 committed by GitHub
parent f2a415135b
commit a64e20ed96
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 841 additions and 518 deletions

View file

@ -4,10 +4,16 @@ mod completion_provider;
mod tools;
pub mod ui;
use crate::{
attachments::ActiveEditorAttachmentTool,
tools::{CreateBufferTool, ProjectIndexTool},
ui::UserOrAssistant,
};
use ::ui::{div, prelude::*, Color, ViewContext};
use anyhow::{Context, Result};
use assistant_tooling::{ToolFunctionCall, ToolRegistry};
use attachments::{ActiveEditorAttachmentTool, UserAttachment, UserAttachmentStore};
use assistant_tooling::{
AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment,
};
use client::{proto, Client, UserStore};
use collections::HashMap;
use completion_provider::*;
@ -34,9 +40,6 @@ use workspace::{
pub use assistant_settings::AssistantSettings;
use crate::tools::{CreateBufferTool, ProjectIndexTool};
use crate::ui::UserOrAssistant;
const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5;
#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
@ -85,10 +88,9 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
});
workspace.register_action(|workspace, _: &DebugProjectIndex, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
if let Some(index) = panel.read(cx).chat.read(cx).project_index.clone() {
let view = cx.new_view(|cx| ProjectIndexDebugView::new(index, cx));
workspace.add_item_to_center(Box::new(view), cx);
}
let index = panel.read(cx).chat.read(cx).project_index.clone();
let view = cx.new_view(|cx| ProjectIndexDebugView::new(index, cx));
workspace.add_item_to_center(Box::new(view), cx);
}
});
},
@ -122,10 +124,7 @@ impl AssistantPanel {
let mut tool_registry = ToolRegistry::new();
tool_registry
.register(
ProjectIndexTool::new(project_index.clone(), project.read(cx).fs().clone()),
cx,
)
.register(ProjectIndexTool::new(project_index.clone()), cx)
.context("failed to register ProjectIndexTool")
.log_err();
tool_registry
@ -136,7 +135,7 @@ impl AssistantPanel {
.context("failed to register CreateBufferTool")
.log_err();
let mut attachment_store = UserAttachmentStore::new();
let mut attachment_store = AttachmentRegistry::new();
attachment_store.register(ActiveEditorAttachmentTool::new(workspace.clone(), cx));
Self::new(
@ -144,7 +143,7 @@ impl AssistantPanel {
Arc::new(tool_registry),
Arc::new(attachment_store),
app_state.user_store.clone(),
Some(project_index),
project_index,
workspace,
cx,
)
@ -155,9 +154,9 @@ impl AssistantPanel {
pub fn new(
language_registry: Arc<LanguageRegistry>,
tool_registry: Arc<ToolRegistry>,
attachment_store: Arc<UserAttachmentStore>,
attachment_store: Arc<AttachmentRegistry>,
user_store: Model<UserStore>,
project_index: Option<Model<ProjectIndex>>,
project_index: Model<ProjectIndex>,
workspace: WeakView<Workspace>,
cx: &mut ViewContext<Self>,
) -> Self {
@ -241,16 +240,16 @@ pub struct AssistantChat {
list_state: ListState,
language_registry: Arc<LanguageRegistry>,
composer_editor: View<Editor>,
project_index_button: Option<View<ProjectIndexButton>>,
project_index_button: View<ProjectIndexButton>,
active_file_button: Option<View<ActiveFileButton>>,
user_store: Model<UserStore>,
next_message_id: MessageId,
collapsed_messages: HashMap<MessageId, bool>,
editing_message: Option<EditingMessage>,
pending_completion: Option<Task<()>>,
attachment_store: Arc<UserAttachmentStore>,
tool_registry: Arc<ToolRegistry>,
project_index: Option<Model<ProjectIndex>>,
attachment_registry: Arc<AttachmentRegistry>,
project_index: Model<ProjectIndex>,
}
struct EditingMessage {
@ -263,9 +262,9 @@ impl AssistantChat {
fn new(
language_registry: Arc<LanguageRegistry>,
tool_registry: Arc<ToolRegistry>,
attachment_store: Arc<UserAttachmentStore>,
attachment_registry: Arc<AttachmentRegistry>,
user_store: Model<UserStore>,
project_index: Option<Model<ProjectIndex>>,
project_index: Model<ProjectIndex>,
workspace: WeakView<Workspace>,
cx: &mut ViewContext<Self>,
) -> Self {
@ -281,14 +280,14 @@ impl AssistantChat {
},
);
let project_index_button = project_index.clone().map(|project_index| {
cx.new_view(|cx| ProjectIndexButton::new(project_index, tool_registry.clone(), cx))
let project_index_button = cx.new_view(|cx| {
ProjectIndexButton::new(project_index.clone(), tool_registry.clone(), cx)
});
let active_file_button = match workspace.upgrade() {
Some(workspace) => {
Some(cx.new_view(
|cx| ActiveFileButton::new(attachment_store.clone(), workspace, cx), //
|cx| ActiveFileButton::new(attachment_registry.clone(), workspace, cx), //
))
}
_ => None,
@ -313,7 +312,7 @@ impl AssistantChat {
editing_message: None,
collapsed_messages: HashMap::default(),
pending_completion: None,
attachment_store,
attachment_registry,
tool_registry,
}
}
@ -395,7 +394,7 @@ impl AssistantChat {
let mode = *mode;
self.pending_completion = Some(cx.spawn(move |this, mut cx| async move {
let attachments_task = this.update(&mut cx, |this, cx| {
let attachment_store = this.attachment_store.clone();
let attachment_store = this.attachment_registry.clone();
attachment_store.call_all_attachment_tools(cx)
});
@ -443,7 +442,7 @@ impl AssistantChat {
let mut call_count = 0;
loop {
let complete = async {
let completion = this.update(cx, |this, cx| {
let (tool_definitions, model_name, messages) = this.update(cx, |this, cx| {
this.push_new_assistant_message(cx);
let definitions = if call_count < limit
@ -455,14 +454,22 @@ impl AssistantChat {
};
call_count += 1;
let messages = this.completion_messages(cx);
CompletionProvider::get(cx).complete(
(
definitions,
this.model.clone(),
this.completion_messages(cx),
)
})?;
let messages = messages.await?;
let completion = cx.update(|cx| {
CompletionProvider::get(cx).complete(
model_name,
messages,
Vec::new(),
1.0,
definitions,
tool_definitions,
)
});
@ -765,7 +772,12 @@ impl AssistantChat {
}
}
fn completion_messages(&self, cx: &mut WindowContext) -> Vec<CompletionMessage> {
fn completion_messages(&self, cx: &mut WindowContext) -> Task<Result<Vec<CompletionMessage>>> {
let project_index = self.project_index.read(cx);
let project = project_index.project();
let fs = project_index.fs();
let mut project_context = ProjectContext::new(project, fs);
let mut completion_messages = Vec::new();
for message in &self.messages {
@ -773,12 +785,11 @@ impl AssistantChat {
ChatMessage::User(UserMessage {
body, attachments, ..
}) => {
completion_messages.extend(
attachments
.into_iter()
.filter_map(|attachment| attachment.message.clone())
.map(|content| CompletionMessage::System { content }),
);
for attachment in attachments {
if let Some(content) = attachment.generate(&mut project_context, cx) {
completion_messages.push(CompletionMessage::System { content });
}
}
// Show user's message last so that the assistant is grounded in the user's request
completion_messages.push(CompletionMessage::User {
@ -815,7 +826,9 @@ impl AssistantChat {
for tool_call in tool_calls {
// Every tool call _must_ have a result by ID, otherwise OpenAI will error.
let content = match &tool_call.result {
Some(result) => result.format(&tool_call.name),
Some(result) => {
result.generate(&tool_call.name, &mut project_context, cx)
}
None => "".to_string(),
};
@ -828,7 +841,13 @@ impl AssistantChat {
}
}
completion_messages
let system_message = project_context.generate_system_message(cx);
cx.background_executor().spawn(async move {
let content = system_message.await?;
completion_messages.insert(0, CompletionMessage::System { content });
Ok(completion_messages)
})
}
}