@@ -63,6 +63,8 @@ pub struct BytePairEncoding {
6363 /// But we don't have efficient access to it and therefore store it here again.
6464 /// If there is none, then the value is set to u32::MAX.
6565 next_prefix_match : Vec < u32 > ,
66+ /// Hash factor used to prevent hash collisions.
67+ hash_factor : u64 ,
6668}
6769
6870fn serialize_daac < S : Serializer > (
@@ -156,11 +158,7 @@ fn token_bytes<'a>(all_tokens: &'a [u8], token_starts: &[u32], token_id: u32) ->
156158 & all_tokens[ token_range ( token_starts, token_id) ]
157159}
158160
159- fn hash_bytes ( bytes : & [ u8 ] ) -> u32 {
160- hash_bytes_with_factor ( bytes, 17846336922010275747 )
161- }
162-
163- fn hash_bytes_with_factor ( bytes : & [ u8 ] , factor : u64 ) -> u32 {
161+ fn hash_bytes ( bytes : & [ u8 ] , factor : u64 ) -> u32 {
164162 let mut hasher = FnvHasher :: default ( ) ;
165163 bytes. hash ( & mut hasher) ;
166164 // Note: we save 1/3 of space for the hashmap by only using the most significant bits of the hash.
@@ -173,8 +171,9 @@ fn find_token_by_bytes(
173171 token_starts : & [ u32 ] ,
174172 bytes_hash_to_token : & FnvHashMap < u32 , u32 > ,
175173 bytes : & [ u8 ] ,
174+ hash_factor : u64 ,
176175) -> Option < u32 > {
177- let hash = hash_bytes ( bytes) ;
176+ let hash = hash_bytes ( bytes, hash_factor ) ;
178177 let token = * bytes_hash_to_token. get ( & hash) ?;
179178 if token_bytes ( all_tokens, token_starts, token) == bytes {
180179 Some ( token)
@@ -193,19 +192,31 @@ impl BytePairEncoding {
193192 }
194193
195194 /// Construct a BytePairEncoding instance frmo a tiktoken dictionary.
195+ /// A suitable hash factor may be necessary to prevent hash collisions. You can find on eusing the [`find_hash_factor`] test.
196196 #[ cfg( feature = "tiktoken-rs" ) ]
197- pub fn from_tiktoken ( tiktoken_bpe : & tiktoken_rs:: CoreBPE , num_tokens : usize ) -> Self {
198- Self :: from_dictionary ( ( 0 ..num_tokens) . map ( |i| tiktoken_bpe. _decode_native ( & [ i] ) ) )
197+ pub fn from_tiktoken (
198+ tiktoken_bpe : & tiktoken_rs:: CoreBPE ,
199+ num_tokens : usize ,
200+ hash_factor : Option < u64 > ,
201+ ) -> Self {
202+ Self :: from_dictionary (
203+ ( 0 ..num_tokens) . map ( |i| tiktoken_bpe. _decode_native ( & [ i] ) ) ,
204+ hash_factor,
205+ )
199206 }
200207
201208 /// Construct a BytePairEncoding instance from an iterator which enumerates all tokens.
202- pub fn from_dictionary ( iter : impl Iterator < Item = Vec < u8 > > ) -> Self {
209+ /// A suitable hash factor may be necessary to prevent hash collisions. You can find on eusing the [`find_hash_factor`] test.
210+ pub fn from_dictionary ( iter : impl Iterator < Item = Vec < u8 > > , hash_factor : Option < u64 > ) -> Self {
211+ let hash_factor = hash_factor
212+ . inspect ( |f| assert_ne ! ( * f, 0 , "hash factor must be larger than zero" ) )
213+ . unwrap_or ( 1 ) ;
203214 let mut all_tokens = Vec :: new ( ) ;
204215 let mut all_tokens_rev = Vec :: new ( ) ;
205216 let mut token_starts = vec ! [ 0 ] ;
206217 let mut bytes_hash_to_token = FnvHashMap :: default ( ) ;
207218 for ( i, token) in iter. enumerate ( ) {
208- bytes_hash_to_token. insert ( hash_bytes ( & token) , i as u32 ) ;
219+ bytes_hash_to_token. insert ( hash_bytes ( & token, hash_factor ) , i as u32 ) ;
209220 all_tokens_rev. extend ( token. iter ( ) . copied ( ) . rev ( ) ) ;
210221 all_tokens. extend ( token) ;
211222 token_starts. push ( all_tokens. len ( ) as u32 ) ;
@@ -236,9 +247,13 @@ impl BytePairEncoding {
236247 let mut token1 = next_prefix_match[ id] ;
237248 while token1 != u32:: MAX {
238249 let rest = & token[ token_range ( & token_starts, token1) . len ( ) ..] ;
239- if let Some ( token2) =
240- find_token_by_bytes ( & all_tokens, & token_starts, & bytes_hash_to_token, rest)
241- {
250+ if let Some ( token2) = find_token_by_bytes (
251+ & all_tokens,
252+ & token_starts,
253+ & bytes_hash_to_token,
254+ rest,
255+ hash_factor,
256+ ) {
242257 if token1 < id as u32
243258 && token2 < id as u32
244259 && is_valid_token_pair ( & pair_lookup, & split_table, token1, token2)
@@ -264,6 +279,7 @@ impl BytePairEncoding {
264279 next_prefix_match,
265280 pair_lookup,
266281 split_table,
282+ hash_factor,
267283 }
268284 }
269285
@@ -308,6 +324,7 @@ impl BytePairEncoding {
308324 & self . token_starts ,
309325 & self . bytes_hash_to_token ,
310326 bytes,
327+ self . hash_factor ,
311328 )
312329 }
313330
@@ -563,32 +580,29 @@ mod data {
563580
564581 use rand:: Rng ;
565582 use serde:: Serialize ;
566- use tiktoken_rs:: { cl100k_base, o200k_base} ;
583+ use tiktoken_rs:: { cl100k_base, o200k_base, CoreBPE } ;
567584
568585 use super :: * ;
569586
570587 const BPE_CL100K_LEN : usize = 100256 ;
571588 const BPE_O200K_LEN : usize = 199998 ;
572589
573590 /// Use this to find a hashing factor for [`hash_bytes`] that prevents collisions.
574- /// 1. Ensure all supported tokenizers are in the list .
591+ /// 1. Set the `(bpe, len)` value to the tiktoken tokenizer you want to find a hash factor for .
575592 /// 2. Update the hash factor in [`hash_bytes`].
576593 /// 3. Run [`update_token_dicts`] tests below to update data files.
594+ /// Note: If you forget this, the next test run will update the files, but
595+ /// all other tests might fail because the data was not up-to-date.
577596 #[ test]
578597 #[ ignore = "run manually to find a suitable hash factor" ]
598+ #[ allow( unreachable_code, unused_variables) ]
579599 fn find_hash_factor ( ) {
580- let bpes = & mut [
581- ( cl100k_base ( ) . unwrap ( ) , BPE_CL100K_LEN ) ,
582- ( o200k_base ( ) . unwrap ( ) , BPE_O200K_LEN ) ,
583- ] ;
600+ let ( bpe, len) : ( CoreBPE , _ ) = todo ! ( "replace with BPE instance and token count" ) ;
584601 let mut rnd = rand:: thread_rng ( ) ;
585602 loop {
586603 let factor: u64 = rnd. gen ( ) ;
587- if bpes. iter ( ) . all ( |( bpe, len) | {
588- let mut seen = HashSet :: with_capacity ( * len) ;
589- ( 0 ..* len)
590- . all ( |i| seen. insert ( hash_bytes_with_factor ( & bpe. _decode_native ( & [ i] ) , factor) ) )
591- } ) {
604+ let mut seen = HashSet :: with_capacity ( len) ;
605+ if ( 0 ..len) . all ( |i| seen. insert ( hash_bytes ( & bpe. _decode_native ( & [ i] ) , factor) ) ) {
592606 println ! ( "hash factor: {factor}" ) ;
593607 return ;
594608 }
@@ -598,27 +612,34 @@ mod data {
598612 #[ test]
599613 fn update_token_dicts ( ) {
600614 serialize_tokens (
615+ "cl100k" ,
601616 & cl100k_base ( ) . expect ( "tiktoken initialization must not fail!" ) ,
602617 BPE_CL100K_LEN ,
603- "cl100k" ,
618+ 17846336922010275747 ,
604619 ) ;
605620 serialize_tokens (
621+ "o200k" ,
606622 & o200k_base ( ) . expect ( "tiktoken initialization must not fail!" ) ,
607623 BPE_O200K_LEN ,
608- "o200k" ,
624+ 17846336922010275747 ,
609625 ) ;
610626 }
611627
612628 #[ track_caller]
613- fn serialize_tokens ( dict : & tiktoken_rs:: CoreBPE , num_tokens : usize , name : & str ) {
629+ fn serialize_tokens (
630+ name : & str ,
631+ dict : & tiktoken_rs:: CoreBPE ,
632+ num_tokens : usize ,
633+ hash_factor : u64 ,
634+ ) {
614635 let path = PathBuf :: from ( file ! ( ) ) ;
615636 let dir = path. parent ( ) . unwrap ( ) ;
616637 let data_file = dir. join ( format ! ( "data/bpe_{name}.dict" ) ) ;
617638 let current_dir = std:: env:: current_dir ( ) . unwrap ( ) ;
618639 let abs_path = current_dir. parent ( ) . unwrap ( ) . parent ( ) . unwrap ( ) ;
619640 let file = File :: create ( abs_path. join ( data_file) ) . unwrap ( ) ;
620641 let mut serializer = rmp_serde:: Serializer :: new ( file) ;
621- BytePairEncoding :: from_tiktoken ( dict, num_tokens)
642+ BytePairEncoding :: from_tiktoken ( dict, num_tokens, Some ( hash_factor ) )
622643 . serialize ( & mut serializer)
623644 . unwrap ( ) ;
624645 }
0 commit comments