From 9c42b20fa9d9c335fa384036ad886c8c6ab03589 Mon Sep 17 00:00:00 2001 From: Natty Date: Tue, 12 Nov 2024 22:37:18 +0100 Subject: [PATCH] Implemented rudimentary RPC --- Cargo.lock | 58 ++++- Cargo.toml | 8 +- config/default.toml | 15 +- magnetar_common/src/config.rs | 55 ++++- magnetar_runtime/Cargo.toml | 21 ++ magnetar_runtime/src/lib.rs | 1 + src/main.rs | 58 ++++- src/rpc_v1/mod.rs | 1 + src/rpc_v1/proto.rs | 399 ++++++++++++++++++++++++++++++++++ 9 files changed, 602 insertions(+), 14 deletions(-) create mode 100644 magnetar_runtime/Cargo.toml create mode 100644 magnetar_runtime/src/lib.rs create mode 100644 src/rpc_v1/mod.rs create mode 100644 src/rpc_v1/proto.rs diff --git a/Cargo.lock b/Cargo.lock index 7e9bcc6..4335a93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1144,8 +1144,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -1769,6 +1771,7 @@ dependencies = [ "async-stream", "axum", "axum-extra", + "bytes", "cached", "cfg-if", "chrono", @@ -1779,7 +1782,6 @@ dependencies = [ "futures-util", "headers", "hyper", - "idna 1.0.2", "itertools", "lru", "magnetar_common", @@ -1788,12 +1790,13 @@ dependencies = [ "magnetar_host_meta", "magnetar_model", "magnetar_nodeinfo", + "magnetar_runtime", "magnetar_sdk", "magnetar_webfinger", "miette", "percent-encoding", "quick-xml", - "regex", + "rmp-serde", "serde", "serde_json", "serde_urlencoded", @@ -1806,6 +1809,7 @@ dependencies = [ "tower-http", "tracing", "tracing-subscriber", + "ulid", "unicode-segmentation", "url", ] @@ -1929,6 +1933,7 @@ dependencies = [ "nom_locate", "quick-xml", "serde", + "smallvec", "strum", "tracing", "unicode-segmentation", @@ -1966,6 +1971,22 @@ dependencies = [ "serde_json", ] +[[package]] +name = "magnetar_runtime" +version = "0.3.0-alpha" +dependencies = [ + "either", + "futures-channel", + "futures-core", + "futures-util", + "itertools", + "magnetar_core", + "magnetar_sdk", + "miette", + "thiserror", + "tracing", +] + [[package]] name = "magnetar_sdk" version = "0.3.0-alpha" @@ -2822,6 +2843,28 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "rmp" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" +dependencies = [ + "byteorder", + "num-traits", + "paste", +] + +[[package]] +name = "rmp-serde" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + [[package]] name = "rsa" version = "0.9.6" @@ -4215,6 +4258,17 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9" +[[package]] +name = "ulid" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04f903f293d11f31c0c29e4148f6dc0d033a7f80cebc0282bea147611667d289" +dependencies = [ + "getrandom", + "rand", + "web-time", +] + [[package]] name = "unic-char-property" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index e14158d..09c1788 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ members = [ "ext_model", "fe_calckey", "magnetar_common", + "magnetar_runtime", "magnetar_sdk", "magnetar_mmm_parser", "core", @@ -30,6 +31,7 @@ async-stream = "0.3" axum = "0.7" axum-extra = "0.9" base64 = "0.22" +bytes = "1.7" cached = "0.53" cfg-if = "1" chrono = "0.4" @@ -58,6 +60,7 @@ priority-queue = "2.0" quick-xml = "0.36" redis = "0.26" regex = "1.9" +rmp-serde = "1.3" rsa = "0.9" reqwest = "0.12" sea-orm = "1" @@ -92,6 +95,7 @@ magnetar_host_meta = { path = "./ext_host_meta" } magnetar_webfinger = { path = "./ext_webfinger" } magnetar_nodeinfo = { path = "./ext_nodeinfo" } magnetar_model = { path = "./ext_model" } +magnetar_runtime = { path = "./magnetar_runtime" } magnetar_sdk = { path = "./magnetar_sdk" } cached = { workspace = true } @@ -111,9 +115,6 @@ tower = { workspace = true } tower-http = { workspace = true, features = ["cors", "trace", "fs"] } ulid = { workspace = true } url = { workspace = true } -idna = { workspace = true } - -regex = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } tracing = { workspace = true } @@ -132,6 +133,7 @@ thiserror = { workspace = true } percent-encoding = { workspace = true } +rmp-serde = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } serde_urlencoded = { workspace = true } diff --git a/config/default.toml b/config/default.toml index a855fba..aa409c2 100644 --- a/config/default.toml +++ b/config/default.toml @@ -59,6 +59,20 @@ # Environment variable: MAG_C_PROXY_REMOTE_FILES # networking.proxy_remote_files = false +# ------------------------------[ RPC CONNECTION ]----------------------------- + +# [Optional] +# A type of connection to use for the application's internal RPC +# Possible values: "none", "tcp", "unix" +# Default: "none" +# Environment variable: MAG_C_RPC_CONNECTION_TYPE +# rpc.connection_type = "none" + +# [Optional] +# The corresponding bind address (or path for Unix-domain sockets) for the internal RPC +# Default: "" +# Environment variable: MAG_C_RPC_BIND_ADDR +# rpc.bind_addr = "" # -----------------------------[ CALCKEY FRONTEND ]---------------------------- @@ -83,7 +97,6 @@ # -------------------------------[ FEDERATION ]-------------------------------- - # --------------------------------[ BRANDING ]--------------------------------- # [Optional] diff --git a/magnetar_common/src/config.rs b/magnetar_common/src/config.rs index 02ec918..09a9f6e 100644 --- a/magnetar_common/src/config.rs +++ b/magnetar_common/src/config.rs @@ -1,6 +1,7 @@ use serde::Deserialize; use std::fmt::{Display, Formatter}; -use std::net::IpAddr; +use std::net::{IpAddr, SocketAddr}; +use std::path::PathBuf; use thiserror::Error; #[derive(Deserialize, Debug)] @@ -94,6 +95,56 @@ impl Default for MagnetarNetworking { } } +#[derive(Deserialize, Debug, Default, Clone)] +#[serde( + rename_all = "snake_case", + tag = "connection_type", + content = "bind_addr" +)] +pub enum MagnetarRpcSocketKind { + #[default] + None, + Unix(PathBuf), + Tcp(SocketAddr), +} + +#[derive(Deserialize, Debug)] +#[non_exhaustive] +pub struct MagnetarRpcConfig { + pub connection_settings: MagnetarRpcSocketKind, +} + +fn env_rpc_connection() -> MagnetarRpcSocketKind { + match std::env::var("MAG_C_RPC_CONNECTION_TYPE") + .unwrap_or_else(|_| "none".to_owned()) + .to_lowercase() + .as_str() + { + "none" => MagnetarRpcSocketKind::None, + "unix" => MagnetarRpcSocketKind::Unix( + std::env::var("MAG_C_RPC_BIND_ADDR") + .unwrap_or_default() + .parse() + .expect("MAG_C_RPC_BIND_ADDR must be a valid path"), + ), + "tcp" => MagnetarRpcSocketKind::Tcp( + std::env::var("MAG_C_RPC_BIND_ADDR") + .unwrap_or_default() + .parse() + .expect("MAG_C_RPC_BIND_ADDR must be a valid socket address"), + ), + _ => panic!("MAG_C_RPC_CONNECTION_TYPE must be a valid protocol or 'none'"), + } +} + +impl Default for MagnetarRpcConfig { + fn default() -> Self { + MagnetarRpcConfig { + connection_settings: env_rpc_connection(), + } + } +} + #[derive(Deserialize, Debug)] #[non_exhaustive] pub struct MagnetarCalckeyFrontendConfig { @@ -196,6 +247,8 @@ pub struct MagnetarConfig { #[serde(default)] pub networking: MagnetarNetworking, #[serde(default)] + pub rpc: MagnetarRpcConfig, + #[serde(default)] pub branding: MagnetarBranding, #[serde(default)] pub calckey_frontend: MagnetarCalckeyFrontendConfig, diff --git a/magnetar_runtime/Cargo.toml b/magnetar_runtime/Cargo.toml new file mode 100644 index 0000000..2507600 --- /dev/null +++ b/magnetar_runtime/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "magnetar_runtime" +version.workspace = true +edition.workspace = true + +[lib] +crate-type = ["rlib"] + +[dependencies] +magnetar_core = { path = "../core" } +magnetar_sdk = { path = "../magnetar_sdk" } + +either = { workspace = true } +futures-channel = { workspace = true } +futures-util = { workspace = true } +futures-core = { workspace = true } +itertools = { workspace = true } +miette = { workspace = true } +thiserror = { workspace = true } + +tracing = { workspace = true } diff --git a/magnetar_runtime/src/lib.rs b/magnetar_runtime/src/lib.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/magnetar_runtime/src/lib.rs @@ -0,0 +1 @@ + diff --git a/src/main.rs b/src/main.rs index 0327886..b006b9f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ mod api_v1; pub mod host_meta; pub mod model; pub mod nodeinfo; +mod rpc_v1; pub mod service; pub mod util; pub mod web; @@ -14,16 +15,21 @@ use crate::service::MagnetarService; use axum::routing::get; use axum::Router; use dotenvy::dotenv; +use futures::{select, FutureExt}; +use magnetar_common::config::{MagnetarConfig, MagnetarRpcSocketKind}; use magnetar_model::{CacheConnectorConfig, CalckeyCache, CalckeyModel, ConnectorConfig}; use miette::{miette, IntoDiagnostic}; +use rpc_v1::proto::{MagRpc, RpcMessage, RpcSockAddr}; +use std::convert::Infallible; +use std::future::Future; use std::net::SocketAddr; use std::sync::Arc; use tokio::net::TcpListener; use tokio::signal; use tower_http::cors::{Any, CorsLayer}; use tower_http::trace::TraceLayer; -use tracing::info; use tracing::log::error; +use tracing::{debug, info}; use tracing_subscriber::EnvFilter; #[tokio::main] @@ -47,15 +53,15 @@ async fn main() -> miette::Result<()> { let db = CalckeyModel::new(ConnectorConfig { url: config.data.database_url.clone(), }) - .await - .into_diagnostic()?; + .await + .into_diagnostic()?; db.migrate().await.into_diagnostic()?; let redis = CalckeyCache::new(CacheConnectorConfig { url: config.data.redis_url.clone(), }) - .into_diagnostic()?; + .into_diagnostic()?; let service = Arc::new( MagnetarService::new(config, db.clone(), redis) @@ -63,10 +69,48 @@ async fn main() -> miette::Result<()> { .into_diagnostic()?, ); + let shutdown_signal = shutdown_signal().shared(); + + select! { + rpc_res = run_rpc(service.clone(), config, shutdown_signal.clone()).fuse() => rpc_res, + web_res = run_web(service, config, shutdown_signal).fuse() => web_res + } +} + +async fn run_rpc( + service: Arc, + config: &'static MagnetarConfig, + shutdown_signal: impl Future + Send + 'static, +) -> miette::Result<()> { + let rpc_bind_addr = match &config.rpc.connection_settings { + MagnetarRpcSocketKind::None => { + std::future::pending::().await; + unreachable!(); + } + MagnetarRpcSocketKind::Unix(path) => RpcSockAddr::Unix(path.clone()), + MagnetarRpcSocketKind::Tcp(ip) => RpcSockAddr::Ip(*ip), + }; + + let rpc = MagRpc::new().handle( + "/ping", + |_, RpcMessage(message): RpcMessage| async move { + debug!("Received RPC ping: {}", message); + RpcMessage("pong".to_owned()) + }, + ); + + rpc.run(service, rpc_bind_addr, Some(shutdown_signal)).await +} + +async fn run_web( + service: Arc, + config: &'static MagnetarConfig, + shutdown_signal: impl Future + Send + 'static, +) -> miette::Result<()> { let well_known_router = Router::new() .route( "/webfinger", - get(webfinger::handle_webfinger).with_state((config, db)), + get(webfinger::handle_webfinger).with_state((config, service.db.clone())), ) .route("/host-meta", get(handle_host_meta)) .route("/nodeinfo", get(handle_nodeinfo)); @@ -93,7 +137,7 @@ async fn main() -> miette::Result<()> { let listener = TcpListener::bind(addr).await.into_diagnostic()?; info!("Serving..."); axum::serve(listener, app.into_make_service()) - .with_graceful_shutdown(shutdown_signal()) + .with_graceful_shutdown(shutdown_signal) .await .map_err(|e| miette!("Error running server: {}", e)) } @@ -121,5 +165,5 @@ async fn shutdown_signal() { _ = terminate => {}, } - info!("Shutting down..."); + info!("Received a signal to shut down..."); } diff --git a/src/rpc_v1/mod.rs b/src/rpc_v1/mod.rs new file mode 100644 index 0000000..febacec --- /dev/null +++ b/src/rpc_v1/mod.rs @@ -0,0 +1 @@ +pub mod proto; diff --git a/src/rpc_v1/proto.rs b/src/rpc_v1/proto.rs new file mode 100644 index 0000000..2795bd8 --- /dev/null +++ b/src/rpc_v1/proto.rs @@ -0,0 +1,399 @@ +use crate::service::MagnetarService; +use bytes::BufMut; +use futures::{FutureExt, Stream, StreamExt}; +use miette::{miette, IntoDiagnostic}; +use serde::{Deserialize, Serialize}; +use std::any::Any; +use std::collections::HashMap; +use std::future::Future; +use std::net::SocketAddr; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncReadExt, BufReader, ReadBuf}; +use tokio::net::{TcpListener, UnixSocket}; +use tokio::select; +use tokio::task::JoinSet; +use tracing::{debug, error, Instrument}; + +#[derive(Debug, Clone)] +pub enum RpcSockAddr { + Ip(SocketAddr), + Unix(PathBuf), +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[repr(transparent)] +pub struct RpcMessage(pub T); + +pub trait IntoRpcResponse: Send { + fn into_rpc_response(self) -> Option; +} + +pub struct RpcResponse(Vec); + +impl IntoRpcResponse for RpcMessage { + fn into_rpc_response(self) -> Option { + rmp_serde::to_vec(&self) + .inspect_err(|e| { + error!( + "Failed to serialize value of type {}: {}", + std::any::type_name::(), + e + ) + }) + .ok() + .map(RpcResponse) + } +} + +#[derive(Debug, Serialize)] +pub struct RpcResult { + success: bool, + data: T, +} + +impl IntoRpcResponse +for Result +{ + fn into_rpc_response(self) -> Option { + match self { + Ok(data) => RpcMessage(RpcResult { + success: true, + data, + }) + .into_rpc_response(), + Err(data) => RpcMessage(RpcResult { + success: false, + data, + }) + .into_rpc_response(), + } + } +} + +pub trait RpcHandler: Send + Sync + 'static +where + T: Send + 'static, +{ + fn process( + &self, + context: Arc, + message: RpcMessage, + ) -> impl Future> + Send; +} + +impl RpcHandler for F +where + T: Send + 'static, + F: Fn(Arc, RpcMessage) -> Fut + Send + Sync + 'static, + Fut: Future + Send, + RR: IntoRpcResponse, +{ + async fn process( + &self, + context: Arc, + message: RpcMessage, + ) -> Option { + self(context, message).await.into_rpc_response() + } +} + +type MagRpcHandlerMapped = dyn Fn( + Arc, + Box, +) -> Pin> + Send + 'static>> ++ Send ++ Sync; + +type MagRpcDecoderMapped = +dyn Fn(&[u8]) -> Result, rmp_serde::decode::Error> + Send + Sync; + +pub struct MagRpc { + listeners: HashMap>, + payload_decoders: HashMap>, +} + +impl MagRpc { + pub fn new() -> Self { + MagRpc { + listeners: HashMap::new(), + payload_decoders: HashMap::new(), + } + } + + pub fn handle(mut self, method: impl Into, handler: H) -> Self + where + T: Send + 'static, + H: RpcHandler + Sync + 'static, + { + let handler_ref = Arc::new(handler); + self.listeners.insert( + method.into(), + Box::new(move |ctx, data| { + let handler = handler_ref.clone(); + async move { + handler + .process(ctx, RpcMessage(*data.downcast().unwrap())) + .await + } + .boxed() + }), + ); + + self + } + + pub async fn run( + self, + context: Arc, + addr: RpcSockAddr, + graceful_shutdown: Option>, + ) -> miette::Result<()> { + match addr { + RpcSockAddr::Ip(sock_addr) => { + self.run_tcp(context, &sock_addr, graceful_shutdown).await + } + RpcSockAddr::Unix(path) => self.run_unix(context, &path, graceful_shutdown).await, + } + } + + async fn run_tcp( + self, + context: Arc, + sock_addr: &SocketAddr, + graceful_shutdown: Option>, + ) -> miette::Result<()> { + debug!("Binding RPC socket to {}", sock_addr); + let listener = TcpListener::bind(sock_addr).await.into_diagnostic()?; + + let (sender, mut cancel) = tokio::sync::oneshot::channel::<()>(); + let mut cancellation_tokens = Vec::new(); + + let rx_dec = RpcCallDecoder { + listeners: Arc::new(self.listeners), + payload_decoders: Arc::new(self.payload_decoders), + }; + + let mut connections = JoinSet::new(); + + loop { + let (stream, sock_addr) = select!( + _ = connections.join_next() => continue, + conn = listener.accept() => { + if let Err(e) = conn { + error!("Connection error: {}", e); + break + } + + conn.unwrap() + }, + _ = &mut cancel => break + ); + + debug!("RPC TCP connection accepted: {:?}", sock_addr); + + let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>(); + let buf_read = BufReader::new(stream); + let context = context.clone(); + let rx_dec = rx_dec.clone(); + let fut = async move { + rx_dec + .stream_decode(buf_read, cancel_recv) + .filter_map(|r| async move { + if let Err(e) = &r { + error!("Stream decoding error: {e}"); + } + + r.ok() + }) + .for_each_concurrent(Some(32), |(payload, listener)| async { + if let Some(response) = listener(context.clone(), payload).await { + // TODO: Respond + } + }) + .await; + + miette::Result::<()>::Ok(()) + } + .instrument(tracing::info_span!("RPC", sock_addr = ?sock_addr)); + + connections.spawn_local(fut); + + cancellation_tokens.push(cancel_send); + } + + if let Some(graceful_shutdown) = graceful_shutdown { + graceful_shutdown.await; + sender.send(()).ok(); + } + + connections.join_all().await; + + Ok(()) + } + + async fn run_unix( + self, + context: Arc, + addr: &Path, + graceful_shutdown: Option>, + ) -> miette::Result<()> { + let sock = UnixSocket::new_stream().into_diagnostic()?; + sock.bind(addr).into_diagnostic()?; + let listener = sock.listen(16).into_diagnostic()?; + + let (sender, mut cancel) = tokio::sync::oneshot::channel::<()>(); + let mut cancellation_tokens = Vec::new(); + + let rx_dec = RpcCallDecoder { + listeners: Arc::new(self.listeners), + payload_decoders: Arc::new(self.payload_decoders), + }; + + let mut connections = JoinSet::new(); + + loop { + let (stream, sock_addr) = select!( + _ = connections.join_next() => continue, + conn = listener.accept() => { + if let Err(e) = conn { + error!("Connection error: {}", e); + break + } + + conn.unwrap() + }, + _ = &mut cancel => break + ); + + debug!("RPC Unix connection accepted: {:?}", sock_addr); + + let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>(); + let buf_read = BufReader::new(stream); + let context = context.clone(); + let rx_dec = rx_dec.clone(); + let fut = async move { + rx_dec + .stream_decode(buf_read, cancel_recv) + .filter_map(|r| async move { + if let Err(e) = &r { + error!("Stream decoding error: {e}"); + } + + r.ok() + }) + .for_each_concurrent(Some(32), |(payload, listener)| async { + if let Some(response) = listener(context.clone(), payload).await { + // TODO: Respond + } + }) + .await; + + miette::Result::<()>::Ok(()) + } + .instrument(tracing::info_span!("RPC", sock_addr = ?sock_addr)); + + connections.spawn_local(fut); + + cancellation_tokens.push(cancel_send); + } + + if let Some(graceful_shutdown) = graceful_shutdown { + graceful_shutdown.await; + sender.send(()).ok(); + } + + connections.join_all().await; + + Ok(()) + } +} + +#[derive(Clone)] +struct RpcCallDecoder { + listeners: Arc>>, + payload_decoders: Arc>>, +} + +impl RpcCallDecoder { + fn stream_decode( + &self, + mut buf_read: BufReader, + mut cancel: tokio::sync::oneshot::Receiver<()>, + ) -> impl Stream, &MagRpcHandlerMapped)>> + Send + { + async_stream::try_stream! { + let mut name_buf = Vec::new(); + let mut buf = Vec::new(); + + loop { + let read_fut = async { + let name_len = buf_read.read_u32().await.into_diagnostic()? as usize; + + if name_len > name_buf.capacity() { + name_buf.reserve(name_len - name_buf.capacity()); + } + + // SAFETY: We use ReadBuf which expects uninit anyway + unsafe { + name_buf.set_len(name_len); + } + + let mut name_buf_write = ReadBuf::uninit(&mut name_buf); + while name_buf_write.has_remaining_mut() { + buf_read + .read_buf(&mut name_buf_write) + .await + .into_diagnostic()?; + } + + let payload_len = buf_read.read_u32().await.into_diagnostic()? as usize; + if payload_len > buf.capacity() { + buf.reserve(payload_len - buf.capacity()); + } + + // SAFETY: We use ReadBuf which expects uninit anyway + unsafe { + buf.set_len(payload_len); + } + + let mut buf_write = ReadBuf::uninit(&mut buf); + while buf_write.has_remaining_mut() { + buf_read.read_buf(&mut buf_write).await.into_diagnostic()?; + } + + miette::Result::<_>::Ok((name_buf_write, buf_write)) + }; + + let (name_buf_write, payload) = select! { + read_result = read_fut => read_result, + _ = &mut cancel => { break; } + }?; + + let name = std::str::from_utf8(name_buf_write.filled()).into_diagnostic()?; + + let decoder = self + .payload_decoders + .get(name) + .ok_or_else(|| miette!("No such RPC call name: {}", name))? + .as_ref(); + let listener = self + .listeners + .get(name) + .ok_or_else(|| miette!("No such RPC call name: {}", name))? + .as_ref(); + + let packet = match decoder(payload.filled()) { + Ok(p) => p, + Err(e) => { + error!("Failed to parse packet: {e}"); + continue; + } + }; + + yield (packet, listener); + } + } + } +}