assistant2: Add support for using tools provided by context servers (#21418)

This PR adds support to Assistant 2 for using tools provided by context
servers.

As part of this I introduced a new `ThreadStore`.

Release Notes:

- N/A

---------

Co-authored-by: Cole <cole@zed.dev>
This commit is contained in:
Marshall Bowers 2024-12-02 15:01:18 -05:00 committed by GitHub
parent f32ffcf5bb
commit b88daae67b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 139 additions and 2 deletions

3
Cargo.lock generated
View file

@ -458,12 +458,15 @@ dependencies = [
"assistant_tool", "assistant_tool",
"collections", "collections",
"command_palette_hooks", "command_palette_hooks",
"context_server",
"editor", "editor",
"feature_flags", "feature_flags",
"futures 0.3.31", "futures 0.3.31",
"gpui", "gpui",
"language_model", "language_model",
"language_model_selector", "language_model_selector",
"log",
"project",
"proto", "proto",
"serde", "serde",
"serde_json", "serde_json",

View file

@ -17,12 +17,15 @@ anyhow.workspace = true
assistant_tool.workspace = true assistant_tool.workspace = true
collections.workspace = true collections.workspace = true
command_palette_hooks.workspace = true command_palette_hooks.workspace = true
context_server.workspace = true
editor.workspace = true editor.workspace = true
feature_flags.workspace = true feature_flags.workspace = true
futures.workspace = true futures.workspace = true
gpui.workspace = true gpui.workspace = true
language_model.workspace = true language_model.workspace = true
language_model_selector.workspace = true language_model_selector.workspace = true
log.workspace = true
project.workspace = true
proto.workspace = true proto.workspace = true
settings.workspace = true settings.workspace = true
serde.workspace = true serde.workspace = true

View file

@ -1,6 +1,7 @@
mod assistant_panel; mod assistant_panel;
mod message_editor; mod message_editor;
mod thread; mod thread;
mod thread_store;
use command_palette_hooks::CommandPaletteFilter; use command_palette_hooks::CommandPaletteFilter;
use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt}; use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt};

View file

@ -14,6 +14,7 @@ use workspace::Workspace;
use crate::message_editor::MessageEditor; use crate::message_editor::MessageEditor;
use crate::thread::{Message, Thread, ThreadEvent}; use crate::thread::{Message, Thread, ThreadEvent};
use crate::thread_store::ThreadStore;
use crate::{NewThread, ToggleFocus, ToggleModelSelector}; use crate::{NewThread, ToggleFocus, ToggleModelSelector};
pub fn init(cx: &mut AppContext) { pub fn init(cx: &mut AppContext) {
@ -29,6 +30,8 @@ pub fn init(cx: &mut AppContext) {
pub struct AssistantPanel { pub struct AssistantPanel {
workspace: WeakView<Workspace>, workspace: WeakView<Workspace>,
#[allow(unused)]
thread_store: Model<ThreadStore>,
thread: Model<Thread>, thread: Model<Thread>,
message_editor: View<MessageEditor>, message_editor: View<MessageEditor>,
tools: Arc<ToolWorkingSet>, tools: Arc<ToolWorkingSet>,
@ -42,13 +45,25 @@ impl AssistantPanel {
) -> Task<Result<View<Self>>> { ) -> Task<Result<View<Self>>> {
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
let tools = Arc::new(ToolWorkingSet::default()); let tools = Arc::new(ToolWorkingSet::default());
let thread_store = workspace
.update(&mut cx, |workspace, cx| {
let project = workspace.project().clone();
ThreadStore::new(project, tools.clone(), cx)
})?
.await?;
workspace.update(&mut cx, |workspace, cx| { workspace.update(&mut cx, |workspace, cx| {
cx.new_view(|cx| Self::new(workspace, tools, cx)) cx.new_view(|cx| Self::new(workspace, thread_store, tools, cx))
}) })
}) })
} }
fn new(workspace: &Workspace, tools: Arc<ToolWorkingSet>, cx: &mut ViewContext<Self>) -> Self { fn new(
workspace: &Workspace,
thread_store: Model<ThreadStore>,
tools: Arc<ToolWorkingSet>,
cx: &mut ViewContext<Self>,
) -> Self {
let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx)); let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx));
let subscriptions = vec![ let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()), cx.observe(&thread, |_, _, cx| cx.notify()),
@ -57,6 +72,7 @@ impl AssistantPanel {
Self { Self {
workspace: workspace.weak_handle(), workspace: workspace.weak_handle(),
thread_store,
thread: thread.clone(), thread: thread.clone(),
message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)), message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
tools, tools,

View file

@ -0,0 +1,114 @@
use std::sync::Arc;
use anyhow::Result;
use assistant_tool::{ToolId, ToolWorkingSet};
use collections::HashMap;
use context_server::manager::ContextServerManager;
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
use gpui::{prelude::*, AppContext, Model, ModelContext, Task};
use project::Project;
use util::ResultExt as _;
pub struct ThreadStore {
#[allow(unused)]
project: Model<Project>,
tools: Arc<ToolWorkingSet>,
context_server_manager: Model<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
}
impl ThreadStore {
pub fn new(
project: Model<Project>,
tools: Arc<ToolWorkingSet>,
cx: &mut AppContext,
) -> Task<Result<Model<Self>>> {
cx.spawn(|mut cx| async move {
let this = cx.new_model(|cx: &mut ModelContext<Self>| {
let context_server_factory_registry =
ContextServerFactoryRegistry::default_global(cx);
let context_server_manager = cx.new_model(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
let this = Self {
project,
tools,
context_server_manager,
context_server_tool_ids: HashMap::default(),
};
this.register_context_server_handlers(cx);
this
})?;
Ok(this)
})
}
fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
cx.subscribe(
&self.context_server_manager.clone(),
Self::handle_context_server_event,
)
.detach();
}
fn handle_context_server_event(
&mut self,
context_server_manager: Model<ContextServerManager>,
event: &context_server::manager::Event,
cx: &mut ModelContext<Self>,
) {
let tool_working_set = self.tools.clone();
match event {
context_server::manager::Event::ServerStarted { server_id } => {
if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
let context_server_manager = context_server_manager.clone();
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
|this, mut cx| async move {
let Some(protocol) = server.client() else {
return;
};
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
})
.collect::<Vec<_>>();
this.update(&mut cx, |this, _cx| {
this.context_server_tool_ids.insert(server_id, tool_ids);
})
.log_err();
}
}
}
})
.detach();
}
}
context_server::manager::Event::ServerStopped { server_id } => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.remove(&tool_ids);
}
}
}
}
}