Skip to content

Commit 35c047d

Browse files
aneubeckHendrik van Antwerpen
authored andcommitted
Replace look-ahead with multiple patterns ==> 3x speedup
1 parent f183341 commit 35c047d

2 files changed

Lines changed: 64 additions & 18 deletions

File tree

crates/bpe-openai/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ bench = false
1515
[dependencies]
1616
bpe = { version = "0.1.0", path = "../bpe" }
1717
either = "1.13"
18-
fancy-regex = "0.13"
18+
regex-automata = "0.4"
1919
rmp-serde = "1"
2020

2121
[dev-dependencies]

crates/bpe-openai/src/lib.rs

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::sync::LazyLock;
22

33
use bpe::byte_pair_encoding::BytePairEncoding;
44
use either::Either;
5-
use fancy_regex::Regex;
5+
use regex_automata::{meta::Regex, util::captures::Captures, Anchored, Input};
66

77
// Note: Below we rewrite the negative look-ahead with a positive pseudo look-ahead.
88
// The look-ahead character is dropped from the match by the Pretokenizer iterator.
@@ -11,23 +11,27 @@ use fancy_regex::Regex;
1111
static BPE_CL100K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
1212
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k_base.dict"));
1313
let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
14-
let pat = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
15-
Tokenizer::new(bpe, Some(pat)).expect("valid regex")
14+
let pat1 = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+";
15+
// Note: Rewrite the negative look-ahead with a positive pseudo look-ahead.
16+
// The look-ahead character is dropped from the match by the SpecialRegexp iterator.
17+
let pat2 = "\\s+\\s";
18+
let pat3 = "\\s+";
19+
Tokenizer::with_many(bpe, &[pat1, pat2, pat3]).expect("valid regex")
1620
});
1721

1822
static BPE_O200K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
1923
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_o200k_base.dict"));
2024
let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
21-
let pat = [
25+
let pat1 = [
2226
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
2327
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
2428
"\\p{N}{1,3}",
2529
" ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
2630
"\\s*[\\r\\n]+",
27-
"\\s+(?!\\S)",
28-
"\\s+",
2931
].join("|");
30-
Tokenizer::new(bpe, Some(&pat)).expect("valid regex")
32+
let pat2 = "\\s+\\s";
33+
let pat3 = "\\s+";
34+
Tokenizer::with_many(bpe, &[pat1.as_str(), pat2, pat3]).expect("valid regex")
3135
});
3236

3337
pub use bpe::*;
@@ -48,8 +52,15 @@ pub struct Tokenizer {
4852
impl Tokenizer {
4953
/// Build a tokenizer with an optional pretokenization regex pattern.
5054
#[allow(clippy::result_large_err)]
51-
pub fn new(bpe: BytePairEncoding, pat: Option<&str>) -> fancy_regex::Result<Self> {
52-
let pat = pat.map(fancy_regex::Regex::new).transpose()?;
55+
pub fn new(bpe: BytePairEncoding, pat: Option<&str>) -> Result<Self, ()> {
56+
let pat = pat.map(Regex::new).transpose().map_err(|_| ())?;
57+
Ok(Self { bpe, pat })
58+
}
59+
60+
/// When using multiple patterns, the second pattern is assumed to be a look-ahead pattern with
61+
/// exactly one look-ahead character!
62+
pub fn with_many(bpe: BytePairEncoding, patterns: &[&str]) -> Result<Self, ()> {
63+
let pat = Some(Regex::new_many(patterns).map_err(|_| ())?);
5364
Ok(Self { bpe, pat })
5465
}
5566

@@ -69,16 +80,51 @@ impl Tokenizer {
6980
String::from_utf8(self.bpe.decode_tokens(tokens)).ok()
7081
}
7182

72-
pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &str> + 'a {
83+
pub fn split<'a>(&'a self, input: &'a str) -> impl Iterator<Item = &str> + 'a {
7384
match &self.pat {
74-
Some(pat) => Either::Left(pat.find_iter(text).scan(0, |start, m| {
75-
let m = m.expect("match succeeded");
76-
assert_eq!(*start, m.start(), "pattern should match all input text");
77-
*start = m.end();
78-
Some(m.as_str())
79-
})),
80-
None => Either::Right(std::iter::once(text)),
85+
Some(pat) => Either::Left(SpecialRegexp {
86+
pat,
87+
input,
88+
last: 0,
89+
caps: Captures::matches(pat.group_info().clone()),
90+
}),
91+
None => Either::Right(std::iter::once(input)),
92+
}
93+
}
94+
}
95+
96+
/// This is a small wrapper around the regex which emulates the behaviour of look-ahead by
97+
/// dropping the look-ahead character from the match. The assumption here is that the
98+
/// second pattern is always a look-ahead pattern, and that just a single character needs
99+
/// to be dropped. With this little hack, we can keep most of the regex patterns as they are,
100+
/// but achieve a >3x speedup.
101+
///
102+
/// Alternatively, this could have been implemented with capture groups, but those were ~30%
103+
/// slower than this approach with multiple patterns.
104+
struct SpecialRegexp<'a> {
105+
pat: &'a Regex,
106+
input: &'a str,
107+
last: usize,
108+
caps: Captures,
109+
}
110+
111+
impl<'a> Iterator for SpecialRegexp<'a> {
112+
type Item = &'a str;
113+
114+
fn next(&mut self) -> Option<Self::Item> {
115+
let input = Input::new(&self.input[self.last..]).anchored(Anchored::Yes);
116+
self.caps.clear();
117+
self.pat.captures(input, &mut self.caps);
118+
let m = self.caps.get_match()?;
119+
let start = self.last;
120+
let mut end = self.last + m.range().end;
121+
if m.pattern() == 1.into() {
122+
let last = self.input[start..end].chars().rev().next().unwrap();
123+
end -= last.len_utf8();
124+
assert_ne!(end, start, "a look-ahead pattern must ALWAYS consume at least one character excluding the look-ahead character!");
81125
}
126+
self.last = end;
127+
Some(&self.input[start..end])
82128
}
83129
}
84130

0 commit comments

Comments
 (0)