Added CI, some tweaks, and Dockerfile
ci/woodpecker/push/ociImagePush Pipeline was successful Details

This commit is contained in:
Natty 2024-04-17 01:35:21 +02:00
parent 4d784eeb16
commit 4f74e57190
Signed by: natty
GPG Key ID: BF6CB659ADEE60EC
6 changed files with 145 additions and 33 deletions

4
.dockerignore Normal file
View File

@ -0,0 +1,4 @@
/target
/state*.msg
.env
/config*.toml

1
.gitignore vendored
View File

@ -2,3 +2,4 @@
/state*.msg
.env
/config*.toml
/.idea

View File

@ -0,0 +1,31 @@
{
"db_name": "PostgreSQL",
"query": "SELECT text, \"createdAt\"\n FROM note\n WHERE \"note\".\"userId\" = $1 \n AND \"note\".\"createdAt\" > $2\n AND \"note\".\"createdAt\" < NOW() \n AND \"note\".\"visibility\" IN ('public', 'home') \n AND ($4 OR (\"note\".\"cw\" IS NULL OR LOWER(\"note\".\"cw\") IN (SELECT UNNEST($3::VARCHAR[]))))",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "text",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "createdAt",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Text",
"Timestamptz",
"VarcharArray",
"Bool"
]
},
"nullable": [
true,
false
]
},
"hash": "a86fab3e2b1a1d2580a342221e82f8881c74a98f571b9063de9b7972cee15239"
}

View File

@ -0,0 +1,18 @@
steps:
publish-docker-latest:
image: docker.io/plugins/kaniko
settings:
repo: git.astolfo.cool/natty/mag-markov
tags:
- ${CI_COMMIT_SHA}
- latest
dockerfile: Dockerfile
registry: git.astolfo.cool
username:
from_secret: docker_username
password:
from_secret: docker_password
when:
- event: push
branch: main

10
Dockerfile Normal file
View File

@ -0,0 +1,10 @@
FROM docker.io/rust:1.77-bookworm AS build
WORKDIR /mag-markov
ADD . .
RUN cargo build --release --locked
FROM docker.io/alpine:3.19
WORKDIR /app
COPY --from=build /mag-markov/target/release/mag-markov .
CMD /app/mag-markov

View File

