Skip to content

Commit 08d4200

Browse files
committed
implement dropout with forbidden_tokens
1 parent 23c8ec9 commit 08d4200

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -558,8 +558,10 @@ impl BytePairEncoding {
558558
#[cfg(feature = "rand")]
559559
pub fn encode_minimal_dropout(&self, text: &[u8], dropout: f32, seed: Option<u64>) -> Vec<u32> {
560560
use rand::rngs::StdRng;
561+
use rand::seq::IndexedRandom;
561562
use rand::Rng;
562563
use rand::SeedableRng;
564+
use std::collections::HashSet;
563565

564566
fn get_rng(seed: Option<u64>) -> StdRng {
565567
match seed {
@@ -585,6 +587,13 @@ impl BytePairEncoding {
585587
assert!(dropout <= 1.0);
586588

587589
let mut last_token: Vec<(u32, u32)> = Vec::with_capacity(text.len());
590+
591+
let allowed_tokens: Vec<u32> = self.pair_lookup.values().cloned().collect();
592+
let tokens_after_dropout = (allowed_tokens.len() as f32) * dropout;
593+
let forbidden_tokens_set: HashSet<&u32> = HashSet::from_iter(
594+
allowed_tokens.choose_multiple(&mut rng, tokens_after_dropout.floor() as usize),
595+
);
596+
588597
let mut state = self.overlapping_searcher.start_state();
589598
for (pos, c) in text.iter().enumerate() {
590599
let (s, iter) = self.overlapping_searcher.consume(state, pos + 1, *c);
@@ -594,12 +603,10 @@ impl BytePairEncoding {
594603
if m.start() == 0 {
595604
best = (m.value(), 1);
596605
break;
597-
} else if last_token[m.start() - 1].1 + 1 < best.1 {
598-
if rng.random_range(0.0..=1.0) < dropout {
599-
best = (m.value(), last_token[m.start() - 1].1 + 1);
600-
} else {
601-
best = (m.value(), 1);
602-
}
606+
} else if (last_token[m.start() - 1].1 + 1 < best.1)
607+
& !(forbidden_tokens_set.contains(&m.value()))
608+
{
609+
best = (m.value(), last_token[m.start() - 1].1 + 1);
603610
}
604611
}
605612
last_token.push(best);

0 commit comments

Comments
 (0)