From 7b02f8427147c2776ec0f0d92408c11c5b878ef8 Mon Sep 17 00:00:00 2001 From: Natty Date: Wed, 17 Jan 2024 00:41:32 +0100 Subject: [PATCH] Basic backend SSE notification implemention --- Cargo.lock | 2 + Cargo.toml | 4 + ext_calckey_model/src/lib.rs | 57 +++++++++-- ext_calckey_model/src/model_ext.rs | 6 ++ ext_calckey_model/src/notification_model.rs | 36 ++++++- magnetar_sdk/src/types/mod.rs | 1 + magnetar_sdk/src/types/streaming.rs | 10 ++ src/api_v1/mod.rs | 3 + src/api_v1/streaming.rs | 100 ++++++++++++++++++++ src/model/processing/notification.rs | 61 ++++++++++++ src/web/mod.rs | 14 ++- 11 files changed, 283 insertions(+), 11 deletions(-) create mode 100644 magnetar_sdk/src/types/streaming.rs create mode 100644 src/api_v1/streaming.rs diff --git a/Cargo.lock b/Cargo.lock index bf41059..324e2c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1472,6 +1472,7 @@ dependencies = [ name = "magnetar" version = "0.3.0-alpha" dependencies = [ + "async-stream", "axum", "axum-extra", "cached", @@ -1502,6 +1503,7 @@ dependencies = [ "strum", "thiserror", "tokio", + "tokio-stream", "toml", "tower", "tower-http", diff --git a/Cargo.toml b/Cargo.toml index 9e005cc..9e53666 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ edition = "2021" [workspace.dependencies] async-trait = "0.1" +async-stream = "0.3" axum = "0.7" axum-extra = "0.9" cached = "0.47" @@ -60,6 +61,7 @@ tera = { version = "1", default-features = false } thiserror = "1" tokio = "1.24" tokio-util = "0.7" +tokio-stream = "0.1" toml = "0.8" tower = "0.4" tower-http = "0.5" @@ -86,9 +88,11 @@ dotenvy = { workspace = true } axum = { workspace = true, features = ["macros"] } axum-extra = { workspace = true, features = ["typed-header"]} +async-stream = { workspace = true } headers = { workspace = true } hyper = { workspace = true, features = ["full"] } tokio = { workspace = true, features = ["full"] } +tokio-stream = { workspace = true } tower = { workspace = true } tower-http = { workspace = true, features = ["cors", "trace", "fs"] } url = { workspace = true } diff --git a/ext_calckey_model/src/lib.rs b/ext_calckey_model/src/lib.rs index c861b36..b83083f 100644 --- a/ext_calckey_model/src/lib.rs +++ b/ext_calckey_model/src/lib.rs @@ -10,6 +10,7 @@ use ck::*; pub use sea_orm; use user_model::UserResolver; +use crate::model_ext::IdShape; use crate::note_model::NoteResolver; use crate::notification_model::NotificationResolver; use chrono::Utc; @@ -21,14 +22,16 @@ use sea_orm::{ ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter, TransactionTrait, }; -use serde::{Deserialize, Serialize}; +use serde::de::Error; +use serde::{Deserialize, Deserializer, Serialize}; +use serde_json::Value; use std::future::Future; use strum::IntoStaticStr; use thiserror::Error; use tokio::select; use tokio_util::sync::CancellationToken; use tracing::log::LevelFilter; -use tracing::{error, info, trace}; +use tracing::{error, info, trace, warn}; #[derive(Debug)] pub struct ConnectorConfig { @@ -353,12 +356,46 @@ impl CalckeyCache { pub struct CalckeyCacheClient(redis::aio::Connection); -#[derive(Clone, Debug, Deserialize)] -#[serde(tag = "channel", content = "message")] +#[derive(Clone, Debug)] pub enum SubMessage { Internal(InternalStreamMessage), - #[serde(other)] - Other, + MainStream(String, MainStreamMessage), + Other(String, Value), +} + +#[derive(Deserialize)] +struct RawMessage<'a> { + channel: &'a str, + message: Value, +} + +impl<'de> Deserialize<'de> for SubMessage { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let raw = RawMessage::deserialize(deserializer)?; + + Ok(match raw.channel { + "internal" => SubMessage::Internal( + InternalStreamMessage::deserialize(raw.message).map_err(Error::custom)?, + ), + c if c.starts_with("mainStream") => SubMessage::MainStream( + c.strip_prefix("mainStream:") + .ok_or_else(|| Error::custom("Invalid mainStream prefix"))? + .to_string(), + MainStreamMessage::deserialize(raw.message).map_err(Error::custom)?, + ), + _ => SubMessage::Other(raw.channel.to_string(), raw.message), + }) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "type", content = "body")] +#[serde(rename_all = "camelCase")] +pub enum MainStreamMessage { + Notification(IdShape), } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -426,6 +463,7 @@ impl CalckeySub { let prefix = prefix.to_string(); tokio::spawn(async move { + trace!("Redis subscriber spawned"); let mut on_message = pub_sub.on_message(); while let Some(msg) = select! { @@ -433,7 +471,7 @@ impl CalckeySub { _ = token_rx.cancelled() => { drop(on_message); if let Err(e) = pub_sub.unsubscribe(prefix).await { - info!("Redis error: {:?}", e); + warn!("Redis error: {:?}", e); } return; } @@ -441,7 +479,7 @@ impl CalckeySub { let data = &match msg.get_payload::() { Ok(val) => val, Err(e) => { - info!("Redis error: {:?}", e); + warn!("Redis error: {:?}", e); continue; } }; @@ -449,7 +487,7 @@ impl CalckeySub { let parsed = match serde_json::from_str::(data) { Ok(val) => val, Err(e) => { - info!("Message parse error: {:?}", e); + warn!("Message parse error: {:?}", e); continue; } }; @@ -466,6 +504,7 @@ impl CalckeySub { impl Drop for CalckeySub { fn drop(&mut self) { + trace!("Redis subscriber dropped"); self.0.cancel(); } } diff --git a/ext_calckey_model/src/model_ext.rs b/ext_calckey_model/src/model_ext.rs index 4db43dc..3c966be 100644 --- a/ext_calckey_model/src/model_ext.rs +++ b/ext_calckey_model/src/model_ext.rs @@ -8,6 +8,7 @@ use sea_orm::{ Iden, IntoIdentity, Iterable, JoinType, QueryTrait, RelationDef, RelationTrait, Select, SelectModel, SelectorTrait, }; +use serde::{Deserialize, Serialize}; use std::fmt::Write; #[derive(Clone)] @@ -332,3 +333,8 @@ pub trait ModelPagination { fn id(&self) -> &str; fn time(&self) -> DateTime; } + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IdShape { + pub id: String, +} diff --git a/ext_calckey_model/src/notification_model.rs b/ext_calckey_model/src/notification_model.rs index 66526d8..6e567f9 100644 --- a/ext_calckey_model/src/notification_model.rs +++ b/ext_calckey_model/src/notification_model.rs @@ -13,7 +13,7 @@ use ext_calckey_model_migration::{JoinType, SelectStatement}; use magnetar_sdk::types::SpanFilter; use sea_orm::prelude::Expr; use sea_orm::sea_query::{IntoCondition, Query}; -use sea_orm::{ActiveEnum, Iden, IntoSimpleExpr, QueryTrait}; +use sea_orm::{ActiveEnum, Iden, QueryTrait}; use sea_orm::{DbErr, EntityTrait, FromQueryResult, QueryFilter, QueryResult, QuerySelect}; use serde::{Deserialize, Serialize}; @@ -130,6 +130,40 @@ impl NotificationResolver { ); } + pub async fn get_single( + &self, + resolve_options: &NotificationResolveOptions, + notification_id: &str, + ) -> Result, CalckeyDbError> { + let notification_tbl = notification::Entity.base_prefix(); + + let mut query = Query::select(); + query.from_as(notification::Entity, notification_tbl.clone()); + + self.resolve( + &mut query, + ¬ification_tbl, + &resolve_options, + &self.note_resolver, + &self.user_resolver, + ); + + let mut select = notification::Entity::find(); + *QuerySelect::query(&mut select) = query; + + let notifications = select + .filter( + notification_tbl + .col(notification::Column::Id) + .eq(notification_id), + ) + .into_model::() + .one(self.db.inner()) + .await?; + + Ok(notifications) + } + pub async fn get( &self, resolve_options: &NotificationResolveOptions, diff --git a/magnetar_sdk/src/types/mod.rs b/magnetar_sdk/src/types/mod.rs index 0c0e7f5..4b3442b 100644 --- a/magnetar_sdk/src/types/mod.rs +++ b/magnetar_sdk/src/types/mod.rs @@ -3,6 +3,7 @@ pub mod emoji; pub mod instance; pub mod note; pub mod notification; +pub mod streaming; pub mod timeline; pub mod user; diff --git a/magnetar_sdk/src/types/streaming.rs b/magnetar_sdk/src/types/streaming.rs new file mode 100644 index 0000000..b8dad91 --- /dev/null +++ b/magnetar_sdk/src/types/streaming.rs @@ -0,0 +1,10 @@ +use crate::types::notification::PackNotification; +use serde::{Deserialize, Serialize}; +use ts_rs::TS; + +#[derive(Clone, Debug, Deserialize, Serialize, TS)] +#[serde(tag = "type", content = "body")] +#[ts(export)] +pub enum ChannelEvent { + Notification(PackNotification), +} diff --git a/src/api_v1/mod.rs b/src/api_v1/mod.rs index 1c62d28..a282f75 100644 --- a/src/api_v1/mod.rs +++ b/src/api_v1/mod.rs @@ -1,7 +1,9 @@ mod note; +mod streaming; mod user; use crate::api_v1::note::handle_note; +use crate::api_v1::streaming::handle_streaming; use crate::api_v1::user::{ handle_follow_requests_self, handle_followers, handle_followers_self, handle_following, handle_following_self, handle_notifications, handle_user_by_id_many, handle_user_info, @@ -31,6 +33,7 @@ pub fn create_api_router(service: Arc) -> Router { .route("/users/@self/followers", get(handle_followers_self)) .route("/users/:id/followers", get(handle_followers)) .route("/notes/:id", get(handle_note)) + .route("/streaming", get(handle_streaming)) .layer(from_fn_with_state( AuthState::new(service.clone()), auth::auth, diff --git a/src/api_v1/streaming.rs b/src/api_v1/streaming.rs new file mode 100644 index 0000000..fca38d2 --- /dev/null +++ b/src/api_v1/streaming.rs @@ -0,0 +1,100 @@ +use crate::model::processing::notification::NotificationModel; +use crate::model::PackingContext; +use crate::service::MagnetarService; +use crate::web::auth::AuthenticatedUser; +use crate::web::ApiError; +use axum::extract::State; +use axum::response::sse::{Event, KeepAlive}; +use axum::response::Sse; +use futures::Stream; +use futures_util::StreamExt as _; +use magnetar_calckey_model::model_ext::IdShape; +use magnetar_calckey_model::{CalckeySub, MainStreamMessage, SubMessage}; +use magnetar_sdk::types::streaming::ChannelEvent; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tracing::{debug, error, trace, warn}; + +pub fn drop_on_close(sub: CalckeySub, tx: mpsc::Sender) { + tokio::spawn(async move { + tx.closed().await; + drop(sub); + debug!("Dropped the listener."); + }); +} + +pub async fn handle_streaming( + State(service): State>, + AuthenticatedUser(self_user): AuthenticatedUser, +) -> Result>>, ApiError> { + trace!("SSE connection from user `{}` start", self_user.username); + + let (tx, rx) = mpsc::channel(1024); + let sub_tx = tx.clone(); + let sub_user_id = self_user.id.clone(); + let sub = service + .cache + .conn() + .await? + .subscribe(&service.config.networking.host, move |message| { + let user_id = sub_user_id.clone(); + let tx = sub_tx.clone(); + async move { + let SubMessage::MainStream(id, msg) = message else { + return; + }; + + if id != user_id { + return; + } + + if let Err(e) = tx.send(msg).await { + warn!("Failed to send stream channel message: {e}"); + } + } + }) + .await?; + + drop_on_close(sub, tx); + + let stream = ReceiverStream::new(rx).filter_map(move |m| { + let service = service.clone(); + let self_user = self_user.clone(); + async move { + match m { + MainStreamMessage::Notification(IdShape { id }) => { + let ctx = PackingContext::new(service, Some(self_user.clone())) + .await + .map_err(|e| { + error!("Failed to create notification packing context: {}", e); + e + }) + .ok()?; + let notification_model = NotificationModel; + + Some( + Event::default().json_data(ChannelEvent::Notification( + notification_model + .get_notification(&ctx, &id, &self_user.id) + .await + .map_err(|e| { + error!("Failed to fetch notification: {}", e); + e + }) + .ok() + .flatten()?, + )), + ) + } + } + } + }); + + Ok(Sse::new(stream).keep_alive( + KeepAlive::new() + .interval(Duration::from_secs(2)) + .text("mag-keep-alive"), + )) +} diff --git a/src/model/processing/notification.rs b/src/model/processing/notification.rs index 411766d..7e69e40 100644 --- a/src/model/processing/notification.rs +++ b/src/model/processing/notification.rs @@ -211,6 +211,67 @@ impl NotificationModel { }) } + pub async fn get_notification( + &self, + ctx: &PackingContext, + notification_id: &str, + user_id: &str, + ) -> PackResult> { + let user_resolve_options = UserResolveOptions { + with_avatar: true, + with_banner: false, + with_profile: false, + }; + + let self_id = ctx.self_user.as_deref().map(ck::user::Model::get_id); + let Some(notification_raw) = ctx + .service + .db + .get_notification_resolver() + .get_single( + &NotificationResolveOptions { + note_options: NoteResolveOptions { + ids: None, + visibility_filter: Arc::new( + NoteVisibilityFilterModel.new_note_visibility_filter(Some(user_id)), + ), + time_range: None, + limit: None, + with_reply_target: true, + with_renote_target: true, + with_interactions_from: self_id.map(str::to_string), + only_pins_from: None, + user_options: user_resolve_options.clone(), + }, + user_options: user_resolve_options, + }, + notification_id, + ) + .await? + else { + return Ok(None); + }; + + let note_model = NoteModel { + with_context: true, + attachments: false, + }; + let user_model = UserModel; + let emoji_model = EmojiModel; + + let notification = self + .pack_notification_single( + ctx, + ¬ification_raw, + ¬e_model, + &user_model, + &emoji_model, + ) + .await?; + + Ok(Some(notification)) + } + pub async fn get_notifications( &self, ctx: &PackingContext, diff --git a/src/web/mod.rs b/src/web/mod.rs index 6ec31dc..9c58585 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -6,6 +6,8 @@ use magnetar_calckey_model::{CalckeyCacheError, CalckeyDbError}; use magnetar_common::util::FediverseTagParseError; use serde::Serialize; use serde_json::json; +use std::fmt::{Display, Formatter}; +use thiserror::Error; pub mod auth; pub mod pagination; @@ -36,13 +38,23 @@ impl ErrorCode { } } -#[derive(Debug)] +#[derive(Debug, Error)] pub struct ApiError { pub status: StatusCode, pub code: ErrorCode, pub message: String, } +impl Display for ApiError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ApiError[status = \"{}\", code = \"{:?}\"]: \"{}\"", + self.status, self.code, self.message + ) + } +} + #[derive(Debug)] pub struct AccessForbidden(pub String);