Fixed RPC and implemented responses

This commit is contained in:
Natty 2024-11-14 17:33:33 +01:00
parent b9160305f1
commit 80c5bf8ae6
Signed by: natty
GPG Key ID: BF6CB659ADEE60EC
1 changed files with 130 additions and 64 deletions

View File

@ -1,7 +1,9 @@
use crate::service::MagnetarService; use crate::service::MagnetarService;
use bytes::BufMut; use bytes::BufMut;
use futures::{FutureExt, Stream, StreamExt}; use futures::{sink, FutureExt, Stream, StreamExt};
use futures_util::SinkExt;
use miette::{miette, IntoDiagnostic}; use miette::{miette, IntoDiagnostic};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::any::Any; use std::any::Any;
use std::collections::HashMap; use std::collections::HashMap;
@ -10,11 +12,11 @@ use std::net::SocketAddr;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, BufReader, ReadBuf}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, ReadBuf};
use tokio::net::{TcpListener, UnixSocket}; use tokio::net::{TcpListener, UnixSocket};
use tokio::select; use tokio::select;
use tokio::task::JoinSet; use tokio::task::JoinSet;
use tracing::{debug, error, Instrument}; use tracing::{debug, error, info, Instrument};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum RpcSockAddr { pub enum RpcSockAddr {
@ -30,6 +32,7 @@ pub trait IntoRpcResponse: Send {
fn into_rpc_response(self) -> Option<RpcResponse>; fn into_rpc_response(self) -> Option<RpcResponse>;
} }
#[derive(Debug, Clone)]
pub struct RpcResponse(Vec<u8>); pub struct RpcResponse(Vec<u8>);
impl<T: Serialize + Send + 'static> IntoRpcResponse for RpcMessage<T> { impl<T: Serialize + Send + 'static> IntoRpcResponse for RpcMessage<T> {
@ -54,7 +57,7 @@ pub struct RpcResult<T> {
} }
impl<T: Serialize + Send + 'static, E: Serialize + Send + 'static> IntoRpcResponse impl<T: Serialize + Send + 'static, E: Serialize + Send + 'static> IntoRpcResponse
for Result<T, E> for Result<T, E>
{ {
fn into_rpc_response(self) -> Option<RpcResponse> { fn into_rpc_response(self) -> Option<RpcResponse> {
match self { match self {
@ -62,12 +65,12 @@ for Result<T, E>
success: true, success: true,
data, data,
}) })
.into_rpc_response(), .into_rpc_response(),
Err(data) => RpcMessage(RpcResult { Err(data) => RpcMessage(RpcResult {
success: false, success: false,
data, data,
}) })
.into_rpc_response(), .into_rpc_response(),
} }
} }
} }
@ -80,14 +83,14 @@ where
&self, &self,
context: Arc<MagnetarService>, context: Arc<MagnetarService>,
message: RpcMessage<T>, message: RpcMessage<T>,
) -> impl Future<Output=Option<RpcResponse>> + Send; ) -> impl Future<Output = Option<RpcResponse>> + Send;
} }
impl<T, F, Fut, RR> RpcHandler<T> for F impl<T, F, Fut, RR> RpcHandler<T> for F
where where
T: Send + 'static, T: Send + 'static,
F: Fn(Arc<MagnetarService>, RpcMessage<T>) -> Fut + Send + Sync + 'static, F: Fn(Arc<MagnetarService>, RpcMessage<T>) -> Fut + Send + Sync + 'static,
Fut: Future<Output=RR> + Send, Fut: Future<Output = RR> + Send,
RR: IntoRpcResponse, RR: IntoRpcResponse,
{ {
async fn process( async fn process(
@ -99,18 +102,21 @@ where
} }
} }
type MessageRaw = Box<dyn Any + Send + 'static>;
type MagRpcHandlerMapped = dyn Fn( type MagRpcHandlerMapped = dyn Fn(
Arc<MagnetarService>, Arc<MagnetarService>,
Box<dyn Any + Send + 'static>, MessageRaw,
) -> Pin<Box<dyn Future<Output=Option<RpcResponse>> + Send + 'static>> ) -> Pin<Box<dyn Future<Output = Option<RpcResponse>> + Send + 'static>>
+ Send + Send
+ Sync; + Sync
+ 'static;
type MagRpcDecoderMapped = type MagRpcDecoderMapped =
dyn Fn(&[u8]) -> Result<Box<dyn Any + Send + 'static>, rmp_serde::decode::Error> + Send + Sync; dyn (Fn(&'_ [u8]) -> Result<MessageRaw, rmp_serde::decode::Error>) + Send + Sync + 'static;
pub struct MagRpc { pub struct MagRpc {
listeners: HashMap<String, Box<MagRpcHandlerMapped>>, listeners: HashMap<String, Arc<MagRpcHandlerMapped>>,
payload_decoders: HashMap<String, Box<MagRpcDecoderMapped>>, payload_decoders: HashMap<String, Box<MagRpcDecoderMapped>>,
} }
@ -124,22 +130,27 @@ impl MagRpc {
pub fn handle<H, T>(mut self, method: impl Into<String>, handler: H) -> Self pub fn handle<H, T>(mut self, method: impl Into<String>, handler: H) -> Self
where where
T: Send + 'static, T: DeserializeOwned + Send + 'static,
H: RpcHandler<T> + Sync + 'static, H: RpcHandler<T> + Sync + 'static,
{ {
let handler_ref = Arc::new(handler); let handler_ref = Arc::new(handler);
let method = method.into();
self.listeners.insert( self.listeners.insert(
method.into(), method.clone(),
Box::new(move |ctx, data| { Arc::new(move |ctx, data| {
let handler = handler_ref.clone(); let handler = handler_ref.clone();
async move { async move {
handler handler
.process(ctx, RpcMessage(*data.downcast().unwrap())) .process(ctx, RpcMessage(*data.downcast().unwrap()))
.await .await
} }
.boxed() .boxed()
}), }),
); );
self.payload_decoders.insert(
method,
Box::new(move |data| Ok(Box::new(rmp_serde::from_slice::<'_, T>(data)?))),
);
self self
} }
@ -148,7 +159,7 @@ impl MagRpc {
self, self,
context: Arc<MagnetarService>, context: Arc<MagnetarService>,
addr: RpcSockAddr, addr: RpcSockAddr,
graceful_shutdown: Option<impl Future<Output=()>>, graceful_shutdown: Option<impl Future<Output = ()>>,
) -> miette::Result<()> { ) -> miette::Result<()> {
match addr { match addr {
RpcSockAddr::Ip(sock_addr) => { RpcSockAddr::Ip(sock_addr) => {
@ -162,10 +173,11 @@ impl MagRpc {
self, self,
context: Arc<MagnetarService>, context: Arc<MagnetarService>,
sock_addr: &SocketAddr, sock_addr: &SocketAddr,
graceful_shutdown: Option<impl Future<Output=()>>, graceful_shutdown: Option<impl Future<Output = ()>>,
) -> miette::Result<()> { ) -> miette::Result<()> {
debug!("Binding RPC socket to {}", sock_addr); debug!("Binding RPC TCP socket to {}", sock_addr);
let listener = TcpListener::bind(sock_addr).await.into_diagnostic()?; let listener = TcpListener::bind(sock_addr).await.into_diagnostic()?;
info!("Listening for RPC calls on {}", sock_addr);
let (sender, mut cancel) = tokio::sync::oneshot::channel::<()>(); let (sender, mut cancel) = tokio::sync::oneshot::channel::<()>();
let mut cancellation_tokens = Vec::new(); let mut cancellation_tokens = Vec::new();
@ -175,11 +187,14 @@ impl MagRpc {
payload_decoders: Arc::new(self.payload_decoders), payload_decoders: Arc::new(self.payload_decoders),
}; };
let mut connections = JoinSet::new(); let mut connections = JoinSet::<miette::Result<_>>::new();
loop { loop {
let (stream, sock_addr) = select!( let (stream, remote_addr) = select!(
_ = connections.join_next() => continue, Some(c) = connections.join_next() => {
debug!("RPC TCP connection closed: {:?}", c);
continue;
},
conn = listener.accept() => { conn = listener.accept() => {
if let Err(e) = conn { if let Err(e) = conn {
error!("Connection error: {}", e); error!("Connection error: {}", e);
@ -191,14 +206,15 @@ impl MagRpc {
_ = &mut cancel => break _ = &mut cancel => break
); );
debug!("RPC TCP connection accepted: {:?}", sock_addr); debug!("RPC TCP connection accepted: {:?}", remote_addr);
let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>(); let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>();
let buf_read = BufReader::new(stream); let (read_half, mut write_half) = stream.into_split();
let buf_read = BufReader::new(read_half);
let context = context.clone(); let context = context.clone();
let rx_dec = rx_dec.clone(); let rx_dec = rx_dec.clone();
let fut = async move { let fut = async move {
rx_dec let src = rx_dec
.stream_decode(buf_read, cancel_recv) .stream_decode(buf_read, cancel_recv)
.filter_map(|r| async move { .filter_map(|r| async move {
if let Err(e) = &r { if let Err(e) = &r {
@ -207,18 +223,27 @@ impl MagRpc {
r.ok() r.ok()
}) })
.for_each_concurrent(Some(32), |(payload, listener)| async { .filter_map(|(payload, listener)| {
if let Some(response) = listener(context.clone(), payload).await { let ctx = context.clone();
// TODO: Respond async move { listener(ctx, payload).await }
} });
})
.await;
miette::Result::<()>::Ok(()) futures::pin_mut!(src);
while let Some(RpcResponse(bytes)) = src.next().await {
write_half
.write_u32(bytes.len() as u32)
.await
.into_diagnostic()?;
write_half.write_all(&bytes).await.into_diagnostic()?;
write_half.flush().await.into_diagnostic()?;
}
Ok(remote_addr)
} }
.instrument(tracing::info_span!("RPC", sock_addr = ?sock_addr)); .instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr));
connections.spawn_local(fut); connections.spawn(fut);
cancellation_tokens.push(cancel_send); cancellation_tokens.push(cancel_send);
} }
@ -228,6 +253,8 @@ impl MagRpc {
sender.send(()).ok(); sender.send(()).ok();
} }
info!("Awaiting shutdown of all RPC connections...");
connections.join_all().await; connections.join_all().await;
Ok(()) Ok(())
@ -237,9 +264,10 @@ impl MagRpc {
self, self,
context: Arc<MagnetarService>, context: Arc<MagnetarService>,
addr: &Path, addr: &Path,
graceful_shutdown: Option<impl Future<Output=()>>, graceful_shutdown: Option<impl Future<Output = ()>>,
) -> miette::Result<()> { ) -> miette::Result<()> {
let sock = UnixSocket::new_stream().into_diagnostic()?; let sock = UnixSocket::new_stream().into_diagnostic()?;
debug!("Binding RPC Unix socket to {}", addr.display());
sock.bind(addr).into_diagnostic()?; sock.bind(addr).into_diagnostic()?;
let listener = sock.listen(16).into_diagnostic()?; let listener = sock.listen(16).into_diagnostic()?;
@ -251,11 +279,14 @@ impl MagRpc {
payload_decoders: Arc::new(self.payload_decoders), payload_decoders: Arc::new(self.payload_decoders),
}; };
let mut connections = JoinSet::new(); let mut connections = JoinSet::<miette::Result<_>>::new();
loop { loop {
let (stream, sock_addr) = select!( let (stream, remote_addr) = select!(
_ = connections.join_next() => continue, Some(c) = connections.join_next() => {
debug!("RPC Unix connection closed: {:?}", c);
continue;
},
conn = listener.accept() => { conn = listener.accept() => {
if let Err(e) = conn { if let Err(e) = conn {
error!("Connection error: {}", e); error!("Connection error: {}", e);
@ -267,14 +298,15 @@ impl MagRpc {
_ = &mut cancel => break _ = &mut cancel => break
); );
debug!("RPC Unix connection accepted: {:?}", sock_addr); debug!("RPC Unix connection accepted: {:?}", remote_addr);
let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>(); let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>();
let buf_read = BufReader::new(stream); let (read_half, mut write_half) = stream.into_split();
let buf_read = BufReader::new(read_half);
let context = context.clone(); let context = context.clone();
let rx_dec = rx_dec.clone(); let rx_dec = rx_dec.clone();
let fut = async move { let fut = async move {
rx_dec let src = rx_dec
.stream_decode(buf_read, cancel_recv) .stream_decode(buf_read, cancel_recv)
.filter_map(|r| async move { .filter_map(|r| async move {
if let Err(e) = &r { if let Err(e) = &r {
@ -283,18 +315,27 @@ impl MagRpc {
r.ok() r.ok()
}) })
.for_each_concurrent(Some(32), |(payload, listener)| async { .filter_map(|(payload, listener)| {
if let Some(response) = listener(context.clone(), payload).await { let ctx = context.clone();
// TODO: Respond async move { listener(ctx, payload).await }
} });
})
.await; futures::pin_mut!(src);
while let Some(RpcResponse(bytes)) = src.next().await {
write_half
.write_u32(bytes.len() as u32)
.await
.into_diagnostic()?;
write_half.write_all(&bytes).await.into_diagnostic()?;
write_half.flush().await.into_diagnostic()?;
}
miette::Result::<()>::Ok(()) miette::Result::<()>::Ok(())
} }
.instrument(tracing::info_span!("RPC", sock_addr = ?sock_addr)); .instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr));
connections.spawn_local(fut); connections.spawn(fut.boxed());
cancellation_tokens.push(cancel_send); cancellation_tokens.push(cancel_send);
} }
@ -304,6 +345,8 @@ impl MagRpc {
sender.send(()).ok(); sender.send(()).ok();
} }
info!("Awaiting shutdown of all RPC connections...");
connections.join_all().await; connections.join_all().await;
Ok(()) Ok(())
@ -312,25 +355,47 @@ impl MagRpc {
#[derive(Clone)] #[derive(Clone)]
struct RpcCallDecoder { struct RpcCallDecoder {
listeners: Arc<HashMap<String, Box<MagRpcHandlerMapped>>>, listeners: Arc<HashMap<String, Arc<MagRpcHandlerMapped>>>,
payload_decoders: Arc<HashMap<String, Box<MagRpcDecoderMapped>>>, payload_decoders: Arc<HashMap<String, Box<MagRpcDecoderMapped>>>,
} }
impl RpcCallDecoder { impl RpcCallDecoder {
fn stream_decode<R: AsyncRead + AsyncReadExt + Unpin + Send>( fn stream_decode<R: AsyncRead + AsyncReadExt + Unpin + Send + 'static>(
&self, &self,
mut buf_read: BufReader<R>, mut buf_read: BufReader<R>,
mut cancel: tokio::sync::oneshot::Receiver<()>, mut cancel: tokio::sync::oneshot::Receiver<()>,
) -> impl Stream<Item=miette::Result<(Box<dyn Any + Send + 'static>, &MagRpcHandlerMapped)>> + Send ) -> impl Stream<Item = miette::Result<(MessageRaw, Arc<MagRpcHandlerMapped>)>> + Send + 'static
{ {
let decoders = self.payload_decoders.clone();
let listeners = self.listeners.clone();
async_stream::try_stream! { async_stream::try_stream! {
let mut name_buf = Vec::new(); let mut name_buf = Vec::new();
let mut buf = Vec::new(); let mut buf = Vec::new();
let mut messages = 0usize;
loop { loop {
let read_fut = async { let read_fut = async {
let name_len = buf_read.read_u32().await.into_diagnostic()? as usize; let mut header = [0u8; 1];
if buf_read.read(&mut header).await.into_diagnostic()? == 0 {
return if messages > 0 {
Ok(None)
} else {
Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Unexpected end of stream, expected a header"
)).into_diagnostic()
}
}
if !matches!(header, [b'M']) {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Unexpected data in stream, expected a header"
)).into_diagnostic();
}
let name_len = buf_read.read_u32().await.into_diagnostic()? as usize;
if name_len > name_buf.capacity() { if name_len > name_buf.capacity() {
name_buf.reserve(name_len - name_buf.capacity()); name_buf.reserve(name_len - name_buf.capacity());
} }
@ -363,26 +428,26 @@ impl RpcCallDecoder {
buf_read.read_buf(&mut buf_write).await.into_diagnostic()?; buf_read.read_buf(&mut buf_write).await.into_diagnostic()?;
} }
miette::Result::<_>::Ok((name_buf_write, buf_write)) miette::Result::<_>::Ok(Some((name_buf_write, buf_write)))
}; };
let (name_buf_write, payload) = select! { let Some((name_buf_write, payload)) = select! {
read_result = read_fut => read_result, read_result = read_fut => read_result,
_ = &mut cancel => { break; } _ = &mut cancel => { break; }
}?; }? else {
break;
};
let name = std::str::from_utf8(name_buf_write.filled()).into_diagnostic()?; let name = std::str::from_utf8(name_buf_write.filled()).into_diagnostic()?;
let decoder = self let decoder = decoders
.payload_decoders
.get(name) .get(name)
.ok_or_else(|| miette!("No such RPC call name: {}", name))? .ok_or_else(|| miette!("No such RPC call name: {}", name))?
.as_ref(); .as_ref();
let listener = self let listener = listeners
.listeners
.get(name) .get(name)
.ok_or_else(|| miette!("No such RPC call name: {}", name))? .ok_or_else(|| miette!("No such RPC call name: {}", name))?
.as_ref(); .clone();
let packet = match decoder(payload.filled()) { let packet = match decoder(payload.filled()) {
Ok(p) => p, Ok(p) => p,
@ -393,6 +458,7 @@ impl RpcCallDecoder {
}; };
yield (packet, listener); yield (packet, listener);
messages += 1;
} }
} }
} }