Fixed RPC and implemented responses

This commit is contained in:
Natty 2024-11-14 17:33:33 +01:00
parent b9160305f1
commit 80c5bf8ae6
Signed by: natty
GPG Key ID: BF6CB659ADEE60EC
1 changed files with 130 additions and 64 deletions

View File

@ -1,7 +1,9 @@
use crate::service::MagnetarService;
use bytes::BufMut;
use futures::{FutureExt, Stream, StreamExt};
use futures::{sink, FutureExt, Stream, StreamExt};
use futures_util::SinkExt;
use miette::{miette, IntoDiagnostic};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::collections::HashMap;
@ -10,11 +12,11 @@ use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, BufReader, ReadBuf};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, ReadBuf};
use tokio::net::{TcpListener, UnixSocket};
use tokio::select;
use tokio::task::JoinSet;
use tracing::{debug, error, Instrument};
use tracing::{debug, error, info, Instrument};
#[derive(Debug, Clone)]
pub enum RpcSockAddr {
@ -30,6 +32,7 @@ pub trait IntoRpcResponse: Send {
fn into_rpc_response(self) -> Option<RpcResponse>;
}
#[derive(Debug, Clone)]
pub struct RpcResponse(Vec<u8>);
impl<T: Serialize + Send + 'static> IntoRpcResponse for RpcMessage<T> {
@ -54,7 +57,7 @@ pub struct RpcResult<T> {
}
impl<T: Serialize + Send + 'static, E: Serialize + Send + 'static> IntoRpcResponse
for Result<T, E>
for Result<T, E>
{
fn into_rpc_response(self) -> Option<RpcResponse> {
match self {
@ -62,12 +65,12 @@ for Result<T, E>
success: true,
data,
})
.into_rpc_response(),
.into_rpc_response(),
Err(data) => RpcMessage(RpcResult {
success: false,
data,
})
.into_rpc_response(),
.into_rpc_response(),
}
}
}
@ -80,14 +83,14 @@ where
&self,
context: Arc<MagnetarService>,
message: RpcMessage<T>,
) -> impl Future<Output=Option<RpcResponse>> + Send;
) -> impl Future<Output = Option<RpcResponse>> + Send;
}
impl<T, F, Fut, RR> RpcHandler<T> for F
where
T: Send + 'static,
F: Fn(Arc<MagnetarService>, RpcMessage<T>) -> Fut + Send + Sync + 'static,
Fut: Future<Output=RR> + Send,
Fut: Future<Output = RR> + Send,
RR: IntoRpcResponse,
{
async fn process(
@ -99,18 +102,21 @@ where
}
}
type MessageRaw = Box<dyn Any + Send + 'static>;
type MagRpcHandlerMapped = dyn Fn(
Arc<MagnetarService>,
Box<dyn Any + Send + 'static>,
) -> Pin<Box<dyn Future<Output=Option<RpcResponse>> + Send + 'static>>
+ Send
+ Sync;
Arc<MagnetarService>,
MessageRaw,
) -> Pin<Box<dyn Future<Output = Option<RpcResponse>> + Send + 'static>>
+ Send
+ Sync
+ 'static;
type MagRpcDecoderMapped =
dyn Fn(&[u8]) -> Result<Box<dyn Any + Send + 'static>, rmp_serde::decode::Error> + Send + Sync;
dyn (Fn(&'_ [u8]) -> Result<MessageRaw, rmp_serde::decode::Error>) + Send + Sync + 'static;
pub struct MagRpc {
listeners: HashMap<String, Box<MagRpcHandlerMapped>>,
listeners: HashMap<String, Arc<MagRpcHandlerMapped>>,
payload_decoders: HashMap<String, Box<MagRpcDecoderMapped>>,
}
@ -124,22 +130,27 @@ impl MagRpc {
pub fn handle<H, T>(mut self, method: impl Into<String>, handler: H) -> Self
where
T: Send + 'static,
T: DeserializeOwned + Send + 'static,
H: RpcHandler<T> + Sync + 'static,
{
let handler_ref = Arc::new(handler);
let method = method.into();
self.listeners.insert(
method.into(),
Box::new(move |ctx, data| {
method.clone(),
Arc::new(move |ctx, data| {
let handler = handler_ref.clone();
async move {
handler
.process(ctx, RpcMessage(*data.downcast().unwrap()))
.await
}
.boxed()
.boxed()
}),
);
self.payload_decoders.insert(
method,
Box::new(move |data| Ok(Box::new(rmp_serde::from_slice::<'_, T>(data)?))),
);
self
}
@ -148,7 +159,7 @@ impl MagRpc {
self,
context: Arc<MagnetarService>,
addr: RpcSockAddr,
graceful_shutdown: Option<impl Future<Output=()>>,
graceful_shutdown: Option<impl Future<Output = ()>>,
) -> miette::Result<()> {
match addr {
RpcSockAddr::Ip(sock_addr) => {
@ -162,10 +173,11 @@ impl MagRpc {
self,
context: Arc<MagnetarService>,
sock_addr: &SocketAddr,
graceful_shutdown: Option<impl Future<Output=()>>,
graceful_shutdown: Option<impl Future<Output = ()>>,
) -> miette::Result<()> {
debug!("Binding RPC socket to {}", sock_addr);
debug!("Binding RPC TCP socket to {}", sock_addr);
let listener = TcpListener::bind(sock_addr).await.into_diagnostic()?;
info!("Listening for RPC calls on {}", sock_addr);
let (sender, mut cancel) = tokio::sync::oneshot::channel::<()>();
let mut cancellation_tokens = Vec::new();
@ -175,11 +187,14 @@ impl MagRpc {
payload_decoders: Arc::new(self.payload_decoders),
};
let mut connections = JoinSet::new();
let mut connections = JoinSet::<miette::Result<_>>::new();
loop {
let (stream, sock_addr) = select!(
_ = connections.join_next() => continue,
let (stream, remote_addr) = select!(
Some(c) = connections.join_next() => {
debug!("RPC TCP connection closed: {:?}", c);
continue;
},
conn = listener.accept() => {
if let Err(e) = conn {
error!("Connection error: {}", e);
@ -191,14 +206,15 @@ impl MagRpc {
_ = &mut cancel => break
);
debug!("RPC TCP connection accepted: {:?}", sock_addr);
debug!("RPC TCP connection accepted: {:?}", remote_addr);
let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>();
let buf_read = BufReader::new(stream);
let (read_half, mut write_half) = stream.into_split();
let buf_read = BufReader::new(read_half);
let context = context.clone();
let rx_dec = rx_dec.clone();
let fut = async move {
rx_dec
let src = rx_dec
.stream_decode(buf_read, cancel_recv)
.filter_map(|r| async move {
if let Err(e) = &r {
@ -207,18 +223,27 @@ impl MagRpc {
r.ok()
})
.for_each_concurrent(Some(32), |(payload, listener)| async {
if let Some(response) = listener(context.clone(), payload).await {
// TODO: Respond
}
})
.await;
.filter_map(|(payload, listener)| {
let ctx = context.clone();
async move { listener(ctx, payload).await }
});
miette::Result::<()>::Ok(())
futures::pin_mut!(src);
while let Some(RpcResponse(bytes)) = src.next().await {
write_half
.write_u32(bytes.len() as u32)
.await
.into_diagnostic()?;
write_half.write_all(&bytes).await.into_diagnostic()?;
write_half.flush().await.into_diagnostic()?;
}
Ok(remote_addr)
}
.instrument(tracing::info_span!("RPC", sock_addr = ?sock_addr));
.instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr));
connections.spawn_local(fut);
connections.spawn(fut);
cancellation_tokens.push(cancel_send);
}
@ -228,6 +253,8 @@ impl MagRpc {
sender.send(()).ok();
}
info!("Awaiting shutdown of all RPC connections...");
connections.join_all().await;
Ok(())
@ -237,9 +264,10 @@ impl MagRpc {
self,
context: Arc<MagnetarService>,
addr: &Path,
graceful_shutdown: Option<impl Future<Output=()>>,
graceful_shutdown: Option<impl Future<Output = ()>>,
) -> miette::Result<()> {
let sock = UnixSocket::new_stream().into_diagnostic()?;
debug!("Binding RPC Unix socket to {}", addr.display());
sock.bind(addr).into_diagnostic()?;
let listener = sock.listen(16).into_diagnostic()?;
@ -251,11 +279,14 @@ impl MagRpc {
payload_decoders: Arc::new(self.payload_decoders),
};
let mut connections = JoinSet::new();
let mut connections = JoinSet::<miette::Result<_>>::new();
loop {
let (stream, sock_addr) = select!(
_ = connections.join_next() => continue,
let (stream, remote_addr) = select!(
Some(c) = connections.join_next() => {
debug!("RPC Unix connection closed: {:?}", c);
continue;
},
conn = listener.accept() => {
if let Err(e) = conn {
error!("Connection error: {}", e);
@ -267,14 +298,15 @@ impl MagRpc {
_ = &mut cancel => break
);
debug!("RPC Unix connection accepted: {:?}", sock_addr);
debug!("RPC Unix connection accepted: {:?}", remote_addr);
let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>();
let buf_read = BufReader::new(stream);
let (read_half, mut write_half) = stream.into_split();
let buf_read = BufReader::new(read_half);
let context = context.clone();
let rx_dec = rx_dec.clone();
let fut = async move {
rx_dec
let src = rx_dec
.stream_decode(buf_read, cancel_recv)
.filter_map(|r| async move {
if let Err(e) = &r {
@ -283,18 +315,27 @@ impl MagRpc {
r.ok()
})
.for_each_concurrent(Some(32), |(payload, listener)| async {
if let Some(response) = listener(context.clone(), payload).await {
// TODO: Respond
}
})
.await;
.filter_map(|(payload, listener)| {
let ctx = context.clone();
async move { listener(ctx, payload).await }
});
futures::pin_mut!(src);
while let Some(RpcResponse(bytes)) = src.next().await {
write_half
.write_u32(bytes.len() as u32)
.await
.into_diagnostic()?;
write_half.write_all(&bytes).await.into_diagnostic()?;
write_half.flush().await.into_diagnostic()?;
}
miette::Result::<()>::Ok(())
}
.instrument(tracing::info_span!("RPC", sock_addr = ?sock_addr));
.instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr));
connections.spawn_local(fut);
connections.spawn(fut.boxed());
cancellation_tokens.push(cancel_send);
}
@ -304,6 +345,8 @@ impl MagRpc {
sender.send(()).ok();
}
info!("Awaiting shutdown of all RPC connections...");
connections.join_all().await;
Ok(())
@ -312,25 +355,47 @@ impl MagRpc {
#[derive(Clone)]
struct RpcCallDecoder {
listeners: Arc<HashMap<String, Box<MagRpcHandlerMapped>>>,
listeners: Arc<HashMap<String, Arc<MagRpcHandlerMapped>>>,
payload_decoders: Arc<HashMap<String, Box<MagRpcDecoderMapped>>>,
}
impl RpcCallDecoder {
fn stream_decode<R: AsyncRead + AsyncReadExt + Unpin + Send>(
fn stream_decode<R: AsyncRead + AsyncReadExt + Unpin + Send + 'static>(
&self,
mut buf_read: BufReader<R>,
mut cancel: tokio::sync::oneshot::Receiver<()>,
) -> impl Stream<Item=miette::Result<(Box<dyn Any + Send + 'static>, &MagRpcHandlerMapped)>> + Send
) -> impl Stream<Item = miette::Result<(MessageRaw, Arc<MagRpcHandlerMapped>)>> + Send + 'static
{
let decoders = self.payload_decoders.clone();
let listeners = self.listeners.clone();
async_stream::try_stream! {
let mut name_buf = Vec::new();
let mut buf = Vec::new();
let mut messages = 0usize;
loop {
let read_fut = async {
let name_len = buf_read.read_u32().await.into_diagnostic()? as usize;
let mut header = [0u8; 1];
if buf_read.read(&mut header).await.into_diagnostic()? == 0 {
return if messages > 0 {
Ok(None)
} else {
Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Unexpected end of stream, expected a header"
)).into_diagnostic()
}
}
if !matches!(header, [b'M']) {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Unexpected data in stream, expected a header"
)).into_diagnostic();
}
let name_len = buf_read.read_u32().await.into_diagnostic()? as usize;
if name_len > name_buf.capacity() {
name_buf.reserve(name_len - name_buf.capacity());
}
@ -363,26 +428,26 @@ impl RpcCallDecoder {
buf_read.read_buf(&mut buf_write).await.into_diagnostic()?;
}
miette::Result::<_>::Ok((name_buf_write, buf_write))
miette::Result::<_>::Ok(Some((name_buf_write, buf_write)))
};
let (name_buf_write, payload) = select! {
let Some((name_buf_write, payload)) = select! {
read_result = read_fut => read_result,
_ = &mut cancel => { break; }
}?;
}? else {
break;
};
let name = std::str::from_utf8(name_buf_write.filled()).into_diagnostic()?;
let decoder = self
.payload_decoders
let decoder = decoders
.get(name)
.ok_or_else(|| miette!("No such RPC call name: {}", name))?
.as_ref();
let listener = self
.listeners
let listener = listeners
.get(name)
.ok_or_else(|| miette!("No such RPC call name: {}", name))?
.as_ref();
.clone();
let packet = match decoder(payload.filled()) {
Ok(p) => p,
@ -393,6 +458,7 @@ impl RpcCallDecoder {
};
yield (packet, listener);
messages += 1;
}
}
}