Skip to content

Commit c215210

Browse files
committed
VERY Fast dropout implementation
1 parent 49c1f0d commit c215210

File tree

3 files changed

+86
-4
lines changed

3 files changed

+86
-4
lines changed

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -526,9 +526,9 @@ impl BytePairEncoding {
526526
/// tokenization produced by the original BPE algorithm.
527527
pub fn encode_minimal(&self, text: &[u8]) -> Vec<u32> {
528528
let mut last_token: Vec<(u32, u32)> = Vec::with_capacity(text.len());
529-
let mut state = self.overlapping_searcher.start_state();
530-
for (pos, c) in text.iter().enumerate() {
531-
let (s, iter) = self.overlapping_searcher.consume(state, pos + 1, *c);
529+
let mut state = self.overlapping_searcher_rev.start_state();
530+
for (pos, c) in text.iter().rev().enumerate() {
531+
let (s, iter) = self.overlapping_searcher_rev.consume(state, pos + 1, *c);
532532
state = s;
533533
let mut best = (0, u32::MAX);
534534
for m in iter {
@@ -548,7 +548,43 @@ impl BytePairEncoding {
548548
encoded.push(token);
549549
pos -= self.token_len(token);
550550
}
551-
encoded.reverse();
551+
encoded
552+
}
553+
554+
/// This function computes the encoding while randomly rejecting some merges.
555+
/// Result of the encoding will be non-deterministic unless `seed` is provided.
556+
/// Implementation loosely follows original BPE dropout paper: https://arxiv.org/abs/1910.13267
557+
#[cfg(feature = "rand")]
558+
pub fn encode_minimal_dropout<R: rand::Rng>(&self, text: &[u8], dropout: f32, mut rng: R) -> Vec<u32> {
559+
assert!(0.0 <= dropout);
560+
assert!(dropout <= 1.0);
561+
562+
let mut last_token: Vec<(u32, u32)> = Vec::with_capacity(text.len());
563+
let mut state = self.overlapping_searcher_rev.start_state();
564+
for (pos, c) in text.iter().rev().enumerate() {
565+
let (s, iter) = self.overlapping_searcher_rev.consume(state, pos + 1, *c);
566+
state = s;
567+
let mut best = (0, u32::MAX);
568+
for m in iter {
569+
if m.end() > m.start() + 1 && dropout >= rng.random() {
570+
continue;
571+
}
572+
if m.start() == 0 {
573+
best = (m.value(), 1);
574+
break;
575+
} else if last_token[m.start() - 1].1 + 1 < best.1 {
576+
best = (m.value(), last_token[m.start() - 1].1 + 1);
577+
}
578+
}
579+
last_token.push(best);
580+
}
581+
let mut encoded = Vec::with_capacity(last_token.last().map(|l| l.1 as usize).unwrap_or(0));
582+
let mut pos = text.len();
583+
while pos > 0 {
584+
let token = last_token[pos - 1].0;
585+
encoded.push(token);
586+
pos -= self.token_len(token);
587+
}
552588
encoded
553589
}
554590
}

crates/bpe/tests/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ bpe-openai = { path = "../../bpe-openai" }
88
itertools = "0.14"
99
rand = "0.9"
1010
tiktoken-rs = "0.9"
11+
12+
[dev-dependencies]
13+
rand_chacha = { version = "0.9" }

crates/bpe/tests/src/lib.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#[cfg(test)]
22
mod tests {
3+
use std::time;
4+
35
use itertools::Itertools;
46
use rand::{rng, Rng};
57
use tiktoken_rs::cl100k_base_singleton;
@@ -141,4 +143,45 @@ mod tests {
141143
assert_eq!(enc.token_count(), bpe.count(&input[i..]));
142144
}
143145
}
146+
147+
#[test]
148+
fn test_bpe_dropout() {
149+
use rand::rngs::StdRng;
150+
use rand::SeedableRng;
151+
152+
fn get_rng(seed: u64) -> StdRng {
153+
// Expand the u64 seed to 32 bytes
154+
let mut seed_bytes = [0u8; 32];
155+
seed_bytes[..8].copy_from_slice(&seed.to_le_bytes());
156+
StdRng::from_seed(seed_bytes)
157+
}
158+
159+
let bpe = &cl100k_base().bpe;
160+
for bytes in [10000, 20000] {
161+
for _ in 0..8 {
162+
let input = create_test_bytes(bpe, bytes);
163+
let encoded = bpe.encode_minimal(&input);
164+
let encoded_d_min = bpe.encode_minimal_dropout(&input, 0.2, get_rng(0));
165+
let encoded_d_max = bpe.encode_minimal_dropout(&input, 0.9, get_rng(1));
166+
let encoded_d_1_0 = bpe.encode_minimal_dropout(&input, 1.0, get_rng(2));
167+
let decoded = bpe.decode_tokens(&encoded);
168+
let decoded_min = bpe.decode_tokens(&encoded_d_min);
169+
let decoded_max = bpe.decode_tokens(&encoded_d_max);
170+
let decoded_max_again = bpe.decode_tokens(&encoded_d_1_0);
171+
println!("Input length: {}, Encoded length: {}, Encoded with dropout length: {}-{}, max {}",
172+
input.len(), encoded.len(), encoded_d_min.len(), encoded_d_max.len(), encoded_d_1_0.len());
173+
assert_eq!(input, decoded);
174+
assert_eq!(input, decoded_min);
175+
assert_eq!(input, decoded_max);
176+
assert_eq!(input, decoded_max_again);
177+
assert_eq!(input.len(), encoded_d_1_0.len());
178+
assert!(encoded_d_min.len() >= encoded.len());
179+
assert!(encoded_d_max.len() > encoded.len());
180+
181+
assert_ne!(encoded, encoded_d_min);
182+
assert_ne!(encoded, encoded_d_max);
183+
assert_ne!(encoded_d_max, encoded_d_1_0);
184+
}
185+
}
186+
}
144187
}

0 commit comments

Comments
 (0)