diff --git a/src/rpc_v1/proto.rs b/src/rpc_v1/proto.rs index 7a7d0b9..8c140fd 100644 --- a/src/rpc_v1/proto.rs +++ b/src/rpc_v1/proto.rs @@ -3,19 +3,20 @@ use either::Either; use futures::{FutureExt, Stream, TryFutureExt, TryStreamExt}; use futures_util::future::BoxFuture; use futures_util::stream::FuturesUnordered; -use futures_util::{pin_mut, StreamExt}; +use futures_util::{pin_mut, SinkExt, StreamExt}; use miette::{miette, IntoDiagnostic}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::any::Any; use std::collections::HashMap; use std::fmt::Debug; +use std::future; use std::future::Future; use std::net::SocketAddr; use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; use tokio::net::{TcpListener, UnixSocket}; use tokio::select; use tokio::task::JoinSet; @@ -209,6 +210,8 @@ impl MagRpc { loop { let (stream, remote_addr) = select!( + biased; + _ = &mut cancel => break, Some(c) = connections.join_next() => { debug!("RPC TCP connection closed: {:?}", c); continue; @@ -221,7 +224,6 @@ impl MagRpc { conn.unwrap() }, - _ = &mut cancel => break ); debug!("RPC TCP connection accepted: {:?}", remote_addr); @@ -275,6 +277,8 @@ impl MagRpc { loop { let (stream, remote_addr) = select!( + biased; + _ = &mut cancel => break, Some(c) = connections.join_next() => { debug!("RPC Unix connection closed: {:?}", c); continue; @@ -286,8 +290,7 @@ impl MagRpc { } conn.unwrap() - }, - _ = &mut cancel => break + } ); debug!("RPC Unix connection accepted: {:?}", remote_addr); @@ -295,7 +298,7 @@ impl MagRpc { let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>(); let (read_half, write_half) = stream.into_split(); let handler_fut = handle_process( - rx_dec.stream_decode(BufReader::new(read_half), cancel_recv), + rx_dec.stream_decode(BufReader::with_capacity(64 * 1024, read_half), cancel_recv), BufWriter::new(write_half), context.clone(), ) @@ -349,13 +352,24 @@ async fn handle_process( let results = FuturesUnordered::new(); pin_mut!(results); - let src = task_stream.map_ok(process(context)); - pin_mut!(src); + let (tx, rx) = futures::channel::mpsc::unbounded(); + pin_mut!(rx); + let input_stream = tokio::spawn( + task_stream + .map_ok(process(context)) + .boxed() + .forward(tx.sink_map_err(|e| miette!(e))) + ); + pin_mut!(input_stream); loop { select!( - Some(task) = src.next() => { - results.push(task?); + biased; + _ = &mut input_stream => { + break; + } + Some(task) = rx.next() => { + results.push(task); } Some(res) = results.next() => { let (serial, result) = res?; @@ -382,6 +396,7 @@ fn process( let start = Instant::now(); let res = listener(ctx, payload).await; let took = start.elapsed(); + // TODO: Extract this into a config if took.as_secs_f64() > 50.0 { warn!("Handler took long: {} sec", took.as_secs_f64()); } @@ -400,9 +415,9 @@ struct RpcCallDecoder { } impl RpcCallDecoder { - fn stream_decode( + fn stream_decode( &self, - mut buf_read: BufReader, + mut buf_read: impl AsyncBufRead + Send + Unpin + 'static, mut cancel: tokio::sync::oneshot::Receiver<()>, ) -> impl Stream)>> + Send + 'static { @@ -465,8 +480,9 @@ impl RpcCallDecoder { }; let Some((serial, name_slice, payload_slice)) = select! { - read_result = read_fut => read_result, + biased; _ = &mut cancel => { break; } + read_result = read_fut => read_result, }? else { break; };