Fixed RPC and implemented responses
This commit is contained in:
parent
b9160305f1
commit
80c5bf8ae6
|
@ -1,7 +1,9 @@
|
|||
use crate::service::MagnetarService;
|
||||
use bytes::BufMut;
|
||||
use futures::{FutureExt, Stream, StreamExt};
|
||||
use futures::{sink, FutureExt, Stream, StreamExt};
|
||||
use futures_util::SinkExt;
|
||||
use miette::{miette, IntoDiagnostic};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::any::Any;
|
||||
use std::collections::HashMap;
|
||||
|
@ -10,11 +12,11 @@ use std::net::SocketAddr;
|
|||
use std::path::{Path, PathBuf};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, BufReader, ReadBuf};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, ReadBuf};
|
||||
use tokio::net::{TcpListener, UnixSocket};
|
||||
use tokio::select;
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::{debug, error, Instrument};
|
||||
use tracing::{debug, error, info, Instrument};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RpcSockAddr {
|
||||
|
@ -30,6 +32,7 @@ pub trait IntoRpcResponse: Send {
|
|||
fn into_rpc_response(self) -> Option<RpcResponse>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RpcResponse(Vec<u8>);
|
||||
|
||||
impl<T: Serialize + Send + 'static> IntoRpcResponse for RpcMessage<T> {
|
||||
|
@ -54,7 +57,7 @@ pub struct RpcResult<T> {
|
|||
}
|
||||
|
||||
impl<T: Serialize + Send + 'static, E: Serialize + Send + 'static> IntoRpcResponse
|
||||
for Result<T, E>
|
||||
for Result<T, E>
|
||||
{
|
||||
fn into_rpc_response(self) -> Option<RpcResponse> {
|
||||
match self {
|
||||
|
@ -62,12 +65,12 @@ for Result<T, E>
|
|||
success: true,
|
||||
data,
|
||||
})
|
||||
.into_rpc_response(),
|
||||
.into_rpc_response(),
|
||||
Err(data) => RpcMessage(RpcResult {
|
||||
success: false,
|
||||
data,
|
||||
})
|
||||
.into_rpc_response(),
|
||||
.into_rpc_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -80,14 +83,14 @@ where
|
|||
&self,
|
||||
context: Arc<MagnetarService>,
|
||||
message: RpcMessage<T>,
|
||||
) -> impl Future<Output=Option<RpcResponse>> + Send;
|
||||
) -> impl Future<Output = Option<RpcResponse>> + Send;
|
||||
}
|
||||
|
||||
impl<T, F, Fut, RR> RpcHandler<T> for F
|
||||
where
|
||||
T: Send + 'static,
|
||||
F: Fn(Arc<MagnetarService>, RpcMessage<T>) -> Fut + Send + Sync + 'static,
|
||||
Fut: Future<Output=RR> + Send,
|
||||
Fut: Future<Output = RR> + Send,
|
||||
RR: IntoRpcResponse,
|
||||
{
|
||||
async fn process(
|
||||
|
@ -99,18 +102,21 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
type MessageRaw = Box<dyn Any + Send + 'static>;
|
||||
|
||||
type MagRpcHandlerMapped = dyn Fn(
|
||||
Arc<MagnetarService>,
|
||||
Box<dyn Any + Send + 'static>,
|
||||
) -> Pin<Box<dyn Future<Output=Option<RpcResponse>> + Send + 'static>>
|
||||
+ Send
|
||||
+ Sync;
|
||||
Arc<MagnetarService>,
|
||||
MessageRaw,
|
||||
) -> Pin<Box<dyn Future<Output = Option<RpcResponse>> + Send + 'static>>
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static;
|
||||
|
||||
type MagRpcDecoderMapped =
|
||||
dyn Fn(&[u8]) -> Result<Box<dyn Any + Send + 'static>, rmp_serde::decode::Error> + Send + Sync;
|
||||
dyn (Fn(&'_ [u8]) -> Result<MessageRaw, rmp_serde::decode::Error>) + Send + Sync + 'static;
|
||||
|
||||
pub struct MagRpc {
|
||||
listeners: HashMap<String, Box<MagRpcHandlerMapped>>,
|
||||
listeners: HashMap<String, Arc<MagRpcHandlerMapped>>,
|
||||
payload_decoders: HashMap<String, Box<MagRpcDecoderMapped>>,
|
||||
}
|
||||
|
||||
|
@ -124,22 +130,27 @@ impl MagRpc {
|
|||
|
||||
pub fn handle<H, T>(mut self, method: impl Into<String>, handler: H) -> Self
|
||||
where
|
||||
T: Send + 'static,
|
||||
T: DeserializeOwned + Send + 'static,
|
||||
H: RpcHandler<T> + Sync + 'static,
|
||||
{
|
||||
let handler_ref = Arc::new(handler);
|
||||
let method = method.into();
|
||||
self.listeners.insert(
|
||||
method.into(),
|
||||
Box::new(move |ctx, data| {
|
||||
method.clone(),
|
||||
Arc::new(move |ctx, data| {
|
||||
let handler = handler_ref.clone();
|
||||
async move {
|
||||
handler
|
||||
.process(ctx, RpcMessage(*data.downcast().unwrap()))
|
||||
.await
|
||||
}
|
||||
.boxed()
|
||||
.boxed()
|
||||
}),
|
||||
);
|
||||
self.payload_decoders.insert(
|
||||
method,
|
||||
Box::new(move |data| Ok(Box::new(rmp_serde::from_slice::<'_, T>(data)?))),
|
||||
);
|
||||
|
||||
self
|
||||
}
|
||||
|
@ -148,7 +159,7 @@ impl MagRpc {
|
|||
self,
|
||||
context: Arc<MagnetarService>,
|
||||
addr: RpcSockAddr,
|
||||
graceful_shutdown: Option<impl Future<Output=()>>,
|
||||
graceful_shutdown: Option<impl Future<Output = ()>>,
|
||||
) -> miette::Result<()> {
|
||||
match addr {
|
||||
RpcSockAddr::Ip(sock_addr) => {
|
||||
|
@ -162,10 +173,11 @@ impl MagRpc {
|
|||
self,
|
||||
context: Arc<MagnetarService>,
|
||||
sock_addr: &SocketAddr,
|
||||
graceful_shutdown: Option<impl Future<Output=()>>,
|
||||
graceful_shutdown: Option<impl Future<Output = ()>>,
|
||||
) -> miette::Result<()> {
|
||||
debug!("Binding RPC socket to {}", sock_addr);
|
||||
debug!("Binding RPC TCP socket to {}", sock_addr);
|
||||
let listener = TcpListener::bind(sock_addr).await.into_diagnostic()?;
|
||||
info!("Listening for RPC calls on {}", sock_addr);
|
||||
|
||||
let (sender, mut cancel) = tokio::sync::oneshot::channel::<()>();
|
||||
let mut cancellation_tokens = Vec::new();
|
||||
|
@ -175,11 +187,14 @@ impl MagRpc {
|
|||
payload_decoders: Arc::new(self.payload_decoders),
|
||||
};
|
||||
|
||||
let mut connections = JoinSet::new();
|
||||
let mut connections = JoinSet::<miette::Result<_>>::new();
|
||||
|
||||
loop {
|
||||
let (stream, sock_addr) = select!(
|
||||
_ = connections.join_next() => continue,
|
||||
let (stream, remote_addr) = select!(
|
||||
Some(c) = connections.join_next() => {
|
||||
debug!("RPC TCP connection closed: {:?}", c);
|
||||
continue;
|
||||
},
|
||||
conn = listener.accept() => {
|
||||
if let Err(e) = conn {
|
||||
error!("Connection error: {}", e);
|
||||
|
@ -191,14 +206,15 @@ impl MagRpc {
|
|||
_ = &mut cancel => break
|
||||
);
|
||||
|
||||
debug!("RPC TCP connection accepted: {:?}", sock_addr);
|
||||
debug!("RPC TCP connection accepted: {:?}", remote_addr);
|
||||
|
||||
let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>();
|
||||
let buf_read = BufReader::new(stream);
|
||||
let (read_half, mut write_half) = stream.into_split();
|
||||
let buf_read = BufReader::new(read_half);
|
||||
let context = context.clone();
|
||||
let rx_dec = rx_dec.clone();
|
||||
let fut = async move {
|
||||
rx_dec
|
||||
let src = rx_dec
|
||||
.stream_decode(buf_read, cancel_recv)
|
||||
.filter_map(|r| async move {
|
||||
if let Err(e) = &r {
|
||||
|
@ -207,18 +223,27 @@ impl MagRpc {
|
|||
|
||||
r.ok()
|
||||
})
|
||||
.for_each_concurrent(Some(32), |(payload, listener)| async {
|
||||
if let Some(response) = listener(context.clone(), payload).await {
|
||||
// TODO: Respond
|
||||
}
|
||||
})
|
||||
.await;
|
||||
.filter_map(|(payload, listener)| {
|
||||
let ctx = context.clone();
|
||||
async move { listener(ctx, payload).await }
|
||||
});
|
||||
|
||||
miette::Result::<()>::Ok(())
|
||||
futures::pin_mut!(src);
|
||||
|
||||
while let Some(RpcResponse(bytes)) = src.next().await {
|
||||
write_half
|
||||
.write_u32(bytes.len() as u32)
|
||||
.await
|
||||
.into_diagnostic()?;
|
||||
write_half.write_all(&bytes).await.into_diagnostic()?;
|
||||
write_half.flush().await.into_diagnostic()?;
|
||||
}
|
||||
|
||||
Ok(remote_addr)
|
||||
}
|
||||
.instrument(tracing::info_span!("RPC", sock_addr = ?sock_addr));
|
||||
.instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr));
|
||||
|
||||
connections.spawn_local(fut);
|
||||
connections.spawn(fut);
|
||||
|
||||
cancellation_tokens.push(cancel_send);
|
||||
}
|
||||
|
@ -228,6 +253,8 @@ impl MagRpc {
|
|||
sender.send(()).ok();
|
||||
}
|
||||
|
||||
info!("Awaiting shutdown of all RPC connections...");
|
||||
|
||||
connections.join_all().await;
|
||||
|
||||
Ok(())
|
||||
|
@ -237,9 +264,10 @@ impl MagRpc {
|
|||
self,
|
||||
context: Arc<MagnetarService>,
|
||||
addr: &Path,
|
||||
graceful_shutdown: Option<impl Future<Output=()>>,
|
||||
graceful_shutdown: Option<impl Future<Output = ()>>,
|
||||
) -> miette::Result<()> {
|
||||
let sock = UnixSocket::new_stream().into_diagnostic()?;
|
||||
debug!("Binding RPC Unix socket to {}", addr.display());
|
||||
sock.bind(addr).into_diagnostic()?;
|
||||
let listener = sock.listen(16).into_diagnostic()?;
|
||||
|
||||
|
@ -251,11 +279,14 @@ impl MagRpc {
|
|||
payload_decoders: Arc::new(self.payload_decoders),
|
||||
};
|
||||
|
||||
let mut connections = JoinSet::new();
|
||||
let mut connections = JoinSet::<miette::Result<_>>::new();
|
||||
|
||||
loop {
|
||||
let (stream, sock_addr) = select!(
|
||||
_ = connections.join_next() => continue,
|
||||
let (stream, remote_addr) = select!(
|
||||
Some(c) = connections.join_next() => {
|
||||
debug!("RPC Unix connection closed: {:?}", c);
|
||||
continue;
|
||||
},
|
||||
conn = listener.accept() => {
|
||||
if let Err(e) = conn {
|
||||
error!("Connection error: {}", e);
|
||||
|
@ -267,14 +298,15 @@ impl MagRpc {
|
|||
_ = &mut cancel => break
|
||||
);
|
||||
|
||||
debug!("RPC Unix connection accepted: {:?}", sock_addr);
|
||||
debug!("RPC Unix connection accepted: {:?}", remote_addr);
|
||||
|
||||
let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>();
|
||||
let buf_read = BufReader::new(stream);
|
||||
let (read_half, mut write_half) = stream.into_split();
|
||||
let buf_read = BufReader::new(read_half);
|
||||
let context = context.clone();
|
||||
let rx_dec = rx_dec.clone();
|
||||
let fut = async move {
|
||||
rx_dec
|
||||
let src = rx_dec
|
||||
.stream_decode(buf_read, cancel_recv)
|
||||
.filter_map(|r| async move {
|
||||
if let Err(e) = &r {
|
||||
|
@ -283,18 +315,27 @@ impl MagRpc {
|
|||
|
||||
r.ok()
|
||||
})
|
||||
.for_each_concurrent(Some(32), |(payload, listener)| async {
|
||||
if let Some(response) = listener(context.clone(), payload).await {
|
||||
// TODO: Respond
|
||||
}
|
||||
})
|
||||
.await;
|
||||
.filter_map(|(payload, listener)| {
|
||||
let ctx = context.clone();
|
||||
async move { listener(ctx, payload).await }
|
||||
});
|
||||
|
||||
futures::pin_mut!(src);
|
||||
|
||||
while let Some(RpcResponse(bytes)) = src.next().await {
|
||||
write_half
|
||||
.write_u32(bytes.len() as u32)
|
||||
.await
|
||||
.into_diagnostic()?;
|
||||
write_half.write_all(&bytes).await.into_diagnostic()?;
|
||||
write_half.flush().await.into_diagnostic()?;
|
||||
}
|
||||
|
||||
miette::Result::<()>::Ok(())
|
||||
}
|
||||
.instrument(tracing::info_span!("RPC", sock_addr = ?sock_addr));
|
||||
.instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr));
|
||||
|
||||
connections.spawn_local(fut);
|
||||
connections.spawn(fut.boxed());
|
||||
|
||||
cancellation_tokens.push(cancel_send);
|
||||
}
|
||||
|
@ -304,6 +345,8 @@ impl MagRpc {
|
|||
sender.send(()).ok();
|
||||
}
|
||||
|
||||
info!("Awaiting shutdown of all RPC connections...");
|
||||
|
||||
connections.join_all().await;
|
||||
|
||||
Ok(())
|
||||
|
@ -312,25 +355,47 @@ impl MagRpc {
|
|||
|
||||
#[derive(Clone)]
|
||||
struct RpcCallDecoder {
|
||||
listeners: Arc<HashMap<String, Box<MagRpcHandlerMapped>>>,
|
||||
listeners: Arc<HashMap<String, Arc<MagRpcHandlerMapped>>>,
|
||||
payload_decoders: Arc<HashMap<String, Box<MagRpcDecoderMapped>>>,
|
||||
}
|
||||
|
||||
impl RpcCallDecoder {
|
||||
fn stream_decode<R: AsyncRead + AsyncReadExt + Unpin + Send>(
|
||||
fn stream_decode<R: AsyncRead + AsyncReadExt + Unpin + Send + 'static>(
|
||||
&self,
|
||||
mut buf_read: BufReader<R>,
|
||||
mut cancel: tokio::sync::oneshot::Receiver<()>,
|
||||
) -> impl Stream<Item=miette::Result<(Box<dyn Any + Send + 'static>, &MagRpcHandlerMapped)>> + Send
|
||||
) -> impl Stream<Item = miette::Result<(MessageRaw, Arc<MagRpcHandlerMapped>)>> + Send + 'static
|
||||
{
|
||||
let decoders = self.payload_decoders.clone();
|
||||
let listeners = self.listeners.clone();
|
||||
|
||||
async_stream::try_stream! {
|
||||
let mut name_buf = Vec::new();
|
||||
let mut buf = Vec::new();
|
||||
let mut messages = 0usize;
|
||||
|
||||
loop {
|
||||
let read_fut = async {
|
||||
let name_len = buf_read.read_u32().await.into_diagnostic()? as usize;
|
||||
let mut header = [0u8; 1];
|
||||
if buf_read.read(&mut header).await.into_diagnostic()? == 0 {
|
||||
return if messages > 0 {
|
||||
Ok(None)
|
||||
} else {
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"Unexpected end of stream, expected a header"
|
||||
)).into_diagnostic()
|
||||
}
|
||||
}
|
||||
|
||||
if !matches!(header, [b'M']) {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"Unexpected data in stream, expected a header"
|
||||
)).into_diagnostic();
|
||||
}
|
||||
|
||||
let name_len = buf_read.read_u32().await.into_diagnostic()? as usize;
|
||||
if name_len > name_buf.capacity() {
|
||||
name_buf.reserve(name_len - name_buf.capacity());
|
||||
}
|
||||
|
@ -363,26 +428,26 @@ impl RpcCallDecoder {
|
|||
buf_read.read_buf(&mut buf_write).await.into_diagnostic()?;
|
||||
}
|
||||
|
||||
miette::Result::<_>::Ok((name_buf_write, buf_write))
|
||||
miette::Result::<_>::Ok(Some((name_buf_write, buf_write)))
|
||||
};
|
||||
|
||||
let (name_buf_write, payload) = select! {
|
||||
let Some((name_buf_write, payload)) = select! {
|
||||
read_result = read_fut => read_result,
|
||||
_ = &mut cancel => { break; }
|
||||
}?;
|
||||
}? else {
|
||||
break;
|
||||
};
|
||||
|
||||
let name = std::str::from_utf8(name_buf_write.filled()).into_diagnostic()?;
|
||||
|
||||
let decoder = self
|
||||
.payload_decoders
|
||||
let decoder = decoders
|
||||
.get(name)
|
||||
.ok_or_else(|| miette!("No such RPC call name: {}", name))?
|
||||
.as_ref();
|
||||
let listener = self
|
||||
.listeners
|
||||
let listener = listeners
|
||||
.get(name)
|
||||
.ok_or_else(|| miette!("No such RPC call name: {}", name))?
|
||||
.as_ref();
|
||||
.clone();
|
||||
|
||||
let packet = match decoder(payload.filled()) {
|
||||
Ok(p) => p,
|
||||
|
@ -393,6 +458,7 @@ impl RpcCallDecoder {
|
|||
};
|
||||
|
||||
yield (packet, listener);
|
||||
messages += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue