@@ -166,6 +166,32 @@ fn hash_bytes(bytes: &[u8], factor: u64) -> u32 {
166166 ( ( hasher. finish ( ) . wrapping_mul ( factor) ) >> 32 ) as u32
167167}
168168
169+ /// Find a suitable hash factor for the given tiktoken dictionary that prevents collisions
170+ /// when constructing a [`BytePairEncoding`] from those tokens.
171+ #[ cfg( all( feature = "tiktoken-rs" , feature = "rand" ) ) ]
172+ pub fn find_hash_factor_from_tiktoken ( bpe : & tiktoken_rs:: CoreBPE , len : usize ) -> u64 {
173+ find_hash_factor ( |i| bpe. _decode_native ( & [ i] ) , len)
174+ }
175+
176+ /// Find a suitable hash factor for a set of given tokens that prevents collisions when
177+ /// constructing a [`BytePairEncoding`] from those tokens.
178+ #[ cfg( feature = "rand" ) ]
179+ pub fn find_hash_factor ( tokens : impl Fn ( usize ) -> Vec < u8 > , len : usize ) -> u64 {
180+ use std:: collections:: HashSet ;
181+
182+ use rand:: Rng ;
183+
184+ let mut rnd = rand:: thread_rng ( ) ;
185+ loop {
186+ let factor: u64 = rnd. gen ( ) ;
187+ let mut seen = HashSet :: with_capacity ( len) ;
188+ if ( 0 ..len) . all ( |i| seen. insert ( hash_bytes ( & tokens ( i) , factor) ) ) {
189+ println ! ( "hash factor: {factor}" ) ;
190+ return factor;
191+ }
192+ }
193+ }
194+
169195fn find_token_by_bytes (
170196 all_tokens : & [ u8 ] ,
171197 token_starts : & [ u32 ] ,
@@ -191,8 +217,12 @@ impl BytePairEncoding {
191217 & BPE_O200K
192218 }
193219
194- /// Construct a BytePairEncoding instance frmo a tiktoken dictionary.
195- /// A suitable hash factor may be necessary to prevent hash collisions. You can find on eusing the [`find_hash_factor`] test.
220+ /// Construct a BytePairEncoding instance from a tiktoken dictionary.
221+ /// A suitable hash factor may be necessary to prevent hash collisions,
222+ /// which can by found using [`crate::data::find_hash_factor_from_tiktoken`].
223+ ///
224+ /// The recommended approach is to store the serialized value and reuse that,
225+ /// to prevent repeating the cost of computing the hash factor and encoding.
196226 #[ cfg( feature = "tiktoken-rs" ) ]
197227 pub fn from_tiktoken (
198228 tiktoken_bpe : & tiktoken_rs:: CoreBPE ,
@@ -205,8 +235,12 @@ impl BytePairEncoding {
205235 )
206236 }
207237
208- /// Construct a BytePairEncoding instance from an iterator which enumerates all tokens.
209- /// A suitable hash factor may be necessary to prevent hash collisions. You can find on eusing the [`find_hash_factor`] test.
238+ /// Construct a BytePairEncoding instance from an iterator that enumerates all tokens.
239+ /// A suitable hash factor may be necessary to prevent hash collisions, which can be
240+ /// found using [`crate::data::find_hash_factor`].
241+ ///
242+ /// The recommended approach is to store the serialized value and reuse that,
243+ /// to prevent repeating the cost of computing the hash factor and encoding.
210244 pub fn from_dictionary ( iter : impl Iterator < Item = Vec < u8 > > , hash_factor : Option < u64 > ) -> Self {
211245 let hash_factor = hash_factor
212246 . inspect ( |f| assert_ne ! ( * f, 0 , "hash factor must be larger than zero" ) )
@@ -574,53 +608,25 @@ mod tests {
574608
575609#[ cfg( test) ]
576610mod data {
577- use std:: collections:: HashSet ;
578611 use std:: fs:: File ;
579612 use std:: path:: PathBuf ;
580613
581- use rand:: Rng ;
582614 use serde:: Serialize ;
583- use tiktoken_rs:: { cl100k_base, o200k_base, CoreBPE } ;
584-
585- use super :: * ;
586615
587- const BPE_CL100K_LEN : usize = 100256 ;
588- const BPE_O200K_LEN : usize = 199998 ;
589-
590- /// Use this to find a hashing factor for [`hash_bytes`] that prevents collisions.
591- /// 1. Set the `(bpe, len)` value to the tiktoken tokenizer you want to find a hash factor for.
592- /// 2. Update the hash factor in [`hash_bytes`].
593- /// 3. Run [`update_token_dicts`] tests below to update data files.
594- /// Note: If you forget this, the next test run will update the files, but
595- /// all other tests might fail because the data was not up-to-date.
596- #[ test]
597- #[ ignore = "run manually to find a suitable hash factor" ]
598- #[ allow( unreachable_code, unused_variables) ]
599- fn find_hash_factor ( ) {
600- let ( bpe, len) : ( CoreBPE , _ ) = todo ! ( "replace with BPE instance and token count" ) ;
601- let mut rnd = rand:: thread_rng ( ) ;
602- loop {
603- let factor: u64 = rnd. gen ( ) ;
604- let mut seen = HashSet :: with_capacity ( len) ;
605- if ( 0 ..len) . all ( |i| seen. insert ( hash_bytes ( & bpe. _decode_native ( & [ i] ) , factor) ) ) {
606- println ! ( "hash factor: {factor}" ) ;
607- return ;
608- }
609- }
610- }
616+ use crate :: byte_pair_encoding:: BytePairEncoding ;
611617
612618 #[ test]
613619 fn update_token_dicts ( ) {
614620 serialize_tokens (
615621 "cl100k" ,
616- & cl100k_base ( ) . expect ( "tiktoken initialization must not fail!" ) ,
617- BPE_CL100K_LEN ,
622+ & tiktoken_rs :: cl100k_base ( ) . expect ( "tiktoken initialization must not fail!" ) ,
623+ 100256 ,
618624 17846336922010275747 ,
619625 ) ;
620626 serialize_tokens (
621627 "o200k" ,
622- & o200k_base ( ) . expect ( "tiktoken initialization must not fail!" ) ,
623- BPE_O200K_LEN ,
628+ & tiktoken_rs :: o200k_base ( ) . expect ( "tiktoken initialization must not fail!" ) ,
629+ 199998 ,
624630 17846336922010275747 ,
625631 ) ;
626632 }
0 commit comments