@ -11,7 +11,10 @@ use clap::Parser;
use futures::TryStreamExt;
use indicatif::ProgressBar;
use qp_trie::Trie;
use rand::distributions::{Distribution, WeightedIndex};
use rand::{
distributions::{Distribution, WeightedIndex},
Rng,
};
use regex::Regex;
use reqwest::{header, ClientBuilder};
use serde::{Deserialize, Serialize};
@ -46,7 +49,7 @@ impl MarkovItem {
}
}
const CONTEXT_SIZE_MAX: usize = 5;
const CONTEXT_SIZE_MAX: usize = 8;
const CONTEXT_SIZE_BYTES: usize = CONTEXT_SIZE_MAX * std::mem::size_of::<u32>();
type PrefixBytes = SmallVec<[u8; CONTEXT_SIZE_BYTES]>;
@ -145,6 +148,32 @@ impl<'a> Model<'a> {
let item_size = std::mem::size_of::<u32>();
let prefix_len_bytes = prefix_slice.len();
let prefix_len_tokens = prefix_len_bytes / item_size;
// Mix in all bigrams if we're at the beginning
if matches!(current, MarkovItem::Start) {
options.extend(
self.state
.model
.iter()
.map(|(k, _)| k)
.filter(|k| k.len() == 2 * item_size)
.filter(|k| {
!matches!(
MarkovItem::from_u32(u32::from_be_bytes(
k[item_size..].try_into().unwrap()
)),
MarkovItem::Stop
)
})
.map(|item| {
(
u32::from_be_bytes(item[..item_size].try_into().unwrap()),
rand.gen_range(0.5..2.0),
)
}),
);
}
let prefixes = self
.state
.model
@ -158,7 +187,13 @@ impl<'a> Model<'a> {
))
.or_insert(0f32);
*value += t * self.config.context_weights[prefix_len_tokens];
*value += t * self.config.context_weights[prefix_len_tokens - 1];
/*
if matches!(current, MarkovItem::Start) {
*value = value.powf(0.75);
}
*/
}
prefix_slice = &prefix_slice[item_size..];
@ -229,7 +264,7 @@ impl<'a> Model<'a> {
fn insert_tokens(&mut self, input: &str) {
let input = input.trim();
// Remove quotes
let input = Regex::new(r"> ?.+")
let input = Regex::new(r"^> ?.+")
.unwrap()
.replace_all(input, "")
.to_string();
@ -251,17 +286,19 @@ impl<'a> Model<'a> {
return;
}
let mut tokens = vec![MarkovItem::Start];
tokens.extend(
self.tokenizer
.tokenize(&input)
.map(|t| self.convert_token(&t.to_lowercase())),
);
tokens.push(MarkovItem::Stop);
for i in 2..=self.config.history_len() {
for tok in tokens.windows(i) {
self.insert(tok);
tokens.extend(self.tokenizer.tokenize(&input).map(|t| {
if t.starts_with("https") {
self.convert_token(t)
} else {
self.convert_token(&t.to_lowercase())
}
}));
for _ in 0..self.config.history_len() {
tokens.push(MarkovItem::Stop);
}
for tok in tokens.windows(tokens.len().min(self.config.training_len()).max(2)) {
self.insert(tok);
}
}
}
@ -301,6 +338,10 @@ impl Config {
fn history_len(&self) -> usize {
self.context_weights.len()
}
fn training_len(&self) -> usize {
self.history_len() + 1
}
}
#[tokio::main]
@ -331,10 +372,10 @@ async fn main() {
return;
}
if !(2..=CONTEXT_SIZE_MAX).contains(&config.context_weights.len()) {
if !(2..=CONTEXT_SIZE_MAX).contains(&config.training_len()) {
error!(
"Context size {} out of range {:?}",
config.context_weights.len(),
config.training_len(),
2..=CONTEXT_SIZE_MAX
);
return;
@ -354,6 +395,7 @@ async fn main() {
};
let pool = PgPoolOptions::new()
.max_connections(1)
.connect(std::env::var("DATABASE_URL").as_deref().unwrap())
.await
.unwrap();
@ -379,25 +421,31 @@ async fn main() {
)
.fetch(&pool);
let mut stream = progress.wrap_stream(text_stream);
let mut cnt = 0;
while let Ok(Some(item)) = stream.try_next().await.map_err(|e| {
error!("Error fetching: {}", e);
e
}) {
model.state.fetched_up_to = Some(item.createdAt);
model.insert_tokens(item.text.as_deref().unwrap_or_default());
progress.set_message(format!(
"Items: {} | Trie size: {} | Current: {:?}",
cnt,
model.state.model.count(),
item.text.unwrap_or_default().clone()
));
cnt += 1;
progress.tick();
progress.set_length(cnt);
{
let mut stream = progress.wrap_stream(text_stream);
let mut cnt = 0;
while let Ok(Some(item)) = stream.try_next().await.map_err(|e| {
error!("Error fetching: {}", e);
e
}) {
model.state.fetched_up_to = Some(item.createdAt);
model.insert_tokens(item.text.as_deref().unwrap_or_default());
progress.set_message(format!(
"Items: {} | Trie size: {} | Current: {:?}",
cnt,
model.state.model.count(),
item.text.unwrap_or_default().clone()
));
cnt += 1;
progress.tick();
progress.set_length(cnt);
}
}
tokio::spawn(async move {
drop(pool);
});
info!("Keys: {}", model.state.model.count());
progress.set_message("Saving data...");