Skip to content

Commit 5493b12

Browse files
committed
remove duplicated code
1 parent 8ad286a commit 5493b12

File tree

1 file changed

+6
-71
lines changed

1 file changed

+6
-71
lines changed

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 6 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,12 @@ impl BytePairEncoding {
555555
/// Result of the encoding will be non-deterministic unless `seed` is provided.
556556
/// Implementation loosely follows original BPE dropout paper: https://arxiv.org/abs/1910.13267
557557
#[cfg(feature = "rand")]
558-
pub fn encode_minimal_dropout<R: rand::Rng>(&self, text: &[u8], dropout: f32, mut rng: R) -> Vec<u32> {
558+
pub fn encode_minimal_dropout<R: rand::Rng>(
559+
&self,
560+
text: &[u8],
561+
dropout: f32,
562+
mut rng: R,
563+
) -> Vec<u32> {
559564
assert!(0.0 <= dropout);
560565
assert!(dropout <= 1.0);
561566

@@ -587,76 +592,6 @@ impl BytePairEncoding {
587592
}
588593
encoded
589594
}
590-
591-
/// This function computes the encoding while randomly rejecting some merges.
592-
/// Result of the encoding will be non-deterministic unless `seed` is provided.
593-
/// Implementation loosely follows original BPE dropout paper: https://arxiv.org/abs/1910.13267
594-
#[cfg(feature = "rand")]
595-
pub fn encode_minimal_dropout(&self, text: &[u8], dropout: f32, seed: Option<u64>) -> Vec<u32> {
596-
use rand::rngs::StdRng;
597-
use rand::seq::IndexedRandom;
598-
use rand::Rng;
599-
use rand::SeedableRng;
600-
use std::collections::HashSet;
601-
602-
fn get_rng(seed: Option<u64>) -> StdRng {
603-
match seed {
604-
Some(num) => {
605-
// Expand the u64 seed to 32 bytes
606-
let mut seed_bytes = [0u8; 32];
607-
seed_bytes[..8].copy_from_slice(&num.to_le_bytes());
608-
StdRng::from_seed(seed_bytes)
609-
}
610-
None => {
611-
// Seed StdRng with a random 32-byte array from ThreadRng
612-
let mut thread_rng = rand::rng();
613-
let mut seed_bytes = [0u8; 32];
614-
thread_rng.fill(&mut seed_bytes);
615-
StdRng::from_seed(seed_bytes)
616-
}
617-
}
618-
}
619-
620-
let mut rng = get_rng(seed);
621-
622-
assert!(0.0 <= dropout);
623-
assert!(dropout <= 1.0);
624-
625-
let mut last_token: Vec<(u32, u32)> = Vec::with_capacity(text.len());
626-
627-
let allowed_tokens: Vec<u32> = self.pair_lookup.values().cloned().collect();
628-
let tokens_after_dropout = (allowed_tokens.len() as f32) * dropout;
629-
let forbidden_tokens_set: HashSet<&u32> = HashSet::from_iter(
630-
allowed_tokens.choose_multiple(&mut rng, tokens_after_dropout.floor() as usize),
631-
);
632-
633-
let mut state = self.overlapping_searcher.start_state();
634-
for (pos, c) in text.iter().enumerate() {
635-
let (s, iter) = self.overlapping_searcher.consume(state, pos + 1, *c);
636-
state = s;
637-
let mut best = (0, u32::MAX);
638-
for m in iter {
639-
if m.start() == 0 {
640-
best = (m.value(), 1);
641-
break;
642-
} else if (last_token[m.start() - 1].1 + 1 < best.1)
643-
& (!(forbidden_tokens_set.contains(&m.value())) | ((m.end() - m.start()) == 1))
644-
{
645-
best = (m.value(), last_token[m.start() - 1].1 + 1);
646-
}
647-
}
648-
last_token.push(best);
649-
}
650-
let mut encoded = Vec::with_capacity(last_token.last().map(|l| l.1 as usize).unwrap_or(0));
651-
let mut pos = text.len();
652-
while pos > 0 {
653-
let token = last_token[pos - 1].0;
654-
encoded.push(token);
655-
pos -= self.token_len(token);
656-
}
657-
encoded.reverse();
658-
encoded
659-
}
660595
}
661596

662597
/// Create a random test string for the given [`BytePairEncoding`]. The string will be at least [`min_bytes`] long.

0 commit comments

Comments
 (0)