Initial commit
This commit is contained in:
commit
e4e8908788
|
@ -0,0 +1,4 @@
|
|||
/target
|
||||
state.msg
|
||||
.env
|
||||
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,19 @@
|
|||
[package]
|
||||
name = "mag-markov"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
dotenvy = "0.15"
|
||||
futures = "0.3"
|
||||
indicatif = { version = "0.17", features = ["futures"] }
|
||||
sprs = { version = "0.11", features = ["serde"] }
|
||||
sqlx = { version = "0.7", features = ["runtime-tokio", "postgres", "chrono"] }
|
||||
rand = "0.8"
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
regex = "1.10"
|
||||
serde_json = "1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
rmp-serde = "1.1"
|
||||
tokio = { version = "1.36", features = ["full"] }
|
|
@ -0,0 +1,271 @@
|
|||
use std::{collections::HashMap, fs::File, path::Path, time::Duration};
|
||||
|
||||
use chrono::{DateTime, TimeZone, Utc};
|
||||
use futures::TryStreamExt;
|
||||
use indicatif::ProgressBar;
|
||||
use rand::distributions::{Distribution, WeightedIndex};
|
||||
use regex::Regex;
|
||||
use reqwest::{header, ClientBuilder, RequestBuilder, Response};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use sprs::{CsMat, TriMat};
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
|
||||
enum MarkovItem {
|
||||
Start,
|
||||
Stop,
|
||||
Token(u32),
|
||||
}
|
||||
|
||||
impl MarkovItem {
|
||||
fn index(self) -> usize {
|
||||
match self {
|
||||
MarkovItem::Start => 0,
|
||||
MarkovItem::Stop => 1,
|
||||
MarkovItem::Token(idx) => idx as usize + 2,
|
||||
}
|
||||
}
|
||||
|
||||
fn required_matrix_size(items: u32) -> usize {
|
||||
items as usize + 2
|
||||
}
|
||||
|
||||
fn from_usize(ind: usize) -> Self {
|
||||
match ind {
|
||||
0 => MarkovItem::Start,
|
||||
1 => MarkovItem::Stop,
|
||||
x => MarkovItem::Token(x.checked_sub(2).unwrap() as u32),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct State {
|
||||
matrix: CsMat<u32>,
|
||||
mappings: HashMap<u32, String>,
|
||||
reverse_mappings: HashMap<String, u32>,
|
||||
fetched_up_to: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new() -> Self {
|
||||
let mut tri_mat = TriMat::new((2, 2));
|
||||
tri_mat.add_triplet(MarkovItem::Start.index(), MarkovItem::Stop.index(), 1);
|
||||
|
||||
State {
|
||||
matrix: tri_mat.to_csr(),
|
||||
mappings: HashMap::new(),
|
||||
reverse_mappings: HashMap::new(),
|
||||
fetched_up_to: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn generate(&self) -> String {
|
||||
let mut rand = rand::thread_rng();
|
||||
let mut attempts = 0;
|
||||
|
||||
loop {
|
||||
let mut idx = MarkovItem::Start;
|
||||
let mut tokens = Vec::new();
|
||||
loop {
|
||||
if matches!(idx, MarkovItem::Stop) {
|
||||
break;
|
||||
}
|
||||
|
||||
let weights = WeightedIndex::new(
|
||||
(0..self.matrix.cols())
|
||||
.map(|j| self.matrix.get(idx.index(), j).copied().unwrap_or_default()),
|
||||
)
|
||||
.expect("Weighted index failed to create");
|
||||
|
||||
idx = MarkovItem::from_usize(weights.sample(&mut rand));
|
||||
|
||||
if let MarkovItem::Token(ti) = idx {
|
||||
tokens.push(
|
||||
self.mappings
|
||||
.get(&ti)
|
||||
.expect("Item should be present in the mappings map")
|
||||
.clone(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if (tokens.is_empty() || tokens.join(" ").trim().is_empty()) && attempts < 10 {
|
||||
attempts += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
break tokens.join(" ").replace('@', "@\u{200b}");
|
||||
}
|
||||
}
|
||||
|
||||
fn insert(&mut self, (tok1, tok2): (Option<&str>, Option<&str>)) {
|
||||
let item1 = match tok1 {
|
||||
None => MarkovItem::Start,
|
||||
Some(t) => match self.reverse_mappings.get(t) {
|
||||
Some(i) => MarkovItem::Token(*i),
|
||||
None => {
|
||||
let new_idx = self.reverse_mappings.len() as u32;
|
||||
self.reverse_mappings.insert(t.to_owned(), new_idx);
|
||||
self.mappings.insert(new_idx, t.to_owned());
|
||||
|
||||
MarkovItem::Token(new_idx)
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let item2 = match tok2 {
|
||||
None => MarkovItem::Stop,
|
||||
Some(t) => match self.reverse_mappings.get(t) {
|
||||
Some(i) => MarkovItem::Token(*i),
|
||||
None => {
|
||||
let new_idx = self.reverse_mappings.len() as u32;
|
||||
self.reverse_mappings.insert(t.to_owned(), new_idx);
|
||||
self.mappings.insert(new_idx, t.to_owned());
|
||||
|
||||
MarkovItem::Token(new_idx)
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let mut item_value = 10;
|
||||
|
||||
if matches!(item1, MarkovItem::Start) && matches!(tok2, Some(val) if val.starts_with('@')) {
|
||||
item_value /= 3;
|
||||
}
|
||||
|
||||
if tok2.is_none() {
|
||||
item_value -= 2;
|
||||
}
|
||||
|
||||
if let Some(item) = self.matrix.get_mut(item1.index(), item2.index()) {
|
||||
*item += item_value;
|
||||
} else {
|
||||
self.matrix.insert(item1.index(), item2.index(), item_value)
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_tokens(&mut self, input: &str) {
|
||||
let input = input.trim();
|
||||
// Remove quotes
|
||||
let input = Regex::new(r"> ?.+")
|
||||
.unwrap()
|
||||
.replace_all(input, "")
|
||||
.to_string();
|
||||
|
||||
let regex = Regex::new(r"\s+").unwrap();
|
||||
|
||||
let nasty_words = vec![
|
||||
"hitler",
|
||||
"natsura",
|
||||
"eve@snug",
|
||||
"lustlion",
|
||||
"abuse",
|
||||
"pedophi",
|
||||
"murder",
|
||||
"transpho",
|
||||
"holocaust",
|
||||
"cpluspatch",
|
||||
"umbrellix",
|
||||
"sachi",
|
||||
"heonkey",
|
||||
];
|
||||
|
||||
if nasty_words.iter().any(|w| input.to_lowercase().contains(w)) {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut last = None;
|
||||
for tok in regex.split(&input) {
|
||||
self.insert((last, Some(tok)));
|
||||
last = Some(tok);
|
||||
}
|
||||
self.insert((last, None));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
dotenvy::dotenv().ok();
|
||||
let token = std::env::var("TOKEN").expect("token");
|
||||
|
||||
let path = Path::new("state.msg");
|
||||
|
||||
let progress = ProgressBar::new_spinner();
|
||||
|
||||
let mut state = if path.is_file() {
|
||||
let file = File::open(path).unwrap();
|
||||
progress.set_message("Loading data...");
|
||||
rmp_serde::decode::from_read::<_, State>(progress.wrap_read(file)).unwrap()
|
||||
} else {
|
||||
State::new()
|
||||
};
|
||||
progress.disable_steady_tick();
|
||||
|
||||
let pool = PgPoolOptions::new()
|
||||
.connect(std::env::var("DATABASE_URL").as_deref().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let note_time_min = state
|
||||
.fetched_up_to
|
||||
.unwrap_or(Utc.with_ymd_and_hms(2010, 1, 1, 12, 0, 0).unwrap());
|
||||
|
||||
progress.set_message("Fetching notes...");
|
||||
let text_stream = sqlx::query!(
|
||||
r#"SELECT text, "createdAt"
|
||||
FROM note
|
||||
WHERE "note"."userId" = $1
|
||||
AND "note"."createdAt" > $2
|
||||
AND "note"."visibility" IN ('public', 'home')
|
||||
AND ("note"."cw" IS NULL OR LOWER("note"."cw") IN ('', 'gay', 'cursed', 'what', 'shitpost', 'no', 'natty what', 'natty what the fuck'))"#,
|
||||
"9awy7u3l76",
|
||||
note_time_min
|
||||
)
|
||||
.fetch(&pool);
|
||||
|
||||
let mut stream = progress.wrap_stream(text_stream);
|
||||
let mut cnt = 0;
|
||||
while let Ok(Some(item)) = stream.try_next().await.map_err(|e| {
|
||||
eprintln!("Error fetching: {}", e);
|
||||
e
|
||||
}) {
|
||||
state.fetched_up_to = Some(item.createdAt);
|
||||
state.insert_tokens(item.text.as_deref().unwrap_or_default());
|
||||
progress.set_message(item.text.unwrap_or_default().clone());
|
||||
cnt += 1;
|
||||
progress.tick();
|
||||
progress.set_length(cnt);
|
||||
}
|
||||
|
||||
let file = File::create(path).unwrap();
|
||||
progress.set_message("Saving data...");
|
||||
rmp_serde::encode::write(&mut progress.wrap_write(file), &state).unwrap();
|
||||
|
||||
let mut timer = tokio::time::interval(Duration::from_secs(600));
|
||||
|
||||
loop {
|
||||
timer.tick().await;
|
||||
|
||||
let text = state.generate();
|
||||
println!("Generated: {:?}", text);
|
||||
|
||||
let client = ClientBuilder::new().https_only(true).build().unwrap();
|
||||
match client
|
||||
.post("https://astolfo.social/api/notes/create")
|
||||
.json(&serde_json::json!({
|
||||
"text": text,
|
||||
"poll": Option::<()>::None,
|
||||
"cw": "Automated Markov-Natty post",
|
||||
"visibility": "home"
|
||||
}))
|
||||
.header(header::AUTHORIZATION, &format!("Bearer {}", token))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Err(e) => eprintln!("Fetch error: {}", e),
|
||||
Ok(r) => println!("Response: {:#?}", r.json::<Value>().await),
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue