diff --git a/src/rpc_v1/proto.rs b/src/rpc_v1/proto.rs index 2795bd8..44270cc 100644 --- a/src/rpc_v1/proto.rs +++ b/src/rpc_v1/proto.rs @@ -1,7 +1,9 @@ use crate::service::MagnetarService; use bytes::BufMut; -use futures::{FutureExt, Stream, StreamExt}; +use futures::{sink, FutureExt, Stream, StreamExt}; +use futures_util::SinkExt; use miette::{miette, IntoDiagnostic}; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::any::Any; use std::collections::HashMap; @@ -10,11 +12,11 @@ use std::net::SocketAddr; use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncReadExt, BufReader, ReadBuf}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, ReadBuf}; use tokio::net::{TcpListener, UnixSocket}; use tokio::select; use tokio::task::JoinSet; -use tracing::{debug, error, Instrument}; +use tracing::{debug, error, info, Instrument}; #[derive(Debug, Clone)] pub enum RpcSockAddr { @@ -30,6 +32,7 @@ pub trait IntoRpcResponse: Send { fn into_rpc_response(self) -> Option; } +#[derive(Debug, Clone)] pub struct RpcResponse(Vec); impl IntoRpcResponse for RpcMessage { @@ -54,7 +57,7 @@ pub struct RpcResult { } impl IntoRpcResponse -for Result + for Result { fn into_rpc_response(self) -> Option { match self { @@ -62,12 +65,12 @@ for Result success: true, data, }) - .into_rpc_response(), + .into_rpc_response(), Err(data) => RpcMessage(RpcResult { success: false, data, }) - .into_rpc_response(), + .into_rpc_response(), } } } @@ -80,14 +83,14 @@ where &self, context: Arc, message: RpcMessage, - ) -> impl Future> + Send; + ) -> impl Future> + Send; } impl RpcHandler for F where T: Send + 'static, F: Fn(Arc, RpcMessage) -> Fut + Send + Sync + 'static, - Fut: Future + Send, + Fut: Future + Send, RR: IntoRpcResponse, { async fn process( @@ -99,18 +102,21 @@ where } } +type MessageRaw = Box; + type MagRpcHandlerMapped = dyn Fn( - Arc, - Box, -) -> Pin> + Send + 'static>> -+ Send -+ Sync; + Arc, + MessageRaw, + ) -> Pin> + Send + 'static>> + + Send + + Sync + + 'static; type MagRpcDecoderMapped = -dyn Fn(&[u8]) -> Result, rmp_serde::decode::Error> + Send + Sync; + dyn (Fn(&'_ [u8]) -> Result) + Send + Sync + 'static; pub struct MagRpc { - listeners: HashMap>, + listeners: HashMap>, payload_decoders: HashMap>, } @@ -124,22 +130,27 @@ impl MagRpc { pub fn handle(mut self, method: impl Into, handler: H) -> Self where - T: Send + 'static, + T: DeserializeOwned + Send + 'static, H: RpcHandler + Sync + 'static, { let handler_ref = Arc::new(handler); + let method = method.into(); self.listeners.insert( - method.into(), - Box::new(move |ctx, data| { + method.clone(), + Arc::new(move |ctx, data| { let handler = handler_ref.clone(); async move { handler .process(ctx, RpcMessage(*data.downcast().unwrap())) .await } - .boxed() + .boxed() }), ); + self.payload_decoders.insert( + method, + Box::new(move |data| Ok(Box::new(rmp_serde::from_slice::<'_, T>(data)?))), + ); self } @@ -148,7 +159,7 @@ impl MagRpc { self, context: Arc, addr: RpcSockAddr, - graceful_shutdown: Option>, + graceful_shutdown: Option>, ) -> miette::Result<()> { match addr { RpcSockAddr::Ip(sock_addr) => { @@ -162,10 +173,11 @@ impl MagRpc { self, context: Arc, sock_addr: &SocketAddr, - graceful_shutdown: Option>, + graceful_shutdown: Option>, ) -> miette::Result<()> { - debug!("Binding RPC socket to {}", sock_addr); + debug!("Binding RPC TCP socket to {}", sock_addr); let listener = TcpListener::bind(sock_addr).await.into_diagnostic()?; + info!("Listening for RPC calls on {}", sock_addr); let (sender, mut cancel) = tokio::sync::oneshot::channel::<()>(); let mut cancellation_tokens = Vec::new(); @@ -175,11 +187,14 @@ impl MagRpc { payload_decoders: Arc::new(self.payload_decoders), }; - let mut connections = JoinSet::new(); + let mut connections = JoinSet::>::new(); loop { - let (stream, sock_addr) = select!( - _ = connections.join_next() => continue, + let (stream, remote_addr) = select!( + Some(c) = connections.join_next() => { + debug!("RPC TCP connection closed: {:?}", c); + continue; + }, conn = listener.accept() => { if let Err(e) = conn { error!("Connection error: {}", e); @@ -191,14 +206,15 @@ impl MagRpc { _ = &mut cancel => break ); - debug!("RPC TCP connection accepted: {:?}", sock_addr); + debug!("RPC TCP connection accepted: {:?}", remote_addr); let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>(); - let buf_read = BufReader::new(stream); + let (read_half, mut write_half) = stream.into_split(); + let buf_read = BufReader::new(read_half); let context = context.clone(); let rx_dec = rx_dec.clone(); let fut = async move { - rx_dec + let src = rx_dec .stream_decode(buf_read, cancel_recv) .filter_map(|r| async move { if let Err(e) = &r { @@ -207,18 +223,27 @@ impl MagRpc { r.ok() }) - .for_each_concurrent(Some(32), |(payload, listener)| async { - if let Some(response) = listener(context.clone(), payload).await { - // TODO: Respond - } - }) - .await; + .filter_map(|(payload, listener)| { + let ctx = context.clone(); + async move { listener(ctx, payload).await } + }); - miette::Result::<()>::Ok(()) + futures::pin_mut!(src); + + while let Some(RpcResponse(bytes)) = src.next().await { + write_half + .write_u32(bytes.len() as u32) + .await + .into_diagnostic()?; + write_half.write_all(&bytes).await.into_diagnostic()?; + write_half.flush().await.into_diagnostic()?; + } + + Ok(remote_addr) } - .instrument(tracing::info_span!("RPC", sock_addr = ?sock_addr)); + .instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr)); - connections.spawn_local(fut); + connections.spawn(fut); cancellation_tokens.push(cancel_send); } @@ -228,6 +253,8 @@ impl MagRpc { sender.send(()).ok(); } + info!("Awaiting shutdown of all RPC connections..."); + connections.join_all().await; Ok(()) @@ -237,9 +264,10 @@ impl MagRpc { self, context: Arc, addr: &Path, - graceful_shutdown: Option>, + graceful_shutdown: Option>, ) -> miette::Result<()> { let sock = UnixSocket::new_stream().into_diagnostic()?; + debug!("Binding RPC Unix socket to {}", addr.display()); sock.bind(addr).into_diagnostic()?; let listener = sock.listen(16).into_diagnostic()?; @@ -251,11 +279,14 @@ impl MagRpc { payload_decoders: Arc::new(self.payload_decoders), }; - let mut connections = JoinSet::new(); + let mut connections = JoinSet::>::new(); loop { - let (stream, sock_addr) = select!( - _ = connections.join_next() => continue, + let (stream, remote_addr) = select!( + Some(c) = connections.join_next() => { + debug!("RPC Unix connection closed: {:?}", c); + continue; + }, conn = listener.accept() => { if let Err(e) = conn { error!("Connection error: {}", e); @@ -267,14 +298,15 @@ impl MagRpc { _ = &mut cancel => break ); - debug!("RPC Unix connection accepted: {:?}", sock_addr); + debug!("RPC Unix connection accepted: {:?}", remote_addr); let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>(); - let buf_read = BufReader::new(stream); + let (read_half, mut write_half) = stream.into_split(); + let buf_read = BufReader::new(read_half); let context = context.clone(); let rx_dec = rx_dec.clone(); let fut = async move { - rx_dec + let src = rx_dec .stream_decode(buf_read, cancel_recv) .filter_map(|r| async move { if let Err(e) = &r { @@ -283,18 +315,27 @@ impl MagRpc { r.ok() }) - .for_each_concurrent(Some(32), |(payload, listener)| async { - if let Some(response) = listener(context.clone(), payload).await { - // TODO: Respond - } - }) - .await; + .filter_map(|(payload, listener)| { + let ctx = context.clone(); + async move { listener(ctx, payload).await } + }); + + futures::pin_mut!(src); + + while let Some(RpcResponse(bytes)) = src.next().await { + write_half + .write_u32(bytes.len() as u32) + .await + .into_diagnostic()?; + write_half.write_all(&bytes).await.into_diagnostic()?; + write_half.flush().await.into_diagnostic()?; + } miette::Result::<()>::Ok(()) } - .instrument(tracing::info_span!("RPC", sock_addr = ?sock_addr)); + .instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr)); - connections.spawn_local(fut); + connections.spawn(fut.boxed()); cancellation_tokens.push(cancel_send); } @@ -304,6 +345,8 @@ impl MagRpc { sender.send(()).ok(); } + info!("Awaiting shutdown of all RPC connections..."); + connections.join_all().await; Ok(()) @@ -312,25 +355,47 @@ impl MagRpc { #[derive(Clone)] struct RpcCallDecoder { - listeners: Arc>>, + listeners: Arc>>, payload_decoders: Arc>>, } impl RpcCallDecoder { - fn stream_decode( + fn stream_decode( &self, mut buf_read: BufReader, mut cancel: tokio::sync::oneshot::Receiver<()>, - ) -> impl Stream, &MagRpcHandlerMapped)>> + Send + ) -> impl Stream)>> + Send + 'static { + let decoders = self.payload_decoders.clone(); + let listeners = self.listeners.clone(); + async_stream::try_stream! { let mut name_buf = Vec::new(); let mut buf = Vec::new(); + let mut messages = 0usize; loop { let read_fut = async { - let name_len = buf_read.read_u32().await.into_diagnostic()? as usize; + let mut header = [0u8; 1]; + if buf_read.read(&mut header).await.into_diagnostic()? == 0 { + return if messages > 0 { + Ok(None) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "Unexpected end of stream, expected a header" + )).into_diagnostic() + } + } + if !matches!(header, [b'M']) { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Unexpected data in stream, expected a header" + )).into_diagnostic(); + } + + let name_len = buf_read.read_u32().await.into_diagnostic()? as usize; if name_len > name_buf.capacity() { name_buf.reserve(name_len - name_buf.capacity()); } @@ -363,26 +428,26 @@ impl RpcCallDecoder { buf_read.read_buf(&mut buf_write).await.into_diagnostic()?; } - miette::Result::<_>::Ok((name_buf_write, buf_write)) + miette::Result::<_>::Ok(Some((name_buf_write, buf_write))) }; - let (name_buf_write, payload) = select! { + let Some((name_buf_write, payload)) = select! { read_result = read_fut => read_result, _ = &mut cancel => { break; } - }?; + }? else { + break; + }; let name = std::str::from_utf8(name_buf_write.filled()).into_diagnostic()?; - let decoder = self - .payload_decoders + let decoder = decoders .get(name) .ok_or_else(|| miette!("No such RPC call name: {}", name))? .as_ref(); - let listener = self - .listeners + let listener = listeners .get(name) .ok_or_else(|| miette!("No such RPC call name: {}", name))? - .as_ref(); + .clone(); let packet = match decoder(payload.filled()) { Ok(p) => p, @@ -393,6 +458,7 @@ impl RpcCallDecoder { }; yield (packet, listener); + messages += 1; } } }