ZIm/crates/assistant_tool/src/assistant_tool.rs
Agus Zubiaga a05066cd83
assistant edit tool: Track read buffers and notify model of user edits (#26952)
When the model reads file, we'll track the version it read, and let it
know if the user makes edits to the buffer. This helps prevent edit
failures because it'll know to re-read the file before.

Release Notes:

- N/A
2025-03-17 21:50:16 +00:00

110 lines
3.4 KiB
Rust

mod tool_registry;
mod tool_working_set;
use std::sync::Arc;
use anyhow::Result;
use collections::{HashMap, HashSet};
use gpui::Context;
use gpui::{App, Entity, SharedString, Task};
use language::Buffer;
use language_model::LanguageModelRequestMessage;
use project::Project;
pub use crate::tool_registry::*;
pub use crate::tool_working_set::*;
pub fn init(cx: &mut App) {
ToolRegistry::default_global(cx);
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
pub enum ToolSource {
/// A native tool built-in to Zed.
Native,
/// A tool provided by a context server.
ContextServer { id: SharedString },
}
/// A tool that can be used by a language model.
pub trait Tool: 'static + Send + Sync {
/// Returns the name of the tool.
fn name(&self) -> String;
/// Returns the description of the tool.
fn description(&self) -> String;
/// Returns the source of the tool.
fn source(&self) -> ToolSource {
ToolSource::Native
}
/// Returns the JSON schema that describes the tool's input.
fn input_schema(&self) -> serde_json::Value {
serde_json::Value::Object(serde_json::Map::default())
}
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>>;
}
/// Tracks actions performed by tools in a thread
#[derive(Debug)]
pub struct ActionLog {
/// Buffers that user manually added to the context, and whose content has
/// changed since the model last saw them.
stale_buffers_in_context: HashSet<Entity<Buffer>>,
/// Buffers that we want to notify the model about when they change.
tracked_buffers: HashMap<Entity<Buffer>, TrackedBuffer>,
}
#[derive(Debug, Default)]
struct TrackedBuffer {
version: clock::Global,
}
impl ActionLog {
/// Creates a new, empty action log.
pub fn new() -> Self {
Self {
stale_buffers_in_context: HashSet::default(),
tracked_buffers: HashMap::default(),
}
}
/// Track a buffer as read, so we can notify the model about user edits.
pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
tracked_buffer.version = buffer.read(cx).version();
}
/// Mark a buffer as edited, so we can refresh it in the context
pub fn buffer_edited(&mut self, buffers: HashSet<Entity<Buffer>>, cx: &mut Context<Self>) {
for buffer in &buffers {
let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
tracked_buffer.version = buffer.read(cx).version();
}
self.stale_buffers_in_context.extend(buffers);
}
/// Iterate over buffers changed since last read or edited by the model
pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
self.tracked_buffers
.iter()
.filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version)
.map(|(buffer, _)| buffer)
}
/// Takes and returns the set of buffers pending refresh, clearing internal state.
pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
std::mem::take(&mut self.stale_buffers_in_context)
}
}