WIP: Make the MMM parser run in linear time #14

Draft
natty wants to merge 2 commits from main into mmm
8 changed files with 60 additions and 54 deletions
Showing only changes of commit 3e4ae86c38 - Show all commits

1
Cargo.lock generated
View File

@ -2028,6 +2028,7 @@ dependencies = [
"miette 7.2.0", "miette 7.2.0",
"percent-encoding", "percent-encoding",
"quick-xml", "quick-xml",
"rand",
"reqwest", "reqwest",
"rsa", "rsa",
"serde", "serde",

View File

@ -62,6 +62,7 @@ quick-xml = "0.36"
redis = "0.26" redis = "0.26"
regex = "1.9" regex = "1.9"
rmp-serde = "1.3" rmp-serde = "1.3"
rand = "0.8"
rsa = "0.9" rsa = "0.9"
reqwest = "0.12" reqwest = "0.12"
sea-orm = "1" sea-orm = "1"

View File

@ -40,6 +40,7 @@ hyper = { workspace = true, features = ["full"] }
percent-encoding = { workspace = true } percent-encoding = { workspace = true }
reqwest = { workspace = true, features = ["stream", "hickory-dns"] } reqwest = { workspace = true, features = ["stream", "hickory-dns"] }
rand = { workspace = true }
ed25519-dalek = { workspace = true, features = [ ed25519-dalek = { workspace = true, features = [
"pem", "pem",
"pkcs8", "pkcs8",

View File

@ -4,6 +4,7 @@ use http::{HeaderMap, HeaderName, HeaderValue, Method};
use indexmap::IndexSet; use indexmap::IndexSet;
use serde_json::Value; use serde_json::Value;
use sha2::Digest; use sha2::Digest;
use std::fmt::Write;
use std::{fmt::Display, string::FromUtf8Error, sync::Arc}; use std::{fmt::Display, string::FromUtf8Error, sync::Arc};
use thiserror::Error; use thiserror::Error;
use tokio::task; use tokio::task;
@ -257,16 +258,17 @@ impl ApClientService for ApClientServiceDefaultProvider {
let message = components let message = components
.iter() .iter()
.map(|(k, v)| format!("{}: {}", k.as_ref(), v)) .fold(String::new(), |mut acc, (k, v)| {
.collect::<Vec<_>>() writeln!(&mut acc, "{}: {}", k.as_ref(), v).unwrap();
.join("\n"); acc
});
let key_id = signing_key.key_id.clone().into_owned(); let key_id = signing_key.key_id.clone().into_owned();
let key = signing_key.into_owned(); let key = signing_key.into_owned();
let signature = task::spawn_blocking(move || { let signature = task::spawn_blocking(move || {
key key
.key .key
.sign_base64(signing_algorithm, &message.into_bytes()) .sign_base64(signing_algorithm, message.trim_end().as_bytes())
}).await??; }).await??;
Ok(ApSignature { Ok(ApSignature {
@ -329,7 +331,7 @@ impl ApClientService for ApClientServiceDefaultProvider {
} }
headers.insert( headers.insert(
HeaderName::from_lowercase(b"signature").unwrap(), HeaderName::from_static("signature"),
HeaderValue::try_from(signed.to_string())?, HeaderValue::try_from(signed.to_string())?,
); );
@ -350,10 +352,9 @@ impl ApClientService for ApClientServiceDefaultProvider {
signing_algorithm: SigningAlgorithm, signing_algorithm: SigningAlgorithm,
expires: Option<chrono::DateTime<Utc>>, expires: Option<chrono::DateTime<Utc>>,
url: &str, url: &str,
body: &Value, body_bytes: Vec<u8>,
) -> Result<String, Self::Error> { ) -> Result<String, Self::Error> {
let url = url.parse()?; let url = url.parse()?;
let body_bytes = serde_json::to_vec(body)?;
// Move in, move out :3 // Move in, move out :3
let (digest_raw, body_bytes) = task::spawn_blocking(move || { let (digest_raw, body_bytes) = task::spawn_blocking(move || {
let mut sha = sha2::Sha256::new(); let mut sha = sha2::Sha256::new();
@ -406,12 +407,12 @@ impl ApClientService for ApClientServiceDefaultProvider {
} }
headers.insert( headers.insert(
HeaderName::from_lowercase(b"digest").unwrap(), HeaderName::from_static("digest"),
HeaderValue::try_from(digest_base64)?, HeaderValue::try_from(digest_base64)?,
); );
headers.insert( headers.insert(
HeaderName::from_lowercase(b"signature").unwrap(), HeaderName::from_static("signature"),
HeaderValue::try_from(signed.to_string())?, HeaderValue::try_from(signed.to_string())?,
); );

View File

@ -2,7 +2,7 @@ use rsa::pkcs1::DecodeRsaPrivateKey;
use rsa::pkcs1::DecodeRsaPublicKey; use rsa::pkcs1::DecodeRsaPublicKey;
use rsa::pkcs8::DecodePrivateKey; use rsa::pkcs8::DecodePrivateKey;
use rsa::pkcs8::DecodePublicKey; use rsa::pkcs8::DecodePublicKey;
use rsa::signature::Verifier; use rsa::signature::{RandomizedSigner, Verifier};
use rsa::{ use rsa::{
sha2::{Sha256, Sha512}, sha2::{Sha256, Sha512},
signature::Signer, signature::Signer,
@ -268,10 +268,10 @@ impl ApHttpSigningKey<'_> {
) -> Result<Vec<u8>, ApSigningError> { ) -> Result<Vec<u8>, ApSigningError> {
match (self, algorithm) { match (self, algorithm) {
(Self::RsaSha256(key), SigningAlgorithm::RsaSha256 | SigningAlgorithm::Hs2019) => { (Self::RsaSha256(key), SigningAlgorithm::RsaSha256 | SigningAlgorithm::Hs2019) => {
Ok(Box::<[u8]>::from(key.sign(message)).into_vec()) Ok(Box::<[u8]>::from(key.sign_with_rng(&mut rand::thread_rng(), message)).into_vec())
} }
(Self::RsaSha512(key), SigningAlgorithm::Hs2019) => { (Self::RsaSha512(key), SigningAlgorithm::Hs2019) => {
Ok(Box::<[u8]>::from(key.sign(message)).into_vec()) Ok(Box::<[u8]>::from(key.sign_with_rng(&mut rand::thread_rng(), message)).into_vec())
} }
(Self::Ed25519(key), SigningAlgorithm::Hs2019) => { (Self::Ed25519(key), SigningAlgorithm::Hs2019) => {
Ok(key.sign(message).to_bytes().to_vec()) Ok(key.sign(message).to_bytes().to_vec())

View File

@ -171,6 +171,6 @@ pub trait ApClientService: Send + Sync {
signing_algorithm: SigningAlgorithm, signing_algorithm: SigningAlgorithm,
expires: Option<chrono::DateTime<Utc>>, expires: Option<chrono::DateTime<Utc>>,
url: &str, url: &str,
body: &Value, body: Vec<u8>,
) -> Result<String, Self::Error>; ) -> Result<String, Self::Error>;
} }

View File

@ -23,7 +23,7 @@ struct RpcApGet {
struct RpcApPost { struct RpcApPost {
user_id: String, user_id: String,
url: String, url: String,
body: serde_json::Value, body: String,
} }
pub fn create_rpc_router() -> MagRpc { pub fn create_rpc_router() -> MagRpc {
@ -90,7 +90,7 @@ pub fn create_rpc_router() -> MagRpc {
.create_signing_key(&key_id, SigningAlgorithm::RsaSha256)?; .create_signing_key(&key_id, SigningAlgorithm::RsaSha256)?;
let result = service let result = service
.ap_client .ap_client
.signed_post(signing_key, SigningAlgorithm::RsaSha256, None, &url, &body) .signed_post(signing_key, SigningAlgorithm::RsaSha256, None, &url, body.into_bytes())
.await?; .await?;
Result::<_, DeliveryError>::Ok(result) Result::<_, DeliveryError>::Ok(result)
}, },

View File

@ -12,7 +12,7 @@ use std::net::SocketAddr;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{TcpListener, UnixSocket}; use tokio::net::{TcpListener, UnixSocket};
use tokio::select; use tokio::select;
use tokio::task::JoinSet; use tokio::task::JoinSet;
@ -57,7 +57,7 @@ pub struct RpcResult<T> {
} }
impl<T: Serialize + Send + 'static, E: Serialize + Debug + Send + 'static> IntoRpcResponse impl<T: Serialize + Send + 'static, E: Serialize + Debug + Send + 'static> IntoRpcResponse
for Result<T, E> for Result<T, E>
{ {
fn into_rpc_response(self) -> Option<RpcResponse> { fn into_rpc_response(self) -> Option<RpcResponse> {
match self { match self {
@ -79,7 +79,7 @@ impl<T: Serialize + Send + 'static, E: Serialize + Debug + Send + 'static> IntoR
} }
impl<L: IntoRpcResponse + Send + 'static, R: IntoRpcResponse + Send + 'static> IntoRpcResponse impl<L: IntoRpcResponse + Send + 'static, R: IntoRpcResponse + Send + 'static> IntoRpcResponse
for Either<L, R> for Either<L, R>
{ {
fn into_rpc_response(self) -> Option<RpcResponse> { fn into_rpc_response(self) -> Option<RpcResponse> {
match self { match self {
@ -97,14 +97,14 @@ where
&self, &self,
context: Arc<MagnetarService>, context: Arc<MagnetarService>,
message: RpcMessage<T>, message: RpcMessage<T>,
) -> impl Future<Output = Option<RpcResponse>> + Send; ) -> impl Future<Output=Option<RpcResponse>> + Send;
} }
impl<T, F, Fut, RR> RpcHandler<T> for F impl<T, F, Fut, RR> RpcHandler<T> for F
where where
T: Send + 'static, T: Send + 'static,
F: Fn(Arc<MagnetarService>, RpcMessage<T>) -> Fut + Send + Sync + 'static, F: Fn(Arc<MagnetarService>, RpcMessage<T>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = RR> + Send, Fut: Future<Output=RR> + Send,
RR: IntoRpcResponse, RR: IntoRpcResponse,
{ {
async fn process( async fn process(
@ -121,13 +121,13 @@ type MessageRaw = Box<dyn Any + Send + 'static>;
type MagRpcHandlerMapped = dyn Fn( type MagRpcHandlerMapped = dyn Fn(
Arc<MagnetarService>, Arc<MagnetarService>,
MessageRaw, MessageRaw,
) -> Pin<Box<dyn Future<Output = Option<RpcResponse>> + Send + 'static>> ) -> Pin<Box<dyn Future<Output=Option<RpcResponse>> + Send + 'static>>
+ Send + Send
+ Sync + Sync
+ 'static; + 'static;
type MagRpcDecoderMapped = type MagRpcDecoderMapped =
dyn (Fn(&'_ [u8]) -> Result<MessageRaw, rmp_serde::decode::Error>) + Send + Sync + 'static; dyn (Fn(&'_ [u8]) -> Result<MessageRaw, rmp_serde::decode::Error>) + Send + Sync + 'static;
pub struct MagRpc { pub struct MagRpc {
listeners: HashMap<String, Arc<MagRpcHandlerMapped>>, listeners: HashMap<String, Arc<MagRpcHandlerMapped>>,
@ -173,7 +173,7 @@ impl MagRpc {
self, self,
context: Arc<MagnetarService>, context: Arc<MagnetarService>,
addr: RpcSockAddr, addr: RpcSockAddr,
graceful_shutdown: Option<impl Future<Output = ()>>, graceful_shutdown: Option<impl Future<Output=()>>,
) -> miette::Result<()> { ) -> miette::Result<()> {
match addr { match addr {
RpcSockAddr::Ip(sock_addr) => { RpcSockAddr::Ip(sock_addr) => {
@ -187,7 +187,7 @@ impl MagRpc {
self, self,
context: Arc<MagnetarService>, context: Arc<MagnetarService>,
sock_addr: &SocketAddr, sock_addr: &SocketAddr,
graceful_shutdown: Option<impl Future<Output = ()>>, graceful_shutdown: Option<impl Future<Output=()>>,
) -> miette::Result<()> { ) -> miette::Result<()> {
debug!("Binding RPC TCP socket to {}", sock_addr); debug!("Binding RPC TCP socket to {}", sock_addr);
let listener = TcpListener::bind(sock_addr).await.into_diagnostic()?; let listener = TcpListener::bind(sock_addr).await.into_diagnostic()?;
@ -223,15 +223,16 @@ impl MagRpc {
debug!("RPC TCP connection accepted: {:?}", remote_addr); debug!("RPC TCP connection accepted: {:?}", remote_addr);
let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>(); let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>();
let (read_half, mut write_half) = stream.into_split(); let (read_half, write_half) = stream.into_split();
let buf_read = BufReader::new(read_half); let buf_read = BufReader::new(read_half);
let mut buf_write = BufWriter::new(write_half);
let context = context.clone(); let context = context.clone();
let rx_dec = rx_dec.clone(); let rx_dec = rx_dec.clone();
let fut = async move { let fut = async move {
let src = rx_dec let src = rx_dec
.stream_decode(buf_read, cancel_recv) .stream_decode(buf_read, cancel_recv)
.map_ok(process(context)) .map_ok(process(context))
.try_buffer_unordered(100) .try_buffer_unordered(400)
.boxed(); .boxed();
futures::pin_mut!(src); futures::pin_mut!(src);
@ -241,14 +242,14 @@ impl MagRpc {
continue; continue;
}; };
write_half.write_u8(b'M').await.into_diagnostic()?; buf_write.write_u8(b'M').await.into_diagnostic()?;
write_half.write_u64(serial).await.into_diagnostic()?; buf_write.write_u64(serial).await.into_diagnostic()?;
write_half buf_write
.write_u32(bytes.len() as u32) .write_u32(bytes.len() as u32)
.await .await
.into_diagnostic()?; .into_diagnostic()?;
write_half.write_all(&bytes).await.into_diagnostic()?; buf_write.write_all(&bytes).await.into_diagnostic()?;
write_half.flush().await.into_diagnostic()?; buf_write.flush().await.into_diagnostic()?;
} }
Ok(remote_addr) Ok(remote_addr)
@ -276,7 +277,7 @@ impl MagRpc {
self, self,
context: Arc<MagnetarService>, context: Arc<MagnetarService>,
addr: &Path, addr: &Path,
graceful_shutdown: Option<impl Future<Output = ()>>, graceful_shutdown: Option<impl Future<Output=()>>,
) -> miette::Result<()> { ) -> miette::Result<()> {
let sock = UnixSocket::new_stream().into_diagnostic()?; let sock = UnixSocket::new_stream().into_diagnostic()?;
debug!("Binding RPC Unix socket to {}", addr.display()); debug!("Binding RPC Unix socket to {}", addr.display());
@ -313,15 +314,16 @@ impl MagRpc {
debug!("RPC Unix connection accepted: {:?}", remote_addr); debug!("RPC Unix connection accepted: {:?}", remote_addr);
let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>(); let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>();
let (read_half, mut write_half) = stream.into_split(); let (read_half, write_half) = stream.into_split();
let buf_read = BufReader::new(read_half); let buf_read = BufReader::new(read_half);
let mut buf_write = BufWriter::new(write_half);
let context = context.clone(); let context = context.clone();
let rx_dec = rx_dec.clone(); let rx_dec = rx_dec.clone();
let fut = async move { let fut = async move {
let src = rx_dec let src = rx_dec
.stream_decode(buf_read, cancel_recv) .stream_decode(buf_read, cancel_recv)
.map_ok(process(context)) .map_ok(process(context))
.try_buffer_unordered(100) .try_buffer_unordered(400)
.boxed(); .boxed();
futures::pin_mut!(src); futures::pin_mut!(src);
@ -331,14 +333,14 @@ impl MagRpc {
continue; continue;
}; };
write_half.write_u8(b'M').await.into_diagnostic()?; buf_write.write_u8(b'M').await.into_diagnostic()?;
write_half.write_u64(serial).await.into_diagnostic()?; buf_write.write_u64(serial).await.into_diagnostic()?;
write_half buf_write
.write_u32(bytes.len() as u32) .write_u32(bytes.len() as u32)
.await .await
.into_diagnostic()?; .into_diagnostic()?;
write_half.write_all(&bytes).await.into_diagnostic()?; buf_write.write_all(&bytes).await.into_diagnostic()?;
write_half.flush().await.into_diagnostic()?; buf_write.flush().await.into_diagnostic()?;
} }
miette::Result::<()>::Ok(()) miette::Result::<()>::Ok(())
@ -368,7 +370,7 @@ fn process(
) -> impl Fn( ) -> impl Fn(
(u64, MessageRaw, Arc<MagRpcHandlerMapped>), (u64, MessageRaw, Arc<MagRpcHandlerMapped>),
) -> Pin< ) -> Pin<
Box<dyn Future<Output = miette::Result<Option<(u64, RpcResponse)>>> + Send + 'static>, Box<dyn Future<Output=miette::Result<Option<(u64, RpcResponse)>>> + Send + 'static>,
> { > {
move |(serial, payload, listener)| { move |(serial, payload, listener)| {
let ctx = context.clone(); let ctx = context.clone();
@ -389,7 +391,7 @@ impl RpcCallDecoder {
&self, &self,
mut buf_read: BufReader<R>, mut buf_read: BufReader<R>,
mut cancel: tokio::sync::oneshot::Receiver<()>, mut cancel: tokio::sync::oneshot::Receiver<()>,
) -> impl Stream<Item = miette::Result<(u64, MessageRaw, Arc<MagRpcHandlerMapped>)>> + Send + 'static ) -> impl Stream<Item=miette::Result<(u64, MessageRaw, Arc<MagRpcHandlerMapped>)>> + Send + 'static
{ {
let decoders = self.payload_decoders.clone(); let decoders = self.payload_decoders.clone();
let listeners = self.listeners.clone(); let listeners = self.listeners.clone();