Initial API authentication implementation
ci/woodpecker/push/ociImagePush Pipeline was successful Details

This commit is contained in:
Natty 2023-08-02 03:10:53 +02:00
parent c49871ec0f
commit 1a44dfe56f
Signed by: natty
GPG Key ID: BF6CB659ADEE60EC
15 changed files with 1285 additions and 12 deletions

View File

@ -15,6 +15,10 @@ nattyarch.local {
reverse_proxy 127.0.0.1:4939
}
handle /mag/* {
reverse_proxy 127.0.0.1:4939
}
@render_html {
not path /api* /proxy* /files* /avatar* /identicon* /streaming
header Accept text/html*

183
Cargo.lock generated
View File

@ -125,6 +125,7 @@ checksum = "f8175979259124331c1d7bf6586ee7e0da434155e4b2d48ec2c8386281d8df39"
dependencies = [
"async-trait",
"axum-core",
"axum-macros",
"bitflags",
"bytes",
"futures-util",
@ -167,6 +168,18 @@ dependencies = [
"tower-service",
]
[[package]]
name = "axum-macros"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdca6a10ecad987bda04e95606ef85a5417dcaac1a78455242d72e031e2b6b62"
dependencies = [
"heck 0.4.1",
"proc-macro2",
"quote",
"syn 2.0.16",
]
[[package]]
name = "bae"
version = "0.1.7"
@ -325,6 +338,42 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be"
[[package]]
name = "cached"
version = "0.44.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b195e4fbc4b6862bbd065b991a34750399c119797efff72492f28a5864de8700"
dependencies = [
"async-trait",
"cached_proc_macro",
"cached_proc_macro_types",
"futures",
"hashbrown 0.13.2",
"instant",
"once_cell",
"thiserror",
"tokio",
]
[[package]]
name = "cached_proc_macro"
version = "0.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b48814962d2fd604c50d2b9433c2a41a0ab567779ee2c02f7fba6eca1221f082"
dependencies = [
"cached_proc_macro_types",
"darling",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "cached_proc_macro_types"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a4f925191b4367301851c6d99b09890311d74b0d43f274c0b34c86d308a3663"
[[package]]
name = "cc"
version = "1.0.79"
@ -419,6 +468,20 @@ dependencies = [
"os_str_bytes",
]
[[package]]
name = "combine"
version = "4.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4"
dependencies = [
"bytes",
"futures-core",
"memchr",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.4"
@ -463,6 +526,41 @@ dependencies = [
"typenum",
]
[[package]]
name = "darling"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn 1.0.109",
]
[[package]]
name = "darling_macro"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e"
dependencies = [
"darling_core",
"quote",
"syn 1.0.109",
]
[[package]]
name = "deunicode"
version = "0.4.3"
@ -594,6 +692,17 @@ version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964"
[[package]]
name = "futures-macro"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.16",
]
[[package]]
name = "futures-sink"
version = "0.3.28"
@ -615,6 +724,7 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
@ -883,6 +993,12 @@ dependencies = [
"cc",
]
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "idna"
version = "0.3.0"
@ -989,7 +1105,11 @@ name = "magnetar"
version = "0.2.0"
dependencies = [
"axum",
"cached",
"cfg-if",
"chrono",
"dotenvy",
"headers",
"hyper",
"magnetar_calckey_model",
"magnetar_common",
@ -1000,6 +1120,7 @@ dependencies = [
"percent-encoding",
"serde",
"serde_json",
"strum",
"thiserror",
"tokio",
"toml 0.7.4",
@ -1041,12 +1162,17 @@ dependencies = [
"ck",
"dotenvy",
"ext_calckey_model_migration",
"futures-core",
"futures-util",
"magnetar_common",
"redis",
"sea-orm",
"serde",
"serde_json",
"strum",
"thiserror",
"tokio",
"tokio-util",
"tracing",
]
@ -1571,6 +1697,29 @@ dependencies = [
"getrandom",
]
[[package]]
name = "redis"
version = "0.23.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff5d95dd18a4d76650f0c2607ed8ebdbf63baf9cb934e1c233cd220c694db1d7"
dependencies = [
"async-trait",
"bytes",
"combine",
"futures-util",
"itoa",
"percent-encoding",
"pin-project-lite",
"ryu",
"serde",
"serde_json",
"sha1_smol",
"socket2",
"tokio",
"tokio-util",
"url",
]
[[package]]
name = "redox_syscall"
version = "0.2.16"
@ -1992,6 +2141,12 @@ dependencies = [
"digest",
]
[[package]]
name = "sha1_smol"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012"
[[package]]
name = "sha2"
version = "0.10.6"
@ -2192,6 +2347,34 @@ dependencies = [
"unicode-normalization",
]
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "strum"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
version = "0.25.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6069ca09d878a33f883cc06aaa9718ede171841d3832450354410b718b097232"
dependencies = [
"heck 0.4.1",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.16",
]
[[package]]
name = "subtle"
version = "2.4.1"

View File

@ -22,19 +22,27 @@ edition = "2021"
[workspace.dependencies]
axum = "0.6"
cached = "0.44"
cfg-if = "1"
chrono = "0.4"
dotenvy = "0.15"
futures-core = "0.3"
futures-util = "0.3"
headers = "0.3"
hyper = "0.14"
log = "0.4"
miette = "5.9"
percent-encoding = "2.2"
redis = "0.23"
sea-orm = "0.11"
sea-orm-migration = "0.11"
serde = "1"
serde_json = "1"
strum = "0.25"
tera = { version = "1", default-features = false }
thiserror = "1"
tokio = "1.24"
tokio-util = "0.7"
toml = "0.7"
tower = "0.4"
tower-http = "0.4"
@ -50,9 +58,12 @@ magnetar_webfinger = { path = "./ext_webfinger" }
magnetar_nodeinfo = { path = "./ext_nodeinfo" }
magnetar_calckey_model = { path = "./ext_calckey_model" }
cached = { workspace = true }
chrono = { workspace = true }
dotenvy = { workspace = true }
axum = { workspace = true }
axum = { workspace = true, features = ["macros"] }
headers = { workspace = true }
hyper = { workspace = true, features = ["full"] }
tokio = { workspace = true, features = ["full"] }
tower = { workspace = true }
@ -61,12 +72,14 @@ tower-http = { workspace = true, features = ["cors", "trace", "fs"] }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
tracing = { workspace = true }
cfg-if = { workspace = true }
strum = { workspace = true, features = ["derive"] }
thiserror = { workspace = true }
miette = { workspace = true }
percent-encoding = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
toml = { workspace = true }
serde_json = { workspace = true }

View File

@ -58,10 +58,14 @@
# ----------------------------------[ DATA ]-----------------------------------
# [REQUIRED]
# An URL pointing to a Postgres database, with a Calckey database
# An URI pointing to a Postgres database, with a Calckey database
# Environment variables: MAG_C_DATABASE_URL, DATABASE_URL
# data.database_url = "postgres://username:password@db:5432/calckey"
# [REQUIRED]
# An URI pointing to a Redis instance
# Environment variables: MAG_C_REDIS_URL
# data.redis_url = "redis://redis:6379"
# -------------------------------[ FEDERATION ]--------------------------------

View File

@ -13,10 +13,15 @@ ext_calckey_model_migration = { path = "./migration" }
magnetar_common = { path = "../magnetar_common" }
dotenvy = { workspace = true}
futures-core = { workspace = true }
futures-util = { workspace = true }
tokio = { workspace = true, features = ["full"] }
tokio-util = { workspace = true}
redis = { workspace = true, features = ["tokio-comp", "json", "serde_json"]}
sea-orm = { workspace = true, features = ["sqlx-postgres", "runtime-tokio-rustls", "macros"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
strum = { workspace = true }
chrono = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }

View File

@ -1,8 +1,21 @@
use ck::user;
use chrono::Utc;
pub use ck::*;
use ext_calckey_model_migration::{Migrator, MigratorTrait};
use sea_orm::{ColumnTrait, ConnectOptions, DatabaseConnection, EntityTrait, QueryFilter};
use futures_util::StreamExt;
use redis::IntoConnectionInfo;
use sea_orm::ActiveValue::Set;
use sea_orm::{
ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter, QueryOrder,
TransactionTrait,
};
use serde::{Deserialize, Serialize};
use std::future::Future;
use strum::IntoStaticStr;
use thiserror::Error;
use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::log::LevelFilter;
use tracing::{error, info};
#[derive(Debug)]
pub struct ConnectorConfig {
@ -12,10 +25,10 @@ pub struct ConnectorConfig {
#[derive(Clone, Debug)]
pub struct CalckeyModel(DatabaseConnection);
#[derive(Debug, Error)]
#[derive(Debug, Error, IntoStaticStr)]
pub enum CalckeyDbError {
#[error("Database error: {0}")]
DbError(#[from] sea_orm::DbErr),
DbError(#[from] DbErr),
}
impl CalckeyModel {
@ -36,6 +49,10 @@ impl CalckeyModel {
Ok(())
}
pub fn inner(&self) -> &DatabaseConnection {
&self.0
}
pub async fn get_user_by_tag(
&self,
name: &str,
@ -61,10 +78,256 @@ impl CalckeyModel {
Ok(user)
}
pub async fn get_note_by_id(&self, id: &str) -> Result<Option<note::Model>, CalckeyDbError> {
Ok(note::Entity::find()
.filter(note::Column::Id.eq(id))
.one(&self.0)
.await?)
}
pub async fn get_user_by_id(&self, id: &str) -> Result<Option<user::Model>, CalckeyDbError> {
Ok(user::Entity::find()
.filter(user::Column::Id.eq(id))
.one(&self.0)
.await?)
}
pub async fn get_user_by_token(
&self,
token: &str,
) -> Result<Option<user::Model>, CalckeyDbError> {
Ok(user::Entity::find()
.filter(
user::Column::Token
.eq(token)
.and(user::Column::Host.is_null()),
)
.one(&self.0)
.await?)
}
pub async fn get_user_by_uri(&self, uri: &str) -> Result<Option<user::Model>, CalckeyDbError> {
Ok(user::Entity::find()
.filter(user::Column::Uri.eq(uri))
.one(&self.0)
.await?)
}
pub async fn get_local_emoji(&self) -> Result<Vec<emoji::Model>, CalckeyDbError> {
Ok(emoji::Entity::find()
.filter(emoji::Column::Host.is_null())
.order_by_asc(emoji::Column::Category)
.order_by_asc(emoji::Column::Name)
.all(&self.0)
.await?)
}
pub async fn get_access_token(
&self,
token: &str,
) -> Result<Option<access_token::Model>, CalckeyDbError> {
let token = access_token::Entity::update(access_token::ActiveModel {
last_used_at: Set(Some(Utc::now().into())),
..Default::default()
})
.filter(
access_token::Column::Token
.eq(token)
.or(access_token::Column::Hash.eq(token.to_lowercase())),
)
.exec(&self.0)
.await;
if let Err(DbErr::RecordNotUpdated) = token {
return Ok(None);
}
Ok(Some(token?))
}
pub async fn get_app_by_id(&self, id: &str) -> Result<Option<app::Model>, CalckeyDbError> {
Ok(app::Entity::find()
.filter(app::Column::Id.eq(id))
.one(&self.0)
.await?)
}
pub async fn get_instance_meta(&self) -> Result<meta::Model, CalckeyDbError> {
let txn = self.0.begin().await?;
let meta = meta::Entity::find().one(&txn).await?;
if let Some(meta) = meta {
txn.commit().await?;
return Ok(meta);
}
let model = meta::ActiveModel {
id: Set("x".to_string()),
..Default::default()
};
let meta = meta::Entity::insert(model)
.exec_with_returning(&txn)
.await?;
txn.commit().await?;
Ok(meta)
}
}
#[derive(Debug)]
pub struct CacheConnectorConfig {
pub url: String,
}
#[derive(Clone, Debug)]
pub struct CalckeyCache(redis::Client);
#[derive(Debug, Error, IntoStaticStr)]
pub enum CalckeyCacheError {
#[error("Redis error: {0}")]
RedisError(#[from] redis::RedisError),
}
impl CalckeyCache {
pub fn new(config: CacheConnectorConfig) -> Result<Self, CalckeyCacheError> {
let conn_info = config.url.into_connection_info()?;
// TODO: Allow overriding redis config with individual options (maybe)
let redis_config = redis::ConnectionInfo {
addr: conn_info.addr,
redis: conn_info.redis,
};
Ok(CalckeyCache(redis::Client::open(redis_config)?))
}
pub fn inner(&self) -> &redis::Client {
&self.0
}
pub async fn conn(&self) -> Result<CalckeyCacheClient, CalckeyCacheError> {
Ok(CalckeyCacheClient(self.0.get_async_connection().await?))
}
}
pub struct CalckeyCacheClient(redis::aio::Connection);
#[derive(Clone, Debug, Deserialize)]
#[serde(tag = "channel", content = "message")]
pub enum SubMessage {
Internal(InternalStreamMessage),
#[serde(other)]
Other,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type", content = "body")]
#[serde(rename_all = "camelCase")]
pub enum InternalStreamMessage {
#[serde(rename_all = "camelCase")]
UserChangeSuspendedState {
id: String,
is_suspended: bool,
},
#[serde(rename_all = "camelCase")]
UserChangeSilencedState {
id: String,
is_silenced: bool,
},
#[serde(rename_all = "camelCase")]
UserChangeModeratorState {
id: String,
is_moderator: bool,
},
#[serde(rename_all = "camelCase")]
UserTokenRegenerated {
id: String,
old_token: String,
new_token: String,
},
LocalUserUpdated {
id: String,
},
RemoteUserUpdated {
id: String,
},
WebhookCreated(webhook::Model),
WebhookDeleted(webhook::Model),
WebhookUpdated(webhook::Model),
AntennaCreated(antenna::Model),
AntennaDeleted(antenna::Model),
AntennaUpdated(antenna::Model),
}
impl CalckeyCacheClient {
pub async fn subscribe<F: Future<Output = ()> + Send + 'static>(
self,
prefix: &str,
handler: impl Fn(SubMessage) -> F + Send + Sync + 'static,
) -> Result<CalckeySub, CalckeyCacheError> {
CalckeySub::new(self.0, prefix, handler).await
}
}
pub struct CalckeySub(CancellationToken);
impl CalckeySub {
async fn new<F: Future<Output = ()> + Send + 'static>(
conn: redis::aio::Connection,
prefix: &str,
handler: impl Fn(SubMessage) -> F + Send + Sync + 'static,
) -> Result<Self, CalckeyCacheError> {
let mut pub_sub = conn.into_pubsub();
pub_sub.subscribe(prefix).await?;
let token = CancellationToken::new();
let token_rx = token.clone();
let prefix = prefix.to_string();
tokio::spawn(async move {
let mut on_message = pub_sub.on_message();
while let Some(msg) = select! {
msg = on_message.next() => msg,
_ = token_rx.cancelled() => {
drop(on_message);
if let Err(e) = pub_sub.unsubscribe(prefix).await {
info!("Redis error: {:?}", e);
}
return;
}
} {
let data = &match msg.get_payload::<String>() {
Ok(val) => val,
Err(e) => {
info!("Redis error: {:?}", e);
continue;
}
};
let parsed = match serde_json::from_str::<SubMessage>(data) {
Ok(val) => val,
Err(e) => {
info!("Message parse error: {:?}", e);
continue;
}
};
println!("Got message: {:#?}", parsed);
handler(parsed).await;
}
});
Ok(CalckeySub(token))
}
}
impl Drop for CalckeySub {
fn drop(&mut self) {
self.0.cancel();
}
}

View File

@ -150,6 +150,7 @@ impl Default for MagnetarBranding {
#[non_exhaustive]
pub struct MagnetarData {
pub database_url: String,
pub redis_url: String,
}
fn env_database_url() -> String {
@ -158,10 +159,16 @@ fn env_database_url() -> String {
.expect("MAG_C_DATABASE_URL, DATABASE_URL or \"data.database_url\" in the default configuration must be set")
}
fn env_redis_url() -> String {
std::env::var("MAG_C_REDIS_URL")
.expect("MAG_C_REDIS_URL or \"data.redis_url\" in the default configuration must be set")
}
impl Default for MagnetarData {
fn default() -> Self {
MagnetarData {
database_url: env_database_url(),
redis_url: env_redis_url(),
}
}
}

20
src/api_v1/mod.rs Normal file
View File

@ -0,0 +1,20 @@
mod user;
use crate::api_v1::user::handle_user_info;
use crate::service::MagnetarService;
use crate::web::auth;
use crate::web::auth::AuthState;
use axum::middleware::from_fn_with_state;
use axum::routing::get;
use axum::Router;
use std::sync::Arc;
pub fn create_api_router(service: Arc<MagnetarService>) -> Router {
Router::new()
.route("/user/@self", get(handle_user_info))
.layer(from_fn_with_state(
AuthState::new(service.clone()),
auth::auth,
))
.with_state(service)
}

18
src/api_v1/user.rs Normal file
View File

@ -0,0 +1,18 @@
use crate::service::MagnetarService;
use crate::web::auth::AuthenticatedUser;
use crate::web::ApiError;
use axum::extract::State;
use axum::response::IntoResponse;
use axum::Json;
use std::sync::Arc;
// TODO: Not a real endpoint
pub async fn handle_user_info(
State(service): State<Arc<MagnetarService>>,
AuthenticatedUser(user): AuthenticatedUser,
) -> Result<impl IntoResponse, ApiError> {
let db = service.db.clone();
let user = db.get_user_by_id(&user.id).await?;
Ok(Json(user))
}

View File

@ -1,13 +1,20 @@
mod api_v1;
pub mod nodeinfo;
pub mod service;
pub mod util;
pub mod web;
pub mod webfinger;
use crate::api_v1::create_api_router;
use crate::nodeinfo::{handle_nodeinfo, handle_nodeinfo_20, handle_nodeinfo_21};
use crate::service::MagnetarService;
use axum::routing::get;
use axum::Router;
use dotenvy::dotenv;
use magnetar_calckey_model::{CalckeyModel, ConnectorConfig};
use magnetar_calckey_model::{CacheConnectorConfig, CalckeyCache, CalckeyModel, ConnectorConfig};
use miette::{miette, IntoDiagnostic};
use std::net::SocketAddr;
use std::sync::Arc;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use tracing::info;
@ -39,6 +46,17 @@ async fn main() -> miette::Result<()> {
db.migrate().await.into_diagnostic()?;
let redis = CalckeyCache::new(CacheConnectorConfig {
url: config.data.redis_url.clone(),
})
.into_diagnostic()?;
let service = Arc::new(
MagnetarService::new(config, db.clone(), redis)
.await
.into_diagnostic()?,
);
let well_known_router = Router::new()
.route(
"/webfinger",
@ -52,9 +70,9 @@ async fn main() -> miette::Result<()> {
.route("/2.1", get(handle_nodeinfo_21));
let app = Router::new()
.nest("/.well-known", well_known_router)
.nest("/nodeinfo", nodeinfo_router)
.with_state(config)
.nest("/.well-known", well_known_router.with_state(config))
.nest("/nodeinfo", nodeinfo_router.with_state(config))
.nest("/mag/v1", create_api_router(service))
.layer(
CorsLayer::new()
.allow_headers(Any)

36
src/service/mod.rs Normal file
View File

@ -0,0 +1,36 @@
use magnetar_calckey_model::{CalckeyCache, CalckeyModel};
use magnetar_common::config::MagnetarConfig;
use thiserror::Error;
pub mod user_cache;
pub struct MagnetarService {
pub db: CalckeyModel,
pub cache: CalckeyCache,
pub config: &'static MagnetarConfig,
pub auth_cache: user_cache::UserCacheService,
}
#[derive(Debug, Error)]
pub enum ServiceInitError {
#[error("Authentication cache initialization error: {0}")]
AuthCacheError(#[from] user_cache::UserCacheError),
}
impl MagnetarService {
pub async fn new(
config: &'static MagnetarConfig,
db: CalckeyModel,
cache: CalckeyCache,
) -> Result<Self, ServiceInitError> {
let auth_cache =
user_cache::UserCacheService::new(config, db.clone(), cache.clone()).await?;
Ok(Self {
db,
cache,
config,
auth_cache,
})
}
}

239
src/service/user_cache.rs Normal file
View File

@ -0,0 +1,239 @@
use crate::web::ApiError;
use cached::{Cached, TimedCache};
use magnetar_calckey_model::{
user, CalckeyCache, CalckeyCacheError, CalckeyDbError, CalckeyModel, CalckeySub,
InternalStreamMessage, SubMessage,
};
use magnetar_common::config::MagnetarConfig;
use std::collections::HashMap;
use std::sync::Arc;
use strum::EnumVariantNames;
use thiserror::Error;
use tokio::sync::Mutex;
use tracing::error;
#[derive(Debug, Error, EnumVariantNames)]
pub enum UserCacheError {
#[error("Database error: {0}")]
DbError(#[from] CalckeyDbError),
#[error("Redis error: {0}")]
RedisError(#[from] CalckeyCacheError),
}
impl From<UserCacheError> for ApiError {
fn from(err: UserCacheError) -> Self {
let mut api_error: ApiError = match err {
UserCacheError::DbError(err) => err.into(),
UserCacheError::RedisError(err) => err.into(),
};
api_error.message = format!("User cache error: {}", api_error.message);
api_error
}
}
struct UserCache {
lifetime: TimedCache<String, ()>,
id_to_user: HashMap<String, Arc<user::Model>>,
token_to_user: HashMap<String, Arc<user::Model>>,
uri_to_user: HashMap<String, Arc<user::Model>>,
}
impl UserCache {
fn purge(&mut self, user: impl AsRef<user::Model>) {
let user = user.as_ref();
self.lifetime.cache_remove(&user.id);
if let Some(user) = self.id_to_user.remove(&user.id) {
if let Some(token) = user.token.clone() {
self.token_to_user.remove(&token);
}
if let Some(uri) = user.uri.clone() {
self.uri_to_user.remove(&uri);
}
}
}
fn refresh(&mut self, user: Arc<user::Model>) {
self.purge(&user);
self.lifetime.cache_set(user.id.clone(), ());
self.id_to_user.insert(user.id.clone(), user.clone());
if let Some(token) = user.token.clone() {
self.token_to_user.insert(token, user.clone());
}
if let Some(uri) = user.uri.clone() {
self.uri_to_user.insert(uri, user);
}
}
/// Low-priority refresh. Only refreshes the cache if the user is not there.
/// Used mostly for getters that would otherwise data race with more important refreshes.
fn maybe_refresh(&mut self, user: &Arc<user::Model>) {
if self.lifetime.cache_get(&user.id).is_none() {
self.refresh(user.clone());
}
}
fn get_by_id(&mut self, id: &str) -> Option<Arc<user::Model>> {
if let Some(user) = self.id_to_user.get(id).cloned() {
if self.lifetime.cache_get(id).is_none() {
self.purge(&user);
return None;
}
return Some(user);
}
None
}
fn get_by_token(&mut self, token: &str) -> Option<Arc<user::Model>> {
if let Some(user) = self.token_to_user.get(token).cloned() {
if self.lifetime.cache_get(&user.id).is_none() {
self.purge(&user);
return None;
}
return Some(user);
}
None
}
fn get_by_uri(&mut self, uri: &str) -> Option<Arc<user::Model>> {
if let Some(user) = self.uri_to_user.get(uri).cloned() {
if self.lifetime.cache_get(&user.id).is_none() {
self.purge(&user);
return None;
}
return Some(user);
}
None
}
}
pub struct UserCacheService {
db: CalckeyModel,
#[allow(dead_code)]
token_watch: CalckeySub,
cache: Arc<Mutex<UserCache>>,
}
impl UserCacheService {
pub(super) async fn new(
config: &MagnetarConfig,
db: CalckeyModel,
redis: CalckeyCache,
) -> Result<Self, UserCacheError> {
let cache = Arc::new(Mutex::new(UserCache {
lifetime: TimedCache::with_lifespan(60 * 5),
id_to_user: HashMap::new(),
token_to_user: HashMap::new(),
uri_to_user: HashMap::new(),
}));
let cache_clone = cache.clone();
let db_clone = db.clone();
let token_watch = redis
.conn()
.await?
.subscribe(&config.networking.host, move |message| {
let cache = cache_clone.clone();
let db = db_clone.clone();
async move {
let SubMessage::Internal(internal) = message else {
return;
};
match internal {
InternalStreamMessage::LocalUserUpdated { id }
| InternalStreamMessage::UserChangeModeratorState { id, .. }
| InternalStreamMessage::UserChangeSilencedState { id, .. }
| InternalStreamMessage::UserChangeSuspendedState { id, .. }
| InternalStreamMessage::RemoteUserUpdated { id }
| InternalStreamMessage::UserTokenRegenerated { id, .. } => {
let user = match db.get_user_by_id(&id).await {
Ok(Some(user)) => user,
Ok(None) => return,
Err(e) => {
error!("Error fetching user from database: {}", e);
return;
}
};
cache.lock().await.refresh(Arc::new(user));
}
_ => {}
};
}
})
.await?;
Ok(Self {
cache,
db,
token_watch,
})
}
async fn map_cache_user(
&self,
user: Option<user::Model>,
) -> Result<Option<Arc<user::Model>>, UserCacheError> {
if let Some(user) = user {
let user = Arc::new(user);
self.cache.lock().await.maybe_refresh(&user);
return Ok(Some(user));
}
Ok(None)
}
pub async fn get_by_token(
&self,
token: &str,
) -> Result<Option<Arc<user::Model>>, UserCacheError> {
let result = self.cache.lock().await.get_by_token(token);
if let Some(user) = result {
return Ok(Some(user));
}
self.map_cache_user(self.db.get_user_by_token(token).await?)
.await
}
pub async fn get_by_uri(&self, uri: &str) -> Result<Option<Arc<user::Model>>, UserCacheError> {
let result = self.cache.lock().await.get_by_uri(uri);
if let Some(user) = result {
return Ok(Some(user));
}
let user = self.db.get_user_by_uri(uri).await?;
self.map_cache_user(user).await
}
pub async fn get_by_id(&self, id: &str) -> Result<Option<Arc<user::Model>>, UserCacheError> {
let result = self.cache.lock().await.get_by_id(id);
if let Some(user) = result {
return Ok(Some(user));
}
let user = self.db.get_user_by_id(id).await?;
self.map_cache_user(user).await
}
}

118
src/util.rs Normal file
View File

@ -0,0 +1,118 @@
use cached::{Cached, TimedCache};
use std::future::Future;
use std::hash::Hash;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Clone)]
pub struct SingleTimedAsyncCache<V: Clone + Send + 'static> {
inner: Arc<Mutex<TimedCache<(), V>>>,
}
impl<V: Clone + Send + 'static> SingleTimedAsyncCache<V> {
pub fn with_lifespan(lifespan: u64) -> Self {
Self {
inner: Arc::new(Mutex::new(TimedCache::with_lifespan(lifespan))),
}
}
pub async fn put(&self, val: V) -> Option<V> {
let mut cache = self.inner.lock().await;
cache.cache_set((), val)
}
pub async fn get_opt(&self) -> Option<V> {
let mut cache = self.inner.lock().await;
cache.cache_get(&()).cloned()
}
pub async fn get<E, FT, F>(&self, f: F) -> Result<V, E>
where
FT: Future<Output = Result<V, E>>,
F: FnOnce() -> FT,
{
let mut cache = self.inner.lock().await;
if let Some(val) = cache.cache_get(&()) {
return Ok(val.clone());
}
let val = f().await?;
cache.cache_set((), val.clone());
Ok(val)
}
pub async fn get_sync<E, F>(&self, f: F) -> Result<V, E>
where
F: FnOnce() -> Result<V, E>,
{
let mut cache = self.inner.lock().await;
if let Some(val) = cache.cache_get(&()) {
return Ok(val.clone());
}
let val = f()?;
cache.cache_set((), val.clone());
Ok(val)
}
}
#[derive(Debug, Clone)]
pub struct TimedAsyncCache<K: Clone + Send, V: Clone + Send + 'static> {
inner: Arc<Mutex<TimedCache<K, V>>>,
}
impl<K: Clone + Send + Eq + Hash, V: Clone + Send + 'static> TimedAsyncCache<K, V> {
pub fn with_lifespan(lifespan: u64) -> Self {
Self {
inner: Arc::new(Mutex::new(TimedCache::with_lifespan(lifespan))),
}
}
pub async fn put(&self, key: K, val: V) -> Option<V> {
let mut cache = self.inner.lock().await;
cache.cache_set(key, val)
}
pub async fn remove(&self, key: &K) -> Option<V> {
let mut cache = self.inner.lock().await;
cache.cache_remove(key)
}
pub async fn get_opt(&self, key: &K) -> Option<V> {
let mut cache = self.inner.lock().await;
cache.cache_get(key).cloned()
}
pub async fn get<E, FT, F>(&self, key: K, f: F) -> Result<V, E>
where
FT: Future<Output = Result<V, E>>,
F: FnOnce() -> FT,
{
let mut cache = self.inner.lock().await;
if let Some(val) = cache.cache_get(&key) {
return Ok(val.clone());
}
let val = f().await?;
cache.cache_set(key, val.clone());
Ok(val)
}
pub async fn get_sync<E, F>(&self, key: K, f: F) -> Result<V, E>
where
F: FnOnce() -> Result<V, E>,
{
let mut cache = self.inner.lock().await;
if let Some(val) = cache.cache_get(&key) {
return Ok(val.clone());
}
let val = f()?;
cache.cache_set(key, val.clone());
Ok(val)
}
}

262
src/web/auth.rs Normal file
View File

@ -0,0 +1,262 @@
use crate::service::user_cache::UserCacheError;
use crate::service::MagnetarService;
use crate::web::{ApiError, IntoErrorCode};
use axum::async_trait;
use axum::extract::rejection::ExtensionRejection;
use axum::extract::{FromRequestParts, State};
use axum::http::request::Parts;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use headers::authorization::Bearer;
use headers::{Authorization, HeaderMapExt};
use magnetar_calckey_model::{access_token, user, CalckeyDbError};
use std::convert::Infallible;
use std::sync::Arc;
use strum::IntoStaticStr;
use thiserror::Error;
use tracing::error;
#[derive(Debug)]
pub enum AuthMode {
User {
user: Arc<user::Model>,
},
AccessToken {
user: Arc<user::Model>,
access_token: Arc<access_token::Model>,
},
Anonymous,
}
impl AuthMode {
fn get_user(&self) -> Option<&Arc<user::Model>> {
match self {
AuthMode::User { user } | AuthMode::AccessToken { user, .. } => Some(user),
AuthMode::Anonymous => None,
}
}
}
pub struct AuthUserRejection(ApiError);
impl From<ExtensionRejection> for AuthUserRejection {
fn from(rejection: ExtensionRejection) -> Self {
AuthUserRejection(ApiError {
status: StatusCode::UNAUTHORIZED,
code: "Unauthorized".error_code(),
message: if cfg!(debug_assertions) {
format!("Missing auth extension: {}", rejection)
} else {
"Unauthorized".to_string()
},
})
}
}
impl IntoResponse for AuthUserRejection {
fn into_response(self) -> Response {
self.0.into_response()
}
}
#[derive(Clone, FromRequestParts)]
#[from_request(via(axum::Extension), rejection(AuthUserRejection))]
pub struct AuthenticatedUser(pub Arc<user::Model>);
#[derive(Clone)]
pub struct MaybeUser(pub Option<Arc<user::Model>>);
#[async_trait]
impl<S> FromRequestParts<S> for MaybeUser {
type Rejection = Infallible;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
Ok(MaybeUser(
parts
.extensions
.get::<AuthenticatedUser>()
.map(|part| part.0.clone()),
))
}
}
#[derive(Clone)]
pub struct AuthState {
service: Arc<MagnetarService>,
}
#[derive(Debug, Error, IntoStaticStr)]
enum AuthError {
#[error("Unsupported authorization scheme")]
UnsupportedScheme,
#[error("Cache error: {0}")]
CacheError(#[from] UserCacheError),
#[error("Database error: {0}")]
DbError(#[from] CalckeyDbError),
#[error("Invalid token")]
InvalidToken,
#[error("Invalid token \"{token}\" referencing user \"{user}\"")]
InvalidTokenUser { token: String, user: String },
#[error("Invalid access token \"{access_token}\" referencing app \"{app}\"")]
InvalidAccessTokenApp { access_token: String, app: String },
}
impl From<AuthError> for ApiError {
fn from(err: AuthError) -> Self {
match err {
AuthError::UnsupportedScheme => ApiError {
status: StatusCode::UNAUTHORIZED,
code: err.error_code(),
message: "Unsupported authorization scheme".to_string(),
},
AuthError::CacheError(err) => err.into(),
AuthError::DbError(err) => err.into(),
AuthError::InvalidTokenUser {
ref token,
ref user,
} => {
error!("Invalid token \"{}\" referencing user \"{}\"", token, user);
ApiError {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: err.error_code(),
message: if cfg!(debug_assertions) {
format!("Invalid token \"{}\" referencing user \"{}\"", token, user)
} else {
"Invalid token-user link".to_string()
},
}
}
AuthError::InvalidAccessTokenApp {
ref access_token,
ref app,
} => {
error!(
"Invalid access token \"{}\" referencing app \"{}\"",
access_token, app
);
ApiError {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: err.error_code(),
message: if cfg!(debug_assertions) {
format!(
"Invalid access token \"{}\" referencing app \"{}\"",
access_token, app
)
} else {
"Invalid access token-app link".to_string()
},
}
}
AuthError::InvalidToken => ApiError {
status: StatusCode::UNAUTHORIZED,
code: err.error_code(),
message: "Invalid token".to_string(),
},
}
}
}
pub fn is_user_token(token: &str) -> bool {
token.chars().count() == 16
}
impl AuthState {
pub fn new(magnetar: Arc<MagnetarService>) -> Self {
Self { service: magnetar }
}
async fn authorize_token(
&self,
Authorization(token): &Authorization<Bearer>,
) -> Result<AuthMode, AuthError> {
let token = token.token();
if is_user_token(token) {
let user_cache = &self.service.auth_cache;
let user = user_cache.get_by_token(token).await?;
if let Some(user) = user {
return Ok(AuthMode::User { user });
}
Err(AuthError::InvalidToken)
} else {
let access_token = self.service.db.get_access_token(token).await?;
if access_token.is_none() {
return Err(AuthError::InvalidToken);
}
let access_token = access_token.unwrap();
let user = self
.service
.auth_cache
.get_by_id(&access_token.user_id)
.await?;
if user.is_none() {
return Err(AuthError::InvalidTokenUser {
token: access_token.id,
user: access_token.user_id,
});
}
let user = user.unwrap();
if let Some(app_id) = &access_token.app_id {
return match self.service.db.get_app_by_id(app_id).await? {
Some(app) => Ok(AuthMode::AccessToken {
user,
access_token: Arc::new(access_token::Model {
permission: app.permission,
..access_token
}),
}),
None => Err(AuthError::InvalidAccessTokenApp {
access_token: access_token.id,
app: access_token.user_id,
}),
};
}
let access_token = Arc::new(access_token);
Ok(AuthMode::AccessToken { user, access_token })
}
}
}
pub async fn auth<B>(
State(state): State<AuthState>,
mut req: Request<B>,
next: Next<B>,
) -> Result<Response, ApiError> {
let auth_bearer = match req.headers().typed_try_get::<Authorization<Bearer>>() {
Ok(Some(auth)) => auth,
Ok(None) => {
req.extensions_mut().insert(AuthMode::Anonymous);
return Ok(next.run(req).await);
}
Err(_) => {
return Err(AuthError::UnsupportedScheme.into());
}
};
match state.authorize_token(&auth_bearer).await {
Ok(auth) => {
if let Some(user) = auth.get_user() {
let user = AuthenticatedUser(user.clone());
req.extensions_mut().insert(user);
}
req.extensions_mut().insert(auth);
Ok(next.run(req).await)
}
Err(e) => Err(e.into()),
}
}

83
src/web/mod.rs Normal file
View File

@ -0,0 +1,83 @@
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use magnetar_calckey_model::{CalckeyCacheError, CalckeyDbError};
use serde::Serialize;
use serde_json::json;
pub mod auth;
#[derive(Debug, Clone, Serialize)]
#[repr(transparent)]
pub struct ErrorCode(pub String);
// This janky hack allows us to use `.error_code()` on enums with strum::IntoStaticStr
pub trait IntoErrorCode {
fn error_code<'a, 'b: 'a>(&'a self) -> ErrorCode
where
&'a Self: Into<&'b str>;
}
impl<T: ?Sized> IntoErrorCode for T {
fn error_code<'a, 'b: 'a>(&'a self) -> ErrorCode
where
&'a Self: Into<&'b str>,
{
ErrorCode(self.into().to_string())
}
}
impl ErrorCode {
pub fn join(&self, other: &str) -> Self {
Self(format!("{}:{}", other, self.0))
}
}
#[derive(Debug)]
pub struct ApiError {
pub status: StatusCode,
pub code: ErrorCode,
pub message: String,
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
(
self.status,
Json(json!({
"status": self.status.as_u16(),
"code": self.code,
"message": self.message,
})),
)
.into_response()
}
}
impl From<CalckeyDbError> for ApiError {
fn from(err: CalckeyDbError) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: err.error_code(),
message: if cfg!(debug_assertions) {
format!("Database error: {}", err)
} else {
"Database error".to_string()
},
}
}
}
impl From<CalckeyCacheError> for ApiError {
fn from(err: CalckeyCacheError) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: err.error_code(),
message: if cfg!(debug_assertions) {
format!("Cache error: {}", err)
} else {
"Cache error".to_string()
},
}
}
}