Skip to content

Commit 8c9fda4

Browse files
authored
Merge pull request #98 from marinegor/feature/add-dropout
Implement dropout for `encode_minimal`
2 parents 49c1f0d + 83dd2f1 commit 8c9fda4

File tree

9 files changed

+387
-378
lines changed

9 files changed

+387
-378
lines changed

crates/bpe/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ We benchmarked the following scenarios:
203203
The data structure we built specifically for this purpose can answer those interval counting requests in typically constant times after the initial linear preprocessing of the text.
204204
This mode is not available in tiktoken, which only supports counting/encoding a complete text.
205205

206-
All benchmarks were run single-threaded on a MacBook Pro M1.
206+
All benchmarks were run single-threaded on a MacBook Air M4.
207207

208208
### Encoding
209209

@@ -219,6 +219,7 @@ Two additional encoders are included that are faster but deviate from the origin
219219

220220
- The greedy encoder picks the left-longest token.
221221
- The minimal encoder computes an encoding with the minimal number of tokens.
222+
- The minimal_dropout encoder implements BPE-Dropout [algorithm](https://arxiv.org/abs/1910.13267), randomly ignoring some multi-byte tokens at runtime. Note that this implementation differs from the paper, and **has not** been tested in an actual language model training pipeline.
222223

223224
The benchmark measured the runtime of encoding of slices of lengths 10, 100, 1000, and 10000 from a random 20000 token original text using the o200k token set.
224225
(All encodings were computed from scratch for each slice.)

crates/bpe/benchmarks/performance.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ use bpe_benchmarks::*;
99
use criterion::{
1010
criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, PlotConfiguration,
1111
};
12+
use rand::rngs::StdRng;
13+
use rand::SeedableRng;
1214
use rand::{rng, Rng};
1315

1416
fn counting_benchmark(c: &mut Criterion) {
@@ -92,6 +94,17 @@ fn encoding_benchmark(c: &mut Criterion) {
9294
criterion::BatchSize::SmallInput,
9395
)
9496
});
97+
group.bench_with_input(
98+
BenchmarkId::new("minimal_dropout", bytes),
99+
&bytes,
100+
|b, bytes| {
101+
b.iter_batched(
102+
|| select_test_string(&text, *bytes),
103+
|text| bpe.bpe.encode_minimal_dropout(text.as_bytes(), 0.1, rng()),
104+
criterion::BatchSize::SmallInput,
105+
)
106+
},
107+
);
95108
group.bench_with_input(
96109
BenchmarkId::new("huggingface", bytes),
97110
&bytes,

crates/bpe/images/performance-appending.svg

Lines changed: 35 additions & 62 deletions
Loading

crates/bpe/images/performance-comparison.svg

Lines changed: 44 additions & 74 deletions
Loading

crates/bpe/images/performance-counting.svg

Lines changed: 35 additions & 62 deletions
Loading

crates/bpe/images/performance-encoding.svg

Lines changed: 74 additions & 86 deletions
Loading

crates/bpe/images/performance-worstcase.svg

Lines changed: 73 additions & 89 deletions
Loading

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -526,9 +526,9 @@ impl BytePairEncoding {
526526
/// tokenization produced by the original BPE algorithm.
527527
pub fn encode_minimal(&self, text: &[u8]) -> Vec<u32> {
528528
let mut last_token: Vec<(u32, u32)> = Vec::with_capacity(text.len());
529-
let mut state = self.overlapping_searcher.start_state();
530-
for (pos, c) in text.iter().enumerate() {
531-
let (s, iter) = self.overlapping_searcher.consume(state, pos + 1, *c);
529+
let mut state = self.overlapping_searcher_rev.start_state();
530+
for (pos, c) in text.iter().rev().enumerate() {
531+
let (s, iter) = self.overlapping_searcher_rev.consume(state, pos + 1, *c);
532532
state = s;
533533
let mut best = (0, u32::MAX);
534534
for m in iter {
@@ -548,7 +548,66 @@ impl BytePairEncoding {
548548
encoded.push(token);
549549
pos -= self.token_len(token);
550550
}
551-
encoded.reverse();
551+
encoded
552+
}
553+
554+
/// This function computes the encoding while randomly rejecting some merges.
555+
/// Result of the encoding will be non-deterministic unless `seed` is provided.
556+
/// Implementation loosely follows original BPE dropout paper: https://arxiv.org/abs/1910.13267
557+
///
558+
/// In more detail: the tokenization uses dynamic programming, i.e. it models the tokenization as a graph,
559+
/// where every position between text bytes is a node and two nodes are connected when the text slice between those two nodes matches a token.
560+
// It then tries to find the shortest possible path from the beginning of the text till the end, i.e. it finds the shortest possible encoding.
561+
// For this nodes are processed from right to left. At each node, edges starting at that node and ending on the right are tested and
562+
// the one producing the shortest path is stored together with the length of the shortest path to that node.
563+
// The length of the shortest path is stored as second value, the edge (or rather token) is stored as first value.
564+
// Then, we walk in reverse direction through the table along the shortest path.
565+
// Note: the reason for constructing the table from back to front is that
566+
// the reconstruction outputs the path from start till end (i.e. we don't have to reverse the path afterwards).
567+
//
568+
// For the dropout (when dropout > 0.0), we uniformly drop edges from the graph, but always keep the one-byte tokens such that the graph stays connected.
569+
// Note: this is very different from how BPE works and cannot produce the same output as the algorithm
570+
// in the [paper's repository](https://github.com/VProv/BPE-Dropout/blob/master/bpe.py#L98), for two main reasons:
571+
// - `encode_minimal` already doesn't follow the original heap-based BPE procedure
572+
// - BPE-dropout authors discard all multi-byte tokens for each word separately, while this implementation does not split the "sentence" into words first
573+
// and hence may include previously discarded token later down the byte stream. At the sentence level though we don't expect it to make much difference.
574+
// Also, this implementation of BPE constructs merges on the fly from the set of tokens, hence might come up with a different set of merges with the same dictionary.
575+
#[cfg(feature = "rand")]
576+
pub fn encode_minimal_dropout<R: rand::Rng>(
577+
&self,
578+
text: &[u8],
579+
dropout: f32,
580+
mut rng: R,
581+
) -> Vec<u32> {
582+
assert!(0.0 <= dropout);
583+
assert!(dropout <= 1.0);
584+
585+
let mut last_token: Vec<(u32, u32)> = Vec::with_capacity(text.len());
586+
let mut state = self.overlapping_searcher_rev.start_state();
587+
for (pos, c) in text.iter().rev().enumerate() {
588+
let (s, iter) = self.overlapping_searcher_rev.consume(state, pos + 1, *c);
589+
state = s;
590+
let mut best = (0, u32::MAX);
591+
for m in iter {
592+
if m.end() > m.start() + 1 && dropout >= rng.random() {
593+
continue;
594+
}
595+
if m.start() == 0 {
596+
best = (m.value(), 1);
597+
break;
598+
} else if last_token[m.start() - 1].1 + 1 < best.1 {
599+
best = (m.value(), last_token[m.start() - 1].1 + 1);
600+
}
601+
}
602+
last_token.push(best);
603+
}
604+
let mut encoded = Vec::with_capacity(last_token.last().map(|l| l.1 as usize).unwrap_or(0));
605+
let mut pos = text.len();
606+
while pos > 0 {
607+
let token = last_token[pos - 1].0;
608+
encoded.push(token);
609+
pos -= self.token_len(token);
610+
}
552611
encoded
553612
}
554613
}

crates/bpe/tests/src/lib.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,52 @@ mod tests {
141141
assert_eq!(enc.token_count(), bpe.count(&input[i..]));
142142
}
143143
}
144+
145+
#[test]
146+
fn test_bpe_dropout() {
147+
use rand::rngs::StdRng;
148+
use rand::SeedableRng;
149+
150+
fn get_rng(seed: u64) -> StdRng {
151+
// Expand the u64 seed to 32 bytes
152+
let mut seed_bytes = [0u8; 32];
153+
seed_bytes[..8].copy_from_slice(&seed.to_le_bytes());
154+
StdRng::from_seed(seed_bytes)
155+
}
156+
157+
let bpe = &cl100k_base().bpe;
158+
let bytes = 10000;
159+
for _ in 0..8 {
160+
let input = create_test_bytes(bpe, bytes);
161+
let encoded = bpe.encode_minimal(&input);
162+
let encoded_d_0_2 = bpe.encode_minimal_dropout(&input, 0.2, get_rng(0));
163+
let encoded_d_0_9 = bpe.encode_minimal_dropout(&input, 0.9, get_rng(1));
164+
let encoded_d_1_0 = bpe.encode_minimal_dropout(&input, 1.0, get_rng(1));
165+
let encoded_d_0_9_again = bpe.encode_minimal_dropout(&input, 0.9, get_rng(1));
166+
let decoded = bpe.decode_tokens(&encoded);
167+
let decoded_min = bpe.decode_tokens(&encoded_d_0_2);
168+
let decoded_max = bpe.decode_tokens(&encoded_d_0_9);
169+
let decoded_max_again = bpe.decode_tokens(&encoded_d_0_9_again);
170+
println!(
171+
"Input length: {}, Encoded length: {}, Encoded with dropout length: {}-{}, max {}",
172+
input.len(),
173+
encoded.len(),
174+
encoded_d_0_2.len(),
175+
encoded_d_0_9.len(),
176+
encoded_d_0_9_again.len()
177+
);
178+
assert_eq!(encoded_d_0_9, encoded_d_0_9_again);
179+
assert_eq!(input, decoded);
180+
assert_eq!(input, decoded_min);
181+
assert_eq!(input, decoded_max);
182+
assert_eq!(input, decoded_max_again);
183+
assert_eq!(input.len(), encoded_d_1_0.len());
184+
assert!(encoded_d_0_2.len() >= encoded.len());
185+
assert!(encoded_d_0_9.len() > encoded.len());
186+
187+
assert_ne!(encoded, encoded_d_0_2);
188+
assert_ne!(encoded, encoded_d_0_9);
189+
assert_ne!(encoded_d_0_9, encoded_d_1_0);
190+
}
191+
}
144192
}

0 commit comments

Comments
 (0)