Also fetch profile in cache
This commit is contained in:
parent
155e458806
commit
e06906dd6e
|
@ -1,3 +1,34 @@
|
|||
use std::future::Future;
|
||||
|
||||
use chrono::Utc;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use redis::IntoConnectionInfo;
|
||||
pub use sea_orm;
|
||||
use sea_orm::{ActiveValue::Set, ConnectionTrait};
|
||||
use sea_orm::{
|
||||
ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter,
|
||||
TransactionTrait,
|
||||
};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use serde::de::Error;
|
||||
use serde_json::Value;
|
||||
use strum::IntoStaticStr;
|
||||
use thiserror::Error;
|
||||
use tokio::select;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, trace, warn};
|
||||
use tracing::log::LevelFilter;
|
||||
use url::Host;
|
||||
|
||||
pub use ck;
|
||||
use ck::*;
|
||||
use ext_model_migration::{Migrator, MigratorTrait};
|
||||
use user_model::UserResolver;
|
||||
|
||||
use crate::model_ext::IdShape;
|
||||
use crate::note_model::NoteResolver;
|
||||
use crate::notification_model::NotificationResolver;
|
||||
|
||||
pub mod emoji;
|
||||
pub mod model_ext;
|
||||
pub mod note_model;
|
||||
|
@ -5,35 +36,6 @@ pub mod notification_model;
|
|||
pub mod poll;
|
||||
pub mod user_model;
|
||||
|
||||
pub use ck;
|
||||
use ck::*;
|
||||
pub use sea_orm;
|
||||
use url::Host;
|
||||
use user_model::UserResolver;
|
||||
|
||||
use crate::model_ext::IdShape;
|
||||
use crate::note_model::NoteResolver;
|
||||
use crate::notification_model::NotificationResolver;
|
||||
use chrono::Utc;
|
||||
use ext_model_migration::{Migrator, MigratorTrait};
|
||||
use futures_util::StreamExt;
|
||||
use redis::IntoConnectionInfo;
|
||||
use sea_orm::{ActiveValue::Set, ConnectionTrait};
|
||||
use sea_orm::{
|
||||
ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter,
|
||||
TransactionTrait,
|
||||
};
|
||||
use serde::de::Error;
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::future::Future;
|
||||
use strum::IntoStaticStr;
|
||||
use thiserror::Error;
|
||||
use tokio::select;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::log::LevelFilter;
|
||||
use tracing::{error, info, trace, warn};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectorConfig {
|
||||
pub url: String,
|
||||
|
@ -90,8 +92,8 @@ impl CalckeyModel {
|
|||
.and(user::Column::Host.is_null()),
|
||||
)
|
||||
}
|
||||
.one(&self.0)
|
||||
.await?;
|
||||
.one(&self.0)
|
||||
.await?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
|
@ -120,6 +122,15 @@ impl CalckeyModel {
|
|||
.await?)
|
||||
}
|
||||
|
||||
pub async fn get_user_and_profile_by_id(&self, id: &str) -> Result<Option<(user::Model, user_profile::Model)>, CalckeyDbError> {
|
||||
Ok(user::Entity::find()
|
||||
.filter(user::Column::Id.eq(id))
|
||||
.find_also_related(user_profile::Entity)
|
||||
.one(&self.0)
|
||||
.await?
|
||||
.and_then(|(u, p)| p.map(|pp| (u, pp))))
|
||||
}
|
||||
|
||||
pub async fn get_user_security_keys_by_id(
|
||||
&self,
|
||||
id: &str,
|
||||
|
@ -154,6 +165,22 @@ impl CalckeyModel {
|
|||
.await?)
|
||||
}
|
||||
|
||||
pub async fn get_user_and_profile_by_token(
|
||||
&self,
|
||||
token: &str,
|
||||
) -> Result<Option<(user::Model, user_profile::Model)>, CalckeyDbError> {
|
||||
Ok(user::Entity::find()
|
||||
.filter(
|
||||
user::Column::Token
|
||||
.eq(token)
|
||||
.and(user::Column::Host.is_null()),
|
||||
)
|
||||
.find_also_related(user_profile::Entity)
|
||||
.one(&self.0)
|
||||
.await?
|
||||
.and_then(|(u, p)| p.map(|pp| (u, pp))))
|
||||
}
|
||||
|
||||
pub async fn get_user_by_uri(&self, uri: &str) -> Result<Option<user::Model>, CalckeyDbError> {
|
||||
Ok(user::Entity::find()
|
||||
.filter(user::Column::Uri.eq(uri))
|
||||
|
@ -244,13 +271,13 @@ impl CalckeyModel {
|
|||
last_used_at: Set(Some(Utc::now().into())),
|
||||
..Default::default()
|
||||
})
|
||||
.filter(
|
||||
access_token::Column::Token
|
||||
.eq(token)
|
||||
.or(access_token::Column::Hash.eq(token.to_lowercase())),
|
||||
)
|
||||
.exec(&self.0)
|
||||
.await;
|
||||
.filter(
|
||||
access_token::Column::Token
|
||||
.eq(token)
|
||||
.or(access_token::Column::Hash.eq(token.to_lowercase())),
|
||||
)
|
||||
.exec(&self.0)
|
||||
.await;
|
||||
|
||||
if let Err(DbErr::RecordNotUpdated) = token {
|
||||
return Ok(None);
|
||||
|
@ -372,8 +399,8 @@ struct RawMessage<'a> {
|
|||
|
||||
impl<'de> Deserialize<'de> for SubMessage {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let raw = RawMessage::deserialize(deserializer)?;
|
||||
|
||||
|
@ -439,7 +466,7 @@ pub enum InternalStreamMessage {
|
|||
}
|
||||
|
||||
impl CalckeyCacheClient {
|
||||
pub async fn subscribe<F: Future<Output = ()> + Send + 'static>(
|
||||
pub async fn subscribe<F: Future<Output=()> + Send + 'static>(
|
||||
self,
|
||||
prefix: &str,
|
||||
handler: impl Fn(SubMessage) -> F + Send + Sync + 'static,
|
||||
|
@ -451,7 +478,7 @@ impl CalckeyCacheClient {
|
|||
pub struct CalckeySub(CancellationToken);
|
||||
|
||||
impl CalckeySub {
|
||||
async fn new<F: Future<Output = ()> + Send + 'static>(
|
||||
async fn new<F: Future<Output=()> + Send + 'static>(
|
||||
conn: redis::aio::Connection,
|
||||
prefix: &str,
|
||||
handler: impl Fn(SubMessage) -> F + Send + Sync + 'static,
|
||||
|
|
|
@ -23,6 +23,21 @@ pub enum UserCacheError {
|
|||
RedisError(#[from] CalckeyCacheError),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedLocalUser {
|
||||
pub user: Arc<ck::user::Model>,
|
||||
pub profile: Arc<ck::user_profile::Model>,
|
||||
}
|
||||
|
||||
impl From<(ck::user::Model, ck::user_profile::Model)> for CachedLocalUser {
|
||||
fn from((user, profile): (ck::user::Model, ck::user_profile::Model)) -> Self {
|
||||
CachedLocalUser {
|
||||
user: Arc::new(user),
|
||||
profile: Arc::new(profile),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UserCacheError> for ApiError {
|
||||
fn from(err: UserCacheError) -> Self {
|
||||
let mut api_error: ApiError = match err {
|
||||
|
@ -38,44 +53,42 @@ impl From<UserCacheError> for ApiError {
|
|||
|
||||
struct LocalUserCache {
|
||||
lifetime: TimedCache<String, ()>,
|
||||
id_to_user: HashMap<String, Arc<ck::user::Model>>,
|
||||
token_to_user: HashMap<String, Arc<ck::user::Model>>,
|
||||
id_to_user: HashMap<String, CachedLocalUser>,
|
||||
token_to_user: HashMap<String, CachedLocalUser>,
|
||||
}
|
||||
|
||||
impl LocalUserCache {
|
||||
fn purge(&mut self, user: impl AsRef<ck::user::Model>) {
|
||||
let user = user.as_ref();
|
||||
fn purge(&mut self, user: &CachedLocalUser) {
|
||||
self.lifetime.cache_remove(&user.user.id);
|
||||
|
||||
self.lifetime.cache_remove(&user.id);
|
||||
|
||||
if let Some(user) = self.id_to_user.remove(&user.id) {
|
||||
if let Some(token) = user.token.clone() {
|
||||
if let Some(user) = self.id_to_user.remove(&user.user.id) {
|
||||
if let Some(token) = user.user.token.clone() {
|
||||
self.token_to_user.remove(&token);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn refresh(&mut self, user: Arc<ck::user::Model>) {
|
||||
self.purge(&user);
|
||||
fn refresh(&mut self, user: &CachedLocalUser) {
|
||||
self.purge(user);
|
||||
|
||||
self.lifetime.cache_set(user.id.clone(), ());
|
||||
self.lifetime.cache_set(user.user.id.clone(), ());
|
||||
|
||||
self.id_to_user.insert(user.id.clone(), user.clone());
|
||||
self.id_to_user.insert(user.user.id.clone(), user.clone());
|
||||
|
||||
if let Some(token) = user.token.clone() {
|
||||
if let Some(token) = user.user.token.clone() {
|
||||
self.token_to_user.insert(token, user.clone());
|
||||
}
|
||||
}
|
||||
|
||||
/// Low-priority refresh. Only refreshes the cache if the user is not there.
|
||||
/// Used mostly for getters that would otherwise data race with more important refreshes.
|
||||
fn maybe_refresh(&mut self, user: &Arc<ck::user::Model>) {
|
||||
if self.lifetime.cache_get(&user.id).is_none() {
|
||||
self.refresh(user.clone());
|
||||
fn maybe_refresh(&mut self, user: &CachedLocalUser) {
|
||||
if self.lifetime.cache_get(&user.user.id).is_none() {
|
||||
self.refresh(user);
|
||||
}
|
||||
}
|
||||
|
||||
fn get_by_id(&mut self, id: &str) -> Option<Arc<ck::user::Model>> {
|
||||
fn get_by_id(&mut self, id: &str) -> Option<CachedLocalUser> {
|
||||
if let Some(user) = self.id_to_user.get(id).cloned() {
|
||||
if self.lifetime.cache_get(id).is_none() {
|
||||
self.purge(&user);
|
||||
|
@ -88,9 +101,9 @@ impl LocalUserCache {
|
|||
None
|
||||
}
|
||||
|
||||
fn get_by_token(&mut self, token: &str) -> Option<Arc<ck::user::Model>> {
|
||||
fn get_by_token(&mut self, token: &str) -> Option<CachedLocalUser> {
|
||||
if let Some(user) = self.token_to_user.get(token).cloned() {
|
||||
if self.lifetime.cache_get(&user.id).is_none() {
|
||||
if self.lifetime.cache_get(&user.user.id).is_none() {
|
||||
self.purge(&user);
|
||||
return None;
|
||||
}
|
||||
|
@ -143,8 +156,8 @@ impl LocalUserCacheService {
|
|||
| InternalStreamMessage::UserChangeSuspendedState { id, .. }
|
||||
| InternalStreamMessage::RemoteUserUpdated { id }
|
||||
| InternalStreamMessage::UserTokenRegenerated { id, .. } => {
|
||||
let user = match db.get_user_by_id(&id).await {
|
||||
Ok(Some(user)) => user,
|
||||
let user_profile = match db.get_user_and_profile_by_id(&id).await {
|
||||
Ok(Some(m)) => m,
|
||||
Ok(None) => return,
|
||||
Err(e) => {
|
||||
error!("Error fetching user from database: {}", e);
|
||||
|
@ -152,7 +165,7 @@ impl LocalUserCacheService {
|
|||
}
|
||||
};
|
||||
|
||||
cache.lock().await.refresh(Arc::new(user));
|
||||
cache.lock().await.refresh(&CachedLocalUser::from(user_profile));
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
|
@ -169,10 +182,9 @@ impl LocalUserCacheService {
|
|||
|
||||
async fn map_cache_user(
|
||||
&self,
|
||||
user: Option<ck::user::Model>,
|
||||
) -> Result<Option<Arc<ck::user::Model>>, UserCacheError> {
|
||||
user: Option<CachedLocalUser>,
|
||||
) -> Result<Option<CachedLocalUser>, UserCacheError> {
|
||||
if let Some(user) = user {
|
||||
let user = Arc::new(user);
|
||||
self.cache.lock().await.maybe_refresh(&user);
|
||||
return Ok(Some(user));
|
||||
}
|
||||
|
@ -183,29 +195,27 @@ impl LocalUserCacheService {
|
|||
pub async fn get_by_token(
|
||||
&self,
|
||||
token: &str,
|
||||
) -> Result<Option<Arc<ck::user::Model>>, UserCacheError> {
|
||||
) -> Result<Option<CachedLocalUser>, UserCacheError> {
|
||||
let result = self.cache.lock().await.get_by_token(token);
|
||||
|
||||
if let Some(user) = result {
|
||||
return Ok(Some(user));
|
||||
}
|
||||
|
||||
self.map_cache_user(self.db.get_user_by_token(token).await?)
|
||||
self.map_cache_user(self.db.get_user_and_profile_by_token(token).await?.map(CachedLocalUser::from))
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_by_id(
|
||||
&self,
|
||||
id: &str,
|
||||
) -> Result<Option<Arc<ck::user::Model>>, UserCacheError> {
|
||||
) -> Result<Option<CachedLocalUser>, UserCacheError> {
|
||||
let result = self.cache.lock().await.get_by_id(id);
|
||||
|
||||
if let Some(user) = result {
|
||||
return Ok(Some(user));
|
||||
}
|
||||
|
||||
let user = self.db.get_user_by_id(id).await?;
|
||||
|
||||
self.map_cache_user(user).await
|
||||
self.map_cache_user(self.db.get_user_and_profile_by_id(id).await?.map(CachedLocalUser::from)).await
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,22 +1,25 @@
|
|||
use crate::service::local_user_cache::UserCacheError;
|
||||
use crate::service::MagnetarService;
|
||||
use crate::web::{ApiError, IntoErrorCode};
|
||||
use axum::async_trait;
|
||||
use axum::extract::rejection::ExtensionRejection;
|
||||
use axum::extract::{FromRequestParts, Request, State};
|
||||
use axum::http::request::Parts;
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
use axum::middleware::Next;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use headers::authorization::Bearer;
|
||||
use headers::{Authorization, HeaderMapExt};
|
||||
use magnetar_model::{ck, CalckeyDbError};
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::async_trait;
|
||||
use axum::extract::{FromRequestParts, Request, State};
|
||||
use axum::extract::rejection::ExtensionRejection;
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
use axum::http::request::Parts;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use headers::{Authorization, HeaderMapExt};
|
||||
use headers::authorization::Bearer;
|
||||
use strum::IntoStaticStr;
|
||||
use thiserror::Error;
|
||||
use tracing::error;
|
||||
|
||||
use magnetar_model::{CalckeyDbError, ck};
|
||||
|
||||
use crate::service::local_user_cache::{CachedLocalUser, UserCacheError};
|
||||
use crate::service::MagnetarService;
|
||||
use crate::web::{ApiError, IntoErrorCode};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum AuthMode {
|
||||
User {
|
||||
|
@ -178,7 +181,7 @@ impl AuthState {
|
|||
let user_cache = &self.service.local_user_cache;
|
||||
let user = user_cache.get_by_token(token).await?;
|
||||
|
||||
if let Some(user) = user {
|
||||
if let Some(CachedLocalUser { user, .. }) = user {
|
||||
return Ok(AuthMode::User { user });
|
||||
}
|
||||
|
||||
|
@ -205,7 +208,7 @@ impl AuthState {
|
|||
});
|
||||
}
|
||||
|
||||
let user = user.unwrap();
|
||||
let CachedLocalUser { user, .. } = user.unwrap();
|
||||
|
||||
if let Some(app_id) = &access_token.app_id {
|
||||
return match self.service.db.get_app_by_id(app_id).await? {
|
||||
|
|
Loading…
Reference in New Issue