Pruning junk generated sentences
ci/woodpecker/push/ociImagePush Pipeline was successful Details

This commit is contained in:
Natty 2024-09-25 01:19:14 +02:00
parent 8cbc9f2f2c
commit 439302b355
Signed by: natty
GPG Key ID: BF6CB659ADEE60EC
1 changed files with 126 additions and 16 deletions

View File

@ -73,7 +73,9 @@ struct Tokenizer {}
impl Tokenizer { impl Tokenizer {
fn tokenize<'a>(&self, input: &'a str) -> impl Iterator<Item = &'a str> { fn tokenize<'a>(&self, input: &'a str) -> impl Iterator<Item = &'a str> {
input.split_whitespace() input
.split_whitespace()
.filter(|s| !s.chars().all(|c| "()[]{};,.".contains(c)))
} }
} }
@ -92,6 +94,12 @@ struct Model<'a> {
tokenizer: Tokenizer, tokenizer: Tokenizer,
} }
#[derive(Debug)]
struct ModelOutput {
sentence: String,
heat: f64,
}
impl<'a> Model<'a> { impl<'a> Model<'a> {
fn new(config: &'a Config) -> Self { fn new(config: &'a Config) -> Self {
let mut model = Trie::new(); let mut model = Trie::new();
@ -119,12 +127,16 @@ impl<'a> Model<'a> {
} }
} }
#[tracing::instrument] #[tracing::instrument(skip(self))]
fn generate(&self) -> String { fn generate(&self) -> ModelOutput {
let mut rand = rand::thread_rng(); let mut rand = rand::thread_rng();
let mut attempts = 0; let mut attempts = 0;
let sentence;
let mut heats = Vec::new();
let mut heat;
loop { loop {
let mut current = MarkovItem::Start; let mut current = MarkovItem::Start;
let mut tokens = Vec::new(); let mut tokens = Vec::new();
@ -143,9 +155,11 @@ impl<'a> Model<'a> {
let mut options = HashMap::<u32, f32>::new(); let mut options = HashMap::<u32, f32>::new();
let prefix = history.make_contiguous().as_ref().into_bytes(); let prefix = history.make_contiguous().as_ref().into_bytes();
let mut prefix_slice = prefix.as_ref(); let mut prefix_slice = prefix.as_ref();
let mut weight_max: f64 = 0.0;
let item_size = std::mem::size_of::<u32>();
let prefix_max_tokens = (prefix_slice.len() / item_size - 1).max(1);
while !prefix_slice.is_empty() { while !prefix_slice.is_empty() {
let item_size = std::mem::size_of::<u32>();
let prefix_len_bytes = prefix_slice.len(); let prefix_len_bytes = prefix_slice.len();
let prefix_len_tokens = prefix_len_bytes / item_size; let prefix_len_tokens = prefix_len_bytes / item_size;
@ -180,20 +194,23 @@ impl<'a> Model<'a> {
.iter_prefix(prefix_slice) .iter_prefix(prefix_slice)
.filter(|(k, _)| k.len() >= prefix_len_bytes + item_size); .filter(|(k, _)| k.len() >= prefix_len_bytes + item_size);
for (p, t) in prefixes { for (p, weight_local) in prefixes {
let value = options let weight_global = options
.entry(u32::from_be_bytes( .entry(u32::from_be_bytes(
p[prefix_len_bytes..][..item_size].try_into().unwrap(), p[prefix_len_bytes..][..item_size].try_into().unwrap(),
)) ))
.or_insert(0f32); .or_insert(0f32);
*value += t * self.config.context_weights[prefix_len_tokens - 1]; let mut weight_scaled = weight_local
* self.config.context_weights[prefix_len_tokens - 1]
/ prefix_max_tokens as f32;
/* if matches!(current, MarkovItem::Stop) {
if matches!(current, MarkovItem::Start) { weight_scaled = weight_scaled.powf(2.0);
*value = value.powf(0.75);
} }
*/
*weight_global += weight_scaled;
weight_max = weight_max.max(*weight_global as f64);
} }
prefix_slice = &prefix_slice[item_size..]; prefix_slice = &prefix_slice[item_size..];
@ -203,9 +220,15 @@ impl<'a> Model<'a> {
let weights = WeightedIndex::new(entries.iter().map(|v| v.1)) let weights = WeightedIndex::new(entries.iter().map(|v| v.1))
.expect("Weighted index failed to create"); .expect("Weighted index failed to create");
current = MarkovItem::from_u32(entries[weights.sample(&mut rand)].0); let (entry_token, entry_global_weight) = entries[weights.sample(&mut rand)];
assert_ne!(weight_max, 0.0);
current = MarkovItem::from_u32(entry_token);
history.push_back(current); history.push_back(current);
// How likely was it to get this token?
let heat_raw = entry_global_weight as f64 / weight_max;
heats.push(heat_raw);
if let MarkovItem::Token(ti) = current { if let MarkovItem::Token(ti) = current {
tokens.push( tokens.push(
self.state self.state
@ -224,8 +247,48 @@ impl<'a> Model<'a> {
continue; continue;
} }
break tokens.join(" ").replace('@', "@\u{200b}"); sentence = tokens
.iter()
.map(|tok| {
if self.config.allowed_tags.contains(tok) {
return tok.to_owned();
} }
tok.replace('@', "@\u{200b}")
})
.collect::<Vec<_>>()
.join(" ");
let token_count = tokens.len() as f64;
let heat_count = heats.len() as f64;
// Geomean to get the final heat
heat = heats.iter().product::<f64>().powf(1.0 / heat_count);
// Are we just generatic gibberish?
let variance_mod = (heats.iter().map(|x| x.powf(3.0)).sum::<f64>() / heat_count)
.powf(1.25)
.min(1.0);
heat *= 1.0 - variance_mod;
// Too many mentions
heat /= (tokens.iter().filter(|s| s.starts_with('@')).count() as f64)
.powf(2.0)
.max(1.0);
// Penalty for being too short
heat *= (0.5 + token_count / 10.0).min(1.0);
// Up until here the heat is inverted
heat = 1.0 - heat;
// Account for float imprecision
heat = heat.clamp(0.0, 1.0);
break;
}
ModelOutput { sentence, heat }
} }
fn insert(&mut self, items: &[MarkovItem]) { fn insert(&mut self, items: &[MarkovItem]) {
@ -263,6 +326,11 @@ impl<'a> Model<'a> {
fn insert_tokens(&mut self, input: &str) { fn insert_tokens(&mut self, input: &str) {
let input = input.trim(); let input = input.trim();
if input.is_empty() {
return;
}
// Remove quotes // Remove quotes
let input = Regex::new(r"^> ?.+") let input = Regex::new(r"^> ?.+")
.unwrap() .unwrap()
@ -277,6 +345,14 @@ impl<'a> Model<'a> {
return; return;
} }
if input.contains('(') || input.contains(')') {
let paren_split = Regex::new("\\(|\\)").unwrap();
paren_split
.split(&input)
.for_each(|p| self.insert_tokens(p));
return;
}
if self if self
.config .config
.nasty_words .nasty_words
@ -285,9 +361,10 @@ impl<'a> Model<'a> {
{ {
return; return;
} }
let mut tokens = vec![MarkovItem::Start]; let mut tokens = vec![MarkovItem::Start];
tokens.extend(self.tokenizer.tokenize(&input).map(|t| { tokens.extend(self.tokenizer.tokenize(&input).map(|t| {
if t.starts_with("https") { if t.starts_with("https") || t.starts_with(':') && t.ends_with(':') {
self.convert_token(t) self.convert_token(t)
} else { } else {
self.convert_token(&t.to_lowercase()) self.convert_token(&t.to_lowercase())
@ -332,6 +409,20 @@ struct Config {
attempts: usize, attempts: usize,
stop_word_modifier: f32, stop_word_modifier: f32,
mention_penalty: f32, mention_penalty: f32,
#[serde(default = "default_heat_exponent")]
heat_exponent: f64,
#[serde(default = "default_heat_retries")]
heat_max_retries: u32,
#[serde(default)]
allowed_tags: Vec<String>,
}
fn default_heat_exponent() -> f64 {
0.7
}
fn default_heat_retries() -> u32 {
100
} }
impl Config { impl Config {
@ -458,8 +549,27 @@ async fn main() {
loop { loop {
timer.tick().await; timer.tick().await;
let mut retries = config.heat_max_retries;
let text = loop {
let text = model.generate(); let text = model.generate();
info!("Generated: {:?}", text); let p_skip = text.heat.powf(config.heat_exponent);
let mut rand = rand::thread_rng();
if rand.gen_bool(p_skip) && retries > 0 {
retries -= 1;
info!(
"[{}/{}] Skipped: {:?}, P(skip): {}",
config.heat_max_retries - retries,
config.heat_max_retries,
text,
p_skip
);
continue;
}
info!("Generated: {:?}, P(skip): {}", text, p_skip);
break text.sentence;
};
if config.test_mode { if config.test_mode {
continue; continue;