Compare commits

..

10 Commits

Author SHA1 Message Date
Natty 5fb85e0db6
Implemented AP fetching via RPC
ci/woodpecker/push/ociImagePush Pipeline failed Details
2024-11-15 20:39:13 +01:00
Natty 80c5bf8ae6
Fixed RPC and implemented responses 2024-11-14 17:33:33 +01:00
Natty b9160305f1
Streamlined error handling 2024-11-13 13:41:06 +01:00
Natty 9c42b20fa9
Implemented rudimentary RPC 2024-11-12 22:37:18 +01:00
Natty 69c126d860
Code cleanup 2024-11-12 22:28:55 +01:00
Natty 7581ecf331
Code cleanup 2024-11-12 22:22:32 +01:00
Natty 766fd8ea7d
Added a service for generating IDs 2024-11-12 22:20:57 +01:00
Natty 5241b18b0d
Also cache the local user key pair 2024-11-12 22:08:18 +01:00
Natty 88df8eca55
Implemented key parsing 2024-11-12 22:00:37 +01:00
Natty 62fc36ff03
Updated comment 2024-09-06 00:04:28 +02:00
32 changed files with 1414 additions and 385 deletions

103
Cargo.lock generated
View File

@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
version = 4
[[package]]
name = "Inflector"
@ -1144,8 +1144,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi",
"wasm-bindgen",
]
[[package]]
@ -1693,6 +1695,17 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "kdl"
version = "4.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "062c875482ccb676fd40c804a40e3824d4464c18c364547456d1c8e8e951ae47"
dependencies = [
"miette 5.10.0",
"nom",
"thiserror",
]
[[package]]
name = "lazy_static"
version = "1.5.0"
@ -1769,6 +1782,7 @@ dependencies = [
"async-stream",
"axum",
"axum-extra",
"bytes",
"cached",
"cfg-if",
"chrono",
@ -1779,8 +1793,8 @@ dependencies = [
"futures-util",
"headers",
"hyper",
"idna 1.0.2",
"itertools",
"kdl",
"lru",
"magnetar_common",
"magnetar_core",
@ -1788,12 +1802,13 @@ dependencies = [
"magnetar_host_meta",
"magnetar_model",
"magnetar_nodeinfo",
"magnetar_runtime",
"magnetar_sdk",
"magnetar_webfinger",
"miette",
"miette 7.2.0",
"percent-encoding",
"quick-xml",
"regex",
"rmp-serde",
"serde",
"serde_json",
"serde_urlencoded",
@ -1806,6 +1821,7 @@ dependencies = [
"tower-http",
"tracing",
"tracing-subscriber",
"ulid",
"unicode-segmentation",
"url",
]
@ -1835,7 +1851,7 @@ dependencies = [
"headers",
"hyper",
"magnetar_common",
"miette",
"miette 7.2.0",
"percent-encoding",
"serde",
"serde_json",
@ -1893,7 +1909,7 @@ dependencies = [
"magnetar_core",
"magnetar_host_meta",
"magnetar_webfinger",
"miette",
"miette 7.2.0",
"percent-encoding",
"quick-xml",
"reqwest",
@ -1929,6 +1945,7 @@ dependencies = [
"nom_locate",
"quick-xml",
"serde",
"smallvec",
"strum",
"tracing",
"unicode-segmentation",
@ -1966,6 +1983,22 @@ dependencies = [
"serde_json",
]
[[package]]
name = "magnetar_runtime"
version = "0.3.0-alpha"
dependencies = [
"either",
"futures-channel",
"futures-core",
"futures-util",
"itertools",
"magnetar_core",
"magnetar_sdk",
"miette 7.2.0",
"thiserror",
"tracing",
]
[[package]]
name = "magnetar_sdk"
version = "0.3.0-alpha"
@ -2030,6 +2063,18 @@ version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "miette"
version = "5.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59bb584eaeeab6bd0226ccf3509a69d7936d148cf3d036ad350abe35e8c6856e"
dependencies = [
"miette-derive 5.10.0",
"once_cell",
"thiserror",
"unicode-width",
]
[[package]]
name = "miette"
version = "7.2.0"
@ -2039,7 +2084,7 @@ dependencies = [
"backtrace",
"backtrace-ext",
"cfg-if",
"miette-derive",
"miette-derive 7.2.0",
"owo-colors",
"supports-color",
"supports-hyperlinks",
@ -2050,6 +2095,17 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "miette-derive"
version = "5.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49e7bc1560b95a3c4a25d03de42fe76ca718ab92d1a22a55b9b4cf67b3ae635c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
]
[[package]]
name = "miette-derive"
version = "7.2.0"
@ -2822,6 +2878,28 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "rmp"
version = "0.8.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4"
dependencies = [
"byteorder",
"num-traits",
"paste",
]
[[package]]
name = "rmp-serde"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db"
dependencies = [
"byteorder",
"rmp",
"serde",
]
[[package]]
name = "rsa"
version = "0.9.6"
@ -4215,6 +4293,17 @@ version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9"
[[package]]
name = "ulid"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04f903f293d11f31c0c29e4148f6dc0d033a7f80cebc0282bea147611667d289"
dependencies = [
"getrandom",
"rand",
"web-time",
]
[[package]]
name = "unic-char-property"
version = "0.9.0"

View File

@ -15,6 +15,7 @@ members = [
"ext_model",
"fe_calckey",
"magnetar_common",
"magnetar_runtime",
"magnetar_sdk",
"magnetar_mmm_parser",
"core",
@ -30,6 +31,7 @@ async-stream = "0.3"
axum = "0.7"
axum-extra = "0.9"
base64 = "0.22"
bytes = "1.7"
cached = "0.53"
cfg-if = "1"
chrono = "0.4"
@ -49,6 +51,7 @@ hyper = "1.1"
idna = "1"
indexmap = "2.2"
itertools = "0.13"
kdl = "4"
lru = "0.12"
miette = "7"
nom = "7"
@ -58,6 +61,7 @@ priority-queue = "2.0"
quick-xml = "0.36"
redis = "0.26"
regex = "1.9"
rmp-serde = "1.3"
rsa = "0.9"
reqwest = "0.12"
sea-orm = "1"
@ -79,6 +83,7 @@ tower-http = "0.5"
tracing = "0.1"
tracing-subscriber = "0.3"
ts-rs = "7"
ulid = "1"
unicode-segmentation = "1.10"
url = "2.3"
walkdir = "2.3"
@ -91,6 +96,7 @@ magnetar_host_meta = { path = "./ext_host_meta" }
magnetar_webfinger = { path = "./ext_webfinger" }
magnetar_nodeinfo = { path = "./ext_nodeinfo" }
magnetar_model = { path = "./ext_model" }
magnetar_runtime = { path = "./magnetar_runtime" }
magnetar_sdk = { path = "./magnetar_sdk" }
cached = { workspace = true }
@ -108,16 +114,15 @@ tokio = { workspace = true, features = ["full"] }
tokio-stream = { workspace = true }
tower = { workspace = true }
tower-http = { workspace = true, features = ["cors", "trace", "fs"] }
ulid = { workspace = true }
url = { workspace = true }
idna = { workspace = true }
regex = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
tracing = { workspace = true }
cfg-if = { workspace = true }
bytes = { workspace = true }
compact_str = { workspace = true }
either = { workspace = true }
futures = { workspace = true }
@ -129,6 +134,8 @@ thiserror = { workspace = true }
percent-encoding = { workspace = true }
kdl = { workspace = true }
rmp-serde = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
serde_urlencoded = { workspace = true }

1
config/.gitignore vendored
View File

@ -1,3 +1,4 @@
*
!.gitignore
!default.toml
!default-vars.kdl

34
config/default-vars.kdl Normal file
View File

@ -0,0 +1,34 @@
cache {
local-user-cache {
// Size is unlimited
lifetime 300min
}
emoji-cache {
size 4096
}
remote-instance-cache {
size 256
lifetime 100s
}
drive-file-cache {
size 128
lifetime 10s
}
}
api-model {
note {
buffer 10
}
notification {
buffer 10
}
}
activity-pub {
user-agent "magnetar/$version ($host)"
}

View File

@ -59,6 +59,20 @@
# Environment variable: MAG_C_PROXY_REMOTE_FILES
# networking.proxy_remote_files = false
# ------------------------------[ RPC CONNECTION ]-----------------------------
# [Optional]
# A type of connection to use for the application's internal RPC
# Possible values: "none", "tcp", "unix"
# Default: "none"
# Environment variable: MAG_C_RPC_CONNECTION_TYPE
# rpc.connection_type = "none"
# [Optional]
# The corresponding bind address (or path for Unix-domain sockets) for the internal RPC
# Default: ""
# Environment variable: MAG_C_RPC_BIND_ADDR
# rpc.bind_addr = ""
# -----------------------------[ CALCKEY FRONTEND ]----------------------------
@ -83,7 +97,6 @@
# -------------------------------[ FEDERATION ]--------------------------------
# --------------------------------[ BRANDING ]---------------------------------
# [Optional]

View File

@ -11,13 +11,21 @@ use url::Url;
use magnetar_core::web_model::content_type::ContentActivityStreams;
use crate::{
ApClientService,
ApSignature,
ApSigningField, ApSigningHeaders, client::federation_client::{FederationClient, FederationClientError}, crypto::{ApSigningError, ApSigningKey, SigningAlgorithm}, SigningInput, SigningParts,
client::federation_client::{FederationClient, FederationClientError},
crypto::{ApSigningError, ApSigningKey, SigningAlgorithm},
ApClientService, ApSignature, ApSigningField, ApSigningHeaders, SigningInput, SigningParts,
};
pub struct ApClientServiceDefaultProvider {
client: Arc<FederationClient>,
client: Arc<dyn AsRef<FederationClient> + Send + Sync>,
}
impl ApClientServiceDefaultProvider {
pub fn new(client: impl AsRef<FederationClient> + Send + Sync + 'static) -> Self {
Self {
client: Arc::new(client),
}
}
}
impl Display for ApSignature {
@ -237,7 +245,7 @@ impl ApClientService for ApClientServiceDefaultProvider {
&self,
signing_key: ApSigningKey<'_>,
signing_algorithm: SigningAlgorithm,
request: impl SigningInput,
request: &dyn SigningInput,
) -> Result<ApSignature, Self::Error> {
let components = request.create_signing_input();
@ -277,7 +285,7 @@ impl ApClientService for ApClientServiceDefaultProvider {
SigningAlgorithm::RsaSha256 => self.sign_request(
signing_key,
signing_algorithm,
SigningInputGetRsaSha256 {
&SigningInputGetRsaSha256 {
request_target: RequestTarget {
url: &url,
method: Method::GET,
@ -290,7 +298,7 @@ impl ApClientService for ApClientServiceDefaultProvider {
SigningAlgorithm::Hs2019 => self.sign_request(
signing_key,
signing_algorithm,
SigningInputGetHs2019 {
&SigningInputGetHs2019 {
request_target: RequestTarget {
url: &url,
method: Method::GET,
@ -318,6 +326,8 @@ impl ApClientService for ApClientServiceDefaultProvider {
Ok(self
.client
.as_ref()
.as_ref()
.get(url)
.accept(ContentActivityStreams)
.headers(headers)
@ -345,7 +355,7 @@ impl ApClientService for ApClientServiceDefaultProvider {
SigningAlgorithm::RsaSha256 => self.sign_request(
signing_key,
signing_algorithm,
SigningInputPostRsaSha256 {
&SigningInputPostRsaSha256 {
request_target: RequestTarget {
url: &url,
method: Method::POST,
@ -359,7 +369,7 @@ impl ApClientService for ApClientServiceDefaultProvider {
SigningAlgorithm::Hs2019 => self.sign_request(
signing_key,
signing_algorithm,
SigningInputPostHs2019 {
&SigningInputPostHs2019 {
request_target: RequestTarget {
url: &url,
method: Method::POST,
@ -394,6 +404,8 @@ impl ApClientService for ApClientServiceDefaultProvider {
Ok(self
.client
.as_ref()
.as_ref()
.builder(Method::POST, url)
.accept(ContentActivityStreams)
.content_type(ContentActivityStreams)
@ -414,9 +426,9 @@ mod test {
use crate::{
ap_client::ApClientServiceDefaultProvider,
ApClientService,
client::federation_client::FederationClient,
crypto::{ApHttpPrivateKey, SigningAlgorithm},
ApClientService,
};
#[tokio::test]
@ -427,20 +439,20 @@ mod test {
let rsa_key = rsa::RsaPrivateKey::from_pkcs8_pem(key.trim()).into_diagnostic()?;
let ap_client = ApClientServiceDefaultProvider {
client: Arc::new(
client: Arc::new(Box::new(
FederationClient::new(
true,
128_000,
25,
UserAgent::from_static("magnetar/0.42 (https://astolfo.social)"),
)
.into_diagnostic()?,
),
.into_diagnostic()?,
)),
};
let val = ap_client
.signed_get(
ApHttpPrivateKey::Rsa(Box::new(Cow::Owned(rsa_key)))
ApHttpPrivateKey::Rsa(Cow::Owned(Box::new(rsa_key)))
.create_signing_key(&key_id, SigningAlgorithm::RsaSha256)
.into_diagnostic()?,
SigningAlgorithm::RsaSha256,

View File

@ -1,12 +1,16 @@
use std::{borrow::Cow, fmt::Display};
use rsa::pkcs1::DecodeRsaPrivateKey;
use rsa::pkcs1::DecodeRsaPublicKey;
use rsa::pkcs8::DecodePrivateKey;
use rsa::pkcs8::DecodePublicKey;
use rsa::signature::Verifier;
use rsa::{
sha2::{Sha256, Sha512},
signature::Signer,
};
use serde::{Deserialize, Serialize};
use std::fmt::Formatter;
use std::str::FromStr;
use std::{borrow::Cow, fmt::Display};
use strum::AsRefStr;
use thiserror::Error;
@ -35,7 +39,7 @@ pub enum SigningAlgorithm {
}
impl Display for SigningAlgorithm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Hs2019 => write!(f, "hs2019"),
Self::RsaSha256 => write!(f, "rsa-sha256"),
@ -61,6 +65,51 @@ pub enum ApHttpPublicKey<'a> {
Ed25519(Cow<'a, ed25519_dalek::VerifyingKey>),
}
#[derive(Debug, Copy, Clone, Error)]
#[error("Failed to parse the public key: No available parser could decode the PEM string")]
pub struct ApHttpPublicKeyParseError;
impl FromStr for ApHttpPublicKey<'_> {
type Err = ApHttpPublicKeyParseError;
fn from_str(input_pem: &str) -> Result<Self, Self::Err> {
let pem = input_pem.trim();
let parse_pkcs1_rsa: &dyn Fn(_) -> _ = &|p| {
Some(ApHttpPublicKey::Rsa(Cow::Owned(
rsa::RsaPublicKey::from_pkcs1_pem(p).ok()?,
)))
};
let parse_spki_rsa: &dyn Fn(_) -> _ = &|p| {
Some(ApHttpPublicKey::Rsa(Cow::Owned(
rsa::RsaPublicKey::from_public_key_pem(p).ok()?,
)))
};
let parse_spki_ed25519: &dyn Fn(_) -> _ = &|p| {
Some(ApHttpPublicKey::Ed25519(Cow::Owned(
ed25519_dalek::VerifyingKey::from_public_key_pem(p).ok()?,
)))
};
// Some heuristics
let parsers: &[_] = match pem {
p if p.starts_with("-----BEGIN PUBLIC KEY-----") => {
&[parse_spki_rsa, parse_spki_ed25519]
}
p if p.starts_with("-----BEGIN RSA PUBLIC KEY-----") => &[parse_pkcs1_rsa],
_ => &[parse_spki_rsa, parse_spki_ed25519, parse_pkcs1_rsa],
};
for parser in parsers {
if let Some(k) = parser(pem) {
return Ok(k);
}
}
Err(ApHttpPublicKeyParseError)
}
}
impl ApHttpVerificationKey<'_> {
pub fn verify(&self, message: &[u8], signature: &[u8]) -> Result<(), ApVerificationError> {
match self {
@ -109,12 +158,10 @@ impl ApHttpPublicKey<'_> {
));
Ok(verification_key.verify(message, signature)?)
}
(_, SigningAlgorithm::RsaSha256) => {
return Err(ApVerificationError::KeyAlgorithmMismatch(
algorithm,
self.as_ref().to_owned(),
));
}
(_, SigningAlgorithm::RsaSha256) => Err(ApVerificationError::KeyAlgorithmMismatch(
algorithm,
self.as_ref().to_owned(),
)),
(Self::Ed25519(key), SigningAlgorithm::Hs2019) => {
let verification_key = ApHttpVerificationKey::Ed25519(Cow::Borrowed(key.as_ref()));
Ok(verification_key.verify(message, signature)?)
@ -126,17 +173,64 @@ impl ApHttpPublicKey<'_> {
#[derive(Debug, Clone, AsRefStr)]
pub enum ApHttpPrivateKey<'a> {
#[strum(serialize = "rsa")]
Rsa(Box<Cow<'a, rsa::RsaPrivateKey>>),
Rsa(Cow<'a, Box<rsa::RsaPrivateKey>>),
#[strum(serialize = "ed25519")]
Ed25519(Cow<'a, ed25519_dalek::SecretKey>),
}
#[derive(Debug, Copy, Clone, Error)]
#[error("Failed to parse the private key: No available parser could decode the PEM string")]
pub struct ApHttpPrivateKeyParseError;
impl FromStr for ApHttpPrivateKey<'_> {
type Err = ApHttpPrivateKeyParseError;
fn from_str(input_pem: &str) -> Result<Self, Self::Err> {
let pem = input_pem.trim();
let parse_pkcs1_rsa: &dyn Fn(_) -> _ = &|p| {
Some(ApHttpPrivateKey::Rsa(Cow::Owned(Box::new(
rsa::RsaPrivateKey::from_pkcs1_pem(p).ok()?,
))))
};
let parse_pkcs8_rsa: &dyn Fn(_) -> _ = &|p| {
Some(ApHttpPrivateKey::Rsa(Cow::Owned(Box::new(
rsa::RsaPrivateKey::from_pkcs8_pem(p).ok()?,
))))
};
let parse_pkcs8_ed25519: &dyn Fn(_) -> _ = &|p| {
Some(ApHttpPrivateKey::Ed25519(Cow::Owned(
ed25519_dalek::SigningKey::from_pkcs8_pem(p)
.ok()?
.to_bytes(),
)))
};
// Some heuristics
let parsers: &[_] = match pem {
p if p.contains("-----BEGIN PRIVATE KEY-----") => {
&[parse_pkcs8_rsa, parse_pkcs8_ed25519]
}
p if p.contains("-----BEGIN RSA PRIVATE KEY-----") => &[parse_pkcs1_rsa],
_ => &[parse_pkcs8_rsa, parse_pkcs8_ed25519, parse_pkcs1_rsa],
};
for parser in parsers {
if let Some(k) = parser(pem) {
return Ok(k);
}
}
Err(ApHttpPrivateKeyParseError)
}
}
#[derive(Debug, Clone, AsRefStr)]
pub enum ApHttpSigningKey<'a> {
#[strum(serialize = "rsa-sha256")]
RsaSha256(Cow<'a, rsa::pkcs1v15::SigningKey<rsa::sha2::Sha256>>),
RsaSha256(Cow<'a, rsa::pkcs1v15::SigningKey<Sha256>>),
#[strum(serialize = "rsa-sha512")]
RsaSha512(Cow<'a, rsa::pkcs1v15::SigningKey<rsa::sha2::Sha512>>),
RsaSha512(Cow<'a, rsa::pkcs1v15::SigningKey<Sha512>>),
#[strum(serialize = "ed25519")]
Ed25519(Cow<'a, ed25519_dalek::SigningKey>),
}
@ -192,7 +286,7 @@ impl ApHttpPrivateKey<'_> {
key: match (self, algorithm) {
(Self::Rsa(key), SigningAlgorithm::RsaSha256 | SigningAlgorithm::Hs2019) => {
ApHttpSigningKey::RsaSha256(Cow::Owned(rsa::pkcs1v15::SigningKey::new(
key.clone().into_owned(),
*key.as_ref().to_owned(),
)))
}
(Self::Ed25519(key), SigningAlgorithm::Hs2019) => ApHttpSigningKey::Ed25519(

View File

@ -154,7 +154,7 @@ pub trait ApClientService: Send + Sync {
&self,
signing_key: ApSigningKey<'_>,
signing_algorithm: SigningAlgorithm,
request: impl SigningInput,
request: &dyn SigningInput,
) -> Result<ApSignature, Self::Error>;
async fn signed_get(

View File

@ -1,23 +1,23 @@
use std::future::Future;
use chrono::Utc;
use futures_util::{SinkExt, StreamExt};
use futures_util::StreamExt;
use redis::IntoConnectionInfo;
pub use sea_orm;
use sea_orm::{ActiveValue::Set, ConnectionTrait};
use sea_orm::ActiveValue::Set;
use sea_orm::{
ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter,
TransactionTrait,
};
use serde::{Deserialize, Deserializer, Serialize};
use serde::de::Error;
use serde::{Deserialize, Deserializer, Serialize};
use serde_json::Value;
use strum::IntoStaticStr;
use thiserror::Error;
use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, trace, warn};
use tracing::log::LevelFilter;
use tracing::{error, info, trace, warn};
use url::Host;
pub use ck;
@ -122,13 +122,25 @@ impl CalckeyModel {
.await?)
}
pub async fn get_user_and_profile_by_id(&self, id: &str) -> Result<Option<(user::Model, user_profile::Model)>, CalckeyDbError> {
Ok(user::Entity::find()
pub async fn get_user_for_cache_by_id(&self, id: &str) -> Result<Option<(user::Model, user_profile::Model, user_keypair::Model)>, CalckeyDbError> {
let txn = self.0.begin().await?;
let Some((user, Some(profile))) = user::Entity::find()
.filter(user::Column::Id.eq(id))
.find_also_related(user_profile::Entity)
.one(&self.0)
.await?
.and_then(|(u, p)| p.map(|pp| (u, pp))))
.one(&txn)
.await? else {
return Ok(None);
};
let Some(keys) = user_keypair::Entity::find()
.filter(user_keypair::Column::UserId.eq(id))
.one(&txn)
.await? else {
return Ok(None);
};
Ok(Some((user, profile, keys)))
}
pub async fn get_user_security_keys_by_id(
@ -165,20 +177,30 @@ impl CalckeyModel {
.await?)
}
pub async fn get_user_and_profile_by_token(
pub async fn get_user_for_cache_by_token(
&self,
token: &str,
) -> Result<Option<(user::Model, user_profile::Model)>, CalckeyDbError> {
Ok(user::Entity::find()
.filter(
user::Column::Token
.eq(token)
.and(user::Column::Host.is_null()),
)
) -> Result<Option<(user::Model, user_profile::Model, user_keypair::Model)>, CalckeyDbError> {
let txn = self.0.begin().await?;
let Some((user, Some(profile))) = user::Entity::find()
.filter(user::Column::Token
.eq(token)
.and(user::Column::Host.is_null()))
.find_also_related(user_profile::Entity)
.one(&self.0)
.await?
.and_then(|(u, p)| p.map(|pp| (u, pp))))
.one(&txn)
.await? else {
return Ok(None);
};
let Some(keys) = user_keypair::Entity::find()
.filter(user_keypair::Column::UserId.eq(&user.id))
.one(&txn)
.await? else {
return Ok(None);
};
Ok(Some((user, profile, keys)))
}
pub async fn get_user_by_uri(&self, uri: &str) -> Result<Option<user::Model>, CalckeyDbError> {
@ -399,8 +421,8 @@ struct RawMessage<'a> {
impl<'de> Deserialize<'de> for SubMessage {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
where
D: Deserializer<'de>,
{
let raw = RawMessage::deserialize(deserializer)?;

View File

@ -11,7 +11,7 @@ use data::{sub_interaction_reaction, sub_interaction_renote, NoteData};
use ext_model_migration::SelectStatement;
use magnetar_sdk::types::SpanFilter;
use sea_orm::sea_query::{Asterisk, Expr, IntoIden, Query, SelectExpr, SimpleExpr};
use sea_orm::{ColumnTrait, Condition, EntityTrait, Iden, JoinType, QueryFilter, QueryOrder, QuerySelect, QueryTrait, Select, StatementBuilder};
use sea_orm::{ColumnTrait, Condition, EntityTrait, Iden, JoinType, QueryFilter, QueryOrder, QuerySelect, QueryTrait, Select};
use std::sync::Arc;
const PINS: &str = "pins.";
@ -45,7 +45,8 @@ impl NoteResolveMode {
match self {
NoteResolveMode::Single(id) => Ok(id_col.eq(id)),
NoteResolveMode::Multiple(ids) => Ok(id_col.is_in(ids)),
// We add a CTE for pins
// We do this in a separate query, because before we used an inner join, and it caused
// a massive performance penalty
NoteResolveMode::PinsFromUserId(user_id) => {
let cte_query = user_note_pining::Entity::find()
.column(user_note_pining::Column::NoteId)

View File

@ -1,6 +1,7 @@
use serde::Deserialize;
use std::fmt::{Display, Formatter};
use std::net::IpAddr;
use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf;
use thiserror::Error;
#[derive(Deserialize, Debug)]
@ -94,6 +95,56 @@ impl Default for MagnetarNetworking {
}
}
#[derive(Deserialize, Debug, Default, Clone)]
#[serde(
rename_all = "snake_case",
tag = "connection_type",
content = "bind_addr"
)]
pub enum MagnetarRpcSocketKind {
#[default]
None,
Unix(PathBuf),
Tcp(SocketAddr),
}
#[derive(Deserialize, Debug)]
#[non_exhaustive]
pub struct MagnetarRpcConfig {
pub connection_settings: MagnetarRpcSocketKind,
}
fn env_rpc_connection() -> MagnetarRpcSocketKind {
match std::env::var("MAG_C_RPC_CONNECTION_TYPE")
.unwrap_or_else(|_| "none".to_owned())
.to_lowercase()
.as_str()
{
"none" => MagnetarRpcSocketKind::None,
"unix" => MagnetarRpcSocketKind::Unix(
std::env::var("MAG_C_RPC_BIND_ADDR")
.unwrap_or_default()
.parse()
.expect("MAG_C_RPC_BIND_ADDR must be a valid path"),
),
"tcp" => MagnetarRpcSocketKind::Tcp(
std::env::var("MAG_C_RPC_BIND_ADDR")
.unwrap_or_default()
.parse()
.expect("MAG_C_RPC_BIND_ADDR must be a valid socket address"),
),
_ => panic!("MAG_C_RPC_CONNECTION_TYPE must be a valid protocol or 'none'"),
}
}
impl Default for MagnetarRpcConfig {
fn default() -> Self {
MagnetarRpcConfig {
connection_settings: env_rpc_connection(),
}
}
}
#[derive(Deserialize, Debug)]
#[non_exhaustive]
pub struct MagnetarCalckeyFrontendConfig {
@ -196,6 +247,8 @@ pub struct MagnetarConfig {
#[serde(default)]
pub networking: MagnetarNetworking,
#[serde(default)]
pub rpc: MagnetarRpcConfig,
#[serde(default)]
pub branding: MagnetarBranding,
#[serde(default)]
pub calckey_frontend: MagnetarCalckeyFrontendConfig,

View File

@ -0,0 +1,21 @@
[package]
name = "magnetar_runtime"
version.workspace = true
edition.workspace = true
[lib]
crate-type = ["rlib"]
[dependencies]
magnetar_core = { path = "../core" }
magnetar_sdk = { path = "../magnetar_sdk" }
either = { workspace = true }
futures-channel = { workspace = true }
futures-util = { workspace = true }
futures-core = { workspace = true }
itertools = { workspace = true }
miette = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }

View File

@ -0,0 +1 @@

View File

@ -15,7 +15,7 @@ use serde::{Deserialize, Deserializer, Serialize};
use ts_rs::TS;
pub(crate) mod packed_time {
use chrono::{DateTime, NaiveDateTime, Utc};
use chrono::{DateTime, Utc};
use serde::de::Error;
use serde::{Deserialize, Deserializer, Serializer};
@ -30,15 +30,12 @@ pub(crate) mod packed_time {
where
D: Deserializer<'de>,
{
Ok(DateTime::<Utc>::from_naive_utc_and_offset(
NaiveDateTime::from_timestamp_millis(
String::deserialize(deserializer)?
.parse::<i64>()
.map_err(Error::custom)?,
)
.ok_or_else(|| Error::custom("millisecond value out of range"))?,
Utc,
))
DateTime::<Utc>::from_timestamp_millis(
String::deserialize(deserializer)?
.parse::<i64>()
.map_err(Error::custom)?,
)
.ok_or_else(|| Error::custom("millisecond value out of range"))
}
}
@ -91,14 +88,14 @@ impl SpanFilter {
pub fn start(&self) -> Option<(DateTime<Utc>, String)> {
match self {
Self::Start(StartFilter {
time_start,
id_start,
})
time_start,
id_start,
})
| Self::Range(RangeFilter {
time_start,
id_start,
..
}) => Some((*time_start, id_start.clone())),
time_start,
id_start,
..
}) => Some((*time_start, id_start.clone())),
_ => None,
}
}
@ -107,8 +104,8 @@ impl SpanFilter {
match self {
Self::End(EndFilter { time_end, id_end })
| Self::Range(RangeFilter {
time_end, id_end, ..
}) => Some((*time_end, id_end.clone())),
time_end, id_end, ..
}) => Some((*time_end, id_end.clone())),
_ => None,
}
}
@ -145,10 +142,10 @@ impl SpanFilter {
})),
Self::Range(RangeFilter { time_start, .. }) if *time_start > last_date => None,
Self::Range(RangeFilter {
time_start,
id_start,
..
}) => Some(SpanFilter::Range(RangeFilter {
time_start,
id_start,
..
}) => Some(SpanFilter::Range(RangeFilter {
time_start: *time_start,
id_start: id_start.clone(),
time_end: last_date,

View File

@ -2,8 +2,10 @@ mod api_v1;
pub mod host_meta;
pub mod model;
pub mod nodeinfo;
mod rpc_v1;
pub mod service;
pub mod util;
pub mod vars;
pub mod web;
pub mod webfinger;
@ -14,8 +16,14 @@ use crate::service::MagnetarService;
use axum::routing::get;
use axum::Router;
use dotenvy::dotenv;
use futures::{select, FutureExt};
use magnetar_common::config::{MagnetarConfig, MagnetarRpcSocketKind};
use magnetar_model::{CacheConnectorConfig, CalckeyCache, CalckeyModel, ConnectorConfig};
use miette::{miette, IntoDiagnostic};
use rpc_v1::create_rpc_router;
use rpc_v1::proto::RpcSockAddr;
use std::convert::Infallible;
use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
@ -57,16 +65,44 @@ async fn main() -> miette::Result<()> {
})
.into_diagnostic()?;
let service = Arc::new(
MagnetarService::new(config, db.clone(), redis)
.await
.into_diagnostic()?,
);
let service = Arc::new(MagnetarService::new(config, db.clone(), redis).await?);
let shutdown_signal = shutdown_signal().shared();
select! {
rpc_res = run_rpc(service.clone(), config, shutdown_signal.clone()).fuse() => rpc_res,
web_res = run_web(service, config, shutdown_signal).fuse() => web_res
}
}
async fn run_rpc(
service: Arc<MagnetarService>,
config: &'static MagnetarConfig,
shutdown_signal: impl Future<Output = ()> + Send + 'static,
) -> miette::Result<()> {
let rpc_bind_addr = match &config.rpc.connection_settings {
MagnetarRpcSocketKind::None => {
std::future::pending::<Infallible>().await;
unreachable!();
}
MagnetarRpcSocketKind::Unix(path) => RpcSockAddr::Unix(path.clone()),
MagnetarRpcSocketKind::Tcp(ip) => RpcSockAddr::Ip(*ip),
};
let rpc = create_rpc_router();
rpc.run(service, rpc_bind_addr, Some(shutdown_signal)).await
}
async fn run_web(
service: Arc<MagnetarService>,
config: &'static MagnetarConfig,
shutdown_signal: impl Future<Output = ()> + Send + 'static,
) -> miette::Result<()> {
let well_known_router = Router::new()
.route(
"/webfinger",
get(webfinger::handle_webfinger).with_state((config, db)),
get(webfinger::handle_webfinger).with_state((config, service.db.clone())),
)
.route("/host-meta", get(handle_host_meta))
.route("/nodeinfo", get(handle_nodeinfo));
@ -93,7 +129,7 @@ async fn main() -> miette::Result<()> {
let listener = TcpListener::bind(addr).await.into_diagnostic()?;
info!("Serving...");
axum::serve(listener, app.into_make_service())
.with_graceful_shutdown(shutdown_signal())
.with_graceful_shutdown(shutdown_signal)
.await
.map_err(|e| miette!("Error running server: {}", e))
}
@ -121,5 +157,5 @@ async fn shutdown_signal() {
_ = terminate => {},
}
info!("Shutting down...");
info!("Received a signal to shut down...");
}

View File

@ -5,6 +5,7 @@ use crate::service::instance_meta_cache::InstanceMetaCacheError;
use magnetar_model::sea_orm::DbErr;
use magnetar_model::CalckeyDbError;
use magnetar_sdk::mmm::Token;
use miette::Diagnostic;
use thiserror::Error;
pub mod drive;
@ -13,27 +14,37 @@ pub mod note;
pub mod notification;
pub mod user;
#[derive(Debug, Error, strum::IntoStaticStr)]
#[derive(Debug, Error, Diagnostic)]
pub enum PackError {
#[error("Database error: {0}")]
#[diagnostic(code(mag::pack_error::db_error))]
DbError(#[from] DbErr),
#[error("Calckey database wrapper error: {0}")]
#[diagnostic(code(mag::pack_error::db_wrapper_error))]
CalckeyDbError(#[from] CalckeyDbError),
#[error("Data error: {0}")]
#[diagnostic(code(mag::pack_error::data_error))]
DataError(String),
#[error("Emoji cache error: {0}")]
#[diagnostic(code(mag::pack_error::emoji_cache_error))]
EmojiCacheError(#[from] EmojiCacheError),
#[error("Instance cache error: {0}")]
#[diagnostic(code(mag::pack_error::instance_meta_cache_error))]
InstanceMetaCacheError(#[from] InstanceMetaCacheError),
#[error("Generic cache error: {0}")]
#[diagnostic(code(mag::pack_error::generic_id_cache_error))]
GenericCacheError(#[from] GenericIdCacheError),
#[error("Remote instance cache error: {0}")]
#[diagnostic(code(mag::pack_error::remote_instance_cache_error))]
RemoteInstanceCacheError(#[from] RemoteInstanceCacheError),
#[error("Deserializer error: {0}")]
#[diagnostic(code(mag::pack_error::deserializer_error))]
DeserializerError(#[from] serde_json::Error),
#[error("URL parse error: {0}")]
#[diagnostic(code(mag::pack_error::url_parse_error))]
UrlParseError(#[from] url::ParseError),
#[error("Parallel processing error: {0}")]
#[diagnostic(code(mag::pack_error::task_join_error))]
JoinError(#[from] tokio::task::JoinError),
}

View File

@ -7,7 +7,7 @@ use crate::model::{PackType, PackingContext};
use compact_str::CompactString;
use either::Either;
use futures_util::future::try_join_all;
use futures_util::{FutureExt, StreamExt, TryStreamExt};
use futures_util::{StreamExt, TryStreamExt};
use magnetar_model::ck::sea_orm_active_enums::NoteVisibilityEnum;
use magnetar_model::model_ext::AliasColumnExt;
use magnetar_model::note_model::data::{

98
src/rpc_v1/mod.rs Normal file
View File

@ -0,0 +1,98 @@
use std::sync::Arc;
use magnetar_federation::crypto::SigningAlgorithm;
use proto::{MagRpc, RpcMessage};
use serde::Deserialize;
use tracing::{debug, info};
use crate::{
model::{processing::note::NoteModel, PackingContext},
service::MagnetarService,
web::{ApiError, ObjectNotFound},
};
pub mod proto;
#[derive(Debug, Deserialize)]
struct RpcApGet {
user_id: String,
url: String,
}
#[derive(Debug, Deserialize)]
struct RpcApPost {
user_id: String,
url: String,
body: serde_json::Value,
}
pub fn create_rpc_router() -> MagRpc {
MagRpc::new()
.handle(
"/ping",
|_, RpcMessage(message): RpcMessage<String>| async move {
debug!("Received RPC ping: {}", message);
RpcMessage("pong".to_owned())
},
)
.handle(
"/note/by-id",
|service, RpcMessage(id): RpcMessage<String>| async move {
let ctx = PackingContext::new(service, None).await?;
let note = NoteModel {
attachments: true,
with_context: true,
}
.fetch_single(&ctx, &id)
.await?
.ok_or(ObjectNotFound(id))?;
Result::<_, ApiError>::Ok(note)
},
)
.handle(
"/ap/get",
|service: Arc<MagnetarService>,
RpcMessage(RpcApGet { user_id, url }): RpcMessage<RpcApGet>| async move {
let Some(user) = service.local_user_cache.get_by_id(&user_id).await? else {
return Err(ObjectNotFound(format!("LocalUserID:{user_id}")).into());
};
let key_id = format!(
"https://{}/users/{}#main-key",
service.config.networking.host, user_id
);
let signing_key = user
.private_key
.create_signing_key(&key_id, SigningAlgorithm::RsaSha256)?;
let result = service
.ap_client
.signed_get(signing_key, SigningAlgorithm::RsaSha256, None, &url)
.await?;
Result::<_, ApiError>::Ok(result)
},
)
.handle(
"/ap/post",
|service: Arc<MagnetarService>,
RpcMessage(RpcApPost { user_id, url, body }): RpcMessage<RpcApPost>| async move {
let Some(user) = service.local_user_cache.get_by_id(&user_id).await? else {
return Err(ObjectNotFound(format!("LocalUserID:{user_id}")).into());
};
let key_id = format!(
"https://{}/users/{}#main-key",
service.config.networking.host, user_id
);
let signing_key = user
.private_key
.create_signing_key(&key_id, SigningAlgorithm::RsaSha256)?;
let result = service
.ap_client
.signed_post(signing_key, SigningAlgorithm::RsaSha256, None, &url, &body)
.await?;
Result::<_, ApiError>::Ok(result)
},
)
}

470
src/rpc_v1/proto.rs Normal file
View File

@ -0,0 +1,470 @@
use crate::service::MagnetarService;
use bytes::BufMut;
use futures::{FutureExt, Stream, StreamExt};
use miette::{miette, IntoDiagnostic};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::collections::HashMap;
use std::future::Future;
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, ReadBuf};
use tokio::net::{TcpListener, UnixSocket};
use tokio::select;
use tokio::task::JoinSet;
use tracing::{debug, error, info, Instrument};
#[derive(Debug, Clone)]
pub enum RpcSockAddr {
Ip(SocketAddr),
Unix(PathBuf),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[repr(transparent)]
pub struct RpcMessage<T>(pub T);
pub trait IntoRpcResponse: Send {
fn into_rpc_response(self) -> Option<RpcResponse>;
}
#[derive(Debug, Clone)]
pub struct RpcResponse(Vec<u8>);
impl<T: Serialize + Send + 'static> IntoRpcResponse for RpcMessage<T> {
fn into_rpc_response(self) -> Option<RpcResponse> {
rmp_serde::to_vec_named(&self)
.inspect_err(|e| {
error!(
"Failed to serialize value of type {}: {}",
std::any::type_name::<T>(),
e
)
})
.ok()
.map(RpcResponse)
}
}
#[derive(Debug, Serialize)]
pub struct RpcResult<T> {
success: bool,
data: T,
}
impl<T: Serialize + Send + 'static, E: Serialize + Send + 'static> IntoRpcResponse
for Result<T, E>
{
fn into_rpc_response(self) -> Option<RpcResponse> {
match self {
Ok(data) => RpcMessage(RpcResult {
success: true,
data,
})
.into_rpc_response(),
Err(data) => RpcMessage(RpcResult {
success: false,
data,
})
.into_rpc_response(),
}
}
}
pub trait RpcHandler<T>: Send + Sync + 'static
where
T: Send + 'static,
{
fn process(
&self,
context: Arc<MagnetarService>,
message: RpcMessage<T>,
) -> impl Future<Output = Option<RpcResponse>> + Send;
}
impl<T, F, Fut, RR> RpcHandler<T> for F
where
T: Send + 'static,
F: Fn(Arc<MagnetarService>, RpcMessage<T>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = RR> + Send,
RR: IntoRpcResponse,
{
async fn process(
&self,
context: Arc<MagnetarService>,
message: RpcMessage<T>,
) -> Option<RpcResponse> {
self(context, message).await.into_rpc_response()
}
}
type MessageRaw = Box<dyn Any + Send + 'static>;
type MagRpcHandlerMapped = dyn Fn(
Arc<MagnetarService>,
MessageRaw,
) -> Pin<Box<dyn Future<Output = Option<RpcResponse>> + Send + 'static>>
+ Send
+ Sync
+ 'static;
type MagRpcDecoderMapped =
dyn (Fn(&'_ [u8]) -> Result<MessageRaw, rmp_serde::decode::Error>) + Send + Sync + 'static;
pub struct MagRpc {
listeners: HashMap<String, Arc<MagRpcHandlerMapped>>,
payload_decoders: HashMap<String, Box<MagRpcDecoderMapped>>,
}
impl MagRpc {
pub fn new() -> Self {
MagRpc {
listeners: HashMap::new(),
payload_decoders: HashMap::new(),
}
}
pub fn handle<H, T>(mut self, method: impl Into<String>, handler: H) -> Self
where
T: DeserializeOwned + Send + 'static,
H: RpcHandler<T> + Sync + 'static,
{
let handler_ref = Arc::new(handler);
let method = method.into();
self.listeners.insert(
method.clone(),
Arc::new(move |ctx, data| {
let handler = handler_ref.clone();
async move {
handler
.process(ctx, RpcMessage(*data.downcast().unwrap()))
.await
}
.boxed()
}),
);
self.payload_decoders.insert(
method,
Box::new(move |data| Ok(Box::new(rmp_serde::from_slice::<'_, T>(data)?))),
);
self
}
pub async fn run(
self,
context: Arc<MagnetarService>,
addr: RpcSockAddr,
graceful_shutdown: Option<impl Future<Output = ()>>,
) -> miette::Result<()> {
match addr {
RpcSockAddr::Ip(sock_addr) => {
self.run_tcp(context, &sock_addr, graceful_shutdown).await
}
RpcSockAddr::Unix(path) => self.run_unix(context, &path, graceful_shutdown).await,
}
}
async fn run_tcp(
self,
context: Arc<MagnetarService>,
sock_addr: &SocketAddr,
graceful_shutdown: Option<impl Future<Output = ()>>,
) -> miette::Result<()> {
debug!("Binding RPC TCP socket to {}", sock_addr);
let listener = TcpListener::bind(sock_addr).await.into_diagnostic()?;
info!("Listening for RPC calls on {}", sock_addr);
let (sender, mut cancel) = tokio::sync::oneshot::channel::<()>();
let mut cancellation_tokens = Vec::new();
let rx_dec = RpcCallDecoder {
listeners: Arc::new(self.listeners),
payload_decoders: Arc::new(self.payload_decoders),
};
let mut connections = JoinSet::<miette::Result<_>>::new();
loop {
let (stream, remote_addr) = select!(
Some(c) = connections.join_next() => {
debug!("RPC TCP connection closed: {:?}", c);
continue;
},
conn = listener.accept() => {
if let Err(e) = conn {
error!("Connection error: {}", e);
break
}
conn.unwrap()
},
_ = &mut cancel => break
);
debug!("RPC TCP connection accepted: {:?}", remote_addr);
let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>();
let (read_half, mut write_half) = stream.into_split();
let buf_read = BufReader::new(read_half);
let context = context.clone();
let rx_dec = rx_dec.clone();
let fut = async move {
let src = rx_dec
.stream_decode(buf_read, cancel_recv)
.filter_map(|r| async move {
if let Err(e) = &r {
error!("Stream decoding error: {e}");
}
r.ok()
})
.filter_map(|(serial, payload, listener)| {
let ctx = context.clone();
async move { Some((serial, listener(ctx, payload).await?)) }
});
futures::pin_mut!(src);
while let Some((serial, RpcResponse(bytes))) = src.next().await {
write_half.write_u8(b'M').await.into_diagnostic()?;
write_half.write_u64(serial).await.into_diagnostic()?;
write_half
.write_u32(bytes.len() as u32)
.await
.into_diagnostic()?;
write_half.write_all(&bytes).await.into_diagnostic()?;
write_half.flush().await.into_diagnostic()?;
}
Ok(remote_addr)
}
.instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr));
connections.spawn(fut);
cancellation_tokens.push(cancel_send);
}
if let Some(graceful_shutdown) = graceful_shutdown {
graceful_shutdown.await;
sender.send(()).ok();
}
info!("Awaiting shutdown of all RPC connections...");
connections.join_all().await;
Ok(())
}
async fn run_unix(
self,
context: Arc<MagnetarService>,
addr: &Path,
graceful_shutdown: Option<impl Future<Output = ()>>,
) -> miette::Result<()> {
let sock = UnixSocket::new_stream().into_diagnostic()?;
debug!("Binding RPC Unix socket to {}", addr.display());
sock.bind(addr).into_diagnostic()?;
let listener = sock.listen(16).into_diagnostic()?;
let (sender, mut cancel) = tokio::sync::oneshot::channel::<()>();
let mut cancellation_tokens = Vec::new();
let rx_dec = RpcCallDecoder {
listeners: Arc::new(self.listeners),
payload_decoders: Arc::new(self.payload_decoders),
};
let mut connections = JoinSet::<miette::Result<_>>::new();
loop {
let (stream, remote_addr) = select!(
Some(c) = connections.join_next() => {
debug!("RPC Unix connection closed: {:?}", c);
continue;
},
conn = listener.accept() => {
if let Err(e) = conn {
error!("Connection error: {}", e);
break
}
conn.unwrap()
},
_ = &mut cancel => break
);
debug!("RPC Unix connection accepted: {:?}", remote_addr);
let (cancel_send, cancel_recv) = tokio::sync::oneshot::channel::<()>();
let (read_half, mut write_half) = stream.into_split();
let buf_read = BufReader::new(read_half);
let context = context.clone();
let rx_dec = rx_dec.clone();
let fut = async move {
let src = rx_dec
.stream_decode(buf_read, cancel_recv)
.filter_map(|r| async move {
if let Err(e) = &r {
error!("Stream decoding error: {e}");
}
r.ok()
})
.filter_map(|(serial, payload, listener)| {
let ctx = context.clone();
async move { Some((serial, listener(ctx, payload).await?)) }
});
futures::pin_mut!(src);
while let Some((serial, RpcResponse(bytes))) = src.next().await {
write_half.write_u8(b'M').await.into_diagnostic()?;
write_half.write_u64(serial).await.into_diagnostic()?;
write_half
.write_u32(bytes.len() as u32)
.await
.into_diagnostic()?;
write_half.write_all(&bytes).await.into_diagnostic()?;
write_half.flush().await.into_diagnostic()?;
}
miette::Result::<()>::Ok(())
}
.instrument(tracing::info_span!("RPC", remote_addr = ?remote_addr));
connections.spawn(fut.boxed());
cancellation_tokens.push(cancel_send);
}
if let Some(graceful_shutdown) = graceful_shutdown {
graceful_shutdown.await;
sender.send(()).ok();
}
info!("Awaiting shutdown of all RPC connections...");
connections.join_all().await;
Ok(())
}
}
#[derive(Clone)]
struct RpcCallDecoder {
listeners: Arc<HashMap<String, Arc<MagRpcHandlerMapped>>>,
payload_decoders: Arc<HashMap<String, Box<MagRpcDecoderMapped>>>,
}
impl RpcCallDecoder {
fn stream_decode<R: AsyncRead + AsyncReadExt + Unpin + Send + 'static>(
&self,
mut buf_read: BufReader<R>,
mut cancel: tokio::sync::oneshot::Receiver<()>,
) -> impl Stream<Item = miette::Result<(u64, MessageRaw, Arc<MagRpcHandlerMapped>)>> + Send + 'static
{
let decoders = self.payload_decoders.clone();
let listeners = self.listeners.clone();
async_stream::try_stream! {
let mut name_buf = Vec::new();
let mut buf = Vec::new();
let mut messages = 0usize;
loop {
let read_fut = async {
let mut header = [0u8; 1];
if buf_read.read(&mut header).await.into_diagnostic()? == 0 {
return if messages > 0 {
Ok(None)
} else {
Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Unexpected end of stream, expected a header"
)).into_diagnostic()
}
}
if !matches!(header, [b'M']) {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Unexpected data in stream, expected a header"
)).into_diagnostic();
}
let serial = buf_read.read_u64().await.into_diagnostic()?;
let name_len = buf_read.read_u32().await.into_diagnostic()? as usize;
if name_len > name_buf.capacity() {
name_buf.reserve(name_len - name_buf.capacity());
}
// SAFETY: We use ReadBuf which expects uninit anyway
unsafe {
name_buf.set_len(name_len);
}
let mut name_buf_write = ReadBuf::uninit(&mut name_buf);
while name_buf_write.has_remaining_mut() {
buf_read
.read_buf(&mut name_buf_write)
.await
.into_diagnostic()?;
}
let payload_len = buf_read.read_u32().await.into_diagnostic()? as usize;
if payload_len > buf.capacity() {
buf.reserve(payload_len - buf.capacity());
}
// SAFETY: We use ReadBuf which expects uninit anyway
unsafe {
buf.set_len(payload_len);
}
let mut buf_write = ReadBuf::uninit(&mut buf);
while buf_write.has_remaining_mut() {
buf_read.read_buf(&mut buf_write).await.into_diagnostic()?;
}
miette::Result::<_>::Ok(Some((serial, name_buf_write, buf_write)))
};
let Some((serial, name_buf_write, payload)) = select! {
read_result = read_fut => read_result,
_ = &mut cancel => { break; }
}? else {
break;
};
let name = std::str::from_utf8(name_buf_write.filled()).into_diagnostic()?;
let decoder = decoders
.get(name)
.ok_or_else(|| miette!("No such RPC call name: {}", name))?
.as_ref();
let listener = listeners
.get(name)
.ok_or_else(|| miette!("No such RPC call name: {}", name))?
.clone();
let packet = match decoder(payload.filled()) {
Ok(p) => p,
Err(e) => {
error!("Failed to parse packet: {e}");
continue;
}
};
yield (serial, packet, listener);
messages += 1;
}
}
}
}

View File

@ -2,27 +2,23 @@ use crate::web::ApiError;
use lru::LruCache;
use magnetar_model::emoji::{EmojiResolver, EmojiTag};
use magnetar_model::{ck, CalckeyDbError, CalckeyModel};
use miette::Diagnostic;
use std::collections::HashSet;
use std::sync::Arc;
use strum::VariantNames;
use thiserror::Error;
use tokio::sync::Mutex;
#[derive(Debug, Error, VariantNames)]
#[derive(Debug, Error, Diagnostic)]
#[error("Emoji cache error: {}")]
pub enum EmojiCacheError {
#[error("Database error: {0}")]
#[diagnostic(code(mag::emoji_cache_error::db_error))]
DbError(#[from] CalckeyDbError),
}
impl From<EmojiCacheError> for ApiError {
fn from(err: EmojiCacheError) -> Self {
let mut api_error: ApiError = match err {
EmojiCacheError::DbError(err) => err.into(),
};
api_error.message = format!("Emoji cache error: {}", api_error.message);
api_error
Self::internal("Cache error", err)
}
}

View File

@ -0,0 +1,32 @@
use std::str::FromStr;
use headers::UserAgent;
use magnetar_common::config::MagnetarConfig;
use magnetar_federation::{
ap_client::{ApClientError, ApClientServiceDefaultProvider},
client::federation_client::FederationClient,
ApClientService,
};
use miette::IntoDiagnostic;
pub(super) fn new_federation_client_service(
config: &'static MagnetarConfig,
) -> miette::Result<FederationClient> {
FederationClient::new(
true,
256000,
20,
UserAgent::from_str(&format!(
"magnetar/{} (https://{})",
config.branding.version, config.networking.host
))
.into_diagnostic()?,
)
.into_diagnostic()
}
pub(super) fn new_ap_client_service(
federation_client: impl AsRef<FederationClient> + Send + Sync + 'static,
) -> impl ApClientService<Error = ApClientError> {
ApClientServiceDefaultProvider::new(federation_client)
}

29
src/service/gen_id.rs Normal file
View File

@ -0,0 +1,29 @@
use std::{sync::Arc, time::SystemTime};
use super::MagnetarService;
pub struct GenIdService;
impl GenIdService {
pub fn new_id(&self) -> ulid::Ulid {
ulid::Ulid::new()
}
pub fn new_id_str(&self) -> String {
self.new_id().to_string()
}
pub fn new_for_time(&self, time: impl Into<SystemTime>) -> ulid::Ulid {
ulid::Ulid::from_datetime(time.into())
}
pub fn new_str_for_time(&self, time: impl Into<SystemTime>) -> String {
self.new_for_time(time).to_string()
}
}
impl AsRef<GenIdService> for Arc<MagnetarService> {
fn as_ref(&self) -> &GenIdService {
&self.gen_id
}
}

View File

@ -2,14 +2,16 @@ use crate::web::ApiError;
use lru::LruCache;
use magnetar_model::sea_orm::{EntityTrait, PrimaryKeyTrait};
use magnetar_model::{CalckeyDbError, CalckeyModel};
use miette::Diagnostic;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::{Duration, Instant};
use strum::VariantNames;
use thiserror::Error;
use tokio::sync::Mutex;
#[derive(Debug, Error, VariantNames)]
#[derive(Debug, Error, Diagnostic)]
#[error("Generic ID cache error: {}")]
#[diagnostic(code(mag::generic_id_cache_error))]
pub enum GenericIdCacheError {
#[error("Database error: {0}")]
DbError(#[from] CalckeyDbError),
@ -17,13 +19,7 @@ pub enum GenericIdCacheError {
impl From<GenericIdCacheError> for ApiError {
fn from(err: GenericIdCacheError) -> Self {
let mut api_error: ApiError = match err {
GenericIdCacheError::DbError(err) => err.into(),
};
api_error.message = format!("Generic ID cache error: {}", api_error.message);
api_error
Self::internal("Cache error", err)
}
}

View File

@ -2,13 +2,15 @@ use crate::web::ApiError;
use lru::LruCache;
use magnetar_common::config::MagnetarConfig;
use magnetar_model::{ck, CalckeyDbError, CalckeyModel};
use miette::Diagnostic;
use std::sync::Arc;
use std::time::{Duration, Instant};
use strum::VariantNames;
use thiserror::Error;
use tokio::sync::Mutex;
#[derive(Debug, Error, VariantNames)]
#[derive(Debug, Error, Diagnostic)]
#[error("Remote instance cache error: {}")]
#[diagnostic(code(mag::remote_instance_cache_error))]
pub enum RemoteInstanceCacheError {
#[error("Database error: {0}")]
DbError(#[from] CalckeyDbError),
@ -16,13 +18,7 @@ pub enum RemoteInstanceCacheError {
impl From<RemoteInstanceCacheError> for ApiError {
fn from(err: RemoteInstanceCacheError) -> Self {
let mut api_error: ApiError = match err {
RemoteInstanceCacheError::DbError(err) => err.into(),
};
api_error.message = format!("Remote instance cache error: {}", api_error.message);
api_error
Self::internal("Cache error", err)
}
}

View File

@ -1,13 +1,15 @@
use crate::web::ApiError;
use magnetar_model::{ck, CalckeyDbError, CalckeyModel};
use miette::Diagnostic;
use std::sync::Arc;
use std::time::{Duration, Instant};
use strum::VariantNames;
use thiserror::Error;
use tokio::sync::{mpsc, oneshot};
use tracing::error;
#[derive(Debug, Error, VariantNames)]
#[derive(Debug, Error, Diagnostic)]
#[error("Instance meta cache error: {}")]
#[diagnostic(code(mag::instance_meta_cache_error))]
pub enum InstanceMetaCacheError {
#[error("Database error: {0}")]
DbError(#[from] CalckeyDbError),
@ -17,14 +19,7 @@ pub enum InstanceMetaCacheError {
impl From<InstanceMetaCacheError> for ApiError {
fn from(err: InstanceMetaCacheError) -> Self {
let mut api_error: ApiError = match err {
InstanceMetaCacheError::DbError(err) => err.into(),
InstanceMetaCacheError::ChannelClosed => err.into(),
};
api_error.message = format!("Instance meta cache error: {}", api_error.message);
api_error
Self::internal("Cache error", err)
}
}

View File

@ -2,52 +2,58 @@ use std::collections::HashMap;
use std::sync::Arc;
use cached::{Cached, TimedCache};
use strum::VariantNames;
use miette::Diagnostic;
use thiserror::Error;
use tokio::sync::Mutex;
use tracing::error;
use crate::web::ApiError;
use magnetar_common::config::MagnetarConfig;
use magnetar_federation::crypto::{ApHttpPrivateKey, ApHttpPrivateKeyParseError, ApHttpPublicKey, ApHttpPublicKeyParseError};
use magnetar_model::{
ck, CalckeyCache, CalckeyCacheError, CalckeyDbError, CalckeyModel, CalckeySub,
InternalStreamMessage, SubMessage,
};
use crate::web::ApiError;
#[derive(Debug, Error, VariantNames)]
#[derive(Debug, Error, Diagnostic)]
#[error("Local user cache error: {}")]
#[diagnostic(code(mag::local_user_cache_error))]
pub enum UserCacheError {
#[error("Database error: {0}")]
DbError(#[from] CalckeyDbError),
#[error("Redis error: {0}")]
RedisError(#[from] CalckeyCacheError),
#[error("Private key parse error: {0}")]
PrivateKeyParseError(#[from] ApHttpPrivateKeyParseError),
#[error("Public key parse error: {0}")]
PublicKeyParseError(#[from] ApHttpPublicKeyParseError),
}
impl From<UserCacheError> for ApiError {
fn from(err: UserCacheError) -> Self {
Self::internal("Cache error", err)
}
}
#[derive(Debug, Clone)]
pub struct CachedLocalUser {
pub user: Arc<ck::user::Model>,
pub profile: Arc<ck::user_profile::Model>,
pub private_key: Arc<ApHttpPrivateKey<'static>>,
pub public_key: Arc<ApHttpPublicKey<'static>>,
}
impl From<(ck::user::Model, ck::user_profile::Model)> for CachedLocalUser {
fn from((user, profile): (ck::user::Model, ck::user_profile::Model)) -> Self {
CachedLocalUser {
impl TryFrom<(ck::user::Model, ck::user_profile::Model, ck::user_keypair::Model)> for CachedLocalUser {
type Error = UserCacheError;
fn try_from((user, profile, key_pair): (ck::user::Model, ck::user_profile::Model, ck::user_keypair::Model)) -> Result<Self, Self::Error> {
Ok(CachedLocalUser {
user: Arc::new(user),
profile: Arc::new(profile),
}
}
}
impl From<UserCacheError> for ApiError {
fn from(err: UserCacheError) -> Self {
let mut api_error: ApiError = match err {
UserCacheError::DbError(err) => err.into(),
UserCacheError::RedisError(err) => err.into(),
};
api_error.message = format!("Local user cache error: {}", api_error.message);
api_error
private_key: Arc::new(key_pair.private_key.parse()?),
public_key: Arc::new(key_pair.public_key.parse()?),
})
}
}
@ -156,7 +162,7 @@ impl LocalUserCacheService {
| InternalStreamMessage::UserChangeSuspendedState { id, .. }
| InternalStreamMessage::RemoteUserUpdated { id }
| InternalStreamMessage::UserTokenRegenerated { id, .. } => {
let user_profile = match db.get_user_and_profile_by_id(&id).await {
let user_profile = match db.get_user_for_cache_by_id(&id).await {
Ok(Some(m)) => m,
Ok(None) => return,
Err(e) => {
@ -165,7 +171,15 @@ impl LocalUserCacheService {
}
};
cache.lock().await.refresh(&CachedLocalUser::from(user_profile));
let cached: CachedLocalUser = match user_profile.try_into() {
Ok(c) => c,
Err(e) => {
error!("Error parsing user from database: {}", e);
return;
}
};
cache.lock().await.refresh(&cached);
}
_ => {}
};
@ -202,7 +216,7 @@ impl LocalUserCacheService {
return Ok(Some(user));
}
self.map_cache_user(self.db.get_user_and_profile_by_token(token).await?.map(CachedLocalUser::from))
self.map_cache_user(self.db.get_user_for_cache_by_token(token).await?.map(CachedLocalUser::try_from).transpose()?)
.await
}
@ -216,6 +230,6 @@ impl LocalUserCacheService {
return Ok(Some(user));
}
self.map_cache_user(self.db.get_user_and_profile_by_id(id).await?.map(CachedLocalUser::from)).await
self.map_cache_user(self.db.get_user_for_cache_by_id(id).await?.map(CachedLocalUser::try_from).transpose()?).await
}
}

View File

@ -1,19 +1,27 @@
use federation_client::{new_ap_client_service, new_federation_client_service};
use gen_id::GenIdService;
use magnetar_common::config::MagnetarConfig;
use magnetar_federation::ap_client::ApClientError;
use magnetar_federation::client::federation_client::FederationClient;
use magnetar_federation::ApClientService;
use magnetar_model::{ck, CalckeyCache, CalckeyModel};
use std::fmt::{Debug, Formatter};
use std::time::Duration;
use thiserror::Error;
pub mod emoji_cache;
pub mod federation_client;
pub mod gen_id;
pub mod generic_id_cache;
pub mod instance_cache;
pub mod instance_meta_cache;
pub mod local_user_cache;
#[non_exhaustive]
pub type ApClient = dyn ApClientService<Error = ApClientError> + Send + Sync;
#[non_exhaustive]
pub struct MagnetarService {
pub db: CalckeyModel,
pub gen_id: GenIdService,
pub cache: CalckeyCache,
pub config: &'static MagnetarConfig,
pub local_user_cache: local_user_cache::LocalUserCacheService,
@ -21,6 +29,8 @@ pub struct MagnetarService {
pub remote_instance_cache: instance_cache::RemoteInstanceCacheService,
pub emoji_cache: emoji_cache::EmojiCacheService,
pub drive_file_cache: generic_id_cache::GenericIdCacheService<ck::drive_file::Entity>,
pub federation_client: FederationClient,
pub ap_client: Box<ApClient>,
}
impl Debug for MagnetarService {
@ -33,18 +43,12 @@ impl Debug for MagnetarService {
}
}
#[derive(Debug, Error)]
pub enum ServiceInitError {
#[error("Authentication cache initialization error: {0}")]
AuthCacheError(#[from] local_user_cache::UserCacheError),
}
impl MagnetarService {
pub async fn new(
config: &'static MagnetarConfig,
db: CalckeyModel,
cache: CalckeyCache,
) -> Result<Self, ServiceInitError> {
) -> miette::Result<Self> {
let local_user_cache =
local_user_cache::LocalUserCacheService::new(config, db.clone(), cache.clone()).await?;
let instance_meta_cache = instance_meta_cache::InstanceMetaCacheService::new(db.clone());
@ -58,6 +62,9 @@ impl MagnetarService {
let drive_file_cache =
generic_id_cache::GenericIdCacheService::new(db.clone(), 128, Duration::from_secs(10));
let federation_client = new_federation_client_service(config)?;
let ap_client = Box::new(new_ap_client_service(Box::new(federation_client.clone())));
Ok(Self {
db,
cache,
@ -67,6 +74,16 @@ impl MagnetarService {
remote_instance_cache,
emoji_cache,
drive_file_cache,
gen_id: GenIdService,
federation_client,
ap_client,
})
}
pub fn service<T>(&self) -> &T
where
Self: AsRef<T>,
{
self.as_ref()
}
}

18
src/vars.rs Normal file
View File

@ -0,0 +1,18 @@
//! Dynamic configuration variables that would be very messy to configure from the environment
//!
//! Most of these values were arbitrarily chosen, so your mileage may vary when adjusting these.
//!
//! Larger instances may benefit from higher cache sizes and increased parallelism, while smaller
//! ones may want to opt for memory savings.
use miette::IntoDiagnostic;
#[derive(Debug)]
pub struct MagVars {}
fn parse_vars() -> miette::Result<MagVars> {
let default_cfg = std::fs::read_to_string("config/default-vars.kdl").into_diagnostic()?;
let doc = default_cfg.parse::<kdl::KdlDocument>().into_diagnostic()?;
Ok(MagVars {})
}

View File

@ -2,23 +2,23 @@ use std::convert::Infallible;
use std::sync::Arc;
use axum::async_trait;
use axum::extract::{FromRequestParts, Request, State};
use axum::extract::rejection::ExtensionRejection;
use axum::http::{HeaderMap, StatusCode};
use axum::extract::{FromRequestParts, Request, State};
use axum::http::request::Parts;
use axum::http::{HeaderMap, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use headers::{Authorization, HeaderMapExt};
use headers::authorization::Bearer;
use strum::IntoStaticStr;
use headers::{Authorization, HeaderMapExt};
use miette::{miette, Diagnostic};
use thiserror::Error;
use tracing::error;
use magnetar_model::{CalckeyDbError, ck};
use magnetar_model::{ck, CalckeyDbError};
use crate::service::local_user_cache::{CachedLocalUser, UserCacheError};
use crate::service::MagnetarService;
use crate::web::{ApiError, IntoErrorCode};
use crate::web::ApiError;
#[derive(Clone, Debug)]
pub enum AuthMode {
@ -45,15 +45,13 @@ pub struct AuthUserRejection(ApiError);
impl From<ExtensionRejection> for AuthUserRejection {
fn from(rejection: ExtensionRejection) -> Self {
AuthUserRejection(ApiError {
status: StatusCode::UNAUTHORIZED,
code: "Unauthorized".error_code(),
message: if cfg!(debug_assertions) {
format!("Missing auth extension: {}", rejection)
} else {
"Unauthorized".to_string()
},
})
AuthUserRejection(
ApiError::new(
StatusCode::UNAUTHORIZED,
"Unauthorized",
miette!(code = "mag::auth_user_rejection", "Missing auth extension: {}", rejection),
)
)
}
}
@ -89,76 +87,46 @@ pub struct AuthState {
service: Arc<MagnetarService>,
}
#[derive(Debug, Error, IntoStaticStr)]
#[derive(Debug, Error, Diagnostic)]
#[error("Auth error: {}")]
enum AuthError {
#[error("Unsupported authorization scheme")]
#[diagnostic(code(mag::auth_error::unsupported_scheme))]
UnsupportedScheme,
#[error("Cache error: {0}")]
#[diagnostic(code(mag::auth_error::cache_error))]
CacheError(#[from] UserCacheError),
#[error("Database error: {0}")]
#[diagnostic(code(mag::auth_error::db_error))]
DbError(#[from] CalckeyDbError),
#[error("Invalid token")]
#[diagnostic(code(mag::auth_error::invalid_token))]
InvalidToken,
#[error("Invalid token \"{token}\" referencing user \"{user}\"")]
InvalidTokenUser { token: String, user: String },
#[error("Invalid access token \"{access_token}\" referencing app \"{app}\"")]
InvalidAccessTokenApp { access_token: String, app: String },
#[error("Invalid token referencing user \"{user}\"")]
#[diagnostic(code(mag::auth_error::invalid_token_user))]
InvalidTokenUser { user: String },
#[error("Invalid access token referencing app \"{app}\"")]
#[diagnostic(code(mag::auth_error::invalid_token_app))]
InvalidAccessTokenApp { app: String },
}
impl From<AuthError> for ApiError {
fn from(err: AuthError) -> Self {
match err {
AuthError::UnsupportedScheme => ApiError {
status: StatusCode::UNAUTHORIZED,
code: err.error_code(),
message: "Unsupported authorization scheme".to_string(),
},
AuthError::CacheError(err) => err.into(),
AuthError::DbError(err) => err.into(),
AuthError::InvalidTokenUser {
ref token,
ref user,
} => {
error!("Invalid token \"{}\" referencing user \"{}\"", token, user);
let code = match err {
AuthError::InvalidToken => StatusCode::UNAUTHORIZED,
_ => StatusCode::INTERNAL_SERVER_ERROR
};
ApiError {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: err.error_code(),
message: if cfg!(debug_assertions) {
format!("Invalid token \"{}\" referencing user \"{}\"", token, user)
} else {
"Invalid token-user link".to_string()
},
}
}
AuthError::InvalidAccessTokenApp {
ref access_token,
ref app,
} => {
error!(
"Invalid access token \"{}\" referencing app \"{}\"",
access_token, app
);
let message = match err {
AuthError::UnsupportedScheme => "Unsupported authorization scheme",
AuthError::InvalidTokenUser { .. } => "Invalid token and user combination",
AuthError::InvalidAccessTokenApp { .. } => "Invalid token and app combination",
AuthError::CacheError(_) => "Cache error",
AuthError::DbError(_) => "Database error",
AuthError::InvalidToken => "Invalid account token"
};
ApiError {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: err.error_code(),
message: if cfg!(debug_assertions) {
format!(
"Invalid access token \"{}\" referencing app \"{}\"",
access_token, app
)
} else {
"Invalid access token-app link".to_string()
},
}
}
AuthError::InvalidToken => ApiError {
status: StatusCode::UNAUTHORIZED,
code: err.error_code(),
message: "Invalid token".to_string(),
},
}
ApiError::new(code, message, miette!(err))
}
}
@ -203,7 +171,6 @@ impl AuthState {
if user.is_none() {
return Err(AuthError::InvalidTokenUser {
token: access_token.id,
user: access_token.user_id,
});
}
@ -220,7 +187,6 @@ impl AuthState {
}),
}),
None => Err(AuthError::InvalidAccessTokenApp {
access_token: access_token.id,
app: access_token.user_id,
}),
};

View File

@ -1,9 +1,10 @@
use axum::{http::HeaderValue, response::IntoResponse};
use hyper::{header, StatusCode};
use hyper::header;
use magnetar_core::web_model::{content_type::ContentXrdXml, ContentType};
use miette::miette;
use serde::Serialize;
use crate::web::{ApiError, ErrorCode};
use crate::web::ApiError;
pub struct XrdXmlExt<T>(pub T);
@ -20,16 +21,11 @@ impl<T: Serialize> IntoResponse for XrdXmlExt<T> {
buf.into_bytes(),
)
.into_response(),
Err(e) => ApiError {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: ErrorCode("XmlSerializationError".into()),
message: if cfg!(debug_assertions) {
format!("Serialization error: {}", e)
} else {
"Serialization error".to_string()
},
}
.into_response(),
Err(e) => ApiError::internal(
"XmlSerializationError",
miette!(code = "mag::xrd_xml_ext_error", "XML serialization error: {}", e),
)
.into_response(),
}
}
}

View File

@ -3,75 +3,97 @@ use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use magnetar_common::util::FediverseTagParseError;
use magnetar_federation::ap_client::ApClientError;
use magnetar_federation::crypto::{
ApHttpPrivateKeyParseError, ApHttpPublicKeyParseError, ApSigningError,
};
use magnetar_model::{CalckeyCacheError, CalckeyDbError};
use miette::{diagnostic, miette, Diagnostic, Report};
use serde::Serialize;
use serde_json::json;
use std::fmt::{Display, Formatter};
use std::borrow::Cow;
use std::fmt::Display;
use thiserror::Error;
use tracing::warn;
use ulid::Ulid;
pub mod auth;
pub mod extractors;
pub mod pagination;
#[derive(Debug, Clone, Serialize)]
#[repr(transparent)]
pub struct ErrorCode(pub String);
// This janky hack allows us to use `.error_code()` on enums with strum::IntoStaticStr
pub trait IntoErrorCode {
fn error_code<'a, 'b: 'a>(&'a self) -> ErrorCode
where
&'a Self: Into<&'b str>;
}
impl<T: ?Sized> IntoErrorCode for T {
fn error_code<'a, 'b: 'a>(&'a self) -> ErrorCode
where
&'a Self: Into<&'b str>,
{
ErrorCode(<&Self as Into<&'b str>>::into(self).to_string())
}
}
impl ErrorCode {
pub fn join(&self, other: &str) -> Self {
Self(format!("{}:{}", other, self.0))
}
}
#[derive(Debug, Error)]
#[error("API Error")]
pub struct ApiError {
pub status: StatusCode,
pub code: ErrorCode,
pub message: String,
pub nonce: Ulid,
pub message: Cow<'static, str>,
pub cause: miette::Report,
}
impl Display for ApiError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ApiError[status = \"{}\", code = \"{:?}\"]: \"{}\"",
self.status, self.code, self.message
)
#[derive(Debug, Serialize)]
pub struct ApiErrorBare<'a> {
pub status: u16,
pub nonce: &'a str,
pub message: &'a str,
}
impl Serialize for ApiError {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
ApiErrorBare {
status: self.status.as_u16(),
nonce: &self.nonce.to_string(),
message: &self.message,
}
.serialize(serializer)
}
}
#[derive(Debug)]
pub struct AccessForbidden(pub String);
impl ApiError {
pub fn new(
status: StatusCode,
message: impl Into<Cow<'static, str>>,
cause: impl Into<Report>,
) -> Self {
Self {
status,
nonce: Ulid::new(),
message: message.into(),
cause: cause.into(),
}
}
impl From<&AccessForbidden> for &str {
fn from(_: &AccessForbidden) -> &'static str {
"AccessForbidden"
pub fn internal(message: impl Into<Cow<'static, str>>, cause: impl Into<Report>) -> Self {
Self::new(StatusCode::INTERNAL_SERVER_ERROR, message, cause)
}
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let mut buf = [0; ulid::ULID_LEN];
let nonce = self.nonce.array_to_str(&mut buf);
warn!(
"[status={},nonce={}] {}",
self.status.as_str(),
nonce,
self.cause
);
let code = self
.cause
.code()
.as_deref()
.map(<dyn Display as ToString>::to_string);
(
self.status,
Json(json!({
"status": self.status.as_u16(),
"code": self.code,
"code": code,
"nonce": nonce,
"message": self.message,
})),
)
@ -79,114 +101,106 @@ impl IntoResponse for ApiError {
}
}
#[derive(Debug, Error, Diagnostic)]
#[error("Access forbidden: {0}")]
#[diagnostic(code(mag::access_forbidden))]
pub struct AccessForbidden(pub String);
impl From<AccessForbidden> for ApiError {
fn from(err: AccessForbidden) -> Self {
Self {
status: StatusCode::FORBIDDEN,
code: err.error_code(),
message: if cfg!(debug_assertions) {
format!("Forbidden: {}", err.0)
} else {
"Forbidden".to_string()
},
}
Self::new(StatusCode::FORBIDDEN, "Access forbidden", err)
}
}
impl From<FediverseTagParseError> for ApiError {
fn from(err: FediverseTagParseError) -> Self {
Self {
status: StatusCode::BAD_REQUEST,
code: err.error_code(),
message: if cfg!(debug_assertions) {
format!("Fediverse tag parse error: {}", err)
} else {
"Fediverse tag parse error".to_string()
},
}
Self::new(
StatusCode::BAD_REQUEST,
"Fediverse tag parse error",
miette!(code = "mag::access_forbidden", "{}", err),
)
}
}
impl From<CalckeyDbError> for ApiError {
fn from(err: CalckeyDbError) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: err.error_code(),
message: if cfg!(debug_assertions) {
format!("Database error: {}", err)
} else {
"Database error".to_string()
},
}
Self::internal("Database error", miette!(code = "mag::db_error", "{}", err))
}
}
impl From<CalckeyCacheError> for ApiError {
fn from(err: CalckeyCacheError) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: err.error_code(),
message: if cfg!(debug_assertions) {
format!("Cache error: {}", err)
} else {
"Cache error".to_string()
},
}
Self::internal("Cache error", miette!(code = "mag::cache_error", "{}", err))
}
}
impl From<PackError> for ApiError {
fn from(err: PackError) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: err.error_code(),
message: if cfg!(debug_assertions) {
format!("Data transformation error: {}", err)
} else {
"Data transformation error".to_string()
},
}
Self::internal(
"Data transformation error",
miette!(code = "mag::pack_error", "{}", err),
)
}
}
#[derive(Debug)]
#[derive(Debug, Error, Diagnostic)]
#[error("Object not found: {0}")]
#[diagnostic(code(mag::object_not_found))]
pub struct ObjectNotFound(pub String);
impl From<&ObjectNotFound> for &str {
fn from(_: &ObjectNotFound) -> Self {
"ObjectNotFound"
}
}
impl From<ObjectNotFound> for ApiError {
fn from(err: ObjectNotFound) -> Self {
Self {
status: StatusCode::NOT_FOUND,
code: err.error_code(),
message: if cfg!(debug_assertions) {
format!("Object not found: {}", err.0)
} else {
"Object not found".to_string()
},
}
Self::new(StatusCode::NOT_FOUND, "Object not found", miette!(err))
}
}
#[derive(Debug)]
#[derive(Debug, Error, Diagnostic)]
#[error("Argument out of range: {0}")]
#[diagnostic(code(mag::argument_out_of_range))]
pub struct ArgumentOutOfRange(pub String);
impl From<&ArgumentOutOfRange> for &str {
fn from(_: &ArgumentOutOfRange) -> Self {
"ArgumentOutOfRange"
}
}
impl From<ArgumentOutOfRange> for ApiError {
fn from(err: ArgumentOutOfRange) -> Self {
Self {
status: StatusCode::BAD_REQUEST,
code: err.error_code(),
message: format!("Argument out of range: {}", err.0),
}
Self::new(
StatusCode::BAD_REQUEST,
format!("Argument out of range: {}", err.0),
err,
)
}
}
impl From<ApHttpPublicKeyParseError> for ApiError {
fn from(err: ApHttpPublicKeyParseError) -> Self {
Self::internal(
"User public key parse error",
miette!(code = "mag::ap_http_public_key_parse_error", "{}", err),
)
}
}
impl From<ApHttpPrivateKeyParseError> for ApiError {
fn from(err: ApHttpPrivateKeyParseError) -> Self {
Self::internal(
"User private key parse error",
miette!(code = "mag::ap_http_private_key_parse_error", "{}", err),
)
}
}
impl From<ApSigningError> for ApiError {
fn from(err: ApSigningError) -> Self {
Self::internal(
"ActivityPub HTTP signing error",
miette!(code = "mag::ap_signing_error", "{}", err),
)
}
}
impl From<ApClientError> for ApiError {
fn from(err: ApClientError) -> Self {
Self::internal(
"ActivityPub client error",
miette!(code = "mag::ap_client_error", "{}", err),
)
}
}

View File

@ -1,6 +1,6 @@
use crate::service::MagnetarService;
use crate::util::serialize_as_urlenc;
use crate::web::{ApiError, IntoErrorCode};
use crate::web::ApiError;
use axum::extract::rejection::QueryRejection;
use axum::extract::{FromRequestParts, OriginalUri, Query};
use axum::http::header::InvalidHeaderValue;
@ -14,10 +14,10 @@ use magnetar_core::web_model::rel::{RelNext, RelPrev};
use magnetar_model::sea_orm::prelude::async_trait::async_trait;
use magnetar_sdk::types::{PaginationShape, SpanFilter};
use magnetar_sdk::util_types::U64Range;
use miette::{miette, Diagnostic};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use strum::IntoStaticStr;
use thiserror::Error;
use tracing::error;
@ -39,32 +39,32 @@ struct PaginationQuery {
query_rest: HashMap<String, String>,
}
#[derive(Debug, Error, IntoStaticStr)]
#[derive(Debug, Error, Diagnostic)]
#[error("Pagination builder error: {}")]
pub enum PaginationBuilderError {
#[error("Query rejection: {0}")]
#[diagnostic(code(mag::pagination_builder_error::query_rejection))]
QueryRejection(#[from] QueryRejection),
#[error("HTTP error: {0}")]
#[diagnostic(code(mag::pagination_builder_error::http_error))]
HttpError(#[from] axum::http::Error),
#[error("Value of out of range error")]
OutOfRange,
#[error("Invalid header value")]
#[diagnostic(code(mag::pagination_builder_error::invalid_header_value))]
InvalidHeaderValue(#[from] InvalidHeaderValue),
#[error("Query string serialization error: {0}")]
#[diagnostic(code(mag::pagination_builder_error::serialization_error_query))]
SerializationErrorQuery(#[from] serde_urlencoded::ser::Error),
#[error("Query string serialization error: {0}")]
#[diagnostic(code(mag::pagination_builder_error::serialization_error_json))]
SerializationErrorJson(#[from] serde_json::Error),
}
impl From<PaginationBuilderError> for ApiError {
fn from(err: PaginationBuilderError) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: err.error_code(),
message: if cfg!(debug_assertions) {
format!("Pagination builder error: {}", err)
} else {
"Pagination builder error".to_string()
},
if matches!(err, PaginationBuilderError::QueryRejection(_)) {
Self::new(StatusCode::BAD_REQUEST, "Invalid pagination query", miette!(err))
} else {
Self::internal("Pagination error", miette!(err))
}
}
}
@ -92,9 +92,9 @@ impl FromRequestParts<Arc<MagnetarService>> for Pagination {
.build()?;
let Query(PaginationQuery {
pagination,
query_rest,
}) = parts.extract::<Query<PaginationQuery>>().await?;
pagination,
query_rest,
}) = parts.extract::<Query<PaginationQuery>>().await?;
Ok(Pagination {
base_uri,