Skip to content

Commit 8ad286a

Browse files
committed
merge aneubeck version
2 parents a10cce2 + c215210 commit 8ad286a

File tree

2 files changed

+64
-10
lines changed

2 files changed

+64
-10
lines changed

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -526,9 +526,9 @@ impl BytePairEncoding {
526526
/// tokenization produced by the original BPE algorithm.
527527
pub fn encode_minimal(&self, text: &[u8]) -> Vec<u32> {
528528
let mut last_token: Vec<(u32, u32)> = Vec::with_capacity(text.len());
529-
let mut state = self.overlapping_searcher.start_state();
530-
for (pos, c) in text.iter().enumerate() {
531-
let (s, iter) = self.overlapping_searcher.consume(state, pos + 1, *c);
529+
let mut state = self.overlapping_searcher_rev.start_state();
530+
for (pos, c) in text.iter().rev().enumerate() {
531+
let (s, iter) = self.overlapping_searcher_rev.consume(state, pos + 1, *c);
532532
state = s;
533533
let mut best = (0, u32::MAX);
534534
for m in iter {
@@ -548,7 +548,43 @@ impl BytePairEncoding {
548548
encoded.push(token);
549549
pos -= self.token_len(token);
550550
}
551-
encoded.reverse();
551+
encoded
552+
}
553+
554+
/// This function computes the encoding while randomly rejecting some merges.
555+
/// Result of the encoding will be non-deterministic unless `seed` is provided.
556+
/// Implementation loosely follows original BPE dropout paper: https://arxiv.org/abs/1910.13267
557+
#[cfg(feature = "rand")]
558+
pub fn encode_minimal_dropout<R: rand::Rng>(&self, text: &[u8], dropout: f32, mut rng: R) -> Vec<u32> {
559+
assert!(0.0 <= dropout);
560+
assert!(dropout <= 1.0);
561+
562+
let mut last_token: Vec<(u32, u32)> = Vec::with_capacity(text.len());
563+
let mut state = self.overlapping_searcher_rev.start_state();
564+
for (pos, c) in text.iter().rev().enumerate() {
565+
let (s, iter) = self.overlapping_searcher_rev.consume(state, pos + 1, *c);
566+
state = s;
567+
let mut best = (0, u32::MAX);
568+
for m in iter {
569+
if m.end() > m.start() + 1 && dropout >= rng.random() {
570+
continue;
571+
}
572+
if m.start() == 0 {
573+
best = (m.value(), 1);
574+
break;
575+
} else if last_token[m.start() - 1].1 + 1 < best.1 {
576+
best = (m.value(), last_token[m.start() - 1].1 + 1);
577+
}
578+
}
579+
last_token.push(best);
580+
}
581+
let mut encoded = Vec::with_capacity(last_token.last().map(|l| l.1 as usize).unwrap_or(0));
582+
let mut pos = text.len();
583+
while pos > 0 {
584+
let token = last_token[pos - 1].0;
585+
encoded.push(token);
586+
pos -= self.token_len(token);
587+
}
552588
encoded
553589
}
554590

crates/bpe/tests/src/lib.rs

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#[cfg(test)]
22
mod tests {
3+
use std::time;
4+
35
use itertools::Itertools;
46
use rand::{rng, Rng};
57
use tiktoken_rs::cl100k_base_singleton;
@@ -144,25 +146,41 @@ mod tests {
144146

145147
#[test]
146148
fn test_bpe_dropout() {
149+
use rand::rngs::StdRng;
150+
use rand::SeedableRng;
151+
152+
fn get_rng(seed: u64) -> StdRng {
153+
// Expand the u64 seed to 32 bytes
154+
let mut seed_bytes = [0u8; 32];
155+
seed_bytes[..8].copy_from_slice(&seed.to_le_bytes());
156+
StdRng::from_seed(seed_bytes)
157+
}
158+
147159
let bpe = &cl100k_base().bpe;
148160
for bytes in [10000, 20000] {
149161
for _ in 0..8 {
150162
let input = create_test_bytes(bpe, bytes);
151163
let encoded = bpe.encode_minimal(&input);
152-
let encoded_d_min = bpe.encode_minimal_dropout(&input, 0.2, Some(0));
153-
let encoded_d_max = bpe.encode_minimal_dropout(&input, 0.9, Some(1));
154-
let encoded_d_max_again = bpe.encode_minimal_dropout(&input, 0.9, Some(2));
164+
let encoded_d_min = bpe.encode_minimal_dropout(&input, 0.2, get_rng(0));
165+
let encoded_d_max = bpe.encode_minimal_dropout(&input, 0.9, get_rng(1));
166+
let encoded_d_1_0 = bpe.encode_minimal_dropout(&input, 1.0, get_rng(2));
155167
let decoded = bpe.decode_tokens(&encoded);
156168
let decoded_min = bpe.decode_tokens(&encoded_d_min);
157169
let decoded_max = bpe.decode_tokens(&encoded_d_max);
158-
assert_eq!(decoded, decoded_min);
159-
assert_eq!(decoded, decoded_max);
170+
let decoded_max_again = bpe.decode_tokens(&encoded_d_1_0);
171+
println!("Input length: {}, Encoded length: {}, Encoded with dropout length: {}-{}, max {}",
172+
input.len(), encoded.len(), encoded_d_min.len(), encoded_d_max.len(), encoded_d_1_0.len());
173+
assert_eq!(input, decoded);
174+
assert_eq!(input, decoded_min);
175+
assert_eq!(input, decoded_max);
176+
assert_eq!(input, decoded_max_again);
177+
assert_eq!(input.len(), encoded_d_1_0.len());
160178
assert!(encoded_d_min.len() >= encoded.len());
161179
assert!(encoded_d_max.len() > encoded.len());
162180

163181
assert_ne!(encoded, encoded_d_min);
164182
assert_ne!(encoded, encoded_d_max);
165-
assert_ne!(encoded_d_max, encoded_d_max_again);
183+
assert_ne!(encoded_d_max, encoded_d_1_0);
166184
}
167185
}
168186
}

0 commit comments

Comments
 (0)