@@ -555,11 +555,33 @@ impl BytePairEncoding {
555555 /// This function computes the shortest possible encoding sequence which will usually differ from the
556556 /// tokenization produced by the original BPE algorithm.
557557 #[ cfg( feature = "rand" ) ]
558- pub fn encode_minimal_dropout ( & self , text : & [ u8 ] , dropout : f32 ) -> Vec < u32 > {
558+ pub fn encode_minimal_dropout ( & self , text : & [ u8 ] , dropout : f32 , seed : Option < u64 > ) -> Vec < u32 > {
559+ use rand:: rngs:: StdRng ;
559560 use rand:: Rng ;
561+ use rand:: SeedableRng ;
562+
563+ fn get_rng ( seed : Option < u64 > ) -> StdRng {
564+ match seed {
565+ Some ( num) => {
566+ // Expand the u64 seed to 32 bytes
567+ let mut seed_bytes = [ 0u8 ; 32 ] ;
568+ seed_bytes[ ..8 ] . copy_from_slice ( & num. to_le_bytes ( ) ) ;
569+ StdRng :: from_seed ( seed_bytes)
570+ }
571+ None => {
572+ // Seed StdRng with a random 32-byte array from ThreadRng
573+ let mut thread_rng = rand:: rng ( ) ;
574+ let mut seed_bytes = [ 0u8 ; 32 ] ;
575+ thread_rng. fill ( & mut seed_bytes) ;
576+ StdRng :: from_seed ( seed_bytes)
577+ }
578+ }
579+ }
580+
581+ let mut rng = get_rng ( seed) ;
582+
560583 assert ! ( 0.0 <= dropout) ;
561584 assert ! ( dropout <= 1.0 ) ;
562- let mut rng = rand:: rng ( ) ;
563585
564586 let mut last_token: Vec < ( u32 , u32 ) > = Vec :: with_capacity ( text. len ( ) ) ;
565587 let mut state = self . overlapping_searcher . start_state ( ) ;
@@ -572,8 +594,9 @@ impl BytePairEncoding {
572594 best = ( m. value ( ) , 1 ) ;
573595 break ;
574596 } else if last_token[ m. start ( ) - 1 ] . 1 + 1 < best. 1 {
575- best = ( m. value ( ) , last_token[ m. start ( ) - 1 ] . 1 + 1 ) ;
576597 if rng. random_range ( 0.0 ..=1.0 ) < dropout {
598+ best = ( m. value ( ) , last_token[ m. start ( ) - 1 ] . 1 + 1 ) ;
599+ } else {
577600 best = ( m. value ( ) , 1 ) ;
578601 }
579602 }
0 commit comments