From 5241b18b0d12464fffa7a40f6cc2c3c256c65558 Mon Sep 17 00:00:00 2001 From: Natty Date: Tue, 12 Nov 2024 22:07:48 +0100 Subject: [PATCH] Also cache the local user key pair --- ext_model/src/lib.rs | 62 ++++++++++++++++++++++----------- src/service/local_user_cache.rs | 40 +++++++++++++++------ src/web/mod.rs | 29 +++++++++++++++ 3 files changed, 101 insertions(+), 30 deletions(-) diff --git a/ext_model/src/lib.rs b/ext_model/src/lib.rs index 6579e29..ae99df1 100644 --- a/ext_model/src/lib.rs +++ b/ext_model/src/lib.rs @@ -9,15 +9,15 @@ use sea_orm::{ ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter, TransactionTrait, }; -use serde::{Deserialize, Deserializer, Serialize}; use serde::de::Error; +use serde::{Deserialize, Deserializer, Serialize}; use serde_json::Value; use strum::IntoStaticStr; use thiserror::Error; use tokio::select; use tokio_util::sync::CancellationToken; -use tracing::{error, info, trace, warn}; use tracing::log::LevelFilter; +use tracing::{error, info, trace, warn}; use url::Host; pub use ck; @@ -122,13 +122,25 @@ impl CalckeyModel { .await?) } - pub async fn get_user_and_profile_by_id(&self, id: &str) -> Result, CalckeyDbError> { - Ok(user::Entity::find() + pub async fn get_user_for_cache_by_id(&self, id: &str) -> Result, CalckeyDbError> { + let txn = self.0.begin().await?; + + let Some((user, Some(profile))) = user::Entity::find() .filter(user::Column::Id.eq(id)) .find_also_related(user_profile::Entity) - .one(&self.0) - .await? - .and_then(|(u, p)| p.map(|pp| (u, pp)))) + .one(&txn) + .await? else { + return Ok(None); + }; + + let Some(keys) = user_keypair::Entity::find() + .filter(user_keypair::Column::UserId.eq(id)) + .one(&txn) + .await? else { + return Ok(None); + }; + + Ok(Some((user, profile, keys))) } pub async fn get_user_security_keys_by_id( @@ -165,20 +177,30 @@ impl CalckeyModel { .await?) } - pub async fn get_user_and_profile_by_token( + pub async fn get_user_for_cache_by_token( &self, token: &str, - ) -> Result, CalckeyDbError> { - Ok(user::Entity::find() - .filter( - user::Column::Token - .eq(token) - .and(user::Column::Host.is_null()), - ) + ) -> Result, CalckeyDbError> { + let txn = self.0.begin().await?; + + let Some((user, Some(profile))) = user::Entity::find() + .filter(user::Column::Token + .eq(token) + .and(user::Column::Host.is_null())) .find_also_related(user_profile::Entity) - .one(&self.0) - .await? - .and_then(|(u, p)| p.map(|pp| (u, pp)))) + .one(&txn) + .await? else { + return Ok(None); + }; + + let Some(keys) = user_keypair::Entity::find() + .filter(user_keypair::Column::UserId.eq(&user.id)) + .one(&txn) + .await? else { + return Ok(None); + }; + + Ok(Some((user, profile, keys))) } pub async fn get_user_by_uri(&self, uri: &str) -> Result, CalckeyDbError> { @@ -399,8 +421,8 @@ struct RawMessage<'a> { impl<'de> Deserialize<'de> for SubMessage { fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, + where + D: Deserializer<'de>, { let raw = RawMessage::deserialize(deserializer)?; diff --git a/src/service/local_user_cache.rs b/src/service/local_user_cache.rs index 0ecb50f..b621996 100644 --- a/src/service/local_user_cache.rs +++ b/src/service/local_user_cache.rs @@ -7,34 +7,44 @@ use thiserror::Error; use tokio::sync::Mutex; use tracing::error; +use crate::web::ApiError; use magnetar_common::config::MagnetarConfig; +use magnetar_federation::crypto::{ApHttpPrivateKey, ApHttpPrivateKeyParseError, ApHttpPublicKey, ApHttpPublicKeyParseError}; use magnetar_model::{ ck, CalckeyCache, CalckeyCacheError, CalckeyDbError, CalckeyModel, CalckeySub, InternalStreamMessage, SubMessage, }; -use crate::web::ApiError; - #[derive(Debug, Error, VariantNames)] pub enum UserCacheError { #[error("Database error: {0}")] DbError(#[from] CalckeyDbError), #[error("Redis error: {0}")] RedisError(#[from] CalckeyCacheError), + #[error("Private key parse error: {0}")] + PrivateKeyParseError(#[from] ApHttpPrivateKeyParseError), + #[error("Public key parse error: {0}")] + PublicKeyParseError(#[from] ApHttpPublicKeyParseError), } #[derive(Debug, Clone)] pub struct CachedLocalUser { pub user: Arc, pub profile: Arc, + pub private_key: Arc>, + pub public_key: Arc>, } -impl From<(ck::user::Model, ck::user_profile::Model)> for CachedLocalUser { - fn from((user, profile): (ck::user::Model, ck::user_profile::Model)) -> Self { - CachedLocalUser { +impl TryFrom<(ck::user::Model, ck::user_profile::Model, ck::user_keypair::Model)> for CachedLocalUser { + type Error = UserCacheError; + + fn try_from((user, profile, key_pair): (ck::user::Model, ck::user_profile::Model, ck::user_keypair::Model)) -> Result { + Ok(CachedLocalUser { user: Arc::new(user), profile: Arc::new(profile), - } + private_key: Arc::new(key_pair.private_key.parse()?), + public_key: Arc::new(key_pair.public_key.parse()?), + }) } } @@ -43,6 +53,8 @@ impl From for ApiError { let mut api_error: ApiError = match err { UserCacheError::DbError(err) => err.into(), UserCacheError::RedisError(err) => err.into(), + UserCacheError::PublicKeyParseError(err) => err.into(), + UserCacheError::PrivateKeyParseError(err) => err.into() }; api_error.message = format!("Local user cache error: {}", api_error.message); @@ -156,7 +168,7 @@ impl LocalUserCacheService { | InternalStreamMessage::UserChangeSuspendedState { id, .. } | InternalStreamMessage::RemoteUserUpdated { id } | InternalStreamMessage::UserTokenRegenerated { id, .. } => { - let user_profile = match db.get_user_and_profile_by_id(&id).await { + let user_profile = match db.get_user_for_cache_by_id(&id).await { Ok(Some(m)) => m, Ok(None) => return, Err(e) => { @@ -165,7 +177,15 @@ impl LocalUserCacheService { } }; - cache.lock().await.refresh(&CachedLocalUser::from(user_profile)); + let cached: CachedLocalUser = match user_profile.try_into() { + Ok(c) => c, + Err(e) => { + error!("Error parsing user from database: {}", e); + return; + } + }; + + cache.lock().await.refresh(&cached); } _ => {} }; @@ -202,7 +222,7 @@ impl LocalUserCacheService { return Ok(Some(user)); } - self.map_cache_user(self.db.get_user_and_profile_by_token(token).await?.map(CachedLocalUser::from)) + self.map_cache_user(self.db.get_user_for_cache_by_token(token).await?.map(CachedLocalUser::try_from).transpose()?) .await } @@ -216,6 +236,6 @@ impl LocalUserCacheService { return Ok(Some(user)); } - self.map_cache_user(self.db.get_user_and_profile_by_id(id).await?.map(CachedLocalUser::from)).await + self.map_cache_user(self.db.get_user_for_cache_by_id(id).await?.map(CachedLocalUser::try_from).transpose()?).await } } diff --git a/src/web/mod.rs b/src/web/mod.rs index 4a40641..5ea8d83 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -3,6 +3,7 @@ use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::Json; use magnetar_common::util::FediverseTagParseError; +use magnetar_federation::crypto::{ApHttpPrivateKeyParseError, ApHttpPublicKeyParseError}; use magnetar_model::{CalckeyCacheError, CalckeyDbError}; use serde::Serialize; use serde_json::json; @@ -190,3 +191,31 @@ impl From for ApiError { } } } + +impl From for ApiError { + fn from(err: ApHttpPublicKeyParseError) -> Self { + Self { + status: StatusCode::INTERNAL_SERVER_ERROR, + code: "ApHttpPublicKeyParseError".error_code(), + message: if cfg!(debug_assertions) { + format!("User public key parse error: {}", err) + } else { + "User public key parse error".to_string() + }, + } + } +} + +impl From for ApiError { + fn from(err: ApHttpPrivateKeyParseError) -> Self { + Self { + status: StatusCode::INTERNAL_SERVER_ERROR, + code: "ApHttpPrivateKeyParseError".error_code(), + message: if cfg!(debug_assertions) { + format!("User private key parse error: {}", err) + } else { + "User private key parse error".to_string() + }, + } + } +}