From 2acc41587a64bc13d21d12872a673348d55ce509 Mon Sep 17 00:00:00 2001 From: Natty Date: Sat, 16 Nov 2024 19:29:46 +0100 Subject: [PATCH] Drop RPC connections receiving bogus data --- src/rpc_v1/proto.rs | 40 ++++++++++++++-------------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/src/rpc_v1/proto.rs b/src/rpc_v1/proto.rs index 22c1cf9..1ca166b 100644 --- a/src/rpc_v1/proto.rs +++ b/src/rpc_v1/proto.rs @@ -1,6 +1,6 @@ use crate::service::MagnetarService; -use futures::{FutureExt, Stream, StreamExt}; -use miette::{miette, IntoDiagnostic}; +use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; +use miette::{miette, Error, IntoDiagnostic}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::any::Any; @@ -218,21 +218,14 @@ impl MagRpc { let fut = async move { let src = rx_dec .stream_decode(buf_read, cancel_recv) - .filter_map(|r| async move { - if let Err(e) = &r { - error!("Stream decoding error: {e}"); - } - - r.ok() - }) - .map(process(context)) - .buffer_unordered(100) + .map_ok(process(context)) + .try_buffer_unordered(100) .boxed(); futures::pin_mut!(src); - while let Some(result) = src.next().await { - let Ok(Some((serial, RpcResponse(bytes)))) = result else { + while let Some(result) = src.try_next().await? { + let Some((serial, RpcResponse(bytes))) = result else { continue; }; @@ -315,21 +308,14 @@ impl MagRpc { let fut = async move { let src = rx_dec .stream_decode(buf_read, cancel_recv) - .filter_map(|r| async move { - if let Err(e) = &r { - error!("Stream decoding error: {e}"); - } - - r.ok() - }) - .map(process(context)) - .buffer_unordered(100) + .map_ok(process(context)) + .try_buffer_unordered(100) .boxed(); futures::pin_mut!(src); - while let Some(result) = src.next().await { - let Ok(Some((serial, RpcResponse(bytes)))) = result else { + while let Some(result) = src.try_next().await? { + let Some((serial, RpcResponse(bytes))) = result else { continue; }; @@ -370,11 +356,13 @@ fn process( ) -> impl Fn( (u64, MessageRaw, Arc), ) -> Pin< - Box, JoinError>> + Send + 'static>, + Box>> + Send + 'static>, > { move |(serial, payload, listener)| { let ctx = context.clone(); - tokio::task::spawn(async move { Some((serial, listener(ctx, payload).await?)) }).boxed() + tokio::task::spawn(async move { Some((serial, listener(ctx, payload).await?)) }) + .map_err(|e| miette!(e)) + .boxed() } }