collab: Add the ability to filter extensions by what they provide (#24315)

This PR adds the ability to filter extension results from the extension
API by the features that they provide.

For instance, to filter down just to extensions that provide icon
themes:

```
https://api.zed.dev/extensions?provides=icon-themes
```

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-02-05 17:12:18 -05:00 committed by GitHub
parent c0dd7e8367
commit e1919b4121
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 85 additions and 11 deletions

View file

@ -9,10 +9,11 @@ use axum::{
routing::get,
Extension, Json, Router,
};
use collections::HashMap;
use rpc::{ExtensionApiManifest, GetExtensionsResponse};
use collections::{BTreeSet, HashMap};
use rpc::{ExtensionApiManifest, ExtensionProvides, GetExtensionsResponse};
use semantic_version::SemanticVersion;
use serde::Deserialize;
use std::str::FromStr;
use std::{sync::Arc, time::Duration};
use time::PrimitiveDateTime;
use util::{maybe, ResultExt};
@ -35,6 +36,14 @@ pub fn router() -> Router {
#[derive(Debug, Deserialize)]
struct GetExtensionsParams {
filter: Option<String>,
/// A comma-delimited list of features that the extension must provide.
///
/// For example:
/// - `themes`
/// - `themes,icon-themes`
/// - `languages,language-servers`
#[serde(default)]
provides: Option<String>,
#[serde(default)]
max_schema_version: i32,
}
@ -43,9 +52,22 @@ async fn get_extensions(
Extension(app): Extension<Arc<AppState>>,
Query(params): Query<GetExtensionsParams>,
) -> Result<Json<GetExtensionsResponse>> {
let provides_filter = params.provides.map(|provides| {
provides
.split(',')
.map(|value| value.trim())
.filter_map(|value| ExtensionProvides::from_str(value).ok())
.collect::<BTreeSet<_>>()
});
let mut extensions = app
.db
.get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
.get_extensions(
params.filter.as_deref(),
provides_filter.as_ref(),
params.max_schema_version,
500,
)
.await?;
if let Some(filter) = params.filter.as_deref() {

View file

@ -10,6 +10,7 @@ impl Database {
pub async fn get_extensions(
&self,
filter: Option<&str>,
provides_filter: Option<&BTreeSet<ExtensionProvides>>,
max_schema_version: i32,
limit: usize,
) -> Result<Vec<ExtensionMetadata>> {
@ -26,6 +27,10 @@ impl Database {
condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
}
if let Some(provides_filter) = provides_filter {
condition = apply_provides_filter(condition, provides_filter);
}
self.get_extensions_where(condition, Some(limit as u64), &tx)
.await
})
@ -385,6 +390,49 @@ impl Database {
}
}
fn apply_provides_filter(
mut condition: Condition,
provides_filter: &BTreeSet<ExtensionProvides>,
) -> Condition {
if provides_filter.contains(&ExtensionProvides::Themes) {
condition = condition.add(extension_version::Column::ProvidesThemes.eq(true));
}
if provides_filter.contains(&ExtensionProvides::IconThemes) {
condition = condition.add(extension_version::Column::ProvidesIconThemes.eq(true));
}
if provides_filter.contains(&ExtensionProvides::Languages) {
condition = condition.add(extension_version::Column::ProvidesLanguages.eq(true));
}
if provides_filter.contains(&ExtensionProvides::Grammars) {
condition = condition.add(extension_version::Column::ProvidesGrammars.eq(true));
}
if provides_filter.contains(&ExtensionProvides::LanguageServers) {
condition = condition.add(extension_version::Column::ProvidesLanguageServers.eq(true));
}
if provides_filter.contains(&ExtensionProvides::ContextServers) {
condition = condition.add(extension_version::Column::ProvidesContextServers.eq(true));
}
if provides_filter.contains(&ExtensionProvides::SlashCommands) {
condition = condition.add(extension_version::Column::ProvidesSlashCommands.eq(true));
}
if provides_filter.contains(&ExtensionProvides::IndexedDocsProviders) {
condition = condition.add(extension_version::Column::ProvidesIndexedDocsProviders.eq(true));
}
if provides_filter.contains(&ExtensionProvides::Snippets) {
condition = condition.add(extension_version::Column::ProvidesSnippets.eq(true));
}
condition
}
fn metadata_from_extension_and_version(
extension: extension::Model,
version: extension_version::Model,

View file

@ -20,7 +20,7 @@ async fn test_extensions(db: &Arc<Database>) {
let versions = db.get_known_extension_versions().await.unwrap();
assert!(versions.is_empty());
let extensions = db.get_extensions(None, 1, 5).await.unwrap();
let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
assert!(extensions.is_empty());
let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();
@ -90,7 +90,7 @@ async fn test_extensions(db: &Arc<Database>) {
);
// The latest version of each extension is returned.
let extensions = db.get_extensions(None, 1, 5).await.unwrap();
let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
assert_eq!(
extensions,
&[
@ -128,7 +128,7 @@ async fn test_extensions(db: &Arc<Database>) {
);
// Extensions with too new of a schema version are excluded.
let extensions = db.get_extensions(None, 0, 5).await.unwrap();
let extensions = db.get_extensions(None, None, 0, 5).await.unwrap();
assert_eq!(
extensions,
&[ExtensionMetadata {
@ -168,7 +168,7 @@ async fn test_extensions(db: &Arc<Database>) {
.unwrap());
// Extensions are returned in descending order of total downloads.
let extensions = db.get_extensions(None, 1, 5).await.unwrap();
let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
assert_eq!(
extensions,
&[
@ -258,7 +258,7 @@ async fn test_extensions(db: &Arc<Database>) {
.collect()
);
let extensions = db.get_extensions(None, 1, 5).await.unwrap();
let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
assert_eq!(
extensions,
&[
@ -306,7 +306,7 @@ async fn test_extensions_by_id(db: &Arc<Database>) {
let versions = db.get_known_extension_versions().await.unwrap();
assert!(versions.is_empty());
let extensions = db.get_extensions(None, 1, 5).await.unwrap();
let extensions = db.get_extensions(None, None, 1, 5).await.unwrap();
assert!(extensions.is_empty());
let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();

View file

@ -1,8 +1,9 @@
use std::collections::BTreeSet;
use std::sync::Arc;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use strum::EnumString;
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
pub struct ExtensionApiManifest {
@ -17,8 +18,11 @@ pub struct ExtensionApiManifest {
pub provides: BTreeSet<ExtensionProvides>,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize, EnumString,
)]
#[serde(rename_all = "kebab-case")]
#[strum(serialize_all = "kebab-case")]
pub enum ExtensionProvides {
Themes,
IconThemes,