diff --git a/.gitignore b/.gitignore index 6f1325f..58d1889 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ /target -state.msg +/state*.msg .env - +/config*.toml diff --git a/Cargo.lock b/Cargo.lock index 221750d..ca108bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -71,6 +71,54 @@ dependencies = [ "libc", ] +[[package]] +name = "anstream" +version = "0.6.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96bd03f33fe50a863e394ee9718a706f988b9079b20c3784fb726e7678b62fb" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" + +[[package]] +name = "anstyle-parse" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c75ac65da39e5fe5ab759307499ddad880d724eed2f6ce5b5e8a26f4f387928c" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e28923312444cdd728e4738b3f9c9cac739500909bb3d3c94b43551b16517648" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7" +dependencies = [ + "anstyle", + "windows-sys 0.52.0", +] + [[package]] name = "approx" version = "0.3.2" @@ -80,6 +128,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "arrayvec" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" + [[package]] name = "atoi" version = "2.0.0" @@ -191,6 +245,52 @@ dependencies = [ "windows-targets 0.52.4", ] +[[package]] +name = "clap" +version = "4.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn 2.0.55", +] + +[[package]] +name = "clap_lex" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" + +[[package]] +name = "colorchoice" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" + [[package]] name = "console" version = "0.15.8" @@ -607,6 +707,12 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "hermit-abi" version = "0.3.9" @@ -797,6 +903,7 @@ dependencies = [ "number_prefix", "portable-atomic", "unicode-width", + "vt100", ] [[package]] @@ -897,18 +1004,36 @@ name = "mag-markov" version = "0.1.0" dependencies = [ "chrono", + "clap", "dotenvy", "futures", "indicatif", + "itertools", + "qp-trie", "rand", "regex", "reqwest", "rmp-serde", "serde", "serde_json", + "smallvec", "sprs", "sqlx", "tokio", + "toml", + "tracing", + "tracing-indicatif", + "tracing-subscriber", + "unicode-segmentation", +] + +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", ] [[package]] @@ -1000,6 +1125,12 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "new_debug_unreachable" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" + [[package]] name = "nom" version = "7.1.3" @@ -1010,6 +1141,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num-bigint-dig" version = "0.8.4" @@ -1151,6 +1292,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking_lot" version = "0.12.1" @@ -1275,6 +1422,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "qp-trie" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ec628a7d1fc2c5f5a551eb34e01e08df62d55203640959a79a9a2859c797a97" +dependencies = [ + "new_debug_unreachable", + "serde", + "unreachable", +] + [[package]] name = "quote" version = "1.0.35" @@ -1357,8 +1515,17 @@ checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.6", + "regex-syntax 0.8.3", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", ] [[package]] @@ -1369,9 +1536,15 @@ checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.3", ] +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.8.3" @@ -1565,6 +1738,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1599,6 +1781,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -1632,6 +1823,9 @@ name = "smallvec" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +dependencies = [ + "serde", +] [[package]] name = "socket2" @@ -1769,7 +1963,7 @@ checksum = "5833ef53aaa16d860e92123292f1f6a3d53c34ba8b1969f152ef1a7bb803f3c8" dependencies = [ "dotenvy", "either", - "heck", + "heck 0.4.1", "hex", "once_cell", "proc-macro2", @@ -1904,6 +2098,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "strsim" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ee073c9e4cd00e28217186dbe12796d692868f432bf2e97ee73bed0c56dfa01" + [[package]] name = "subtle" version = "2.5.0" @@ -1991,6 +2191,16 @@ dependencies = [ "syn 2.0.55", ] +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + [[package]] name = "tinyvec" version = "1.6.0" @@ -2071,6 +2281,40 @@ dependencies = [ "tracing", ] +[[package]] +name = "toml" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9dd1545e8208b4a5af1aa9bbd0b4cf7e9ea08fabc5d0a5c67fcaafa17433aa3" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e40bb779c5187258fd7aad0eb68cb8706a0a81fa712fbea808ab43c4b8374c4" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + [[package]] name = "tower" version = "0.4.13" @@ -2129,6 +2373,48 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-indicatif" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "069580424efe11d97c3fef4197fa98c004fa26672cc71ad8770d224e23b1951d" +dependencies = [ + "indicatif", + "tracing", + "tracing-core", + "tracing-subscriber", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", ] [[package]] @@ -2182,6 +2468,15 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "unreachable" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "382810877fe448991dfc7f0dd6e3ae5d58088fd0ea5e35189655f84e6814fa56" +dependencies = [ + "void", +] + [[package]] name = "url" version = "2.5.0" @@ -2199,6 +2494,18 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "vcpkg" version = "0.2.15" @@ -2211,6 +2518,45 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "void" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" + +[[package]] +name = "vt100" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84cd863bf0db7e392ba3bd04994be3473491b31e66340672af5d11943c6274de" +dependencies = [ + "itoa", + "log", + "unicode-width", + "vte", +] + +[[package]] +name = "vte" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5022b5fbf9407086c180e9557be968742d839e68346af7792b8592489732197" +dependencies = [ + "arrayvec", + "utf8parse", + "vte_generate_state_changes", +] + +[[package]] +name = "vte_generate_state_changes" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d257817081c7dffcdbab24b9e62d2def62e2ff7d00b1c20062551e6cccc145ff" +dependencies = [ + "proc-macro2", + "quote", +] + [[package]] name = "want" version = "0.3.1" @@ -2318,6 +2664,28 @@ dependencies = [ "wasite", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.52.0" @@ -2459,6 +2827,15 @@ version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" +[[package]] +name = "winnow" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dffa400e67ed5a4dd237983829e66475f0a4a26938c4b04c21baede6262215b8" +dependencies = [ + "memchr", +] + [[package]] name = "winreg" version = "0.50.0" diff --git a/Cargo.toml b/Cargo.toml index eb87f72..dd56074 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,15 +5,24 @@ edition = "2021" [dependencies] chrono = { version = "0.4", features = ["serde"] } +clap = { version = "4.5", features = ["derive"] } dotenvy = "0.15" futures = "0.3" indicatif = { version = "0.17", features = ["futures"] } +itertools = "0.12" +qp-trie = { version = "0.8", features = ["serde"] } +rand = "0.8" +regex = "1.10" +reqwest = { version = "0.12", features = ["json"] } +rmp-serde = "1.1" +serde_json = "1" +toml = "0.8" +serde = { version = "1", features = ["derive"] } +smallvec = { version = "1.13", features = ["serde"] } sprs = { version = "0.11", features = ["serde"] } sqlx = { version = "0.7", features = ["runtime-tokio", "postgres", "chrono"] } -rand = "0.8" -reqwest = { version = "0.12", features = ["json"] } -regex = "1.10" -serde_json = "1" -serde = { version = "1", features = ["derive"] } -rmp-serde = "1.1" tokio = { version = "1.36", features = ["full"] } +tracing = "0.1" +tracing-indicatif = "0.3" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +unicode-segmentation = "1.11" diff --git a/src/main.rs b/src/main.rs index 98f6bc1..8737c27 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,15 +1,25 @@ -use std::{collections::HashMap, fs::File, path::Path, time::Duration}; +use std::{ + collections::{HashMap, VecDeque}, + fs::File, + io::Write, + path::Path, + time::Duration, +}; use chrono::{DateTime, TimeZone, Utc}; +use clap::Parser; use futures::TryStreamExt; use indicatif::ProgressBar; +use qp_trie::Trie; use rand::distributions::{Distribution, WeightedIndex}; use regex::Regex; use reqwest::{header, ClientBuilder}; use serde::{Deserialize, Serialize}; -use serde_json::Value; -use sprs::{CsMat, TriMat}; +use smallvec::SmallVec; use sqlx::postgres::PgPoolOptions; +use tracing::{debug, error, info, level_filters::LevelFilter}; +use tracing_indicatif::IndicatifLayer; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] enum MarkovItem { @@ -19,71 +29,152 @@ enum MarkovItem { } impl MarkovItem { - fn index(self) -> usize { + fn index(self) -> u32 { match self { MarkovItem::Start => 0, MarkovItem::Stop => 1, - MarkovItem::Token(idx) => idx as usize + 2, + MarkovItem::Token(idx) => idx + 2, } } - fn required_matrix_size(items: u32) -> usize { - items as usize + 2 - } - - fn from_usize(ind: usize) -> Self { + fn from_u32(ind: u32) -> Self { match ind { 0 => MarkovItem::Start, 1 => MarkovItem::Stop, - x => MarkovItem::Token(x.checked_sub(2).unwrap() as u32), + x => MarkovItem::Token(x.checked_sub(2).unwrap()), } } } +const CONTEXT_SIZE_MAX: usize = 5; +const CONTEXT_SIZE_BYTES: usize = CONTEXT_SIZE_MAX * std::mem::size_of::(); + +type PrefixBytes = SmallVec<[u8; CONTEXT_SIZE_BYTES]>; + +trait IntoBytes { + fn into_bytes(self) -> PrefixBytes; +} + +impl<'a> IntoBytes for &'a [MarkovItem] { + fn into_bytes(self) -> PrefixBytes { + self.iter() + .copied() + .map(MarkovItem::index) + .flat_map(u32::to_be_bytes) + .collect::() + } +} + +#[derive(Debug, Default)] +struct Tokenizer {} + +impl Tokenizer { + fn tokenize<'a>(&self, input: &'a str) -> impl Iterator { + input.split_whitespace() + } +} + #[derive(Debug, Serialize, Deserialize)] struct State { - matrix: CsMat, + model: Trie, mappings: HashMap, reverse_mappings: HashMap, fetched_up_to: Option>, } -impl State { - fn new() -> Self { - let mut tri_mat = TriMat::new((2, 2)); - tri_mat.add_triplet(MarkovItem::Start.index(), MarkovItem::Stop.index(), 1); +#[derive(Debug)] +struct Model<'a> { + state: State, + config: &'a Config, + tokenizer: Tokenizer, +} - State { - matrix: tri_mat.to_csr(), - mappings: HashMap::new(), - reverse_mappings: HashMap::new(), - fetched_up_to: None, +impl<'a> Model<'a> { + fn new(config: &'a Config) -> Self { + let mut model = Trie::new(); + model.insert( + [MarkovItem::Start, MarkovItem::Stop].as_ref().into_bytes(), + 1.0, + ); + + Self::from_existing( + config, + State { + model, + mappings: HashMap::new(), + reverse_mappings: HashMap::new(), + fetched_up_to: None, + }, + ) + } + + fn from_existing(config: &'a Config, state: State) -> Self { + Self { + state, + config, + tokenizer: Default::default(), } } + #[tracing::instrument] fn generate(&self) -> String { let mut rand = rand::thread_rng(); + let mut attempts = 0; loop { - let mut idx = MarkovItem::Start; + let mut current = MarkovItem::Start; let mut tokens = Vec::new(); + let mut history = VecDeque::new(); + history.push_front(current); + loop { - if matches!(idx, MarkovItem::Stop) { + while history.len() > self.config.history_len() { + history.pop_front(); + } + + if matches!(current, MarkovItem::Stop) { break; } - let weights = WeightedIndex::new( - (0..self.matrix.cols()) - .map(|j| self.matrix.get(idx.index(), j).copied().unwrap_or_default()), - ) - .expect("Weighted index failed to create"); + let mut options = HashMap::::new(); + let prefix = history.make_contiguous().as_ref().into_bytes(); + let mut prefix_slice = prefix.as_ref(); - idx = MarkovItem::from_usize(weights.sample(&mut rand)); + while !prefix_slice.is_empty() { + let item_size = std::mem::size_of::(); + let prefix_len_bytes = prefix_slice.len(); + let prefix_len_tokens = prefix_len_bytes / item_size; + let prefixes = self + .state + .model + .iter_prefix(prefix_slice) + .filter(|(k, _)| k.len() >= prefix_len_bytes + item_size); - if let MarkovItem::Token(ti) = idx { + for (p, t) in prefixes { + let value = options + .entry(u32::from_be_bytes( + p[prefix_len_bytes..][..item_size].try_into().unwrap(), + )) + .or_insert(0f32); + + *value += t * self.config.context_weights[prefix_len_tokens]; + } + + prefix_slice = &prefix_slice[item_size..]; + } + + let entries = options.into_iter().collect::>(); + let weights = WeightedIndex::new(entries.iter().map(|v| v.1)) + .expect("Weighted index failed to create"); + + current = MarkovItem::from_u32(entries[weights.sample(&mut rand)].0); + history.push_back(current); + + if let MarkovItem::Token(ti) = current { tokens.push( - self.mappings + self.state + .mappings .get(&ti) .expect("Item should be present in the mappings map") .clone(), @@ -91,7 +182,9 @@ impl State { } } - if (tokens.is_empty() || tokens.join(" ").trim().is_empty()) && attempts < 10 { + if (tokens.is_empty() || tokens.join("").trim().is_empty()) + && attempts < self.config.attempts + { attempts += 1; continue; } @@ -100,49 +193,36 @@ impl State { } } - fn insert(&mut self, (tok1, tok2): (Option<&str>, Option<&str>)) { - let item1 = match tok1 { - None => MarkovItem::Start, - Some(t) => match self.reverse_mappings.get(t) { - Some(i) => MarkovItem::Token(*i), - None => { - let new_idx = self.reverse_mappings.len() as u32; - self.reverse_mappings.insert(t.to_owned(), new_idx); - self.mappings.insert(new_idx, t.to_owned()); + fn insert(&mut self, items: &[MarkovItem]) { + let mut item_value = 100.0; - MarkovItem::Token(new_idx) - } - }, - }; - - let item2 = match tok2 { - None => MarkovItem::Stop, - Some(t) => match self.reverse_mappings.get(t) { - Some(i) => MarkovItem::Token(*i), - None => { - let new_idx = self.reverse_mappings.len() as u32; - self.reverse_mappings.insert(t.to_owned(), new_idx); - self.mappings.insert(new_idx, t.to_owned()); - - MarkovItem::Token(new_idx) - } - }, - }; - - let mut item_value = 10; - - if matches!(item1, MarkovItem::Start) && matches!(tok2, Some(val) if val.starts_with('@')) { - item_value /= 3; + if matches!(items[0], MarkovItem::Start) + && matches!(items[1], MarkovItem::Token(t) if self.state.mappings.get(&t).is_some_and(|t| t.starts_with('@'))) + { + item_value *= 1.0 - self.config.mention_penalty; } - if tok2.is_none() { - item_value -= 2; + if matches!(items.last(), Some(MarkovItem::Stop)) { + item_value *= self.config.stop_word_modifier; } - if let Some(item) = self.matrix.get_mut(item1.index(), item2.index()) { + if let Some(item) = self.state.model.get_mut(&items.into_bytes()) { *item += item_value; } else { - self.matrix.insert(item1.index(), item2.index(), item_value) + self.state.model.insert(items.into_bytes(), item_value); + } + } + + fn convert_token(&mut self, tok: &str) -> MarkovItem { + match self.state.reverse_mappings.get(tok) { + Some(i) => MarkovItem::Token(*i), + None => { + let new_idx = self.state.reverse_mappings.len() as u32; + self.state.reverse_mappings.insert(tok.to_owned(), new_idx); + self.state.mappings.insert(new_idx, tok.to_owned()); + + MarkovItem::Token(new_idx) + } } } @@ -162,52 +242,115 @@ impl State { return; } - let regex = Regex::new(r"\s+").unwrap(); - - let nasty_words = vec![ - "hitler", - "natsura", - "eve@snug", - "lustlion", - "abuse", - "pedophi", - "murder", - "transpho", - "holocaust", - "cpluspatch", - "umbrellix", - "sachi", - "heonkey", - ]; - - if nasty_words.iter().any(|w| input.to_lowercase().contains(w)) { + if self + .config + .nasty_words + .iter() + .any(|w| input.to_lowercase().contains(w)) + { return; } + let mut tokens = vec![MarkovItem::Start]; + tokens.extend( + self.tokenizer + .tokenize(&input) + .map(|t| self.convert_token(&t.to_lowercase())), + ); + tokens.push(MarkovItem::Stop); - let mut last = None; - for tok in regex.split(&input) { - self.insert((last, Some(tok))); - last = Some(tok); + for i in 2..=self.config.history_len() { + for tok in tokens.windows(i) { + self.insert(tok); + } } - self.insert((last, None)); + } +} + +#[derive(Debug, Parser)] +#[command(version, about)] +struct Args { + #[arg(long, default_value = "config.toml")] + config_path: String, + #[arg(long, default_value = "state.msg")] + state_file: String, +} + +#[derive(Debug, Deserialize)] +struct Config { + #[serde(default)] + learn_from_cws: bool, + #[serde(default)] + allowed_cws: Vec, + user_id: String, + instance: String, + period_seconds: u32, + #[serde(default)] + test_mode: bool, + #[serde(default)] + display_cw: Option, + context_weights: Vec, + #[serde(default)] + nasty_words: Vec, + #[serde(default)] + attempts: usize, + stop_word_modifier: f32, + mention_penalty: f32, +} + +impl Config { + fn history_len(&self) -> usize { + self.context_weights.len() } } #[tokio::main] async fn main() { dotenvy::dotenv().ok(); + + let indicatif_layer = IndicatifLayer::new(); + + tracing_subscriber::registry() + .with( + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .with(tracing_subscriber::fmt::layer().with_writer(indicatif_layer.get_stderr_writer())) + .with(indicatif_layer) + .init(); + let token = std::env::var("TOKEN").expect("token"); - let path = Path::new("state.msg"); + let args = Args::parse(); + let config = + toml::de::from_str::(&std::fs::read_to_string(&args.config_path).unwrap()).unwrap(); + + if config.period_seconds < 1 { + error!("Cannot have a period shorter than 1 second!"); + return; + } + + if !(2..=CONTEXT_SIZE_MAX).contains(&config.context_weights.len()) { + error!( + "Context size {} out of range {:?}", + config.context_weights.len(), + 2..=CONTEXT_SIZE_MAX + ); + return; + } + + let state_path = Path::new(&args.state_file); let progress = ProgressBar::new_spinner(); - let mut state = if path.is_file() { - let file = File::open(path).unwrap(); + let mut model = if state_path.is_file() { progress.set_message("Loading data..."); - rmp_serde::decode::from_read::<_, State>(progress.wrap_read(file)).unwrap() + let state = + rmp_serde::decode::from_slice::(&tokio::fs::read(state_path).await.unwrap()) + .unwrap(); + Model::from_existing(&config, state) } else { - State::new() + Model::new(&config) }; let pool = PgPoolOptions::new() @@ -215,70 +358,80 @@ async fn main() { .await .unwrap(); - let note_time_min = state + let note_time_min = model + .state .fetched_up_to .unwrap_or(Utc.with_ymd_and_hms(2010, 1, 1, 12, 0, 0).unwrap()); progress.set_message("Fetching notes..."); let text_stream = sqlx::query!( - r#"SELECT text, "createdAt" + r#"SELECT text, "createdAt" FROM note WHERE "note"."userId" = $1 AND "note"."createdAt" > $2 AND "note"."createdAt" < NOW() AND "note"."visibility" IN ('public', 'home') - AND ("note"."cw" IS NULL OR LOWER("note"."cw") IN ('', 'gay', 'cursed', 'what', 'shitpost', 'no', 'natty what', 'natty what the fuck'))"#, - "9awy7u3l76", - note_time_min + AND ($4 OR ("note"."cw" IS NULL OR LOWER("note"."cw") IN (SELECT UNNEST($3::VARCHAR[]))))"#, + &config.user_id, + note_time_min, + &config.allowed_cws, + config.learn_from_cws ) .fetch(&pool); let mut stream = progress.wrap_stream(text_stream); let mut cnt = 0; while let Ok(Some(item)) = stream.try_next().await.map_err(|e| { - eprintln!("Error fetching: {}", e); + error!("Error fetching: {}", e); e }) { - state.fetched_up_to = Some(item.createdAt); - state.insert_tokens(item.text.as_deref().unwrap_or_default()); - progress.set_message(item.text.unwrap_or_default().clone()); + model.state.fetched_up_to = Some(item.createdAt); + model.insert_tokens(item.text.as_deref().unwrap_or_default()); + progress.set_message(format!( + "Items: {} | Trie size: {} | Current: {:?}", + cnt, + model.state.model.count(), + item.text.unwrap_or_default().clone() + )); cnt += 1; progress.tick(); progress.set_length(cnt); } - drop(stream); - drop(pool); + info!("Keys: {}", model.state.model.count()); - println!("Shape: {:?}", state.matrix.shape()); - - let file = File::create(path).unwrap(); progress.set_message("Saving data..."); - rmp_serde::encode::write(&mut progress.wrap_write(file), &state).unwrap(); + let encoded = rmp_serde::encode::to_vec(&model.state).unwrap(); + let file = File::create(state_path).unwrap(); + progress.wrap_write(file).write_all(&encoded).unwrap(); - let mut timer = tokio::time::interval(Duration::from_secs(600)); + let mut timer = tokio::time::interval(Duration::from_secs(config.period_seconds.into())); loop { timer.tick().await; - let text = state.generate(); - println!("Generated: {:?}", text); + let text = model.generate(); + info!("Generated: {:?}", text); + + if config.test_mode { + continue; + } let client = ClientBuilder::new().https_only(true).build().unwrap(); match client - .post("https://astolfo.social/api/notes/create") + .post(format!("https://{}/api/notes/create", &config.instance)) .json(&serde_json::json!({ "text": text, "poll": Option::<()>::None, - "cw": "Automated Markov-Natty post", + "cw": config.display_cw.as_deref(), "visibility": "home" })) .header(header::AUTHORIZATION, &format!("Bearer {}", token)) .send() .await { - Err(e) => eprintln!("Fetch error: {}", e), - Ok(r) => println!("Response: {:#?}", r.json::().await), + Err(e) => error!("Fetch error: {}", e), + Ok(r) => debug!("Response: {:#?}", r.json::().await), } } }