mod tool_registry; mod tool_working_set; use std::sync::Arc; use anyhow::Result; use collections::{HashMap, HashSet}; use gpui::Context; use gpui::{App, Entity, SharedString, Task}; use language::Buffer; use language_model::LanguageModelRequestMessage; use project::Project; pub use crate::tool_registry::*; pub use crate::tool_working_set::*; pub fn init(cx: &mut App) { ToolRegistry::default_global(cx); } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)] pub enum ToolSource { /// A native tool built-in to Zed. Native, /// A tool provided by a context server. ContextServer { id: SharedString }, } /// A tool that can be used by a language model. pub trait Tool: 'static + Send + Sync { /// Returns the name of the tool. fn name(&self) -> String; /// Returns the description of the tool. fn description(&self) -> String; /// Returns the source of the tool. fn source(&self) -> ToolSource { ToolSource::Native } /// Returns the JSON schema that describes the tool's input. fn input_schema(&self) -> serde_json::Value { serde_json::Value::Object(serde_json::Map::default()) } /// Runs the tool with the provided input. fn run( self: Arc, input: serde_json::Value, messages: &[LanguageModelRequestMessage], project: Entity, action_log: Entity, cx: &mut App, ) -> Task>; } /// Tracks actions performed by tools in a thread #[derive(Debug)] pub struct ActionLog { /// Buffers that user manually added to the context, and whose content has /// changed since the model last saw them. stale_buffers_in_context: HashSet>, /// Buffers that we want to notify the model about when they change. tracked_buffers: HashMap, TrackedBuffer>, } #[derive(Debug, Default)] struct TrackedBuffer { version: clock::Global, } impl ActionLog { /// Creates a new, empty action log. pub fn new() -> Self { Self { stale_buffers_in_context: HashSet::default(), tracked_buffers: HashMap::default(), } } /// Track a buffer as read, so we can notify the model about user edits. pub fn buffer_read(&mut self, buffer: Entity, cx: &mut Context) { let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default(); tracked_buffer.version = buffer.read(cx).version(); } /// Mark a buffer as edited, so we can refresh it in the context pub fn buffer_edited(&mut self, buffers: HashSet>, cx: &mut Context) { for buffer in &buffers { let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default(); tracked_buffer.version = buffer.read(cx).version(); } self.stale_buffers_in_context.extend(buffers); } /// Iterate over buffers changed since last read or edited by the model pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator> { self.tracked_buffers .iter() .filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version) .map(|(buffer, _)| buffer) } /// Takes and returns the set of buffers pending refresh, clearing internal state. pub fn take_stale_buffers_in_context(&mut self) -> HashSet> { std::mem::take(&mut self.stale_buffers_in_context) } }