Do less allocating and copying when generating text

Before, on a low-capacity system (such as a an inexpensive cloud host),
doing Markov-chain text generation was _extraordinarily_ slow, taking
half a second or more to produce a page, and if multiple requests came
in simultaneously they could easily swamp the capacity of such a system.

Most of the time was spent in the Words iterator, which did a bunch of
cloning of Strings in what the hot path.

This changes the Markov generator's internal representation - now, instead
of storing Strings, it stores index-pairs into a single shared String,
normalized so that all references to particular words are collapsed into
a single pair.  This also means that the hash map is working with
fixed-size values, which can't hurt.

In addition, it does only one hash-map lookup per generated word in the
happy-path of not reaching the end of the chain.

The upshot of all this is that where it was taking a half-second or more
to generate a page, it now takes about 0.001 seconds.

On the downside, the initialization of WurstsalatGeneratorPro has become
rather less flexible.  Before, you created one and then taught it various
strings, or gave it a list of paths to read and teach itself from.  Now,
the _only_ way to create one is directly with a list of paths.  Changing
this is possible, but it means `Substr` would have to learn to distinguish
which source data it came from, which would mean a likely 50% increase in
its size.  It didn't seem worth it to preserve that capability, which
wasn't even being used.
This commit is contained in:
iadd 2025-02-10 08:12:15 -08:00
parent d2bdbf750a
commit 284af56e68
4 changed files with 179 additions and 54 deletions

View file

