@@ -558,8 +558,10 @@ impl BytePairEncoding {
558558 #[ cfg( feature = "rand" ) ]
559559 pub fn encode_minimal_dropout ( & self , text : & [ u8 ] , dropout : f32 , seed : Option < u64 > ) -> Vec < u32 > {
560560 use rand:: rngs:: StdRng ;
561+ use rand:: seq:: IndexedRandom ;
561562 use rand:: Rng ;
562563 use rand:: SeedableRng ;
564+ use std:: collections:: HashSet ;
563565
564566 fn get_rng ( seed : Option < u64 > ) -> StdRng {
565567 match seed {
@@ -585,6 +587,13 @@ impl BytePairEncoding {
585587 assert ! ( dropout <= 1.0 ) ;
586588
587589 let mut last_token: Vec < ( u32 , u32 ) > = Vec :: with_capacity ( text. len ( ) ) ;
590+
591+ let allowed_tokens: Vec < u32 > = self . pair_lookup . values ( ) . cloned ( ) . collect ( ) ;
592+ let tokens_after_dropout = ( allowed_tokens. len ( ) as f32 ) * dropout;
593+ let forbidden_tokens_set: HashSet < & u32 > = HashSet :: from_iter (
594+ allowed_tokens. choose_multiple ( & mut rng, tokens_after_dropout. floor ( ) as usize ) ,
595+ ) ;
596+
588597 let mut state = self . overlapping_searcher . start_state ( ) ;
589598 for ( pos, c) in text. iter ( ) . enumerate ( ) {
590599 let ( s, iter) = self . overlapping_searcher . consume ( state, pos + 1 , * c) ;
@@ -594,12 +603,10 @@ impl BytePairEncoding {
594603 if m. start ( ) == 0 {
595604 best = ( m. value ( ) , 1 ) ;
596605 break ;
597- } else if last_token[ m. start ( ) - 1 ] . 1 + 1 < best. 1 {
598- if rng. random_range ( 0.0 ..=1.0 ) < dropout {
599- best = ( m. value ( ) , last_token[ m. start ( ) - 1 ] . 1 + 1 ) ;
600- } else {
601- best = ( m. value ( ) , 1 ) ;
602- }
606+ } else if ( last_token[ m. start ( ) - 1 ] . 1 + 1 < best. 1 )
607+ & !( forbidden_tokens_set. contains ( & m. value ( ) ) )
608+ {
609+ best = ( m. value ( ) , last_token[ m. start ( ) - 1 ] . 1 + 1 ) ;
603610 }
604611 }
605612 last_token. push ( best) ;
0 commit comments