Skip to content

Commit a336729

Browse files
committed
implement review comments
1 parent aa49a37 commit a336729

3 files changed

Lines changed: 32 additions & 6 deletions

File tree

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,11 +555,33 @@ impl BytePairEncoding {
555555
/// This function computes the shortest possible encoding sequence which will usually differ from the
556556
/// tokenization produced by the original BPE algorithm.
557557
#[cfg(feature = "rand")]
558-
pub fn encode_minimal_dropout(&self, text: &[u8], dropout: f32) -> Vec<u32> {
558+
pub fn encode_minimal_dropout(&self, text: &[u8], dropout: f32, seed: Option<u64>) -> Vec<u32> {
559+
use rand::rngs::StdRng;
559560
use rand::Rng;
561+
use rand::SeedableRng;
562+
563+
fn get_rng(seed: Option<u64>) -> StdRng {
564+
match seed {
565+
Some(num) => {
566+
// Expand the u64 seed to 32 bytes
567+
let mut seed_bytes = [0u8; 32];
568+
seed_bytes[..8].copy_from_slice(&num.to_le_bytes());
569+
StdRng::from_seed(seed_bytes)
570+
}
571+
None => {
572+
// Seed StdRng with a random 32-byte array from ThreadRng
573+
let mut thread_rng = rand::rng();
574+
let mut seed_bytes = [0u8; 32];
575+
thread_rng.fill(&mut seed_bytes);
576+
StdRng::from_seed(seed_bytes)
577+
}
578+
}
579+
}
580+
581+
let mut rng = get_rng(seed);
582+
560583
assert!(0.0 <= dropout);
561584
assert!(dropout <= 1.0);
562-
let mut rng = rand::rng();
563585

564586
let mut last_token: Vec<(u32, u32)> = Vec::with_capacity(text.len());
565587
let mut state = self.overlapping_searcher.start_state();
@@ -572,8 +594,9 @@ impl BytePairEncoding {
572594
best = (m.value(), 1);
573595
break;
574596
} else if last_token[m.start() - 1].1 + 1 < best.1 {
575-
best = (m.value(), last_token[m.start() - 1].1 + 1);
576597
if rng.random_range(0.0..=1.0) < dropout {
598+
best = (m.value(), last_token[m.start() - 1].1 + 1);
599+
} else {
577600
best = (m.value(), 1);
578601
}
579602
}

crates/bpe/tests/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ bpe-openai = { path = "../../bpe-openai" }
88
itertools = "0.14"
99
rand = "0.9"
1010
tiktoken-rs = "0.9"
11+
12+
[dev-dependencies]
13+
rand_chacha = { version = "0.9" }

crates/bpe/tests/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ mod tests {
149149
for _ in 0..8 {
150150
let input = create_test_bytes(bpe, bytes);
151151
let encoded = bpe.encode_minimal(&input);
152-
let encoded_d_min = bpe.encode_minimal_dropout(&input, 0.2);
153-
let encoded_d_max = bpe.encode_minimal_dropout(&input, 0.9);
154-
let encoded_d_max_again = bpe.encode_minimal_dropout(&input, 0.9);
152+
let encoded_d_min = bpe.encode_minimal_dropout(&input, 0.2, Some(0));
153+
let encoded_d_max = bpe.encode_minimal_dropout(&input, 0.9, Some(1));
154+
let encoded_d_max_again = bpe.encode_minimal_dropout(&input, 0.9, Some(2));
155155
let decoded = bpe.decode_tokens(&encoded);
156156
let decoded_min = bpe.decode_tokens(&encoded_d_min);
157157
let decoded_max = bpe.decode_tokens(&encoded_d_max);

0 commit comments

Comments
 (0)