mirror of
https://git.madhouse-project.org/algernon/iocaine.git
synced 2025-03-10 17:28:49 +01:00
Do less allocating and copying when generating text
Before, on a low-capacity system (such as a an inexpensive cloud host), doing Markov-chain text generation was _extraordinarily_ slow, taking half a second or more to produce a page, and if multiple requests came in simultaneously they could easily swamp the capacity of such a system. Most of the time was spent in the Words iterator, which did a bunch of cloning of Strings in what the hot path. This changes the Markov generator's internal representation - now, instead of storing Strings, it stores index-pairs into a single shared String, normalized so that all references to particular words are collapsed into a single pair. This also means that the hash map is working with fixed-size values, which can't hurt. In addition, it does only one hash-map lookup per generated word in the happy-path of not reaching the end of the chain. The upshot of all this is that where it was taking a half-second or more to generate a page, it now takes about 0.001 seconds. On the downside, the initialization of WurstsalatGeneratorPro has become rather less flexible. Before, you created one and then taught it various strings, or gave it a list of paths to read and teach itself from. Now, the _only_ way to create one is directly with a list of paths. Changing this is possible, but it means `Substr` would have to learn to distinguish which source data it came from, which would mean a likely 50% increase in its size. It didn't seem worth it to preserve that capability, which wasn't even being used.
This commit is contained in:
parent
d2bdbf750a
commit
284af56e68
4 changed files with 179 additions and 54 deletions
|
@ -29,9 +29,8 @@ pub struct Iocaine {
|
|||
|
||||
impl Iocaine {
|
||||
pub fn new(config: Config) -> Result<Self> {
|
||||
let mut chain = WurstsalatGeneratorPro::new();
|
||||
tracing::info!("Loading the markov chain corpus");
|
||||
chain.learn_from_files(&config.sources.markov)?;
|
||||
let chain = WurstsalatGeneratorPro::learn_from_files(&config.sources.markov)?;
|
||||
tracing::info!("Corpus loaded");
|
||||
|
||||
tracing::info!("Loading word salad");
|
||||
|
|
|
@ -68,8 +68,8 @@ impl AssembledStatisticalSequences {
|
|||
|
||||
let mut rng = GobbledyGook::for_url(format!("{}://{}/{}", group, &static_seed, index));
|
||||
let words = (1..=count)
|
||||
.map(|_| iocaine.words.0.choose(&mut rng).unwrap().to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.map(|_| iocaine.words.0.choose(&mut rng).unwrap().as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("-");
|
||||
out.write(&words)?;
|
||||
|
||||
|
|
|
@ -10,12 +10,16 @@
|
|||
use rand::{seq::IndexedRandom, Rng};
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io;
|
||||
use std::io::Read as _;
|
||||
use substrings::{Interner, Substr, WhitespaceSplitIterator};
|
||||
|
||||
pub type Bigram = (String, String);
|
||||
mod substrings;
|
||||
|
||||
type Bigram = (Substr, Substr);
|
||||
#[derive(Debug, Default)]
|
||||
pub struct WurstsalatGeneratorPro {
|
||||
map: HashMap<Bigram, Vec<String>>,
|
||||
string: String,
|
||||
map: HashMap<Bigram, Vec<Substr>>,
|
||||
keys: Vec<Bigram>,
|
||||
}
|
||||
|
||||
|
@ -24,42 +28,63 @@ impl WurstsalatGeneratorPro {
|
|||
Default::default()
|
||||
}
|
||||
|
||||
fn learn(&mut self, text: String) {
|
||||
let words = text.split_whitespace().collect::<Vec<&str>>();
|
||||
for window in words.windows(3) {
|
||||
let (a, b, c) = (
|
||||
window[0].to_string(),
|
||||
window[1].to_string(),
|
||||
window[2].to_string(),
|
||||
);
|
||||
self.map.entry((a, b)).or_default().push(c);
|
||||
fn learn(string: String, mut breaks: &[usize]) -> Self {
|
||||
let mut interner = Interner::new();
|
||||
let words = WhitespaceSplitIterator::new(&string);
|
||||
let mut map = HashMap::<Bigram, Vec<Substr>>::new();
|
||||
for window in words.collect::<Vec<_>>().windows(3) {
|
||||
let (a, b, c) = (window[0], window[1], window[2]);
|
||||
|
||||
// This bit of weirdness is to preserve the behavior from
|
||||
// learning from multiple files independently; if our
|
||||
// current window spans a break, we don't add the triple.
|
||||
let mut skip_triple = false;
|
||||
while !breaks.is_empty() && breaks[0] <= c.start {
|
||||
if breaks[0] <= a.start {
|
||||
// completely passed the first break, can remove it
|
||||
breaks = &breaks[1..];
|
||||
} else {
|
||||
skip_triple = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !skip_triple {
|
||||
map.entry((interner.intern(&string, a), interner.intern(&string, b)))
|
||||
.or_default()
|
||||
.push(interner.intern(&string, c));
|
||||
}
|
||||
}
|
||||
self.keys = self.map.keys().cloned().collect();
|
||||
self.keys.sort_unstable();
|
||||
|
||||
let mut keys = map.keys().copied().collect::<Vec<_>>();
|
||||
keys.sort_unstable_by_key(|(s1, s2)| {
|
||||
(&string[s1.start..s1.end], &string[s2.start..s2.end])
|
||||
});
|
||||
|
||||
Self { string, keys, map }
|
||||
}
|
||||
|
||||
pub fn learn_from_files(&mut self, files: &[String]) -> Result<(), std::io::Error> {
|
||||
pub fn learn_from_files(files: &[String]) -> Result<Self, std::io::Error> {
|
||||
let mut s = String::new();
|
||||
let mut breaks = Vec::new();
|
||||
for source in files.iter() {
|
||||
let f = File::open(source)?;
|
||||
let s = io::read_to_string(f)?.clone();
|
||||
self.learn(s);
|
||||
let mut f = File::open(source)?;
|
||||
f.read_to_string(&mut s)?;
|
||||
breaks.push(s.len());
|
||||
s.push(' ');
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(Self::learn(s, &breaks))
|
||||
}
|
||||
|
||||
pub fn generate<R: Rng>(&self, mut rng: R) -> Words<'_, R> {
|
||||
let initial_bigram = if self.map.is_empty() {
|
||||
("".to_string(), "".to_string())
|
||||
} else {
|
||||
let (a, b) = self.keys.choose(&mut rng).unwrap();
|
||||
(a.to_string(), b.to_string())
|
||||
};
|
||||
let initial_bigram = self.keys.choose(&mut rng).cloned().unwrap_or_default();
|
||||
self.iter_with_rng_from(rng, initial_bigram)
|
||||
}
|
||||
|
||||
fn iter_with_rng_from<R: Rng>(&self, rng: R, from: Bigram) -> Words<'_, R> {
|
||||
Words {
|
||||
string: self.string.as_str(),
|
||||
map: &self.map,
|
||||
rng,
|
||||
keys: &self.keys,
|
||||
|
@ -70,30 +95,31 @@ impl WurstsalatGeneratorPro {
|
|||
|
||||
#[derive(Clone)]
|
||||
pub struct Words<'a, R: Rng> {
|
||||
map: &'a HashMap<Bigram, Vec<String>>,
|
||||
string: &'a str,
|
||||
map: &'a HashMap<Bigram, Vec<Substr>>,
|
||||
rng: R,
|
||||
keys: &'a Vec<Bigram>,
|
||||
keys: &'a [Bigram],
|
||||
state: Bigram,
|
||||
}
|
||||
|
||||
impl<'a, R: Rng> Iterator for Words<'a, R> {
|
||||
type Item = String;
|
||||
type Item = &'a str;
|
||||
|
||||
fn next(&mut self) -> Option<String> {
|
||||
fn next(&mut self) -> Option<&'a str> {
|
||||
if self.map.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let result = Some(self.state.0.clone());
|
||||
let result = self.state.0.extract_str(self.string);
|
||||
|
||||
while !self.map.contains_key(&self.state) {
|
||||
let (a, b) = self.keys.choose(&mut self.rng).unwrap();
|
||||
self.state = (a.to_string(), b.to_string());
|
||||
}
|
||||
let next_words = &self.map[&self.state];
|
||||
let next = next_words.choose(&mut self.rng).unwrap();
|
||||
self.state = (self.state.1.clone(), next.to_string());
|
||||
result
|
||||
let next_words = self.map.get(&self.state).unwrap_or_else(|| {
|
||||
self.state = *self.keys.choose(&mut self.rng).unwrap();
|
||||
&self.map[&self.state]
|
||||
});
|
||||
let next = *next_words.choose(&mut self.rng).unwrap();
|
||||
self.state = (self.state.1, next);
|
||||
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -118,14 +144,14 @@ fn capitalize(word: &str) -> String {
|
|||
/// Join words from an iterator. The first word is always capitalized
|
||||
/// and the generated sentence will end with `'.'` if it doesn't
|
||||
/// already end with some other ASCII punctuation character.
|
||||
pub fn join_words<I: Iterator<Item = String>>(mut words: I) -> String {
|
||||
pub fn join_words<'a, I: Iterator<Item = &'a str>>(mut words: I) -> String {
|
||||
match words.next() {
|
||||
None => String::new(),
|
||||
Some(word) => {
|
||||
// Punctuation characters which ends a sentence.
|
||||
let punctuation: &[char] = &['.', '!', '?'];
|
||||
|
||||
let mut sentence = capitalize(&word);
|
||||
let mut sentence = capitalize(word);
|
||||
let mut needs_cap = sentence.ends_with(punctuation);
|
||||
|
||||
// Add remaining words.
|
||||
|
@ -133,9 +159,9 @@ pub fn join_words<I: Iterator<Item = String>>(mut words: I) -> String {
|
|||
sentence.push(' ');
|
||||
|
||||
if needs_cap {
|
||||
sentence.push_str(&capitalize(&word));
|
||||
sentence.push_str(&capitalize(word));
|
||||
} else {
|
||||
sentence.push_str(&word);
|
||||
sentence.push_str(word);
|
||||
}
|
||||
|
||||
needs_cap = word.ends_with(punctuation);
|
||||
|
@ -162,24 +188,21 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_load_error() {
|
||||
let mut wurstsalat = WurstsalatGeneratorPro::new();
|
||||
let result = wurstsalat.learn_from_files(&["/does-not-exist".to_string()]);
|
||||
let result = WurstsalatGeneratorPro::learn_from_files(&["/does-not-exist".to_string()]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_ok() {
|
||||
let mut wurstsalat = WurstsalatGeneratorPro::new();
|
||||
let result = wurstsalat.learn_from_files(&["README.md".to_string()]);
|
||||
let result = WurstsalatGeneratorPro::learn_from_files(&["README.md".to_string()]);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate() {
|
||||
let mut wurstsalat = WurstsalatGeneratorPro::new();
|
||||
wurstsalat
|
||||
.learn_from_files(&["tests/data/lorem-ipsum.txt".to_string()])
|
||||
.unwrap();
|
||||
let wurstsalat =
|
||||
WurstsalatGeneratorPro::learn_from_files(&["tests/data/lorem-ipsum.txt".to_string()])
|
||||
.unwrap();
|
||||
|
||||
let mut rng = GobbledyGook::for_url("/test");
|
||||
let words = wurstsalat.generate(&mut rng).take(1);
|
||||
|
|
103
src/wurstsalat_generator_pro/substrings.rs
Normal file
103
src/wurstsalat_generator_pro/substrings.rs
Normal file
|
@ -0,0 +1,103 @@
|
|||
use std::{collections::HashMap, str::CharIndices};
|
||||
|
||||
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
|
||||
pub struct Substr {
|
||||
pub start: usize,
|
||||
pub end: usize,
|
||||
}
|
||||
|
||||
impl Substr {
|
||||
pub fn extract_str<'a>(&'_ self, relative_to: &'a str) -> &'a str {
|
||||
&relative_to[self.start..self.end]
|
||||
}
|
||||
}
|
||||
|
||||
// Normalizes Substrs so that the same substring gets turned into the
|
||||
// same Substr.
|
||||
pub struct Interner<'a>(HashMap<&'a str, Substr>);
|
||||
|
||||
impl<'a> Interner<'a> {
|
||||
pub fn new() -> Self {
|
||||
Self(HashMap::new())
|
||||
}
|
||||
|
||||
pub fn intern(&mut self, str: &'a str, substr: Substr) -> Substr {
|
||||
*self
|
||||
.0
|
||||
.entry(&str[substr.start..substr.end])
|
||||
.or_insert(substr)
|
||||
}
|
||||
}
|
||||
|
||||
// An iterator that splits a string into Substrs on whitespace.
|
||||
// Equivalent to the iterator returned by `str::split_whitespace`
|
||||
// but returns `Substr`s instead of string slices.
|
||||
pub struct WhitespaceSplitIterator<'a> {
|
||||
underlying: CharIndices<'a>,
|
||||
}
|
||||
|
||||
impl<'a> WhitespaceSplitIterator<'a> {
|
||||
pub fn new(s: &'a str) -> Self {
|
||||
Self {
|
||||
underlying: s.char_indices(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for WhitespaceSplitIterator<'_> {
|
||||
type Item = Substr;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let start = loop {
|
||||
let (pos, c) = self.underlying.next()?;
|
||||
if !c.is_whitespace() {
|
||||
break pos;
|
||||
}
|
||||
};
|
||||
|
||||
let end = loop {
|
||||
let Some((pos, c)) = self.underlying.next() else {
|
||||
break self.underlying.offset();
|
||||
};
|
||||
if c.is_whitespace() {
|
||||
break pos;
|
||||
}
|
||||
};
|
||||
|
||||
Some(Substr { start, end })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn compare_same(s: &str) {
|
||||
let substrs = WhitespaceSplitIterator::new(s)
|
||||
.map(|ss| ss.extract_str(s))
|
||||
.collect::<Vec<_>>();
|
||||
let std_split = s.split_whitespace().collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(substrs, std_split);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn splits_simple_whitespace() {
|
||||
compare_same("hello there world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_interior_whitespace() {
|
||||
compare_same("hello\t\t\tthere world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn leading_whitespace() {
|
||||
compare_same(" hello there world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trailing_whitespace() {
|
||||
compare_same(" hello there world");
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue