mirror of
https://git.madhouse-project.org/algernon/iocaine.git
synced 2025-03-10 17:28:49 +01:00
Do less allocating and copying when generating text
Before, on a low-capacity system (such as a an inexpensive cloud host), doing Markov-chain text generation was _extraordinarily_ slow, taking half a second or more to produce a page, and if multiple requests came in simultaneously they could easily swamp the capacity of such a system. Most of the time was spent in the Words iterator, which did a bunch of cloning of Strings in what the hot path. This changes the Markov generator's internal representation - now, instead of storing Strings, it stores index-pairs into a single shared String, normalized so that all references to particular words are collapsed into a single pair. This also means that the hash map is working with fixed-size values, which can't hurt. In addition, it does only one hash-map lookup per generated word in the happy-path of not reaching the end of the chain. The upshot of all this is that where it was taking a half-second or more to generate a page, it now takes about 0.001 seconds. On the downside, the initialization of WurstsalatGeneratorPro has become rather less flexible. Before, you created one and then taught it various strings, or gave it a list of paths to read and teach itself from. Now, the _only_ way to create one is directly with a list of paths. Changing this is possible, but it means `Substr` would have to learn to distinguish which source data it came from, which would mean a likely 50% increase in its size. It didn't seem worth it to preserve that capability, which wasn't even being used.
This commit is contained in:
parent
d2bdbf750a
commit
284af56e68
4 changed files with 179 additions and 54 deletions
|
@ -29,9 +29,8 @@ pub struct Iocaine {
|
||||||
|
|
||||||
impl Iocaine {
|
impl Iocaine {
|
||||||
pub fn new(config: Config) -> Result<Self> {
|
pub fn new(config: Config) -> Result<Self> {
|
||||||
let mut chain = WurstsalatGeneratorPro::new();
|
|
||||||
tracing::info!("Loading the markov chain corpus");
|
tracing::info!("Loading the markov chain corpus");
|
||||||
chain.learn_from_files(&config.sources.markov)?;
|
let chain = WurstsalatGeneratorPro::learn_from_files(&config.sources.markov)?;
|
||||||
tracing::info!("Corpus loaded");
|
tracing::info!("Corpus loaded");
|
||||||
|
|
||||||
tracing::info!("Loading word salad");
|
tracing::info!("Loading word salad");
|
||||||
|
|
|
@ -68,8 +68,8 @@ impl AssembledStatisticalSequences {
|
||||||
|
|
||||||
let mut rng = GobbledyGook::for_url(format!("{}://{}/{}", group, &static_seed, index));
|
let mut rng = GobbledyGook::for_url(format!("{}://{}/{}", group, &static_seed, index));
|
||||||
let words = (1..=count)
|
let words = (1..=count)
|
||||||
.map(|_| iocaine.words.0.choose(&mut rng).unwrap().to_string())
|
.map(|_| iocaine.words.0.choose(&mut rng).unwrap().as_str())
|
||||||
.collect::<Vec<String>>()
|
.collect::<Vec<_>>()
|
||||||
.join("-");
|
.join("-");
|
||||||
out.write(&words)?;
|
out.write(&words)?;
|
||||||
|
|
||||||
|
|
|
@ -10,12 +10,16 @@
|
||||||
use rand::{seq::IndexedRandom, Rng};
|
use rand::{seq::IndexedRandom, Rng};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io;
|
use std::io::Read as _;
|
||||||
|
use substrings::{Interner, Substr, WhitespaceSplitIterator};
|
||||||
|
|
||||||
pub type Bigram = (String, String);
|
mod substrings;
|
||||||
|
|
||||||
|
type Bigram = (Substr, Substr);
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
pub struct WurstsalatGeneratorPro {
|
pub struct WurstsalatGeneratorPro {
|
||||||
map: HashMap<Bigram, Vec<String>>,
|
string: String,
|
||||||
|
map: HashMap<Bigram, Vec<Substr>>,
|
||||||
keys: Vec<Bigram>,
|
keys: Vec<Bigram>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,42 +28,63 @@ impl WurstsalatGeneratorPro {
|
||||||
Default::default()
|
Default::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn learn(&mut self, text: String) {
|
fn learn(string: String, mut breaks: &[usize]) -> Self {
|
||||||
let words = text.split_whitespace().collect::<Vec<&str>>();
|
let mut interner = Interner::new();
|
||||||
for window in words.windows(3) {
|
let words = WhitespaceSplitIterator::new(&string);
|
||||||
let (a, b, c) = (
|
let mut map = HashMap::<Bigram, Vec<Substr>>::new();
|
||||||
window[0].to_string(),
|
for window in words.collect::<Vec<_>>().windows(3) {
|
||||||
window[1].to_string(),
|
let (a, b, c) = (window[0], window[1], window[2]);
|
||||||
window[2].to_string(),
|
|
||||||
);
|
// This bit of weirdness is to preserve the behavior from
|
||||||
self.map.entry((a, b)).or_default().push(c);
|
// learning from multiple files independently; if our
|
||||||
|
// current window spans a break, we don't add the triple.
|
||||||
|
let mut skip_triple = false;
|
||||||
|
while !breaks.is_empty() && breaks[0] <= c.start {
|
||||||
|
if breaks[0] <= a.start {
|
||||||
|
// completely passed the first break, can remove it
|
||||||
|
breaks = &breaks[1..];
|
||||||
|
} else {
|
||||||
|
skip_triple = true;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
self.keys = self.map.keys().cloned().collect();
|
|
||||||
self.keys.sort_unstable();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn learn_from_files(&mut self, files: &[String]) -> Result<(), std::io::Error> {
|
if !skip_triple {
|
||||||
|
map.entry((interner.intern(&string, a), interner.intern(&string, b)))
|
||||||
|
.or_default()
|
||||||
|
.push(interner.intern(&string, c));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut keys = map.keys().copied().collect::<Vec<_>>();
|
||||||
|
keys.sort_unstable_by_key(|(s1, s2)| {
|
||||||
|
(&string[s1.start..s1.end], &string[s2.start..s2.end])
|
||||||
|
});
|
||||||
|
|
||||||
|
Self { string, keys, map }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn learn_from_files(files: &[String]) -> Result<Self, std::io::Error> {
|
||||||
|
let mut s = String::new();
|
||||||
|
let mut breaks = Vec::new();
|
||||||
for source in files.iter() {
|
for source in files.iter() {
|
||||||
let f = File::open(source)?;
|
let mut f = File::open(source)?;
|
||||||
let s = io::read_to_string(f)?.clone();
|
f.read_to_string(&mut s)?;
|
||||||
self.learn(s);
|
breaks.push(s.len());
|
||||||
|
s.push(' ');
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(Self::learn(s, &breaks))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn generate<R: Rng>(&self, mut rng: R) -> Words<'_, R> {
|
pub fn generate<R: Rng>(&self, mut rng: R) -> Words<'_, R> {
|
||||||
let initial_bigram = if self.map.is_empty() {
|
let initial_bigram = self.keys.choose(&mut rng).cloned().unwrap_or_default();
|
||||||
("".to_string(), "".to_string())
|
|
||||||
} else {
|
|
||||||
let (a, b) = self.keys.choose(&mut rng).unwrap();
|
|
||||||
(a.to_string(), b.to_string())
|
|
||||||
};
|
|
||||||
self.iter_with_rng_from(rng, initial_bigram)
|
self.iter_with_rng_from(rng, initial_bigram)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn iter_with_rng_from<R: Rng>(&self, rng: R, from: Bigram) -> Words<'_, R> {
|
fn iter_with_rng_from<R: Rng>(&self, rng: R, from: Bigram) -> Words<'_, R> {
|
||||||
Words {
|
Words {
|
||||||
|
string: self.string.as_str(),
|
||||||
map: &self.map,
|
map: &self.map,
|
||||||
rng,
|
rng,
|
||||||
keys: &self.keys,
|
keys: &self.keys,
|
||||||
|
@ -70,30 +95,31 @@ impl WurstsalatGeneratorPro {
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Words<'a, R: Rng> {
|
pub struct Words<'a, R: Rng> {
|
||||||
map: &'a HashMap<Bigram, Vec<String>>,
|
string: &'a str,
|
||||||
|
map: &'a HashMap<Bigram, Vec<Substr>>,
|
||||||
rng: R,
|
rng: R,
|
||||||
keys: &'a Vec<Bigram>,
|
keys: &'a [Bigram],
|
||||||
state: Bigram,
|
state: Bigram,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, R: Rng> Iterator for Words<'a, R> {
|
impl<'a, R: Rng> Iterator for Words<'a, R> {
|
||||||
type Item = String;
|
type Item = &'a str;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<String> {
|
fn next(&mut self) -> Option<&'a str> {
|
||||||
if self.map.is_empty() {
|
if self.map.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
let result = Some(self.state.0.clone());
|
let result = self.state.0.extract_str(self.string);
|
||||||
|
|
||||||
while !self.map.contains_key(&self.state) {
|
let next_words = self.map.get(&self.state).unwrap_or_else(|| {
|
||||||
let (a, b) = self.keys.choose(&mut self.rng).unwrap();
|
self.state = *self.keys.choose(&mut self.rng).unwrap();
|
||||||
self.state = (a.to_string(), b.to_string());
|
&self.map[&self.state]
|
||||||
}
|
});
|
||||||
let next_words = &self.map[&self.state];
|
let next = *next_words.choose(&mut self.rng).unwrap();
|
||||||
let next = next_words.choose(&mut self.rng).unwrap();
|
self.state = (self.state.1, next);
|
||||||
self.state = (self.state.1.clone(), next.to_string());
|
|
||||||
result
|
Some(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,14 +144,14 @@ fn capitalize(word: &str) -> String {
|
||||||
/// Join words from an iterator. The first word is always capitalized
|
/// Join words from an iterator. The first word is always capitalized
|
||||||
/// and the generated sentence will end with `'.'` if it doesn't
|
/// and the generated sentence will end with `'.'` if it doesn't
|
||||||
/// already end with some other ASCII punctuation character.
|
/// already end with some other ASCII punctuation character.
|
||||||
pub fn join_words<I: Iterator<Item = String>>(mut words: I) -> String {
|
pub fn join_words<'a, I: Iterator<Item = &'a str>>(mut words: I) -> String {
|
||||||
match words.next() {
|
match words.next() {
|
||||||
None => String::new(),
|
None => String::new(),
|
||||||
Some(word) => {
|
Some(word) => {
|
||||||
// Punctuation characters which ends a sentence.
|
// Punctuation characters which ends a sentence.
|
||||||
let punctuation: &[char] = &['.', '!', '?'];
|
let punctuation: &[char] = &['.', '!', '?'];
|
||||||
|
|
||||||
let mut sentence = capitalize(&word);
|
let mut sentence = capitalize(word);
|
||||||
let mut needs_cap = sentence.ends_with(punctuation);
|
let mut needs_cap = sentence.ends_with(punctuation);
|
||||||
|
|
||||||
// Add remaining words.
|
// Add remaining words.
|
||||||
|
@ -133,9 +159,9 @@ pub fn join_words<I: Iterator<Item = String>>(mut words: I) -> String {
|
||||||
sentence.push(' ');
|
sentence.push(' ');
|
||||||
|
|
||||||
if needs_cap {
|
if needs_cap {
|
||||||
sentence.push_str(&capitalize(&word));
|
sentence.push_str(&capitalize(word));
|
||||||
} else {
|
} else {
|
||||||
sentence.push_str(&word);
|
sentence.push_str(word);
|
||||||
}
|
}
|
||||||
|
|
||||||
needs_cap = word.ends_with(punctuation);
|
needs_cap = word.ends_with(punctuation);
|
||||||
|
@ -162,23 +188,20 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_load_error() {
|
fn test_load_error() {
|
||||||
let mut wurstsalat = WurstsalatGeneratorPro::new();
|
let result = WurstsalatGeneratorPro::learn_from_files(&["/does-not-exist".to_string()]);
|
||||||
let result = wurstsalat.learn_from_files(&["/does-not-exist".to_string()]);
|
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_load_ok() {
|
fn test_load_ok() {
|
||||||
let mut wurstsalat = WurstsalatGeneratorPro::new();
|
let result = WurstsalatGeneratorPro::learn_from_files(&["README.md".to_string()]);
|
||||||
let result = wurstsalat.learn_from_files(&["README.md".to_string()]);
|
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_generate() {
|
fn test_generate() {
|
||||||
let mut wurstsalat = WurstsalatGeneratorPro::new();
|
let wurstsalat =
|
||||||
wurstsalat
|
WurstsalatGeneratorPro::learn_from_files(&["tests/data/lorem-ipsum.txt".to_string()])
|
||||||
.learn_from_files(&["tests/data/lorem-ipsum.txt".to_string()])
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let mut rng = GobbledyGook::for_url("/test");
|
let mut rng = GobbledyGook::for_url("/test");
|
||||||
|
|
103
src/wurstsalat_generator_pro/substrings.rs
Normal file
103
src/wurstsalat_generator_pro/substrings.rs
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
use std::{collections::HashMap, str::CharIndices};
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
|
||||||
|
pub struct Substr {
|
||||||
|
pub start: usize,
|
||||||
|
pub end: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Substr {
|
||||||
|
pub fn extract_str<'a>(&'_ self, relative_to: &'a str) -> &'a str {
|
||||||
|
&relative_to[self.start..self.end]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalizes Substrs so that the same substring gets turned into the
|
||||||
|
// same Substr.
|
||||||
|
pub struct Interner<'a>(HashMap<&'a str, Substr>);
|
||||||
|
|
||||||
|
impl<'a> Interner<'a> {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self(HashMap::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn intern(&mut self, str: &'a str, substr: Substr) -> Substr {
|
||||||
|
*self
|
||||||
|
.0
|
||||||
|
.entry(&str[substr.start..substr.end])
|
||||||
|
.or_insert(substr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// An iterator that splits a string into Substrs on whitespace.
|
||||||
|
// Equivalent to the iterator returned by `str::split_whitespace`
|
||||||
|
// but returns `Substr`s instead of string slices.
|
||||||
|
pub struct WhitespaceSplitIterator<'a> {
|
||||||
|
underlying: CharIndices<'a>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> WhitespaceSplitIterator<'a> {
|
||||||
|
pub fn new(s: &'a str) -> Self {
|
||||||
|
Self {
|
||||||
|
underlying: s.char_indices(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Iterator for WhitespaceSplitIterator<'_> {
|
||||||
|
type Item = Substr;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let start = loop {
|
||||||
|
let (pos, c) = self.underlying.next()?;
|
||||||
|
if !c.is_whitespace() {
|
||||||
|
break pos;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let end = loop {
|
||||||
|
let Some((pos, c)) = self.underlying.next() else {
|
||||||
|
break self.underlying.offset();
|
||||||
|
};
|
||||||
|
if c.is_whitespace() {
|
||||||
|
break pos;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Some(Substr { start, end })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn compare_same(s: &str) {
|
||||||
|
let substrs = WhitespaceSplitIterator::new(s)
|
||||||
|
.map(|ss| ss.extract_str(s))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let std_split = s.split_whitespace().collect::<Vec<_>>();
|
||||||
|
|
||||||
|
assert_eq!(substrs, std_split);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn splits_simple_whitespace() {
|
||||||
|
compare_same("hello there world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn multiple_interior_whitespace() {
|
||||||
|
compare_same("hello\t\t\tthere world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn leading_whitespace() {
|
||||||
|
compare_same(" hello there world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn trailing_whitespace() {
|
||||||
|
compare_same(" hello there world");
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue