diff --git a/crates/collab/src/api/extensions.rs b/crates/collab/src/api/extensions.rs index e132acaf0b..73aea45340 100644 --- a/crates/collab/src/api/extensions.rs +++ b/crates/collab/src/api/extensions.rs @@ -9,10 +9,11 @@ use axum::{ routing::get, Extension, Json, Router, }; -use collections::HashMap; -use rpc::{ExtensionApiManifest, GetExtensionsResponse}; +use collections::{BTreeSet, HashMap}; +use rpc::{ExtensionApiManifest, ExtensionProvides, GetExtensionsResponse}; use semantic_version::SemanticVersion; use serde::Deserialize; +use std::str::FromStr; use std::{sync::Arc, time::Duration}; use time::PrimitiveDateTime; use util::{maybe, ResultExt}; @@ -35,6 +36,14 @@ pub fn router() -> Router { #[derive(Debug, Deserialize)] struct GetExtensionsParams { filter: Option, + /// A comma-delimited list of features that the extension must provide. + /// + /// For example: + /// - `themes` + /// - `themes,icon-themes` + /// - `languages,language-servers` + #[serde(default)] + provides: Option, #[serde(default)] max_schema_version: i32, } @@ -43,9 +52,22 @@ async fn get_extensions( Extension(app): Extension>, Query(params): Query, ) -> Result> { + let provides_filter = params.provides.map(|provides| { + provides + .split(',') + .map(|value| value.trim()) + .filter_map(|value| ExtensionProvides::from_str(value).ok()) + .collect::>() + }); + let mut extensions = app .db - .get_extensions(params.filter.as_deref(), params.max_schema_version, 500) + .get_extensions( + params.filter.as_deref(), + provides_filter.as_ref(), + params.max_schema_version, + 500, + ) .await?; if let Some(filter) = params.filter.as_deref() { diff --git a/crates/collab/src/db/queries/extensions.rs b/crates/collab/src/db/queries/extensions.rs index 54f47ae45e..2b76e12335 100644 --- a/crates/collab/src/db/queries/extensions.rs +++ b/crates/collab/src/db/queries/extensions.rs @@ -10,6 +10,7 @@ impl Database { pub async fn get_extensions( &self, filter: Option<&str>, + provides_filter: Option<&BTreeSet>, max_schema_version: i32, limit: usize, ) -> Result> { @@ -26,6 +27,10 @@ impl Database { condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter)); } + if let Some(provides_filter) = provides_filter { + condition = apply_provides_filter(condition, provides_filter); + } + self.get_extensions_where(condition, Some(limit as u64), &tx) .await }) @@ -385,6 +390,49 @@ impl Database { } } +fn apply_provides_filter( + mut condition: Condition, + provides_filter: &BTreeSet, +) -> Condition { + if provides_filter.contains(&ExtensionProvides::Themes) { + condition = condition.add(extension_version::Column::ProvidesThemes.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::IconThemes) { + condition = condition.add(extension_version::Column::ProvidesIconThemes.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::Languages) { + condition = condition.add(extension_version::Column::ProvidesLanguages.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::Grammars) { + condition = condition.add(extension_version::Column::ProvidesGrammars.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::LanguageServers) { + condition = condition.add(extension_version::Column::ProvidesLanguageServers.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::ContextServers) { + condition = condition.add(extension_version::Column::ProvidesContextServers.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::SlashCommands) { + condition = condition.add(extension_version::Column::ProvidesSlashCommands.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::IndexedDocsProviders) { + condition = condition.add(extension_version::Column::ProvidesIndexedDocsProviders.eq(true)); + } + + if provides_filter.contains(&ExtensionProvides::Snippets) { + condition = condition.add(extension_version::Column::ProvidesSnippets.eq(true)); + } + + condition +} + fn metadata_from_extension_and_version( extension: extension::Model, version: extension_version::Model, diff --git a/crates/collab/src/db/tests/extension_tests.rs b/crates/collab/src/db/tests/extension_tests.rs index f7a5398d3c..460d74ffc0 100644 --- a/crates/collab/src/db/tests/extension_tests.rs +++ b/crates/collab/src/db/tests/extension_tests.rs @@ -20,7 +20,7 @@ async fn test_extensions(db: &Arc) { let versions = db.get_known_extension_versions().await.unwrap(); assert!(versions.is_empty()); - let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); assert!(extensions.is_empty()); let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap(); @@ -90,7 +90,7 @@ async fn test_extensions(db: &Arc) { ); // The latest version of each extension is returned. - let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); assert_eq!( extensions, &[ @@ -128,7 +128,7 @@ async fn test_extensions(db: &Arc) { ); // Extensions with too new of a schema version are excluded. - let extensions = db.get_extensions(None, 0, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 0, 5).await.unwrap(); assert_eq!( extensions, &[ExtensionMetadata { @@ -168,7 +168,7 @@ async fn test_extensions(db: &Arc) { .unwrap()); // Extensions are returned in descending order of total downloads. - let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); assert_eq!( extensions, &[ @@ -258,7 +258,7 @@ async fn test_extensions(db: &Arc) { .collect() ); - let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); assert_eq!( extensions, &[ @@ -306,7 +306,7 @@ async fn test_extensions_by_id(db: &Arc) { let versions = db.get_known_extension_versions().await.unwrap(); assert!(versions.is_empty()); - let extensions = db.get_extensions(None, 1, 5).await.unwrap(); + let extensions = db.get_extensions(None, None, 1, 5).await.unwrap(); assert!(extensions.is_empty()); let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap(); diff --git a/crates/rpc/src/extension.rs b/crates/rpc/src/extension.rs index 67b9116b83..f1dcdc28d6 100644 --- a/crates/rpc/src/extension.rs +++ b/crates/rpc/src/extension.rs @@ -1,8 +1,9 @@ use std::collections::BTreeSet; +use std::sync::Arc; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; -use std::sync::Arc; +use strum::EnumString; #[derive(Clone, Serialize, Deserialize, Debug, PartialEq)] pub struct ExtensionApiManifest { @@ -17,8 +18,11 @@ pub struct ExtensionApiManifest { pub provides: BTreeSet, } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize, EnumString, +)] #[serde(rename_all = "kebab-case")] +#[strum(serialize_all = "kebab-case")] pub enum ExtensionProvides { Themes, IconThemes,