use std::{collections::HashMap, fs::File, path::Path, time::Duration}; use chrono::{DateTime, TimeZone, Utc}; use futures::TryStreamExt; use indicatif::ProgressBar; use rand::distributions::{Distribution, WeightedIndex}; use regex::Regex; use reqwest::{header, ClientBuilder}; use serde::{Deserialize, Serialize}; use serde_json::Value; use sprs::{CsMat, TriMat}; use sqlx::postgres::PgPoolOptions; #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] enum MarkovItem { Start, Stop, Token(u32), } impl MarkovItem { fn index(self) -> usize { match self { MarkovItem::Start => 0, MarkovItem::Stop => 1, MarkovItem::Token(idx) => idx as usize + 2, } } fn required_matrix_size(items: u32) -> usize { items as usize + 2 } fn from_usize(ind: usize) -> Self { match ind { 0 => MarkovItem::Start, 1 => MarkovItem::Stop, x => MarkovItem::Token(x.checked_sub(2).unwrap() as u32), } } } #[derive(Debug, Serialize, Deserialize)] struct State { matrix: CsMat, mappings: HashMap, reverse_mappings: HashMap, fetched_up_to: Option>, } impl State { fn new() -> Self { let mut tri_mat = TriMat::new((2, 2)); tri_mat.add_triplet(MarkovItem::Start.index(), MarkovItem::Stop.index(), 1); State { matrix: tri_mat.to_csr(), mappings: HashMap::new(), reverse_mappings: HashMap::new(), fetched_up_to: None, } } fn generate(&self) -> String { let mut rand = rand::thread_rng(); let mut attempts = 0; loop { let mut idx = MarkovItem::Start; let mut tokens = Vec::new(); loop { if matches!(idx, MarkovItem::Stop) { break; } let weights = WeightedIndex::new( (0..self.matrix.cols()) .map(|j| self.matrix.get(idx.index(), j).copied().unwrap_or_default()), ) .expect("Weighted index failed to create"); idx = MarkovItem::from_usize(weights.sample(&mut rand)); if let MarkovItem::Token(ti) = idx { tokens.push( self.mappings .get(&ti) .expect("Item should be present in the mappings map") .clone(), ) } } if (tokens.is_empty() || tokens.join(" ").trim().is_empty()) && attempts < 10 { attempts += 1; continue; } break tokens.join(" ").replace('@', "@\u{200b}"); } } fn insert(&mut self, (tok1, tok2): (Option<&str>, Option<&str>)) { let item1 = match tok1 { None => MarkovItem::Start, Some(t) => match self.reverse_mappings.get(t) { Some(i) => MarkovItem::Token(*i), None => { let new_idx = self.reverse_mappings.len() as u32; self.reverse_mappings.insert(t.to_owned(), new_idx); self.mappings.insert(new_idx, t.to_owned()); MarkovItem::Token(new_idx) } }, }; let item2 = match tok2 { None => MarkovItem::Stop, Some(t) => match self.reverse_mappings.get(t) { Some(i) => MarkovItem::Token(*i), None => { let new_idx = self.reverse_mappings.len() as u32; self.reverse_mappings.insert(t.to_owned(), new_idx); self.mappings.insert(new_idx, t.to_owned()); MarkovItem::Token(new_idx) } }, }; let mut item_value = 10; if matches!(item1, MarkovItem::Start) && matches!(tok2, Some(val) if val.starts_with('@')) { item_value /= 3; } if tok2.is_none() { item_value -= 2; } if let Some(item) = self.matrix.get_mut(item1.index(), item2.index()) { *item += item_value; } else { self.matrix.insert(item1.index(), item2.index(), item_value) } } fn insert_tokens(&mut self, input: &str) { let input = input.trim(); // Remove quotes let input = Regex::new(r"> ?.+") .unwrap() .replace_all(input, "") .to_string(); if input.contains("\n\n") { let paragraph_split = Regex::new("\n\n+").unwrap(); paragraph_split .split(&input) .for_each(|p| self.insert_tokens(p)); return; } let regex = Regex::new(r"\s+").unwrap(); let nasty_words = vec![ "hitler", "natsura", "eve@snug", "lustlion", "abuse", "pedophi", "murder", "transpho", "holocaust", "cpluspatch", "umbrellix", "sachi", "heonkey", ]; if nasty_words.iter().any(|w| input.to_lowercase().contains(w)) { return; } let mut last = None; for tok in regex.split(&input) { self.insert((last, Some(tok))); last = Some(tok); } self.insert((last, None)); } } #[tokio::main] async fn main() { dotenvy::dotenv().ok(); let token = std::env::var("TOKEN").expect("token"); let path = Path::new("state.msg"); let progress = ProgressBar::new_spinner(); let mut state = if path.is_file() { let file = File::open(path).unwrap(); progress.set_message("Loading data..."); rmp_serde::decode::from_read::<_, State>(progress.wrap_read(file)).unwrap() } else { State::new() }; let pool = PgPoolOptions::new() .connect(std::env::var("DATABASE_URL").as_deref().unwrap()) .await .unwrap(); let note_time_min = state .fetched_up_to .unwrap_or(Utc.with_ymd_and_hms(2010, 1, 1, 12, 0, 0).unwrap()); progress.set_message("Fetching notes..."); let text_stream = sqlx::query!( r#"SELECT text, "createdAt" FROM note WHERE "note"."userId" = $1 AND "note"."createdAt" > $2 AND "note"."createdAt" < NOW() AND "note"."visibility" IN ('public', 'home') AND ("note"."cw" IS NULL OR LOWER("note"."cw") IN ('', 'gay', 'cursed', 'what', 'shitpost', 'no', 'natty what', 'natty what the fuck'))"#, "9awy7u3l76", note_time_min ) .fetch(&pool); let mut stream = progress.wrap_stream(text_stream); let mut cnt = 0; while let Ok(Some(item)) = stream.try_next().await.map_err(|e| { eprintln!("Error fetching: {}", e); e }) { state.fetched_up_to = Some(item.createdAt); state.insert_tokens(item.text.as_deref().unwrap_or_default()); progress.set_message(item.text.unwrap_or_default().clone()); cnt += 1; progress.tick(); progress.set_length(cnt); } drop(stream); drop(pool); println!("Shape: {:?}", state.matrix.shape()); let file = File::create(path).unwrap(); progress.set_message("Saving data..."); rmp_serde::encode::write(&mut progress.wrap_write(file), &state).unwrap(); let mut timer = tokio::time::interval(Duration::from_secs(600)); loop { timer.tick().await; let text = state.generate(); println!("Generated: {:?}", text); let client = ClientBuilder::new().https_only(true).build().unwrap(); match client .post("https://astolfo.social/api/notes/create") .json(&serde_json::json!({ "text": text, "poll": Option::<()>::None, "cw": "Automated Markov-Natty post", "visibility": "home" })) .header(header::AUTHORIZATION, &format!("Bearer {}", token)) .send() .await { Err(e) => eprintln!("Fetch error: {}", e), Ok(r) => println!("Response: {:#?}", r.json::().await), } } }