From 3e4ae86c38c4cef9f2af6a808de9c777c2affa69 Mon Sep 17 00:00:00 2001 From: Natty Date: Thu, 19 Dec 2024 16:58:10 +0100 Subject: [PATCH] Optimizations of RPC and delivery --- Cargo.lock | 1 + Cargo.toml | 1 + ext_federation/Cargo.toml | 1 + ext_federation/src/ap_client.rs | 19 +++++---- ext_federation/src/crypto.rs | 6 +-- ext_federation/src/lib.rs | 2 +- src/rpc_v1/mod.rs | 10 ++--- src/rpc_v1/proto.rs | 74 +++++++++++++++++---------------- 8 files changed, 60 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 343d5ad..d210d42 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2028,6 +2028,7 @@ dependencies = [ "miette 7.2.0", "percent-encoding", "quick-xml", + "rand", "reqwest", "rsa", "serde", diff --git a/Cargo.toml b/Cargo.toml index 50145eb..068f26f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,6 +62,7 @@ quick-xml = "0.36" redis = "0.26" regex = "1.9" rmp-serde = "1.3" +rand = "0.8" rsa = "0.9" reqwest = "0.12" sea-orm = "1" diff --git a/ext_federation/Cargo.toml b/ext_federation/Cargo.toml index 15e3f58..fa9e8e7 100644 --- a/ext_federation/Cargo.toml +++ b/ext_federation/Cargo.toml @@ -40,6 +40,7 @@ hyper = { workspace = true, features = ["full"] } percent-encoding = { workspace = true } reqwest = { workspace = true, features = ["stream", "hickory-dns"] } +rand = { workspace = true } ed25519-dalek = { workspace = true, features = [ "pem", "pkcs8", diff --git a/ext_federation/src/ap_client.rs b/ext_federation/src/ap_client.rs index 264d38d..4e42b10 100644 --- a/ext_federation/src/ap_client.rs +++ b/ext_federation/src/ap_client.rs @@ -4,6 +4,7 @@ use http::{HeaderMap, HeaderName, HeaderValue, Method}; use indexmap::IndexSet; use serde_json::Value; use sha2::Digest; +use std::fmt::Write; use std::{fmt::Display, string::FromUtf8Error, sync::Arc}; use thiserror::Error; use tokio::task; @@ -257,16 +258,17 @@ impl ApClientService for ApClientServiceDefaultProvider { let message = components .iter() - .map(|(k, v)| format!("{}: {}", k.as_ref(), v)) - .collect::>() - .join("\n"); + .fold(String::new(), |mut acc, (k, v)| { + writeln!(&mut acc, "{}: {}", k.as_ref(), v).unwrap(); + acc + }); let key_id = signing_key.key_id.clone().into_owned(); let key = signing_key.into_owned(); let signature = task::spawn_blocking(move || { key .key - .sign_base64(signing_algorithm, &message.into_bytes()) + .sign_base64(signing_algorithm, message.trim_end().as_bytes()) }).await??; Ok(ApSignature { @@ -329,7 +331,7 @@ impl ApClientService for ApClientServiceDefaultProvider { } headers.insert( - HeaderName::from_lowercase(b"signature").unwrap(), + HeaderName::from_static("signature"), HeaderValue::try_from(signed.to_string())?, ); @@ -350,10 +352,9 @@ impl ApClientService for ApClientServiceDefaultProvider { signing_algorithm: SigningAlgorithm, expires: Option>, url: &str, - body: &Value, + body_bytes: Vec, ) -> Result { let url = url.parse()?; - let body_bytes = serde_json::to_vec(body)?; // Move in, move out :3 let (digest_raw, body_bytes) = task::spawn_blocking(move || { let mut sha = sha2::Sha256::new(); @@ -406,12 +407,12 @@ impl ApClientService for ApClientServiceDefaultProvider { } headers.insert( - HeaderName::from_lowercase(b"digest").unwrap(), + HeaderName::from_static("digest"), HeaderValue::try_from(digest_base64)?, ); headers.insert( - HeaderName::from_lowercase(b"signature").unwrap(), + HeaderName::from_static("signature"), HeaderValue::try_from(signed.to_string())?, ); diff --git a/ext_federation/src/crypto.rs b/ext_federation/src/crypto.rs index 1b25886..e8af944 100644 --- a/ext_federation/src/crypto.rs +++ b/ext_federation/src/crypto.rs @@ -2,7 +2,7 @@ use rsa::pkcs1::DecodeRsaPrivateKey; use rsa::pkcs1::DecodeRsaPublicKey; use rsa::pkcs8::DecodePrivateKey; use rsa::pkcs8::DecodePublicKey; -use rsa::signature::Verifier; +use rsa::signature::{RandomizedSigner, Verifier}; use rsa::{ sha2::{Sha256, Sha512}, signature::Signer, @@ -268,10 +268,10 @@ impl ApHttpSigningKey<'_> { ) -> Result, ApSigningError> { match (self, algorithm) { (Self::RsaSha256(key), SigningAlgorithm::RsaSha256 | SigningAlgorithm::Hs2019) => { - Ok(Box::<[u8]>::from(key.sign(message)).into_vec()) + Ok(Box::<[u8]>::from(key.sign_with_rng(&mut rand::thread_rng(), message)).into_vec()) } (Self::RsaSha512(key), SigningAlgorithm::Hs2019) => { - Ok(Box::<[u8]>::from(key.sign(message)).into_vec()) + Ok(Box::<[u8]>::from(key.sign_with_rng(&mut rand::thread_rng(), message)).into_vec()) } (Self::Ed25519(key), SigningAlgorithm::Hs2019) => { Ok(key.sign(message).to_bytes().to_vec()) diff --git a/ext_federation/src/lib.rs b/ext_federation/src/lib.rs index 15852c1..0136ffe 100644 --- a/ext_federation/src/lib.rs +++ b/ext_federation/src/lib.rs @@ -171,6 +171,6 @@ pub trait ApClientService: Send + Sync { signing_algorithm: SigningAlgorithm, expires: Option>, url: &str, - body: &Value, + body: Vec, ) -> Result; } diff --git a/src/rpc_v1/mod.rs b/src/rpc_v1/mod.rs index 6ddb1ce..4e0f8fe 100644 --- a/src/rpc_v1/mod.rs +++ b/src/rpc_v1/mod.rs @@ -23,7 +23,7 @@ struct RpcApGet { struct RpcApPost { user_id: String, url: String, - body: serde_json::Value, + body: String, } pub fn create_rpc_router() -> MagRpc { @@ -43,9 +43,9 @@ pub fn create_rpc_router() -> MagRpc { attachments: true, with_context: true, } - .fetch_single(&ctx, &id) - .await? - .ok_or(ObjectNotFound(id))?; + .fetch_single(&ctx, &id) + .await? + .ok_or(ObjectNotFound(id))?; Result::<_, ApiError>::Ok(note) }, @@ -90,7 +90,7 @@ pub fn create_rpc_router() -> MagRpc { .create_signing_key(&key_id, SigningAlgorithm::RsaSha256)?; let result = service .ap_client - .signed_post(signing_key, SigningAlgorithm::RsaSha256, None, &url, &body) + .signed_post(signing_key, SigningAlgorithm::RsaSha256, None, &url, body.into_bytes()) .await?; Result::<_, DeliveryError>::Ok(result) }, diff --git a/src/rpc_v1/proto.rs b/src/rpc_v1/proto.rs index 9e9cb59..41c99ec 100644 --- a/src/rpc_v1/proto.rs +++ b/src/rpc_v1/proto.rs @@ -12,7 +12,7 @@ use std::net::SocketAddr; use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}; use tokio::net::{TcpListener, UnixSocket}; use tokio::select; use tokio::task::JoinSet; @@ -57,7 +57,7 @@ pub struct RpcResult { } impl IntoRpcResponse - for Result +for Result { fn into_rpc_response(self) -> Option { match self { @@ -65,21 +65,21 @@ impl IntoR success: true, data, }) - .into_rpc_response(), + .into_rpc_response(), Err(data) => { warn!("{:?}", data); RpcMessage(RpcResult { success: false, data, }) - .into_rpc_response() + .into_rpc_response() } } } } impl IntoRpcResponse - for Either +for Either { fn into_rpc_response(self) -> Option { match self { @@ -97,14 +97,14 @@ where &self, context: Arc, message: RpcMessage, - ) -> impl Future> + Send; + ) -> impl Future> + Send; } impl RpcHandler for F where T: Send + 'static, F: Fn(Arc, RpcMessage) -> Fut + Send + Sync + 'static, - Fut: Future + Send, + Fut: Future + Send, RR: IntoRpcResponse, { async fn process( @@ -119,15 +119,15 @@ where type MessageRaw = Box; type MagRpcHandlerMapped = dyn Fn( - Arc, - MessageRaw, - ) -> Pin> + Send + 'static>> - + Send - + Sync - + 'static; + Arc, + MessageRaw, +) -> Pin> + Send + 'static>> ++ Send ++ Sync ++ 'static; type MagRpcDecoderMapped = - dyn (Fn(&'_ [u8]) -> Result) + Send + Sync + 'static; +dyn (Fn(&'_ [u8]) -> Result) + Send + Sync + 'static; pub struct MagRpc { listeners: HashMap>, @@ -158,7 +158,7 @@ impl MagRpc { .process(ctx, RpcMessage(*data.downcast().unwrap())) .await } - .boxed() + .boxed() }), ); self.payload_decoders.insert( @@ -173,7 +173,7 @@ impl MagRpc { self, context: Arc, addr: RpcSockAddr, - graceful_shutdown: Option>, + graceful_shutdown: Option>, ) -> miette::Result<()> { match addr { RpcSockAddr::Ip(sock_addr) => { @@ -187,7 +187,7 @@ impl MagRpc { self, context: Arc, sock_addr: &SocketAddr, - graceful_shutdown: Option>, + graceful_shutdown: Option>, ) -> miette::Result<()> { debug!("Binding RPC TCP socket to {}", sock_addr); let listener = TcpListener::bind(sock_addr).await.into_diagnostic()?; @@ -223,15 +223,16 @@ impl MagRpc { debug!("RPC TCP connection accepted: {:?}", remote_addr); let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>(); - let (read_half, mut write_half) = stream.into_split(); + let (read_half, write_half) = stream.into_split(); let buf_read = BufReader::new(read_half); + let mut buf_write = BufWriter::new(write_half); let context = context.clone(); let rx_dec = rx_dec.clone(); let fut = async move { let src = rx_dec .stream_decode(buf_read, cancel_recv) .map_ok(process(context)) - .try_buffer_unordered(100) + .try_buffer_unordered(400) .boxed(); futures::pin_mut!(src); @@ -241,19 +242,19 @@ impl MagRpc { continue; }; - write_half.write_u8(b'M').await.into_diagnostic()?; - write_half.write_u64(serial).await.into_diagnostic()?; - write_half + buf_write.write_u8(b'M').await.into_diagnostic()?; + buf_write.write_u64(serial).await.into_diagnostic()?; + buf_write .write_u32(bytes.len() as u32) .await .into_diagnostic()?; - write_half.write_all(&bytes).await.into_diagnostic()?; - write_half.flush().await.into_diagnostic()?; + buf_write.write_all(&bytes).await.into_diagnostic()?; + buf_write.flush().await.into_diagnostic()?; } Ok(remote_addr) } - .instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr)); + .instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr)); connections.spawn(fut); @@ -276,7 +277,7 @@ impl MagRpc { self, context: Arc, addr: &Path, - graceful_shutdown: Option>, + graceful_shutdown: Option>, ) -> miette::Result<()> { let sock = UnixSocket::new_stream().into_diagnostic()?; debug!("Binding RPC Unix socket to {}", addr.display()); @@ -313,15 +314,16 @@ impl MagRpc { debug!("RPC Unix connection accepted: {:?}", remote_addr); let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>(); - let (read_half, mut write_half) = stream.into_split(); + let (read_half, write_half) = stream.into_split(); let buf_read = BufReader::new(read_half); + let mut buf_write = BufWriter::new(write_half); let context = context.clone(); let rx_dec = rx_dec.clone(); let fut = async move { let src = rx_dec .stream_decode(buf_read, cancel_recv) .map_ok(process(context)) - .try_buffer_unordered(100) + .try_buffer_unordered(400) .boxed(); futures::pin_mut!(src); @@ -331,19 +333,19 @@ impl MagRpc { continue; }; - write_half.write_u8(b'M').await.into_diagnostic()?; - write_half.write_u64(serial).await.into_diagnostic()?; - write_half + buf_write.write_u8(b'M').await.into_diagnostic()?; + buf_write.write_u64(serial).await.into_diagnostic()?; + buf_write .write_u32(bytes.len() as u32) .await .into_diagnostic()?; - write_half.write_all(&bytes).await.into_diagnostic()?; - write_half.flush().await.into_diagnostic()?; + buf_write.write_all(&bytes).await.into_diagnostic()?; + buf_write.flush().await.into_diagnostic()?; } miette::Result::<()>::Ok(()) } - .instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr)); + .instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr)); connections.spawn(fut.boxed()); @@ -368,7 +370,7 @@ fn process( ) -> impl Fn( (u64, MessageRaw, Arc), ) -> Pin< - Box>> + Send + 'static>, + Box>> + Send + 'static>, > { move |(serial, payload, listener)| { let ctx = context.clone(); @@ -389,7 +391,7 @@ impl RpcCallDecoder { &self, mut buf_read: BufReader, mut cancel: tokio::sync::oneshot::Receiver<()>, - ) -> impl Stream)>> + Send + 'static + ) -> impl Stream)>> + Send + 'static { let decoders = self.payload_decoders.clone(); let listeners = self.listeners.clone();