272 lines
7.9 KiB
Rust
272 lines
7.9 KiB
Rust
|
use std::{collections::HashMap, fs::File, path::Path, time::Duration};
|
||
|
|
||
|
use chrono::{DateTime, TimeZone, Utc};
|
||
|
use futures::TryStreamExt;
|
||
|
use indicatif::ProgressBar;
|
||
|
use rand::distributions::{Distribution, WeightedIndex};
|
||
|
use regex::Regex;
|
||
|
use reqwest::{header, ClientBuilder, RequestBuilder, Response};
|
||
|
use serde::{Deserialize, Serialize};
|
||
|
use serde_json::{json, Value};
|
||
|
use sprs::{CsMat, TriMat};
|
||
|
use sqlx::postgres::PgPoolOptions;
|
||
|
|
||
|
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
|
||
|
enum MarkovItem {
|
||
|
Start,
|
||
|
Stop,
|
||
|
Token(u32),
|
||
|
}
|
||
|
|
||
|
impl MarkovItem {
|
||
|
fn index(self) -> usize {
|
||
|
match self {
|
||
|
MarkovItem::Start => 0,
|
||
|
MarkovItem::Stop => 1,
|
||
|
MarkovItem::Token(idx) => idx as usize + 2,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
fn required_matrix_size(items: u32) -> usize {
|
||
|
items as usize + 2
|
||
|
}
|
||
|
|
||
|
fn from_usize(ind: usize) -> Self {
|
||
|
match ind {
|
||
|
0 => MarkovItem::Start,
|
||
|
1 => MarkovItem::Stop,
|
||
|
x => MarkovItem::Token(x.checked_sub(2).unwrap() as u32),
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Debug, Serialize, Deserialize)]
|
||
|
struct State {
|
||
|
matrix: CsMat<u32>,
|
||
|
mappings: HashMap<u32, String>,
|
||
|
reverse_mappings: HashMap<String, u32>,
|
||
|
fetched_up_to: Option<DateTime<Utc>>,
|
||
|
}
|
||
|
|
||
|
impl State {
|
||
|
fn new() -> Self {
|
||
|
let mut tri_mat = TriMat::new((2, 2));
|
||
|
tri_mat.add_triplet(MarkovItem::Start.index(), MarkovItem::Stop.index(), 1);
|
||
|
|
||
|
State {
|
||
|
matrix: tri_mat.to_csr(),
|
||
|
mappings: HashMap::new(),
|
||
|
reverse_mappings: HashMap::new(),
|
||
|
fetched_up_to: None,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
fn generate(&self) -> String {
|
||
|
let mut rand = rand::thread_rng();
|
||
|
let mut attempts = 0;
|
||
|
|
||
|
loop {
|
||
|
let mut idx = MarkovItem::Start;
|
||
|
let mut tokens = Vec::new();
|
||
|
loop {
|
||
|
if matches!(idx, MarkovItem::Stop) {
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
let weights = WeightedIndex::new(
|
||
|
(0..self.matrix.cols())
|
||
|
.map(|j| self.matrix.get(idx.index(), j).copied().unwrap_or_default()),
|
||
|
)
|
||
|
.expect("Weighted index failed to create");
|
||
|
|
||
|
idx = MarkovItem::from_usize(weights.sample(&mut rand));
|
||
|
|
||
|
if let MarkovItem::Token(ti) = idx {
|
||
|
tokens.push(
|
||
|
self.mappings
|
||
|
.get(&ti)
|
||
|
.expect("Item should be present in the mappings map")
|
||
|
.clone(),
|
||
|
)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if (tokens.is_empty() || tokens.join(" ").trim().is_empty()) && attempts < 10 {
|
||
|
attempts += 1;
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
break tokens.join(" ").replace('@', "@\u{200b}");
|
||
|
}
|
||
|
}
|
||
|
|
||
|
fn insert(&mut self, (tok1, tok2): (Option<&str>, Option<&str>)) {
|
||
|
let item1 = match tok1 {
|
||
|
None => MarkovItem::Start,
|
||
|
Some(t) => match self.reverse_mappings.get(t) {
|
||
|
Some(i) => MarkovItem::Token(*i),
|
||
|
None => {
|
||
|
let new_idx = self.reverse_mappings.len() as u32;
|
||
|
self.reverse_mappings.insert(t.to_owned(), new_idx);
|
||
|
self.mappings.insert(new_idx, t.to_owned());
|
||
|
|
||
|
MarkovItem::Token(new_idx)
|
||
|
}
|
||
|
},
|
||
|
};
|
||
|
|
||
|
let item2 = match tok2 {
|
||
|
None => MarkovItem::Stop,
|
||
|
Some(t) => match self.reverse_mappings.get(t) {
|
||
|
Some(i) => MarkovItem::Token(*i),
|
||
|
None => {
|
||
|
let new_idx = self.reverse_mappings.len() as u32;
|
||
|
self.reverse_mappings.insert(t.to_owned(), new_idx);
|
||
|
self.mappings.insert(new_idx, t.to_owned());
|
||
|
|
||
|
MarkovItem::Token(new_idx)
|
||
|
}
|
||
|
},
|
||
|
};
|
||
|
|
||
|
let mut item_value = 10;
|
||
|
|
||
|
if matches!(item1, MarkovItem::Start) && matches!(tok2, Some(val) if val.starts_with('@')) {
|
||
|
item_value /= 3;
|
||
|
}
|
||
|
|
||
|
if tok2.is_none() {
|
||
|
item_value -= 2;
|
||
|
}
|
||
|
|
||
|
if let Some(item) = self.matrix.get_mut(item1.index(), item2.index()) {
|
||
|
*item += item_value;
|
||
|
} else {
|
||
|
self.matrix.insert(item1.index(), item2.index(), item_value)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
fn insert_tokens(&mut self, input: &str) {
|
||
|
let input = input.trim();
|
||
|
// Remove quotes
|
||
|
let input = Regex::new(r"> ?.+")
|
||
|
.unwrap()
|
||
|
.replace_all(input, "")
|
||
|
.to_string();
|
||
|
|
||
|
let regex = Regex::new(r"\s+").unwrap();
|
||
|
|
||
|
let nasty_words = vec![
|
||
|
"hitler",
|
||
|
"natsura",
|
||
|
"eve@snug",
|
||
|
"lustlion",
|
||
|
"abuse",
|
||
|
"pedophi",
|
||
|
"murder",
|
||
|
"transpho",
|
||
|
"holocaust",
|
||
|
"cpluspatch",
|
||
|
"umbrellix",
|
||
|
"sachi",
|
||
|
"heonkey",
|
||
|
];
|
||
|
|
||
|
if nasty_words.iter().any(|w| input.to_lowercase().contains(w)) {
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
let mut last = None;
|
||
|
for tok in regex.split(&input) {
|
||
|
self.insert((last, Some(tok)));
|
||
|
last = Some(tok);
|
||
|
}
|
||
|
self.insert((last, None));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[tokio::main]
|
||
|
async fn main() {
|
||
|
dotenvy::dotenv().ok();
|
||
|
let token = std::env::var("TOKEN").expect("token");
|
||
|
|
||
|
let path = Path::new("state.msg");
|
||
|
|
||
|
let progress = ProgressBar::new_spinner();
|
||
|
|
||
|
let mut state = if path.is_file() {
|
||
|
let file = File::open(path).unwrap();
|
||
|
progress.set_message("Loading data...");
|
||
|
rmp_serde::decode::from_read::<_, State>(progress.wrap_read(file)).unwrap()
|
||
|
} else {
|
||
|
State::new()
|
||
|
};
|
||
|
progress.disable_steady_tick();
|
||
|
|
||
|
let pool = PgPoolOptions::new()
|
||
|
.connect(std::env::var("DATABASE_URL").as_deref().unwrap())
|
||
|
.await
|
||
|
.unwrap();
|
||
|
|
||
|
let note_time_min = state
|
||
|
.fetched_up_to
|
||
|
.unwrap_or(Utc.with_ymd_and_hms(2010, 1, 1, 12, 0, 0).unwrap());
|
||
|
|
||
|
progress.set_message("Fetching notes...");
|
||
|
let text_stream = sqlx::query!(
|
||
|
r#"SELECT text, "createdAt"
|
||
|
FROM note
|
||
|
WHERE "note"."userId" = $1
|
||
|
AND "note"."createdAt" > $2
|
||
|
AND "note"."visibility" IN ('public', 'home')
|
||
|
AND ("note"."cw" IS NULL OR LOWER("note"."cw") IN ('', 'gay', 'cursed', 'what', 'shitpost', 'no', 'natty what', 'natty what the fuck'))"#,
|
||
|
"9awy7u3l76",
|
||
|
note_time_min
|
||
|
)
|
||
|
.fetch(&pool);
|
||
|
|
||
|
let mut stream = progress.wrap_stream(text_stream);
|
||
|
let mut cnt = 0;
|
||
|
while let Ok(Some(item)) = stream.try_next().await.map_err(|e| {
|
||
|
eprintln!("Error fetching: {}", e);
|
||
|
e
|
||
|
}) {
|
||
|
state.fetched_up_to = Some(item.createdAt);
|
||
|
state.insert_tokens(item.text.as_deref().unwrap_or_default());
|
||
|
progress.set_message(item.text.unwrap_or_default().clone());
|
||
|
cnt += 1;
|
||
|
progress.tick();
|
||
|
progress.set_length(cnt);
|
||
|
}
|
||
|
|
||
|
let file = File::create(path).unwrap();
|
||
|
progress.set_message("Saving data...");
|
||
|
rmp_serde::encode::write(&mut progress.wrap_write(file), &state).unwrap();
|
||
|
|
||
|
let mut timer = tokio::time::interval(Duration::from_secs(600));
|
||
|
|
||
|
loop {
|
||
|
timer.tick().await;
|
||
|
|
||
|
let text = state.generate();
|
||
|
println!("Generated: {:?}", text);
|
||
|
|
||
|
let client = ClientBuilder::new().https_only(true).build().unwrap();
|
||
|
match client
|
||
|
.post("https://astolfo.social/api/notes/create")
|
||
|
.json(&serde_json::json!({
|
||
|
"text": text,
|
||
|
"poll": Option::<()>::None,
|
||
|
"cw": "Automated Markov-Natty post",
|
||
|
"visibility": "home"
|
||
|
}))
|
||
|
.header(header::AUTHORIZATION, &format!("Bearer {}", token))
|
||
|
.send()
|
||
|
.await
|
||
|
{
|
||
|
Err(e) => eprintln!("Fetch error: {}", e),
|
||
|
Ok(r) => println!("Response: {:#?}", r.json::<Value>().await),
|
||
|
}
|
||
|
}
|
||
|
}
|