diff --git a/Cargo.lock b/Cargo.lock index 1c7e594da2..5ee4e94281 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -197,6 +197,7 @@ dependencies = [ "clock", "cloud_llm_client", "collections", + "context_server", "ctor", "editor", "env_logger 0.11.8", @@ -18026,6 +18027,7 @@ dependencies = [ "command_palette_hooks", "db", "editor", + "env_logger 0.11.8", "futures 0.3.31", "git_ui", "gpui", @@ -20923,6 +20925,7 @@ dependencies = [ "menu", "postage", "project", + "rand 0.8.5", "regex", "release_channel", "reqwest_client", diff --git a/assets/keymaps/vim.json b/assets/keymaps/vim.json index 57edb1e4c1..98f9cafc40 100644 --- a/assets/keymaps/vim.json +++ b/assets/keymaps/vim.json @@ -333,10 +333,14 @@ "ctrl-x ctrl-c": "editor::ShowEditPrediction", // zed specific "ctrl-x ctrl-l": "editor::ToggleCodeActions", // zed specific "ctrl-x ctrl-z": "editor::Cancel", + "ctrl-x ctrl-e": "vim::LineDown", + "ctrl-x ctrl-y": "vim::LineUp", "ctrl-w": "editor::DeleteToPreviousWordStart", "ctrl-u": "editor::DeleteToBeginningOfLine", "ctrl-t": "vim::Indent", "ctrl-d": "vim::Outdent", + "ctrl-y": "vim::InsertFromAbove", + "ctrl-e": "vim::InsertFromBelow", "ctrl-k": ["vim::PushDigraph", {}], "ctrl-v": ["vim::PushLiteral", {}], "ctrl-shift-v": "editor::Paste", // note: this is *very* similar to ctrl-v in vim, but ctrl-shift-v on linux is the typical shortcut for paste when ctrl-v is already in use. diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 4f8773b416..eccbef96b8 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -219,6 +219,15 @@ impl ToolCall { } if let Some(raw_output) = raw_output { + if self.content.is_empty() { + if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx) + { + self.content + .push(ToolCallContent::ContentBlock(ContentBlock::Markdown { + markdown, + })); + } + } self.raw_output = Some(raw_output); } } @@ -1239,6 +1248,48 @@ impl AcpThread { } } +fn markdown_for_raw_output( + raw_output: &serde_json::Value, + language_registry: &Arc, + cx: &mut App, +) -> Option> { + match raw_output { + serde_json::Value::Null => None, + serde_json::Value::Bool(value) => Some(cx.new(|cx| { + Markdown::new( + value.to_string().into(), + Some(language_registry.clone()), + None, + cx, + ) + })), + serde_json::Value::Number(value) => Some(cx.new(|cx| { + Markdown::new( + value.to_string().into(), + Some(language_registry.clone()), + None, + cx, + ) + })), + serde_json::Value::String(value) => Some(cx.new(|cx| { + Markdown::new( + value.clone().into(), + Some(language_registry.clone()), + None, + cx, + ) + })), + value => Some(cx.new(|cx| { + Markdown::new( + format!("```json\n{}\n```", value).into(), + Some(language_registry.clone()), + None, + cx, + ) + })), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 7ee48aca04..1030380dc0 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -23,6 +23,7 @@ assistant_tools.workspace = true chrono.workspace = true cloud_llm_client.workspace = true collections.workspace = true +context_server.workspace = true fs.workspace = true futures.workspace = true gpui.workspace = true @@ -60,6 +61,7 @@ workspace-hack.workspace = true ctor.workspace = true client = { workspace = true, "features" = ["test-support"] } clock = { workspace = true, "features" = ["test-support"] } +context_server = { workspace = true, "features" = ["test-support"] } editor = { workspace = true, "features" = ["test-support"] } env_logger.workspace = true fs = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index d3aae39a7f..7439b2a088 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,8 +1,8 @@ use crate::{AgentResponseEvent, Thread, templates::Templates}; use crate::{ - CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, - GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool, OpenTool, ReadFileTool, - TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool, + ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool, + FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool, + OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool, }; use acp_thread::ModelSelector; use agent_client_protocol as acp; @@ -55,6 +55,7 @@ pub struct NativeAgent { project_context: Rc>, project_context_needs_refresh: watch::Sender<()>, _maintain_project_context: Task>, + context_server_registry: Entity, /// Shared templates for all threads templates: Arc, project: Entity, @@ -90,6 +91,9 @@ impl NativeAgent { _maintain_project_context: cx.spawn(async move |this, cx| { Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await }), + context_server_registry: cx.new(|cx| { + ContextServerRegistry::new(project.read(cx).context_server_store(), cx) + }), templates, project, prompt_store, @@ -385,7 +389,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection { // Create AcpThread let acp_thread = cx.update(|cx| { cx.new(|cx| { - acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx) + acp_thread::AcpThread::new( + "agent2", + self.clone(), + project.clone(), + session_id.clone(), + cx, + ) }) })?; let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?; @@ -413,11 +423,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection { }) .ok_or_else(|| { log::warn!("No default model configured in settings"); - anyhow!("No default model configured. Please configure a default model in settings.") + anyhow!( + "No default model. Please configure a default model in settings." + ) })?; let thread = cx.new(|cx| { - let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model); + let mut thread = Thread::new( + project.clone(), + agent.project_context.clone(), + agent.context_server_registry.clone(), + action_log.clone(), + agent.templates.clone(), + default_model, + cx, + ); thread.add_tool(CreateDirectoryTool::new(project.clone())); thread.add_tool(CopyPathTool::new(project.clone())); thread.add_tool(DiagnosticsTool::new(project.clone())); @@ -450,7 +470,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { acp_thread: acp_thread.downgrade(), _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { this.sessions.remove(acp_thread.session_id()); - }) + }), }, ); })?; @@ -496,14 +516,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection { })?; log::debug!("Found session for: {}", session_id); - // Convert prompt to message let message: Vec = params .prompt .into_iter() .map(Into::into) .collect::>(); log::info!("Converted prompt to message: {} chars", message.len()); - // log::debug!("Message content: {}", message); + log::debug!("Message content: {:?}", message); // Get model using the ModelSelector capability (always available for agent2) // Get the selected model from the thread directly diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index cf5c70448c..88cf92836b 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -3,6 +3,7 @@ use crate::MessageContent; use acp_thread::AgentConnection; use action_log::ActionLog; use agent_client_protocol::{self as acp}; +use agent_settings::AgentProfileId; use anyhow::Result; use client::{Client, UserStore}; use fs::{FakeFs, Fs}; @@ -166,7 +167,9 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { } else { false } - }) + }), + "{}", + thread.to_markdown() ); }); } @@ -474,6 +477,82 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_profiles(cx: &mut TestAppContext) { + let ThreadTest { + model, thread, fs, .. + } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + thread.update(cx, |thread, _cx| { + thread.add_tool(DelayTool); + thread.add_tool(EchoTool); + thread.add_tool(InfiniteTool); + }); + + // Override profiles and wait for settings to be loaded. + fs.insert_file( + paths::settings_file(), + json!({ + "agent": { + "profiles": { + "test-1": { + "name": "Test Profile 1", + "tools": { + EchoTool.name(): true, + DelayTool.name(): true, + } + }, + "test-2": { + "name": "Test Profile 2", + "tools": { + InfiniteTool.name(): true, + } + } + } + } + }) + .to_string() + .into_bytes(), + ) + .await; + cx.run_until_parked(); + + // Test that test-1 profile (default) has echo and delay tools + thread.update(cx, |thread, cx| { + thread.set_profile(AgentProfileId("test-1".into())); + thread.send("test", cx); + }); + cx.run_until_parked(); + + let mut pending_completions = fake_model.pending_completions(); + assert_eq!(pending_completions.len(), 1); + let completion = pending_completions.pop().unwrap(); + let tool_names: Vec = completion + .tools + .iter() + .map(|tool| tool.name.clone()) + .collect(); + assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]); + fake_model.end_last_completion_stream(); + + // Switch to test-2 profile, and verify that it has only the infinite tool. + thread.update(cx, |thread, cx| { + thread.set_profile(AgentProfileId("test-2".into())); + thread.send("test2", cx) + }); + cx.run_until_parked(); + let mut pending_completions = fake_model.pending_completions(); + assert_eq!(pending_completions.len(), 1); + let completion = pending_completions.pop().unwrap(); + let tool_names: Vec = completion + .tools + .iter() + .map(|tool| tool.name.clone()) + .collect(); + assert_eq!(tool_names, vec![InfiniteTool.name()]); +} + #[gpui::test] #[ignore = "can't run on CI yet"] async fn test_cancellation(cx: &mut TestAppContext) { @@ -600,6 +679,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) { language_models::init(user_store.clone(), client.clone(), cx); Project::init_settings(cx); LanguageModelRegistry::test(cx); + agent_settings::init(cx); }); cx.executor().forbid_parking(); @@ -795,6 +875,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { id: acp::ToolCallId("1".into()), fields: acp::ToolCallUpdateFields { status: Some(acp::ToolCallStatus::Completed), + raw_output: Some("Finished thinking.".into()), ..Default::default() }, } @@ -818,6 +899,7 @@ struct ThreadTest { model: Arc, thread: Entity, project_context: Rc>, + fs: Arc, } enum TestModel { @@ -840,30 +922,57 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { cx.executor().allow_parking(); let fs = FakeFs::new(cx.background_executor.clone()); + fs.create_dir(paths::settings_file().parent().unwrap()) + .await + .unwrap(); + fs.insert_file( + paths::settings_file(), + json!({ + "agent": { + "default_profile": "test-profile", + "profiles": { + "test-profile": { + "name": "Test Profile", + "tools": { + EchoTool.name(): true, + DelayTool.name(): true, + WordListTool.name(): true, + ToolRequiringPermission.name(): true, + InfiniteTool.name(): true, + } + } + } + } + }) + .to_string() + .into_bytes(), + ) + .await; cx.update(|cx| { settings::init(cx); - watch_settings(fs.clone(), cx); Project::init_settings(cx); agent_settings::init(cx); + gpui_tokio::init(cx); + let http_client = ReqwestClient::user_agent("agent tests").unwrap(); + cx.set_http_client(Arc::new(http_client)); + + client::init_settings(cx); + let client = Client::production(cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + + watch_settings(fs.clone(), cx); }); + let templates = Templates::new(); fs.insert_tree(path!("/test"), json!({})).await; - let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; let model = cx .update(|cx| { - gpui_tokio::init(cx); - let http_client = ReqwestClient::user_agent("agent tests").unwrap(); - cx.set_http_client(Arc::new(http_client)); - - client::init_settings(cx); - let client = Client::production(cx); - let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(client.clone(), cx); - language_models::init(user_store.clone(), client.clone(), cx); - if let TestModel::Fake = model { Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>) } else { @@ -886,20 +995,25 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { .await; let project_context = Rc::new(RefCell::new(ProjectContext::default())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project, project_context.clone(), + context_server_registry, action_log, templates, model.clone(), + cx, ) }); ThreadTest { model, thread, project_context, + fs, } } diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 1bd46574d2..678e4cb5d2 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,8 +1,8 @@ -use crate::{SystemPromptTemplate, Template, Templates}; +use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates}; use acp_thread::MentionUri; use action_log::ActionLog; use agent_client_protocol as acp; -use agent_settings::AgentSettings; +use agent_settings::{AgentProfileId, AgentSettings}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; use cloud_llm_client::{CompletionIntent, CompletionMode}; @@ -148,6 +148,8 @@ pub struct Thread { running_turn: Option>, pending_tool_uses: HashMap, tools: BTreeMap>, + context_server_registry: Entity, + profile_id: AgentProfileId, project_context: Rc>, templates: Arc, pub selected_model: Arc, @@ -159,16 +161,21 @@ impl Thread { pub fn new( project: Entity, project_context: Rc>, + context_server_registry: Entity, action_log: Entity, templates: Arc, default_model: Arc, + cx: &mut Context, ) -> Self { + let profile_id = AgentSettings::get_global(cx).default_profile.clone(); Self { messages: Vec::new(), completion_mode: CompletionMode::Normal, running_turn: None, pending_tool_uses: HashMap::default(), tools: BTreeMap::default(), + context_server_registry, + profile_id, project_context, templates, selected_model: default_model, @@ -201,6 +208,10 @@ impl Thread { self.tools.remove(name).is_some() } + pub fn set_profile(&mut self, profile_id: AgentProfileId) { + self.profile_id = profile_id; + } + pub fn cancel(&mut self) { self.running_turn.take(); @@ -321,6 +332,7 @@ impl Thread { } else { acp::ToolCallStatus::Completed }), + raw_output: tool_result.output.clone(), ..Default::default() }, ); @@ -627,21 +639,23 @@ impl Thread { let messages = self.build_request_messages(); log::info!("Request will include {} messages", messages.len()); - let tools: Vec = self - .tools - .values() - .filter_map(|tool| { - let tool_name = tool.name().to_string(); - log::trace!("Including tool: {}", tool_name); - Some(LanguageModelRequestTool { - name: tool_name, - description: tool.description(cx).to_string(), - input_schema: tool - .input_schema(self.selected_model.tool_input_format()) - .log_err()?, + let tools = if let Some(tools) = self.tools(cx).log_err() { + tools + .filter_map(|tool| { + let tool_name = tool.name().to_string(); + log::trace!("Including tool: {}", tool_name); + Some(LanguageModelRequestTool { + name: tool_name, + description: tool.description().to_string(), + input_schema: tool + .input_schema(self.selected_model.tool_input_format()) + .log_err()?, + }) }) - }) - .collect(); + .collect() + } else { + Vec::new() + }; log::info!("Request includes {} tools", tools.len()); @@ -662,6 +676,35 @@ impl Thread { request } + fn tools<'a>(&'a self, cx: &'a App) -> Result>> { + let profile = AgentSettings::get_global(cx) + .profiles + .get(&self.profile_id) + .context("profile not found")?; + + Ok(self + .tools + .iter() + .filter_map(|(tool_name, tool)| { + if profile.is_tool_enabled(tool_name) { + Some(tool) + } else { + None + } + }) + .chain(self.context_server_registry.read(cx).servers().flat_map( + |(server_id, tools)| { + tools.iter().filter_map(|(tool_name, tool)| { + if profile.is_context_server_tool_enabled(&server_id.0, tool_name) { + Some(tool) + } else { + None + } + }) + }, + ))) + } + fn build_request_messages(&self) -> Vec { log::trace!( "Building request messages from {} thread messages", @@ -719,7 +762,7 @@ where fn name(&self) -> SharedString; - fn description(&self, _cx: &mut App) -> SharedString { + fn description(&self) -> SharedString { let schema = schemars::schema_for!(Self::Input); SharedString::new( schema @@ -755,13 +798,13 @@ where pub struct Erased(T); pub struct AgentToolOutput { - llm_output: LanguageModelToolResultContent, - raw_output: serde_json::Value, + pub llm_output: LanguageModelToolResultContent, + pub raw_output: serde_json::Value, } pub trait AnyAgentTool { fn name(&self) -> SharedString; - fn description(&self, cx: &mut App) -> SharedString; + fn description(&self) -> SharedString; fn kind(&self) -> acp::ToolKind; fn initial_title(&self, input: serde_json::Value) -> SharedString; fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; @@ -781,8 +824,8 @@ where self.0.name() } - fn description(&self, cx: &mut App) -> SharedString { - self.0.description(cx) + fn description(&self) -> SharedString { + self.0.description() } fn kind(&self) -> agent_client_protocol::ToolKind { diff --git a/crates/agent2/src/tools.rs b/crates/agent2/src/tools.rs index 8896b14538..d1f2b3b1c7 100644 --- a/crates/agent2/src/tools.rs +++ b/crates/agent2/src/tools.rs @@ -1,3 +1,4 @@ +mod context_server_registry; mod copy_path_tool; mod create_directory_tool; mod delete_path_tool; @@ -15,6 +16,7 @@ mod terminal_tool; mod thinking_tool; mod web_search_tool; +pub use context_server_registry::*; pub use copy_path_tool::*; pub use create_directory_tool::*; pub use delete_path_tool::*; diff --git a/crates/agent2/src/tools/context_server_registry.rs b/crates/agent2/src/tools/context_server_registry.rs new file mode 100644 index 0000000000..db39e9278c --- /dev/null +++ b/crates/agent2/src/tools/context_server_registry.rs @@ -0,0 +1,231 @@ +use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream}; +use agent_client_protocol::ToolKind; +use anyhow::{Result, anyhow, bail}; +use collections::{BTreeMap, HashMap}; +use context_server::ContextServerId; +use gpui::{App, Context, Entity, SharedString, Task}; +use project::context_server_store::{ContextServerStatus, ContextServerStore}; +use std::sync::Arc; +use util::ResultExt; + +pub struct ContextServerRegistry { + server_store: Entity, + registered_servers: HashMap, + _subscription: gpui::Subscription, +} + +struct RegisteredContextServer { + tools: BTreeMap>, + load_tools: Task>, +} + +impl ContextServerRegistry { + pub fn new(server_store: Entity, cx: &mut Context) -> Self { + let mut this = Self { + server_store: server_store.clone(), + registered_servers: HashMap::default(), + _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event), + }; + for server in server_store.read(cx).running_servers() { + this.reload_tools_for_server(server.id(), cx); + } + this + } + + pub fn servers( + &self, + ) -> impl Iterator< + Item = ( + &ContextServerId, + &BTreeMap>, + ), + > { + self.registered_servers + .iter() + .map(|(id, server)| (id, &server.tools)) + } + + fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context) { + let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else { + return; + }; + let Some(client) = server.client() else { + return; + }; + if !client.capable(context_server::protocol::ServerCapability::Tools) { + return; + } + + let registered_server = + self.registered_servers + .entry(server_id.clone()) + .or_insert(RegisteredContextServer { + tools: BTreeMap::default(), + load_tools: Task::ready(Ok(())), + }); + registered_server.load_tools = cx.spawn(async move |this, cx| { + let response = client + .request::(()) + .await; + + this.update(cx, |this, cx| { + let Some(registered_server) = this.registered_servers.get_mut(&server_id) else { + return; + }; + + registered_server.tools.clear(); + if let Some(response) = response.log_err() { + for tool in response.tools { + let tool = Arc::new(ContextServerTool::new( + this.server_store.clone(), + server.id(), + tool, + )); + registered_server.tools.insert(tool.name(), tool); + } + cx.notify(); + } + }) + }); + } + + fn handle_context_server_store_event( + &mut self, + _: Entity, + event: &project::context_server_store::Event, + cx: &mut Context, + ) { + match event { + project::context_server_store::Event::ServerStatusChanged { server_id, status } => { + match status { + ContextServerStatus::Starting => {} + ContextServerStatus::Running => { + self.reload_tools_for_server(server_id.clone(), cx); + } + ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { + self.registered_servers.remove(&server_id); + cx.notify(); + } + } + } + } + } +} + +struct ContextServerTool { + store: Entity, + server_id: ContextServerId, + tool: context_server::types::Tool, +} + +impl ContextServerTool { + fn new( + store: Entity, + server_id: ContextServerId, + tool: context_server::types::Tool, + ) -> Self { + Self { + store, + server_id, + tool, + } + } +} + +impl AnyAgentTool for ContextServerTool { + fn name(&self) -> SharedString { + self.tool.name.clone().into() + } + + fn description(&self) -> SharedString { + self.tool.description.clone().unwrap_or_default().into() + } + + fn kind(&self) -> ToolKind { + ToolKind::Other + } + + fn initial_title(&self, _input: serde_json::Value) -> SharedString { + format!("Run MCP tool `{}`", self.tool.name).into() + } + + fn input_schema( + &self, + format: language_model::LanguageModelToolSchemaFormat, + ) -> Result { + let mut schema = self.tool.input_schema.clone(); + assistant_tool::adapt_schema_to_format(&mut schema, format)?; + Ok(match schema { + serde_json::Value::Null => { + serde_json::json!({ "type": "object", "properties": [] }) + } + serde_json::Value::Object(map) if map.is_empty() => { + serde_json::json!({ "type": "object", "properties": [] }) + } + _ => schema, + }) + } + + fn run( + self: Arc, + input: serde_json::Value, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else { + return Task::ready(Err(anyhow!("Context server not found"))); + }; + let tool_name = self.tool.name.clone(); + let server_clone = server.clone(); + let input_clone = input.clone(); + + cx.spawn(async move |_cx| { + let Some(protocol) = server_clone.client() else { + bail!("Context server not initialized"); + }; + + let arguments = if let serde_json::Value::Object(map) = input_clone { + Some(map.into_iter().collect()) + } else { + None + }; + + log::trace!( + "Running tool: {} with arguments: {:?}", + tool_name, + arguments + ); + let response = protocol + .request::( + context_server::types::CallToolParams { + name: tool_name, + arguments, + meta: None, + }, + ) + .await?; + + let mut result = String::new(); + for content in response.content { + match content { + context_server::types::ToolResponseContent::Text { text } => { + result.push_str(&text); + } + context_server::types::ToolResponseContent::Image { .. } => { + log::warn!("Ignoring image content from tool response"); + } + context_server::types::ToolResponseContent::Audio { .. } => { + log::warn!("Ignoring audio content from tool response"); + } + context_server::types::ToolResponseContent::Resource { .. } => { + log::warn!("Ignoring resource content from tool response"); + } + } + } + Ok(AgentToolOutput { + raw_output: result.clone().into(), + llm_output: result.into(), + }) + }) + } +} diff --git a/crates/agent2/src/tools/diagnostics_tool.rs b/crates/agent2/src/tools/diagnostics_tool.rs index bd0b20df5a..6ba8b7b377 100644 --- a/crates/agent2/src/tools/diagnostics_tool.rs +++ b/crates/agent2/src/tools/diagnostics_tool.rs @@ -85,7 +85,7 @@ impl AgentTool for DiagnosticsTool { fn run( self: Arc, input: Self::Input, - event_stream: ToolCallEventStream, + _event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { match input.path { @@ -119,11 +119,6 @@ impl AgentTool for DiagnosticsTool { range.start.row + 1, entry.diagnostic.message )?; - - event_stream.update_fields(acp::ToolCallUpdateFields { - content: Some(vec![output.clone().into()]), - ..Default::default() - }); } if output.is_empty() { @@ -158,18 +153,9 @@ impl AgentTool for DiagnosticsTool { } if has_diagnostics { - event_stream.update_fields(acp::ToolCallUpdateFields { - content: Some(vec![output.clone().into()]), - ..Default::default() - }); Task::ready(Ok(output)) } else { - let text = "No errors or warnings found in the project."; - event_stream.update_fields(acp::ToolCallUpdateFields { - content: Some(vec![text.into()]), - ..Default::default() - }); - Task::ready(Ok(text.into())) + Task::ready(Ok("No errors or warnings found in the project.".into())) } } } diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index 88764d1953..134bc5e5e4 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -454,9 +454,8 @@ fn resolve_path( #[cfg(test)] mod tests { - use crate::Templates; - use super::*; + use crate::{ContextServerRegistry, Templates}; use action_log::ActionLog; use client::TelemetrySettings; use fs::Fs; @@ -475,9 +474,20 @@ mod tests { fs.insert_tree("/root", json!({})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = - cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model)); + let thread = cx.new(|cx| { + Thread::new( + project, + Rc::default(), + context_server_registry, + action_log, + Templates::new(), + model, + cx, + ) + }); let result = cx .update(|cx| { let input = EditFileToolInput { @@ -661,14 +671,18 @@ mod tests { }); let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project, Rc::default(), + context_server_registry, action_log.clone(), Templates::new(), model.clone(), + cx, ) }); @@ -792,15 +806,19 @@ mod tests { .unwrap(); let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project, Rc::default(), + context_server_registry, action_log.clone(), Templates::new(), model.clone(), + cx, ) }); @@ -914,15 +932,19 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project, Rc::default(), + context_server_registry, action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); @@ -1041,15 +1063,19 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/project", json!({})).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project, Rc::default(), + context_server_registry, action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); @@ -1148,14 +1174,18 @@ mod tests { .await; let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project.clone(), Rc::default(), + context_server_registry.clone(), action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); @@ -1225,14 +1255,18 @@ mod tests { .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project.clone(), Rc::default(), + context_server_registry.clone(), action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); @@ -1305,14 +1339,18 @@ mod tests { .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project.clone(), Rc::default(), + context_server_registry.clone(), action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); @@ -1382,14 +1420,18 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|_| { + let thread = cx.new(|cx| { Thread::new( project.clone(), Rc::default(), + context_server_registry, action_log.clone(), Templates::new(), model.clone(), + cx, ) }); let tool = Arc::new(EditFileTool { thread }); diff --git a/crates/agent2/src/tools/fetch_tool.rs b/crates/agent2/src/tools/fetch_tool.rs index 7f3752843c..ae26c5fe19 100644 --- a/crates/agent2/src/tools/fetch_tool.rs +++ b/crates/agent2/src/tools/fetch_tool.rs @@ -136,7 +136,7 @@ impl AgentTool for FetchTool { fn run( self: Arc, input: Self::Input, - event_stream: ToolCallEventStream, + _event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let text = cx.background_spawn({ @@ -149,12 +149,6 @@ impl AgentTool for FetchTool { if text.trim().is_empty() { bail!("no textual content found"); } - - event_stream.update_fields(acp::ToolCallUpdateFields { - content: Some(vec![text.clone().into()]), - ..Default::default() - }); - Ok(text) }) } diff --git a/crates/agent2/src/tools/find_path_tool.rs b/crates/agent2/src/tools/find_path_tool.rs index 611d34e701..552de144a7 100644 --- a/crates/agent2/src/tools/find_path_tool.rs +++ b/crates/agent2/src/tools/find_path_tool.rs @@ -139,9 +139,6 @@ impl AgentTool for FindPathTool { }) .collect(), ), - raw_output: Some(serde_json::json!({ - "paths": &matches, - })), ..Default::default() }); diff --git a/crates/agent2/src/tools/grep_tool.rs b/crates/agent2/src/tools/grep_tool.rs index 3266cb5734..e5d92b3c1d 100644 --- a/crates/agent2/src/tools/grep_tool.rs +++ b/crates/agent2/src/tools/grep_tool.rs @@ -101,7 +101,7 @@ impl AgentTool for GrepTool { fn run( self: Arc, input: Self::Input, - event_stream: ToolCallEventStream, + _event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { const CONTEXT_LINES: u32 = 2; @@ -282,33 +282,22 @@ impl AgentTool for GrepTool { } } - event_stream.update_fields(acp::ToolCallUpdateFields { - content: Some(vec![output.clone().into()]), - ..Default::default() - }); matches_found += 1; } } - let output = if matches_found == 0 { - "No matches found".to_string() + if matches_found == 0 { + Ok("No matches found".into()) } else if has_more_matches { - format!( + Ok(format!( "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}", input.offset + 1, input.offset + matches_found, input.offset + RESULTS_PER_PAGE, - ) + )) } else { - format!("Found {matches_found} matches:\n{output}") - }; - - event_stream.update_fields(acp::ToolCallUpdateFields { - content: Some(vec![output.clone().into()]), - ..Default::default() - }); - - Ok(output) + Ok(format!("Found {matches_found} matches:\n{output}")) + } }) } } diff --git a/crates/agent2/src/tools/now_tool.rs b/crates/agent2/src/tools/now_tool.rs index 71698b8275..a72ede26fe 100644 --- a/crates/agent2/src/tools/now_tool.rs +++ b/crates/agent2/src/tools/now_tool.rs @@ -47,20 +47,13 @@ impl AgentTool for NowTool { fn run( self: Arc, input: Self::Input, - event_stream: ToolCallEventStream, + _event_stream: ToolCallEventStream, _cx: &mut App, ) -> Task> { let now = match input.timezone { Timezone::Utc => Utc::now().to_rfc3339(), Timezone::Local => Local::now().to_rfc3339(), }; - let content = format!("The current datetime is {now}."); - - event_stream.update_fields(acp::ToolCallUpdateFields { - content: Some(vec![content.clone().into()]), - ..Default::default() - }); - - Task::ready(Ok(content)) + Task::ready(Ok(format!("The current datetime is {now}."))) } } diff --git a/crates/agent_settings/src/agent_profile.rs b/crates/agent_settings/src/agent_profile.rs index a6b8633b34..402cf81678 100644 --- a/crates/agent_settings/src/agent_profile.rs +++ b/crates/agent_settings/src/agent_profile.rs @@ -48,6 +48,20 @@ pub struct AgentProfileSettings { pub context_servers: IndexMap, ContextServerPreset>, } +impl AgentProfileSettings { + pub fn is_tool_enabled(&self, tool_name: &str) -> bool { + self.tools.get(tool_name) == Some(&true) + } + + pub fn is_context_server_tool_enabled(&self, server_id: &str, tool_name: &str) -> bool { + self.enable_all_context_servers + || self + .context_servers + .get(server_id) + .map_or(false, |preset| preset.tools.get(tool_name) == Some(&true)) + } +} + #[derive(Debug, Clone, Default)] pub struct ContextServerPreset { pub tools: IndexMap, bool>, diff --git a/crates/gpui_macros/src/test.rs b/crates/gpui_macros/src/test.rs index 2c52149897..adb27f42ea 100644 --- a/crates/gpui_macros/src/test.rs +++ b/crates/gpui_macros/src/test.rs @@ -167,6 +167,7 @@ fn generate_test_function( )); cx_teardowns.extend(quote!( dispatcher.run_until_parked(); + #cx_varname.executor().forbid_parking(); #cx_varname.quit(); dispatcher.run_until_parked(); )); @@ -232,7 +233,7 @@ fn generate_test_function( cx_teardowns.extend(quote!( drop(#cx_varname_lock); dispatcher.run_until_parked(); - #cx_varname.update(|cx| { cx.quit() }); + #cx_varname.update(|cx| { cx.background_executor().forbid_parking(); cx.quit(); }); dispatcher.run_until_parked(); )); continue; @@ -247,6 +248,7 @@ fn generate_test_function( )); cx_teardowns.extend(quote!( dispatcher.run_until_parked(); + #cx_varname.executor().forbid_parking(); #cx_varname.quit(); dispatcher.run_until_parked(); )); diff --git a/crates/languages/src/go.rs b/crates/languages/src/go.rs index 16c1b67203..14f646133b 100644 --- a/crates/languages/src/go.rs +++ b/crates/languages/src/go.rs @@ -487,6 +487,8 @@ const GO_MODULE_ROOT_TASK_VARIABLE: VariableName = VariableName::Custom(Cow::Borrowed("GO_MODULE_ROOT")); const GO_SUBTEST_NAME_TASK_VARIABLE: VariableName = VariableName::Custom(Cow::Borrowed("GO_SUBTEST_NAME")); +const GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE: VariableName = + VariableName::Custom(Cow::Borrowed("GO_TABLE_TEST_CASE_NAME")); impl ContextProvider for GoContextProvider { fn build_context( @@ -545,10 +547,19 @@ impl ContextProvider for GoContextProvider { let go_subtest_variable = extract_subtest_name(_subtest_name.unwrap_or("")) .map(|subtest_name| (GO_SUBTEST_NAME_TASK_VARIABLE.clone(), subtest_name)); + let table_test_case_name = variables.get(&VariableName::Custom(Cow::Borrowed( + "_table_test_case_name", + ))); + + let go_table_test_case_variable = table_test_case_name + .and_then(extract_subtest_name) + .map(|case_name| (GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE.clone(), case_name)); + Task::ready(Ok(TaskVariables::from_iter( [ go_package_variable, go_subtest_variable, + go_table_test_case_variable, go_module_root_variable, ] .into_iter() @@ -570,6 +581,28 @@ impl ContextProvider for GoContextProvider { let module_cwd = Some(GO_MODULE_ROOT_TASK_VARIABLE.template_value()); Task::ready(Some(TaskTemplates(vec![ + TaskTemplate { + label: format!( + "go test {} -v -run {}/{}", + GO_PACKAGE_TASK_VARIABLE.template_value(), + VariableName::Symbol.template_value(), + GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE.template_value(), + ), + command: "go".into(), + args: vec![ + "test".into(), + "-v".into(), + "-run".into(), + format!( + "\\^{}\\$/\\^{}\\$", + VariableName::Symbol.template_value(), + GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE.template_value(), + ), + ], + cwd: package_cwd.clone(), + tags: vec!["go-table-test-case".to_owned()], + ..TaskTemplate::default() + }, TaskTemplate { label: format!( "go test {} -run {}", @@ -842,10 +875,21 @@ mod tests { .collect() }); + let tag_strings: Vec = runnables + .iter() + .flat_map(|r| &r.runnable.tags) + .map(|tag| tag.0.to_string()) + .collect(); + assert!( - runnables.len() == 2, - "Should find test function and subtest with double quotes, found: {}", - runnables.len() + tag_strings.contains(&"go-test".to_string()), + "Should find go-test tag, found: {:?}", + tag_strings + ); + assert!( + tag_strings.contains(&"go-subtest".to_string()), + "Should find go-subtest tag, found: {:?}", + tag_strings ); let buffer = cx.new(|cx| { @@ -860,10 +904,299 @@ mod tests { .collect() }); + let tag_strings: Vec = runnables + .iter() + .flat_map(|r| &r.runnable.tags) + .map(|tag| tag.0.to_string()) + .collect(); + assert!( - runnables.len() == 2, - "Should find test function and subtest with backticks, found: {}", - runnables.len() + tag_strings.contains(&"go-test".to_string()), + "Should find go-test tag, found: {:?}", + tag_strings + ); + assert!( + tag_strings.contains(&"go-subtest".to_string()), + "Should find go-subtest tag, found: {:?}", + tag_strings + ); + } + + #[gpui::test] + fn test_go_table_test_slice_detection(cx: &mut TestAppContext) { + let language = language("go", tree_sitter_go::LANGUAGE.into()); + + let table_test = r#" + package main + + import "testing" + + func TestExample(t *testing.T) { + _ = "some random string" + + testCases := []struct{ + name string + anotherStr string + }{ + { + name: "test case 1", + anotherStr: "foo", + }, + { + name: "test case 2", + anotherStr: "bar", + }, + } + + notATableTest := []struct{ + name string + }{ + { + name: "some string", + }, + { + name: "some other string", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // test code here + }) + } + } + "#; + + let buffer = + cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx)); + cx.executor().run_until_parked(); + + let runnables: Vec<_> = buffer.update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + snapshot.runnable_ranges(0..table_test.len()).collect() + }); + + let tag_strings: Vec = runnables + .iter() + .flat_map(|r| &r.runnable.tags) + .map(|tag| tag.0.to_string()) + .collect(); + + assert!( + tag_strings.contains(&"go-test".to_string()), + "Should find go-test tag, found: {:?}", + tag_strings + ); + assert!( + tag_strings.contains(&"go-table-test-case".to_string()), + "Should find go-table-test-case tag, found: {:?}", + tag_strings + ); + + let go_test_count = tag_strings.iter().filter(|&tag| tag == "go-test").count(); + let go_table_test_count = tag_strings + .iter() + .filter(|&tag| tag == "go-table-test-case") + .count(); + + assert!( + go_test_count == 1, + "Should find exactly 1 go-test, found: {}", + go_test_count + ); + assert!( + go_table_test_count == 2, + "Should find exactly 2 go-table-test-case, found: {}", + go_table_test_count + ); + } + + #[gpui::test] + fn test_go_table_test_slice_ignored(cx: &mut TestAppContext) { + let language = language("go", tree_sitter_go::LANGUAGE.into()); + + let table_test = r#" + package main + + func Example() { + _ = "some random string" + + notATableTest := []struct{ + name string + }{ + { + name: "some string", + }, + { + name: "some other string", + }, + } + } + "#; + + let buffer = + cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx)); + cx.executor().run_until_parked(); + + let runnables: Vec<_> = buffer.update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + snapshot.runnable_ranges(0..table_test.len()).collect() + }); + + let tag_strings: Vec = runnables + .iter() + .flat_map(|r| &r.runnable.tags) + .map(|tag| tag.0.to_string()) + .collect(); + + assert!( + !tag_strings.contains(&"go-test".to_string()), + "Should find go-test tag, found: {:?}", + tag_strings + ); + assert!( + !tag_strings.contains(&"go-table-test-case".to_string()), + "Should find go-table-test-case tag, found: {:?}", + tag_strings + ); + } + + #[gpui::test] + fn test_go_table_test_map_detection(cx: &mut TestAppContext) { + let language = language("go", tree_sitter_go::LANGUAGE.into()); + + let table_test = r#" + package main + + import "testing" + + func TestExample(t *testing.T) { + _ = "some random string" + + testCases := map[string]struct { + someStr string + fail bool + }{ + "test failure": { + someStr: "foo", + fail: true, + }, + "test success": { + someStr: "bar", + fail: false, + }, + } + + notATableTest := map[string]struct { + someStr string + }{ + "some string": { + someStr: "foo", + }, + "some other string": { + someStr: "bar", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // test code here + }) + } + } + "#; + + let buffer = + cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx)); + cx.executor().run_until_parked(); + + let runnables: Vec<_> = buffer.update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + snapshot.runnable_ranges(0..table_test.len()).collect() + }); + + let tag_strings: Vec = runnables + .iter() + .flat_map(|r| &r.runnable.tags) + .map(|tag| tag.0.to_string()) + .collect(); + + assert!( + tag_strings.contains(&"go-test".to_string()), + "Should find go-test tag, found: {:?}", + tag_strings + ); + assert!( + tag_strings.contains(&"go-table-test-case".to_string()), + "Should find go-table-test-case tag, found: {:?}", + tag_strings + ); + + let go_test_count = tag_strings.iter().filter(|&tag| tag == "go-test").count(); + let go_table_test_count = tag_strings + .iter() + .filter(|&tag| tag == "go-table-test-case") + .count(); + + assert!( + go_test_count == 1, + "Should find exactly 1 go-test, found: {}", + go_test_count + ); + assert!( + go_table_test_count == 2, + "Should find exactly 2 go-table-test-case, found: {}", + go_table_test_count + ); + } + + #[gpui::test] + fn test_go_table_test_map_ignored(cx: &mut TestAppContext) { + let language = language("go", tree_sitter_go::LANGUAGE.into()); + + let table_test = r#" + package main + + func Example() { + _ = "some random string" + + notATableTest := map[string]struct { + someStr string + }{ + "some string": { + someStr: "foo", + }, + "some other string": { + someStr: "bar", + }, + } + } + "#; + + let buffer = + cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx)); + cx.executor().run_until_parked(); + + let runnables: Vec<_> = buffer.update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + snapshot.runnable_ranges(0..table_test.len()).collect() + }); + + let tag_strings: Vec = runnables + .iter() + .flat_map(|r| &r.runnable.tags) + .map(|tag| tag.0.to_string()) + .collect(); + + assert!( + !tag_strings.contains(&"go-test".to_string()), + "Should find go-test tag, found: {:?}", + tag_strings + ); + assert!( + !tag_strings.contains(&"go-table-test-case".to_string()), + "Should find go-table-test-case tag, found: {:?}", + tag_strings ); } diff --git a/crates/languages/src/go/runnables.scm b/crates/languages/src/go/runnables.scm index 6418cd04d8..f56262f799 100644 --- a/crates/languages/src/go/runnables.scm +++ b/crates/languages/src/go/runnables.scm @@ -91,3 +91,103 @@ ) @_ (#set! tag go-main) ) + +; Table test cases - slice and map +( + (short_var_declaration + left: (expression_list (identifier) @_collection_var) + right: (expression_list + (composite_literal + type: [ + (slice_type) + (map_type + key: (type_identifier) @_key_type + (#eq? @_key_type "string") + ) + ] + body: (literal_value + [ + (literal_element + (literal_value + (keyed_element + (literal_element + (identifier) @_field_name + ) + (literal_element + [ + (interpreted_string_literal) @run @_table_test_case_name + (raw_string_literal) @run @_table_test_case_name + ] + ) + ) + ) + ) + (keyed_element + (literal_element + [ + (interpreted_string_literal) @run @_table_test_case_name + (raw_string_literal) @run @_table_test_case_name + ] + ) + ) + ] + ) + ) + ) + ) + (for_statement + (range_clause + left: (expression_list + [ + ( + (identifier) + (identifier) @_loop_var + ) + (identifier) @_loop_var + ] + ) + right: (identifier) @_range_var + (#eq? @_range_var @_collection_var) + ) + body: (block + (expression_statement + (call_expression + function: (selector_expression + operand: (identifier) @_t_var + field: (field_identifier) @_run_method + (#eq? @_run_method "Run") + ) + arguments: (argument_list + . + [ + (selector_expression + operand: (identifier) @_tc_var + (#eq? @_tc_var @_loop_var) + field: (field_identifier) @_field_check + (#eq? @_field_check @_field_name) + ) + (identifier) @_arg_var + (#eq? @_arg_var @_loop_var) + ] + . + (func_literal + parameters: (parameter_list + (parameter_declaration + type: (pointer_type + (qualified_type + package: (package_identifier) @_pkg + name: (type_identifier) @_type + (#eq? @_pkg "testing") + (#eq? @_type "T") + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) @_ + (#set! tag go-table-test-case) +) diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index de6544f5a2..827341d60d 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -3367,20 +3367,6 @@ impl LocalLspStore { } } -fn parse_register_capabilities( - reg: lsp::Registration, -) -> anyhow::Result> { - let caps = match reg - .register_options - .map(|options| serde_json::from_value::(options)) - .transpose()? - { - None => OneOf::Left(true), - Some(options) => OneOf::Right(options), - }; - Ok(caps) -} - fn notify_server_capabilities_updated(server: &LanguageServer, cx: &mut Context) { if let Some(capabilities) = serde_json::to_string(&server.capabilities()).ok() { cx.emit(LspStoreEvent::LanguageServerUpdate { @@ -11690,190 +11676,190 @@ impl LspStore { // Ignore payload since we notify clients of setting changes unconditionally, relying on them pulling the latest settings. } "workspace/symbol" => { - let options = parse_register_capabilities(reg)?; - server.update_capabilities(|capabilities| { - capabilities.workspace_symbol_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); + if let Some(options) = parse_register_capabilities(reg)? { + server.update_capabilities(|capabilities| { + capabilities.workspace_symbol_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } } "workspace/fileOperations" => { - let caps = reg - .register_options - .map(serde_json::from_value) - .transpose()? - .unwrap_or_default(); - server.update_capabilities(|capabilities| { - capabilities - .workspace - .get_or_insert_default() - .file_operations = Some(caps); - }); - notify_server_capabilities_updated(&server, cx); + if let Some(options) = reg.register_options { + let caps = serde_json::from_value(options)?; + server.update_capabilities(|capabilities| { + capabilities + .workspace + .get_or_insert_default() + .file_operations = Some(caps); + }); + notify_server_capabilities_updated(&server, cx); + } } "workspace/executeCommand" => { - let options = reg - .register_options - .map(serde_json::from_value) - .transpose()? - .unwrap_or_default(); - server.update_capabilities(|capabilities| { - capabilities.execute_command_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); + if let Some(options) = reg.register_options { + let options = serde_json::from_value(options)?; + server.update_capabilities(|capabilities| { + capabilities.execute_command_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } } "textDocument/rangeFormatting" => { - let options = parse_register_capabilities(reg)?; - server.update_capabilities(|capabilities| { - capabilities.document_range_formatting_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); + if let Some(options) = parse_register_capabilities(reg)? { + server.update_capabilities(|capabilities| { + capabilities.document_range_formatting_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } } "textDocument/onTypeFormatting" => { - let options = reg + if let Some(options) = reg .register_options .map(serde_json::from_value) .transpose()? - .unwrap_or_default(); - server.update_capabilities(|capabilities| { - capabilities.document_on_type_formatting_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); + { + server.update_capabilities(|capabilities| { + capabilities.document_on_type_formatting_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } } "textDocument/formatting" => { - let options = parse_register_capabilities(reg)?; - server.update_capabilities(|capabilities| { - capabilities.document_formatting_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); + if let Some(options) = parse_register_capabilities(reg)? { + server.update_capabilities(|capabilities| { + capabilities.document_formatting_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } } "textDocument/rename" => { - let options = parse_register_capabilities(reg)?; - server.update_capabilities(|capabilities| { - capabilities.rename_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); + if let Some(options) = parse_register_capabilities(reg)? { + server.update_capabilities(|capabilities| { + capabilities.rename_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } } "textDocument/inlayHint" => { - let options = parse_register_capabilities(reg)?; - server.update_capabilities(|capabilities| { - capabilities.inlay_hint_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); + if let Some(options) = parse_register_capabilities(reg)? { + server.update_capabilities(|capabilities| { + capabilities.inlay_hint_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } } "textDocument/documentSymbol" => { - let options = parse_register_capabilities(reg)?; - server.update_capabilities(|capabilities| { - capabilities.document_symbol_provider = Some(options); - }); - notify_server_capabilities_updated(&server, cx); + if let Some(options) = parse_register_capabilities(reg)? { + server.update_capabilities(|capabilities| { + capabilities.document_symbol_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } } "textDocument/codeAction" => { - let options = reg + if let Some(options) = reg .register_options .map(serde_json::from_value) - .transpose()?; - let provider_capability = match options { - None => lsp::CodeActionProviderCapability::Simple(true), - Some(options) => lsp::CodeActionProviderCapability::Options(options), - }; - server.update_capabilities(|capabilities| { - capabilities.code_action_provider = Some(provider_capability); - }); - notify_server_capabilities_updated(&server, cx); + .transpose()? + { + server.update_capabilities(|capabilities| { + capabilities.code_action_provider = + Some(lsp::CodeActionProviderCapability::Options(options)); + }); + notify_server_capabilities_updated(&server, cx); + } } "textDocument/definition" => { - let caps = parse_register_capabilities(reg)?; - server.update_capabilities(|capabilities| { - capabilities.definition_provider = Some(caps); - }); - notify_server_capabilities_updated(&server, cx); + if let Some(options) = parse_register_capabilities(reg)? { + server.update_capabilities(|capabilities| { + capabilities.definition_provider = Some(options); + }); + notify_server_capabilities_updated(&server, cx); + } } "textDocument/completion" => { - let caps = reg + if let Some(caps) = reg .register_options .map(serde_json::from_value) .transpose()? - .unwrap_or_default(); - server.update_capabilities(|capabilities| { - capabilities.completion_provider = Some(caps); - }); - notify_server_capabilities_updated(&server, cx); + { + server.update_capabilities(|capabilities| { + capabilities.completion_provider = Some(caps); + }); + notify_server_capabilities_updated(&server, cx); + } } "textDocument/hover" => { - let caps = reg + if let Some(caps) = reg .register_options .map(serde_json::from_value) .transpose()? - .unwrap_or_else(|| lsp::HoverProviderCapability::Simple(true)); - server.update_capabilities(|capabilities| { - capabilities.hover_provider = Some(caps); - }); - notify_server_capabilities_updated(&server, cx); + { + server.update_capabilities(|capabilities| { + capabilities.hover_provider = Some(caps); + }); + notify_server_capabilities_updated(&server, cx); + } } "textDocument/signatureHelp" => { - let caps = reg + if let Some(caps) = reg .register_options .map(serde_json::from_value) .transpose()? - .unwrap_or_default(); - server.update_capabilities(|capabilities| { - capabilities.signature_help_provider = Some(caps); - }); - notify_server_capabilities_updated(&server, cx); + { + server.update_capabilities(|capabilities| { + capabilities.signature_help_provider = Some(caps); + }); + notify_server_capabilities_updated(&server, cx); + } } "textDocument/synchronization" => { - let caps = reg + if let Some(caps) = reg .register_options .map(serde_json::from_value) .transpose()? - .unwrap_or_else(|| { - lsp::TextDocumentSyncCapability::Options( - lsp::TextDocumentSyncOptions::default(), - ) + { + server.update_capabilities(|capabilities| { + capabilities.text_document_sync = Some(caps); }); - server.update_capabilities(|capabilities| { - capabilities.text_document_sync = Some(caps); - }); - notify_server_capabilities_updated(&server, cx); + notify_server_capabilities_updated(&server, cx); + } } "textDocument/codeLens" => { - let caps = reg + if let Some(caps) = reg .register_options .map(serde_json::from_value) .transpose()? - .unwrap_or_else(|| lsp::CodeLensOptions { - resolve_provider: None, + { + server.update_capabilities(|capabilities| { + capabilities.code_lens_provider = Some(caps); }); - server.update_capabilities(|capabilities| { - capabilities.code_lens_provider = Some(caps); - }); - notify_server_capabilities_updated(&server, cx); + notify_server_capabilities_updated(&server, cx); + } } "textDocument/diagnostic" => { - let caps = reg + if let Some(caps) = reg .register_options .map(serde_json::from_value) .transpose()? - .unwrap_or_else(|| { - lsp::DiagnosticServerCapabilities::RegistrationOptions( - lsp::DiagnosticRegistrationOptions::default(), - ) + { + server.update_capabilities(|capabilities| { + capabilities.diagnostic_provider = Some(caps); }); - server.update_capabilities(|capabilities| { - capabilities.diagnostic_provider = Some(caps); - }); - notify_server_capabilities_updated(&server, cx); + notify_server_capabilities_updated(&server, cx); + } } "textDocument/colorProvider" => { - let caps = reg + if let Some(caps) = reg .register_options .map(serde_json::from_value) .transpose()? - .unwrap_or_else(|| lsp::ColorProviderCapability::Simple(true)); - server.update_capabilities(|capabilities| { - capabilities.color_provider = Some(caps); - }); - notify_server_capabilities_updated(&server, cx); + { + server.update_capabilities(|capabilities| { + capabilities.color_provider = Some(caps); + }); + notify_server_capabilities_updated(&server, cx); + } } _ => log::warn!("unhandled capability registration: {reg:?}"), } @@ -12016,6 +12002,18 @@ impl LspStore { } } +// Registration with empty capabilities should be ignored. +// https://github.com/microsoft/vscode-languageserver-node/blob/d90a87f9557a0df9142cfb33e251cfa6fe27d970/client/src/common/formatting.ts#L67-L70 +fn parse_register_capabilities( + reg: lsp::Registration, +) -> anyhow::Result>> { + Ok(reg + .register_options + .map(|options| serde_json::from_value::(options)) + .transpose()? + .map(OneOf::Right)) +} + fn subscribe_to_binary_statuses( languages: &Arc, cx: &mut Context<'_, LspStore>, diff --git a/crates/vim/Cargo.toml b/crates/vim/Cargo.toml index 9fb5c46564..434b14b07c 100644 --- a/crates/vim/Cargo.toml +++ b/crates/vim/Cargo.toml @@ -24,6 +24,7 @@ command_palette.workspace = true command_palette_hooks.workspace = true db.workspace = true editor.workspace = true +env_logger.workspace = true futures.workspace = true gpui.workspace = true itertools.workspace = true diff --git a/crates/vim/src/insert.rs b/crates/vim/src/insert.rs index 584057a8c0..8ef1cd7811 100644 --- a/crates/vim/src/insert.rs +++ b/crates/vim/src/insert.rs @@ -3,7 +3,9 @@ use editor::{Bias, Editor}; use gpui::{Action, Context, Window, actions}; use language::SelectionGoal; use settings::Settings; +use text::Point; use vim_mode_setting::HelixModeSetting; +use workspace::searchable::Direction; actions!( vim, @@ -11,13 +13,23 @@ actions!( /// Switches to normal mode with cursor positioned before the current character. NormalBefore, /// Temporarily switches to normal mode for one command. - TemporaryNormal + TemporaryNormal, + /// Inserts the next character from the line above into the current line. + InsertFromAbove, + /// Inserts the next character from the line below into the current line. + InsertFromBelow ] ); pub fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, Vim::normal_before); Vim::action(editor, cx, Vim::temporary_normal); + Vim::action(editor, cx, |vim, _: &InsertFromAbove, window, cx| { + vim.insert_around(Direction::Prev, window, cx) + }); + Vim::action(editor, cx, |vim, _: &InsertFromBelow, window, cx| { + vim.insert_around(Direction::Next, window, cx) + }) } impl Vim { @@ -71,6 +83,29 @@ impl Vim { self.switch_mode(Mode::Normal, true, window, cx); self.temp_mode = true; } + + fn insert_around(&mut self, direction: Direction, _: &mut Window, cx: &mut Context) { + self.update_editor(cx, |_, editor, cx| { + let snapshot = editor.buffer().read(cx).snapshot(cx); + let mut edits = Vec::new(); + for selection in editor.selections.all::(cx) { + let point = selection.head(); + let new_row = match direction { + Direction::Next => point.row + 1, + Direction::Prev if point.row > 0 => point.row - 1, + _ => continue, + }; + let source = snapshot.clip_point(Point::new(new_row, point.column), Bias::Left); + if let Some(c) = snapshot.chars_at(source).next() + && c != '\n' + { + edits.push((point..point, c.to_string())) + } + } + + editor.edit(edits, cx); + }); + } } #[cfg(test)] @@ -156,4 +191,13 @@ mod test { .await; cx.shared_state().await.assert_eq("hehello\nˇllo\n"); } + + #[gpui::test] + async fn test_insert_ctrl_y(cx: &mut gpui::TestAppContext) { + let mut cx = NeovimBackedTestContext::new(cx).await; + + cx.set_shared_state("hello\nˇ\nworld").await; + cx.simulate_shared_keystrokes("i ctrl-y ctrl-e").await; + cx.shared_state().await.assert_eq("hello\nhoˇ\nworld"); + } } diff --git a/crates/vim/src/test/vim_test_context.rs b/crates/vim/src/test/vim_test_context.rs index b8988b1d1f..904e48e5a3 100644 --- a/crates/vim/src/test/vim_test_context.rs +++ b/crates/vim/src/test/vim_test_context.rs @@ -15,6 +15,7 @@ impl VimTestContext { if cx.has_global::() { return; } + env_logger::try_init().ok(); cx.update(|cx| { let settings = SettingsStore::test(cx); cx.set_global(settings); diff --git a/crates/vim/test_data/test_insert_ctrl_y.json b/crates/vim/test_data/test_insert_ctrl_y.json new file mode 100644 index 0000000000..09b707a198 --- /dev/null +++ b/crates/vim/test_data/test_insert_ctrl_y.json @@ -0,0 +1,5 @@ +{"Put":{"state":"hello\nˇ\nworld"}} +{"Key":"i"} +{"Key":"ctrl-y"} +{"Key":"ctrl-e"} +{"Get":{"state":"hello\nhoˇ\nworld","mode":"Insert"}} diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index 6fa5c969e7..b2d1340a7b 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -542,6 +542,20 @@ define_connection! { ALTER TABLE breakpoints ADD COLUMN condition TEXT; ALTER TABLE breakpoints ADD COLUMN hit_condition TEXT; ), + sql!(CREATE TABLE toolchains2 ( + workspace_id INTEGER, + worktree_id INTEGER, + language_name TEXT NOT NULL, + name TEXT NOT NULL, + path TEXT NOT NULL, + raw_json TEXT NOT NULL, + relative_worktree_path TEXT NOT NULL, + PRIMARY KEY (workspace_id, worktree_id, language_name, relative_worktree_path)) STRICT; + INSERT INTO toolchains2 + SELECT * FROM toolchains; + DROP TABLE toolchains; + ALTER TABLE toolchains2 RENAME TO toolchains; + ) ]; } @@ -1428,12 +1442,12 @@ impl WorkspaceDb { self.write(move |conn| { let mut insert = conn .exec_bound(sql!( - INSERT INTO toolchains(workspace_id, worktree_id, relative_worktree_path, language_name, name, path) VALUES (?, ?, ?, ?, ?, ?) + INSERT INTO toolchains(workspace_id, worktree_id, relative_worktree_path, language_name, name, path, raw_json) VALUES (?, ?, ?, ?, ?, ?, ?) ON CONFLICT DO UPDATE SET name = ?5, - path = ?6 - + path = ?6, + raw_json = ?7 )) .context("Preparing insertion")?; @@ -1444,6 +1458,7 @@ impl WorkspaceDb { toolchain.language_name.as_ref(), toolchain.name.as_ref(), toolchain.path.as_ref(), + toolchain.as_json.to_string(), ))?; Ok(()) diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index 9f1d02b790..ee76308ff3 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -26,6 +26,7 @@ collections.workspace = true command_palette_hooks.workspace = true copilot.workspace = true db.workspace = true +edit_prediction.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true @@ -33,13 +34,13 @@ futures.workspace = true gpui.workspace = true http_client.workspace = true indoc.workspace = true -edit_prediction.workspace = true language.workspace = true language_model.workspace = true log.workspace = true menu.workspace = true postage.workspace = true project.workspace = true +rand.workspace = true regex.workspace = true release_channel.workspace = true serde.workspace = true diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 6900082003..1a6a8c2934 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -429,6 +429,7 @@ impl Zeta { body, editable_range, } = gather_task.await?; + let done_gathering_context_at = Instant::now(); log::debug!( "Events:\n{}\nExcerpt:\n{:?}", @@ -481,6 +482,7 @@ impl Zeta { } }; + let received_response_at = Instant::now(); log::debug!("completion response: {}", &response.output_excerpt); if let Some(usage) = usage { @@ -492,7 +494,7 @@ impl Zeta { .ok(); } - Self::process_completion_response( + let edit_prediction = Self::process_completion_response( response, buffer, &snapshot, @@ -505,7 +507,25 @@ impl Zeta { buffer_snapshotted_at, &cx, ) - .await + .await; + + let finished_at = Instant::now(); + + // record latency for ~1% of requests + if rand::random::() <= 2 { + telemetry::event!( + "Edit Prediction Request", + context_latency = done_gathering_context_at + .duration_since(buffer_snapshotted_at) + .as_millis(), + request_latency = received_response_at + .duration_since(done_gathering_context_at) + .as_millis(), + process_latency = finished_at.duration_since(received_response_at).as_millis() + ); + } + + edit_prediction }) } diff --git a/docs/src/languages/rust.md b/docs/src/languages/rust.md index 1ee25a37b5..7695280275 100644 --- a/docs/src/languages/rust.md +++ b/docs/src/languages/rust.md @@ -326,7 +326,7 @@ When you use `cargo build` or `cargo test` as the build command, Zed can infer t [ { "label": "Build & Debug native binary", - "adapter": "CodeLLDB" + "adapter": "CodeLLDB", "build": { "command": "cargo", "args": ["build"]