Add an extensions API to the collaboration server (#7807)

This PR adds a REST API to the collab server for searching and
downloading extensions. Previously, we had implemented this API in
zed.dev directly, but this implementation is better, because we use the
collab database to store the download counts for extensions.

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Conrad <conrad@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-02-15 12:53:57 -08:00 committed by GitHub
parent bdc2558eac
commit e1ae0d46da
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 1755 additions and 174 deletions

View file

@ -85,6 +85,7 @@ id_type!(SignupId);
id_type!(UserId);
id_type!(ChannelBufferCollaboratorId);
id_type!(FlagId);
id_type!(ExtensionId);
id_type!(NotificationId);
id_type!(NotificationKindId);

View file

@ -5,6 +5,7 @@ pub mod buffers;
pub mod channels;
pub mod contacts;
pub mod contributors;
pub mod extensions;
pub mod messages;
pub mod notifications;
pub mod projects;

View file

@ -0,0 +1,205 @@
use super::*;
impl Database {
pub async fn get_extensions(
&self,
filter: Option<&str>,
limit: usize,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
let mut condition = Condition::all();
if let Some(filter) = filter {
let fuzzy_name_filter = Self::fuzzy_like_string(filter);
condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
}
let extensions = extension::Entity::find()
.filter(condition)
.order_by_desc(extension::Column::TotalDownloadCount)
.order_by_asc(extension::Column::Id)
.limit(Some(limit as u64))
.filter(
extension::Column::LatestVersion
.into_expr()
.eq(extension_version::Column::Version.into_expr()),
)
.inner_join(extension_version::Entity)
.select_also(extension_version::Entity)
.all(&*tx)
.await?;
Ok(extensions
.into_iter()
.filter_map(|(extension, latest_version)| {
let version = latest_version?;
Some(ExtensionMetadata {
id: extension.external_id,
name: extension.name,
version: version.version,
authors: version
.authors
.split(',')
.map(|author| author.trim().to_string())
.collect::<Vec<_>>(),
repository: version.repository,
published_at: version.published_at,
download_count: extension.total_download_count as u64,
})
})
.collect())
})
.await
}
pub async fn get_known_extension_versions<'a>(&self) -> Result<HashMap<String, Vec<String>>> {
self.transaction(|tx| async move {
let mut extension_external_ids_by_id = HashMap::default();
let mut rows = extension::Entity::find().stream(&*tx).await?;
while let Some(row) = rows.next().await {
let row = row?;
extension_external_ids_by_id.insert(row.id, row.external_id);
}
drop(rows);
let mut known_versions_by_extension_id: HashMap<String, Vec<String>> =
HashMap::default();
let mut rows = extension_version::Entity::find().stream(&*tx).await?;
while let Some(row) = rows.next().await {
let row = row?;
let Some(extension_id) = extension_external_ids_by_id.get(&row.extension_id) else {
continue;
};
let versions = known_versions_by_extension_id
.entry(extension_id.clone())
.or_default();
if let Err(ix) = versions.binary_search(&row.version) {
versions.insert(ix, row.version);
}
}
drop(rows);
Ok(known_versions_by_extension_id)
})
.await
}
pub async fn insert_extension_versions(
&self,
versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
) -> Result<()> {
self.transaction(|tx| async move {
for (external_id, versions) in versions_by_extension_id {
if versions.is_empty() {
continue;
}
let latest_version = versions
.iter()
.max_by_key(|version| &version.version)
.unwrap();
let insert = extension::Entity::insert(extension::ActiveModel {
name: ActiveValue::Set(latest_version.name.clone()),
external_id: ActiveValue::Set(external_id.to_string()),
id: ActiveValue::NotSet,
latest_version: ActiveValue::Set(latest_version.version.to_string()),
total_download_count: ActiveValue::NotSet,
})
.on_conflict(
OnConflict::columns([extension::Column::ExternalId])
.update_column(extension::Column::ExternalId)
.to_owned(),
);
let extension = if tx.support_returning() {
insert.exec_with_returning(&*tx).await?
} else {
// Sqlite
insert.exec_without_returning(&*tx).await?;
extension::Entity::find()
.filter(extension::Column::ExternalId.eq(*external_id))
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("failed to insert extension"))?
};
extension_version::Entity::insert_many(versions.iter().map(|version| {
extension_version::ActiveModel {
extension_id: ActiveValue::Set(extension.id),
published_at: ActiveValue::Set(version.published_at),
version: ActiveValue::Set(version.version.to_string()),
authors: ActiveValue::Set(version.authors.join(", ")),
repository: ActiveValue::Set(version.repository.clone()),
description: ActiveValue::Set(version.description.clone()),
download_count: ActiveValue::NotSet,
}
}))
.on_conflict(OnConflict::new().do_nothing().to_owned())
.exec_without_returning(&*tx)
.await?;
if let Ok(db_version) = semver::Version::parse(&extension.latest_version) {
if db_version >= latest_version.version {
continue;
}
}
let mut extension = extension.into_active_model();
extension.latest_version = ActiveValue::Set(latest_version.version.to_string());
extension.name = ActiveValue::set(latest_version.name.clone());
extension::Entity::update(extension).exec(&*tx).await?;
}
Ok(())
})
.await
}
pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
self.transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryId {
Id,
}
let extension_id: Option<ExtensionId> = extension::Entity::find()
.filter(extension::Column::ExternalId.eq(extension))
.select_only()
.column(extension::Column::Id)
.into_values::<_, QueryId>()
.one(&*tx)
.await?;
let Some(extension_id) = extension_id else {
return Ok(false);
};
extension_version::Entity::update_many()
.col_expr(
extension_version::Column::DownloadCount,
extension_version::Column::DownloadCount.into_expr().add(1),
)
.filter(
extension_version::Column::ExtensionId
.eq(extension_id)
.and(extension_version::Column::Version.eq(version)),
)
.exec(&*tx)
.await?;
extension::Entity::update_many()
.col_expr(
extension::Column::TotalDownloadCount,
extension::Column::TotalDownloadCount.into_expr().add(1),
)
.filter(extension::Column::Id.eq(extension_id))
.exec(&*tx)
.await?;
Ok(true)
})
.await
}
}

