From 439302b35520a657270c53d4864f6aceee31b03c Mon Sep 17 00:00:00 2001 From: Natty Date: Wed, 25 Sep 2024 01:19:14 +0200 Subject: [PATCH] Pruning junk generated sentences --- src/main.rs | 142 ++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 126 insertions(+), 16 deletions(-) diff --git a/src/main.rs b/src/main.rs index 08cb82f..6b62313 100644 --- a/src/main.rs +++ b/src/main.rs @@ -73,7 +73,9 @@ struct Tokenizer {} impl Tokenizer { fn tokenize<'a>(&self, input: &'a str) -> impl Iterator { - input.split_whitespace() + input + .split_whitespace() + .filter(|s| !s.chars().all(|c| "()[]{};,.".contains(c))) } } @@ -92,6 +94,12 @@ struct Model<'a> { tokenizer: Tokenizer, } +#[derive(Debug)] +struct ModelOutput { + sentence: String, + heat: f64, +} + impl<'a> Model<'a> { fn new(config: &'a Config) -> Self { let mut model = Trie::new(); @@ -119,12 +127,16 @@ impl<'a> Model<'a> { } } - #[tracing::instrument] - fn generate(&self) -> String { + #[tracing::instrument(skip(self))] + fn generate(&self) -> ModelOutput { let mut rand = rand::thread_rng(); let mut attempts = 0; + let sentence; + let mut heats = Vec::new(); + let mut heat; + loop { let mut current = MarkovItem::Start; let mut tokens = Vec::new(); @@ -143,9 +155,11 @@ impl<'a> Model<'a> { let mut options = HashMap::::new(); let prefix = history.make_contiguous().as_ref().into_bytes(); let mut prefix_slice = prefix.as_ref(); + let mut weight_max: f64 = 0.0; + let item_size = std::mem::size_of::(); + let prefix_max_tokens = (prefix_slice.len() / item_size - 1).max(1); while !prefix_slice.is_empty() { - let item_size = std::mem::size_of::(); let prefix_len_bytes = prefix_slice.len(); let prefix_len_tokens = prefix_len_bytes / item_size; @@ -180,20 +194,23 @@ impl<'a> Model<'a> { .iter_prefix(prefix_slice) .filter(|(k, _)| k.len() >= prefix_len_bytes + item_size); - for (p, t) in prefixes { - let value = options + for (p, weight_local) in prefixes { + let weight_global = options .entry(u32::from_be_bytes( p[prefix_len_bytes..][..item_size].try_into().unwrap(), )) .or_insert(0f32); - *value += t * self.config.context_weights[prefix_len_tokens - 1]; + let mut weight_scaled = weight_local + * self.config.context_weights[prefix_len_tokens - 1] + / prefix_max_tokens as f32; - /* - if matches!(current, MarkovItem::Start) { - *value = value.powf(0.75); + if matches!(current, MarkovItem::Stop) { + weight_scaled = weight_scaled.powf(2.0); } - */ + + *weight_global += weight_scaled; + weight_max = weight_max.max(*weight_global as f64); } prefix_slice = &prefix_slice[item_size..]; @@ -203,9 +220,15 @@ impl<'a> Model<'a> { let weights = WeightedIndex::new(entries.iter().map(|v| v.1)) .expect("Weighted index failed to create"); - current = MarkovItem::from_u32(entries[weights.sample(&mut rand)].0); + let (entry_token, entry_global_weight) = entries[weights.sample(&mut rand)]; + assert_ne!(weight_max, 0.0); + current = MarkovItem::from_u32(entry_token); history.push_back(current); + // How likely was it to get this token? + let heat_raw = entry_global_weight as f64 / weight_max; + heats.push(heat_raw); + if let MarkovItem::Token(ti) = current { tokens.push( self.state @@ -224,8 +247,48 @@ impl<'a> Model<'a> { continue; } - break tokens.join(" ").replace('@', "@\u{200b}"); + sentence = tokens + .iter() + .map(|tok| { + if self.config.allowed_tags.contains(tok) { + return tok.to_owned(); + } + + tok.replace('@', "@\u{200b}") + }) + .collect::>() + .join(" "); + + let token_count = tokens.len() as f64; + let heat_count = heats.len() as f64; + + // Geomean to get the final heat + heat = heats.iter().product::().powf(1.0 / heat_count); + + // Are we just generatic gibberish? + let variance_mod = (heats.iter().map(|x| x.powf(3.0)).sum::() / heat_count) + .powf(1.25) + .min(1.0); + heat *= 1.0 - variance_mod; + + // Too many mentions + heat /= (tokens.iter().filter(|s| s.starts_with('@')).count() as f64) + .powf(2.0) + .max(1.0); + + // Penalty for being too short + heat *= (0.5 + token_count / 10.0).min(1.0); + + // Up until here the heat is inverted + heat = 1.0 - heat; + + // Account for float imprecision + heat = heat.clamp(0.0, 1.0); + + break; } + + ModelOutput { sentence, heat } } fn insert(&mut self, items: &[MarkovItem]) { @@ -263,6 +326,11 @@ impl<'a> Model<'a> { fn insert_tokens(&mut self, input: &str) { let input = input.trim(); + + if input.is_empty() { + return; + } + // Remove quotes let input = Regex::new(r"^> ?.+") .unwrap() @@ -277,6 +345,14 @@ impl<'a> Model<'a> { return; } + if input.contains('(') || input.contains(')') { + let paren_split = Regex::new("\\(|\\)").unwrap(); + paren_split + .split(&input) + .for_each(|p| self.insert_tokens(p)); + return; + } + if self .config .nasty_words @@ -285,9 +361,10 @@ impl<'a> Model<'a> { { return; } + let mut tokens = vec![MarkovItem::Start]; tokens.extend(self.tokenizer.tokenize(&input).map(|t| { - if t.starts_with("https") { + if t.starts_with("https") || t.starts_with(':') && t.ends_with(':') { self.convert_token(t) } else { self.convert_token(&t.to_lowercase()) @@ -332,6 +409,20 @@ struct Config { attempts: usize, stop_word_modifier: f32, mention_penalty: f32, + #[serde(default = "default_heat_exponent")] + heat_exponent: f64, + #[serde(default = "default_heat_retries")] + heat_max_retries: u32, + #[serde(default)] + allowed_tags: Vec, +} + +fn default_heat_exponent() -> f64 { + 0.7 +} + +fn default_heat_retries() -> u32 { + 100 } impl Config { @@ -458,8 +549,27 @@ async fn main() { loop { timer.tick().await; - let text = model.generate(); - info!("Generated: {:?}", text); + let mut retries = config.heat_max_retries; + let text = loop { + let text = model.generate(); + let p_skip = text.heat.powf(config.heat_exponent); + let mut rand = rand::thread_rng(); + if rand.gen_bool(p_skip) && retries > 0 { + retries -= 1; + info!( + "[{}/{}] Skipped: {:?}, P(skip): {}", + config.heat_max_retries - retries, + config.heat_max_retries, + text, + p_skip + ); + continue; + } + + info!("Generated: {:?}, P(skip): {}", text, p_skip); + + break text.sentence; + }; if config.test_mode { continue;