From e06906dd6e9fd3e3e987b030801156ca7519d7c2 Mon Sep 17 00:00:00 2001 From: Natty Date: Tue, 30 Apr 2024 16:02:35 +0200 Subject: [PATCH] Also fetch profile in cache --- ext_model/src/lib.rs | 111 ++++++++++++++++++++------------ src/service/local_user_cache.rs | 72 ++++++++++++--------- src/web/auth.rs | 33 +++++----- 3 files changed, 128 insertions(+), 88 deletions(-) diff --git a/ext_model/src/lib.rs b/ext_model/src/lib.rs index e228fb8..6579e29 100644 --- a/ext_model/src/lib.rs +++ b/ext_model/src/lib.rs @@ -1,3 +1,34 @@ +use std::future::Future; + +use chrono::Utc; +use futures_util::{SinkExt, StreamExt}; +use redis::IntoConnectionInfo; +pub use sea_orm; +use sea_orm::{ActiveValue::Set, ConnectionTrait}; +use sea_orm::{ + ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter, + TransactionTrait, +}; +use serde::{Deserialize, Deserializer, Serialize}; +use serde::de::Error; +use serde_json::Value; +use strum::IntoStaticStr; +use thiserror::Error; +use tokio::select; +use tokio_util::sync::CancellationToken; +use tracing::{error, info, trace, warn}; +use tracing::log::LevelFilter; +use url::Host; + +pub use ck; +use ck::*; +use ext_model_migration::{Migrator, MigratorTrait}; +use user_model::UserResolver; + +use crate::model_ext::IdShape; +use crate::note_model::NoteResolver; +use crate::notification_model::NotificationResolver; + pub mod emoji; pub mod model_ext; pub mod note_model; @@ -5,35 +36,6 @@ pub mod notification_model; pub mod poll; pub mod user_model; -pub use ck; -use ck::*; -pub use sea_orm; -use url::Host; -use user_model::UserResolver; - -use crate::model_ext::IdShape; -use crate::note_model::NoteResolver; -use crate::notification_model::NotificationResolver; -use chrono::Utc; -use ext_model_migration::{Migrator, MigratorTrait}; -use futures_util::StreamExt; -use redis::IntoConnectionInfo; -use sea_orm::{ActiveValue::Set, ConnectionTrait}; -use sea_orm::{ - ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter, - TransactionTrait, -}; -use serde::de::Error; -use serde::{Deserialize, Deserializer, Serialize}; -use serde_json::Value; -use std::future::Future; -use strum::IntoStaticStr; -use thiserror::Error; -use tokio::select; -use tokio_util::sync::CancellationToken; -use tracing::log::LevelFilter; -use tracing::{error, info, trace, warn}; - #[derive(Debug)] pub struct ConnectorConfig { pub url: String, @@ -90,8 +92,8 @@ impl CalckeyModel { .and(user::Column::Host.is_null()), ) } - .one(&self.0) - .await?; + .one(&self.0) + .await?; Ok(user) } @@ -120,6 +122,15 @@ impl CalckeyModel { .await?) } + pub async fn get_user_and_profile_by_id(&self, id: &str) -> Result, CalckeyDbError> { + Ok(user::Entity::find() + .filter(user::Column::Id.eq(id)) + .find_also_related(user_profile::Entity) + .one(&self.0) + .await? + .and_then(|(u, p)| p.map(|pp| (u, pp)))) + } + pub async fn get_user_security_keys_by_id( &self, id: &str, @@ -154,6 +165,22 @@ impl CalckeyModel { .await?) } + pub async fn get_user_and_profile_by_token( + &self, + token: &str, + ) -> Result, CalckeyDbError> { + Ok(user::Entity::find() + .filter( + user::Column::Token + .eq(token) + .and(user::Column::Host.is_null()), + ) + .find_also_related(user_profile::Entity) + .one(&self.0) + .await? + .and_then(|(u, p)| p.map(|pp| (u, pp)))) + } + pub async fn get_user_by_uri(&self, uri: &str) -> Result, CalckeyDbError> { Ok(user::Entity::find() .filter(user::Column::Uri.eq(uri)) @@ -244,13 +271,13 @@ impl CalckeyModel { last_used_at: Set(Some(Utc::now().into())), ..Default::default() }) - .filter( - access_token::Column::Token - .eq(token) - .or(access_token::Column::Hash.eq(token.to_lowercase())), - ) - .exec(&self.0) - .await; + .filter( + access_token::Column::Token + .eq(token) + .or(access_token::Column::Hash.eq(token.to_lowercase())), + ) + .exec(&self.0) + .await; if let Err(DbErr::RecordNotUpdated) = token { return Ok(None); @@ -372,8 +399,8 @@ struct RawMessage<'a> { impl<'de> Deserialize<'de> for SubMessage { fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, + where + D: Deserializer<'de>, { let raw = RawMessage::deserialize(deserializer)?; @@ -439,7 +466,7 @@ pub enum InternalStreamMessage { } impl CalckeyCacheClient { - pub async fn subscribe + Send + 'static>( + pub async fn subscribe + Send + 'static>( self, prefix: &str, handler: impl Fn(SubMessage) -> F + Send + Sync + 'static, @@ -451,7 +478,7 @@ impl CalckeyCacheClient { pub struct CalckeySub(CancellationToken); impl CalckeySub { - async fn new + Send + 'static>( + async fn new + Send + 'static>( conn: redis::aio::Connection, prefix: &str, handler: impl Fn(SubMessage) -> F + Send + Sync + 'static, diff --git a/src/service/local_user_cache.rs b/src/service/local_user_cache.rs index 10cf41d..fda63f6 100644 --- a/src/service/local_user_cache.rs +++ b/src/service/local_user_cache.rs @@ -23,6 +23,21 @@ pub enum UserCacheError { RedisError(#[from] CalckeyCacheError), } +#[derive(Debug, Clone)] +pub struct CachedLocalUser { + pub user: Arc, + pub profile: Arc, +} + +impl From<(ck::user::Model, ck::user_profile::Model)> for CachedLocalUser { + fn from((user, profile): (ck::user::Model, ck::user_profile::Model)) -> Self { + CachedLocalUser { + user: Arc::new(user), + profile: Arc::new(profile), + } + } +} + impl From for ApiError { fn from(err: UserCacheError) -> Self { let mut api_error: ApiError = match err { @@ -38,44 +53,42 @@ impl From for ApiError { struct LocalUserCache { lifetime: TimedCache, - id_to_user: HashMap>, - token_to_user: HashMap>, + id_to_user: HashMap, + token_to_user: HashMap, } impl LocalUserCache { - fn purge(&mut self, user: impl AsRef) { - let user = user.as_ref(); + fn purge(&mut self, user: &CachedLocalUser) { + self.lifetime.cache_remove(&user.user.id); - self.lifetime.cache_remove(&user.id); - - if let Some(user) = self.id_to_user.remove(&user.id) { - if let Some(token) = user.token.clone() { + if let Some(user) = self.id_to_user.remove(&user.user.id) { + if let Some(token) = user.user.token.clone() { self.token_to_user.remove(&token); } } } - fn refresh(&mut self, user: Arc) { - self.purge(&user); + fn refresh(&mut self, user: &CachedLocalUser) { + self.purge(user); - self.lifetime.cache_set(user.id.clone(), ()); + self.lifetime.cache_set(user.user.id.clone(), ()); - self.id_to_user.insert(user.id.clone(), user.clone()); + self.id_to_user.insert(user.user.id.clone(), user.clone()); - if let Some(token) = user.token.clone() { + if let Some(token) = user.user.token.clone() { self.token_to_user.insert(token, user.clone()); } } /// Low-priority refresh. Only refreshes the cache if the user is not there. /// Used mostly for getters that would otherwise data race with more important refreshes. - fn maybe_refresh(&mut self, user: &Arc) { - if self.lifetime.cache_get(&user.id).is_none() { - self.refresh(user.clone()); + fn maybe_refresh(&mut self, user: &CachedLocalUser) { + if self.lifetime.cache_get(&user.user.id).is_none() { + self.refresh(user); } } - fn get_by_id(&mut self, id: &str) -> Option> { + fn get_by_id(&mut self, id: &str) -> Option { if let Some(user) = self.id_to_user.get(id).cloned() { if self.lifetime.cache_get(id).is_none() { self.purge(&user); @@ -88,9 +101,9 @@ impl LocalUserCache { None } - fn get_by_token(&mut self, token: &str) -> Option> { + fn get_by_token(&mut self, token: &str) -> Option { if let Some(user) = self.token_to_user.get(token).cloned() { - if self.lifetime.cache_get(&user.id).is_none() { + if self.lifetime.cache_get(&user.user.id).is_none() { self.purge(&user); return None; } @@ -143,8 +156,8 @@ impl LocalUserCacheService { | InternalStreamMessage::UserChangeSuspendedState { id, .. } | InternalStreamMessage::RemoteUserUpdated { id } | InternalStreamMessage::UserTokenRegenerated { id, .. } => { - let user = match db.get_user_by_id(&id).await { - Ok(Some(user)) => user, + let user_profile = match db.get_user_and_profile_by_id(&id).await { + Ok(Some(m)) => m, Ok(None) => return, Err(e) => { error!("Error fetching user from database: {}", e); @@ -152,7 +165,7 @@ impl LocalUserCacheService { } }; - cache.lock().await.refresh(Arc::new(user)); + cache.lock().await.refresh(&CachedLocalUser::from(user_profile)); } _ => {} }; @@ -169,10 +182,9 @@ impl LocalUserCacheService { async fn map_cache_user( &self, - user: Option, - ) -> Result>, UserCacheError> { + user: Option, + ) -> Result, UserCacheError> { if let Some(user) = user { - let user = Arc::new(user); self.cache.lock().await.maybe_refresh(&user); return Ok(Some(user)); } @@ -183,29 +195,27 @@ impl LocalUserCacheService { pub async fn get_by_token( &self, token: &str, - ) -> Result>, UserCacheError> { + ) -> Result, UserCacheError> { let result = self.cache.lock().await.get_by_token(token); if let Some(user) = result { return Ok(Some(user)); } - self.map_cache_user(self.db.get_user_by_token(token).await?) + self.map_cache_user(self.db.get_user_and_profile_by_token(token).await?.map(CachedLocalUser::from)) .await } pub async fn get_by_id( &self, id: &str, - ) -> Result>, UserCacheError> { + ) -> Result, UserCacheError> { let result = self.cache.lock().await.get_by_id(id); if let Some(user) = result { return Ok(Some(user)); } - let user = self.db.get_user_by_id(id).await?; - - self.map_cache_user(user).await + self.map_cache_user(self.db.get_user_and_profile_by_id(id).await?.map(CachedLocalUser::from)).await } } diff --git a/src/web/auth.rs b/src/web/auth.rs index dcf4efc..3bb092e 100644 --- a/src/web/auth.rs +++ b/src/web/auth.rs @@ -1,22 +1,25 @@ -use crate::service::local_user_cache::UserCacheError; -use crate::service::MagnetarService; -use crate::web::{ApiError, IntoErrorCode}; -use axum::async_trait; -use axum::extract::rejection::ExtensionRejection; -use axum::extract::{FromRequestParts, Request, State}; -use axum::http::request::Parts; -use axum::http::{HeaderMap, StatusCode}; -use axum::middleware::Next; -use axum::response::{IntoResponse, Response}; -use headers::authorization::Bearer; -use headers::{Authorization, HeaderMapExt}; -use magnetar_model::{ck, CalckeyDbError}; use std::convert::Infallible; use std::sync::Arc; + +use axum::async_trait; +use axum::extract::{FromRequestParts, Request, State}; +use axum::extract::rejection::ExtensionRejection; +use axum::http::{HeaderMap, StatusCode}; +use axum::http::request::Parts; +use axum::middleware::Next; +use axum::response::{IntoResponse, Response}; +use headers::{Authorization, HeaderMapExt}; +use headers::authorization::Bearer; use strum::IntoStaticStr; use thiserror::Error; use tracing::error; +use magnetar_model::{CalckeyDbError, ck}; + +use crate::service::local_user_cache::{CachedLocalUser, UserCacheError}; +use crate::service::MagnetarService; +use crate::web::{ApiError, IntoErrorCode}; + #[derive(Clone, Debug)] pub enum AuthMode { User { @@ -178,7 +181,7 @@ impl AuthState { let user_cache = &self.service.local_user_cache; let user = user_cache.get_by_token(token).await?; - if let Some(user) = user { + if let Some(CachedLocalUser { user, .. }) = user { return Ok(AuthMode::User { user }); } @@ -205,7 +208,7 @@ impl AuthState { }); } - let user = user.unwrap(); + let CachedLocalUser { user, .. } = user.unwrap(); if let Some(app_id) = &access_token.app_id { return match self.service.db.get_app_by_id(app_id).await? {