View file

@ -10,6 +10,8 @@ pub mod channel_message;
pub mod channel_message_mention;
pub mod contact;
pub mod contributor;
pub mod extension;
pub mod extension_version;
pub mod feature_flag;
pub mod follower;
pub mod language_server;

View file

@ -0,0 +1,27 @@
use crate::db::ExtensionId;
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "extensions")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: ExtensionId,
pub external_id: String,
pub name: String,
pub latest_version: String,
pub total_download_count: i64,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_one = "super::extension_version::Entity")]
LatestVersion,
}
impl Related<super::extension_version::Entity> for Entity {
fn to() -> RelationDef {
Relation::LatestVersion.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -0,0 +1,36 @@
use crate::db::ExtensionId;
use sea_orm::entity::prelude::*;
use time::PrimitiveDateTime;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "extension_versions")]
pub struct Model {
#[sea_orm(primary_key)]
pub extension_id: ExtensionId,
#[sea_orm(primary_key)]
pub version: String,
pub published_at: PrimitiveDateTime,
pub authors: String,
pub repository: String,
pub description: String,
pub download_count: i64,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::extension::Entity",
from = "Column::ExtensionId",
to = "super::extension::Column::Id"
on_condition = r#"super::extension::Column::LatestVersion.into_expr().eq(Column::Version.into_expr())"#
)]
Extension,
}
impl Related<super::extension::Entity> for Entity {
fn to() -> RelationDef {
Relation::Extension.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -2,6 +2,7 @@ mod buffer_tests;
mod channel_tests;
mod contributor_tests;
mod db_tests;
mod extension_tests;
mod feature_flag_tests;
mod message_tests;

View file

@ -0,0 +1,219 @@
use super::Database;
use crate::{
db::{ExtensionMetadata, NewExtensionVersion},
test_both_dbs,
};
use std::sync::Arc;
use time::{OffsetDateTime, PrimitiveDateTime};
test_both_dbs!(
test_extensions,
test_extensions_postgres,
test_extensions_sqlite
);
async fn test_extensions(db: &Arc<Database>) {
let versions = db.get_known_extension_versions().await.unwrap();
assert!(versions.is_empty());
let extensions = db.get_extensions(None, 5).await.unwrap();
assert!(extensions.is_empty());
let t0 = OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();
let t0 = PrimitiveDateTime::new(t0.date(), t0.time());
db.insert_extension_versions(
&[
(
"ext1",
vec![
NewExtensionVersion {
name: "Extension 1".into(),
version: semver::Version::parse("0.0.1").unwrap(),
description: "an extension".into(),
authors: vec!["max".into()],
repository: "ext1/repo".into(),
published_at: t0,
},
NewExtensionVersion {
name: "Extension One".into(),
version: semver::Version::parse("0.0.2").unwrap(),
description: "a good extension".into(),
authors: vec!["max".into(), "marshall".into()],
repository: "ext1/repo".into(),
published_at: t0,
},
],
),
(
"ext2",
vec![NewExtensionVersion {
name: "Extension Two".into(),
version: semver::Version::parse("0.2.0").unwrap(),
description: "a great extension".into(),
authors: vec!["marshall".into()],
repository: "ext2/repo".into(),
published_at: t0,
}],
),
]
.into_iter()
.collect(),
)
.await
.unwrap();
let versions = db.get_known_extension_versions().await.unwrap();
assert_eq!(
versions,
[
("ext1".into(), vec!["0.0.1".into(), "0.0.2".into()]),
("ext2".into(), vec!["0.2.0".into()])
]
.into_iter()
.collect()
);
// The latest version of each extension is returned.
let extensions = db.get_extensions(None, 5).await.unwrap();
assert_eq!(
extensions,
&[
ExtensionMetadata {
id: "ext1".into(),
name: "Extension One".into(),
version: "0.0.2".into(),
authors: vec!["max".into(), "marshall".into()],
repository: "ext1/repo".into(),
published_at: t0,
download_count: 0,
},
ExtensionMetadata {
id: "ext2".into(),
name: "Extension Two".into(),
version: "0.2.0".into(),
authors: vec!["marshall".into()],
repository: "ext2/repo".into(),
published_at: t0,
download_count: 0
},
]
);
// Record extensions being downloaded.
for _ in 0..7 {
assert!(db.record_extension_download("ext2", "0.0.2").await.unwrap());
}
for _ in 0..3 {
assert!(db.record_extension_download("ext1", "0.0.1").await.unwrap());
}
for _ in 0..2 {
assert!(db.record_extension_download("ext1", "0.0.2").await.unwrap());
}
// Record download returns false if the extension does not exist.
assert!(!db
.record_extension_download("no-such-extension", "0.0.2")
.await
.unwrap());
// Extensions are returned in descending order of total downloads.
let extensions = db.get_extensions(None, 5).await.unwrap();
assert_eq!(
extensions,
&[
ExtensionMetadata {
id: "ext2".into(),
name: "Extension Two".into(),
version: "0.2.0".into(),
authors: vec!["marshall".into()],
repository: "ext2/repo".into(),
published_at: t0,
download_count: 7
},
ExtensionMetadata {
id: "ext1".into(),
name: "Extension One".into(),
version: "0.0.2".into(),
authors: vec!["max".into(), "marshall".into()],
repository: "ext1/repo".into(),
published_at: t0,
download_count: 5,
},
]
);
// Add more extensions, including a new version of `ext1`, and backfilling
// an older version of `ext2`.
db.insert_extension_versions(
&[
(
"ext1",
vec![NewExtensionVersion {
name: "Extension One".into(),
version: semver::Version::parse("0.0.3").unwrap(),
description: "a real good extension".into(),
authors: vec!["max".into(), "marshall".into()],
repository: "ext1/repo".into(),
published_at: t0,
}],
),
(
"ext2",
vec![NewExtensionVersion {
name: "Extension Two".into(),
version: semver::Version::parse("0.1.0").unwrap(),
description: "an old extension".into(),
authors: vec!["marshall".into()],
repository: "ext2/repo".into(),
published_at: t0,
}],
),
]
.into_iter()
.collect(),
)
.await
.unwrap();
let versions = db.get_known_extension_versions().await.unwrap();
assert_eq!(
versions,
[
(
"ext1".into(),
vec!["0.0.1".into(), "0.0.2".into(), "0.0.3".into()]
),
("ext2".into(), vec!["0.1.0".into(), "0.2.0".into()])
]
.into_iter()
.collect()
);
let extensions = db.get_extensions(None, 5).await.unwrap();
assert_eq!(
extensions,
&[
ExtensionMetadata {
id: "ext2".into(),
name: "Extension Two".into(),
version: "0.2.0".into(),
authors: vec!["marshall".into()],
repository: "ext2/repo".into(),
published_at: t0,
download_count: 7
},
ExtensionMetadata {
id: "ext1".into(),
name: "Extension One".into(),
version: "0.0.3".into(),
authors: vec!["max".into(), "marshall".into()],
repository: "ext1/repo".into(),
published_at: t0,
download_count: 5,
},
]
);
}