diff --git a/crates/collab/src/tests/random_project_collaboration_tests.rs b/crates/collab/src/tests/random_project_collaboration_tests.rs index 66a9d06804..1f39190d75 100644 --- a/crates/collab/src/tests/random_project_collaboration_tests.rs +++ b/crates/collab/src/tests/random_project_collaboration_tests.rs @@ -835,7 +835,7 @@ impl RandomizedTest for ProjectCollaborationTest { .map_ok(|_| ()) .boxed(), LspRequestKind::CodeAction => project - .code_actions(&buffer, offset..offset, cx) + .code_actions(&buffer, offset..offset, None, cx) .map(|_| Ok(())) .boxed(), LspRequestKind::Definition => project diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index b31938bcfd..401462795e 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -13811,7 +13811,9 @@ impl CodeActionProvider for Model { range: Range, cx: &mut WindowContext, ) -> Task>> { - self.update(cx, |project, cx| project.code_actions(buffer, range, cx)) + self.update(cx, |project, cx| { + project.code_actions(buffer, range, None, cx) + }) } fn apply_code_action( diff --git a/crates/project/src/lsp_command.rs b/crates/project/src/lsp_command.rs index 57f8cea348..6de4902746 100644 --- a/crates/project/src/lsp_command.rs +++ b/crates/project/src/lsp_command.rs @@ -2090,19 +2090,33 @@ impl LspCommand for GetCodeActions { server_id: LanguageServerId, _: AsyncAppContext, ) -> Result> { + let requested_kinds_set = if let Some(kinds) = self.kinds { + Some(kinds.into_iter().collect::>()) + } else { + None + }; + Ok(actions .unwrap_or_default() .into_iter() .filter_map(|entry| { - if let lsp::CodeActionOrCommand::CodeAction(lsp_action) = entry { - Some(CodeAction { - server_id, - range: self.range.clone(), - lsp_action, - }) - } else { - None + let lsp::CodeActionOrCommand::CodeAction(lsp_action) = entry else { + return None; + }; + + if let Some((requested_kinds, kind)) = + requested_kinds_set.as_ref().zip(lsp_action.kind.as_ref()) + { + if !requested_kinds.contains(kind) { + return None; + } } + + Some(CodeAction { + server_id, + range: self.range.clone(), + lsp_action, + }) }) .collect()) } diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index 3ed311a51d..29a4c8e71b 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -2015,6 +2015,7 @@ impl LspStore { &mut self, buffer_handle: &Model, range: Range, + kinds: Option>, cx: &mut ModelContext, ) -> Task>> { if let Some((upstream_client, project_id)) = self.upstream_client() { @@ -2028,7 +2029,7 @@ impl LspStore { request: Some(proto::multi_lsp_query::Request::GetCodeActions( GetCodeActions { range: range.clone(), - kinds: None, + kinds: kinds.clone(), } .to_proto(project_id, buffer_handle.read(cx)), )), @@ -2054,7 +2055,7 @@ impl LspStore { .map(|code_actions_response| { GetCodeActions { range: range.clone(), - kinds: None, + kinds: kinds.clone(), } .response_from_proto( code_actions_response, @@ -2079,7 +2080,7 @@ impl LspStore { Some(range.start), GetCodeActions { range: range.clone(), - kinds: None, + kinds: kinds.clone(), }, cx, ); diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 61a700e5d6..40da76ff3a 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -52,8 +52,8 @@ use language::{ Transaction, Unclipped, }; use lsp::{ - CompletionContext, CompletionItemKind, DocumentHighlightKind, LanguageServer, LanguageServerId, - LanguageServerName, MessageActionItem, + CodeActionKind, CompletionContext, CompletionItemKind, DocumentHighlightKind, LanguageServer, + LanguageServerId, LanguageServerName, MessageActionItem, }; use lsp_command::*; use node_runtime::NodeRuntime; @@ -2843,12 +2843,13 @@ impl Project { &mut self, buffer_handle: &Model, range: Range, + kinds: Option>, cx: &mut ModelContext, ) -> Task>> { let buffer = buffer_handle.read(cx); let range = buffer.anchor_before(range.start)..buffer.anchor_before(range.end); self.lsp_store.update(cx, |lsp_store, cx| { - lsp_store.code_actions(buffer_handle, range, cx) + lsp_store.code_actions(buffer_handle, range, kinds, cx) }) } diff --git a/crates/project/src/project_tests.rs b/crates/project/src/project_tests.rs index ab00d62d6c..2704259306 100644 --- a/crates/project/src/project_tests.rs +++ b/crates/project/src/project_tests.rs @@ -2792,7 +2792,9 @@ async fn test_apply_code_actions_with_commands(cx: &mut gpui::TestAppContext) { let fake_server = fake_language_servers.next().await.unwrap(); // Language server returns code actions that contain commands, and not edits. - let actions = project.update(cx, |project, cx| project.code_actions(&buffer, 0..0, cx)); + let actions = project.update(cx, |project, cx| { + project.code_actions(&buffer, 0..0, None, cx) + }); fake_server .handle_request::(|_, _| async move { Ok(Some(vec![ @@ -4961,6 +4963,84 @@ async fn test_hovers_with_empty_parts(cx: &mut gpui::TestAppContext) { ); } +#[gpui::test] +async fn test_code_actions_only_kinds(cx: &mut gpui::TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/dir", + json!({ + "a.ts": "a", + }), + ) + .await; + + let project = Project::test(fs, ["/dir".as_ref()], cx).await; + + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + language_registry.add(typescript_lang()); + let mut fake_language_servers = language_registry.register_fake_lsp( + "TypeScript", + FakeLspAdapter { + capabilities: lsp::ServerCapabilities { + code_action_provider: Some(lsp::CodeActionProviderCapability::Simple(true)), + ..lsp::ServerCapabilities::default() + }, + ..FakeLspAdapter::default() + }, + ); + + let buffer = project + .update(cx, |p, cx| p.open_local_buffer("/dir/a.ts", cx)) + .await + .unwrap(); + cx.executor().run_until_parked(); + + let fake_server = fake_language_servers + .next() + .await + .expect("failed to get the language server"); + + let mut request_handled = fake_server.handle_request::( + move |_, _| async move { + Ok(Some(vec![ + lsp::CodeActionOrCommand::CodeAction(lsp::CodeAction { + title: "organize imports".to_string(), + kind: Some(CodeActionKind::SOURCE_ORGANIZE_IMPORTS), + ..lsp::CodeAction::default() + }), + lsp::CodeActionOrCommand::CodeAction(lsp::CodeAction { + title: "fix code".to_string(), + kind: Some(CodeActionKind::SOURCE_FIX_ALL), + ..lsp::CodeAction::default() + }), + ])) + }, + ); + + let code_actions_task = project.update(cx, |project, cx| { + project.code_actions( + &buffer, + 0..buffer.read(cx).len(), + Some(vec![CodeActionKind::SOURCE_ORGANIZE_IMPORTS]), + cx, + ) + }); + + let () = request_handled + .next() + .await + .expect("The code action request should have been triggered"); + + let code_actions = code_actions_task.await.unwrap(); + assert_eq!(code_actions.len(), 1); + assert_eq!( + code_actions[0].lsp_action.kind, + Some(CodeActionKind::SOURCE_ORGANIZE_IMPORTS) + ); +} + #[gpui::test] async fn test_multiple_language_server_actions(cx: &mut gpui::TestAppContext) { init_test(cx); @@ -5092,7 +5172,7 @@ async fn test_multiple_language_server_actions(cx: &mut gpui::TestAppContext) { } let code_actions_task = project.update(cx, |project, cx| { - project.code_actions(&buffer, 0..buffer.read(cx).len(), cx) + project.code_actions(&buffer, 0..buffer.read(cx).len(), None, cx) }); // cx.run_until_parked();