magnetar/ext_calckey_model/src/lib.rs

472 lines
13 KiB
Rust

pub mod emoji;
pub mod model_ext;
pub mod note_model;
pub mod notification_model;
pub mod poll;
pub mod user_model;
pub use ck;
use ck::*;
pub use sea_orm;
use user_model::UserResolver;
use crate::note_model::NoteResolver;
use crate::notification_model::NotificationResolver;
use chrono::Utc;
use ext_calckey_model_migration::{Migrator, MigratorTrait};
use futures_util::StreamExt;
use redis::IntoConnectionInfo;
use sea_orm::ActiveValue::Set;
use sea_orm::{
ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter,
TransactionTrait,
};
use serde::{Deserialize, Serialize};
use std::future::Future;
use strum::IntoStaticStr;
use thiserror::Error;
use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::log::LevelFilter;
use tracing::{error, info, trace};
#[derive(Debug)]
pub struct ConnectorConfig {
pub url: String,
}
#[derive(Clone, Debug)]
pub struct CalckeyModel(DatabaseConnection);
#[derive(Debug, Error, IntoStaticStr)]
pub enum CalckeyDbError {
#[error("Database error: {0}")]
DbError(#[from] DbErr),
}
impl CalckeyModel {
pub async fn new(config: ConnectorConfig) -> Result<Self, CalckeyDbError> {
let opt = ConnectOptions::new(config.url)
.max_connections(64)
.min_connections(8)
.sqlx_logging(true)
.sqlx_logging_level(LevelFilter::Debug)
.to_owned();
info!("Attempting database connection...");
Ok(CalckeyModel(sea_orm::Database::connect(opt).await?))
}
pub async fn migrate(&self) -> Result<(), CalckeyDbError> {
Migrator::up(&self.0, None).await?;
Ok(())
}
pub fn inner(&self) -> &DatabaseConnection {
&self.0
}
pub async fn get_user_by_tag(
&self,
name: &str,
instance: Option<&str>,
) -> Result<Option<user::Model>, CalckeyDbError> {
let name = name.to_lowercase();
let instance = instance.map(str::to_lowercase);
let user = if let Some(instance) = instance {
user::Entity::find()
.filter(user::Column::UsernameLower.eq(name))
.filter(user::Column::Host.eq(instance))
} else {
user::Entity::find().filter(
user::Column::UsernameLower
.eq(name)
.and(user::Column::Host.is_null()),
)
}
.one(&self.0)
.await?;
Ok(user)
}
pub async fn get_note_by_id(&self, id: &str) -> Result<Option<note::Model>, CalckeyDbError> {
Ok(note::Entity::find()
.filter(note::Column::Id.eq(id))
.one(&self.0)
.await?)
}
pub async fn get_user_by_id(&self, id: &str) -> Result<Option<user::Model>, CalckeyDbError> {
Ok(user::Entity::find()
.filter(user::Column::Id.eq(id))
.one(&self.0)
.await?)
}
pub async fn get_user_profile_by_id(
&self,
id: &str,
) -> Result<Option<user_profile::Model>, CalckeyDbError> {
Ok(user_profile::Entity::find()
.filter(user_profile::Column::UserId.eq(id))
.one(&self.0)
.await?)
}
pub async fn get_user_security_keys_by_id(
&self,
id: &str,
) -> Result<Vec<user_security_key::Model>, CalckeyDbError> {
Ok(user_security_key::Entity::find()
.filter(user_security_key::Column::UserId.eq(id))
.all(&self.0)
.await?)
}
pub async fn get_many_users_by_id(
&self,
id: &[String],
) -> Result<Vec<user::Model>, CalckeyDbError> {
Ok(user::Entity::find()
.filter(user::Column::Id.is_in(id))
.all(&self.0)
.await?)
}
pub async fn get_user_by_token(
&self,
token: &str,
) -> Result<Option<user::Model>, CalckeyDbError> {
Ok(user::Entity::find()
.filter(
user::Column::Token
.eq(token)
.and(user::Column::Host.is_null()),
)
.one(&self.0)
.await?)
}
pub async fn get_user_by_uri(&self, uri: &str) -> Result<Option<user::Model>, CalckeyDbError> {
Ok(user::Entity::find()
.filter(user::Column::Uri.eq(uri))
.one(&self.0)
.await?)
}
pub async fn get_follower_status(
&self,
from: &str,
to: &str,
) -> Result<Option<following::Model>, CalckeyDbError> {
Ok(following::Entity::find()
.filter(
following::Column::FollowerId
.eq(from)
.and(following::Column::FolloweeId.eq(to)),
)
.one(&self.0)
.await?)
}
pub async fn get_follow_request_status(
&self,
from: &str,
to: &str,
) -> Result<Option<follow_request::Model>, CalckeyDbError> {
Ok(follow_request::Entity::find()
.filter(
follow_request::Column::FollowerId
.eq(from)
.and(follow_request::Column::FolloweeId.eq(to)),
)
.one(&self.0)
.await?)
}
pub async fn get_block_status(
&self,
from: &str,
to: &str,
) -> Result<Option<blocking::Model>, CalckeyDbError> {
Ok(blocking::Entity::find()
.filter(
blocking::Column::BlockerId
.eq(from)
.and(blocking::Column::BlockeeId.eq(to)),
)
.one(&self.0)
.await?)
}
pub async fn get_mute_status(
&self,
from: &str,
to: &str,
) -> Result<Option<muting::Model>, CalckeyDbError> {
Ok(muting::Entity::find()
.filter(
muting::Column::MuterId
.eq(from)
.and(muting::Column::MuteeId.eq(to)),
)
.one(&self.0)
.await?)
}
pub async fn get_renote_mute_status(
&self,
from: &str,
to: &str,
) -> Result<Option<renote_muting::Model>, CalckeyDbError> {
Ok(renote_muting::Entity::find()
.filter(
renote_muting::Column::MuterId
.eq(from)
.and(renote_muting::Column::MuteeId.eq(to)),
)
.one(&self.0)
.await?)
}
pub async fn get_access_token(
&self,
token: &str,
) -> Result<Option<access_token::Model>, CalckeyDbError> {
let token = access_token::Entity::update(access_token::ActiveModel {
last_used_at: Set(Some(Utc::now().into())),
..Default::default()
})
.filter(
access_token::Column::Token
.eq(token)
.or(access_token::Column::Hash.eq(token.to_lowercase())),
)
.exec(&self.0)
.await;
if let Err(DbErr::RecordNotUpdated) = token {
return Ok(None);
}
Ok(Some(token?))
}
pub async fn get_app_by_id(&self, id: &str) -> Result<Option<app::Model>, CalckeyDbError> {
Ok(app::Entity::find()
.filter(app::Column::Id.eq(id))
.one(&self.0)
.await?)
}
pub async fn get_instance(
&self,
host: &str,
) -> Result<Option<instance::Model>, CalckeyDbError> {
let instance = instance::Entity::find()
.filter(instance::Column::Host.eq(host))
.one(&self.0)
.await?;
Ok(instance)
}
pub async fn get_instance_meta(&self) -> Result<meta::Model, CalckeyDbError> {
let txn = self.0.begin().await?;
let meta = meta::Entity::find().one(&txn).await?;
if let Some(meta) = meta {
txn.commit().await?;
return Ok(meta);
}
let model = meta::ActiveModel {
id: Set("x".to_string()),
..Default::default()
};
let meta = meta::Entity::insert(model)
.exec_with_returning(&txn)
.await?;
txn.commit().await?;
Ok(meta)
}
pub fn get_notification_resolver(&self) -> NotificationResolver {
NotificationResolver::new(
self.clone(),
self.get_user_resolver(),
self.get_note_resolver(),
)
}
pub fn get_note_resolver(&self) -> NoteResolver {
NoteResolver::new(self.clone(), self.get_user_resolver())
}
pub fn get_user_resolver(&self) -> UserResolver {
UserResolver::new(self.clone())
}
}
#[derive(Debug)]
pub struct CacheConnectorConfig {
pub url: String,
}
#[derive(Clone, Debug)]
pub struct CalckeyCache(redis::Client);
#[derive(Debug, Error, IntoStaticStr)]
pub enum CalckeyCacheError {
#[error("Redis error: {0}")]
RedisError(#[from] redis::RedisError),
}
impl CalckeyCache {
pub fn new(config: CacheConnectorConfig) -> Result<Self, CalckeyCacheError> {
let conn_info = config.url.into_connection_info()?;
// TODO: Allow overriding redis config with individual options (maybe)
let redis_config = redis::ConnectionInfo {
addr: conn_info.addr,
redis: conn_info.redis,
};
Ok(CalckeyCache(redis::Client::open(redis_config)?))
}
pub fn inner(&self) -> &redis::Client {
&self.0
}
pub async fn conn(&self) -> Result<CalckeyCacheClient, CalckeyCacheError> {
Ok(CalckeyCacheClient(self.0.get_async_connection().await?))
}
}
pub struct CalckeyCacheClient(redis::aio::Connection);
#[derive(Clone, Debug, Deserialize)]
#[serde(tag = "channel", content = "message")]
pub enum SubMessage {
Internal(InternalStreamMessage),
#[serde(other)]
Other,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type", content = "body")]
#[serde(rename_all = "camelCase")]
pub enum InternalStreamMessage {
#[serde(rename_all = "camelCase")]
UserChangeSuspendedState {
id: String,
is_suspended: bool,
},
#[serde(rename_all = "camelCase")]
UserChangeSilencedState {
id: String,
is_silenced: bool,
},
#[serde(rename_all = "camelCase")]
UserChangeModeratorState {
id: String,
is_moderator: bool,
},
#[serde(rename_all = "camelCase")]
UserTokenRegenerated {
id: String,
old_token: String,
new_token: String,
},
LocalUserUpdated {
id: String,
},
RemoteUserUpdated {
id: String,
},
WebhookCreated(webhook::Model),
WebhookDeleted(webhook::Model),
WebhookUpdated(webhook::Model),
AntennaCreated(antenna::Model),
AntennaDeleted(antenna::Model),
AntennaUpdated(antenna::Model),
}
impl CalckeyCacheClient {
pub async fn subscribe<F: Future<Output = ()> + Send + 'static>(
self,
prefix: &str,
handler: impl Fn(SubMessage) -> F + Send + Sync + 'static,
) -> Result<CalckeySub, CalckeyCacheError> {
CalckeySub::new(self.0, prefix, handler).await
}
}
pub struct CalckeySub(CancellationToken);
impl CalckeySub {
async fn new<F: Future<Output = ()> + Send + 'static>(
conn: redis::aio::Connection,
prefix: &str,
handler: impl Fn(SubMessage) -> F + Send + Sync + 'static,
) -> Result<Self, CalckeyCacheError> {
let mut pub_sub = conn.into_pubsub();
pub_sub.subscribe(prefix).await?;
let token = CancellationToken::new();
let token_rx = token.clone();
let prefix = prefix.to_string();
tokio::spawn(async move {
let mut on_message = pub_sub.on_message();
while let Some(msg) = select! {
msg = on_message.next() => msg,
_ = token_rx.cancelled() => {
drop(on_message);
if let Err(e) = pub_sub.unsubscribe(prefix).await {
info!("Redis error: {:?}", e);
}
return;
}
} {
let data = &match msg.get_payload::<String>() {
Ok(val) => val,
Err(e) => {
info!("Redis error: {:?}", e);
continue;
}
};
let parsed = match serde_json::from_str::<SubMessage>(data) {
Ok(val) => val,
Err(e) => {
info!("Message parse error: {:?}", e);
continue;
}
};
trace!("Got message: {:#?}", parsed);
handler(parsed).await;
}
});
Ok(CalckeySub(token))
}
}
impl Drop for CalckeySub {
fn drop(&mut self) {
self.0.cancel();
}
}