Also fetch profile in cache

This commit is contained in:
Natty 2024-04-30 16:02:35 +02:00
parent 155e458806
commit e06906dd6e
Signed by: natty
GPG Key ID: BF6CB659ADEE60EC
3 changed files with 128 additions and 88 deletions

View File

@ -1,3 +1,34 @@
use std::future::Future;
use chrono::Utc;
use futures_util::{SinkExt, StreamExt};
use redis::IntoConnectionInfo;
pub use sea_orm;
use sea_orm::{ActiveValue::Set, ConnectionTrait};
use sea_orm::{
ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter,
TransactionTrait,
};
use serde::{Deserialize, Deserializer, Serialize};
use serde::de::Error;
use serde_json::Value;
use strum::IntoStaticStr;
use thiserror::Error;
use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, trace, warn};
use tracing::log::LevelFilter;
use url::Host;
pub use ck;
use ck::*;
use ext_model_migration::{Migrator, MigratorTrait};
use user_model::UserResolver;
use crate::model_ext::IdShape;
use crate::note_model::NoteResolver;
use crate::notification_model::NotificationResolver;
pub mod emoji; pub mod emoji;
pub mod model_ext; pub mod model_ext;
pub mod note_model; pub mod note_model;
@ -5,35 +36,6 @@ pub mod notification_model;
pub mod poll; pub mod poll;
pub mod user_model; pub mod user_model;
pub use ck;
use ck::*;
pub use sea_orm;
use url::Host;
use user_model::UserResolver;
use crate::model_ext::IdShape;
use crate::note_model::NoteResolver;
use crate::notification_model::NotificationResolver;
use chrono::Utc;
use ext_model_migration::{Migrator, MigratorTrait};
use futures_util::StreamExt;
use redis::IntoConnectionInfo;
use sea_orm::{ActiveValue::Set, ConnectionTrait};
use sea_orm::{
ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter,
TransactionTrait,
};
use serde::de::Error;
use serde::{Deserialize, Deserializer, Serialize};
use serde_json::Value;
use std::future::Future;
use strum::IntoStaticStr;
use thiserror::Error;
use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::log::LevelFilter;
use tracing::{error, info, trace, warn};
#[derive(Debug)] #[derive(Debug)]
pub struct ConnectorConfig { pub struct ConnectorConfig {
pub url: String, pub url: String,
@ -120,6 +122,15 @@ impl CalckeyModel {
.await?) .await?)
} }
pub async fn get_user_and_profile_by_id(&self, id: &str) -> Result<Option<(user::Model, user_profile::Model)>, CalckeyDbError> {
Ok(user::Entity::find()
.filter(user::Column::Id.eq(id))
.find_also_related(user_profile::Entity)
.one(&self.0)
.await?
.and_then(|(u, p)| p.map(|pp| (u, pp))))
}
pub async fn get_user_security_keys_by_id( pub async fn get_user_security_keys_by_id(
&self, &self,
id: &str, id: &str,
@ -154,6 +165,22 @@ impl CalckeyModel {
.await?) .await?)
} }
pub async fn get_user_and_profile_by_token(
&self,
token: &str,
) -> Result<Option<(user::Model, user_profile::Model)>, CalckeyDbError> {
Ok(user::Entity::find()
.filter(
user::Column::Token
.eq(token)
.and(user::Column::Host.is_null()),
)
.find_also_related(user_profile::Entity)
.one(&self.0)
.await?
.and_then(|(u, p)| p.map(|pp| (u, pp))))
}
pub async fn get_user_by_uri(&self, uri: &str) -> Result<Option<user::Model>, CalckeyDbError> { pub async fn get_user_by_uri(&self, uri: &str) -> Result<Option<user::Model>, CalckeyDbError> {
Ok(user::Entity::find() Ok(user::Entity::find()
.filter(user::Column::Uri.eq(uri)) .filter(user::Column::Uri.eq(uri))
@ -439,7 +466,7 @@ pub enum InternalStreamMessage {
} }
impl CalckeyCacheClient { impl CalckeyCacheClient {
pub async fn subscribe<F: Future<Output = ()> + Send + 'static>( pub async fn subscribe<F: Future<Output=()> + Send + 'static>(
self, self,
prefix: &str, prefix: &str,
handler: impl Fn(SubMessage) -> F + Send + Sync + 'static, handler: impl Fn(SubMessage) -> F + Send + Sync + 'static,
@ -451,7 +478,7 @@ impl CalckeyCacheClient {
pub struct CalckeySub(CancellationToken); pub struct CalckeySub(CancellationToken);
impl CalckeySub { impl CalckeySub {
async fn new<F: Future<Output = ()> + Send + 'static>( async fn new<F: Future<Output=()> + Send + 'static>(
conn: redis::aio::Connection, conn: redis::aio::Connection,
prefix: &str, prefix: &str,
handler: impl Fn(SubMessage) -> F + Send + Sync + 'static, handler: impl Fn(SubMessage) -> F + Send + Sync + 'static,

View File

@ -23,6 +23,21 @@ pub enum UserCacheError {
RedisError(#[from] CalckeyCacheError), RedisError(#[from] CalckeyCacheError),
} }
#[derive(Debug, Clone)]
pub struct CachedLocalUser {
pub user: Arc<ck::user::Model>,
pub profile: Arc<ck::user_profile::Model>,
}
impl From<(ck::user::Model, ck::user_profile::Model)> for CachedLocalUser {
fn from((user, profile): (ck::user::Model, ck::user_profile::Model)) -> Self {
CachedLocalUser {
user: Arc::new(user),
profile: Arc::new(profile),
}
}
}
impl From<UserCacheError> for ApiError { impl From<UserCacheError> for ApiError {
fn from(err: UserCacheError) -> Self { fn from(err: UserCacheError) -> Self {
let mut api_error: ApiError = match err { let mut api_error: ApiError = match err {
@ -38,44 +53,42 @@ impl From<UserCacheError> for ApiError {
struct LocalUserCache { struct LocalUserCache {
lifetime: TimedCache<String, ()>, lifetime: TimedCache<String, ()>,
id_to_user: HashMap<String, Arc<ck::user::Model>>, id_to_user: HashMap<String, CachedLocalUser>,
token_to_user: HashMap<String, Arc<ck::user::Model>>, token_to_user: HashMap<String, CachedLocalUser>,
} }
impl LocalUserCache { impl LocalUserCache {
fn purge(&mut self, user: impl AsRef<ck::user::Model>) { fn purge(&mut self, user: &CachedLocalUser) {
let user = user.as_ref(); self.lifetime.cache_remove(&user.user.id);
self.lifetime.cache_remove(&user.id); if let Some(user) = self.id_to_user.remove(&user.user.id) {
if let Some(token) = user.user.token.clone() {
if let Some(user) = self.id_to_user.remove(&user.id) {
if let Some(token) = user.token.clone() {
self.token_to_user.remove(&token); self.token_to_user.remove(&token);
} }
} }
} }
fn refresh(&mut self, user: Arc<ck::user::Model>) { fn refresh(&mut self, user: &CachedLocalUser) {
self.purge(&user); self.purge(user);
self.lifetime.cache_set(user.id.clone(), ()); self.lifetime.cache_set(user.user.id.clone(), ());
self.id_to_user.insert(user.id.clone(), user.clone()); self.id_to_user.insert(user.user.id.clone(), user.clone());
if let Some(token) = user.token.clone() { if let Some(token) = user.user.token.clone() {
self.token_to_user.insert(token, user.clone()); self.token_to_user.insert(token, user.clone());
} }
} }
/// Low-priority refresh. Only refreshes the cache if the user is not there. /// Low-priority refresh. Only refreshes the cache if the user is not there.
/// Used mostly for getters that would otherwise data race with more important refreshes. /// Used mostly for getters that would otherwise data race with more important refreshes.
fn maybe_refresh(&mut self, user: &Arc<ck::user::Model>) { fn maybe_refresh(&mut self, user: &CachedLocalUser) {
if self.lifetime.cache_get(&user.id).is_none() { if self.lifetime.cache_get(&user.user.id).is_none() {
self.refresh(user.clone()); self.refresh(user);
} }
} }
fn get_by_id(&mut self, id: &str) -> Option<Arc<ck::user::Model>> { fn get_by_id(&mut self, id: &str) -> Option<CachedLocalUser> {
if let Some(user) = self.id_to_user.get(id).cloned() { if let Some(user) = self.id_to_user.get(id).cloned() {
if self.lifetime.cache_get(id).is_none() { if self.lifetime.cache_get(id).is_none() {
self.purge(&user); self.purge(&user);
@ -88,9 +101,9 @@ impl LocalUserCache {
None None
} }
fn get_by_token(&mut self, token: &str) -> Option<Arc<ck::user::Model>> { fn get_by_token(&mut self, token: &str) -> Option<CachedLocalUser> {
if let Some(user) = self.token_to_user.get(token).cloned() { if let Some(user) = self.token_to_user.get(token).cloned() {
if self.lifetime.cache_get(&user.id).is_none() { if self.lifetime.cache_get(&user.user.id).is_none() {
self.purge(&user); self.purge(&user);
return None; return None;
} }
@ -143,8 +156,8 @@ impl LocalUserCacheService {
| InternalStreamMessage::UserChangeSuspendedState { id, .. } | InternalStreamMessage::UserChangeSuspendedState { id, .. }
| InternalStreamMessage::RemoteUserUpdated { id } | InternalStreamMessage::RemoteUserUpdated { id }
| InternalStreamMessage::UserTokenRegenerated { id, .. } => { | InternalStreamMessage::UserTokenRegenerated { id, .. } => {
let user = match db.get_user_by_id(&id).await { let user_profile = match db.get_user_and_profile_by_id(&id).await {
Ok(Some(user)) => user, Ok(Some(m)) => m,
Ok(None) => return, Ok(None) => return,
Err(e) => { Err(e) => {
error!("Error fetching user from database: {}", e); error!("Error fetching user from database: {}", e);
@ -152,7 +165,7 @@ impl LocalUserCacheService {
} }
}; };
cache.lock().await.refresh(Arc::new(user)); cache.lock().await.refresh(&CachedLocalUser::from(user_profile));
} }
_ => {} _ => {}
}; };
@ -169,10 +182,9 @@ impl LocalUserCacheService {
async fn map_cache_user( async fn map_cache_user(
&self, &self,
user: Option<ck::user::Model>, user: Option<CachedLocalUser>,
) -> Result<Option<Arc<ck::user::Model>>, UserCacheError> { ) -> Result<Option<CachedLocalUser>, UserCacheError> {
if let Some(user) = user { if let Some(user) = user {
let user = Arc::new(user);
self.cache.lock().await.maybe_refresh(&user); self.cache.lock().await.maybe_refresh(&user);
return Ok(Some(user)); return Ok(Some(user));
} }
@ -183,29 +195,27 @@ impl LocalUserCacheService {
pub async fn get_by_token( pub async fn get_by_token(
&self, &self,
token: &str, token: &str,
) -> Result<Option<Arc<ck::user::Model>>, UserCacheError> { ) -> Result<Option<CachedLocalUser>, UserCacheError> {
let result = self.cache.lock().await.get_by_token(token); let result = self.cache.lock().await.get_by_token(token);
if let Some(user) = result { if let Some(user) = result {
return Ok(Some(user)); return Ok(Some(user));
} }
self.map_cache_user(self.db.get_user_by_token(token).await?) self.map_cache_user(self.db.get_user_and_profile_by_token(token).await?.map(CachedLocalUser::from))
.await .await
} }
pub async fn get_by_id( pub async fn get_by_id(
&self, &self,
id: &str, id: &str,
) -> Result<Option<Arc<ck::user::Model>>, UserCacheError> { ) -> Result<Option<CachedLocalUser>, UserCacheError> {
let result = self.cache.lock().await.get_by_id(id); let result = self.cache.lock().await.get_by_id(id);
if let Some(user) = result { if let Some(user) = result {
return Ok(Some(user)); return Ok(Some(user));
} }
let user = self.db.get_user_by_id(id).await?; self.map_cache_user(self.db.get_user_and_profile_by_id(id).await?.map(CachedLocalUser::from)).await
self.map_cache_user(user).await
} }
} }

View File

@ -1,22 +1,25 @@
use crate::service::local_user_cache::UserCacheError;
use crate::service::MagnetarService;
use crate::web::{ApiError, IntoErrorCode};
use axum::async_trait;
use axum::extract::rejection::ExtensionRejection;
use axum::extract::{FromRequestParts, Request, State};
use axum::http::request::Parts;
use axum::http::{HeaderMap, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use headers::authorization::Bearer;
use headers::{Authorization, HeaderMapExt};
use magnetar_model::{ck, CalckeyDbError};
use std::convert::Infallible; use std::convert::Infallible;
use std::sync::Arc; use std::sync::Arc;
use axum::async_trait;
use axum::extract::{FromRequestParts, Request, State};
use axum::extract::rejection::ExtensionRejection;
use axum::http::{HeaderMap, StatusCode};
use axum::http::request::Parts;
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use headers::{Authorization, HeaderMapExt};
use headers::authorization::Bearer;
use strum::IntoStaticStr; use strum::IntoStaticStr;
use thiserror::Error; use thiserror::Error;
use tracing::error; use tracing::error;
use magnetar_model::{CalckeyDbError, ck};
use crate::service::local_user_cache::{CachedLocalUser, UserCacheError};
use crate::service::MagnetarService;
use crate::web::{ApiError, IntoErrorCode};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum AuthMode { pub enum AuthMode {
User { User {
@ -178,7 +181,7 @@ impl AuthState {
let user_cache = &self.service.local_user_cache; let user_cache = &self.service.local_user_cache;
let user = user_cache.get_by_token(token).await?; let user = user_cache.get_by_token(token).await?;
if let Some(user) = user { if let Some(CachedLocalUser { user, .. }) = user {
return Ok(AuthMode::User { user }); return Ok(AuthMode::User { user });
} }
@ -205,7 +208,7 @@ impl AuthState {
}); });
} }
let user = user.unwrap(); let CachedLocalUser { user, .. } = user.unwrap();
if let Some(app_id) = &access_token.app_id { if let Some(app_id) = &access_token.app_id {
return match self.service.db.get_app_by_id(app_id).await? { return match self.service.db.get_app_by_id(app_id).await? {