Pruning junk generated sentences
ci/woodpecker/push/ociImagePush Pipeline was successful
Details
ci/woodpecker/push/ociImagePush Pipeline was successful
Details
This commit is contained in:
parent
8cbc9f2f2c
commit
439302b355
142
src/main.rs
142
src/main.rs
|
@ -73,7 +73,9 @@ struct Tokenizer {}
|
|||
|
||||
impl Tokenizer {
|
||||
fn tokenize<'a>(&self, input: &'a str) -> impl Iterator<Item = &'a str> {
|
||||
input.split_whitespace()
|
||||
input
|
||||
.split_whitespace()
|
||||
.filter(|s| !s.chars().all(|c| "()[]{};,.".contains(c)))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -92,6 +94,12 @@ struct Model<'a> {
|
|||
tokenizer: Tokenizer,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ModelOutput {
|
||||
sentence: String,
|
||||
heat: f64,
|
||||
}
|
||||
|
||||
impl<'a> Model<'a> {
|
||||
fn new(config: &'a Config) -> Self {
|
||||
let mut model = Trie::new();
|
||||
|
@ -119,12 +127,16 @@ impl<'a> Model<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
fn generate(&self) -> String {
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn generate(&self) -> ModelOutput {
|
||||
let mut rand = rand::thread_rng();
|
||||
|
||||
let mut attempts = 0;
|
||||
|
||||
let sentence;
|
||||
let mut heats = Vec::new();
|
||||
let mut heat;
|
||||
|
||||
loop {
|
||||
let mut current = MarkovItem::Start;
|
||||
let mut tokens = Vec::new();
|
||||
|
@ -143,9 +155,11 @@ impl<'a> Model<'a> {
|
|||
let mut options = HashMap::<u32, f32>::new();
|
||||
let prefix = history.make_contiguous().as_ref().into_bytes();
|
||||
let mut prefix_slice = prefix.as_ref();
|
||||
let mut weight_max: f64 = 0.0;
|
||||
let item_size = std::mem::size_of::<u32>();
|
||||
let prefix_max_tokens = (prefix_slice.len() / item_size - 1).max(1);
|
||||
|
||||
while !prefix_slice.is_empty() {
|
||||
let item_size = std::mem::size_of::<u32>();
|
||||
let prefix_len_bytes = prefix_slice.len();
|
||||
let prefix_len_tokens = prefix_len_bytes / item_size;
|
||||
|
||||
|
@ -180,20 +194,23 @@ impl<'a> Model<'a> {
|
|||
.iter_prefix(prefix_slice)
|
||||
.filter(|(k, _)| k.len() >= prefix_len_bytes + item_size);
|
||||
|
||||
for (p, t) in prefixes {
|
||||
let value = options
|
||||
for (p, weight_local) in prefixes {
|
||||
let weight_global = options
|
||||
.entry(u32::from_be_bytes(
|
||||
p[prefix_len_bytes..][..item_size].try_into().unwrap(),
|
||||
))
|
||||
.or_insert(0f32);
|
||||
|
||||
*value += t * self.config.context_weights[prefix_len_tokens - 1];
|
||||
let mut weight_scaled = weight_local
|
||||
* self.config.context_weights[prefix_len_tokens - 1]
|
||||
/ prefix_max_tokens as f32;
|
||||
|
||||
/*
|
||||
if matches!(current, MarkovItem::Start) {
|
||||
*value = value.powf(0.75);
|
||||
if matches!(current, MarkovItem::Stop) {
|
||||
weight_scaled = weight_scaled.powf(2.0);
|
||||
}
|
||||
*/
|
||||
|
||||
*weight_global += weight_scaled;
|
||||
weight_max = weight_max.max(*weight_global as f64);
|
||||
}
|
||||
|
||||
prefix_slice = &prefix_slice[item_size..];
|
||||
|
@ -203,9 +220,15 @@ impl<'a> Model<'a> {
|
|||
let weights = WeightedIndex::new(entries.iter().map(|v| v.1))
|
||||
.expect("Weighted index failed to create");
|
||||
|
||||
current = MarkovItem::from_u32(entries[weights.sample(&mut rand)].0);
|
||||
let (entry_token, entry_global_weight) = entries[weights.sample(&mut rand)];
|
||||
assert_ne!(weight_max, 0.0);
|
||||
current = MarkovItem::from_u32(entry_token);
|
||||
history.push_back(current);
|
||||
|
||||
// How likely was it to get this token?
|
||||
let heat_raw = entry_global_weight as f64 / weight_max;
|
||||
heats.push(heat_raw);
|
||||
|
||||
if let MarkovItem::Token(ti) = current {
|
||||
tokens.push(
|
||||
self.state
|
||||
|
@ -224,8 +247,48 @@ impl<'a> Model<'a> {
|
|||
continue;
|
||||
}
|
||||
|
||||
break tokens.join(" ").replace('@', "@\u{200b}");
|
||||
sentence = tokens
|
||||
.iter()
|
||||
.map(|tok| {
|
||||
if self.config.allowed_tags.contains(tok) {
|
||||
return tok.to_owned();
|
||||
}
|
||||
|
||||
tok.replace('@', "@\u{200b}")
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
|
||||
let token_count = tokens.len() as f64;
|
||||
let heat_count = heats.len() as f64;
|
||||
|
||||
// Geomean to get the final heat
|
||||
heat = heats.iter().product::<f64>().powf(1.0 / heat_count);
|
||||
|
||||
// Are we just generatic gibberish?
|
||||
let variance_mod = (heats.iter().map(|x| x.powf(3.0)).sum::<f64>() / heat_count)
|
||||
.powf(1.25)
|
||||
.min(1.0);
|
||||
heat *= 1.0 - variance_mod;
|
||||
|
||||
// Too many mentions
|
||||
heat /= (tokens.iter().filter(|s| s.starts_with('@')).count() as f64)
|
||||
.powf(2.0)
|
||||
.max(1.0);
|
||||
|
||||
// Penalty for being too short
|
||||
heat *= (0.5 + token_count / 10.0).min(1.0);
|
||||
|
||||
// Up until here the heat is inverted
|
||||
heat = 1.0 - heat;
|
||||
|
||||
// Account for float imprecision
|
||||
heat = heat.clamp(0.0, 1.0);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
ModelOutput { sentence, heat }
|
||||
}
|
||||
|
||||
fn insert(&mut self, items: &[MarkovItem]) {
|
||||
|
@ -263,6 +326,11 @@ impl<'a> Model<'a> {
|
|||
|
||||
fn insert_tokens(&mut self, input: &str) {
|
||||
let input = input.trim();
|
||||
|
||||
if input.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove quotes
|
||||
let input = Regex::new(r"^> ?.+")
|
||||
.unwrap()
|
||||
|
@ -277,6 +345,14 @@ impl<'a> Model<'a> {
|
|||
return;
|
||||
}
|
||||
|
||||
if input.contains('(') || input.contains(')') {
|
||||
let paren_split = Regex::new("\\(|\\)").unwrap();
|
||||
paren_split
|
||||
.split(&input)
|
||||
.for_each(|p| self.insert_tokens(p));
|
||||
return;
|
||||
}
|
||||
|
||||
if self
|
||||
.config
|
||||
.nasty_words
|
||||
|
@ -285,9 +361,10 @@ impl<'a> Model<'a> {
|
|||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let mut tokens = vec![MarkovItem::Start];
|
||||
tokens.extend(self.tokenizer.tokenize(&input).map(|t| {
|
||||
if t.starts_with("https") {
|
||||
if t.starts_with("https") || t.starts_with(':') && t.ends_with(':') {
|
||||
self.convert_token(t)
|
||||
} else {
|
||||
self.convert_token(&t.to_lowercase())
|
||||
|
@ -332,6 +409,20 @@ struct Config {
|
|||
attempts: usize,
|
||||
stop_word_modifier: f32,
|
||||
mention_penalty: f32,
|
||||
#[serde(default = "default_heat_exponent")]
|
||||
heat_exponent: f64,
|
||||
#[serde(default = "default_heat_retries")]
|
||||
heat_max_retries: u32,
|
||||
#[serde(default)]
|
||||
allowed_tags: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_heat_exponent() -> f64 {
|
||||
0.7
|
||||
}
|
||||
|
||||
fn default_heat_retries() -> u32 {
|
||||
100
|
||||
}
|
||||
|
||||
impl Config {
|
||||
|
@ -458,8 +549,27 @@ async fn main() {
|
|||
loop {
|
||||
timer.tick().await;
|
||||
|
||||
let text = model.generate();
|
||||
info!("Generated: {:?}", text);
|
||||
let mut retries = config.heat_max_retries;
|
||||
let text = loop {
|
||||
let text = model.generate();
|
||||
let p_skip = text.heat.powf(config.heat_exponent);
|
||||
let mut rand = rand::thread_rng();
|
||||
if rand.gen_bool(p_skip) && retries > 0 {
|
||||
retries -= 1;
|
||||
info!(
|
||||
"[{}/{}] Skipped: {:?}, P(skip): {}",
|
||||
config.heat_max_retries - retries,
|
||||
config.heat_max_retries,
|
||||
text,
|
||||
p_skip
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
info!("Generated: {:?}, P(skip): {}", text, p_skip);
|
||||
|
||||
break text.sentence;
|
||||
};
|
||||
|
||||
if config.test_mode {
|
||||
continue;
|
||||
|
|
Loading…
Reference in New Issue