Basic backend SSE notification implemention

This commit is contained in:
Natty 2024-01-17 00:41:32 +01:00
parent ad3528055f
commit 7b02f84271
Signed by: natty
GPG Key ID: BF6CB659ADEE60EC
11 changed files with 283 additions and 11 deletions

2
Cargo.lock generated
View File

@ -1472,6 +1472,7 @@ dependencies = [
name = "magnetar" name = "magnetar"
version = "0.3.0-alpha" version = "0.3.0-alpha"
dependencies = [ dependencies = [
"async-stream",
"axum", "axum",
"axum-extra", "axum-extra",
"cached", "cached",
@ -1502,6 +1503,7 @@ dependencies = [
"strum", "strum",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-stream",
"toml", "toml",
"tower", "tower",
"tower-http", "tower-http",

View File

@ -24,6 +24,7 @@ edition = "2021"
[workspace.dependencies] [workspace.dependencies]
async-trait = "0.1" async-trait = "0.1"
async-stream = "0.3"
axum = "0.7" axum = "0.7"
axum-extra = "0.9" axum-extra = "0.9"
cached = "0.47" cached = "0.47"
@ -60,6 +61,7 @@ tera = { version = "1", default-features = false }
thiserror = "1" thiserror = "1"
tokio = "1.24" tokio = "1.24"
tokio-util = "0.7" tokio-util = "0.7"
tokio-stream = "0.1"
toml = "0.8" toml = "0.8"
tower = "0.4" tower = "0.4"
tower-http = "0.5" tower-http = "0.5"
@ -86,9 +88,11 @@ dotenvy = { workspace = true }
axum = { workspace = true, features = ["macros"] } axum = { workspace = true, features = ["macros"] }
axum-extra = { workspace = true, features = ["typed-header"]} axum-extra = { workspace = true, features = ["typed-header"]}
async-stream = { workspace = true }
headers = { workspace = true } headers = { workspace = true }
hyper = { workspace = true, features = ["full"] } hyper = { workspace = true, features = ["full"] }
tokio = { workspace = true, features = ["full"] } tokio = { workspace = true, features = ["full"] }
tokio-stream = { workspace = true }
tower = { workspace = true } tower = { workspace = true }
tower-http = { workspace = true, features = ["cors", "trace", "fs"] } tower-http = { workspace = true, features = ["cors", "trace", "fs"] }
url = { workspace = true } url = { workspace = true }

View File

@ -10,6 +10,7 @@ use ck::*;
pub use sea_orm; pub use sea_orm;
use user_model::UserResolver; use user_model::UserResolver;
use crate::model_ext::IdShape;
use crate::note_model::NoteResolver; use crate::note_model::NoteResolver;
use crate::notification_model::NotificationResolver; use crate::notification_model::NotificationResolver;
use chrono::Utc; use chrono::Utc;
@ -21,14 +22,16 @@ use sea_orm::{
ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter, ColumnTrait, ConnectOptions, DatabaseConnection, DbErr, EntityTrait, QueryFilter,
TransactionTrait, TransactionTrait,
}; };
use serde::{Deserialize, Serialize}; use serde::de::Error;
use serde::{Deserialize, Deserializer, Serialize};
use serde_json::Value;
use std::future::Future; use std::future::Future;
use strum::IntoStaticStr; use strum::IntoStaticStr;
use thiserror::Error; use thiserror::Error;
use tokio::select; use tokio::select;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::log::LevelFilter; use tracing::log::LevelFilter;
use tracing::{error, info, trace}; use tracing::{error, info, trace, warn};
#[derive(Debug)] #[derive(Debug)]
pub struct ConnectorConfig { pub struct ConnectorConfig {
@ -353,12 +356,46 @@ impl CalckeyCache {
pub struct CalckeyCacheClient(redis::aio::Connection); pub struct CalckeyCacheClient(redis::aio::Connection);
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug)]
#[serde(tag = "channel", content = "message")]
pub enum SubMessage { pub enum SubMessage {
Internal(InternalStreamMessage), Internal(InternalStreamMessage),
#[serde(other)] MainStream(String, MainStreamMessage),
Other, Other(String, Value),
}
#[derive(Deserialize)]
struct RawMessage<'a> {
channel: &'a str,
message: Value,
}
impl<'de> Deserialize<'de> for SubMessage {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let raw = RawMessage::deserialize(deserializer)?;
Ok(match raw.channel {
"internal" => SubMessage::Internal(
InternalStreamMessage::deserialize(raw.message).map_err(Error::custom)?,
),
c if c.starts_with("mainStream") => SubMessage::MainStream(
c.strip_prefix("mainStream:")
.ok_or_else(|| Error::custom("Invalid mainStream prefix"))?
.to_string(),
MainStreamMessage::deserialize(raw.message).map_err(Error::custom)?,
),
_ => SubMessage::Other(raw.channel.to_string(), raw.message),
})
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type", content = "body")]
#[serde(rename_all = "camelCase")]
pub enum MainStreamMessage {
Notification(IdShape),
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
@ -426,6 +463,7 @@ impl CalckeySub {
let prefix = prefix.to_string(); let prefix = prefix.to_string();
tokio::spawn(async move { tokio::spawn(async move {
trace!("Redis subscriber spawned");
let mut on_message = pub_sub.on_message(); let mut on_message = pub_sub.on_message();
while let Some(msg) = select! { while let Some(msg) = select! {
@ -433,7 +471,7 @@ impl CalckeySub {
_ = token_rx.cancelled() => { _ = token_rx.cancelled() => {
drop(on_message); drop(on_message);
if let Err(e) = pub_sub.unsubscribe(prefix).await { if let Err(e) = pub_sub.unsubscribe(prefix).await {
info!("Redis error: {:?}", e); warn!("Redis error: {:?}", e);
} }
return; return;
} }
@ -441,7 +479,7 @@ impl CalckeySub {
let data = &match msg.get_payload::<String>() { let data = &match msg.get_payload::<String>() {
Ok(val) => val, Ok(val) => val,
Err(e) => { Err(e) => {
info!("Redis error: {:?}", e); warn!("Redis error: {:?}", e);
continue; continue;
} }
}; };
@ -449,7 +487,7 @@ impl CalckeySub {
let parsed = match serde_json::from_str::<SubMessage>(data) { let parsed = match serde_json::from_str::<SubMessage>(data) {
Ok(val) => val, Ok(val) => val,
Err(e) => { Err(e) => {
info!("Message parse error: {:?}", e); warn!("Message parse error: {:?}", e);
continue; continue;
} }
}; };
@ -466,6 +504,7 @@ impl CalckeySub {
impl Drop for CalckeySub { impl Drop for CalckeySub {
fn drop(&mut self) { fn drop(&mut self) {
trace!("Redis subscriber dropped");
self.0.cancel(); self.0.cancel();
} }
} }

View File

@ -8,6 +8,7 @@ use sea_orm::{
Iden, IntoIdentity, Iterable, JoinType, QueryTrait, RelationDef, RelationTrait, Select, Iden, IntoIdentity, Iterable, JoinType, QueryTrait, RelationDef, RelationTrait, Select,
SelectModel, SelectorTrait, SelectModel, SelectorTrait,
}; };
use serde::{Deserialize, Serialize};
use std::fmt::Write; use std::fmt::Write;
#[derive(Clone)] #[derive(Clone)]
@ -332,3 +333,8 @@ pub trait ModelPagination {
fn id(&self) -> &str; fn id(&self) -> &str;
fn time(&self) -> DateTime<Utc>; fn time(&self) -> DateTime<Utc>;
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IdShape {
pub id: String,
}

View File

@ -13,7 +13,7 @@ use ext_calckey_model_migration::{JoinType, SelectStatement};
use magnetar_sdk::types::SpanFilter; use magnetar_sdk::types::SpanFilter;
use sea_orm::prelude::Expr; use sea_orm::prelude::Expr;
use sea_orm::sea_query::{IntoCondition, Query}; use sea_orm::sea_query::{IntoCondition, Query};
use sea_orm::{ActiveEnum, Iden, IntoSimpleExpr, QueryTrait}; use sea_orm::{ActiveEnum, Iden, QueryTrait};
use sea_orm::{DbErr, EntityTrait, FromQueryResult, QueryFilter, QueryResult, QuerySelect}; use sea_orm::{DbErr, EntityTrait, FromQueryResult, QueryFilter, QueryResult, QuerySelect};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -130,6 +130,40 @@ impl NotificationResolver {
); );
} }
pub async fn get_single(
&self,
resolve_options: &NotificationResolveOptions,
notification_id: &str,
) -> Result<Option<NotificationData>, CalckeyDbError> {
let notification_tbl = notification::Entity.base_prefix();
let mut query = Query::select();
query.from_as(notification::Entity, notification_tbl.clone());
self.resolve(
&mut query,
&notification_tbl,
&resolve_options,
&self.note_resolver,
&self.user_resolver,
);
let mut select = notification::Entity::find();
*QuerySelect::query(&mut select) = query;
let notifications = select
.filter(
notification_tbl
.col(notification::Column::Id)
.eq(notification_id),
)
.into_model::<NotificationData>()
.one(self.db.inner())
.await?;
Ok(notifications)
}
pub async fn get( pub async fn get(
&self, &self,
resolve_options: &NotificationResolveOptions, resolve_options: &NotificationResolveOptions,

View File

@ -3,6 +3,7 @@ pub mod emoji;
pub mod instance; pub mod instance;
pub mod note; pub mod note;
pub mod notification; pub mod notification;
pub mod streaming;
pub mod timeline; pub mod timeline;
pub mod user; pub mod user;

View File

@ -0,0 +1,10 @@
use crate::types::notification::PackNotification;
use serde::{Deserialize, Serialize};
use ts_rs::TS;
#[derive(Clone, Debug, Deserialize, Serialize, TS)]
#[serde(tag = "type", content = "body")]
#[ts(export)]
pub enum ChannelEvent {
Notification(PackNotification),
}

View File

@ -1,7 +1,9 @@
mod note; mod note;
mod streaming;
mod user; mod user;
use crate::api_v1::note::handle_note; use crate::api_v1::note::handle_note;
use crate::api_v1::streaming::handle_streaming;
use crate::api_v1::user::{ use crate::api_v1::user::{
handle_follow_requests_self, handle_followers, handle_followers_self, handle_following, handle_follow_requests_self, handle_followers, handle_followers_self, handle_following,
handle_following_self, handle_notifications, handle_user_by_id_many, handle_user_info, handle_following_self, handle_notifications, handle_user_by_id_many, handle_user_info,
@ -31,6 +33,7 @@ pub fn create_api_router(service: Arc<MagnetarService>) -> Router {
.route("/users/@self/followers", get(handle_followers_self)) .route("/users/@self/followers", get(handle_followers_self))
.route("/users/:id/followers", get(handle_followers)) .route("/users/:id/followers", get(handle_followers))
.route("/notes/:id", get(handle_note)) .route("/notes/:id", get(handle_note))
.route("/streaming", get(handle_streaming))
.layer(from_fn_with_state( .layer(from_fn_with_state(
AuthState::new(service.clone()), AuthState::new(service.clone()),
auth::auth, auth::auth,

100
src/api_v1/streaming.rs Normal file
View File

@ -0,0 +1,100 @@
use crate::model::processing::notification::NotificationModel;
use crate::model::PackingContext;
use crate::service::MagnetarService;
use crate::web::auth::AuthenticatedUser;
use crate::web::ApiError;
use axum::extract::State;
use axum::response::sse::{Event, KeepAlive};
use axum::response::Sse;
use futures::Stream;
use futures_util::StreamExt as _;
use magnetar_calckey_model::model_ext::IdShape;
use magnetar_calckey_model::{CalckeySub, MainStreamMessage, SubMessage};
use magnetar_sdk::types::streaming::ChannelEvent;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error, trace, warn};
pub fn drop_on_close(sub: CalckeySub, tx: mpsc::Sender<MainStreamMessage>) {
tokio::spawn(async move {
tx.closed().await;
drop(sub);
debug!("Dropped the listener.");
});
}
pub async fn handle_streaming(
State(service): State<Arc<MagnetarService>>,
AuthenticatedUser(self_user): AuthenticatedUser,
) -> Result<Sse<impl Stream<Item = Result<Event, axum::Error>>>, ApiError> {
trace!("SSE connection from user `{}` start", self_user.username);
let (tx, rx) = mpsc::channel(1024);
let sub_tx = tx.clone();
let sub_user_id = self_user.id.clone();
let sub = service
.cache
.conn()
.await?
.subscribe(&service.config.networking.host, move |message| {
let user_id = sub_user_id.clone();
let tx = sub_tx.clone();
async move {
let SubMessage::MainStream(id, msg) = message else {
return;
};
if id != user_id {
return;
}
if let Err(e) = tx.send(msg).await {
warn!("Failed to send stream channel message: {e}");
}
}
})
.await?;
drop_on_close(sub, tx);
let stream = ReceiverStream::new(rx).filter_map(move |m| {
let service = service.clone();
let self_user = self_user.clone();
async move {
match m {
MainStreamMessage::Notification(IdShape { id }) => {
let ctx = PackingContext::new(service, Some(self_user.clone()))
.await
.map_err(|e| {
error!("Failed to create notification packing context: {}", e);
e
})
.ok()?;
let notification_model = NotificationModel;
Some(
Event::default().json_data(ChannelEvent::Notification(
notification_model
.get_notification(&ctx, &id, &self_user.id)
.await
.map_err(|e| {
error!("Failed to fetch notification: {}", e);
e
})
.ok()
.flatten()?,
)),
)
}
}
}
});
Ok(Sse::new(stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(2))
.text("mag-keep-alive"),
))
}

View File

@ -211,6 +211,67 @@ impl NotificationModel {
}) })
} }
pub async fn get_notification(
&self,
ctx: &PackingContext,
notification_id: &str,
user_id: &str,
) -> PackResult<Option<PackNotification>> {
let user_resolve_options = UserResolveOptions {
with_avatar: true,
with_banner: false,
with_profile: false,
};
let self_id = ctx.self_user.as_deref().map(ck::user::Model::get_id);
let Some(notification_raw) = ctx
.service
.db
.get_notification_resolver()
.get_single(
&NotificationResolveOptions {
note_options: NoteResolveOptions {
ids: None,
visibility_filter: Arc::new(
NoteVisibilityFilterModel.new_note_visibility_filter(Some(user_id)),
),
time_range: None,
limit: None,
with_reply_target: true,
with_renote_target: true,
with_interactions_from: self_id.map(str::to_string),
only_pins_from: None,
user_options: user_resolve_options.clone(),
},
user_options: user_resolve_options,
},
notification_id,
)
.await?
else {
return Ok(None);
};
let note_model = NoteModel {
with_context: true,
attachments: false,
};
let user_model = UserModel;
let emoji_model = EmojiModel;
let notification = self
.pack_notification_single(
ctx,
&notification_raw,
&note_model,
&user_model,
&emoji_model,
)
.await?;
Ok(Some(notification))
}
pub async fn get_notifications( pub async fn get_notifications(
&self, &self,
ctx: &PackingContext, ctx: &PackingContext,

View File

@ -6,6 +6,8 @@ use magnetar_calckey_model::{CalckeyCacheError, CalckeyDbError};
use magnetar_common::util::FediverseTagParseError; use magnetar_common::util::FediverseTagParseError;
use serde::Serialize; use serde::Serialize;
use serde_json::json; use serde_json::json;
use std::fmt::{Display, Formatter};
use thiserror::Error;
pub mod auth; pub mod auth;
pub mod pagination; pub mod pagination;
@ -36,13 +38,23 @@ impl ErrorCode {
} }
} }
#[derive(Debug)] #[derive(Debug, Error)]
pub struct ApiError { pub struct ApiError {
pub status: StatusCode, pub status: StatusCode,
pub code: ErrorCode, pub code: ErrorCode,
pub message: String, pub message: String,
} }
impl Display for ApiError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ApiError[status = \"{}\", code = \"{:?}\"]: \"{}\"",
self.status, self.code, self.message
)
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct AccessForbidden(pub String); pub struct AccessForbidden(pub String);