mag-markov/src/main.rs

272 lines
7.9 KiB
Rust
Raw Normal View History

2024-03-28 20:09:45 +00:00
use std::{collections::HashMap, fs::File, path::Path, time::Duration};
use chrono::{DateTime, TimeZone, Utc};
use futures::TryStreamExt;
use indicatif::ProgressBar;
use rand::distributions::{Distribution, WeightedIndex};
use regex::Regex;
use reqwest::{header, ClientBuilder, RequestBuilder, Response};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sprs::{CsMat, TriMat};
use sqlx::postgres::PgPoolOptions;
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
enum MarkovItem {
Start,
Stop,
Token(u32),
}
impl MarkovItem {
fn index(self) -> usize {
match self {
MarkovItem::Start => 0,
MarkovItem::Stop => 1,
MarkovItem::Token(idx) => idx as usize + 2,
}
}
fn required_matrix_size(items: u32) -> usize {
items as usize + 2
}
fn from_usize(ind: usize) -> Self {
match ind {
0 => MarkovItem::Start,
1 => MarkovItem::Stop,
x => MarkovItem::Token(x.checked_sub(2).unwrap() as u32),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
struct State {
matrix: CsMat<u32>,
mappings: HashMap<u32, String>,
reverse_mappings: HashMap<String, u32>,
fetched_up_to: Option<DateTime<Utc>>,
}
impl State {
fn new() -> Self {
let mut tri_mat = TriMat::new((2, 2));
tri_mat.add_triplet(MarkovItem::Start.index(), MarkovItem::Stop.index(), 1);
State {
matrix: tri_mat.to_csr(),
mappings: HashMap::new(),
reverse_mappings: HashMap::new(),
fetched_up_to: None,
}
}
fn generate(&self) -> String {
let mut rand = rand::thread_rng();
let mut attempts = 0;
loop {
let mut idx = MarkovItem::Start;
let mut tokens = Vec::new();
loop {
if matches!(idx, MarkovItem::Stop) {
break;
}
let weights = WeightedIndex::new(
(0..self.matrix.cols())
.map(|j| self.matrix.get(idx.index(), j).copied().unwrap_or_default()),
)
.expect("Weighted index failed to create");
idx = MarkovItem::from_usize(weights.sample(&mut rand));
if let MarkovItem::Token(ti) = idx {
tokens.push(
self.mappings
.get(&ti)
.expect("Item should be present in the mappings map")
.clone(),
)
}
}
if (tokens.is_empty() || tokens.join(" ").trim().is_empty()) && attempts < 10 {
attempts += 1;
continue;
}
break tokens.join(" ").replace('@', "@\u{200b}");
}
}
fn insert(&mut self, (tok1, tok2): (Option<&str>, Option<&str>)) {
let item1 = match tok1 {
None => MarkovItem::Start,
Some(t) => match self.reverse_mappings.get(t) {
Some(i) => MarkovItem::Token(*i),
None => {
let new_idx = self.reverse_mappings.len() as u32;
self.reverse_mappings.insert(t.to_owned(), new_idx);
self.mappings.insert(new_idx, t.to_owned());
MarkovItem::Token(new_idx)
}
},
};
let item2 = match tok2 {
None => MarkovItem::Stop,
Some(t) => match self.reverse_mappings.get(t) {
Some(i) => MarkovItem::Token(*i),
None => {
let new_idx = self.reverse_mappings.len() as u32;
self.reverse_mappings.insert(t.to_owned(), new_idx);
self.mappings.insert(new_idx, t.to_owned());
MarkovItem::Token(new_idx)
}
},
};
let mut item_value = 10;
if matches!(item1, MarkovItem::Start) && matches!(tok2, Some(val) if val.starts_with('@')) {
item_value /= 3;
}
if tok2.is_none() {
item_value -= 2;
}
if let Some(item) = self.matrix.get_mut(item1.index(), item2.index()) {
*item += item_value;
} else {
self.matrix.insert(item1.index(), item2.index(), item_value)
}
}
fn insert_tokens(&mut self, input: &str) {
let input = input.trim();
// Remove quotes
let input = Regex::new(r"> ?.+")
.unwrap()
.replace_all(input, "")
.to_string();
let regex = Regex::new(r"\s+").unwrap();
let nasty_words = vec![
"hitler",
"natsura",
"eve@snug",
"lustlion",
"abuse",
"pedophi",
"murder",
"transpho",
"holocaust",
"cpluspatch",
"umbrellix",
"sachi",
"heonkey",
];
if nasty_words.iter().any(|w| input.to_lowercase().contains(w)) {
return;
}
let mut last = None;
for tok in regex.split(&input) {
self.insert((last, Some(tok)));
last = Some(tok);
}
self.insert((last, None));
}
}
#[tokio::main]
async fn main() {
dotenvy::dotenv().ok();
let token = std::env::var("TOKEN").expect("token");
let path = Path::new("state.msg");
let progress = ProgressBar::new_spinner();
let mut state = if path.is_file() {
let file = File::open(path).unwrap();
progress.set_message("Loading data...");
rmp_serde::decode::from_read::<_, State>(progress.wrap_read(file)).unwrap()
} else {
State::new()
};
progress.disable_steady_tick();
let pool = PgPoolOptions::new()
.connect(std::env::var("DATABASE_URL").as_deref().unwrap())
.await
.unwrap();
let note_time_min = state
.fetched_up_to
.unwrap_or(Utc.with_ymd_and_hms(2010, 1, 1, 12, 0, 0).unwrap());
progress.set_message("Fetching notes...");
let text_stream = sqlx::query!(
r#"SELECT text, "createdAt"
FROM note
WHERE "note"."userId" = $1
AND "note"."createdAt" > $2
AND "note"."visibility" IN ('public', 'home')
AND ("note"."cw" IS NULL OR LOWER("note"."cw") IN ('', 'gay', 'cursed', 'what', 'shitpost', 'no', 'natty what', 'natty what the fuck'))"#,
"9awy7u3l76",
note_time_min
)
.fetch(&pool);
let mut stream = progress.wrap_stream(text_stream);
let mut cnt = 0;
while let Ok(Some(item)) = stream.try_next().await.map_err(|e| {
eprintln!("Error fetching: {}", e);
e
}) {
state.fetched_up_to = Some(item.createdAt);
state.insert_tokens(item.text.as_deref().unwrap_or_default());
progress.set_message(item.text.unwrap_or_default().clone());
cnt += 1;
progress.tick();
progress.set_length(cnt);
}
let file = File::create(path).unwrap();
progress.set_message("Saving data...");
rmp_serde::encode::write(&mut progress.wrap_write(file), &state).unwrap();
let mut timer = tokio::time::interval(Duration::from_secs(600));
loop {
timer.tick().await;
let text = state.generate();
println!("Generated: {:?}", text);
let client = ClientBuilder::new().https_only(true).build().unwrap();
match client
.post("https://astolfo.social/api/notes/create")
.json(&serde_json::json!({
"text": text,
"poll": Option::<()>::None,
"cw": "Automated Markov-Natty post",
"visibility": "home"
}))
.header(header::AUTHORIZATION, &format!("Bearer {}", token))
.send()
.await
{
Err(e) => eprintln!("Fetch error: {}", e),
Ok(r) => println!("Response: {:#?}", r.json::<Value>().await),
}
}
}