@@ -169,23 +169,27 @@ fn hash_bytes(bytes: &[u8], factor: u64) -> u32 {
169169/// Find a suitable hash factor for the given tiktoken dictionary that prevents collisions
170170/// when constructing a [`BytePairEncoding`] from those tokens.
171171#[ cfg( all( feature = "tiktoken-rs" , feature = "rand" ) ) ]
172- pub fn find_hash_factor_from_tiktoken ( bpe : & tiktoken_rs:: CoreBPE , len : usize ) -> u64 {
173- find_hash_factor ( |i| bpe. _decode_native ( & [ i] ) , len )
172+ pub fn find_hash_factor_for_tiktoken ( bpe : & tiktoken_rs:: CoreBPE , len : usize ) -> u64 {
173+ find_hash_factor_for_dictionary ( ( 0 ..len ) . map ( |i| bpe. _decode_native ( & [ i] ) ) )
174174}
175175
176176/// Find a suitable hash factor for a set of given tokens that prevents collisions when
177177/// constructing a [`BytePairEncoding`] from those tokens.
178178#[ cfg( feature = "rand" ) ]
179- pub fn find_hash_factor ( tokens : impl Fn ( usize ) -> Vec < u8 > , len : usize ) -> u64 {
179+ pub fn find_hash_factor_for_dictionary ( iter : impl Iterator < Item = Vec < u8 > > ) -> u64 {
180180 use std:: collections:: HashSet ;
181181
182182 use rand:: Rng ;
183183
184+ let all_tokens = iter. collect_vec ( ) ;
184185 let mut rnd = rand:: thread_rng ( ) ;
185186 loop {
186187 let factor: u64 = rnd. gen ( ) ;
187- let mut seen = HashSet :: with_capacity ( len) ;
188- if ( 0 ..len) . all ( |i| seen. insert ( hash_bytes ( & tokens ( i) , factor) ) ) {
188+ let mut seen = HashSet :: new ( ) ;
189+ if all_tokens
190+ . iter ( )
191+ . all ( |token| seen. insert ( hash_bytes ( token, factor) ) )
192+ {
189193 println ! ( "hash factor: {factor}" ) ;
190194 return factor;
191195 }
@@ -219,7 +223,7 @@ impl BytePairEncoding {
219223
220224 /// Construct a BytePairEncoding instance from a tiktoken dictionary.
221225 /// A suitable hash factor may be necessary to prevent hash collisions,
222- /// which can by found using [`crate::data::find_hash_factor_from_tiktoken `].
226+ /// which can by found using [`find_hash_factor_for_tiktoken `].
223227 ///
224228 /// The recommended approach is to store the serialized value and reuse that,
225229 /// to prevent repeating the cost of computing the hash factor and encoding.
@@ -237,7 +241,7 @@ impl BytePairEncoding {
237241
238242 /// Construct a BytePairEncoding instance from an iterator that enumerates all tokens.
239243 /// A suitable hash factor may be necessary to prevent hash collisions, which can be
240- /// found using [`crate::data::find_hash_factor `].
244+ /// found using [`find_hash_factor_for_dictionary `].
241245 ///
242246 /// The recommended approach is to store the serialized value and reuse that,
243247 /// to prevent repeating the cost of computing the hash factor and encoding.
0 commit comments