From 5363a0c137fbfc9e6fb25f6342f318011ebd86ef Mon Sep 17 00:00:00 2001 From: Natty Date: Mon, 25 Nov 2024 22:48:27 +0100 Subject: [PATCH] Offload some blocking tasks to worker threads --- ext_federation/src/ap_client.rs | 46 ++++++++++++++++++++------------- ext_federation/src/crypto.rs | 19 ++++++++++++++ ext_federation/src/lib.rs | 10 +++---- src/model/activity/delivery.rs | 7 ++++- src/service/local_user_cache.rs | 33 ++++++++++++++++++++--- 5 files changed, 88 insertions(+), 27 deletions(-) diff --git a/ext_federation/src/ap_client.rs b/ext_federation/src/ap_client.rs index b02a0fa..264d38d 100644 --- a/ext_federation/src/ap_client.rs +++ b/ext_federation/src/ap_client.rs @@ -1,12 +1,13 @@ -use std::{borrow::Cow, fmt::Display, string::FromUtf8Error, sync::Arc}; - use chrono::Utc; -use futures::{FutureExt, TryFutureExt}; +use futures::TryFutureExt; use http::{HeaderMap, HeaderName, HeaderValue, Method}; use indexmap::IndexSet; use serde_json::Value; use sha2::Digest; +use std::{fmt::Display, string::FromUtf8Error, sync::Arc}; use thiserror::Error; +use tokio::task; +use tokio::task::JoinError; use url::Url; use magnetar_core::web_model::content_type::ContentActivityStreams; @@ -77,6 +78,8 @@ pub enum ApClientError { InvalidHeaderValue(#[from] http::header::InvalidHeaderValue), #[error("UTF-8 parse error: {0}")] Utf8ParseError(#[from] FromUtf8Error), + #[error("Task join error: {0}")] + JoinError(#[from] JoinError), } trait CreateField { @@ -244,7 +247,7 @@ impl SigningParts for SigningInputPostHs2019<'_> { impl ApClientService for ApClientServiceDefaultProvider { type Error = ApClientError; - fn sign_request( + async fn sign_request( &self, signing_key: ApSigningKey<'_>, signing_algorithm: SigningAlgorithm, @@ -258,12 +261,16 @@ impl ApClientService for ApClientServiceDefaultProvider { .collect::>() .join("\n"); - let signature = signing_key - .key - .sign_base64(signing_algorithm, &message.into_bytes())?; + let key_id = signing_key.key_id.clone().into_owned(); + let key = signing_key.into_owned(); + let signature = task::spawn_blocking(move || { + key + .key + .sign_base64(signing_algorithm, &message.into_bytes()) + }).await??; Ok(ApSignature { - key_id: signing_key.key_id.clone().into_owned(), + key_id, algorithm: Some(signing_algorithm), created: request.get_created().cloned(), expires: request.get_expires().cloned(), @@ -287,8 +294,7 @@ impl ApClientService for ApClientServiceDefaultProvider { let signed = match signing_algorithm { SigningAlgorithm::RsaSha256 => self.sign_request( signing_key, - signing_algorithm, - &SigningInputGetRsaSha256 { + signing_algorithm, &SigningInputGetRsaSha256 { request_target: RequestTarget { url: &url, method: Method::GET, @@ -297,7 +303,7 @@ impl ApClientService for ApClientServiceDefaultProvider { date: DateHeader(time_created), expires: expires.map(ExpiresPseudoHeader), }, - )?, + ).await?, SigningAlgorithm::Hs2019 => self.sign_request( signing_key, signing_algorithm, @@ -310,7 +316,7 @@ impl ApClientService for ApClientServiceDefaultProvider { created: CreatedPseudoHeader(time_created), expires: expires.map(ExpiresPseudoHeader), }, - )?, + ).await?, }; let mut headers = HeaderMap::new(); @@ -348,9 +354,13 @@ impl ApClientService for ApClientServiceDefaultProvider { ) -> Result { let url = url.parse()?; let body_bytes = serde_json::to_vec(body)?; - let mut sha = sha2::Sha256::new(); - sha.update(&body_bytes); - let digest_raw = sha.finalize(); + // Move in, move out :3 + let (digest_raw, body_bytes) = task::spawn_blocking(move || { + let mut sha = sha2::Sha256::new(); + sha.update(&body_bytes); + (sha.finalize(), body_bytes) + }).await?; + use base64::prelude::*; let digest_base64 = format!("sha-256={}", BASE64_STANDARD.encode(digest_raw)); let time_created = Utc::now(); @@ -368,7 +378,7 @@ impl ApClientService for ApClientServiceDefaultProvider { digest: DigestHeader(&digest_base64), expires: expires.map(ExpiresPseudoHeader), }, - )?, + ).await?, SigningAlgorithm::Hs2019 => self.sign_request( signing_key, signing_algorithm, @@ -382,7 +392,7 @@ impl ApClientService for ApClientServiceDefaultProvider { digest: DigestHeader(&digest_base64), expires: expires.map(ExpiresPseudoHeader), }, - )?, + ).await?, }; let mut headers = HeaderMap::new(); @@ -450,7 +460,7 @@ mod test { 25, UserAgent::from_static("magnetar/0.42 (https://astolfo.social)"), ) - .into_diagnostic()?, + .into_diagnostic()?, )), }; diff --git a/ext_federation/src/crypto.rs b/ext_federation/src/crypto.rs index d69324e..1b25886 100644 --- a/ext_federation/src/crypto.rs +++ b/ext_federation/src/crypto.rs @@ -241,6 +241,25 @@ pub struct ApSigningKey<'a> { pub key_id: Cow<'a, str>, } +impl<'a> ApSigningKey<'a> { + pub fn into_owned(self) -> ApSigningKey<'static> { + ApSigningKey { + key: self.key.into_owned(), + key_id: Cow::Owned(self.key_id.into_owned()), + } + } +} + +impl<'a> ApHttpSigningKey<'a> { + pub fn into_owned(self) -> ApHttpSigningKey<'static> { + match self { + ApHttpSigningKey::RsaSha256(k) => ApHttpSigningKey::RsaSha256(Cow::Owned(k.into_owned())), + ApHttpSigningKey::RsaSha512(k) => ApHttpSigningKey::RsaSha512(Cow::Owned(k.into_owned())), + ApHttpSigningKey::Ed25519(k) => ApHttpSigningKey::Ed25519(Cow::Owned(k.into_owned())), + } + } +} + impl ApHttpSigningKey<'_> { pub fn sign( &self, diff --git a/ext_federation/src/lib.rs b/ext_federation/src/lib.rs index 4b0801c..15852c1 100644 --- a/ext_federation/src/lib.rs +++ b/ext_federation/src/lib.rs @@ -62,10 +62,10 @@ pub trait WebFingerResolverService: Send + Sync { resolved_uri, percent_encoding::NON_ALPHANUMERIC, ) - .to_string(), + .to_string(), ), ) - .await + .await } } @@ -137,12 +137,12 @@ pub struct ApSignature { #[derive(Debug)] pub struct ApSigningHeaders(pub(crate) IndexSet); -pub trait SigningParts { +pub trait SigningParts: Send { fn get_created(&self) -> Option<&DateTime>; fn get_expires(&self) -> Option<&DateTime>; } -pub trait SigningInput: SigningParts { +pub trait SigningInput: SigningParts + Sync { fn create_signing_input(&self) -> Vec<(ApSigningField, String)>; } @@ -150,7 +150,7 @@ pub trait SigningInput: SigningParts { pub trait ApClientService: Send + Sync { type Error; - fn sign_request( + async fn sign_request( &self, signing_key: ApSigningKey<'_>, signing_algorithm: SigningAlgorithm, diff --git a/src/model/activity/delivery.rs b/src/model/activity/delivery.rs index 2e2d088..31bb9b1 100644 --- a/src/model/activity/delivery.rs +++ b/src/model/activity/delivery.rs @@ -41,6 +41,8 @@ pub enum RetriableLocalDeliveryTaskError { JsonSerialization(String), #[error("User cache error: {0}")] UserCache(String), + #[error("Task join error: {0}")] + JoinError(String), } #[derive(Debug, Error, Serialize)] @@ -178,6 +180,9 @@ impl From for DeliveryErrorKind { ApClientError::Utf8ParseError(e) => { RetriableRemoteDeliveryError::Utf8(e.to_string()).into() } + ApClientError::JoinError(e) => { + RetriableLocalDeliveryTaskError::JoinError(e.to_string()).into() + } } } } @@ -188,7 +193,7 @@ impl From for DeliveryErrorKind { FederationClientError::TimeoutError => RetriableRemoteDeliveryError::Timeout( "Reached maximum time for response".to_string(), ) - .into(), + .into(), FederationClientError::ReqwestError(e) => e.into(), FederationClientError::JsonError(e) => { RetriableRemoteDeliveryError::Json(e.to_string()).into() diff --git a/src/service/local_user_cache.rs b/src/service/local_user_cache.rs index 296c346..5e639be 100644 --- a/src/service/local_user_cache.rs +++ b/src/service/local_user_cache.rs @@ -2,9 +2,12 @@ use std::collections::HashMap; use std::sync::Arc; use cached::{Cached, TimedCache}; +use futures_util::future::OptionFuture; use miette::Diagnostic; use thiserror::Error; use tokio::sync::Mutex; +use tokio::task; +use tokio::task::JoinError; use tracing::error; use crate::web::ApiError; @@ -27,6 +30,8 @@ pub enum UserCacheError { PrivateKeyParseError(#[from] ApHttpPrivateKeyParseError), #[error("Public key parse error: {0}")] PublicKeyParseError(#[from] ApHttpPublicKeyParseError), + #[error("Task join error: {0}")] + JoinError(#[from] JoinError), } impl From for ApiError { @@ -216,8 +221,18 @@ impl LocalUserCacheService { return Ok(Some(user)); } - self.map_cache_user(self.db.get_user_for_cache_by_token(token).await?.map(CachedLocalUser::try_from).transpose()?) - .await + let fetch: OptionFuture<_> = self.db + .get_user_for_cache_by_token(token) + .await? + .map(|p| task::spawn_blocking(move || CachedLocalUser::try_from(p))) + .into(); + + self.map_cache_user( + fetch + .await + .transpose()? + .transpose()? + ).await } pub async fn get_by_id( @@ -230,6 +245,18 @@ impl LocalUserCacheService { return Ok(Some(user)); } - self.map_cache_user(self.db.get_user_for_cache_by_id(id).await?.map(CachedLocalUser::try_from).transpose()?).await + + let fetch: OptionFuture<_> = self.db + .get_user_for_cache_by_id(id) + .await? + .map(|p| task::spawn_blocking(move || CachedLocalUser::try_from(p))) + .into(); + + self.map_cache_user( + fetch + .await + .transpose()? + .transpose()? + ).await } }