magnetar/ext_federation/src/client/federation_client.rs

163 lines
4.5 KiB
Rust

use async_stream::stream;
use futures_util::{select, stream::StreamExt, FutureExt, Stream, TryStreamExt};
use headers::Header;
use hyper::body::Bytes;
use magnetar_core::web_model::{content_type::ContentActivityStreams, ContentType};
use reqwest::{redirect::Policy, Client, RequestBuilder};
use serde_json::Value;
use thiserror::Error;
use tokio::pin;
use url::Url;
#[derive(Debug, Clone)]
pub struct FederationClient {
pub client: Client,
pub body_limit: usize,
pub timeout_seconds: u64,
}
#[derive(Debug, Error)]
pub enum FederationClientBuilderError {
#[error("Reqwest error: {0}")]
ReqwestError(#[from] reqwest::Error),
}
#[derive(Debug, Error)]
pub enum FederationClientError {
#[error("Fetch timed out")]
TimeoutError,
#[error("Reqwest error: {0}")]
ReqwestError(#[from] reqwest::Error),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("XML error: {0}")]
XmlError(#[from] quick_xml::de::DeError),
#[error("Body limit exceeded error")]
BodyLimitExceededError,
#[error("Invalid URL: {0}")]
InvalidUrl(#[from] url::ParseError),
#[error("Client error: {0}")]
Other(String),
}
#[derive(Debug)]
pub struct FederationRequestBuilder<'a> {
client: &'a FederationClient,
builder: RequestBuilder,
}
impl FederationClient {
pub fn new(
force_https: bool,
body_limit: usize,
timeout_seconds: u64,
) -> Result<FederationClient, FederationClientBuilderError> {
let client = Client::builder()
.https_only(force_https)
.redirect(Policy::limited(5))
.build()?;
Ok(FederationClient {
client,
body_limit,
timeout_seconds,
})
}
pub fn builder(&self, method: reqwest::Method, url: Url) -> FederationRequestBuilder<'_> {
FederationRequestBuilder {
client: self,
builder: self.client.request(method, url),
}
}
pub fn get(&self, url: Url) -> FederationRequestBuilder<'_> {
self.builder(reqwest::Method::GET, url)
}
}
impl FederationRequestBuilder<'_> {
pub fn content_type(self, content_type: impl ContentType) -> Self {
Self {
client: self.client,
builder: self.builder.header(
headers::ContentType::name().to_string(),
content_type.mime_type().to_string(),
),
}
}
pub fn headers(self, headers: reqwest::header::HeaderMap) -> Self {
Self {
client: self.client,
builder: self.builder.headers(headers),
}
}
async fn send_stream(
self,
) -> Result<impl Stream<Item = Result<Bytes, FederationClientError>>, FederationClientError>
{
let mut body = self
.builder
.send()
.await?
.error_for_status()?
.bytes_stream()
.map(|b| b.map_err(FederationClientError::ReqwestError));
let body_limit = self.client.body_limit;
let mut partial_length: usize = 0;
Ok(stream! {
while let Some(chunk) = body.next().await.transpose()? {
if partial_length + chunk.len() > body_limit {
yield Err(FederationClientError::BodyLimitExceededError);
}
partial_length += chunk.len();
yield Ok(chunk);
}
})
}
pub async fn send(self) -> Result<Vec<u8>, FederationClientError> {
let sleep = tokio::time::sleep(tokio::time::Duration::from_secs(
self.client.timeout_seconds,
))
.fuse();
tokio::pin!(sleep);
let body = async move {
self.send_stream()
.await?
.try_fold(Vec::new(), |mut acc, b| async move {
acc.extend_from_slice(&b);
Ok(acc)
})
.await
}
.fuse();
pin!(body);
select! {
b = body => b,
_ = sleep => Err(FederationClientError::TimeoutError)
}
}
pub async fn dereference(self) -> Result<Value, FederationClientError> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::ACCEPT,
reqwest::header::HeaderValue::from_static(ContentActivityStreams.as_ref()),
);
let data = self.send().await?;
let json = serde_json::from_slice::<Value>(&data)?;
Ok(json)
}
}