From 52cced9537248246ea4d24312f3268865d1f855d Mon Sep 17 00:00:00 2001 From: Natty Date: Sat, 16 Nov 2024 03:34:43 +0100 Subject: [PATCH] Allow concurrent RPC calls on the server --- src/rpc_v1/proto.rs | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/src/rpc_v1/proto.rs b/src/rpc_v1/proto.rs index aa708eb..abf011c 100644 --- a/src/rpc_v1/proto.rs +++ b/src/rpc_v1/proto.rs @@ -14,7 +14,7 @@ use std::sync::Arc; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::net::{TcpListener, UnixSocket}; use tokio::select; -use tokio::task::JoinSet; +use tokio::task::{JoinError, JoinSet}; use tracing::{debug, error, info, warn, Instrument}; #[derive(Debug, Clone)] @@ -225,14 +225,17 @@ impl MagRpc { r.ok() }) - .filter_map(|(serial, payload, listener)| { - let ctx = context.clone(); - async move { Some((serial, listener(ctx, payload).await?)) } - }); + .map(process(context)) + .buffer_unordered(100) + .boxed(); futures::pin_mut!(src); - while let Some((serial, RpcResponse(bytes))) = src.next().await { + while let Some(result) = src.next().await { + let Ok(Some((serial, RpcResponse(bytes)))) = result else { + continue; + }; + write_half.write_u8(b'M').await.into_diagnostic()?; write_half.write_u64(serial).await.into_diagnostic()?; write_half @@ -319,14 +322,17 @@ impl MagRpc { r.ok() }) - .filter_map(|(serial, payload, listener)| { - let ctx = context.clone(); - async move { Some((serial, listener(ctx, payload).await?)) } - }); + .map(process(context)) + .buffer_unordered(100) + .boxed(); futures::pin_mut!(src); - while let Some((serial, RpcResponse(bytes))) = src.next().await { + while let Some(result) = src.next().await { + let Ok(Some((serial, RpcResponse(bytes)))) = result else { + continue; + }; + write_half.write_u8(b'M').await.into_diagnostic()?; write_half.write_u64(serial).await.into_diagnostic()?; write_half @@ -359,6 +365,19 @@ impl MagRpc { } } +fn process( + context: Arc, +) -> impl Fn( + (u64, MessageRaw, Arc), +) -> Pin< + Box, JoinError>> + Send + 'static>, +> { + move |(serial, payload, listener)| { + let ctx = context.clone(); + tokio::task::spawn(async move { Some((serial, listener(ctx, payload).await?)) }).boxed() + } +} + #[derive(Clone)] struct RpcCallDecoder { listeners: Arc>>,