Also cache the local user key pair

This commit is contained in:
Natty 2024-11-12 22:07:48 +01:00
parent 88df8eca55
commit 5241b18b0d
Signed by: natty
GPG Key ID: BF6CB659ADEE60EC
3 changed files with 101 additions and 30 deletions

View File

@ -9,15 +9,15 @@ use sea_orm::{
ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter,
TransactionTrait,
};
use serde::{Deserialize, Deserializer, Serialize};
use serde::de::Error;
use serde::{Deserialize, Deserializer, Serialize};
use serde_json::Value;
use strum::IntoStaticStr;
use thiserror::Error;
use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, trace, warn};
use tracing::log::LevelFilter;
use tracing::{error, info, trace, warn};
use url::Host;
pub use ck;
@ -122,13 +122,25 @@ impl CalckeyModel {
.await?)
}
pub async fn get_user_and_profile_by_id(&self, id: &str) -> Result<Option<(user::Model, user_profile::Model)>, CalckeyDbError> {
Ok(user::Entity::find()
pub async fn get_user_for_cache_by_id(&self, id: &str) -> Result<Option<(user::Model, user_profile::Model, user_keypair::Model)>, CalckeyDbError> {
let txn = self.0.begin().await?;
let Some((user, Some(profile))) = user::Entity::find()
.filter(user::Column::Id.eq(id))
.find_also_related(user_profile::Entity)
.one(&self.0)
.await?
.and_then(|(u, p)| p.map(|pp| (u, pp))))
.one(&txn)
.await? else {
return Ok(None);
};
let Some(keys) = user_keypair::Entity::find()
.filter(user_keypair::Column::UserId.eq(id))
.one(&txn)
.await? else {
return Ok(None);
};
Ok(Some((user, profile, keys)))
}
pub async fn get_user_security_keys_by_id(
@ -165,20 +177,30 @@ impl CalckeyModel {
.await?)
}
pub async fn get_user_and_profile_by_token(
pub async fn get_user_for_cache_by_token(
&self,
token: &str,
) -> Result<Option<(user::Model, user_profile::Model)>, CalckeyDbError> {
Ok(user::Entity::find()
.filter(
user::Column::Token
.eq(token)
.and(user::Column::Host.is_null()),
)
) -> Result<Option<(user::Model, user_profile::Model, user_keypair::Model)>, CalckeyDbError> {
let txn = self.0.begin().await?;
let Some((user, Some(profile))) = user::Entity::find()
.filter(user::Column::Token
.eq(token)
.and(user::Column::Host.is_null()))
.find_also_related(user_profile::Entity)
.one(&self.0)
.await?
.and_then(|(u, p)| p.map(|pp| (u, pp))))
.one(&txn)
.await? else {
return Ok(None);
};
let Some(keys) = user_keypair::Entity::find()
.filter(user_keypair::Column::UserId.eq(&user.id))
.one(&txn)
.await? else {
return Ok(None);
};
Ok(Some((user, profile, keys)))
}
pub async fn get_user_by_uri(&self, uri: &str) -> Result<Option<user::Model>, CalckeyDbError> {
@ -399,8 +421,8 @@ struct RawMessage<'a> {
impl<'de> Deserialize<'de> for SubMessage {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
where
D: Deserializer<'de>,
{
let raw = RawMessage::deserialize(deserializer)?;

View File

@ -7,34 +7,44 @@ use thiserror::Error;
use tokio::sync::Mutex;
use tracing::error;
use crate::web::ApiError;
use magnetar_common::config::MagnetarConfig;
use magnetar_federation::crypto::{ApHttpPrivateKey, ApHttpPrivateKeyParseError, ApHttpPublicKey, ApHttpPublicKeyParseError};
use magnetar_model::{
ck, CalckeyCache, CalckeyCacheError, CalckeyDbError, CalckeyModel, CalckeySub,
InternalStreamMessage, SubMessage,
};
use crate::web::ApiError;
#[derive(Debug, Error, VariantNames)]
pub enum UserCacheError {
#[error("Database error: {0}")]
DbError(#[from] CalckeyDbError),
#[error("Redis error: {0}")]
RedisError(#[from] CalckeyCacheError),
#[error("Private key parse error: {0}")]
PrivateKeyParseError(#[from] ApHttpPrivateKeyParseError),
#[error("Public key parse error: {0}")]
PublicKeyParseError(#[from] ApHttpPublicKeyParseError),
}
#[derive(Debug, Clone)]
pub struct CachedLocalUser {
pub user: Arc<ck::user::Model>,
pub profile: Arc<ck::user_profile::Model>,
pub private_key: Arc<ApHttpPrivateKey<'static>>,
pub public_key: Arc<ApHttpPublicKey<'static>>,
}
impl From<(ck::user::Model, ck::user_profile::Model)> for CachedLocalUser {
fn from((user, profile): (ck::user::Model, ck::user_profile::Model)) -> Self {
CachedLocalUser {
impl TryFrom<(ck::user::Model, ck::user_profile::Model, ck::user_keypair::Model)> for CachedLocalUser {
type Error = UserCacheError;
fn try_from((user, profile, key_pair): (ck::user::Model, ck::user_profile::Model, ck::user_keypair::Model)) -> Result<Self, Self::Error> {
Ok(CachedLocalUser {
user: Arc::new(user),
profile: Arc::new(profile),
}
private_key: Arc::new(key_pair.private_key.parse()?),
public_key: Arc::new(key_pair.public_key.parse()?),
})
}
}
@ -43,6 +53,8 @@ impl From<UserCacheError> for ApiError {
let mut api_error: ApiError = match err {
UserCacheError::DbError(err) => err.into(),
UserCacheError::RedisError(err) => err.into(),
UserCacheError::PublicKeyParseError(err) => err.into(),
UserCacheError::PrivateKeyParseError(err) => err.into()
};
api_error.message = format!("Local user cache error: {}", api_error.message);
@ -156,7 +168,7 @@ impl LocalUserCacheService {
| InternalStreamMessage::UserChangeSuspendedState { id, .. }
| InternalStreamMessage::RemoteUserUpdated { id }
| InternalStreamMessage::UserTokenRegenerated { id, .. } => {
let user_profile = match db.get_user_and_profile_by_id(&id).await {
let user_profile = match db.get_user_for_cache_by_id(&id).await {
Ok(Some(m)) => m,
Ok(None) => return,
Err(e) => {
@ -165,7 +177,15 @@ impl LocalUserCacheService {
}
};
cache.lock().await.refresh(&CachedLocalUser::from(user_profile));
let cached: CachedLocalUser = match user_profile.try_into() {
Ok(c) => c,
Err(e) => {
error!("Error parsing user from database: {}", e);
return;
}
};
cache.lock().await.refresh(&cached);
}
_ => {}
};
@ -202,7 +222,7 @@ impl LocalUserCacheService {
return Ok(Some(user));
}
self.map_cache_user(self.db.get_user_and_profile_by_token(token).await?.map(CachedLocalUser::from))
self.map_cache_user(self.db.get_user_for_cache_by_token(token).await?.map(CachedLocalUser::try_from).transpose()?)
.await
}
@ -216,6 +236,6 @@ impl LocalUserCacheService {
return Ok(Some(user));
}
self.map_cache_user(self.db.get_user_and_profile_by_id(id).await?.map(CachedLocalUser::from)).await
self.map_cache_user(self.db.get_user_for_cache_by_id(id).await?.map(CachedLocalUser::try_from).transpose()?).await
}
}

View File

@ -3,6 +3,7 @@ use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use magnetar_common::util::FediverseTagParseError;
use magnetar_federation::crypto::{ApHttpPrivateKeyParseError, ApHttpPublicKeyParseError};
use magnetar_model::{CalckeyCacheError, CalckeyDbError};
use serde::Serialize;
use serde_json::json;
@ -190,3 +191,31 @@ impl From<ArgumentOutOfRange> for ApiError {
}
}
}
impl From<ApHttpPublicKeyParseError> for ApiError {
fn from(err: ApHttpPublicKeyParseError) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: "ApHttpPublicKeyParseError".error_code(),
message: if cfg!(debug_assertions) {
format!("User public key parse error: {}", err)
} else {
"User public key parse error".to_string()
},
}
}
}
impl From<ApHttpPrivateKeyParseError> for ApiError {
fn from(err: ApHttpPrivateKeyParseError) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: "ApHttpPrivateKeyParseError".error_code(),
message: if cfg!(debug_assertions) {
format!("User private key parse error: {}", err)
} else {
"User private key parse error".to_string()
},
}
}
}