Improve claude tools (#36538)
- Return unified diff from `Edit` tool so model can see the final state - Format on save if enabled - Provide `Write` tool - Disable `MultiEdit` tool - Better prompting Release Notes: - N/A
This commit is contained in:
parent
714c36fa7b
commit
7c7043947b
11 changed files with 606 additions and 302 deletions
|
@ -3,9 +3,12 @@ mod diff;
|
||||||
mod mention;
|
mod mention;
|
||||||
mod terminal;
|
mod terminal;
|
||||||
|
|
||||||
|
use collections::HashSet;
|
||||||
pub use connection::*;
|
pub use connection::*;
|
||||||
pub use diff::*;
|
pub use diff::*;
|
||||||
|
use language::language_settings::FormatOnSave;
|
||||||
pub use mention::*;
|
pub use mention::*;
|
||||||
|
use project::lsp_store::{FormatTrigger, LspFormatTarget};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
pub use terminal::*;
|
pub use terminal::*;
|
||||||
|
|
||||||
|
@ -1051,6 +1054,22 @@ impl AcpThread {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
|
||||||
|
self.entries
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.rev()
|
||||||
|
.find_map(|(index, tool_call)| {
|
||||||
|
if let AgentThreadEntry::ToolCall(tool_call) = tool_call
|
||||||
|
&& &tool_call.id == id
|
||||||
|
{
|
||||||
|
Some((index, tool_call))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
|
pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
|
||||||
let project = self.project.clone();
|
let project = self.project.clone();
|
||||||
let Some((_, tool_call)) = self.tool_call_mut(&id) else {
|
let Some((_, tool_call)) = self.tool_call_mut(&id) else {
|
||||||
|
@ -1601,30 +1620,59 @@ impl AcpThread {
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
cx.update(|cx| {
|
|
||||||
project.update(cx, |project, cx| {
|
|
||||||
project.set_agent_location(
|
|
||||||
Some(AgentLocation {
|
|
||||||
buffer: buffer.downgrade(),
|
|
||||||
position: edits
|
|
||||||
.last()
|
|
||||||
.map(|(range, _)| range.end)
|
|
||||||
.unwrap_or(Anchor::MIN),
|
|
||||||
}),
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
|
project.update(cx, |project, cx| {
|
||||||
|
project.set_agent_location(
|
||||||
|
Some(AgentLocation {
|
||||||
|
buffer: buffer.downgrade(),
|
||||||
|
position: edits
|
||||||
|
.last()
|
||||||
|
.map(|(range, _)| range.end)
|
||||||
|
.unwrap_or(Anchor::MIN),
|
||||||
|
}),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let format_on_save = cx.update(|cx| {
|
||||||
action_log.update(cx, |action_log, cx| {
|
action_log.update(cx, |action_log, cx| {
|
||||||
action_log.buffer_read(buffer.clone(), cx);
|
action_log.buffer_read(buffer.clone(), cx);
|
||||||
});
|
});
|
||||||
buffer.update(cx, |buffer, cx| {
|
|
||||||
|
let format_on_save = buffer.update(cx, |buffer, cx| {
|
||||||
buffer.edit(edits, None, cx);
|
buffer.edit(edits, None, cx);
|
||||||
|
|
||||||
|
let settings = language::language_settings::language_settings(
|
||||||
|
buffer.language().map(|l| l.name()),
|
||||||
|
buffer.file(),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
|
||||||
|
settings.format_on_save != FormatOnSave::Off
|
||||||
});
|
});
|
||||||
action_log.update(cx, |action_log, cx| {
|
action_log.update(cx, |action_log, cx| {
|
||||||
action_log.buffer_edited(buffer.clone(), cx);
|
action_log.buffer_edited(buffer.clone(), cx);
|
||||||
});
|
});
|
||||||
|
format_on_save
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
if format_on_save {
|
||||||
|
let format_task = project.update(cx, |project, cx| {
|
||||||
|
project.format(
|
||||||
|
HashSet::from_iter([buffer.clone()]),
|
||||||
|
LspFormatTarget::Buffers,
|
||||||
|
false,
|
||||||
|
FormatTrigger::Save,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
format_task.await.log_err();
|
||||||
|
|
||||||
|
action_log.update(cx, |action_log, cx| {
|
||||||
|
action_log.buffer_edited(buffer.clone(), cx);
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
project
|
project
|
||||||
.update(cx, |project, cx| project.save_buffer(buffer, cx))?
|
.update(cx, |project, cx| project.save_buffer(buffer, cx))?
|
||||||
.await
|
.await
|
||||||
|
|
|
@ -29,6 +29,7 @@ futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
indoc.workspace = true
|
indoc.workspace = true
|
||||||
itertools.workspace = true
|
itertools.workspace = true
|
||||||
|
language.workspace = true
|
||||||
language_model.workspace = true
|
language_model.workspace = true
|
||||||
language_models.workspace = true
|
language_models.workspace = true
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
|
mod edit_tool;
|
||||||
mod mcp_server;
|
mod mcp_server;
|
||||||
|
mod permission_tool;
|
||||||
|
mod read_tool;
|
||||||
pub mod tools;
|
pub mod tools;
|
||||||
|
mod write_tool;
|
||||||
|
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
|
@ -351,18 +355,16 @@ fn spawn_claude(
|
||||||
&format!(
|
&format!(
|
||||||
"mcp__{}__{}",
|
"mcp__{}__{}",
|
||||||
mcp_server::SERVER_NAME,
|
mcp_server::SERVER_NAME,
|
||||||
mcp_server::PermissionTool::NAME,
|
permission_tool::PermissionTool::NAME,
|
||||||
),
|
),
|
||||||
"--allowedTools",
|
"--allowedTools",
|
||||||
&format!(
|
&format!(
|
||||||
"mcp__{}__{},mcp__{}__{}",
|
"mcp__{}__{}",
|
||||||
mcp_server::SERVER_NAME,
|
mcp_server::SERVER_NAME,
|
||||||
mcp_server::EditTool::NAME,
|
read_tool::ReadTool::NAME
|
||||||
mcp_server::SERVER_NAME,
|
|
||||||
mcp_server::ReadTool::NAME
|
|
||||||
),
|
),
|
||||||
"--disallowedTools",
|
"--disallowedTools",
|
||||||
"Read,Edit",
|
"Read,Write,Edit,MultiEdit",
|
||||||
])
|
])
|
||||||
.args(match mode {
|
.args(match mode {
|
||||||
ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
|
ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
|
||||||
|
@ -470,9 +472,16 @@ impl ClaudeAgentSession {
|
||||||
let content = content.to_string();
|
let content = content.to_string();
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
|
let id = acp::ToolCallId(tool_use_id.into());
|
||||||
|
let set_new_content = !content.is_empty()
|
||||||
|
&& thread.tool_call(&id).is_none_or(|(_, tool_call)| {
|
||||||
|
// preserve rich diff if we have one
|
||||||
|
tool_call.diffs().next().is_none()
|
||||||
|
});
|
||||||
|
|
||||||
thread.update_tool_call(
|
thread.update_tool_call(
|
||||||
acp::ToolCallUpdate {
|
acp::ToolCallUpdate {
|
||||||
id: acp::ToolCallId(tool_use_id.into()),
|
id,
|
||||||
fields: acp::ToolCallUpdateFields {
|
fields: acp::ToolCallUpdateFields {
|
||||||
status: if turn_state.borrow().is_canceled() {
|
status: if turn_state.borrow().is_canceled() {
|
||||||
// Do not set to completed if turn was canceled
|
// Do not set to completed if turn was canceled
|
||||||
|
@ -480,7 +489,7 @@ impl ClaudeAgentSession {
|
||||||
} else {
|
} else {
|
||||||
Some(acp::ToolCallStatus::Completed)
|
Some(acp::ToolCallStatus::Completed)
|
||||||
},
|
},
|
||||||
content: (!content.is_empty())
|
content: set_new_content
|
||||||
.then(|| vec![content.into()]),
|
.then(|| vec![content.into()]),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
|
|
178
crates/agent_servers/src/claude/edit_tool.rs
Normal file
178
crates/agent_servers/src/claude/edit_tool.rs
Normal file
|
@ -0,0 +1,178 @@
|
||||||
|
use acp_thread::AcpThread;
|
||||||
|
use anyhow::Result;
|
||||||
|
use context_server::{
|
||||||
|
listener::{McpServerTool, ToolResponse},
|
||||||
|
types::{ToolAnnotations, ToolResponseContent},
|
||||||
|
};
|
||||||
|
use gpui::{AsyncApp, WeakEntity};
|
||||||
|
use language::unified_diff;
|
||||||
|
use util::markdown::MarkdownCodeBlock;
|
||||||
|
|
||||||
|
use crate::tools::EditToolParams;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct EditTool {
|
||||||
|
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EditTool {
|
||||||
|
pub fn new(thread_rx: watch::Receiver<WeakEntity<AcpThread>>) -> Self {
|
||||||
|
Self { thread_rx }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpServerTool for EditTool {
|
||||||
|
type Input = EditToolParams;
|
||||||
|
type Output = ();
|
||||||
|
|
||||||
|
const NAME: &'static str = "Edit";
|
||||||
|
|
||||||
|
fn annotations(&self) -> ToolAnnotations {
|
||||||
|
ToolAnnotations {
|
||||||
|
title: Some("Edit file".to_string()),
|
||||||
|
read_only_hint: Some(false),
|
||||||
|
destructive_hint: Some(false),
|
||||||
|
open_world_hint: Some(false),
|
||||||
|
idempotent_hint: Some(false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run(
|
||||||
|
&self,
|
||||||
|
input: Self::Input,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Result<ToolResponse<Self::Output>> {
|
||||||
|
let mut thread_rx = self.thread_rx.clone();
|
||||||
|
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||||
|
anyhow::bail!("Thread closed");
|
||||||
|
};
|
||||||
|
|
||||||
|
let content = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.read_text_file(input.abs_path.clone(), None, None, true, cx)
|
||||||
|
})?
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let (new_content, diff) = cx
|
||||||
|
.background_executor()
|
||||||
|
.spawn(async move {
|
||||||
|
let new_content = content.replace(&input.old_text, &input.new_text);
|
||||||
|
if new_content == content {
|
||||||
|
return Err(anyhow::anyhow!("Failed to find `old_text`",));
|
||||||
|
}
|
||||||
|
let diff = unified_diff(&content, &new_content);
|
||||||
|
|
||||||
|
Ok((new_content, diff))
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.write_text_file(input.abs_path, new_content, cx)
|
||||||
|
})?
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(ToolResponse {
|
||||||
|
content: vec![ToolResponseContent::Text {
|
||||||
|
text: MarkdownCodeBlock {
|
||||||
|
tag: "diff",
|
||||||
|
text: diff.as_str().trim_end_matches('\n'),
|
||||||
|
}
|
||||||
|
.to_string(),
|
||||||
|
}],
|
||||||
|
structured_content: (),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::rc::Rc;
|
||||||
|
|
||||||
|
use acp_thread::{AgentConnection, StubAgentConnection};
|
||||||
|
use gpui::{Entity, TestAppContext};
|
||||||
|
use indoc::indoc;
|
||||||
|
use project::{FakeFs, Project};
|
||||||
|
use serde_json::json;
|
||||||
|
use settings::SettingsStore;
|
||||||
|
use util::path;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn old_text_not_found(cx: &mut TestAppContext) {
|
||||||
|
let (_thread, tool) = init_test(cx).await;
|
||||||
|
|
||||||
|
let result = tool
|
||||||
|
.run(
|
||||||
|
EditToolParams {
|
||||||
|
abs_path: path!("/root/file.txt").into(),
|
||||||
|
old_text: "hi".into(),
|
||||||
|
new_text: "bye".into(),
|
||||||
|
},
|
||||||
|
&mut cx.to_async(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(result.unwrap_err().to_string(), "Failed to find `old_text`");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn found_and_replaced(cx: &mut TestAppContext) {
|
||||||
|
let (_thread, tool) = init_test(cx).await;
|
||||||
|
|
||||||
|
let result = tool
|
||||||
|
.run(
|
||||||
|
EditToolParams {
|
||||||
|
abs_path: path!("/root/file.txt").into(),
|
||||||
|
old_text: "hello".into(),
|
||||||
|
new_text: "hi".into(),
|
||||||
|
},
|
||||||
|
&mut cx.to_async(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
result.unwrap().content[0].text().unwrap(),
|
||||||
|
indoc! {
|
||||||
|
r"
|
||||||
|
```diff
|
||||||
|
@@ -1,1 +1,1 @@
|
||||||
|
-hello
|
||||||
|
+hi
|
||||||
|
```
|
||||||
|
"
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn init_test(cx: &mut TestAppContext) -> (Entity<AcpThread>, EditTool) {
|
||||||
|
cx.update(|cx| {
|
||||||
|
let settings_store = SettingsStore::test(cx);
|
||||||
|
cx.set_global(settings_store);
|
||||||
|
language::init(cx);
|
||||||
|
Project::init_settings(cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
let connection = Rc::new(StubAgentConnection::new());
|
||||||
|
let fs = FakeFs::new(cx.executor());
|
||||||
|
fs.insert_tree(
|
||||||
|
path!("/root"),
|
||||||
|
json!({
|
||||||
|
"file.txt": "hello"
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||||
|
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
|
||||||
|
|
||||||
|
let thread = cx
|
||||||
|
.update(|cx| connection.new_thread(project, path!("/test").as_ref(), cx))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
thread_tx.send(thread.downgrade()).unwrap();
|
||||||
|
|
||||||
|
(thread, EditTool::new(thread_rx))
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,23 +1,22 @@
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
|
use crate::claude::edit_tool::EditTool;
|
||||||
|
use crate::claude::permission_tool::PermissionTool;
|
||||||
|
use crate::claude::read_tool::ReadTool;
|
||||||
|
use crate::claude::write_tool::WriteTool;
|
||||||
use acp_thread::AcpThread;
|
use acp_thread::AcpThread;
|
||||||
use agent_client_protocol as acp;
|
#[cfg(not(test))]
|
||||||
use agent_settings::AgentSettings;
|
use anyhow::Context as _;
|
||||||
use anyhow::{Context, Result};
|
use anyhow::Result;
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use context_server::listener::{McpServerTool, ToolResponse};
|
|
||||||
use context_server::types::{
|
use context_server::types::{
|
||||||
Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
|
Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
|
||||||
ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests,
|
ToolsCapabilities, requests,
|
||||||
};
|
};
|
||||||
use gpui::{App, AsyncApp, Task, WeakEntity};
|
use gpui::{App, AsyncApp, Task, WeakEntity};
|
||||||
use project::Fs;
|
use project::Fs;
|
||||||
use schemars::JsonSchema;
|
use serde::Serialize;
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use settings::{Settings as _, update_settings_file};
|
|
||||||
use util::debug_panic;
|
|
||||||
|
|
||||||
pub struct ClaudeZedMcpServer {
|
pub struct ClaudeZedMcpServer {
|
||||||
server: context_server::listener::McpServer,
|
server: context_server::listener::McpServer,
|
||||||
|
@ -34,16 +33,10 @@ impl ClaudeZedMcpServer {
|
||||||
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
|
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
|
||||||
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
|
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
|
||||||
|
|
||||||
mcp_server.add_tool(PermissionTool {
|
mcp_server.add_tool(PermissionTool::new(fs.clone(), thread_rx.clone()));
|
||||||
thread_rx: thread_rx.clone(),
|
mcp_server.add_tool(ReadTool::new(thread_rx.clone()));
|
||||||
fs: fs.clone(),
|
mcp_server.add_tool(EditTool::new(thread_rx.clone()));
|
||||||
});
|
mcp_server.add_tool(WriteTool::new(thread_rx.clone()));
|
||||||
mcp_server.add_tool(ReadTool {
|
|
||||||
thread_rx: thread_rx.clone(),
|
|
||||||
});
|
|
||||||
mcp_server.add_tool(EditTool {
|
|
||||||
thread_rx: thread_rx.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(Self { server: mcp_server })
|
Ok(Self { server: mcp_server })
|
||||||
}
|
}
|
||||||
|
@ -104,249 +97,3 @@ pub struct McpServerConfig {
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub env: Option<HashMap<String, String>>,
|
pub env: Option<HashMap<String, String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tools
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct PermissionTool {
|
|
||||||
fs: Arc<dyn Fs>,
|
|
||||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, JsonSchema, Debug)]
|
|
||||||
pub struct PermissionToolParams {
|
|
||||||
tool_name: String,
|
|
||||||
input: serde_json::Value,
|
|
||||||
tool_use_id: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub struct PermissionToolResponse {
|
|
||||||
behavior: PermissionToolBehavior,
|
|
||||||
updated_input: serde_json::Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
enum PermissionToolBehavior {
|
|
||||||
Allow,
|
|
||||||
Deny,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl McpServerTool for PermissionTool {
|
|
||||||
type Input = PermissionToolParams;
|
|
||||||
type Output = ();
|
|
||||||
|
|
||||||
const NAME: &'static str = "Confirmation";
|
|
||||||
|
|
||||||
fn description(&self) -> &'static str {
|
|
||||||
"Request permission for tool calls"
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run(
|
|
||||||
&self,
|
|
||||||
input: Self::Input,
|
|
||||||
cx: &mut AsyncApp,
|
|
||||||
) -> Result<ToolResponse<Self::Output>> {
|
|
||||||
if agent_settings::AgentSettings::try_read_global(cx, |settings| {
|
|
||||||
settings.always_allow_tool_actions
|
|
||||||
})
|
|
||||||
.unwrap_or(false)
|
|
||||||
{
|
|
||||||
let response = PermissionToolResponse {
|
|
||||||
behavior: PermissionToolBehavior::Allow,
|
|
||||||
updated_input: input.input,
|
|
||||||
};
|
|
||||||
|
|
||||||
return Ok(ToolResponse {
|
|
||||||
content: vec![ToolResponseContent::Text {
|
|
||||||
text: serde_json::to_string(&response)?,
|
|
||||||
}],
|
|
||||||
structured_content: (),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut thread_rx = self.thread_rx.clone();
|
|
||||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
|
||||||
anyhow::bail!("Thread closed");
|
|
||||||
};
|
|
||||||
|
|
||||||
let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone());
|
|
||||||
let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into());
|
|
||||||
|
|
||||||
const ALWAYS_ALLOW: &str = "always_allow";
|
|
||||||
const ALLOW: &str = "allow";
|
|
||||||
const REJECT: &str = "reject";
|
|
||||||
|
|
||||||
let chosen_option = thread
|
|
||||||
.update(cx, |thread, cx| {
|
|
||||||
thread.request_tool_call_authorization(
|
|
||||||
claude_tool.as_acp(tool_call_id).into(),
|
|
||||||
vec![
|
|
||||||
acp::PermissionOption {
|
|
||||||
id: acp::PermissionOptionId(ALWAYS_ALLOW.into()),
|
|
||||||
name: "Always Allow".into(),
|
|
||||||
kind: acp::PermissionOptionKind::AllowAlways,
|
|
||||||
},
|
|
||||||
acp::PermissionOption {
|
|
||||||
id: acp::PermissionOptionId(ALLOW.into()),
|
|
||||||
name: "Allow".into(),
|
|
||||||
kind: acp::PermissionOptionKind::AllowOnce,
|
|
||||||
},
|
|
||||||
acp::PermissionOption {
|
|
||||||
id: acp::PermissionOptionId(REJECT.into()),
|
|
||||||
name: "Reject".into(),
|
|
||||||
kind: acp::PermissionOptionKind::RejectOnce,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
})??
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let response = match chosen_option.0.as_ref() {
|
|
||||||
ALWAYS_ALLOW => {
|
|
||||||
cx.update(|cx| {
|
|
||||||
update_settings_file::<AgentSettings>(self.fs.clone(), cx, |settings, _| {
|
|
||||||
settings.set_always_allow_tool_actions(true);
|
|
||||||
});
|
|
||||||
})?;
|
|
||||||
|
|
||||||
PermissionToolResponse {
|
|
||||||
behavior: PermissionToolBehavior::Allow,
|
|
||||||
updated_input: input.input,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ALLOW => PermissionToolResponse {
|
|
||||||
behavior: PermissionToolBehavior::Allow,
|
|
||||||
updated_input: input.input,
|
|
||||||
},
|
|
||||||
REJECT => PermissionToolResponse {
|
|
||||||
behavior: PermissionToolBehavior::Deny,
|
|
||||||
updated_input: input.input,
|
|
||||||
},
|
|
||||||
opt => {
|
|
||||||
debug_panic!("Unexpected option: {}", opt);
|
|
||||||
PermissionToolResponse {
|
|
||||||
behavior: PermissionToolBehavior::Deny,
|
|
||||||
updated_input: input.input,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(ToolResponse {
|
|
||||||
content: vec![ToolResponseContent::Text {
|
|
||||||
text: serde_json::to_string(&response)?,
|
|
||||||
}],
|
|
||||||
structured_content: (),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct ReadTool {
|
|
||||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl McpServerTool for ReadTool {
|
|
||||||
type Input = ReadToolParams;
|
|
||||||
type Output = ();
|
|
||||||
|
|
||||||
const NAME: &'static str = "Read";
|
|
||||||
|
|
||||||
fn description(&self) -> &'static str {
|
|
||||||
"Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents."
|
|
||||||
}
|
|
||||||
|
|
||||||
fn annotations(&self) -> ToolAnnotations {
|
|
||||||
ToolAnnotations {
|
|
||||||
title: Some("Read file".to_string()),
|
|
||||||
read_only_hint: Some(true),
|
|
||||||
destructive_hint: Some(false),
|
|
||||||
open_world_hint: Some(false),
|
|
||||||
idempotent_hint: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run(
|
|
||||||
&self,
|
|
||||||
input: Self::Input,
|
|
||||||
cx: &mut AsyncApp,
|
|
||||||
) -> Result<ToolResponse<Self::Output>> {
|
|
||||||
let mut thread_rx = self.thread_rx.clone();
|
|
||||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
|
||||||
anyhow::bail!("Thread closed");
|
|
||||||
};
|
|
||||||
|
|
||||||
let content = thread
|
|
||||||
.update(cx, |thread, cx| {
|
|
||||||
thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx)
|
|
||||||
})?
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(ToolResponse {
|
|
||||||
content: vec![ToolResponseContent::Text { text: content }],
|
|
||||||
structured_content: (),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct EditTool {
|
|
||||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl McpServerTool for EditTool {
|
|
||||||
type Input = EditToolParams;
|
|
||||||
type Output = ();
|
|
||||||
|
|
||||||
const NAME: &'static str = "Edit";
|
|
||||||
|
|
||||||
fn description(&self) -> &'static str {
|
|
||||||
"Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better."
|
|
||||||
}
|
|
||||||
|
|
||||||
fn annotations(&self) -> ToolAnnotations {
|
|
||||||
ToolAnnotations {
|
|
||||||
title: Some("Edit file".to_string()),
|
|
||||||
read_only_hint: Some(false),
|
|
||||||
destructive_hint: Some(false),
|
|
||||||
open_world_hint: Some(false),
|
|
||||||
idempotent_hint: Some(false),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run(
|
|
||||||
&self,
|
|
||||||
input: Self::Input,
|
|
||||||
cx: &mut AsyncApp,
|
|
||||||
) -> Result<ToolResponse<Self::Output>> {
|
|
||||||
let mut thread_rx = self.thread_rx.clone();
|
|
||||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
|
||||||
anyhow::bail!("Thread closed");
|
|
||||||
};
|
|
||||||
|
|
||||||
let content = thread
|
|
||||||
.update(cx, |thread, cx| {
|
|
||||||
thread.read_text_file(input.abs_path.clone(), None, None, true, cx)
|
|
||||||
})?
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let new_content = content.replace(&input.old_text, &input.new_text);
|
|
||||||
if new_content == content {
|
|
||||||
return Err(anyhow::anyhow!("The old_text was not found in the content"));
|
|
||||||
}
|
|
||||||
|
|
||||||
thread
|
|
||||||
.update(cx, |thread, cx| {
|
|
||||||
thread.write_text_file(input.abs_path, new_content, cx)
|
|
||||||
})?
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(ToolResponse {
|
|
||||||
content: vec![],
|
|
||||||
structured_content: (),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
158
crates/agent_servers/src/claude/permission_tool.rs
Normal file
158
crates/agent_servers/src/claude/permission_tool.rs
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use acp_thread::AcpThread;
|
||||||
|
use agent_client_protocol as acp;
|
||||||
|
use agent_settings::AgentSettings;
|
||||||
|
use anyhow::{Context as _, Result};
|
||||||
|
use context_server::{
|
||||||
|
listener::{McpServerTool, ToolResponse},
|
||||||
|
types::ToolResponseContent,
|
||||||
|
};
|
||||||
|
use gpui::{AsyncApp, WeakEntity};
|
||||||
|
use project::Fs;
|
||||||
|
use schemars::JsonSchema;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use settings::{Settings as _, update_settings_file};
|
||||||
|
use util::debug_panic;
|
||||||
|
|
||||||
|
use crate::tools::ClaudeTool;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct PermissionTool {
|
||||||
|
fs: Arc<dyn Fs>,
|
||||||
|
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Request permission for tool calls
|
||||||
|
#[derive(Deserialize, JsonSchema, Debug)]
|
||||||
|
pub struct PermissionToolParams {
|
||||||
|
tool_name: String,
|
||||||
|
input: serde_json::Value,
|
||||||
|
tool_use_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct PermissionToolResponse {
|
||||||
|
behavior: PermissionToolBehavior,
|
||||||
|
updated_input: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
enum PermissionToolBehavior {
|
||||||
|
Allow,
|
||||||
|
Deny,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PermissionTool {
|
||||||
|
pub fn new(fs: Arc<dyn Fs>, thread_rx: watch::Receiver<WeakEntity<AcpThread>>) -> Self {
|
||||||
|
Self { fs, thread_rx }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpServerTool for PermissionTool {
|
||||||
|
type Input = PermissionToolParams;
|
||||||
|
type Output = ();
|
||||||
|
|
||||||
|
const NAME: &'static str = "Confirmation";
|
||||||
|
|
||||||
|
async fn run(
|
||||||
|
&self,
|
||||||
|
input: Self::Input,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Result<ToolResponse<Self::Output>> {
|
||||||
|
if agent_settings::AgentSettings::try_read_global(cx, |settings| {
|
||||||
|
settings.always_allow_tool_actions
|
||||||
|
})
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
let response = PermissionToolResponse {
|
||||||
|
behavior: PermissionToolBehavior::Allow,
|
||||||
|
updated_input: input.input,
|
||||||
|
};
|
||||||
|
|
||||||
|
return Ok(ToolResponse {
|
||||||
|
content: vec![ToolResponseContent::Text {
|
||||||
|
text: serde_json::to_string(&response)?,
|
||||||
|
}],
|
||||||
|
structured_content: (),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut thread_rx = self.thread_rx.clone();
|
||||||
|
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||||
|
anyhow::bail!("Thread closed");
|
||||||
|
};
|
||||||
|
|
||||||
|
let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone());
|
||||||
|
let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into());
|
||||||
|
|
||||||
|
const ALWAYS_ALLOW: &str = "always_allow";
|
||||||
|
const ALLOW: &str = "allow";
|
||||||
|
const REJECT: &str = "reject";
|
||||||
|
|
||||||
|
let chosen_option = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.request_tool_call_authorization(
|
||||||
|
claude_tool.as_acp(tool_call_id).into(),
|
||||||
|
vec![
|
||||||
|
acp::PermissionOption {
|
||||||
|
id: acp::PermissionOptionId(ALWAYS_ALLOW.into()),
|
||||||
|
name: "Always Allow".into(),
|
||||||
|
kind: acp::PermissionOptionKind::AllowAlways,
|
||||||
|
},
|
||||||
|
acp::PermissionOption {
|
||||||
|
id: acp::PermissionOptionId(ALLOW.into()),
|
||||||
|
name: "Allow".into(),
|
||||||
|
kind: acp::PermissionOptionKind::AllowOnce,
|
||||||
|
},
|
||||||
|
acp::PermissionOption {
|
||||||
|
id: acp::PermissionOptionId(REJECT.into()),
|
||||||
|
name: "Reject".into(),
|
||||||
|
kind: acp::PermissionOptionKind::RejectOnce,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})??
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let response = match chosen_option.0.as_ref() {
|
||||||
|
ALWAYS_ALLOW => {
|
||||||
|
cx.update(|cx| {
|
||||||
|
update_settings_file::<AgentSettings>(self.fs.clone(), cx, |settings, _| {
|
||||||
|
settings.set_always_allow_tool_actions(true);
|
||||||
|
});
|
||||||
|
})?;
|
||||||
|
|
||||||
|
PermissionToolResponse {
|
||||||
|
behavior: PermissionToolBehavior::Allow,
|
||||||
|
updated_input: input.input,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ALLOW => PermissionToolResponse {
|
||||||
|
behavior: PermissionToolBehavior::Allow,
|
||||||
|
updated_input: input.input,
|
||||||
|
},
|
||||||
|
REJECT => PermissionToolResponse {
|
||||||
|
behavior: PermissionToolBehavior::Deny,
|
||||||
|
updated_input: input.input,
|
||||||
|
},
|
||||||
|
opt => {
|
||||||
|
debug_panic!("Unexpected option: {}", opt);
|
||||||
|
PermissionToolResponse {
|
||||||
|
behavior: PermissionToolBehavior::Deny,
|
||||||
|
updated_input: input.input,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(ToolResponse {
|
||||||
|
content: vec![ToolResponseContent::Text {
|
||||||
|
text: serde_json::to_string(&response)?,
|
||||||
|
}],
|
||||||
|
structured_content: (),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
59
crates/agent_servers/src/claude/read_tool.rs
Normal file
59
crates/agent_servers/src/claude/read_tool.rs
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
use acp_thread::AcpThread;
|
||||||
|
use anyhow::Result;
|
||||||
|
use context_server::{
|
||||||
|
listener::{McpServerTool, ToolResponse},
|
||||||
|
types::{ToolAnnotations, ToolResponseContent},
|
||||||
|
};
|
||||||
|
use gpui::{AsyncApp, WeakEntity};
|
||||||
|
|
||||||
|
use crate::tools::ReadToolParams;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ReadTool {
|
||||||
|
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReadTool {
|
||||||
|
pub fn new(thread_rx: watch::Receiver<WeakEntity<AcpThread>>) -> Self {
|
||||||
|
Self { thread_rx }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpServerTool for ReadTool {
|
||||||
|
type Input = ReadToolParams;
|
||||||
|
type Output = ();
|
||||||
|
|
||||||
|
const NAME: &'static str = "Read";
|
||||||
|
|
||||||
|
fn annotations(&self) -> ToolAnnotations {
|
||||||
|
ToolAnnotations {
|
||||||
|
title: Some("Read file".to_string()),
|
||||||
|
read_only_hint: Some(true),
|
||||||
|
destructive_hint: Some(false),
|
||||||
|
open_world_hint: Some(false),
|
||||||
|
idempotent_hint: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run(
|
||||||
|
&self,
|
||||||
|
input: Self::Input,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Result<ToolResponse<Self::Output>> {
|
||||||
|
let mut thread_rx = self.thread_rx.clone();
|
||||||
|
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||||
|
anyhow::bail!("Thread closed");
|
||||||
|
};
|
||||||
|
|
||||||
|
let content = thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx)
|
||||||
|
})?
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(ToolResponse {
|
||||||
|
content: vec![ToolResponseContent::Text { text: content }],
|
||||||
|
structured_content: (),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -34,6 +34,7 @@ impl ClaudeTool {
|
||||||
// Known tools
|
// Known tools
|
||||||
"mcp__zed__Read" => Self::ReadFile(serde_json::from_value(input).log_err()),
|
"mcp__zed__Read" => Self::ReadFile(serde_json::from_value(input).log_err()),
|
||||||
"mcp__zed__Edit" => Self::Edit(serde_json::from_value(input).log_err()),
|
"mcp__zed__Edit" => Self::Edit(serde_json::from_value(input).log_err()),
|
||||||
|
"mcp__zed__Write" => Self::Write(serde_json::from_value(input).log_err()),
|
||||||
"MultiEdit" => Self::MultiEdit(serde_json::from_value(input).log_err()),
|
"MultiEdit" => Self::MultiEdit(serde_json::from_value(input).log_err()),
|
||||||
"Write" => Self::Write(serde_json::from_value(input).log_err()),
|
"Write" => Self::Write(serde_json::from_value(input).log_err()),
|
||||||
"LS" => Self::Ls(serde_json::from_value(input).log_err()),
|
"LS" => Self::Ls(serde_json::from_value(input).log_err()),
|
||||||
|
@ -93,7 +94,7 @@ impl ClaudeTool {
|
||||||
}
|
}
|
||||||
Self::MultiEdit(None) => "Multi Edit".into(),
|
Self::MultiEdit(None) => "Multi Edit".into(),
|
||||||
Self::Write(Some(params)) => {
|
Self::Write(Some(params)) => {
|
||||||
format!("Write {}", params.file_path.display())
|
format!("Write {}", params.abs_path.display())
|
||||||
}
|
}
|
||||||
Self::Write(None) => "Write".into(),
|
Self::Write(None) => "Write".into(),
|
||||||
Self::Glob(Some(params)) => {
|
Self::Glob(Some(params)) => {
|
||||||
|
@ -153,7 +154,7 @@ impl ClaudeTool {
|
||||||
}],
|
}],
|
||||||
Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff {
|
Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff {
|
||||||
diff: acp::Diff {
|
diff: acp::Diff {
|
||||||
path: params.file_path.clone(),
|
path: params.abs_path.clone(),
|
||||||
old_text: None,
|
old_text: None,
|
||||||
new_text: params.content.clone(),
|
new_text: params.content.clone(),
|
||||||
},
|
},
|
||||||
|
@ -229,7 +230,10 @@ impl ClaudeTool {
|
||||||
line: None,
|
line: None,
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
Self::Write(Some(WriteToolParams { file_path, .. })) => {
|
Self::Write(Some(WriteToolParams {
|
||||||
|
abs_path: file_path,
|
||||||
|
..
|
||||||
|
})) => {
|
||||||
vec![acp::ToolCallLocation {
|
vec![acp::ToolCallLocation {
|
||||||
path: file_path.clone(),
|
path: file_path.clone(),
|
||||||
line: None,
|
line: None,
|
||||||
|
@ -302,6 +306,20 @@ impl ClaudeTool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Edit a file.
|
||||||
|
///
|
||||||
|
/// In sessions with mcp__zed__Edit always use it instead of Edit as it will
|
||||||
|
/// allow the user to conveniently review changes.
|
||||||
|
///
|
||||||
|
/// File editing instructions:
|
||||||
|
/// - The `old_text` param must match existing file content, including indentation.
|
||||||
|
/// - The `old_text` param must come from the actual file, not an outline.
|
||||||
|
/// - The `old_text` section must not be empty.
|
||||||
|
/// - Be minimal with replacements:
|
||||||
|
/// - For unique lines, include only those lines.
|
||||||
|
/// - For non-unique lines, include enough context to identify them.
|
||||||
|
/// - Do not escape quotes, newlines, or other characters.
|
||||||
|
/// - Only edit the specified file.
|
||||||
#[derive(Deserialize, JsonSchema, Debug)]
|
#[derive(Deserialize, JsonSchema, Debug)]
|
||||||
pub struct EditToolParams {
|
pub struct EditToolParams {
|
||||||
/// The absolute path to the file to read.
|
/// The absolute path to the file to read.
|
||||||
|
@ -312,6 +330,11 @@ pub struct EditToolParams {
|
||||||
pub new_text: String,
|
pub new_text: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Reads the content of the given file in the project.
|
||||||
|
///
|
||||||
|
/// Never attempt to read a path that hasn't been previously mentioned.
|
||||||
|
///
|
||||||
|
/// In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents.
|
||||||
#[derive(Deserialize, JsonSchema, Debug)]
|
#[derive(Deserialize, JsonSchema, Debug)]
|
||||||
pub struct ReadToolParams {
|
pub struct ReadToolParams {
|
||||||
/// The absolute path to the file to read.
|
/// The absolute path to the file to read.
|
||||||
|
@ -324,11 +347,15 @@ pub struct ReadToolParams {
|
||||||
pub limit: Option<u32>,
|
pub limit: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Writes content to the specified file in the project.
|
||||||
|
///
|
||||||
|
/// In sessions with mcp__zed__Write always use it instead of Write as it will
|
||||||
|
/// allow the user to conveniently review changes.
|
||||||
#[derive(Deserialize, JsonSchema, Debug)]
|
#[derive(Deserialize, JsonSchema, Debug)]
|
||||||
pub struct WriteToolParams {
|
pub struct WriteToolParams {
|
||||||
/// Absolute path for new file
|
/// The absolute path of the file to write.
|
||||||
pub file_path: PathBuf,
|
pub abs_path: PathBuf,
|
||||||
/// File content
|
/// The full content to write.
|
||||||
pub content: String,
|
pub content: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
59
crates/agent_servers/src/claude/write_tool.rs
Normal file
59
crates/agent_servers/src/claude/write_tool.rs
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
use acp_thread::AcpThread;
|
||||||
|
use anyhow::Result;
|
||||||
|
use context_server::{
|
||||||
|
listener::{McpServerTool, ToolResponse},
|
||||||
|
types::ToolAnnotations,
|
||||||
|
};
|
||||||
|
use gpui::{AsyncApp, WeakEntity};
|
||||||
|
|
||||||
|
use crate::tools::WriteToolParams;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct WriteTool {
|
||||||
|
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WriteTool {
|
||||||
|
pub fn new(thread_rx: watch::Receiver<WeakEntity<AcpThread>>) -> Self {
|
||||||
|
Self { thread_rx }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpServerTool for WriteTool {
|
||||||
|
type Input = WriteToolParams;
|
||||||
|
type Output = ();
|
||||||
|
|
||||||
|
const NAME: &'static str = "Write";
|
||||||
|
|
||||||
|
fn annotations(&self) -> ToolAnnotations {
|
||||||
|
ToolAnnotations {
|
||||||
|
title: Some("Write file".to_string()),
|
||||||
|
read_only_hint: Some(false),
|
||||||
|
destructive_hint: Some(false),
|
||||||
|
open_world_hint: Some(false),
|
||||||
|
idempotent_hint: Some(false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run(
|
||||||
|
&self,
|
||||||
|
input: Self::Input,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Result<ToolResponse<Self::Output>> {
|
||||||
|
let mut thread_rx = self.thread_rx.clone();
|
||||||
|
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||||
|
anyhow::bail!("Thread closed");
|
||||||
|
};
|
||||||
|
|
||||||
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.write_text_file(input.abs_path, input.content, cx)
|
||||||
|
})?
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(ToolResponse {
|
||||||
|
content: vec![],
|
||||||
|
structured_content: (),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -14,6 +14,7 @@ use serde::de::DeserializeOwned;
|
||||||
use serde_json::{json, value::RawValue};
|
use serde_json::{json, value::RawValue};
|
||||||
use smol::stream::StreamExt;
|
use smol::stream::StreamExt;
|
||||||
use std::{
|
use std::{
|
||||||
|
any::TypeId,
|
||||||
cell::RefCell,
|
cell::RefCell,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
rc::Rc,
|
rc::Rc,
|
||||||
|
@ -87,18 +88,26 @@ impl McpServer {
|
||||||
settings.inline_subschemas = true;
|
settings.inline_subschemas = true;
|
||||||
let mut generator = settings.into_generator();
|
let mut generator = settings.into_generator();
|
||||||
|
|
||||||
let output_schema = generator.root_schema_for::<T::Output>();
|
let input_schema = generator.root_schema_for::<T::Input>();
|
||||||
let unit_schema = generator.root_schema_for::<T::Output>();
|
|
||||||
|
let description = input_schema
|
||||||
|
.get("description")
|
||||||
|
.and_then(|desc| desc.as_str())
|
||||||
|
.map(|desc| desc.to_string());
|
||||||
|
debug_assert!(
|
||||||
|
description.is_some(),
|
||||||
|
"Input schema struct must include a doc comment for the tool description"
|
||||||
|
);
|
||||||
|
|
||||||
let registered_tool = RegisteredTool {
|
let registered_tool = RegisteredTool {
|
||||||
tool: Tool {
|
tool: Tool {
|
||||||
name: T::NAME.into(),
|
name: T::NAME.into(),
|
||||||
description: Some(tool.description().into()),
|
description,
|
||||||
input_schema: generator.root_schema_for::<T::Input>().into(),
|
input_schema: input_schema.into(),
|
||||||
output_schema: if output_schema == unit_schema {
|
output_schema: if TypeId::of::<T::Output>() == TypeId::of::<()>() {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(output_schema.into())
|
Some(generator.root_schema_for::<T::Output>().into())
|
||||||
},
|
},
|
||||||
annotations: Some(tool.annotations()),
|
annotations: Some(tool.annotations()),
|
||||||
},
|
},
|
||||||
|
@ -399,8 +408,6 @@ pub trait McpServerTool {
|
||||||
|
|
||||||
const NAME: &'static str;
|
const NAME: &'static str;
|
||||||
|
|
||||||
fn description(&self) -> &'static str;
|
|
||||||
|
|
||||||
fn annotations(&self) -> ToolAnnotations {
|
fn annotations(&self) -> ToolAnnotations {
|
||||||
ToolAnnotations {
|
ToolAnnotations {
|
||||||
title: None,
|
title: None,
|
||||||
|
@ -418,6 +425,7 @@ pub trait McpServerTool {
|
||||||
) -> impl Future<Output = Result<ToolResponse<Self::Output>>>;
|
) -> impl Future<Output = Result<ToolResponse<Self::Output>>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct ToolResponse<T> {
|
pub struct ToolResponse<T> {
|
||||||
pub content: Vec<ToolResponseContent>,
|
pub content: Vec<ToolResponseContent>,
|
||||||
pub structured_content: T,
|
pub structured_content: T,
|
||||||
|
|
|
@ -711,6 +711,16 @@ pub enum ToolResponseContent {
|
||||||
Resource { resource: ResourceContents },
|
Resource { resource: ResourceContents },
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ToolResponseContent {
|
||||||
|
pub fn text(&self) -> Option<&str> {
|
||||||
|
if let ToolResponseContent::Text { text } = self {
|
||||||
|
Some(text)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ListToolsResponse {
|
pub struct ListToolsResponse {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue