Merge branch 'main' into mention-more
This commit is contained in:
commit
cf8e056ec4
28 changed files with 1264 additions and 260 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -197,6 +197,7 @@ dependencies = [
|
|||
"clock",
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
"context_server",
|
||||
"ctor",
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
|
@ -18026,6 +18027,7 @@ dependencies = [
|
|||
"command_palette_hooks",
|
||||
"db",
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
"futures 0.3.31",
|
||||
"git_ui",
|
||||
"gpui",
|
||||
|
@ -20923,6 +20925,7 @@ dependencies = [
|
|||
"menu",
|
||||
"postage",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"release_channel",
|
||||
"reqwest_client",
|
||||
|
|
|
@ -333,10 +333,14 @@
|
|||
"ctrl-x ctrl-c": "editor::ShowEditPrediction", // zed specific
|
||||
"ctrl-x ctrl-l": "editor::ToggleCodeActions", // zed specific
|
||||
"ctrl-x ctrl-z": "editor::Cancel",
|
||||
"ctrl-x ctrl-e": "vim::LineDown",
|
||||
"ctrl-x ctrl-y": "vim::LineUp",
|
||||
"ctrl-w": "editor::DeleteToPreviousWordStart",
|
||||
"ctrl-u": "editor::DeleteToBeginningOfLine",
|
||||
"ctrl-t": "vim::Indent",
|
||||
"ctrl-d": "vim::Outdent",
|
||||
"ctrl-y": "vim::InsertFromAbove",
|
||||
"ctrl-e": "vim::InsertFromBelow",
|
||||
"ctrl-k": ["vim::PushDigraph", {}],
|
||||
"ctrl-v": ["vim::PushLiteral", {}],
|
||||
"ctrl-shift-v": "editor::Paste", // note: this is *very* similar to ctrl-v in vim, but ctrl-shift-v on linux is the typical shortcut for paste when ctrl-v is already in use.
|
||||
|
|
|
@ -219,6 +219,15 @@ impl ToolCall {
|
|||
}
|
||||
|
||||
if let Some(raw_output) = raw_output {
|
||||
if self.content.is_empty() {
|
||||
if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
|
||||
{
|
||||
self.content
|
||||
.push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
|
||||
markdown,
|
||||
}));
|
||||
}
|
||||
}
|
||||
self.raw_output = Some(raw_output);
|
||||
}
|
||||
}
|
||||
|
@ -1239,6 +1248,48 @@ impl AcpThread {
|
|||
}
|
||||
}
|
||||
|
||||
fn markdown_for_raw_output(
|
||||
raw_output: &serde_json::Value,
|
||||
language_registry: &Arc<LanguageRegistry>,
|
||||
cx: &mut App,
|
||||
) -> Option<Entity<Markdown>> {
|
||||
match raw_output {
|
||||
serde_json::Value::Null => None,
|
||||
serde_json::Value::Bool(value) => Some(cx.new(|cx| {
|
||||
Markdown::new(
|
||||
value.to_string().into(),
|
||||
Some(language_registry.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})),
|
||||
serde_json::Value::Number(value) => Some(cx.new(|cx| {
|
||||
Markdown::new(
|
||||
value.to_string().into(),
|
||||
Some(language_registry.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})),
|
||||
serde_json::Value::String(value) => Some(cx.new(|cx| {
|
||||
Markdown::new(
|
||||
value.clone().into(),
|
||||
Some(language_registry.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})),
|
||||
value => Some(cx.new(|cx| {
|
||||
Markdown::new(
|
||||
format!("```json\n{}\n```", value).into(),
|
||||
Some(language_registry.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -23,6 +23,7 @@ assistant_tools.workspace = true
|
|||
chrono.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
context_server.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
|
@ -60,6 +61,7 @@ workspace-hack.workspace = true
|
|||
ctor.workspace = true
|
||||
client = { workspace = true, "features" = ["test-support"] }
|
||||
clock = { workspace = true, "features" = ["test-support"] }
|
||||
context_server = { workspace = true, "features" = ["test-support"] }
|
||||
editor = { workspace = true, "features" = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
fs = { workspace = true, "features" = ["test-support"] }
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use crate::{AgentResponseEvent, Thread, templates::Templates};
|
||||
use crate::{
|
||||
CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool,
|
||||
GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool, OpenTool, ReadFileTool,
|
||||
TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool,
|
||||
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
|
||||
FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool,
|
||||
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool,
|
||||
};
|
||||
use acp_thread::ModelSelector;
|
||||
use agent_client_protocol as acp;
|
||||
|
@ -55,6 +55,7 @@ pub struct NativeAgent {
|
|||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
project_context_needs_refresh: watch::Sender<()>,
|
||||
_maintain_project_context: Task<Result<()>>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
/// Shared templates for all threads
|
||||
templates: Arc<Templates>,
|
||||
project: Entity<Project>,
|
||||
|
@ -90,6 +91,9 @@ impl NativeAgent {
|
|||
_maintain_project_context: cx.spawn(async move |this, cx| {
|
||||
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
|
||||
}),
|
||||
context_server_registry: cx.new(|cx| {
|
||||
ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
|
||||
}),
|
||||
templates,
|
||||
project,
|
||||
prompt_store,
|
||||
|
@ -385,7 +389,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
// Create AcpThread
|
||||
let acp_thread = cx.update(|cx| {
|
||||
cx.new(|cx| {
|
||||
acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
|
||||
acp_thread::AcpThread::new(
|
||||
"agent2",
|
||||
self.clone(),
|
||||
project.clone(),
|
||||
session_id.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})?;
|
||||
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
|
||||
|
@ -413,11 +423,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
})
|
||||
.ok_or_else(|| {
|
||||
log::warn!("No default model configured in settings");
|
||||
anyhow!("No default model configured. Please configure a default model in settings.")
|
||||
anyhow!(
|
||||
"No default model. Please configure a default model in settings."
|
||||
)
|
||||
})?;
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
|
||||
let mut thread = Thread::new(
|
||||
project.clone(),
|
||||
agent.project_context.clone(),
|
||||
agent.context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
agent.templates.clone(),
|
||||
default_model,
|
||||
cx,
|
||||
);
|
||||
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(CopyPathTool::new(project.clone()));
|
||||
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
||||
|
@ -450,7 +470,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
acp_thread: acp_thread.downgrade(),
|
||||
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||
this.sessions.remove(acp_thread.session_id());
|
||||
})
|
||||
}),
|
||||
},
|
||||
);
|
||||
})?;
|
||||
|
@ -496,14 +516,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
})?;
|
||||
log::debug!("Found session for: {}", session_id);
|
||||
|
||||
// Convert prompt to message
|
||||
let message: Vec<MessageContent> = params
|
||||
.prompt
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.collect::<Vec<_>>();
|
||||
log::info!("Converted prompt to message: {} chars", message.len());
|
||||
// log::debug!("Message content: {}", message);
|
||||
log::debug!("Message content: {:?}", message);
|
||||
|
||||
// Get model using the ModelSelector capability (always available for agent2)
|
||||
// Get the selected model from the thread directly
|
||||
|
|
|
@ -3,6 +3,7 @@ use crate::MessageContent;
|
|||
use acp_thread::AgentConnection;
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use agent_settings::AgentProfileId;
|
||||
use anyhow::Result;
|
||||
use client::{Client, UserStore};
|
||||
use fs::{FakeFs, Fs};
|
||||
|
@ -166,7 +167,9 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
|||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
}),
|
||||
"{}",
|
||||
thread.to_markdown()
|
||||
);
|
||||
});
|
||||
}
|
||||
|
@ -474,6 +477,82 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
|||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_profiles(cx: &mut TestAppContext) {
|
||||
let ThreadTest {
|
||||
model, thread, fs, ..
|
||||
} = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.add_tool(DelayTool);
|
||||
thread.add_tool(EchoTool);
|
||||
thread.add_tool(InfiniteTool);
|
||||
});
|
||||
|
||||
// Override profiles and wait for settings to be loaded.
|
||||
fs.insert_file(
|
||||
paths::settings_file(),
|
||||
json!({
|
||||
"agent": {
|
||||
"profiles": {
|
||||
"test-1": {
|
||||
"name": "Test Profile 1",
|
||||
"tools": {
|
||||
EchoTool.name(): true,
|
||||
DelayTool.name(): true,
|
||||
}
|
||||
},
|
||||
"test-2": {
|
||||
"name": "Test Profile 2",
|
||||
"tools": {
|
||||
InfiniteTool.name(): true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.into_bytes(),
|
||||
)
|
||||
.await;
|
||||
cx.run_until_parked();
|
||||
|
||||
// Test that test-1 profile (default) has echo and delay tools
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_profile(AgentProfileId("test-1".into()));
|
||||
thread.send("test", cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
let mut pending_completions = fake_model.pending_completions();
|
||||
assert_eq!(pending_completions.len(), 1);
|
||||
let completion = pending_completions.pop().unwrap();
|
||||
let tool_names: Vec<String> = completion
|
||||
.tools
|
||||
.iter()
|
||||
.map(|tool| tool.name.clone())
|
||||
.collect();
|
||||
assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
|
||||
fake_model.end_last_completion_stream();
|
||||
|
||||
// Switch to test-2 profile, and verify that it has only the infinite tool.
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_profile(AgentProfileId("test-2".into()));
|
||||
thread.send("test2", cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
let mut pending_completions = fake_model.pending_completions();
|
||||
assert_eq!(pending_completions.len(), 1);
|
||||
let completion = pending_completions.pop().unwrap();
|
||||
let tool_names: Vec<String> = completion
|
||||
.tools
|
||||
.iter()
|
||||
.map(|tool| tool.name.clone())
|
||||
.collect();
|
||||
assert_eq!(tool_names, vec![InfiniteTool.name()]);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_cancellation(cx: &mut TestAppContext) {
|
||||
|
@ -600,6 +679,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
|||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
Project::init_settings(cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
agent_settings::init(cx);
|
||||
});
|
||||
cx.executor().forbid_parking();
|
||||
|
||||
|
@ -795,6 +875,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
|||
id: acp::ToolCallId("1".into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::Completed),
|
||||
raw_output: Some("Finished thinking.".into()),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
|
@ -818,6 +899,7 @@ struct ThreadTest {
|
|||
model: Arc<dyn LanguageModel>,
|
||||
thread: Entity<Thread>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
fs: Arc<FakeFs>,
|
||||
}
|
||||
|
||||
enum TestModel {
|
||||
|
@ -840,30 +922,57 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
fs.create_dir(paths::settings_file().parent().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
fs.insert_file(
|
||||
paths::settings_file(),
|
||||
json!({
|
||||
"agent": {
|
||||
"default_profile": "test-profile",
|
||||
"profiles": {
|
||||
"test-profile": {
|
||||
"name": "Test Profile",
|
||||
"tools": {
|
||||
EchoTool.name(): true,
|
||||
DelayTool.name(): true,
|
||||
WordListTool.name(): true,
|
||||
ToolRequiringPermission.name(): true,
|
||||
InfiniteTool.name(): true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.into_bytes(),
|
||||
)
|
||||
.await;
|
||||
|
||||
cx.update(|cx| {
|
||||
settings::init(cx);
|
||||
watch_settings(fs.clone(), cx);
|
||||
Project::init_settings(cx);
|
||||
agent_settings::init(cx);
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
watch_settings(fs.clone(), cx);
|
||||
});
|
||||
|
||||
let templates = Templates::new();
|
||||
|
||||
fs.insert_tree(path!("/test"), json!({})).await;
|
||||
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
|
||||
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
|
||||
|
||||
let model = cx
|
||||
.update(|cx| {
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
if let TestModel::Fake = model {
|
||||
Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
|
||||
} else {
|
||||
|
@ -886,20 +995,25 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
.await;
|
||||
|
||||
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
project_context.clone(),
|
||||
context_server_registry,
|
||||
action_log,
|
||||
templates,
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
ThreadTest {
|
||||
model,
|
||||
thread,
|
||||
project_context,
|
||||
fs,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use crate::{SystemPromptTemplate, Template, Templates};
|
||||
use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
|
||||
use acp_thread::MentionUri;
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
use agent_settings::{AgentProfileId, AgentSettings};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::adapt_schema_to_format;
|
||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
||||
|
@ -148,6 +148,8 @@ pub struct Thread {
|
|||
running_turn: Option<Task<()>>,
|
||||
pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
profile_id: AgentProfileId,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
templates: Arc<Templates>,
|
||||
pub selected_model: Arc<dyn LanguageModel>,
|
||||
|
@ -159,16 +161,21 @@ impl Thread {
|
|||
pub fn new(
|
||||
project: Entity<Project>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
default_model: Arc<dyn LanguageModel>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||
Self {
|
||||
messages: Vec::new(),
|
||||
completion_mode: CompletionMode::Normal,
|
||||
running_turn: None,
|
||||
pending_tool_uses: HashMap::default(),
|
||||
tools: BTreeMap::default(),
|
||||
context_server_registry,
|
||||
profile_id,
|
||||
project_context,
|
||||
templates,
|
||||
selected_model: default_model,
|
||||
|
@ -201,6 +208,10 @@ impl Thread {
|
|||
self.tools.remove(name).is_some()
|
||||
}
|
||||
|
||||
pub fn set_profile(&mut self, profile_id: AgentProfileId) {
|
||||
self.profile_id = profile_id;
|
||||
}
|
||||
|
||||
pub fn cancel(&mut self) {
|
||||
self.running_turn.take();
|
||||
|
||||
|
@ -321,6 +332,7 @@ impl Thread {
|
|||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
}),
|
||||
raw_output: tool_result.output.clone(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
@ -627,21 +639,23 @@ impl Thread {
|
|||
let messages = self.build_request_messages();
|
||||
log::info!("Request will include {} messages", messages.len());
|
||||
|
||||
let tools: Vec<LanguageModelRequestTool> = self
|
||||
.tools
|
||||
.values()
|
||||
.filter_map(|tool| {
|
||||
let tool_name = tool.name().to_string();
|
||||
log::trace!("Including tool: {}", tool_name);
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool_name,
|
||||
description: tool.description(cx).to_string(),
|
||||
input_schema: tool
|
||||
.input_schema(self.selected_model.tool_input_format())
|
||||
.log_err()?,
|
||||
let tools = if let Some(tools) = self.tools(cx).log_err() {
|
||||
tools
|
||||
.filter_map(|tool| {
|
||||
let tool_name = tool.name().to_string();
|
||||
log::trace!("Including tool: {}", tool_name);
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool_name,
|
||||
description: tool.description().to_string(),
|
||||
input_schema: tool
|
||||
.input_schema(self.selected_model.tool_input_format())
|
||||
.log_err()?,
|
||||
})
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
log::info!("Request includes {} tools", tools.len());
|
||||
|
||||
|
@ -662,6 +676,35 @@ impl Thread {
|
|||
request
|
||||
}
|
||||
|
||||
fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
|
||||
let profile = AgentSettings::get_global(cx)
|
||||
.profiles
|
||||
.get(&self.profile_id)
|
||||
.context("profile not found")?;
|
||||
|
||||
Ok(self
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|(tool_name, tool)| {
|
||||
if profile.is_tool_enabled(tool_name) {
|
||||
Some(tool)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.chain(self.context_server_registry.read(cx).servers().flat_map(
|
||||
|(server_id, tools)| {
|
||||
tools.iter().filter_map(|(tool_name, tool)| {
|
||||
if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
|
||||
Some(tool)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
},
|
||||
)))
|
||||
}
|
||||
|
||||
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
|
||||
log::trace!(
|
||||
"Building request messages from {} thread messages",
|
||||
|
@ -719,7 +762,7 @@ where
|
|||
|
||||
fn name(&self) -> SharedString;
|
||||
|
||||
fn description(&self, _cx: &mut App) -> SharedString {
|
||||
fn description(&self) -> SharedString {
|
||||
let schema = schemars::schema_for!(Self::Input);
|
||||
SharedString::new(
|
||||
schema
|
||||
|
@ -755,13 +798,13 @@ where
|
|||
pub struct Erased<T>(T);
|
||||
|
||||
pub struct AgentToolOutput {
|
||||
llm_output: LanguageModelToolResultContent,
|
||||
raw_output: serde_json::Value,
|
||||
pub llm_output: LanguageModelToolResultContent,
|
||||
pub raw_output: serde_json::Value,
|
||||
}
|
||||
|
||||
pub trait AnyAgentTool {
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, cx: &mut App) -> SharedString;
|
||||
fn description(&self) -> SharedString;
|
||||
fn kind(&self) -> acp::ToolKind;
|
||||
fn initial_title(&self, input: serde_json::Value) -> SharedString;
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
||||
|
@ -781,8 +824,8 @@ where
|
|||
self.0.name()
|
||||
}
|
||||
|
||||
fn description(&self, cx: &mut App) -> SharedString {
|
||||
self.0.description(cx)
|
||||
fn description(&self) -> SharedString {
|
||||
self.0.description()
|
||||
}
|
||||
|
||||
fn kind(&self) -> agent_client_protocol::ToolKind {
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
mod context_server_registry;
|
||||
mod copy_path_tool;
|
||||
mod create_directory_tool;
|
||||
mod delete_path_tool;
|
||||
|
@ -15,6 +16,7 @@ mod terminal_tool;
|
|||
mod thinking_tool;
|
||||
mod web_search_tool;
|
||||
|
||||
pub use context_server_registry::*;
|
||||
pub use copy_path_tool::*;
|
||||
pub use create_directory_tool::*;
|
||||
pub use delete_path_tool::*;
|
||||
|
|
231
crates/agent2/src/tools/context_server_registry.rs
Normal file
231
crates/agent2/src/tools/context_server_registry.rs
Normal file
|
@ -0,0 +1,231 @@
|
|||
use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
|
||||
use agent_client_protocol::ToolKind;
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use context_server::ContextServerId;
|
||||
use gpui::{App, Context, Entity, SharedString, Task};
|
||||
use project::context_server_store::{ContextServerStatus, ContextServerStore};
|
||||
use std::sync::Arc;
|
||||
use util::ResultExt;
|
||||
|
||||
pub struct ContextServerRegistry {
|
||||
server_store: Entity<ContextServerStore>,
|
||||
registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
|
||||
_subscription: gpui::Subscription,
|
||||
}
|
||||
|
||||
struct RegisteredContextServer {
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
load_tools: Task<Result<()>>,
|
||||
}
|
||||
|
||||
impl ContextServerRegistry {
|
||||
pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
|
||||
let mut this = Self {
|
||||
server_store: server_store.clone(),
|
||||
registered_servers: HashMap::default(),
|
||||
_subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
|
||||
};
|
||||
for server in server_store.read(cx).running_servers() {
|
||||
this.reload_tools_for_server(server.id(), cx);
|
||||
}
|
||||
this
|
||||
}
|
||||
|
||||
pub fn servers(
|
||||
&self,
|
||||
) -> impl Iterator<
|
||||
Item = (
|
||||
&ContextServerId,
|
||||
&BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
),
|
||||
> {
|
||||
self.registered_servers
|
||||
.iter()
|
||||
.map(|(id, server)| (id, &server.tools))
|
||||
}
|
||||
|
||||
fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
|
||||
let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
|
||||
return;
|
||||
};
|
||||
let Some(client) = server.client() else {
|
||||
return;
|
||||
};
|
||||
if !client.capable(context_server::protocol::ServerCapability::Tools) {
|
||||
return;
|
||||
}
|
||||
|
||||
let registered_server =
|
||||
self.registered_servers
|
||||
.entry(server_id.clone())
|
||||
.or_insert(RegisteredContextServer {
|
||||
tools: BTreeMap::default(),
|
||||
load_tools: Task::ready(Ok(())),
|
||||
});
|
||||
registered_server.load_tools = cx.spawn(async move |this, cx| {
|
||||
let response = client
|
||||
.request::<context_server::types::requests::ListTools>(())
|
||||
.await;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
registered_server.tools.clear();
|
||||
if let Some(response) = response.log_err() {
|
||||
for tool in response.tools {
|
||||
let tool = Arc::new(ContextServerTool::new(
|
||||
this.server_store.clone(),
|
||||
server.id(),
|
||||
tool,
|
||||
));
|
||||
registered_server.tools.insert(tool.name(), tool);
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn handle_context_server_store_event(
|
||||
&mut self,
|
||||
_: Entity<ContextServerStore>,
|
||||
event: &project::context_server_store::Event,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
|
||||
match status {
|
||||
ContextServerStatus::Starting => {}
|
||||
ContextServerStatus::Running => {
|
||||
self.reload_tools_for_server(server_id.clone(), cx);
|
||||
}
|
||||
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
|
||||
self.registered_servers.remove(&server_id);
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ContextServerTool {
|
||||
store: Entity<ContextServerStore>,
|
||||
server_id: ContextServerId,
|
||||
tool: context_server::types::Tool,
|
||||
}
|
||||
|
||||
impl ContextServerTool {
|
||||
fn new(
|
||||
store: Entity<ContextServerStore>,
|
||||
server_id: ContextServerId,
|
||||
tool: context_server::types::Tool,
|
||||
) -> Self {
|
||||
Self {
|
||||
store,
|
||||
server_id,
|
||||
tool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AnyAgentTool for ContextServerTool {
|
||||
fn name(&self) -> SharedString {
|
||||
self.tool.name.clone().into()
|
||||
}
|
||||
|
||||
fn description(&self) -> SharedString {
|
||||
self.tool.description.clone().unwrap_or_default().into()
|
||||
}
|
||||
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Other
|
||||
}
|
||||
|
||||
fn initial_title(&self, _input: serde_json::Value) -> SharedString {
|
||||
format!("Run MCP tool `{}`", self.tool.name).into()
|
||||
}
|
||||
|
||||
fn input_schema(
|
||||
&self,
|
||||
format: language_model::LanguageModelToolSchemaFormat,
|
||||
) -> Result<serde_json::Value> {
|
||||
let mut schema = self.tool.input_schema.clone();
|
||||
assistant_tool::adapt_schema_to_format(&mut schema, format)?;
|
||||
Ok(match schema {
|
||||
serde_json::Value::Null => {
|
||||
serde_json::json!({ "type": "object", "properties": [] })
|
||||
}
|
||||
serde_json::Value::Object(map) if map.is_empty() => {
|
||||
serde_json::json!({ "type": "object", "properties": [] })
|
||||
}
|
||||
_ => schema,
|
||||
})
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<AgentToolOutput>> {
|
||||
let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
|
||||
return Task::ready(Err(anyhow!("Context server not found")));
|
||||
};
|
||||
let tool_name = self.tool.name.clone();
|
||||
let server_clone = server.clone();
|
||||
let input_clone = input.clone();
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
let Some(protocol) = server_clone.client() else {
|
||||
bail!("Context server not initialized");
|
||||
};
|
||||
|
||||
let arguments = if let serde_json::Value::Object(map) = input_clone {
|
||||
Some(map.into_iter().collect())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
log::trace!(
|
||||
"Running tool: {} with arguments: {:?}",
|
||||
tool_name,
|
||||
arguments
|
||||
);
|
||||
let response = protocol
|
||||
.request::<context_server::types::requests::CallTool>(
|
||||
context_server::types::CallToolParams {
|
||||
name: tool_name,
|
||||
arguments,
|
||||
meta: None,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut result = String::new();
|
||||
for content in response.content {
|
||||
match content {
|
||||
context_server::types::ToolResponseContent::Text { text } => {
|
||||
result.push_str(&text);
|
||||
}
|
||||
context_server::types::ToolResponseContent::Image { .. } => {
|
||||
log::warn!("Ignoring image content from tool response");
|
||||
}
|
||||
context_server::types::ToolResponseContent::Audio { .. } => {
|
||||
log::warn!("Ignoring audio content from tool response");
|
||||
}
|
||||
context_server::types::ToolResponseContent::Resource { .. } => {
|
||||
log::warn!("Ignoring resource content from tool response");
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(AgentToolOutput {
|
||||
raw_output: result.clone().into(),
|
||||
llm_output: result.into(),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
|
@ -85,7 +85,7 @@ impl AgentTool for DiagnosticsTool {
|
|||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Self::Output>> {
|
||||
match input.path {
|
||||
|
@ -119,11 +119,6 @@ impl AgentTool for DiagnosticsTool {
|
|||
range.start.row + 1,
|
||||
entry.diagnostic.message
|
||||
)?;
|
||||
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
content: Some(vec![output.clone().into()]),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if output.is_empty() {
|
||||
|
@ -158,18 +153,9 @@ impl AgentTool for DiagnosticsTool {
|
|||
}
|
||||
|
||||
if has_diagnostics {
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
content: Some(vec![output.clone().into()]),
|
||||
..Default::default()
|
||||
});
|
||||
Task::ready(Ok(output))
|
||||
} else {
|
||||
let text = "No errors or warnings found in the project.";
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
content: Some(vec![text.into()]),
|
||||
..Default::default()
|
||||
});
|
||||
Task::ready(Ok(text.into()))
|
||||
Task::ready(Ok("No errors or warnings found in the project.".into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -454,9 +454,8 @@ fn resolve_path(
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::Templates;
|
||||
|
||||
use super::*;
|
||||
use crate::{ContextServerRegistry, Templates};
|
||||
use action_log::ActionLog;
|
||||
use client::TelemetrySettings;
|
||||
use fs::Fs;
|
||||
|
@ -475,9 +474,20 @@ mod tests {
|
|||
fs.insert_tree("/root", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread =
|
||||
cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model));
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log,
|
||||
Templates::new(),
|
||||
model,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = EditFileToolInput {
|
||||
|
@ -661,14 +671,18 @@ mod tests {
|
|||
});
|
||||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
|
@ -792,15 +806,19 @@ mod tests {
|
|||
.unwrap();
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
|
@ -914,15 +932,19 @@ mod tests {
|
|||
init_test(cx);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
|
@ -1041,15 +1063,19 @@ mod tests {
|
|||
let fs = project::FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/project", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
|
@ -1148,14 +1174,18 @@ mod tests {
|
|||
.await;
|
||||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
|
@ -1225,14 +1255,18 @@ mod tests {
|
|||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
|
@ -1305,14 +1339,18 @@ mod tests {
|
|||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
|
@ -1382,14 +1420,18 @@ mod tests {
|
|||
let fs = project::FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
|
|
|
@ -136,7 +136,7 @@ impl AgentTool for FetchTool {
|
|||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Self::Output>> {
|
||||
let text = cx.background_spawn({
|
||||
|
@ -149,12 +149,6 @@ impl AgentTool for FetchTool {
|
|||
if text.trim().is_empty() {
|
||||
bail!("no textual content found");
|
||||
}
|
||||
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
content: Some(vec![text.clone().into()]),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
Ok(text)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -139,9 +139,6 @@ impl AgentTool for FindPathTool {
|
|||
})
|
||||
.collect(),
|
||||
),
|
||||
raw_output: Some(serde_json::json!({
|
||||
"paths": &matches,
|
||||
})),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
|
|
|
@ -101,7 +101,7 @@ impl AgentTool for GrepTool {
|
|||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Self::Output>> {
|
||||
const CONTEXT_LINES: u32 = 2;
|
||||
|
@ -282,33 +282,22 @@ impl AgentTool for GrepTool {
|
|||
}
|
||||
}
|
||||
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
content: Some(vec![output.clone().into()]),
|
||||
..Default::default()
|
||||
});
|
||||
matches_found += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let output = if matches_found == 0 {
|
||||
"No matches found".to_string()
|
||||
if matches_found == 0 {
|
||||
Ok("No matches found".into())
|
||||
} else if has_more_matches {
|
||||
format!(
|
||||
Ok(format!(
|
||||
"Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
|
||||
input.offset + 1,
|
||||
input.offset + matches_found,
|
||||
input.offset + RESULTS_PER_PAGE,
|
||||
)
|
||||
))
|
||||
} else {
|
||||
format!("Found {matches_found} matches:\n{output}")
|
||||
};
|
||||
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
content: Some(vec![output.clone().into()]),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
Ok(output)
|
||||
Ok(format!("Found {matches_found} matches:\n{output}"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,20 +47,13 @@ impl AgentTool for NowTool {
|
|||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
_event_stream: ToolCallEventStream,
|
||||
_cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
let now = match input.timezone {
|
||||
Timezone::Utc => Utc::now().to_rfc3339(),
|
||||
Timezone::Local => Local::now().to_rfc3339(),
|
||||
};
|
||||
let content = format!("The current datetime is {now}.");
|
||||
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
content: Some(vec![content.clone().into()]),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
Task::ready(Ok(content))
|
||||
Task::ready(Ok(format!("The current datetime is {now}.")))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,6 +48,20 @@ pub struct AgentProfileSettings {
|
|||
pub context_servers: IndexMap<Arc<str>, ContextServerPreset>,
|
||||
}
|
||||
|
||||
impl AgentProfileSettings {
|
||||
pub fn is_tool_enabled(&self, tool_name: &str) -> bool {
|
||||
self.tools.get(tool_name) == Some(&true)
|
||||
}
|
||||
|
||||
pub fn is_context_server_tool_enabled(&self, server_id: &str, tool_name: &str) -> bool {
|
||||
self.enable_all_context_servers
|
||||
|| self
|
||||
.context_servers
|
||||
.get(server_id)
|
||||
.map_or(false, |preset| preset.tools.get(tool_name) == Some(&true))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ContextServerPreset {
|
||||
pub tools: IndexMap<Arc<str>, bool>,
|
||||
|
|
|
@ -167,6 +167,7 @@ fn generate_test_function(
|
|||
));
|
||||
cx_teardowns.extend(quote!(
|
||||
dispatcher.run_until_parked();
|
||||
#cx_varname.executor().forbid_parking();
|
||||
#cx_varname.quit();
|
||||
dispatcher.run_until_parked();
|
||||
));
|
||||
|
@ -232,7 +233,7 @@ fn generate_test_function(
|
|||
cx_teardowns.extend(quote!(
|
||||
drop(#cx_varname_lock);
|
||||
dispatcher.run_until_parked();
|
||||
#cx_varname.update(|cx| { cx.quit() });
|
||||
#cx_varname.update(|cx| { cx.background_executor().forbid_parking(); cx.quit(); });
|
||||
dispatcher.run_until_parked();
|
||||
));
|
||||
continue;
|
||||
|
@ -247,6 +248,7 @@ fn generate_test_function(
|
|||
));
|
||||
cx_teardowns.extend(quote!(
|
||||
dispatcher.run_until_parked();
|
||||
#cx_varname.executor().forbid_parking();
|
||||
#cx_varname.quit();
|
||||
dispatcher.run_until_parked();
|
||||
));
|
||||
|
|
|
@ -487,6 +487,8 @@ const GO_MODULE_ROOT_TASK_VARIABLE: VariableName =
|
|||
VariableName::Custom(Cow::Borrowed("GO_MODULE_ROOT"));
|
||||
const GO_SUBTEST_NAME_TASK_VARIABLE: VariableName =
|
||||
VariableName::Custom(Cow::Borrowed("GO_SUBTEST_NAME"));
|
||||
const GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE: VariableName =
|
||||
VariableName::Custom(Cow::Borrowed("GO_TABLE_TEST_CASE_NAME"));
|
||||
|
||||
impl ContextProvider for GoContextProvider {
|
||||
fn build_context(
|
||||
|
@ -545,10 +547,19 @@ impl ContextProvider for GoContextProvider {
|
|||
let go_subtest_variable = extract_subtest_name(_subtest_name.unwrap_or(""))
|
||||
.map(|subtest_name| (GO_SUBTEST_NAME_TASK_VARIABLE.clone(), subtest_name));
|
||||
|
||||
let table_test_case_name = variables.get(&VariableName::Custom(Cow::Borrowed(
|
||||
"_table_test_case_name",
|
||||
)));
|
||||
|
||||
let go_table_test_case_variable = table_test_case_name
|
||||
.and_then(extract_subtest_name)
|
||||
.map(|case_name| (GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE.clone(), case_name));
|
||||
|
||||
Task::ready(Ok(TaskVariables::from_iter(
|
||||
[
|
||||
go_package_variable,
|
||||
go_subtest_variable,
|
||||
go_table_test_case_variable,
|
||||
go_module_root_variable,
|
||||
]
|
||||
.into_iter()
|
||||
|
@ -570,6 +581,28 @@ impl ContextProvider for GoContextProvider {
|
|||
let module_cwd = Some(GO_MODULE_ROOT_TASK_VARIABLE.template_value());
|
||||
|
||||
Task::ready(Some(TaskTemplates(vec![
|
||||
TaskTemplate {
|
||||
label: format!(
|
||||
"go test {} -v -run {}/{}",
|
||||
GO_PACKAGE_TASK_VARIABLE.template_value(),
|
||||
VariableName::Symbol.template_value(),
|
||||
GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE.template_value(),
|
||||
),
|
||||
command: "go".into(),
|
||||
args: vec![
|
||||
"test".into(),
|
||||
"-v".into(),
|
||||
"-run".into(),
|
||||
format!(
|
||||
"\\^{}\\$/\\^{}\\$",
|
||||
VariableName::Symbol.template_value(),
|
||||
GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE.template_value(),
|
||||
),
|
||||
],
|
||||
cwd: package_cwd.clone(),
|
||||
tags: vec!["go-table-test-case".to_owned()],
|
||||
..TaskTemplate::default()
|
||||
},
|
||||
TaskTemplate {
|
||||
label: format!(
|
||||
"go test {} -run {}",
|
||||
|
@ -842,10 +875,21 @@ mod tests {
|
|||
.collect()
|
||||
});
|
||||
|
||||
let tag_strings: Vec<String> = runnables
|
||||
.iter()
|
||||
.flat_map(|r| &r.runnable.tags)
|
||||
.map(|tag| tag.0.to_string())
|
||||
.collect();
|
||||
|
||||
assert!(
|
||||
runnables.len() == 2,
|
||||
"Should find test function and subtest with double quotes, found: {}",
|
||||
runnables.len()
|
||||
tag_strings.contains(&"go-test".to_string()),
|
||||
"Should find go-test tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
assert!(
|
||||
tag_strings.contains(&"go-subtest".to_string()),
|
||||
"Should find go-subtest tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
|
||||
let buffer = cx.new(|cx| {
|
||||
|
@ -860,10 +904,299 @@ mod tests {
|
|||
.collect()
|
||||
});
|
||||
|
||||
let tag_strings: Vec<String> = runnables
|
||||
.iter()
|
||||
.flat_map(|r| &r.runnable.tags)
|
||||
.map(|tag| tag.0.to_string())
|
||||
.collect();
|
||||
|
||||
assert!(
|
||||
runnables.len() == 2,
|
||||
"Should find test function and subtest with backticks, found: {}",
|
||||
runnables.len()
|
||||
tag_strings.contains(&"go-test".to_string()),
|
||||
"Should find go-test tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
assert!(
|
||||
tag_strings.contains(&"go-subtest".to_string()),
|
||||
"Should find go-subtest tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_go_table_test_slice_detection(cx: &mut TestAppContext) {
|
||||
let language = language("go", tree_sitter_go::LANGUAGE.into());
|
||||
|
||||
let table_test = r#"
|
||||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestExample(t *testing.T) {
|
||||
_ = "some random string"
|
||||
|
||||
testCases := []struct{
|
||||
name string
|
||||
anotherStr string
|
||||
}{
|
||||
{
|
||||
name: "test case 1",
|
||||
anotherStr: "foo",
|
||||
},
|
||||
{
|
||||
name: "test case 2",
|
||||
anotherStr: "bar",
|
||||
},
|
||||
}
|
||||
|
||||
notATableTest := []struct{
|
||||
name string
|
||||
}{
|
||||
{
|
||||
name: "some string",
|
||||
},
|
||||
{
|
||||
name: "some other string",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// test code here
|
||||
})
|
||||
}
|
||||
}
|
||||
"#;
|
||||
|
||||
let buffer =
|
||||
cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx));
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
let runnables: Vec<_> = buffer.update(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
snapshot.runnable_ranges(0..table_test.len()).collect()
|
||||
});
|
||||
|
||||
let tag_strings: Vec<String> = runnables
|
||||
.iter()
|
||||
.flat_map(|r| &r.runnable.tags)
|
||||
.map(|tag| tag.0.to_string())
|
||||
.collect();
|
||||
|
||||
assert!(
|
||||
tag_strings.contains(&"go-test".to_string()),
|
||||
"Should find go-test tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
assert!(
|
||||
tag_strings.contains(&"go-table-test-case".to_string()),
|
||||
"Should find go-table-test-case tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
|
||||
let go_test_count = tag_strings.iter().filter(|&tag| tag == "go-test").count();
|
||||
let go_table_test_count = tag_strings
|
||||
.iter()
|
||||
.filter(|&tag| tag == "go-table-test-case")
|
||||
.count();
|
||||
|
||||
assert!(
|
||||
go_test_count == 1,
|
||||
"Should find exactly 1 go-test, found: {}",
|
||||
go_test_count
|
||||
);
|
||||
assert!(
|
||||
go_table_test_count == 2,
|
||||
"Should find exactly 2 go-table-test-case, found: {}",
|
||||
go_table_test_count
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_go_table_test_slice_ignored(cx: &mut TestAppContext) {
|
||||
let language = language("go", tree_sitter_go::LANGUAGE.into());
|
||||
|
||||
let table_test = r#"
|
||||
package main
|
||||
|
||||
func Example() {
|
||||
_ = "some random string"
|
||||
|
||||
notATableTest := []struct{
|
||||
name string
|
||||
}{
|
||||
{
|
||||
name: "some string",
|
||||
},
|
||||
{
|
||||
name: "some other string",
|
||||
},
|
||||
}
|
||||
}
|
||||
"#;
|
||||
|
||||
let buffer =
|
||||
cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx));
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
let runnables: Vec<_> = buffer.update(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
snapshot.runnable_ranges(0..table_test.len()).collect()
|
||||
});
|
||||
|
||||
let tag_strings: Vec<String> = runnables
|
||||
.iter()
|
||||
.flat_map(|r| &r.runnable.tags)
|
||||
.map(|tag| tag.0.to_string())
|
||||
.collect();
|
||||
|
||||
assert!(
|
||||
!tag_strings.contains(&"go-test".to_string()),
|
||||
"Should find go-test tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
assert!(
|
||||
!tag_strings.contains(&"go-table-test-case".to_string()),
|
||||
"Should find go-table-test-case tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_go_table_test_map_detection(cx: &mut TestAppContext) {
|
||||
let language = language("go", tree_sitter_go::LANGUAGE.into());
|
||||
|
||||
let table_test = r#"
|
||||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestExample(t *testing.T) {
|
||||
_ = "some random string"
|
||||
|
||||
testCases := map[string]struct {
|
||||
someStr string
|
||||
fail bool
|
||||
}{
|
||||
"test failure": {
|
||||
someStr: "foo",
|
||||
fail: true,
|
||||
},
|
||||
"test success": {
|
||||
someStr: "bar",
|
||||
fail: false,
|
||||
},
|
||||
}
|
||||
|
||||
notATableTest := map[string]struct {
|
||||
someStr string
|
||||
}{
|
||||
"some string": {
|
||||
someStr: "foo",
|
||||
},
|
||||
"some other string": {
|
||||
someStr: "bar",
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// test code here
|
||||
})
|
||||
}
|
||||
}
|
||||
"#;
|
||||
|
||||
let buffer =
|
||||
cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx));
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
let runnables: Vec<_> = buffer.update(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
snapshot.runnable_ranges(0..table_test.len()).collect()
|
||||
});
|
||||
|
||||
let tag_strings: Vec<String> = runnables
|
||||
.iter()
|
||||
.flat_map(|r| &r.runnable.tags)
|
||||
.map(|tag| tag.0.to_string())
|
||||
.collect();
|
||||
|
||||
assert!(
|
||||
tag_strings.contains(&"go-test".to_string()),
|
||||
"Should find go-test tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
assert!(
|
||||
tag_strings.contains(&"go-table-test-case".to_string()),
|
||||
"Should find go-table-test-case tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
|
||||
let go_test_count = tag_strings.iter().filter(|&tag| tag == "go-test").count();
|
||||
let go_table_test_count = tag_strings
|
||||
.iter()
|
||||
.filter(|&tag| tag == "go-table-test-case")
|
||||
.count();
|
||||
|
||||
assert!(
|
||||
go_test_count == 1,
|
||||
"Should find exactly 1 go-test, found: {}",
|
||||
go_test_count
|
||||
);
|
||||
assert!(
|
||||
go_table_test_count == 2,
|
||||
"Should find exactly 2 go-table-test-case, found: {}",
|
||||
go_table_test_count
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_go_table_test_map_ignored(cx: &mut TestAppContext) {
|
||||
let language = language("go", tree_sitter_go::LANGUAGE.into());
|
||||
|
||||
let table_test = r#"
|
||||
package main
|
||||
|
||||
func Example() {
|
||||
_ = "some random string"
|
||||
|
||||
notATableTest := map[string]struct {
|
||||
someStr string
|
||||
}{
|
||||
"some string": {
|
||||
someStr: "foo",
|
||||
},
|
||||
"some other string": {
|
||||
someStr: "bar",
|
||||
},
|
||||
}
|
||||
}
|
||||
"#;
|
||||
|
||||
let buffer =
|
||||
cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx));
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
let runnables: Vec<_> = buffer.update(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
snapshot.runnable_ranges(0..table_test.len()).collect()
|
||||
});
|
||||
|
||||
let tag_strings: Vec<String> = runnables
|
||||
.iter()
|
||||
.flat_map(|r| &r.runnable.tags)
|
||||
.map(|tag| tag.0.to_string())
|
||||
.collect();
|
||||
|
||||
assert!(
|
||||
!tag_strings.contains(&"go-test".to_string()),
|
||||
"Should find go-test tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
assert!(
|
||||
!tag_strings.contains(&"go-table-test-case".to_string()),
|
||||
"Should find go-table-test-case tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -91,3 +91,103 @@
|
|||
) @_
|
||||
(#set! tag go-main)
|
||||
)
|
||||
|
||||
; Table test cases - slice and map
|
||||
(
|
||||
(short_var_declaration
|
||||
left: (expression_list (identifier) @_collection_var)
|
||||
right: (expression_list
|
||||
(composite_literal
|
||||
type: [
|
||||
(slice_type)
|
||||
(map_type
|
||||
key: (type_identifier) @_key_type
|
||||
(#eq? @_key_type "string")
|
||||
)
|
||||
]
|
||||
body: (literal_value
|
||||
[
|
||||
(literal_element
|
||||
(literal_value
|
||||
(keyed_element
|
||||
(literal_element
|
||||
(identifier) @_field_name
|
||||
)
|
||||
(literal_element
|
||||
[
|
||||
(interpreted_string_literal) @run @_table_test_case_name
|
||||
(raw_string_literal) @run @_table_test_case_name
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
(keyed_element
|
||||
(literal_element
|
||||
[
|
||||
(interpreted_string_literal) @run @_table_test_case_name
|
||||
(raw_string_literal) @run @_table_test_case_name
|
||||
]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
(for_statement
|
||||
(range_clause
|
||||
left: (expression_list
|
||||
[
|
||||
(
|
||||
(identifier)
|
||||
(identifier) @_loop_var
|
||||
)
|
||||
(identifier) @_loop_var
|
||||
]
|
||||
)
|
||||
right: (identifier) @_range_var
|
||||
(#eq? @_range_var @_collection_var)
|
||||
)
|
||||
body: (block
|
||||
(expression_statement
|
||||
(call_expression
|
||||
function: (selector_expression
|
||||
operand: (identifier) @_t_var
|
||||
field: (field_identifier) @_run_method
|
||||
(#eq? @_run_method "Run")
|
||||
)
|
||||
arguments: (argument_list
|
||||
.
|
||||
[
|
||||
(selector_expression
|
||||
operand: (identifier) @_tc_var
|
||||
(#eq? @_tc_var @_loop_var)
|
||||
field: (field_identifier) @_field_check
|
||||
(#eq? @_field_check @_field_name)
|
||||
)
|
||||
(identifier) @_arg_var
|
||||
(#eq? @_arg_var @_loop_var)
|
||||
]
|
||||
.
|
||||
(func_literal
|
||||
parameters: (parameter_list
|
||||
(parameter_declaration
|
||||
type: (pointer_type
|
||||
(qualified_type
|
||||
package: (package_identifier) @_pkg
|
||||
name: (type_identifier) @_type
|
||||
(#eq? @_pkg "testing")
|
||||
(#eq? @_type "T")
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
) @_
|
||||
(#set! tag go-table-test-case)
|
||||
)
|
||||
|
|
|
@ -3367,20 +3367,6 @@ impl LocalLspStore {
|
|||
}
|
||||
}
|
||||
|
||||
fn parse_register_capabilities<T: serde::de::DeserializeOwned>(
|
||||
reg: lsp::Registration,
|
||||
) -> anyhow::Result<OneOf<bool, T>> {
|
||||
let caps = match reg
|
||||
.register_options
|
||||
.map(|options| serde_json::from_value::<T>(options))
|
||||
.transpose()?
|
||||
{
|
||||
None => OneOf::Left(true),
|
||||
Some(options) => OneOf::Right(options),
|
||||
};
|
||||
Ok(caps)
|
||||
}
|
||||
|
||||
fn notify_server_capabilities_updated(server: &LanguageServer, cx: &mut Context<LspStore>) {
|
||||
if let Some(capabilities) = serde_json::to_string(&server.capabilities()).ok() {
|
||||
cx.emit(LspStoreEvent::LanguageServerUpdate {
|
||||
|
@ -11690,190 +11676,190 @@ impl LspStore {
|
|||
// Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings.
|
||||
}
|
||||
"workspace/symbol" => {
|
||||
let options = parse_register_capabilities(reg)?;
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.workspace_symbol_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
if let Some(options) = parse_register_capabilities(reg)? {
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.workspace_symbol_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"workspace/fileOperations" => {
|
||||
let caps = reg
|
||||
.register_options
|
||||
.map(serde_json::from_value)
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities
|
||||
.workspace
|
||||
.get_or_insert_default()
|
||||
.file_operations = Some(caps);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
if let Some(options) = reg.register_options {
|
||||
let caps = serde_json::from_value(options)?;
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities
|
||||
.workspace
|
||||
.get_or_insert_default()
|
||||
.file_operations = Some(caps);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"workspace/executeCommand" => {
|
||||
let options = reg
|
||||
.register_options
|
||||
.map(serde_json::from_value)
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.execute_command_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
if let Some(options) = reg.register_options {
|
||||
let options = serde_json::from_value(options)?;
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.execute_command_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"textDocument/rangeFormatting" => {
|
||||
let options = parse_register_capabilities(reg)?;
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.document_range_formatting_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
if let Some(options) = parse_register_capabilities(reg)? {
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.document_range_formatting_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"textDocument/onTypeFormatting" => {
|
||||
let options = reg
|
||||
if let Some(options) = reg
|
||||
.register_options
|
||||
.map(serde_json::from_value)
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.document_on_type_formatting_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
{
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.document_on_type_formatting_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"textDocument/formatting" => {
|
||||
let options = parse_register_capabilities(reg)?;
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.document_formatting_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
if let Some(options) = parse_register_capabilities(reg)? {
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.document_formatting_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"textDocument/rename" => {
|
||||
let options = parse_register_capabilities(reg)?;
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.rename_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
if let Some(options) = parse_register_capabilities(reg)? {
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.rename_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"textDocument/inlayHint" => {
|
||||
let options = parse_register_capabilities(reg)?;
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.inlay_hint_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
if let Some(options) = parse_register_capabilities(reg)? {
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.inlay_hint_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"textDocument/documentSymbol" => {
|
||||
let options = parse_register_capabilities(reg)?;
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.document_symbol_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
if let Some(options) = parse_register_capabilities(reg)? {
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.document_symbol_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"textDocument/codeAction" => {
|
||||
let options = reg
|
||||
if let Some(options) = reg
|
||||
.register_options
|
||||
.map(serde_json::from_value)
|
||||
.transpose()?;
|
||||
let provider_capability = match options {
|
||||
None => lsp::CodeActionProviderCapability::Simple(true),
|
||||
Some(options) => lsp::CodeActionProviderCapability::Options(options),
|
||||
};
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.code_action_provider = Some(provider_capability);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
.transpose()?
|
||||
{
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.code_action_provider =
|
||||
Some(lsp::CodeActionProviderCapability::Options(options));
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"textDocument/definition" => {
|
||||
let caps = parse_register_capabilities(reg)?;
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.definition_provider = Some(caps);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
if let Some(options) = parse_register_capabilities(reg)? {
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.definition_provider = Some(options);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"textDocument/completion" => {
|
||||
let caps = reg
|
||||
if let Some(caps) = reg
|
||||
.register_options
|
||||
.map(serde_json::from_value)
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.completion_provider = Some(caps);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
{
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.completion_provider = Some(caps);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"textDocument/hover" => {
|
||||
let caps = reg
|
||||
if let Some(caps) = reg
|
||||
.register_options
|
||||
.map(serde_json::from_value)
|
||||
.transpose()?
|
||||
.unwrap_or_else(|| lsp::HoverProviderCapability::Simple(true));
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.hover_provider = Some(caps);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
{
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.hover_provider = Some(caps);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"textDocument/signatureHelp" => {
|
||||
let caps = reg
|
||||
if let Some(caps) = reg
|
||||
.register_options
|
||||
.map(serde_json::from_value)
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.signature_help_provider = Some(caps);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
{
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.signature_help_provider = Some(caps);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"textDocument/synchronization" => {
|
||||
let caps = reg
|
||||
if let Some(caps) = reg
|
||||
.register_options
|
||||
.map(serde_json::from_value)
|
||||
.transpose()?
|
||||
.unwrap_or_else(|| {
|
||||
lsp::TextDocumentSyncCapability::Options(
|
||||
lsp::TextDocumentSyncOptions::default(),
|
||||
)
|
||||
{
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.text_document_sync = Some(caps);
|
||||
});
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.text_document_sync = Some(caps);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"textDocument/codeLens" => {
|
||||
let caps = reg
|
||||
if let Some(caps) = reg
|
||||
.register_options
|
||||
.map(serde_json::from_value)
|
||||
.transpose()?
|
||||
.unwrap_or_else(|| lsp::CodeLensOptions {
|
||||
resolve_provider: None,
|
||||
{
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.code_lens_provider = Some(caps);
|
||||
});
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.code_lens_provider = Some(caps);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"textDocument/diagnostic" => {
|
||||
let caps = reg
|
||||
if let Some(caps) = reg
|
||||
.register_options
|
||||
.map(serde_json::from_value)
|
||||
.transpose()?
|
||||
.unwrap_or_else(|| {
|
||||
lsp::DiagnosticServerCapabilities::RegistrationOptions(
|
||||
lsp::DiagnosticRegistrationOptions::default(),
|
||||
)
|
||||
{
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.diagnostic_provider = Some(caps);
|
||||
});
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.diagnostic_provider = Some(caps);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
"textDocument/colorProvider" => {
|
||||
let caps = reg
|
||||
if let Some(caps) = reg
|
||||
.register_options
|
||||
.map(serde_json::from_value)
|
||||
.transpose()?
|
||||
.unwrap_or_else(|| lsp::ColorProviderCapability::Simple(true));
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.color_provider = Some(caps);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
{
|
||||
server.update_capabilities(|capabilities| {
|
||||
capabilities.color_provider = Some(caps);
|
||||
});
|
||||
notify_server_capabilities_updated(&server, cx);
|
||||
}
|
||||
}
|
||||
_ => log::warn!("unhandled capability registration: {reg:?}"),
|
||||
}
|
||||
|
@ -12016,6 +12002,18 @@ impl LspStore {
|
|||
}
|
||||
}
|
||||
|
||||
// Registration with empty capabilities should be ignored.
|
||||
// https://github.com/microsoft/vscode-languageserver-node/blob/d90a87f9557a0df9142cfb33e251cfa6fe27d970/client/src/common/formatting.ts#L67-L70
|
||||
fn parse_register_capabilities<T: serde::de::DeserializeOwned>(
|
||||
reg: lsp::Registration,
|
||||
) -> anyhow::Result<Option<OneOf<bool, T>>> {
|
||||
Ok(reg
|
||||
.register_options
|
||||
.map(|options| serde_json::from_value::<T>(options))
|
||||
.transpose()?
|
||||
.map(OneOf::Right))
|
||||
}
|
||||
|
||||
fn subscribe_to_binary_statuses(
|
||||
languages: &Arc<LanguageRegistry>,
|
||||
cx: &mut Context<'_, LspStore>,
|
||||
|
|
|
@ -24,6 +24,7 @@ command_palette.workspace = true
|
|||
command_palette_hooks.workspace = true
|
||||
db.workspace = true
|
||||
editor.workspace = true
|
||||
env_logger.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
itertools.workspace = true
|
||||
|
|
|
@ -3,7 +3,9 @@ use editor::{Bias, Editor};
|
|||
use gpui::{Action, Context, Window, actions};
|
||||
use language::SelectionGoal;
|
||||
use settings::Settings;
|
||||
use text::Point;
|
||||
use vim_mode_setting::HelixModeSetting;
|
||||
use workspace::searchable::Direction;
|
||||
|
||||
actions!(
|
||||
vim,
|
||||
|
@ -11,13 +13,23 @@ actions!(
|
|||
/// Switches to normal mode with cursor positioned before the current character.
|
||||
NormalBefore,
|
||||
/// Temporarily switches to normal mode for one command.
|
||||
TemporaryNormal
|
||||
TemporaryNormal,
|
||||
/// Inserts the next character from the line above into the current line.
|
||||
InsertFromAbove,
|
||||
/// Inserts the next character from the line below into the current line.
|
||||
InsertFromBelow
|
||||
]
|
||||
);
|
||||
|
||||
pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) {
|
||||
Vim::action(editor, cx, Vim::normal_before);
|
||||
Vim::action(editor, cx, Vim::temporary_normal);
|
||||
Vim::action(editor, cx, |vim, _: &InsertFromAbove, window, cx| {
|
||||
vim.insert_around(Direction::Prev, window, cx)
|
||||
});
|
||||
Vim::action(editor, cx, |vim, _: &InsertFromBelow, window, cx| {
|
||||
vim.insert_around(Direction::Next, window, cx)
|
||||
})
|
||||
}
|
||||
|
||||
impl Vim {
|
||||
|
@ -71,6 +83,29 @@ impl Vim {
|
|||
self.switch_mode(Mode::Normal, true, window, cx);
|
||||
self.temp_mode = true;
|
||||
}
|
||||
|
||||
fn insert_around(&mut self, direction: Direction, _: &mut Window, cx: &mut Context<Self>) {
|
||||
self.update_editor(cx, |_, editor, cx| {
|
||||
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||
let mut edits = Vec::new();
|
||||
for selection in editor.selections.all::<Point>(cx) {
|
||||
let point = selection.head();
|
||||
let new_row = match direction {
|
||||
Direction::Next => point.row + 1,
|
||||
Direction::Prev if point.row > 0 => point.row - 1,
|
||||
_ => continue,
|
||||
};
|
||||
let source = snapshot.clip_point(Point::new(new_row, point.column), Bias::Left);
|
||||
if let Some(c) = snapshot.chars_at(source).next()
|
||||
&& c != '\n'
|
||||
{
|
||||
edits.push((point..point, c.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
editor.edit(edits, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -156,4 +191,13 @@ mod test {
|
|||
.await;
|
||||
cx.shared_state().await.assert_eq("hehello\nˇllo\n");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_insert_ctrl_y(cx: &mut gpui::TestAppContext) {
|
||||
let mut cx = NeovimBackedTestContext::new(cx).await;
|
||||
|
||||
cx.set_shared_state("hello\nˇ\nworld").await;
|
||||
cx.simulate_shared_keystrokes("i ctrl-y ctrl-e").await;
|
||||
cx.shared_state().await.assert_eq("hello\nhoˇ\nworld");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ impl VimTestContext {
|
|||
if cx.has_global::<VimGlobals>() {
|
||||
return;
|
||||
}
|
||||
env_logger::try_init().ok();
|
||||
cx.update(|cx| {
|
||||
let settings = SettingsStore::test(cx);
|
||||
cx.set_global(settings);
|
||||
|
|
5
crates/vim/test_data/test_insert_ctrl_y.json
Normal file
5
crates/vim/test_data/test_insert_ctrl_y.json
Normal file
|
@ -0,0 +1,5 @@
|
|||
{"Put":{"state":"hello\nˇ\nworld"}}
|
||||
{"Key":"i"}
|
||||
{"Key":"ctrl-y"}
|
||||
{"Key":"ctrl-e"}
|
||||
{"Get":{"state":"hello\nhoˇ\nworld","mode":"Insert"}}
|
|
@ -542,6 +542,20 @@ define_connection! {
|
|||
ALTER TABLE breakpoints ADD COLUMN condition TEXT;
|
||||
ALTER TABLE breakpoints ADD COLUMN hit_condition TEXT;
|
||||
),
|
||||
sql!(CREATE TABLE toolchains2 (
|
||||
workspace_id INTEGER,
|
||||
worktree_id INTEGER,
|
||||
language_name TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
raw_json TEXT NOT NULL,
|
||||
relative_worktree_path TEXT NOT NULL,
|
||||
PRIMARY KEY (workspace_id, worktree_id, language_name, relative_worktree_path)) STRICT;
|
||||
INSERT INTO toolchains2
|
||||
SELECT * FROM toolchains;
|
||||
DROP TABLE toolchains;
|
||||
ALTER TABLE toolchains2 RENAME TO toolchains;
|
||||
)
|
||||
];
|
||||
}
|
||||
|
||||
|
@ -1428,12 +1442,12 @@ impl WorkspaceDb {
|
|||
self.write(move |conn| {
|
||||
let mut insert = conn
|
||||
.exec_bound(sql!(
|
||||
INSERT INTO toolchains(workspace_id, worktree_id, relative_worktree_path, language_name, name, path) VALUES (?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO toolchains(workspace_id, worktree_id, relative_worktree_path, language_name, name, path, raw_json) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT DO
|
||||
UPDATE SET
|
||||
name = ?5,
|
||||
path = ?6
|
||||
|
||||
path = ?6,
|
||||
raw_json = ?7
|
||||
))
|
||||
.context("Preparing insertion")?;
|
||||
|
||||
|
@ -1444,6 +1458,7 @@ impl WorkspaceDb {
|
|||
toolchain.language_name.as_ref(),
|
||||
toolchain.name.as_ref(),
|
||||
toolchain.path.as_ref(),
|
||||
toolchain.as_json.to_string(),
|
||||
))?;
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -26,6 +26,7 @@ collections.workspace = true
|
|||
command_palette_hooks.workspace = true
|
||||
copilot.workspace = true
|
||||
db.workspace = true
|
||||
edit_prediction.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
fs.workspace = true
|
||||
|
@ -33,13 +34,13 @@ futures.workspace = true
|
|||
gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
indoc.workspace = true
|
||||
edit_prediction.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
menu.workspace = true
|
||||
postage.workspace = true
|
||||
project.workspace = true
|
||||
rand.workspace = true
|
||||
regex.workspace = true
|
||||
release_channel.workspace = true
|
||||
serde.workspace = true
|
||||
|
|
|
@ -429,6 +429,7 @@ impl Zeta {
|
|||
body,
|
||||
editable_range,
|
||||
} = gather_task.await?;
|
||||
let done_gathering_context_at = Instant::now();
|
||||
|
||||
log::debug!(
|
||||
"Events:\n{}\nExcerpt:\n{:?}",
|
||||
|
@ -481,6 +482,7 @@ impl Zeta {
|
|||
}
|
||||
};
|
||||
|
||||
let received_response_at = Instant::now();
|
||||
log::debug!("completion response: {}", &response.output_excerpt);
|
||||
|
||||
if let Some(usage) = usage {
|
||||
|
@ -492,7 +494,7 @@ impl Zeta {
|
|||
.ok();
|
||||
}
|
||||
|
||||
Self::process_completion_response(
|
||||
let edit_prediction = Self::process_completion_response(
|
||||
response,
|
||||
buffer,
|
||||
&snapshot,
|
||||
|
@ -505,7 +507,25 @@ impl Zeta {
|
|||
buffer_snapshotted_at,
|
||||
&cx,
|
||||
)
|
||||
.await
|
||||
.await;
|
||||
|
||||
let finished_at = Instant::now();
|
||||
|
||||
// record latency for ~1% of requests
|
||||
if rand::random::<u8>() <= 2 {
|
||||
telemetry::event!(
|
||||
"Edit Prediction Request",
|
||||
context_latency = done_gathering_context_at
|
||||
.duration_since(buffer_snapshotted_at)
|
||||
.as_millis(),
|
||||
request_latency = received_response_at
|
||||
.duration_since(done_gathering_context_at)
|
||||
.as_millis(),
|
||||
process_latency = finished_at.duration_since(received_response_at).as_millis()
|
||||
);
|
||||
}
|
||||
|
||||
edit_prediction
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -326,7 +326,7 @@ When you use `cargo build` or `cargo test` as the build command, Zed can infer t
|
|||
[
|
||||
{
|
||||
"label": "Build & Debug native binary",
|
||||
"adapter": "CodeLLDB"
|
||||
"adapter": "CodeLLDB",
|
||||
"build": {
|
||||
"command": "cargo",
|
||||
"args": ["build"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue