magnetar/src/client/dereferencing_client.rs

95 lines
2.7 KiB
Rust

use futures_util::stream::StreamExt;
use magnetar_activity_pub::Resolver;
use magnetar_core::web_model::content_type::ContentActivityStreams;
use reqwest::Client;
use serde::Deserialize;
use serde_json::Value;
use thiserror::Error;
use url::Url;
#[derive(Debug, Clone)]
pub struct DereferencingClient {
pub client: Client,
pub body_limit: usize,
}
#[derive(Debug, Error)]
pub enum DereferencingClientBuilderError {
#[error("Reqwest error: {0}")]
ReqwestError(#[from] reqwest::Error),
}
#[derive(Debug, Error)]
pub enum DereferencingClientError {
#[error("Reqwest error: {0}")]
ReqwestError(#[from] reqwest::Error),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Body limit exceeded error")]
BodyLimitExceededError,
#[error("Invalid URL: {0}")]
InvalidUrl(#[from] url::ParseError),
#[error("Client error: {0}")]
Other(String),
}
impl DereferencingClient {
pub fn new(
force_https: bool,
body_limit: usize,
) -> Result<DereferencingClient, reqwest::Error> {
let client = Client::builder().https_only(force_https).build()?;
Ok(DereferencingClient { client, body_limit })
}
pub async fn dereference(&self, url: Url) -> Result<Value, DereferencingClientError> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::ACCEPT,
reqwest::header::HeaderValue::from_static(ContentActivityStreams.as_ref()),
);
let response = self
.client
.get(url)
.headers(headers)
.send()
.await
.map_err(DereferencingClientError::ReqwestError)?;
let mut body = response.bytes_stream();
let mut data = Vec::with_capacity(4096);
while let Some(buf) = body.next().await {
let chunk = buf.map_err(DereferencingClientError::ReqwestError)?;
if data.len() + chunk.len() > self.body_limit {
return Err(DereferencingClientError::BodyLimitExceededError);
}
data.extend_from_slice(&chunk);
}
let json =
serde_json::from_slice::<Value>(&data).map_err(DereferencingClientError::JsonError)?;
Ok(json)
}
}
#[async_trait::async_trait]
impl Resolver for DereferencingClient {
type Error = DereferencingClientError;
async fn resolve<T: for<'a> Deserialize<'a>>(&self, id: &str) -> Result<T, Self::Error> {
let url = id.parse().map_err(DereferencingClientError::InvalidUrl)?;
let json = self.dereference(url).await?;
serde_json::from_value(json).map_err(DereferencingClientError::JsonError)
}
}