From e1919b41215f2db23c1ff79edae435191f3fd510 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 5 Feb 2025 17:12:18 -0500 Subject: [PATCH] collab: Add the ability to filter extensions by what they provide (#24315) This PR adds the ability to filter extension results from the extension API by the features that they provide. For instance, to filter down just to extensions that provide icon themes: ``` https://api.zed.dev/extensions?provides=icon-themes ``` Release Notes: - N/A --- crates/collab/src/api/extensions.rs | 28 +++++++++-- crates/collab/src/db/queries/extensions.rs | 48 +++++++++++++++++++ crates/collab/src/db/tests/extension_tests.rs | 12 ++--- crates/rpc/src/extension.rs | 8 +++- 4 files changed, 85 insertions(+), 11 deletions(-) diff --git a/crates/collab/src/api/extensions.rs b/crates/collab/src/api/extensions.rs index e132acaf0b..73aea45340 100644 --- a/crates/collab/src/api/extensions.rs +++ b/crates/collab/src/api/extensions.rs @@ -9,10 +9,11 @@ use axum::{ routing::get, Extension, Json, Router, }; -use collections::HashMap; -use rpc::{ExtensionApiManifest, GetExtensionsResponse}; +use collections::{BTreeSet, HashMap}; +use rpc::{ExtensionApiManifest, ExtensionProvides, GetExtensionsResponse}; use semantic_version::SemanticVersion; use serde::Deserialize; +use std::str::FromStr; use std::{sync::Arc, time::Duration}; use time::PrimitiveDateTime; use util::{maybe, ResultExt}; @@ -35,6 +36,14 @@ pub fn router() -> Router { #[derive(Debug, Deserialize)] struct GetExtensionsParams { filter: Option, + /// A comma-delimited list of features that the extension must provide. + /// + /// For example: + /// - `themes` + /// - `themes,icon-themes` + /// - `languages,language-servers` + #[serde(default)] + provides: Option, #[serde(default)] max_schema_version: i32, } @@ -43,9 +52,22 @@ async fn get_extensions( Extension(app): Extension>, Query(params): Query, ) -> Result> { + let provides_filter = params.provides.map(|provides| { + provides + .split(',') + .map(|value| value.trim()) + .filter_map(|value| ExtensionProvides::from_str(value).ok()) + .collect::>() + }); + let mut extensions = app .db - .get_extensions(params.filter.as_deref(), params.max_schema_version, 500) + .get_extensions( + params.filter.as_deref(), + provides_filter.as_ref(), + params.max_schema_version, + 500, + ) .await?; if let Some(filter) = params.filter.as_deref() { diff --git a/crates/collab/src/db/queries/extensions.rs b/crates/collab/src/db/queries/extensions.rs index 54f47ae45e..2b76e12335 100644 --- a/crates/collab/src/db/queries/extensions.rs +++ b/crates/collab/src/db/queries/extensions.rs @@ -10,6 +10,7 @@ impl Database { pub async fn get_extensions( &self, filter: Option<&str>, + provides_filter: Option<&BTreeSet>, max_schema_version: i32, limit: usize, ) -> Result> { @@ -26,6 +27,10 @@ impl Database { condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter)); } + if let Some(provides_filter) = provides_filter { + condition = apply_provides_filter(condition, provides_filter); + } + self.get_extensions_where(condition, Some(limit as u64), &tx) .await }) @@ -385,6 +390,49 @@ impl Database { } } +fn apply_provides_filter( + mut condition: Condition, + provides_filter: &BTreeSet, +) -> Condition { + if provides_filter.contains(&ExtensionProvides::Themes) { + condition = condition.add(extension_version::Column::ProvidesThemes.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::IconThemes) { + condition = condition.add(extension_version::Column::ProvidesIconThemes.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::Languages) { + condition = condition.add(extension_version::Column::ProvidesLanguages.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::Grammars) { + condition = condition.add(extension_version::Column::ProvidesGrammars.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::LanguageServers) { + condition = condition.add(extension_version::Column::ProvidesLanguageServers.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::ContextServers) { + condition = condition.add(extension_version::Column::ProvidesContextServers.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::SlashCommands) { + condition = condition.add(extension_version::Column::ProvidesSlashCommands.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::IndexedDocsProviders) { + condition = condition.add(extension_version::Column::ProvidesIndexedDocsProviders.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::Snippets) { + condition = condition.add(extension_version::Column::ProvidesSnippets.eq(true)); + } + + condition +} + fn metadata_from_extension_and_version( extension: extension::Model, version: extension_version::Model, diff --git a/crates/collab/src/db/tests/extension_tests.rs b/crates/collab/src/db/tests/extension_tests.rs index f7a5398d3c..460d74ffc0 100644 --- a/crates/collab/src/db/tests/extension_tests.rs +++ b/crates/collab/src/db/tests/extension_tests.rs @@ -20,7 +20,7 @@ async fn test_extensions(db: &Arc) { let versions = db.get_known_extension_versions().await.unwrap(); assert!(versions.is_empty()); - let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); assert!(extensions.is_empty()); let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap(); @@ -90,7 +90,7 @@ async fn test_extensions(db: &Arc) { ); // The latest version of each extension is returned. - let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); assert_eq!( extensions, &[ @@ -128,7 +128,7 @@ async fn test_extensions(db: &Arc) { ); // Extensions with too new of a schema version are excluded. - let extensions = db.get_extensions(None, 0, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 0, 5).await.unwrap(); assert_eq!( extensions, &[ExtensionMetadata { @@ -168,7 +168,7 @@ async fn test_extensions(db: &Arc) { .unwrap()); // Extensions are returned in descending order of total downloads. - let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); assert_eq!( extensions, &[ @@ -258,7 +258,7 @@ async fn test_extensions(db: &Arc) { .collect() ); - let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); assert_eq!( extensions, &[ @@ -306,7 +306,7 @@ async fn test_extensions_by_id(db: &Arc) { let versions = db.get_known_extension_versions().await.unwrap(); assert!(versions.is_empty()); - let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); assert!(extensions.is_empty()); let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap(); diff --git a/crates/rpc/src/extension.rs b/crates/rpc/src/extension.rs index 67b9116b83..f1dcdc28d6 100644 --- a/crates/rpc/src/extension.rs +++ b/crates/rpc/src/extension.rs @@ -1,8 +1,9 @@ use std::collections::BTreeSet; +use std::sync::Arc; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; -use std::sync::Arc; +use strum::EnumString; #[derive(Clone, Serialize, Deserialize, Debug, PartialEq)] pub struct ExtensionApiManifest { @@ -17,8 +18,11 @@ pub struct ExtensionApiManifest { pub provides: BTreeSet, } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize, EnumString, +)] #[serde(rename_all = "kebab-case")] +#[strum(serialize_all = "kebab-case")] pub enum ExtensionProvides { Themes, IconThemes,