@@ -555,7 +555,12 @@ impl BytePairEncoding {
555555 /// Result of the encoding will be non-deterministic unless `seed` is provided.
556556 /// Implementation loosely follows original BPE dropout paper: https://arxiv.org/abs/1910.13267
557557 #[ cfg( feature = "rand" ) ]
558- pub fn encode_minimal_dropout < R : rand:: Rng > ( & self , text : & [ u8 ] , dropout : f32 , mut rng : R ) -> Vec < u32 > {
558+ pub fn encode_minimal_dropout < R : rand:: Rng > (
559+ & self ,
560+ text : & [ u8 ] ,
561+ dropout : f32 ,
562+ mut rng : R ,
563+ ) -> Vec < u32 > {
559564 assert ! ( 0.0 <= dropout) ;
560565 assert ! ( dropout <= 1.0 ) ;
561566
@@ -587,76 +592,6 @@ impl BytePairEncoding {
587592 }
588593 encoded
589594 }
590-
591- /// This function computes the encoding while randomly rejecting some merges.
592- /// Result of the encoding will be non-deterministic unless `seed` is provided.
593- /// Implementation loosely follows original BPE dropout paper: https://arxiv.org/abs/1910.13267
594- #[ cfg( feature = "rand" ) ]
595- pub fn encode_minimal_dropout ( & self , text : & [ u8 ] , dropout : f32 , seed : Option < u64 > ) -> Vec < u32 > {
596- use rand:: rngs:: StdRng ;
597- use rand:: seq:: IndexedRandom ;
598- use rand:: Rng ;
599- use rand:: SeedableRng ;
600- use std:: collections:: HashSet ;
601-
602- fn get_rng ( seed : Option < u64 > ) -> StdRng {
603- match seed {
604- Some ( num) => {
605- // Expand the u64 seed to 32 bytes
606- let mut seed_bytes = [ 0u8 ; 32 ] ;
607- seed_bytes[ ..8 ] . copy_from_slice ( & num. to_le_bytes ( ) ) ;
608- StdRng :: from_seed ( seed_bytes)
609- }
610- None => {
611- // Seed StdRng with a random 32-byte array from ThreadRng
612- let mut thread_rng = rand:: rng ( ) ;
613- let mut seed_bytes = [ 0u8 ; 32 ] ;
614- thread_rng. fill ( & mut seed_bytes) ;
615- StdRng :: from_seed ( seed_bytes)
616- }
617- }
618- }
619-
620- let mut rng = get_rng ( seed) ;
621-
622- assert ! ( 0.0 <= dropout) ;
623- assert ! ( dropout <= 1.0 ) ;
624-
625- let mut last_token: Vec < ( u32 , u32 ) > = Vec :: with_capacity ( text. len ( ) ) ;
626-
627- let allowed_tokens: Vec < u32 > = self . pair_lookup . values ( ) . cloned ( ) . collect ( ) ;
628- let tokens_after_dropout = ( allowed_tokens. len ( ) as f32 ) * dropout;
629- let forbidden_tokens_set: HashSet < & u32 > = HashSet :: from_iter (
630- allowed_tokens. choose_multiple ( & mut rng, tokens_after_dropout. floor ( ) as usize ) ,
631- ) ;
632-
633- let mut state = self . overlapping_searcher . start_state ( ) ;
634- for ( pos, c) in text. iter ( ) . enumerate ( ) {
635- let ( s, iter) = self . overlapping_searcher . consume ( state, pos + 1 , * c) ;
636- state = s;
637- let mut best = ( 0 , u32:: MAX ) ;
638- for m in iter {
639- if m. start ( ) == 0 {
640- best = ( m. value ( ) , 1 ) ;
641- break ;
642- } else if ( last_token[ m. start ( ) - 1 ] . 1 + 1 < best. 1 )
643- & ( !( forbidden_tokens_set. contains ( & m. value ( ) ) ) | ( ( m. end ( ) - m. start ( ) ) == 1 ) )
644- {
645- best = ( m. value ( ) , last_token[ m. start ( ) - 1 ] . 1 + 1 ) ;
646- }
647- }
648- last_token. push ( best) ;
649- }
650- let mut encoded = Vec :: with_capacity ( last_token. last ( ) . map ( |l| l. 1 as usize ) . unwrap_or ( 0 ) ) ;
651- let mut pos = text. len ( ) ;
652- while pos > 0 {
653- let token = last_token[ pos - 1 ] . 0 ;
654- encoded. push ( token) ;
655- pos -= self . token_len ( token) ;
656- }
657- encoded. reverse ( ) ;
658- encoded
659- }
660595}
661596
662597/// Create a random test string for the given [`BytePairEncoding`]. The string will be at least [`min_bytes`] long.
0 commit comments