Also cache the local user key pair

This commit is contained in:
Natty 2024-11-12 22:07:48 +01:00
parent 88df8eca55
commit 5241b18b0d
Signed by: natty
GPG Key ID: BF6CB659ADEE60EC
3 changed files with 101 additions and 30 deletions

View File

@ -9,15 +9,15 @@ use sea_orm::{
ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter, ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter,
TransactionTrait, TransactionTrait,
}; };
use serde::{Deserialize, Deserializer, Serialize};
use serde::de::Error; use serde::de::Error;
use serde::{Deserialize, Deserializer, Serialize};
use serde_json::Value; use serde_json::Value;
use strum::IntoStaticStr; use strum::IntoStaticStr;
use thiserror::Error; use thiserror::Error;
use tokio::select; use tokio::select;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{error, info, trace, warn};
use tracing::log::LevelFilter; use tracing::log::LevelFilter;
use tracing::{error, info, trace, warn};
use url::Host; use url::Host;
pub use ck; pub use ck;
@ -122,13 +122,25 @@ impl CalckeyModel {
.await?) .await?)
} }
pub async fn get_user_and_profile_by_id(&self, id: &str) -> Result<Option<(user::Model, user_profile::Model)>, CalckeyDbError> { pub async fn get_user_for_cache_by_id(&self, id: &str) -> Result<Option<(user::Model, user_profile::Model, user_keypair::Model)>, CalckeyDbError> {
Ok(user::Entity::find() let txn = self.0.begin().await?;
let Some((user, Some(profile))) = user::Entity::find()
.filter(user::Column::Id.eq(id)) .filter(user::Column::Id.eq(id))
.find_also_related(user_profile::Entity) .find_also_related(user_profile::Entity)
.one(&self.0) .one(&txn)
.await? .await? else {
.and_then(|(u, p)| p.map(|pp| (u, pp)))) return Ok(None);
};
let Some(keys) = user_keypair::Entity::find()
.filter(user_keypair::Column::UserId.eq(id))
.one(&txn)
.await? else {
return Ok(None);
};
Ok(Some((user, profile, keys)))
} }
pub async fn get_user_security_keys_by_id( pub async fn get_user_security_keys_by_id(
@ -165,20 +177,30 @@ impl CalckeyModel {
.await?) .await?)
} }
pub async fn get_user_and_profile_by_token( pub async fn get_user_for_cache_by_token(
&self, &self,
token: &str, token: &str,
) -> Result<Option<(user::Model, user_profile::Model)>, CalckeyDbError> { ) -> Result<Option<(user::Model, user_profile::Model, user_keypair::Model)>, CalckeyDbError> {
Ok(user::Entity::find() let txn = self.0.begin().await?;
.filter(
user::Column::Token let Some((user, Some(profile))) = user::Entity::find()
.eq(token) .filter(user::Column::Token
.and(user::Column::Host.is_null()), .eq(token)
) .and(user::Column::Host.is_null()))
.find_also_related(user_profile::Entity) .find_also_related(user_profile::Entity)
.one(&self.0) .one(&txn)
.await? .await? else {
.and_then(|(u, p)| p.map(|pp| (u, pp)))) return Ok(None);
};
let Some(keys) = user_keypair::Entity::find()
.filter(user_keypair::Column::UserId.eq(&user.id))
.one(&txn)
.await? else {
return Ok(None);
};
Ok(Some((user, profile, keys)))
} }
pub async fn get_user_by_uri(&self, uri: &str) -> Result<Option<user::Model>, CalckeyDbError> { pub async fn get_user_by_uri(&self, uri: &str) -> Result<Option<user::Model>, CalckeyDbError> {
@ -399,8 +421,8 @@ struct RawMessage<'a> {
impl<'de> Deserialize<'de> for SubMessage { impl<'de> Deserialize<'de> for SubMessage {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let raw = RawMessage::deserialize(deserializer)?; let raw = RawMessage::deserialize(deserializer)?;

View File

@ -7,34 +7,44 @@ use thiserror::Error;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::error; use tracing::error;
use crate::web::ApiError;
use magnetar_common::config::MagnetarConfig; use magnetar_common::config::MagnetarConfig;
use magnetar_federation::crypto::{ApHttpPrivateKey, ApHttpPrivateKeyParseError, ApHttpPublicKey, ApHttpPublicKeyParseError};
use magnetar_model::{ use magnetar_model::{
ck, CalckeyCache, CalckeyCacheError, CalckeyDbError, CalckeyModel, CalckeySub, ck, CalckeyCache, CalckeyCacheError, CalckeyDbError, CalckeyModel, CalckeySub,
InternalStreamMessage, SubMessage, InternalStreamMessage, SubMessage,
}; };
use crate::web::ApiError;
#[derive(Debug, Error, VariantNames)] #[derive(Debug, Error, VariantNames)]
pub enum UserCacheError { pub enum UserCacheError {
#[error("Database error: {0}")] #[error("Database error: {0}")]
DbError(#[from] CalckeyDbError), DbError(#[from] CalckeyDbError),
#[error("Redis error: {0}")] #[error("Redis error: {0}")]
RedisError(#[from] CalckeyCacheError), RedisError(#[from] CalckeyCacheError),
#[error("Private key parse error: {0}")]
PrivateKeyParseError(#[from] ApHttpPrivateKeyParseError),
#[error("Public key parse error: {0}")]
PublicKeyParseError(#[from] ApHttpPublicKeyParseError),
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct CachedLocalUser { pub struct CachedLocalUser {
pub user: Arc<ck::user::Model>, pub user: Arc<ck::user::Model>,
pub profile: Arc<ck::user_profile::Model>, pub profile: Arc<ck::user_profile::Model>,
pub private_key: Arc<ApHttpPrivateKey<'static>>,
pub public_key: Arc<ApHttpPublicKey<'static>>,
} }
impl From<(ck::user::Model, ck::user_profile::Model)> for CachedLocalUser { impl TryFrom<(ck::user::Model, ck::user_profile::Model, ck::user_keypair::Model)> for CachedLocalUser {
fn from((user, profile): (ck::user::Model, ck::user_profile::Model)) -> Self { type Error = UserCacheError;
CachedLocalUser {
fn try_from((user, profile, key_pair): (ck::user::Model, ck::user_profile::Model, ck::user_keypair::Model)) -> Result<Self, Self::Error> {
Ok(CachedLocalUser {
user: Arc::new(user), user: Arc::new(user),
profile: Arc::new(profile), profile: Arc::new(profile),
} private_key: Arc::new(key_pair.private_key.parse()?),
public_key: Arc::new(key_pair.public_key.parse()?),
})
} }
} }
@ -43,6 +53,8 @@ impl From<UserCacheError> for ApiError {
let mut api_error: ApiError = match err { let mut api_error: ApiError = match err {
UserCacheError::DbError(err) => err.into(), UserCacheError::DbError(err) => err.into(),
UserCacheError::RedisError(err) => err.into(), UserCacheError::RedisError(err) => err.into(),
UserCacheError::PublicKeyParseError(err) => err.into(),
UserCacheError::PrivateKeyParseError(err) => err.into()
}; };
api_error.message = format!("Local user cache error: {}", api_error.message); api_error.message = format!("Local user cache error: {}", api_error.message);
@ -156,7 +168,7 @@ impl LocalUserCacheService {
| InternalStreamMessage::UserChangeSuspendedState { id, .. } | InternalStreamMessage::UserChangeSuspendedState { id, .. }
| InternalStreamMessage::RemoteUserUpdated { id } | InternalStreamMessage::RemoteUserUpdated { id }
| InternalStreamMessage::UserTokenRegenerated { id, .. } => { | InternalStreamMessage::UserTokenRegenerated { id, .. } => {
let user_profile = match db.get_user_and_profile_by_id(&id).await { let user_profile = match db.get_user_for_cache_by_id(&id).await {
Ok(Some(m)) => m, Ok(Some(m)) => m,
Ok(None) => return, Ok(None) => return,
Err(e) => { Err(e) => {
@ -165,7 +177,15 @@ impl LocalUserCacheService {
} }
}; };
cache.lock().await.refresh(&CachedLocalUser::from(user_profile)); let cached: CachedLocalUser = match user_profile.try_into() {
Ok(c) => c,
Err(e) => {
error!("Error parsing user from database: {}", e);
return;
}
};
cache.lock().await.refresh(&cached);
} }
_ => {} _ => {}
}; };
@ -202,7 +222,7 @@ impl LocalUserCacheService {
return Ok(Some(user)); return Ok(Some(user));
} }
self.map_cache_user(self.db.get_user_and_profile_by_token(token).await?.map(CachedLocalUser::from)) self.map_cache_user(self.db.get_user_for_cache_by_token(token).await?.map(CachedLocalUser::try_from).transpose()?)
.await .await
} }
@ -216,6 +236,6 @@ impl LocalUserCacheService {
return Ok(Some(user)); return Ok(Some(user));
} }
self.map_cache_user(self.db.get_user_and_profile_by_id(id).await?.map(CachedLocalUser::from)).await self.map_cache_user(self.db.get_user_for_cache_by_id(id).await?.map(CachedLocalUser::try_from).transpose()?).await
} }
} }

View File

@ -3,6 +3,7 @@ use axum::http::StatusCode;
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
use axum::Json; use axum::Json;
use magnetar_common::util::FediverseTagParseError; use magnetar_common::util::FediverseTagParseError;
use magnetar_federation::crypto::{ApHttpPrivateKeyParseError, ApHttpPublicKeyParseError};
use magnetar_model::{CalckeyCacheError, CalckeyDbError}; use magnetar_model::{CalckeyCacheError, CalckeyDbError};
use serde::Serialize; use serde::Serialize;
use serde_json::json; use serde_json::json;
@ -190,3 +191,31 @@ impl From<ArgumentOutOfRange> for ApiError {
} }
} }
} }
impl From<ApHttpPublicKeyParseError> for ApiError {
fn from(err: ApHttpPublicKeyParseError) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: "ApHttpPublicKeyParseError".error_code(),
message: if cfg!(debug_assertions) {
format!("User public key parse error: {}", err)
} else {
"User public key parse error".to_string()
},
}
}
}
impl From<ApHttpPrivateKeyParseError> for ApiError {
fn from(err: ApHttpPrivateKeyParseError) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: "ApHttpPrivateKeyParseError".error_code(),
message: if cfg!(debug_assertions) {
format!("User private key parse error: {}", err)
} else {
"User private key parse error".to_string()
},
}
}
}