192 lines
5.2 KiB
Rust
192 lines
5.2 KiB
Rust
use std::io::Cursor;
|
|
|
|
use async_stream::stream;
|
|
use futures_util::{select, stream::StreamExt, FutureExt, Stream, TryStreamExt};
|
|
use hyper::body::Bytes;
|
|
use magnetar_common::config::MagnetarNetworkingProtocol;
|
|
use magnetar_core::web_model::content_type::ContentActivityStreams;
|
|
use magnetar_host_meta::Xrd;
|
|
use reqwest::{Client, RequestBuilder};
|
|
use serde_json::Value;
|
|
use thiserror::Error;
|
|
use tokio::pin;
|
|
use url::Url;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct FederationClient {
|
|
pub client: Client,
|
|
pub body_limit: usize,
|
|
pub timeout_seconds: u64,
|
|
}
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum FederationClientBuilderError {
|
|
#[error("Reqwest error: {0}")]
|
|
ReqwestError(#[from] reqwest::Error),
|
|
}
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum FederationClientError {
|
|
#[error("Fetch timed out")]
|
|
TimeoutError,
|
|
#[error("Reqwest error: {0}")]
|
|
ReqwestError(#[from] reqwest::Error),
|
|
#[error("JSON error: {0}")]
|
|
JsonError(#[from] serde_json::Error),
|
|
#[error("XML error: {0}")]
|
|
XmlError(#[from] quick_xml::de::DeError),
|
|
#[error("Body limit exceeded error")]
|
|
BodyLimitExceededError,
|
|
#[error("Invalid URL: {0}")]
|
|
InvalidUrl(#[from] url::ParseError),
|
|
#[error("Client error: {0}")]
|
|
Other(String),
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct FederationRequestBuilder<'a> {
|
|
client: &'a FederationClient,
|
|
builder: RequestBuilder,
|
|
}
|
|
|
|
impl FederationClient {
|
|
pub fn new(
|
|
force_https: bool,
|
|
body_limit: usize,
|
|
timeout_seconds: u64,
|
|
) -> Result<FederationClient, FederationClientBuilderError> {
|
|
let client = Client::builder().https_only(force_https).build()?;
|
|
|
|
Ok(FederationClient {
|
|
client,
|
|
body_limit,
|
|
timeout_seconds,
|
|
})
|
|
}
|
|
|
|
pub fn builder(&self, method: reqwest::Method, url: Url) -> FederationRequestBuilder<'_> {
|
|
FederationRequestBuilder {
|
|
client: &self,
|
|
builder: self.client.request(method, url),
|
|
}
|
|
}
|
|
|
|
pub fn get(&self, url: Url) -> FederationRequestBuilder<'_> {
|
|
self.builder(reqwest::Method::GET, url)
|
|
}
|
|
|
|
pub async fn host_meta(
|
|
&self,
|
|
protocol: MagnetarNetworkingProtocol,
|
|
host: &str,
|
|
) -> Result<Xrd, FederationClientError> {
|
|
let host_meta_xml = self
|
|
.get(Url::parse(&format!(
|
|
"{}://{}/.well-known/host-meta",
|
|
protocol.as_ref(),
|
|
host
|
|
))?)
|
|
.send()
|
|
.await?;
|
|
|
|
let reader = quick_xml::de::from_reader(Cursor::new(host_meta_xml))?;
|
|
|
|
Ok(reader)
|
|
}
|
|
}
|
|
|
|
impl FederationRequestBuilder<'_> {
|
|
pub fn headers(self, headers: reqwest::header::HeaderMap) -> Self {
|
|
Self {
|
|
client: self.client,
|
|
builder: self.builder.headers(headers),
|
|
}
|
|
}
|
|
|
|
async fn send_stream(
|
|
self,
|
|
) -> Result<impl Stream<Item = Result<Bytes, FederationClientError>>, FederationClientError>
|
|
{
|
|
let mut body = self
|
|
.builder
|
|
.send()
|
|
.await?
|
|
.error_for_status()?
|
|
.bytes_stream()
|
|
.map(|b| b.map_err(FederationClientError::ReqwestError));
|
|
|
|
let body_limit = self.client.body_limit;
|
|
let mut partial_length: usize = 0;
|
|
Ok(stream! {
|
|
while let Some(chunk) = body.next().await.transpose()? {
|
|
if partial_length + chunk.len() > body_limit {
|
|
yield Err(FederationClientError::BodyLimitExceededError);
|
|
}
|
|
|
|
partial_length += chunk.len();
|
|
yield Ok(chunk);
|
|
}
|
|
})
|
|
}
|
|
|
|
async fn send(self) -> Result<Vec<u8>, FederationClientError> {
|
|
let sleep = tokio::time::sleep(tokio::time::Duration::from_secs(
|
|
self.client.timeout_seconds,
|
|
))
|
|
.fuse();
|
|
tokio::pin!(sleep);
|
|
|
|
let body = async move {
|
|
self.send_stream()
|
|
.await?
|
|
.try_fold(Vec::new(), |mut acc, b| async move {
|
|
acc.extend_from_slice(&b);
|
|
Ok(acc)
|
|
})
|
|
.await
|
|
}
|
|
.fuse();
|
|
|
|
pin!(body);
|
|
|
|
select! {
|
|
b = body => b,
|
|
_ = sleep => Err(FederationClientError::TimeoutError)
|
|
}
|
|
}
|
|
|
|
pub async fn dereference(self) -> Result<Value, FederationClientError> {
|
|
let mut headers = reqwest::header::HeaderMap::new();
|
|
|
|
headers.insert(
|
|
reqwest::header::ACCEPT,
|
|
reqwest::header::HeaderValue::from_static(ContentActivityStreams.as_ref()),
|
|
);
|
|
|
|
let data = self.send().await?;
|
|
let json =
|
|
serde_json::from_slice::<Value>(&data).map_err(FederationClientError::JsonError)?;
|
|
|
|
Ok(json)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use magnetar_common::config::MagnetarNetworkingProtocol;
|
|
use miette::IntoDiagnostic;
|
|
|
|
use super::FederationClient;
|
|
|
|
#[tokio::test]
|
|
async fn test() -> miette::Result<()> {
|
|
let client = FederationClient::new(true, 1024 * 1024, 30).into_diagnostic()?;
|
|
let host_meta = client
|
|
.host_meta(MagnetarNetworkingProtocol::Https, "astolfo.social")
|
|
.await
|
|
.into_diagnostic()?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|