Respect version constraints when installing extensions (#10052)

This PR modifies the extension installation and update process to
respect version constraints (schema version and Wasm API version) to
ensure only compatible versions of extensions are able to be installed.

To achieve this there is a new `GET /extensions/updates` endpoint that
will return extension versions based on the provided constraints.

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Marshall Bowers 2024-04-01 17:10:30 -04:00 committed by GitHub
parent 39cc3c0778
commit 83ce783856
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 304 additions and 78 deletions

View file

@ -1,3 +1,4 @@
use crate::db::ExtensionVersionConstraints;
use crate::{db::NewExtensionVersion, AppState, Error, Result};
use anyhow::{anyhow, Context as _};
use aws_sdk_s3::presigning::PresigningConfig;
@ -10,14 +11,16 @@ use axum::{
};
use collections::HashMap;
use rpc::{ExtensionApiManifest, GetExtensionsResponse};
use semantic_version::SemanticVersion;
use serde::Deserialize;
use std::{sync::Arc, time::Duration};
use time::PrimitiveDateTime;
use util::ResultExt;
use util::{maybe, ResultExt};
pub fn router() -> Router {
Router::new()
.route("/extensions", get(get_extensions))
.route("/extensions/updates", get(get_extension_updates))
.route("/extensions/:extension_id", get(get_extension_versions))
.route(
"/extensions/:extension_id/download",
@ -48,9 +51,7 @@ async fn get_extensions(
.map(|s| s.split(',').map(|s| s.trim()).collect::<Vec<_>>());
let extensions = if let Some(extension_ids) = extension_ids {
app.db
.get_extensions_by_ids(&extension_ids, params.max_schema_version)
.await?
app.db.get_extensions_by_ids(&extension_ids, None).await?
} else {
app.db
.get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
@ -60,6 +61,34 @@ async fn get_extensions(
Ok(Json(GetExtensionsResponse { data: extensions }))
}
#[derive(Debug, Deserialize)]
struct GetExtensionUpdatesParams {
ids: String,
min_schema_version: i32,
max_schema_version: i32,
min_wasm_api_version: SemanticVersion,
max_wasm_api_version: SemanticVersion,
}
async fn get_extension_updates(
Extension(app): Extension<Arc<AppState>>,
Query(params): Query<GetExtensionUpdatesParams>,
) -> Result<Json<GetExtensionsResponse>> {
let constraints = ExtensionVersionConstraints {
schema_versions: params.min_schema_version..=params.max_schema_version,
wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
};
let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
let extensions = app
.db
.get_extensions_by_ids(&extension_ids, Some(&constraints))
.await?;
Ok(Json(GetExtensionsResponse { data: extensions }))
}
#[derive(Debug, Deserialize)]
struct GetExtensionVersionsParams {
extension_id: String,
@ -79,15 +108,31 @@ async fn get_extension_versions(
#[derive(Debug, Deserialize)]
struct DownloadLatestExtensionParams {
extension_id: String,
min_schema_version: Option<i32>,
max_schema_version: Option<i32>,
min_wasm_api_version: Option<SemanticVersion>,
max_wasm_api_version: Option<SemanticVersion>,
}
async fn download_latest_extension(
Extension(app): Extension<Arc<AppState>>,
Path(params): Path<DownloadLatestExtensionParams>,
) -> Result<Redirect> {
let constraints = maybe!({
let min_schema_version = params.min_schema_version?;
let max_schema_version = params.max_schema_version?;
let min_wasm_api_version = params.min_wasm_api_version?;
let max_wasm_api_version = params.max_wasm_api_version?;
Some(ExtensionVersionConstraints {
schema_versions: min_schema_version..=max_schema_version,
wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
})
});
let extension = app
.db
.get_extension(&params.extension_id)
.get_extension(&params.extension_id, constraints.as_ref())
.await?
.ok_or_else(|| anyhow!("unknown extension"))?;
download_extension(

View file

@ -21,11 +21,13 @@ use sea_orm::{
FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
TransactionTrait,
};
use serde::{ser::Error as _, Deserialize, Serialize, Serializer};
use semantic_version::SemanticVersion;
use serde::{Deserialize, Serialize};
use sqlx::{
migrate::{Migrate, Migration, MigrationSource},
Connection,
};
use std::ops::RangeInclusive;
use std::{
fmt::Write as _,
future::Future,
@ -36,7 +38,7 @@ use std::{
sync::Arc,
time::Duration,
};
use time::{format_description::well_known::iso8601, PrimitiveDateTime};
use time::PrimitiveDateTime;
use tokio::sync::{Mutex, OwnedMutexGuard};
#[cfg(test)]
@ -730,20 +732,7 @@ pub struct NewExtensionVersion {
pub published_at: PrimitiveDateTime,
}
pub fn serialize_iso8601<S: Serializer>(
datetime: &PrimitiveDateTime,
serializer: S,
) -> Result<S::Ok, S::Error> {
const SERDE_CONFIG: iso8601::EncodedConfig = iso8601::Config::DEFAULT
.set_year_is_six_digits(false)
.set_time_precision(iso8601::TimePrecision::Second {
decimal_digits: None,
})
.encode();
datetime
.assume_utc()
.format(&time::format_description::well_known::Iso8601::<SERDE_CONFIG>)
.map_err(S::Error::custom)?
.serialize(serializer)
pub struct ExtensionVersionConstraints {
pub schema_versions: RangeInclusive<i32>,
pub wasm_api_versions: RangeInclusive<SemanticVersion>,
}

View file

@ -1,5 +1,8 @@
use std::str::FromStr;
use chrono::Utc;
use sea_orm::sea_query::IntoCondition;
use util::ResultExt;
use super::*;
@ -32,23 +35,83 @@ impl Database {
pub async fn get_extensions_by_ids(
&self,
ids: &[&str],
max_schema_version: i32,
constraints: Option<&ExtensionVersionConstraints>,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
let condition = Condition::all()
.add(
extension::Column::LatestVersion
.into_expr()
.eq(extension_version::Column::Version.into_expr()),
)
.add(extension::Column::ExternalId.is_in(ids.iter().copied()))
.add(extension_version::Column::SchemaVersion.lte(max_schema_version));
let extensions = extension::Entity::find()
.filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
.all(&*tx)
.await?;
self.get_extensions_where(condition, None, &tx).await
let mut max_versions = self
.get_latest_versions_for_extensions(&extensions, constraints, &tx)
.await?;
Ok(extensions
.into_iter()
.filter_map(|extension| {
let (version, _) = max_versions.remove(&extension.id)?;
Some(metadata_from_extension_and_version(extension, version))
})
.collect())
})
.await
}
async fn get_latest_versions_for_extensions(
&self,
extensions: &[extension::Model],
constraints: Option<&ExtensionVersionConstraints>,
tx: &DatabaseTransaction,
) -> Result<HashMap<ExtensionId, (extension_version::Model, SemanticVersion)>> {
let mut versions = extension_version::Entity::find()
.filter(
extension_version::Column::ExtensionId
.is_in(extensions.iter().map(|extension| extension.id)),
)
.stream(tx)
.await?;
let mut max_versions =
HashMap::<ExtensionId, (extension_version::Model, SemanticVersion)>::default();
while let Some(version) = versions.next().await {
let version = version?;
let Some(extension_version) = SemanticVersion::from_str(&version.version).log_err()
else {
continue;
};
if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id) {
if max_extension_version > &extension_version {
continue;
}
}
if let Some(constraints) = constraints {
if !constraints
.schema_versions
.contains(&version.schema_version)
{
continue;
}
if let Some(wasm_api_version) = version.wasm_api_version.as_ref() {
if let Some(version) = SemanticVersion::from_str(wasm_api_version).log_err() {
if !constraints.wasm_api_versions.contains(&version) {
continue;
}
} else {
continue;
}
}
}
max_versions.insert(version.extension_id, (version, extension_version));
}
Ok(max_versions)
}
/// Returns all of the versions for the extension with the given ID.
pub async fn get_extension_versions(
&self,
@ -88,22 +151,26 @@ impl Database {
.collect())
}
pub async fn get_extension(&self, extension_id: &str) -> Result<Option<ExtensionMetadata>> {
pub async fn get_extension(
&self,
extension_id: &str,
constraints: Option<&ExtensionVersionConstraints>,
) -> Result<Option<ExtensionMetadata>> {
self.transaction(|tx| async move {
let extension = extension::Entity::find()
.filter(extension::Column::ExternalId.eq(extension_id))
.filter(
extension::Column::LatestVersion
.into_expr()
.eq(extension_version::Column::Version.into_expr()),
)
.inner_join(extension_version::Entity)
.select_also(extension_version::Entity)
.one(&*tx)
.await?;
.await?
.ok_or_else(|| anyhow!("no such extension: {extension_id}"))?;
Ok(extension.and_then(|(extension, version)| {
Some(metadata_from_extension_and_version(extension, version?))
let extensions = [extension];
let mut versions = self
.get_latest_versions_for_extensions(&extensions, constraints, &tx)
.await?;
let [extension] = extensions;
Ok(versions.remove(&extension.id).map(|(max_version, _)| {
metadata_from_extension_and_version(extension, max_version)
}))
})
.await

View file

@ -1,4 +1,5 @@
use super::Database;
use crate::db::ExtensionVersionConstraints;
use crate::{
db::{queries::extensions::convert_time_to_chrono, ExtensionMetadata, NewExtensionVersion},
test_both_dbs,
@ -278,3 +279,108 @@ async fn test_extensions(db: &Arc<Database>) {
]
);
}
test_both_dbs!(
test_extensions_by_id,
test_extensions_by_id_postgres,
test_extensions_by_id_sqlite
);
async fn test_extensions_by_id(db: &Arc<Database>) {
let versions = db.get_known_extension_versions().await.unwrap();
assert!(versions.is_empty());
let extensions = db.get_extensions(None, 1, 5).await.unwrap();
assert!(extensions.is_empty());
let t0 = time::OffsetDateTime::from_unix_timestamp_nanos(0).unwrap();
let t0 = time::PrimitiveDateTime::new(t0.date(), t0.time());
let t0_chrono = convert_time_to_chrono(t0);
db.insert_extension_versions(
&[
(
"ext1",
vec![
NewExtensionVersion {
name: "Extension 1".into(),
version: semver::Version::parse("0.0.1").unwrap(),
description: "an extension".into(),
authors: vec!["max".into()],
repository: "ext1/repo".into(),
schema_version: 1,
wasm_api_version: Some("0.0.4".into()),
published_at: t0,
},
NewExtensionVersion {
name: "Extension 1".into(),
version: semver::Version::parse("0.0.2").unwrap(),
description: "a good extension".into(),
authors: vec!["max".into()],
repository: "ext1/repo".into(),
schema_version: 1,
wasm_api_version: Some("0.0.4".into()),
published_at: t0,
},
NewExtensionVersion {
name: "Extension 1".into(),
version: semver::Version::parse("0.0.3").unwrap(),
description: "a real good extension".into(),
authors: vec!["max".into(), "marshall".into()],
repository: "ext1/repo".into(),
schema_version: 1,
wasm_api_version: Some("0.0.5".into()),
published_at: t0,
},
],
),
(
"ext2",
vec![NewExtensionVersion {
name: "Extension 2".into(),
version: semver::Version::parse("0.2.0").unwrap(),
description: "a great extension".into(),
authors: vec!["marshall".into()],
repository: "ext2/repo".into(),
schema_version: 0,
wasm_api_version: None,
published_at: t0,
}],
),
]
.into_iter()
.collect(),
)
.await
.unwrap();
let extensions = db
.get_extensions_by_ids(
&["ext1"],
Some(&ExtensionVersionConstraints {
schema_versions: 1..=1,
wasm_api_versions: "0.0.1".parse().unwrap()..="0.0.4".parse().unwrap(),
}),
)
.await
.unwrap();
assert_eq!(
extensions,
&[ExtensionMetadata {
id: "ext1".into(),
manifest: rpc::ExtensionApiManifest {
name: "Extension 1".into(),
version: "0.0.2".into(),
authors: vec!["max".into()],
description: Some("a good extension".into()),
repository: "ext1/repo".into(),
schema_version: Some(1),
wasm_api_version: Some("0.0.4".into()),
},
published_at: t0_chrono,
download_count: 0,
}]
);
}