diff --git a/Cargo.lock b/Cargo.lock index d210d42..f5ab938 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "Inflector" @@ -4198,6 +4198,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 068f26f..7c266c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -113,7 +113,7 @@ headers = { workspace = true } hyper = { workspace = true, features = ["full"] } reqwest = { workspace = true, features = ["hickory-dns"] } tokio = { workspace = true, features = ["full"] } -tokio-stream = { workspace = true } +tokio-stream = { workspace = true, features = ["full"] } tower = { workspace = true } tower-http = { workspace = true, features = ["cors", "trace", "fs"] } ulid = { workspace = true } diff --git a/src/rpc_v1/proto.rs b/src/rpc_v1/proto.rs index 41c99ec..45b33d1 100644 --- a/src/rpc_v1/proto.rs +++ b/src/rpc_v1/proto.rs @@ -1,6 +1,9 @@ use crate::service::MagnetarService; use either::Either; -use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; +use futures::{FutureExt, Stream, TryFutureExt, TryStreamExt}; +use futures_util::future::BoxFuture; +use futures_util::stream::FuturesUnordered; +use futures_util::{pin_mut, StreamExt}; use miette::{miette, IntoDiagnostic}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; @@ -12,7 +15,7 @@ use std::net::SocketAddr; use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; use tokio::net::{TcpListener, UnixSocket}; use tokio::select; use tokio::task::JoinSet; @@ -224,39 +227,14 @@ impl MagRpc { let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>(); let (read_half, write_half) = stream.into_split(); - let buf_read = BufReader::new(read_half); - let mut buf_write = BufWriter::new(write_half); - let context = context.clone(); - let rx_dec = rx_dec.clone(); - let fut = async move { - let src = rx_dec - .stream_decode(buf_read, cancel_recv) - .map_ok(process(context)) - .try_buffer_unordered(400) - .boxed(); - - futures::pin_mut!(src); - - while let Some(result) = src.try_next().await? { - let Some((serial, RpcResponse(bytes))) = result else { - continue; - }; - - buf_write.write_u8(b'M').await.into_diagnostic()?; - buf_write.write_u64(serial).await.into_diagnostic()?; - buf_write - .write_u32(bytes.len() as u32) - .await - .into_diagnostic()?; - buf_write.write_all(&bytes).await.into_diagnostic()?; - buf_write.flush().await.into_diagnostic()?; - } - - Ok(remote_addr) - } + let handler_fut = handle_process( + rx_dec.stream_decode(BufReader::new(read_half), cancel_recv), + BufWriter::new(write_half), + context.clone(), + ) .instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr)); - connections.spawn(fut); + connections.spawn(handler_fut); cancellation_tokens.push(cancel_send); } @@ -315,39 +293,14 @@ impl MagRpc { let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>(); let (read_half, write_half) = stream.into_split(); - let buf_read = BufReader::new(read_half); - let mut buf_write = BufWriter::new(write_half); - let context = context.clone(); - let rx_dec = rx_dec.clone(); - let fut = async move { - let src = rx_dec - .stream_decode(buf_read, cancel_recv) - .map_ok(process(context)) - .try_buffer_unordered(400) - .boxed(); - - futures::pin_mut!(src); - - while let Some(result) = src.try_next().await? { - let Some((serial, RpcResponse(bytes))) = result else { - continue; - }; - - buf_write.write_u8(b'M').await.into_diagnostic()?; - buf_write.write_u64(serial).await.into_diagnostic()?; - buf_write - .write_u32(bytes.len() as u32) - .await - .into_diagnostic()?; - buf_write.write_all(&bytes).await.into_diagnostic()?; - buf_write.flush().await.into_diagnostic()?; - } - - miette::Result::<()>::Ok(()) - } + let handler_fut = handle_process( + rx_dec.stream_decode(BufReader::new(read_half), cancel_recv), + BufWriter::new(write_half), + context.clone(), + ) .instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr)); - connections.spawn(fut.boxed()); + connections.spawn(handler_fut); cancellation_tokens.push(cancel_send); } @@ -365,16 +318,66 @@ impl MagRpc { } } +async fn write_response( + mut buf_write: Pin<&mut BufWriter>, + serial: u64, + result: Option, +) -> miette::Result<()> { + let header = if result.is_some() { b'M' } else { b'F' }; + buf_write.write_u8(header).await.into_diagnostic()?; + buf_write.write_u64(serial).await.into_diagnostic()?; + + let Some(RpcResponse(bytes)) = result else { + return Ok(()); + }; + + buf_write + .write_u32(bytes.len() as u32) + .await + .into_diagnostic()?; + buf_write.write_all(&bytes).await.into_diagnostic()?; + buf_write.flush().await.into_diagnostic()?; + Ok(()) +} + +async fn handle_process( + task_stream: impl Stream)>> + Send + 'static, + mut buf_write: BufWriter, + context: Arc, +) -> miette::Result<()> { + let results = FuturesUnordered::new(); + pin_mut!(results); + + let src = task_stream.map_ok(process(context)); + pin_mut!(src); + + loop { + select!( + Some(task) = src.next() => { + results.push(task?); + } + Some(res) = results.next() => { + let (serial, result) = res?; + write_response(Pin::new(&mut buf_write), serial, result) + .await?; + } + else => { + break; + } + ); + } + + Ok(()) +} + fn process( context: Arc, ) -> impl Fn( (u64, MessageRaw, Arc), -) -> Pin< - Box>> + Send + 'static>, -> { +) -> BoxFuture<'static, miette::Result<(u64, Option)>> { move |(serial, payload, listener)| { let ctx = context.clone(); - tokio::task::spawn(async move { Some((serial, listener(ctx, payload).await?)) }) + tokio::task::spawn(async move { (serial, listener(ctx, payload).await) }) .map_err(|e| miette!(e)) .boxed() }