
This PR removes the dependency on the `ui` crate from the `assistant_tool` and `context_server` crates. These crates were only depending on it for `IconName`, which can now be depended on from `icons` directly. Release Notes: - N/A
127 lines
4 KiB
Rust
127 lines
4 KiB
Rust
mod tool_registry;
|
|
mod tool_working_set;
|
|
|
|
use std::fmt::{self, Debug, Formatter};
|
|
use std::sync::Arc;
|
|
|
|
use anyhow::Result;
|
|
use collections::{HashMap, HashSet};
|
|
use gpui::{App, Context, Entity, SharedString, Task};
|
|
use icons::IconName;
|
|
use language::Buffer;
|
|
use language_model::LanguageModelRequestMessage;
|
|
use project::Project;
|
|
|
|
pub use crate::tool_registry::*;
|
|
pub use crate::tool_working_set::*;
|
|
|
|
pub fn init(cx: &mut App) {
|
|
ToolRegistry::default_global(cx);
|
|
}
|
|
|
|
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
|
|
pub enum ToolSource {
|
|
/// A native tool built-in to Zed.
|
|
Native,
|
|
/// A tool provided by a context server.
|
|
ContextServer { id: SharedString },
|
|
}
|
|
|
|
/// A tool that can be used by a language model.
|
|
pub trait Tool: 'static + Send + Sync {
|
|
/// Returns the name of the tool.
|
|
fn name(&self) -> String;
|
|
|
|
/// Returns the description of the tool.
|
|
fn description(&self) -> String;
|
|
|
|
/// Returns the icon for the tool.
|
|
fn icon(&self) -> IconName;
|
|
|
|
/// Returns the source of the tool.
|
|
fn source(&self) -> ToolSource {
|
|
ToolSource::Native
|
|
}
|
|
|
|
/// Returns true iff the tool needs the users's confirmation
|
|
/// before having permission to run.
|
|
fn needs_confirmation(&self) -> bool;
|
|
|
|
/// Returns the JSON schema that describes the tool's input.
|
|
fn input_schema(&self) -> serde_json::Value {
|
|
serde_json::Value::Object(serde_json::Map::default())
|
|
}
|
|
|
|
/// Returns markdown to be displayed in the UI for this tool.
|
|
fn ui_text(&self, input: &serde_json::Value) -> String;
|
|
|
|
/// Runs the tool with the provided input.
|
|
fn run(
|
|
self: Arc<Self>,
|
|
input: serde_json::Value,
|
|
messages: &[LanguageModelRequestMessage],
|
|
project: Entity<Project>,
|
|
action_log: Entity<ActionLog>,
|
|
cx: &mut App,
|
|
) -> Task<Result<String>>;
|
|
}
|
|
|
|
impl Debug for dyn Tool {
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
|
f.debug_struct("Tool").field("name", &self.name()).finish()
|
|
}
|
|
}
|
|
|
|
/// Tracks actions performed by tools in a thread
|
|
#[derive(Debug)]
|
|
pub struct ActionLog {
|
|
/// Buffers that user manually added to the context, and whose content has
|
|
/// changed since the model last saw them.
|
|
stale_buffers_in_context: HashSet<Entity<Buffer>>,
|
|
/// Buffers that we want to notify the model about when they change.
|
|
tracked_buffers: HashMap<Entity<Buffer>, TrackedBuffer>,
|
|
}
|
|
|
|
#[derive(Debug, Default)]
|
|
struct TrackedBuffer {
|
|
version: clock::Global,
|
|
}
|
|
|
|
impl ActionLog {
|
|
/// Creates a new, empty action log.
|
|
pub fn new() -> Self {
|
|
Self {
|
|
stale_buffers_in_context: HashSet::default(),
|
|
tracked_buffers: HashMap::default(),
|
|
}
|
|
}
|
|
|
|
/// Track a buffer as read, so we can notify the model about user edits.
|
|
pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
|
let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
|
|
tracked_buffer.version = buffer.read(cx).version();
|
|
}
|
|
|
|
/// Mark a buffer as edited, so we can refresh it in the context
|
|
pub fn buffer_edited(&mut self, buffers: HashSet<Entity<Buffer>>, cx: &mut Context<Self>) {
|
|
for buffer in &buffers {
|
|
let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
|
|
tracked_buffer.version = buffer.read(cx).version();
|
|
}
|
|
|
|
self.stale_buffers_in_context.extend(buffers);
|
|
}
|
|
|
|
/// Iterate over buffers changed since last read or edited by the model
|
|
pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
|
|
self.tracked_buffers
|
|
.iter()
|
|
.filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version)
|
|
.map(|(buffer, _)| buffer)
|
|
}
|
|
|
|
/// Takes and returns the set of buffers pending refresh, clearing internal state.
|
|
pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
|
|
std::mem::take(&mut self.stale_buffers_in_context)
|
|
}
|
|
}
|