From 1a44dfe56fdbd9d6ae807bf87768a76a3895ca24 Mon Sep 17 00:00:00 2001 From: Natty Date: Wed, 2 Aug 2023 03:10:53 +0200 Subject: [PATCH] Initial API authentication implementation --- .dev/Caddyfile | 4 + Cargo.lock | 183 +++++++++++++++++++++++ Cargo.toml | 19 ++- config/default.toml | 6 +- ext_calckey_model/Cargo.toml | 5 + ext_calckey_model/src/lib.rs | 271 +++++++++++++++++++++++++++++++++- magnetar_common/src/config.rs | 7 + src/api_v1/mod.rs | 20 +++ src/api_v1/user.rs | 18 +++ src/main.rs | 26 +++- src/service/mod.rs | 36 +++++ src/service/user_cache.rs | 239 ++++++++++++++++++++++++++++++ src/util.rs | 118 +++++++++++++++ src/web/auth.rs | 262 ++++++++++++++++++++++++++++++++ src/web/mod.rs | 83 +++++++++++ 15 files changed, 1285 insertions(+), 12 deletions(-) create mode 100644 src/api_v1/mod.rs create mode 100644 src/api_v1/user.rs create mode 100644 src/service/mod.rs create mode 100644 src/service/user_cache.rs create mode 100644 src/util.rs create mode 100644 src/web/auth.rs create mode 100644 src/web/mod.rs diff --git a/.dev/Caddyfile b/.dev/Caddyfile index 546372f..c302e22 100644 --- a/.dev/Caddyfile +++ b/.dev/Caddyfile @@ -15,6 +15,10 @@ nattyarch.local { reverse_proxy 127.0.0.1:4939 } + handle /mag/* { + reverse_proxy 127.0.0.1:4939 + } + @render_html { not path /api* /proxy* /files* /avatar* /identicon* /streaming header Accept text/html* diff --git a/Cargo.lock b/Cargo.lock index 624db9e..72ad8ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -125,6 +125,7 @@ checksum = "f8175979259124331c1d7bf6586ee7e0da434155e4b2d48ec2c8386281d8df39" dependencies = [ "async-trait", "axum-core", + "axum-macros", "bitflags", "bytes", "futures-util", @@ -167,6 +168,18 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-macros" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdca6a10ecad987bda04e95606ef85a5417dcaac1a78455242d72e031e2b6b62" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "syn 2.0.16", +] + [[package]] name = "bae" version = "0.1.7" @@ -325,6 +338,42 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +[[package]] +name = "cached" +version = "0.44.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b195e4fbc4b6862bbd065b991a34750399c119797efff72492f28a5864de8700" +dependencies = [ + "async-trait", + "cached_proc_macro", + "cached_proc_macro_types", + "futures", + "hashbrown 0.13.2", + "instant", + "once_cell", + "thiserror", + "tokio", +] + +[[package]] +name = "cached_proc_macro" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b48814962d2fd604c50d2b9433c2a41a0ab567779ee2c02f7fba6eca1221f082" +dependencies = [ + "cached_proc_macro_types", + "darling", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "cached_proc_macro_types" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a4f925191b4367301851c6d99b09890311d74b0d43f274c0b34c86d308a3663" + [[package]] name = "cc" version = "1.0.79" @@ -419,6 +468,20 @@ dependencies = [ "os_str_bytes", ] +[[package]] +name = "combine" +version = "4.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4" +dependencies = [ + "bytes", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "tokio-util", +] + [[package]] name = "core-foundation-sys" version = "0.8.4" @@ -463,6 +526,41 @@ dependencies = [ "typenum", ] +[[package]] +name = "darling" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 1.0.109", +] + +[[package]] +name = "darling_macro" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" +dependencies = [ + "darling_core", + "quote", + "syn 1.0.109", +] + [[package]] name = "deunicode" version = "0.4.3" @@ -594,6 +692,17 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +[[package]] +name = "futures-macro" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.16", +] + [[package]] name = "futures-sink" version = "0.3.28" @@ -615,6 +724,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -883,6 +993,12 @@ dependencies = [ "cc", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.3.0" @@ -989,7 +1105,11 @@ name = "magnetar" version = "0.2.0" dependencies = [ "axum", + "cached", + "cfg-if", + "chrono", "dotenvy", + "headers", "hyper", "magnetar_calckey_model", "magnetar_common", @@ -1000,6 +1120,7 @@ dependencies = [ "percent-encoding", "serde", "serde_json", + "strum", "thiserror", "tokio", "toml 0.7.4", @@ -1041,12 +1162,17 @@ dependencies = [ "ck", "dotenvy", "ext_calckey_model_migration", + "futures-core", + "futures-util", "magnetar_common", + "redis", "sea-orm", "serde", "serde_json", + "strum", "thiserror", "tokio", + "tokio-util", "tracing", ] @@ -1571,6 +1697,29 @@ dependencies = [ "getrandom", ] +[[package]] +name = "redis" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff5d95dd18a4d76650f0c2607ed8ebdbf63baf9cb934e1c233cd220c694db1d7" +dependencies = [ + "async-trait", + "bytes", + "combine", + "futures-util", + "itoa", + "percent-encoding", + "pin-project-lite", + "ryu", + "serde", + "serde_json", + "sha1_smol", + "socket2", + "tokio", + "tokio-util", + "url", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -1992,6 +2141,12 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1_smol" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012" + [[package]] name = "sha2" version = "0.10.6" @@ -2192,6 +2347,34 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + +[[package]] +name = "strum" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6069ca09d878a33f883cc06aaa9718ede171841d3832450354410b718b097232" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.16", +] + [[package]] name = "subtle" version = "2.4.1" diff --git a/Cargo.toml b/Cargo.toml index de549d9..6b6540f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,19 +22,27 @@ edition = "2021" [workspace.dependencies] axum = "0.6" +cached = "0.44" +cfg-if = "1" chrono = "0.4" dotenvy = "0.15" +futures-core = "0.3" +futures-util = "0.3" +headers = "0.3" hyper = "0.14" log = "0.4" miette = "5.9" percent-encoding = "2.2" +redis = "0.23" sea-orm = "0.11" sea-orm-migration = "0.11" serde = "1" serde_json = "1" +strum = "0.25" tera = { version = "1", default-features = false } thiserror = "1" tokio = "1.24" +tokio-util = "0.7" toml = "0.7" tower = "0.4" tower-http = "0.4" @@ -50,9 +58,12 @@ magnetar_webfinger = { path = "./ext_webfinger" } magnetar_nodeinfo = { path = "./ext_nodeinfo" } magnetar_calckey_model = { path = "./ext_calckey_model" } +cached = { workspace = true } +chrono = { workspace = true } dotenvy = { workspace = true } -axum = { workspace = true } +axum = { workspace = true, features = ["macros"] } +headers = { workspace = true } hyper = { workspace = true, features = ["full"] } tokio = { workspace = true, features = ["full"] } tower = { workspace = true } @@ -61,12 +72,14 @@ tower-http = { workspace = true, features = ["cors", "trace", "fs"] } tracing-subscriber = { workspace = true, features = ["env-filter"] } tracing = { workspace = true } +cfg-if = { workspace = true } + +strum = { workspace = true, features = ["derive"] } thiserror = { workspace = true } miette = { workspace = true } percent-encoding = { workspace = true } serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } toml = { workspace = true } - -serde_json = { workspace = true } \ No newline at end of file diff --git a/config/default.toml b/config/default.toml index 662b777..07cd085 100644 --- a/config/default.toml +++ b/config/default.toml @@ -58,10 +58,14 @@ # ----------------------------------[ DATA ]----------------------------------- # [REQUIRED] -# An URL pointing to a Postgres database, with a Calckey database +# An URI pointing to a Postgres database, with a Calckey database # Environment variables: MAG_C_DATABASE_URL, DATABASE_URL # data.database_url = "postgres://username:password@db:5432/calckey" +# [REQUIRED] +# An URI pointing to a Redis instance +# Environment variables: MAG_C_REDIS_URL +# data.redis_url = "redis://redis:6379" # -------------------------------[ FEDERATION ]-------------------------------- diff --git a/ext_calckey_model/Cargo.toml b/ext_calckey_model/Cargo.toml index e8646cf..12e6a72 100644 --- a/ext_calckey_model/Cargo.toml +++ b/ext_calckey_model/Cargo.toml @@ -13,10 +13,15 @@ ext_calckey_model_migration = { path = "./migration" } magnetar_common = { path = "../magnetar_common" } dotenvy = { workspace = true} +futures-core = { workspace = true } +futures-util = { workspace = true } tokio = { workspace = true, features = ["full"] } +tokio-util = { workspace = true} +redis = { workspace = true, features = ["tokio-comp", "json", "serde_json"]} sea-orm = { workspace = true, features = ["sqlx-postgres", "runtime-tokio-rustls", "macros"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } +strum = { workspace = true } chrono = { workspace = true } tracing = { workspace = true } thiserror = { workspace = true } \ No newline at end of file diff --git a/ext_calckey_model/src/lib.rs b/ext_calckey_model/src/lib.rs index 6f4d0c7..ee74e8e 100644 --- a/ext_calckey_model/src/lib.rs +++ b/ext_calckey_model/src/lib.rs @@ -1,8 +1,21 @@ -use ck::user; +use chrono::Utc; +pub use ck::*; use ext_calckey_model_migration::{Migrator, MigratorTrait}; -use sea_orm::{ColumnTrait, ConnectOptions, DatabaseConnection, EntityTrait, QueryFilter}; +use futures_util::StreamExt; +use redis::IntoConnectionInfo; +use sea_orm::ActiveValue::Set; +use sea_orm::{ + ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter, QueryOrder, + TransactionTrait, +}; +use serde::{Deserialize, Serialize}; +use std::future::Future; +use strum::IntoStaticStr; use thiserror::Error; +use tokio::select; +use tokio_util::sync::CancellationToken; use tracing::log::LevelFilter; +use tracing::{error, info}; #[derive(Debug)] pub struct ConnectorConfig { @@ -12,10 +25,10 @@ pub struct ConnectorConfig { #[derive(Clone, Debug)] pub struct CalckeyModel(DatabaseConnection); -#[derive(Debug, Error)] +#[derive(Debug, Error, IntoStaticStr)] pub enum CalckeyDbError { #[error("Database error: {0}")] - DbError(#[from] sea_orm::DbErr), + DbError(#[from] DbErr), } impl CalckeyModel { @@ -36,6 +49,10 @@ impl CalckeyModel { Ok(()) } + pub fn inner(&self) -> &DatabaseConnection { + &self.0 + } + pub async fn get_user_by_tag( &self, name: &str, @@ -61,10 +78,256 @@ impl CalckeyModel { Ok(user) } + pub async fn get_note_by_id(&self, id: &str) -> Result, CalckeyDbError> { + Ok(note::Entity::find() + .filter(note::Column::Id.eq(id)) + .one(&self.0) + .await?) + } + + pub async fn get_user_by_id(&self, id: &str) -> Result, CalckeyDbError> { + Ok(user::Entity::find() + .filter(user::Column::Id.eq(id)) + .one(&self.0) + .await?) + } + + pub async fn get_user_by_token( + &self, + token: &str, + ) -> Result, CalckeyDbError> { + Ok(user::Entity::find() + .filter( + user::Column::Token + .eq(token) + .and(user::Column::Host.is_null()), + ) + .one(&self.0) + .await?) + } + pub async fn get_user_by_uri(&self, uri: &str) -> Result, CalckeyDbError> { Ok(user::Entity::find() .filter(user::Column::Uri.eq(uri)) .one(&self.0) .await?) } + + pub async fn get_local_emoji(&self) -> Result, CalckeyDbError> { + Ok(emoji::Entity::find() + .filter(emoji::Column::Host.is_null()) + .order_by_asc(emoji::Column::Category) + .order_by_asc(emoji::Column::Name) + .all(&self.0) + .await?) + } + + pub async fn get_access_token( + &self, + token: &str, + ) -> Result, CalckeyDbError> { + let token = access_token::Entity::update(access_token::ActiveModel { + last_used_at: Set(Some(Utc::now().into())), + ..Default::default() + }) + .filter( + access_token::Column::Token + .eq(token) + .or(access_token::Column::Hash.eq(token.to_lowercase())), + ) + .exec(&self.0) + .await; + + if let Err(DbErr::RecordNotUpdated) = token { + return Ok(None); + } + + Ok(Some(token?)) + } + + pub async fn get_app_by_id(&self, id: &str) -> Result, CalckeyDbError> { + Ok(app::Entity::find() + .filter(app::Column::Id.eq(id)) + .one(&self.0) + .await?) + } + + pub async fn get_instance_meta(&self) -> Result { + let txn = self.0.begin().await?; + + let meta = meta::Entity::find().one(&txn).await?; + + if let Some(meta) = meta { + txn.commit().await?; + return Ok(meta); + } + + let model = meta::ActiveModel { + id: Set("x".to_string()), + ..Default::default() + }; + + let meta = meta::Entity::insert(model) + .exec_with_returning(&txn) + .await?; + + txn.commit().await?; + + Ok(meta) + } +} + +#[derive(Debug)] +pub struct CacheConnectorConfig { + pub url: String, +} + +#[derive(Clone, Debug)] +pub struct CalckeyCache(redis::Client); + +#[derive(Debug, Error, IntoStaticStr)] +pub enum CalckeyCacheError { + #[error("Redis error: {0}")] + RedisError(#[from] redis::RedisError), +} + +impl CalckeyCache { + pub fn new(config: CacheConnectorConfig) -> Result { + let conn_info = config.url.into_connection_info()?; + + // TODO: Allow overriding redis config with individual options (maybe) + let redis_config = redis::ConnectionInfo { + addr: conn_info.addr, + redis: conn_info.redis, + }; + + Ok(CalckeyCache(redis::Client::open(redis_config)?)) + } + + pub fn inner(&self) -> &redis::Client { + &self.0 + } + + pub async fn conn(&self) -> Result { + Ok(CalckeyCacheClient(self.0.get_async_connection().await?)) + } +} + +pub struct CalckeyCacheClient(redis::aio::Connection); + +#[derive(Clone, Debug, Deserialize)] +#[serde(tag = "channel", content = "message")] +pub enum SubMessage { + Internal(InternalStreamMessage), + #[serde(other)] + Other, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "type", content = "body")] +#[serde(rename_all = "camelCase")] +pub enum InternalStreamMessage { + #[serde(rename_all = "camelCase")] + UserChangeSuspendedState { + id: String, + is_suspended: bool, + }, + #[serde(rename_all = "camelCase")] + UserChangeSilencedState { + id: String, + is_silenced: bool, + }, + #[serde(rename_all = "camelCase")] + UserChangeModeratorState { + id: String, + is_moderator: bool, + }, + #[serde(rename_all = "camelCase")] + UserTokenRegenerated { + id: String, + old_token: String, + new_token: String, + }, + LocalUserUpdated { + id: String, + }, + RemoteUserUpdated { + id: String, + }, + WebhookCreated(webhook::Model), + WebhookDeleted(webhook::Model), + WebhookUpdated(webhook::Model), + AntennaCreated(antenna::Model), + AntennaDeleted(antenna::Model), + AntennaUpdated(antenna::Model), +} + +impl CalckeyCacheClient { + pub async fn subscribe + Send + 'static>( + self, + prefix: &str, + handler: impl Fn(SubMessage) -> F + Send + Sync + 'static, + ) -> Result { + CalckeySub::new(self.0, prefix, handler).await + } +} + +pub struct CalckeySub(CancellationToken); + +impl CalckeySub { + async fn new + Send + 'static>( + conn: redis::aio::Connection, + prefix: &str, + handler: impl Fn(SubMessage) -> F + Send + Sync + 'static, + ) -> Result { + let mut pub_sub = conn.into_pubsub(); + pub_sub.subscribe(prefix).await?; + + let token = CancellationToken::new(); + let token_rx = token.clone(); + let prefix = prefix.to_string(); + + tokio::spawn(async move { + let mut on_message = pub_sub.on_message(); + + while let Some(msg) = select! { + msg = on_message.next() => msg, + _ = token_rx.cancelled() => { + drop(on_message); + if let Err(e) = pub_sub.unsubscribe(prefix).await { + info!("Redis error: {:?}", e); + } + return; + } + } { + let data = &match msg.get_payload::() { + Ok(val) => val, + Err(e) => { + info!("Redis error: {:?}", e); + continue; + } + }; + + let parsed = match serde_json::from_str::(data) { + Ok(val) => val, + Err(e) => { + info!("Message parse error: {:?}", e); + continue; + } + }; + + println!("Got message: {:#?}", parsed); + + handler(parsed).await; + } + }); + + Ok(CalckeySub(token)) + } +} + +impl Drop for CalckeySub { + fn drop(&mut self) { + self.0.cancel(); + } } diff --git a/magnetar_common/src/config.rs b/magnetar_common/src/config.rs index 520039d..a1d6946 100644 --- a/magnetar_common/src/config.rs +++ b/magnetar_common/src/config.rs @@ -150,6 +150,7 @@ impl Default for MagnetarBranding { #[non_exhaustive] pub struct MagnetarData { pub database_url: String, + pub redis_url: String, } fn env_database_url() -> String { @@ -158,10 +159,16 @@ fn env_database_url() -> String { .expect("MAG_C_DATABASE_URL, DATABASE_URL or \"data.database_url\" in the default configuration must be set") } +fn env_redis_url() -> String { + std::env::var("MAG_C_REDIS_URL") + .expect("MAG_C_REDIS_URL or \"data.redis_url\" in the default configuration must be set") +} + impl Default for MagnetarData { fn default() -> Self { MagnetarData { database_url: env_database_url(), + redis_url: env_redis_url(), } } } diff --git a/src/api_v1/mod.rs b/src/api_v1/mod.rs new file mode 100644 index 0000000..9eefbb7 --- /dev/null +++ b/src/api_v1/mod.rs @@ -0,0 +1,20 @@ +mod user; + +use crate::api_v1::user::handle_user_info; +use crate::service::MagnetarService; +use crate::web::auth; +use crate::web::auth::AuthState; +use axum::middleware::from_fn_with_state; +use axum::routing::get; +use axum::Router; +use std::sync::Arc; + +pub fn create_api_router(service: Arc) -> Router { + Router::new() + .route("/user/@self", get(handle_user_info)) + .layer(from_fn_with_state( + AuthState::new(service.clone()), + auth::auth, + )) + .with_state(service) +} diff --git a/src/api_v1/user.rs b/src/api_v1/user.rs new file mode 100644 index 0000000..d8d8f95 --- /dev/null +++ b/src/api_v1/user.rs @@ -0,0 +1,18 @@ +use crate::service::MagnetarService; +use crate::web::auth::AuthenticatedUser; +use crate::web::ApiError; +use axum::extract::State; +use axum::response::IntoResponse; +use axum::Json; +use std::sync::Arc; + +// TODO: Not a real endpoint +pub async fn handle_user_info( + State(service): State>, + AuthenticatedUser(user): AuthenticatedUser, +) -> Result { + let db = service.db.clone(); + let user = db.get_user_by_id(&user.id).await?; + + Ok(Json(user)) +} diff --git a/src/main.rs b/src/main.rs index 585c444..c7681fd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,20 @@ +mod api_v1; pub mod nodeinfo; +pub mod service; +pub mod util; +pub mod web; pub mod webfinger; +use crate::api_v1::create_api_router; use crate::nodeinfo::{handle_nodeinfo, handle_nodeinfo_20, handle_nodeinfo_21}; +use crate::service::MagnetarService; use axum::routing::get; use axum::Router; use dotenvy::dotenv; -use magnetar_calckey_model::{CalckeyModel, ConnectorConfig}; +use magnetar_calckey_model::{CacheConnectorConfig, CalckeyCache, CalckeyModel, ConnectorConfig}; use miette::{miette, IntoDiagnostic}; use std::net::SocketAddr; +use std::sync::Arc; use tower_http::cors::{Any, CorsLayer}; use tower_http::trace::TraceLayer; use tracing::info; @@ -39,6 +46,17 @@ async fn main() -> miette::Result<()> { db.migrate().await.into_diagnostic()?; + let redis = CalckeyCache::new(CacheConnectorConfig { + url: config.data.redis_url.clone(), + }) + .into_diagnostic()?; + + let service = Arc::new( + MagnetarService::new(config, db.clone(), redis) + .await + .into_diagnostic()?, + ); + let well_known_router = Router::new() .route( "/webfinger", @@ -52,9 +70,9 @@ async fn main() -> miette::Result<()> { .route("/2.1", get(handle_nodeinfo_21)); let app = Router::new() - .nest("/.well-known", well_known_router) - .nest("/nodeinfo", nodeinfo_router) - .with_state(config) + .nest("/.well-known", well_known_router.with_state(config)) + .nest("/nodeinfo", nodeinfo_router.with_state(config)) + .nest("/mag/v1", create_api_router(service)) .layer( CorsLayer::new() .allow_headers(Any) diff --git a/src/service/mod.rs b/src/service/mod.rs new file mode 100644 index 0000000..7a88653 --- /dev/null +++ b/src/service/mod.rs @@ -0,0 +1,36 @@ +use magnetar_calckey_model::{CalckeyCache, CalckeyModel}; +use magnetar_common::config::MagnetarConfig; +use thiserror::Error; + +pub mod user_cache; + +pub struct MagnetarService { + pub db: CalckeyModel, + pub cache: CalckeyCache, + pub config: &'static MagnetarConfig, + pub auth_cache: user_cache::UserCacheService, +} + +#[derive(Debug, Error)] +pub enum ServiceInitError { + #[error("Authentication cache initialization error: {0}")] + AuthCacheError(#[from] user_cache::UserCacheError), +} + +impl MagnetarService { + pub async fn new( + config: &'static MagnetarConfig, + db: CalckeyModel, + cache: CalckeyCache, + ) -> Result { + let auth_cache = + user_cache::UserCacheService::new(config, db.clone(), cache.clone()).await?; + + Ok(Self { + db, + cache, + config, + auth_cache, + }) + } +} diff --git a/src/service/user_cache.rs b/src/service/user_cache.rs new file mode 100644 index 0000000..3f8816a --- /dev/null +++ b/src/service/user_cache.rs @@ -0,0 +1,239 @@ +use crate::web::ApiError; +use cached::{Cached, TimedCache}; +use magnetar_calckey_model::{ + user, CalckeyCache, CalckeyCacheError, CalckeyDbError, CalckeyModel, CalckeySub, + InternalStreamMessage, SubMessage, +}; +use magnetar_common::config::MagnetarConfig; +use std::collections::HashMap; +use std::sync::Arc; +use strum::EnumVariantNames; +use thiserror::Error; +use tokio::sync::Mutex; +use tracing::error; + +#[derive(Debug, Error, EnumVariantNames)] +pub enum UserCacheError { + #[error("Database error: {0}")] + DbError(#[from] CalckeyDbError), + #[error("Redis error: {0}")] + RedisError(#[from] CalckeyCacheError), +} + +impl From for ApiError { + fn from(err: UserCacheError) -> Self { + let mut api_error: ApiError = match err { + UserCacheError::DbError(err) => err.into(), + UserCacheError::RedisError(err) => err.into(), + }; + + api_error.message = format!("User cache error: {}", api_error.message); + + api_error + } +} + +struct UserCache { + lifetime: TimedCache, + id_to_user: HashMap>, + token_to_user: HashMap>, + uri_to_user: HashMap>, +} + +impl UserCache { + fn purge(&mut self, user: impl AsRef) { + let user = user.as_ref(); + + self.lifetime.cache_remove(&user.id); + + if let Some(user) = self.id_to_user.remove(&user.id) { + if let Some(token) = user.token.clone() { + self.token_to_user.remove(&token); + } + if let Some(uri) = user.uri.clone() { + self.uri_to_user.remove(&uri); + } + } + } + + fn refresh(&mut self, user: Arc) { + self.purge(&user); + + self.lifetime.cache_set(user.id.clone(), ()); + + self.id_to_user.insert(user.id.clone(), user.clone()); + + if let Some(token) = user.token.clone() { + self.token_to_user.insert(token, user.clone()); + } + + if let Some(uri) = user.uri.clone() { + self.uri_to_user.insert(uri, user); + } + } + + /// Low-priority refresh. Only refreshes the cache if the user is not there. + /// Used mostly for getters that would otherwise data race with more important refreshes. + fn maybe_refresh(&mut self, user: &Arc) { + if self.lifetime.cache_get(&user.id).is_none() { + self.refresh(user.clone()); + } + } + + fn get_by_id(&mut self, id: &str) -> Option> { + if let Some(user) = self.id_to_user.get(id).cloned() { + if self.lifetime.cache_get(id).is_none() { + self.purge(&user); + return None; + } + + return Some(user); + } + + None + } + + fn get_by_token(&mut self, token: &str) -> Option> { + if let Some(user) = self.token_to_user.get(token).cloned() { + if self.lifetime.cache_get(&user.id).is_none() { + self.purge(&user); + return None; + } + + return Some(user); + } + + None + } + + fn get_by_uri(&mut self, uri: &str) -> Option> { + if let Some(user) = self.uri_to_user.get(uri).cloned() { + if self.lifetime.cache_get(&user.id).is_none() { + self.purge(&user); + return None; + } + + return Some(user); + } + + None + } +} + +pub struct UserCacheService { + db: CalckeyModel, + #[allow(dead_code)] + token_watch: CalckeySub, + cache: Arc>, +} + +impl UserCacheService { + pub(super) async fn new( + config: &MagnetarConfig, + db: CalckeyModel, + redis: CalckeyCache, + ) -> Result { + let cache = Arc::new(Mutex::new(UserCache { + lifetime: TimedCache::with_lifespan(60 * 5), + id_to_user: HashMap::new(), + token_to_user: HashMap::new(), + uri_to_user: HashMap::new(), + })); + + let cache_clone = cache.clone(); + let db_clone = db.clone(); + + let token_watch = redis + .conn() + .await? + .subscribe(&config.networking.host, move |message| { + let cache = cache_clone.clone(); + let db = db_clone.clone(); + + async move { + let SubMessage::Internal(internal) = message else { + return; + }; + + match internal { + InternalStreamMessage::LocalUserUpdated { id } + | InternalStreamMessage::UserChangeModeratorState { id, .. } + | InternalStreamMessage::UserChangeSilencedState { id, .. } + | InternalStreamMessage::UserChangeSuspendedState { id, .. } + | InternalStreamMessage::RemoteUserUpdated { id } + | InternalStreamMessage::UserTokenRegenerated { id, .. } => { + let user = match db.get_user_by_id(&id).await { + Ok(Some(user)) => user, + Ok(None) => return, + Err(e) => { + error!("Error fetching user from database: {}", e); + return; + } + }; + + cache.lock().await.refresh(Arc::new(user)); + } + _ => {} + }; + } + }) + .await?; + + Ok(Self { + cache, + db, + token_watch, + }) + } + + async fn map_cache_user( + &self, + user: Option, + ) -> Result>, UserCacheError> { + if let Some(user) = user { + let user = Arc::new(user); + self.cache.lock().await.maybe_refresh(&user); + return Ok(Some(user)); + } + + Ok(None) + } + + pub async fn get_by_token( + &self, + token: &str, + ) -> Result>, UserCacheError> { + let result = self.cache.lock().await.get_by_token(token); + + if let Some(user) = result { + return Ok(Some(user)); + } + + self.map_cache_user(self.db.get_user_by_token(token).await?) + .await + } + + pub async fn get_by_uri(&self, uri: &str) -> Result>, UserCacheError> { + let result = self.cache.lock().await.get_by_uri(uri); + + if let Some(user) = result { + return Ok(Some(user)); + } + + let user = self.db.get_user_by_uri(uri).await?; + + self.map_cache_user(user).await + } + + pub async fn get_by_id(&self, id: &str) -> Result>, UserCacheError> { + let result = self.cache.lock().await.get_by_id(id); + + if let Some(user) = result { + return Ok(Some(user)); + } + + let user = self.db.get_user_by_id(id).await?; + + self.map_cache_user(user).await + } +} diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..b965d51 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,118 @@ +use cached::{Cached, TimedCache}; +use std::future::Future; +use std::hash::Hash; +use std::sync::Arc; +use tokio::sync::Mutex; + +#[derive(Debug, Clone)] +pub struct SingleTimedAsyncCache { + inner: Arc>>, +} + +impl SingleTimedAsyncCache { + pub fn with_lifespan(lifespan: u64) -> Self { + Self { + inner: Arc::new(Mutex::new(TimedCache::with_lifespan(lifespan))), + } + } + + pub async fn put(&self, val: V) -> Option { + let mut cache = self.inner.lock().await; + cache.cache_set((), val) + } + + pub async fn get_opt(&self) -> Option { + let mut cache = self.inner.lock().await; + cache.cache_get(&()).cloned() + } + + pub async fn get(&self, f: F) -> Result + where + FT: Future>, + F: FnOnce() -> FT, + { + let mut cache = self.inner.lock().await; + + if let Some(val) = cache.cache_get(&()) { + return Ok(val.clone()); + } + + let val = f().await?; + cache.cache_set((), val.clone()); + Ok(val) + } + + pub async fn get_sync(&self, f: F) -> Result + where + F: FnOnce() -> Result, + { + let mut cache = self.inner.lock().await; + + if let Some(val) = cache.cache_get(&()) { + return Ok(val.clone()); + } + + let val = f()?; + cache.cache_set((), val.clone()); + Ok(val) + } +} + +#[derive(Debug, Clone)] +pub struct TimedAsyncCache { + inner: Arc>>, +} + +impl TimedAsyncCache { + pub fn with_lifespan(lifespan: u64) -> Self { + Self { + inner: Arc::new(Mutex::new(TimedCache::with_lifespan(lifespan))), + } + } + + pub async fn put(&self, key: K, val: V) -> Option { + let mut cache = self.inner.lock().await; + cache.cache_set(key, val) + } + + pub async fn remove(&self, key: &K) -> Option { + let mut cache = self.inner.lock().await; + cache.cache_remove(key) + } + + pub async fn get_opt(&self, key: &K) -> Option { + let mut cache = self.inner.lock().await; + cache.cache_get(key).cloned() + } + + pub async fn get(&self, key: K, f: F) -> Result + where + FT: Future>, + F: FnOnce() -> FT, + { + let mut cache = self.inner.lock().await; + + if let Some(val) = cache.cache_get(&key) { + return Ok(val.clone()); + } + + let val = f().await?; + cache.cache_set(key, val.clone()); + Ok(val) + } + + pub async fn get_sync(&self, key: K, f: F) -> Result + where + F: FnOnce() -> Result, + { + let mut cache = self.inner.lock().await; + + if let Some(val) = cache.cache_get(&key) { + return Ok(val.clone()); + } + + let val = f()?; + cache.cache_set(key, val.clone()); + Ok(val) + } +} diff --git a/src/web/auth.rs b/src/web/auth.rs new file mode 100644 index 0000000..1073f19 --- /dev/null +++ b/src/web/auth.rs @@ -0,0 +1,262 @@ +use crate::service::user_cache::UserCacheError; +use crate::service::MagnetarService; +use crate::web::{ApiError, IntoErrorCode}; +use axum::async_trait; +use axum::extract::rejection::ExtensionRejection; +use axum::extract::{FromRequestParts, State}; +use axum::http::request::Parts; +use axum::http::{Request, StatusCode}; +use axum::middleware::Next; +use axum::response::{IntoResponse, Response}; +use headers::authorization::Bearer; +use headers::{Authorization, HeaderMapExt}; +use magnetar_calckey_model::{access_token, user, CalckeyDbError}; +use std::convert::Infallible; +use std::sync::Arc; +use strum::IntoStaticStr; +use thiserror::Error; +use tracing::error; + +#[derive(Debug)] +pub enum AuthMode { + User { + user: Arc, + }, + AccessToken { + user: Arc, + access_token: Arc, + }, + Anonymous, +} + +impl AuthMode { + fn get_user(&self) -> Option<&Arc> { + match self { + AuthMode::User { user } | AuthMode::AccessToken { user, .. } => Some(user), + AuthMode::Anonymous => None, + } + } +} + +pub struct AuthUserRejection(ApiError); + +impl From for AuthUserRejection { + fn from(rejection: ExtensionRejection) -> Self { + AuthUserRejection(ApiError { + status: StatusCode::UNAUTHORIZED, + code: "Unauthorized".error_code(), + message: if cfg!(debug_assertions) { + format!("Missing auth extension: {}", rejection) + } else { + "Unauthorized".to_string() + }, + }) + } +} + +impl IntoResponse for AuthUserRejection { + fn into_response(self) -> Response { + self.0.into_response() + } +} + +#[derive(Clone, FromRequestParts)] +#[from_request(via(axum::Extension), rejection(AuthUserRejection))] +pub struct AuthenticatedUser(pub Arc); + +#[derive(Clone)] +pub struct MaybeUser(pub Option>); + +#[async_trait] +impl FromRequestParts for MaybeUser { + type Rejection = Infallible; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + Ok(MaybeUser( + parts + .extensions + .get::() + .map(|part| part.0.clone()), + )) + } +} + +#[derive(Clone)] +pub struct AuthState { + service: Arc, +} + +#[derive(Debug, Error, IntoStaticStr)] +enum AuthError { + #[error("Unsupported authorization scheme")] + UnsupportedScheme, + #[error("Cache error: {0}")] + CacheError(#[from] UserCacheError), + #[error("Database error: {0}")] + DbError(#[from] CalckeyDbError), + #[error("Invalid token")] + InvalidToken, + #[error("Invalid token \"{token}\" referencing user \"{user}\"")] + InvalidTokenUser { token: String, user: String }, + #[error("Invalid access token \"{access_token}\" referencing app \"{app}\"")] + InvalidAccessTokenApp { access_token: String, app: String }, +} + +impl From for ApiError { + fn from(err: AuthError) -> Self { + match err { + AuthError::UnsupportedScheme => ApiError { + status: StatusCode::UNAUTHORIZED, + code: err.error_code(), + message: "Unsupported authorization scheme".to_string(), + }, + AuthError::CacheError(err) => err.into(), + AuthError::DbError(err) => err.into(), + AuthError::InvalidTokenUser { + ref token, + ref user, + } => { + error!("Invalid token \"{}\" referencing user \"{}\"", token, user); + + ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + code: err.error_code(), + message: if cfg!(debug_assertions) { + format!("Invalid token \"{}\" referencing user \"{}\"", token, user) + } else { + "Invalid token-user link".to_string() + }, + } + } + AuthError::InvalidAccessTokenApp { + ref access_token, + ref app, + } => { + error!( + "Invalid access token \"{}\" referencing app \"{}\"", + access_token, app + ); + + ApiError { + status: StatusCode::INTERNAL_SERVER_ERROR, + code: err.error_code(), + message: if cfg!(debug_assertions) { + format!( + "Invalid access token \"{}\" referencing app \"{}\"", + access_token, app + ) + } else { + "Invalid access token-app link".to_string() + }, + } + } + AuthError::InvalidToken => ApiError { + status: StatusCode::UNAUTHORIZED, + code: err.error_code(), + message: "Invalid token".to_string(), + }, + } + } +} + +pub fn is_user_token(token: &str) -> bool { + token.chars().count() == 16 +} + +impl AuthState { + pub fn new(magnetar: Arc) -> Self { + Self { service: magnetar } + } + + async fn authorize_token( + &self, + Authorization(token): &Authorization, + ) -> Result { + let token = token.token(); + + if is_user_token(token) { + let user_cache = &self.service.auth_cache; + let user = user_cache.get_by_token(token).await?; + + if let Some(user) = user { + return Ok(AuthMode::User { user }); + } + + Err(AuthError::InvalidToken) + } else { + let access_token = self.service.db.get_access_token(token).await?; + + if access_token.is_none() { + return Err(AuthError::InvalidToken); + } + + let access_token = access_token.unwrap(); + + let user = self + .service + .auth_cache + .get_by_id(&access_token.user_id) + .await?; + + if user.is_none() { + return Err(AuthError::InvalidTokenUser { + token: access_token.id, + user: access_token.user_id, + }); + } + + let user = user.unwrap(); + + if let Some(app_id) = &access_token.app_id { + return match self.service.db.get_app_by_id(app_id).await? { + Some(app) => Ok(AuthMode::AccessToken { + user, + access_token: Arc::new(access_token::Model { + permission: app.permission, + ..access_token + }), + }), + None => Err(AuthError::InvalidAccessTokenApp { + access_token: access_token.id, + app: access_token.user_id, + }), + }; + } + + let access_token = Arc::new(access_token); + + Ok(AuthMode::AccessToken { user, access_token }) + } + } +} + +pub async fn auth( + State(state): State, + mut req: Request, + next: Next, +) -> Result { + let auth_bearer = match req.headers().typed_try_get::>() { + Ok(Some(auth)) => auth, + Ok(None) => { + req.extensions_mut().insert(AuthMode::Anonymous); + return Ok(next.run(req).await); + } + Err(_) => { + return Err(AuthError::UnsupportedScheme.into()); + } + }; + + match state.authorize_token(&auth_bearer).await { + Ok(auth) => { + if let Some(user) = auth.get_user() { + let user = AuthenticatedUser(user.clone()); + req.extensions_mut().insert(user); + } + + req.extensions_mut().insert(auth); + + Ok(next.run(req).await) + } + Err(e) => Err(e.into()), + } +} diff --git a/src/web/mod.rs b/src/web/mod.rs new file mode 100644 index 0000000..bfc834e --- /dev/null +++ b/src/web/mod.rs @@ -0,0 +1,83 @@ +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use axum::Json; +use magnetar_calckey_model::{CalckeyCacheError, CalckeyDbError}; +use serde::Serialize; +use serde_json::json; + +pub mod auth; + +#[derive(Debug, Clone, Serialize)] +#[repr(transparent)] +pub struct ErrorCode(pub String); + +// This janky hack allows us to use `.error_code()` on enums with strum::IntoStaticStr +pub trait IntoErrorCode { + fn error_code<'a, 'b: 'a>(&'a self) -> ErrorCode + where + &'a Self: Into<&'b str>; +} + +impl IntoErrorCode for T { + fn error_code<'a, 'b: 'a>(&'a self) -> ErrorCode + where + &'a Self: Into<&'b str>, + { + ErrorCode(self.into().to_string()) + } +} + +impl ErrorCode { + pub fn join(&self, other: &str) -> Self { + Self(format!("{}:{}", other, self.0)) + } +} + +#[derive(Debug)] +pub struct ApiError { + pub status: StatusCode, + pub code: ErrorCode, + pub message: String, +} + +impl IntoResponse for ApiError { + fn into_response(self) -> Response { + ( + self.status, + Json(json!({ + "status": self.status.as_u16(), + "code": self.code, + "message": self.message, + })), + ) + .into_response() + } +} + +impl From for ApiError { + fn from(err: CalckeyDbError) -> Self { + Self { + status: StatusCode::INTERNAL_SERVER_ERROR, + code: err.error_code(), + message: if cfg!(debug_assertions) { + format!("Database error: {}", err) + } else { + "Database error".to_string() + }, + } + } +} + +impl From for ApiError { + fn from(err: CalckeyCacheError) -> Self { + Self { + status: StatusCode::INTERNAL_SERVER_ERROR, + code: err.error_code(), + message: if cfg!(debug_assertions) { + format!("Cache error: {}", err) + } else { + "Cache error".to_string() + }, + } + } +}