@ -29,9 +29,8 @@ pub struct Iocaine {
impl Iocaine {
pub fn new(config: Config) -> Result<Self> {
let mut chain = WurstsalatGeneratorPro::new();
tracing::info!("Loading the markov chain corpus");
chain.learn_from_files(&config.sources.markov)?;
let chain = WurstsalatGeneratorPro::learn_from_files(&config.sources.markov)?;
tracing::info!("Corpus loaded");
tracing::info!("Loading word salad");

View file

@ -68,8 +68,8 @@ impl AssembledStatisticalSequences {
let mut rng = GobbledyGook::for_url(format!("{}://{}/{}", group, &static_seed, index));
let words = (1..=count)
.map(|_| iocaine.words.0.choose(&mut rng).unwrap().to_string())
.collect::<Vec<String>>()
.map(|_| iocaine.words.0.choose(&mut rng).unwrap().as_str())
.collect::<Vec<_>>()
.join("-");
out.write(&words)?;

View file

@ -10,12 +10,16 @@
use rand::{seq::IndexedRandom, Rng};
use std::collections::HashMap;
use std::fs::File;
use std::io;
use std::io::Read as _;
use substrings::{Interner, Substr, WhitespaceSplitIterator};
pub type Bigram = (String, String);
mod substrings;
type Bigram = (Substr, Substr);
#[derive(Debug, Default)]
pub struct WurstsalatGeneratorPro {
map: HashMap<Bigram, Vec<String>>,
string: String,
map: HashMap<Bigram, Vec<Substr>>,
keys: Vec<Bigram>,
}
@ -24,42 +28,63 @@ impl WurstsalatGeneratorPro {
Default::default()
}
fn learn(&mut self, text: String) {
let words = text.split_whitespace().collect::<Vec<&str>>();
for window in words.windows(3) {
let (a, b, c) = (
window[0].to_string(),
window[1].to_string(),
window[2].to_string(),
);
self.map.entry((a, b)).or_default().push(c);
fn learn(string: String, mut breaks: &[usize]) -> Self {
let mut interner = Interner::new();
let words = WhitespaceSplitIterator::new(&string);
let mut map = HashMap::<Bigram, Vec<Substr>>::new();
for window in words.collect::<Vec<_>>().windows(3) {
let (a, b, c) = (window[0], window[1], window[2]);
// This bit of weirdness is to preserve the behavior from
// learning from multiple files independently; if our
// current window spans a break, we don't add the triple.
let mut skip_triple = false;
while !breaks.is_empty() && breaks[0] <= c.start {
if breaks[0] <= a.start {
// completely passed the first break, can remove it
breaks = &breaks[1..];
} else {
skip_triple = true;
break;
}
}
if !skip_triple {
map.entry((interner.intern(&string, a), interner.intern(&string, b)))
.or_default()
.push(interner.intern(&string, c));
}
}
self.keys = self.map.keys().cloned().collect();
self.keys.sort_unstable();
let mut keys = map.keys().copied().collect::<Vec<_>>();
keys.sort_unstable_by_key(|(s1, s2)| {
(&string[s1.start..s1.end], &string[s2.start..s2.end])
});
Self { string, keys, map }
}
pub fn learn_from_files(&mut self, files: &[String]) -> Result<(), std::io::Error> {
pub fn learn_from_files(files: &[String]) -> Result<Self, std::io::Error> {
let mut s = String::new();
let mut breaks = Vec::new();
for source in files.iter() {
let f = File::open(source)?;
let s = io::read_to_string(f)?.clone();
self.learn(s);
let mut f = File::open(source)?;
f.read_to_string(&mut s)?;
breaks.push(s.len());
s.push(' ');
}
Ok(())
Ok(Self::learn(s, &breaks))
}
pub fn generate<R: Rng>(&self, mut rng: R) -> Words<'_, R> {
let initial_bigram = if self.map.is_empty() {
("".to_string(), "".to_string())
} else {
let (a, b) = self.keys.choose(&mut rng).unwrap();
(a.to_string(), b.to_string())
};
let initial_bigram = self.keys.choose(&mut rng).cloned().unwrap_or_default();
self.iter_with_rng_from(rng, initial_bigram)
}
fn iter_with_rng_from<R: Rng>(&self, rng: R, from: Bigram) -> Words<'_, R> {
Words {
string: self.string.as_str(),
map: &self.map,
rng,
keys: &self.keys,
@ -70,30 +95,31 @@ impl WurstsalatGeneratorPro {
#[derive(Clone)]
pub struct Words<'a, R: Rng> {
map: &'a HashMap<Bigram, Vec<String>>,
string: &'a str,
map: &'a HashMap<Bigram, Vec<Substr>>,
rng: R,
keys: &'a Vec<Bigram>,
keys: &'a [Bigram],
state: Bigram,
}
impl<'a, R: Rng> Iterator for Words<'a, R> {
type Item = String;
type Item = &'a str;
fn next(&mut self) -> Option<String> {
fn next(&mut self) -> Option<&'a str> {
if self.map.is_empty() {
return None;
}
let result = Some(self.state.0.clone());
let result = self.state.0.extract_str(self.string);
while !self.map.contains_key(&self.state) {
let (a, b) = self.keys.choose(&mut self.rng).unwrap();
self.state = (a.to_string(), b.to_string());
}
let next_words = &self.map[&self.state];
let next = next_words.choose(&mut self.rng).unwrap();
self.state = (self.state.1.clone(), next.to_string());
result
let next_words = self.map.get(&self.state).unwrap_or_else(|| {
self.state = *self.keys.choose(&mut self.rng).unwrap();
&self.map[&self.state]
});
let next = *next_words.choose(&mut self.rng).unwrap();
self.state = (self.state.1, next);
Some(result)
}
}
@ -118,14 +144,14 @@ fn capitalize(word: &str) -> String {
/// Join words from an iterator. The first word is always capitalized
/// and the generated sentence will end with `'.'` if it doesn't
/// already end with some other ASCII punctuation character.
pub fn join_words<I: Iterator<Item = String>>(mut words: I) -> String {
pub fn join_words<'a, I: Iterator<Item = &'a str>>(mut words: I) -> String {
match words.next() {
None => String::new(),
Some(word) => {
// Punctuation characters which ends a sentence.
let punctuation: &[char] = &['.', '!', '?'];
let mut sentence = capitalize(&word);
let mut sentence = capitalize(word);
let mut needs_cap = sentence.ends_with(punctuation);
// Add remaining words.
@ -133,9 +159,9 @@ pub fn join_words<I: Iterator<Item = String>>(mut words: I) -> String {
sentence.push(' ');
if needs_cap {
sentence.push_str(&capitalize(&word));
sentence.push_str(&capitalize(word));
} else {
sentence.push_str(&word);
sentence.push_str(word);
}
needs_cap = word.ends_with(punctuation);
@ -162,24 +188,21 @@ mod tests {
#[test]
fn test_load_error() {
let mut wurstsalat = WurstsalatGeneratorPro::new();
let result = wurstsalat.learn_from_files(&["/does-not-exist".to_string()]);
let result = WurstsalatGeneratorPro::learn_from_files(&["/does-not-exist".to_string()]);
assert!(result.is_err());
}
#[test]
fn test_load_ok() {
let mut wurstsalat = WurstsalatGeneratorPro::new();
let result = wurstsalat.learn_from_files(&["README.md".to_string()]);
let result = WurstsalatGeneratorPro::learn_from_files(&["README.md".to_string()]);
assert!(result.is_ok());
}
#[test]
fn test_generate() {
let mut wurstsalat = WurstsalatGeneratorPro::new();
wurstsalat
.learn_from_files(&["tests/data/lorem-ipsum.txt".to_string()])
.unwrap();
let wurstsalat =
WurstsalatGeneratorPro::learn_from_files(&["tests/data/lorem-ipsum.txt".to_string()])
.unwrap();
let mut rng = GobbledyGook::for_url("/test");
let words = wurstsalat.generate(&mut rng).take(1);

View file

@ -0,0 +1,103 @@
use std::{collections::HashMap, str::CharIndices};
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
pub struct Substr {
pub start: usize,
pub end: usize,
}
impl Substr {
pub fn extract_str<'a>(&'_ self, relative_to: &'a str) -> &'a str {
&relative_to[self.start..self.end]
}
}
// Normalizes Substrs so that the same substring gets turned into the
// same Substr.
pub struct Interner<'a>(HashMap<&'a str, Substr>);
impl<'a> Interner<'a> {
pub fn new() -> Self {
Self(HashMap::new())
}
pub fn intern(&mut self, str: &'a str, substr: Substr) -> Substr {
*self
.0
.entry(&str[substr.start..substr.end])
.or_insert(substr)
}
}
// An iterator that splits a string into Substrs on whitespace.
// Equivalent to the iterator returned by `str::split_whitespace`
// but returns `Substr`s instead of string slices.
pub struct WhitespaceSplitIterator<'a> {
underlying: CharIndices<'a>,
}
impl<'a> WhitespaceSplitIterator<'a> {
pub fn new(s: &'a str) -> Self {
Self {
underlying: s.char_indices(),
}
}
}
impl Iterator for WhitespaceSplitIterator<'_> {
type Item = Substr;
fn next(&mut self) -> Option<Self::Item> {
let start = loop {
let (pos, c) = self.underlying.next()?;
if !c.is_whitespace() {
break pos;
}
};
let end = loop {
let Some((pos, c)) = self.underlying.next() else {
break self.underlying.offset();
};
if c.is_whitespace() {
break pos;
}
};
Some(Substr { start, end })
}
}
#[cfg(test)]
mod tests {
use super::*;
fn compare_same(s: &str) {
let substrs = WhitespaceSplitIterator::new(s)
.map(|ss| ss.extract_str(s))
.collect::<Vec<_>>();
let std_split = s.split_whitespace().collect::<Vec<_>>();
assert_eq!(substrs, std_split);
}
#[test]
fn splits_simple_whitespace() {
compare_same("hello there world");
}
#[test]
fn multiple_interior_whitespace() {
compare_same("hello\t\t\tthere world");
}
#[test]
fn leading_whitespace() {
compare_same(" hello there world");
}
#[test]
fn trailing_whitespace() {
compare_same(" hello there world");
}
}