Offload some blocking tasks to worker threads
ci/woodpecker/push/ociImagePush Pipeline was successful Details

This commit is contained in:
Natty 2024-11-25 22:48:27 +01:00
parent 5666bb4622
commit 5363a0c137
Signed by: natty
GPG Key ID: BF6CB659ADEE60EC
5 changed files with 88 additions and 27 deletions

View File

@ -1,12 +1,13 @@
use std::{borrow::Cow, fmt::Display, string::FromUtf8Error, sync::Arc};
use chrono::Utc; use chrono::Utc;
use futures::{FutureExt, TryFutureExt}; use futures::TryFutureExt;
use http::{HeaderMap, HeaderName, HeaderValue, Method}; use http::{HeaderMap, HeaderName, HeaderValue, Method};
use indexmap::IndexSet; use indexmap::IndexSet;
use serde_json::Value; use serde_json::Value;
use sha2::Digest; use sha2::Digest;
use std::{fmt::Display, string::FromUtf8Error, sync::Arc};
use thiserror::Error; use thiserror::Error;
use tokio::task;
use tokio::task::JoinError;
use url::Url; use url::Url;
use magnetar_core::web_model::content_type::ContentActivityStreams; use magnetar_core::web_model::content_type::ContentActivityStreams;
@ -77,6 +78,8 @@ pub enum ApClientError {
InvalidHeaderValue(#[from] http::header::InvalidHeaderValue), InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
#[error("UTF-8 parse error: {0}")] #[error("UTF-8 parse error: {0}")]
Utf8ParseError(#[from] FromUtf8Error), Utf8ParseError(#[from] FromUtf8Error),
#[error("Task join error: {0}")]
JoinError(#[from] JoinError),
} }
trait CreateField { trait CreateField {
@ -244,7 +247,7 @@ impl SigningParts for SigningInputPostHs2019<'_> {
impl ApClientService for ApClientServiceDefaultProvider { impl ApClientService for ApClientServiceDefaultProvider {
type Error = ApClientError; type Error = ApClientError;
fn sign_request( async fn sign_request(
&self, &self,
signing_key: ApSigningKey<'_>, signing_key: ApSigningKey<'_>,
signing_algorithm: SigningAlgorithm, signing_algorithm: SigningAlgorithm,
@ -258,12 +261,16 @@ impl ApClientService for ApClientServiceDefaultProvider {
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("\n"); .join("\n");
let signature = signing_key let key_id = signing_key.key_id.clone().into_owned();
.key let key = signing_key.into_owned();
.sign_base64(signing_algorithm, &message.into_bytes())?; let signature = task::spawn_blocking(move || {
key
.key
.sign_base64(signing_algorithm, &message.into_bytes())
}).await??;
Ok(ApSignature { Ok(ApSignature {
key_id: signing_key.key_id.clone().into_owned(), key_id,
algorithm: Some(signing_algorithm), algorithm: Some(signing_algorithm),
created: request.get_created().cloned(), created: request.get_created().cloned(),
expires: request.get_expires().cloned(), expires: request.get_expires().cloned(),
@ -287,8 +294,7 @@ impl ApClientService for ApClientServiceDefaultProvider {
let signed = match signing_algorithm { let signed = match signing_algorithm {
SigningAlgorithm::RsaSha256 => self.sign_request( SigningAlgorithm::RsaSha256 => self.sign_request(
signing_key, signing_key,
signing_algorithm, signing_algorithm, &SigningInputGetRsaSha256 {
&SigningInputGetRsaSha256 {
request_target: RequestTarget { request_target: RequestTarget {
url: &url, url: &url,
method: Method::GET, method: Method::GET,
@ -297,7 +303,7 @@ impl ApClientService for ApClientServiceDefaultProvider {
date: DateHeader(time_created), date: DateHeader(time_created),
expires: expires.map(ExpiresPseudoHeader), expires: expires.map(ExpiresPseudoHeader),
}, },
)?, ).await?,
SigningAlgorithm::Hs2019 => self.sign_request( SigningAlgorithm::Hs2019 => self.sign_request(
signing_key, signing_key,
signing_algorithm, signing_algorithm,
@ -310,7 +316,7 @@ impl ApClientService for ApClientServiceDefaultProvider {
created: CreatedPseudoHeader(time_created), created: CreatedPseudoHeader(time_created),
expires: expires.map(ExpiresPseudoHeader), expires: expires.map(ExpiresPseudoHeader),
}, },
)?, ).await?,
}; };
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
@ -348,9 +354,13 @@ impl ApClientService for ApClientServiceDefaultProvider {
) -> Result<String, Self::Error> { ) -> Result<String, Self::Error> {
let url = url.parse()?; let url = url.parse()?;
let body_bytes = serde_json::to_vec(body)?; let body_bytes = serde_json::to_vec(body)?;
let mut sha = sha2::Sha256::new(); // Move in, move out :3
sha.update(&body_bytes); let (digest_raw, body_bytes) = task::spawn_blocking(move || {
let digest_raw = sha.finalize(); let mut sha = sha2::Sha256::new();
sha.update(&body_bytes);
(sha.finalize(), body_bytes)
}).await?;
use base64::prelude::*; use base64::prelude::*;
let digest_base64 = format!("sha-256={}", BASE64_STANDARD.encode(digest_raw)); let digest_base64 = format!("sha-256={}", BASE64_STANDARD.encode(digest_raw));
let time_created = Utc::now(); let time_created = Utc::now();
@ -368,7 +378,7 @@ impl ApClientService for ApClientServiceDefaultProvider {
digest: DigestHeader(&digest_base64), digest: DigestHeader(&digest_base64),
expires: expires.map(ExpiresPseudoHeader), expires: expires.map(ExpiresPseudoHeader),
}, },
)?, ).await?,
SigningAlgorithm::Hs2019 => self.sign_request( SigningAlgorithm::Hs2019 => self.sign_request(
signing_key, signing_key,
signing_algorithm, signing_algorithm,
@ -382,7 +392,7 @@ impl ApClientService for ApClientServiceDefaultProvider {
digest: DigestHeader(&digest_base64), digest: DigestHeader(&digest_base64),
expires: expires.map(ExpiresPseudoHeader), expires: expires.map(ExpiresPseudoHeader),
}, },
)?, ).await?,
}; };
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
@ -450,7 +460,7 @@ mod test {
25, 25,
UserAgent::from_static("magnetar/0.42 (https://astolfo.social)"), UserAgent::from_static("magnetar/0.42 (https://astolfo.social)"),
) )
.into_diagnostic()?, .into_diagnostic()?,
)), )),
}; };

View File

@ -241,6 +241,25 @@ pub struct ApSigningKey<'a> {
pub key_id: Cow<'a, str>, pub key_id: Cow<'a, str>,
} }
impl<'a> ApSigningKey<'a> {
pub fn into_owned(self) -> ApSigningKey<'static> {
ApSigningKey {
key: self.key.into_owned(),
key_id: Cow::Owned(self.key_id.into_owned()),
}
}
}
impl<'a> ApHttpSigningKey<'a> {
pub fn into_owned(self) -> ApHttpSigningKey<'static> {
match self {
ApHttpSigningKey::RsaSha256(k) => ApHttpSigningKey::RsaSha256(Cow::Owned(k.into_owned())),
ApHttpSigningKey::RsaSha512(k) => ApHttpSigningKey::RsaSha512(Cow::Owned(k.into_owned())),
ApHttpSigningKey::Ed25519(k) => ApHttpSigningKey::Ed25519(Cow::Owned(k.into_owned())),
}
}
}
impl ApHttpSigningKey<'_> { impl ApHttpSigningKey<'_> {
pub fn sign( pub fn sign(
&self, &self,

View File

@ -62,10 +62,10 @@ pub trait WebFingerResolverService: Send + Sync {
resolved_uri, resolved_uri,
percent_encoding::NON_ALPHANUMERIC, percent_encoding::NON_ALPHANUMERIC,
) )
.to_string(), .to_string(),
), ),
) )
.await .await
} }
} }
@ -137,12 +137,12 @@ pub struct ApSignature {
#[derive(Debug)] #[derive(Debug)]
pub struct ApSigningHeaders(pub(crate) IndexSet<ApSigningField>); pub struct ApSigningHeaders(pub(crate) IndexSet<ApSigningField>);
pub trait SigningParts { pub trait SigningParts: Send {
fn get_created(&self) -> Option<&DateTime<Utc>>; fn get_created(&self) -> Option<&DateTime<Utc>>;
fn get_expires(&self) -> Option<&DateTime<Utc>>; fn get_expires(&self) -> Option<&DateTime<Utc>>;
} }
pub trait SigningInput: SigningParts { pub trait SigningInput: SigningParts + Sync {
fn create_signing_input(&self) -> Vec<(ApSigningField, String)>; fn create_signing_input(&self) -> Vec<(ApSigningField, String)>;
} }
@ -150,7 +150,7 @@ pub trait SigningInput: SigningParts {
pub trait ApClientService: Send + Sync { pub trait ApClientService: Send + Sync {
type Error; type Error;
fn sign_request( async fn sign_request(
&self, &self,
signing_key: ApSigningKey<'_>, signing_key: ApSigningKey<'_>,
signing_algorithm: SigningAlgorithm, signing_algorithm: SigningAlgorithm,

View File

@ -41,6 +41,8 @@ pub enum RetriableLocalDeliveryTaskError {
JsonSerialization(String), JsonSerialization(String),
#[error("User cache error: {0}")] #[error("User cache error: {0}")]
UserCache(String), UserCache(String),
#[error("Task join error: {0}")]
JoinError(String),
} }
#[derive(Debug, Error, Serialize)] #[derive(Debug, Error, Serialize)]
@ -178,6 +180,9 @@ impl From<ApClientError> for DeliveryErrorKind {
ApClientError::Utf8ParseError(e) => { ApClientError::Utf8ParseError(e) => {
RetriableRemoteDeliveryError::Utf8(e.to_string()).into() RetriableRemoteDeliveryError::Utf8(e.to_string()).into()
} }
ApClientError::JoinError(e) => {
RetriableLocalDeliveryTaskError::JoinError(e.to_string()).into()
}
} }
} }
} }
@ -188,7 +193,7 @@ impl From<FederationClientError> for DeliveryErrorKind {
FederationClientError::TimeoutError => RetriableRemoteDeliveryError::Timeout( FederationClientError::TimeoutError => RetriableRemoteDeliveryError::Timeout(
"Reached maximum time for response".to_string(), "Reached maximum time for response".to_string(),
) )
.into(), .into(),
FederationClientError::ReqwestError(e) => e.into(), FederationClientError::ReqwestError(e) => e.into(),
FederationClientError::JsonError(e) => { FederationClientError::JsonError(e) => {
RetriableRemoteDeliveryError::Json(e.to_string()).into() RetriableRemoteDeliveryError::Json(e.to_string()).into()

View File

@ -2,9 +2,12 @@ use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use cached::{Cached, TimedCache}; use cached::{Cached, TimedCache};
use futures_util::future::OptionFuture;
use miette::Diagnostic; use miette::Diagnostic;
use thiserror::Error; use thiserror::Error;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::task;
use tokio::task::JoinError;
use tracing::error; use tracing::error;
use crate::web::ApiError; use crate::web::ApiError;
@ -27,6 +30,8 @@ pub enum UserCacheError {
PrivateKeyParseError(#[from] ApHttpPrivateKeyParseError), PrivateKeyParseError(#[from] ApHttpPrivateKeyParseError),
#[error("Public key parse error: {0}")] #[error("Public key parse error: {0}")]
PublicKeyParseError(#[from] ApHttpPublicKeyParseError), PublicKeyParseError(#[from] ApHttpPublicKeyParseError),
#[error("Task join error: {0}")]
JoinError(#[from] JoinError),
} }
impl From<UserCacheError> for ApiError { impl From<UserCacheError> for ApiError {
@ -216,8 +221,18 @@ impl LocalUserCacheService {
return Ok(Some(user)); return Ok(Some(user));
} }
self.map_cache_user(self.db.get_user_for_cache_by_token(token).await?.map(CachedLocalUser::try_from).transpose()?) let fetch: OptionFuture<_> = self.db
.await .get_user_for_cache_by_token(token)
.await?
.map(|p| task::spawn_blocking(move || CachedLocalUser::try_from(p)))
.into();
self.map_cache_user(
fetch
.await
.transpose()?
.transpose()?
).await
} }
pub async fn get_by_id( pub async fn get_by_id(
@ -230,6 +245,18 @@ impl LocalUserCacheService {
return Ok(Some(user)); return Ok(Some(user));
} }
self.map_cache_user(self.db.get_user_for_cache_by_id(id).await?.map(CachedLocalUser::try_from).transpose()?).await
let fetch: OptionFuture<_> = self.db
.get_user_for_cache_by_id(id)
.await?
.map(|p| task::spawn_blocking(move || CachedLocalUser::try_from(p)))
.into();
self.map_cache_user(
fetch
.await
.transpose()?
.transpose()?
).await
} }
} }