Pruning junk generated sentences
ci/woodpecker/push/ociImagePush Pipeline was successful
Details
ci/woodpecker/push/ociImagePush Pipeline was successful
Details
This commit is contained in:
parent
8cbc9f2f2c
commit
439302b355
140
src/main.rs
140
src/main.rs
|
@ -73,7 +73,9 @@ struct Tokenizer {}
|
||||||
|
|
||||||
impl Tokenizer {
|
impl Tokenizer {
|
||||||
fn tokenize<'a>(&self, input: &'a str) -> impl Iterator<Item = &'a str> {
|
fn tokenize<'a>(&self, input: &'a str) -> impl Iterator<Item = &'a str> {
|
||||||
input.split_whitespace()
|
input
|
||||||
|
.split_whitespace()
|
||||||
|
.filter(|s| !s.chars().all(|c| "()[]{};,.".contains(c)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,6 +94,12 @@ struct Model<'a> {
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ModelOutput {
|
||||||
|
sentence: String,
|
||||||
|
heat: f64,
|
||||||
|
}
|
||||||
|
|
||||||
impl<'a> Model<'a> {
|
impl<'a> Model<'a> {
|
||||||
fn new(config: &'a Config) -> Self {
|
fn new(config: &'a Config) -> Self {
|
||||||
let mut model = Trie::new();
|
let mut model = Trie::new();
|
||||||
|
@ -119,12 +127,16 @@ impl<'a> Model<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument(skip(self))]
|
||||||
fn generate(&self) -> String {
|
fn generate(&self) -> ModelOutput {
|
||||||
let mut rand = rand::thread_rng();
|
let mut rand = rand::thread_rng();
|
||||||
|
|
||||||
let mut attempts = 0;
|
let mut attempts = 0;
|
||||||
|
|
||||||
|
let sentence;
|
||||||
|
let mut heats = Vec::new();
|
||||||
|
let mut heat;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let mut current = MarkovItem::Start;
|
let mut current = MarkovItem::Start;
|
||||||
let mut tokens = Vec::new();
|
let mut tokens = Vec::new();
|
||||||
|
@ -143,9 +155,11 @@ impl<'a> Model<'a> {
|
||||||
let mut options = HashMap::<u32, f32>::new();
|
let mut options = HashMap::<u32, f32>::new();
|
||||||
let prefix = history.make_contiguous().as_ref().into_bytes();
|
let prefix = history.make_contiguous().as_ref().into_bytes();
|
||||||
let mut prefix_slice = prefix.as_ref();
|
let mut prefix_slice = prefix.as_ref();
|
||||||
|
let mut weight_max: f64 = 0.0;
|
||||||
|
let item_size = std::mem::size_of::<u32>();
|
||||||
|
let prefix_max_tokens = (prefix_slice.len() / item_size - 1).max(1);
|
||||||
|
|
||||||
while !prefix_slice.is_empty() {
|
while !prefix_slice.is_empty() {
|
||||||
let item_size = std::mem::size_of::<u32>();
|
|
||||||
let prefix_len_bytes = prefix_slice.len();
|
let prefix_len_bytes = prefix_slice.len();
|
||||||
let prefix_len_tokens = prefix_len_bytes / item_size;
|
let prefix_len_tokens = prefix_len_bytes / item_size;
|
||||||
|
|
||||||
|
@ -180,20 +194,23 @@ impl<'a> Model<'a> {
|
||||||
.iter_prefix(prefix_slice)
|
.iter_prefix(prefix_slice)
|
||||||
.filter(|(k, _)| k.len() >= prefix_len_bytes + item_size);
|
.filter(|(k, _)| k.len() >= prefix_len_bytes + item_size);
|
||||||
|
|
||||||
for (p, t) in prefixes {
|
for (p, weight_local) in prefixes {
|
||||||
let value = options
|
let weight_global = options
|
||||||
.entry(u32::from_be_bytes(
|
.entry(u32::from_be_bytes(
|
||||||
p[prefix_len_bytes..][..item_size].try_into().unwrap(),
|
p[prefix_len_bytes..][..item_size].try_into().unwrap(),
|
||||||
))
|
))
|
||||||
.or_insert(0f32);
|
.or_insert(0f32);
|
||||||
|
|
||||||
*value += t * self.config.context_weights[prefix_len_tokens - 1];
|
let mut weight_scaled = weight_local
|
||||||
|
* self.config.context_weights[prefix_len_tokens - 1]
|
||||||
|
/ prefix_max_tokens as f32;
|
||||||
|
|
||||||
/*
|
if matches!(current, MarkovItem::Stop) {
|
||||||
if matches!(current, MarkovItem::Start) {
|
weight_scaled = weight_scaled.powf(2.0);
|
||||||
*value = value.powf(0.75);
|
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
*weight_global += weight_scaled;
|
||||||
|
weight_max = weight_max.max(*weight_global as f64);
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix_slice = &prefix_slice[item_size..];
|
prefix_slice = &prefix_slice[item_size..];
|
||||||
|
@ -203,9 +220,15 @@ impl<'a> Model<'a> {
|
||||||
let weights = WeightedIndex::new(entries.iter().map(|v| v.1))
|
let weights = WeightedIndex::new(entries.iter().map(|v| v.1))
|
||||||
.expect("Weighted index failed to create");
|
.expect("Weighted index failed to create");
|
||||||
|
|
||||||
current = MarkovItem::from_u32(entries[weights.sample(&mut rand)].0);
|
let (entry_token, entry_global_weight) = entries[weights.sample(&mut rand)];
|
||||||
|
assert_ne!(weight_max, 0.0);
|
||||||
|
current = MarkovItem::from_u32(entry_token);
|
||||||
history.push_back(current);
|
history.push_back(current);
|
||||||
|
|
||||||
|
// How likely was it to get this token?
|
||||||
|
let heat_raw = entry_global_weight as f64 / weight_max;
|
||||||
|
heats.push(heat_raw);
|
||||||
|
|
||||||
if let MarkovItem::Token(ti) = current {
|
if let MarkovItem::Token(ti) = current {
|
||||||
tokens.push(
|
tokens.push(
|
||||||
self.state
|
self.state
|
||||||
|
@ -224,8 +247,48 @@ impl<'a> Model<'a> {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
break tokens.join(" ").replace('@', "@\u{200b}");
|
sentence = tokens
|
||||||
|
.iter()
|
||||||
|
.map(|tok| {
|
||||||
|
if self.config.allowed_tags.contains(tok) {
|
||||||
|
return tok.to_owned();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tok.replace('@', "@\u{200b}")
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" ");
|
||||||
|
|
||||||
|
let token_count = tokens.len() as f64;
|
||||||
|
let heat_count = heats.len() as f64;
|
||||||
|
|
||||||
|
// Geomean to get the final heat
|
||||||
|
heat = heats.iter().product::<f64>().powf(1.0 / heat_count);
|
||||||
|
|
||||||
|
// Are we just generatic gibberish?
|
||||||
|
let variance_mod = (heats.iter().map(|x| x.powf(3.0)).sum::<f64>() / heat_count)
|
||||||
|
.powf(1.25)
|
||||||
|
.min(1.0);
|
||||||
|
heat *= 1.0 - variance_mod;
|
||||||
|
|
||||||
|
// Too many mentions
|
||||||
|
heat /= (tokens.iter().filter(|s| s.starts_with('@')).count() as f64)
|
||||||
|
.powf(2.0)
|
||||||
|
.max(1.0);
|
||||||
|
|
||||||
|
// Penalty for being too short
|
||||||
|
heat *= (0.5 + token_count / 10.0).min(1.0);
|
||||||
|
|
||||||
|
// Up until here the heat is inverted
|
||||||
|
heat = 1.0 - heat;
|
||||||
|
|
||||||
|
// Account for float imprecision
|
||||||
|
heat = heat.clamp(0.0, 1.0);
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelOutput { sentence, heat }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn insert(&mut self, items: &[MarkovItem]) {
|
fn insert(&mut self, items: &[MarkovItem]) {
|
||||||
|
@ -263,6 +326,11 @@ impl<'a> Model<'a> {
|
||||||
|
|
||||||
fn insert_tokens(&mut self, input: &str) {
|
fn insert_tokens(&mut self, input: &str) {
|
||||||
let input = input.trim();
|
let input = input.trim();
|
||||||
|
|
||||||
|
if input.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Remove quotes
|
// Remove quotes
|
||||||
let input = Regex::new(r"^> ?.+")
|
let input = Regex::new(r"^> ?.+")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -277,6 +345,14 @@ impl<'a> Model<'a> {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.contains('(') || input.contains(')') {
|
||||||
|
let paren_split = Regex::new("\\(|\\)").unwrap();
|
||||||
|
paren_split
|
||||||
|
.split(&input)
|
||||||
|
.for_each(|p| self.insert_tokens(p));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if self
|
if self
|
||||||
.config
|
.config
|
||||||
.nasty_words
|
.nasty_words
|
||||||
|
@ -285,9 +361,10 @@ impl<'a> Model<'a> {
|
||||||
{
|
{
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut tokens = vec![MarkovItem::Start];
|
let mut tokens = vec![MarkovItem::Start];
|
||||||
tokens.extend(self.tokenizer.tokenize(&input).map(|t| {
|
tokens.extend(self.tokenizer.tokenize(&input).map(|t| {
|
||||||
if t.starts_with("https") {
|
if t.starts_with("https") || t.starts_with(':') && t.ends_with(':') {
|
||||||
self.convert_token(t)
|
self.convert_token(t)
|
||||||
} else {
|
} else {
|
||||||
self.convert_token(&t.to_lowercase())
|
self.convert_token(&t.to_lowercase())
|
||||||
|
@ -332,6 +409,20 @@ struct Config {
|
||||||
attempts: usize,
|
attempts: usize,
|
||||||
stop_word_modifier: f32,
|
stop_word_modifier: f32,
|
||||||
mention_penalty: f32,
|
mention_penalty: f32,
|
||||||
|
#[serde(default = "default_heat_exponent")]
|
||||||
|
heat_exponent: f64,
|
||||||
|
#[serde(default = "default_heat_retries")]
|
||||||
|
heat_max_retries: u32,
|
||||||
|
#[serde(default)]
|
||||||
|
allowed_tags: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_heat_exponent() -> f64 {
|
||||||
|
0.7
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_heat_retries() -> u32 {
|
||||||
|
100
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
@ -458,8 +549,27 @@ async fn main() {
|
||||||
loop {
|
loop {
|
||||||
timer.tick().await;
|
timer.tick().await;
|
||||||
|
|
||||||
|
let mut retries = config.heat_max_retries;
|
||||||
|
let text = loop {
|
||||||
let text = model.generate();
|
let text = model.generate();
|
||||||
info!("Generated: {:?}", text);
|
let p_skip = text.heat.powf(config.heat_exponent);
|
||||||
|
let mut rand = rand::thread_rng();
|
||||||
|
if rand.gen_bool(p_skip) && retries > 0 {
|
||||||
|
retries -= 1;
|
||||||
|
info!(
|
||||||
|
"[{}/{}] Skipped: {:?}, P(skip): {}",
|
||||||
|
config.heat_max_retries - retries,
|
||||||
|
config.heat_max_retries,
|
||||||
|
text,
|
||||||
|
p_skip
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Generated: {:?}, P(skip): {}", text, p_skip);
|
||||||
|
|
||||||
|
break text.sentence;
|
||||||
|
};
|
||||||
|
|
||||||
if config.test_mode {
|
if config.test_mode {
|
||||||
continue;
|
continue;
|
||||||
|
|
Loading…
Reference in New Issue