Skip to content

Commit 45c1067

Browse files
author
Hendrik van Antwerpen
committed
Align function names and parameters
1 parent 646deeb commit 45c1067

1 file changed

Lines changed: 11 additions & 7 deletions

File tree

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,23 +169,27 @@ fn hash_bytes(bytes: &[u8], factor: u64) -> u32 {
169169
/// Find a suitable hash factor for the given tiktoken dictionary that prevents collisions
170170
/// when constructing a [`BytePairEncoding`] from those tokens.
171171
#[cfg(all(feature = "tiktoken-rs", feature = "rand"))]
172-
pub fn find_hash_factor_from_tiktoken(bpe: &tiktoken_rs::CoreBPE, len: usize) -> u64 {
173-
find_hash_factor(|i| bpe._decode_native(&[i]), len)
172+
pub fn find_hash_factor_for_tiktoken(bpe: &tiktoken_rs::CoreBPE, len: usize) -> u64 {
173+
find_hash_factor_for_dictionary((0..len).map(|i| bpe._decode_native(&[i])))
174174
}
175175

176176
/// Find a suitable hash factor for a set of given tokens that prevents collisions when
177177
/// constructing a [`BytePairEncoding`] from those tokens.
178178
#[cfg(feature = "rand")]
179-
pub fn find_hash_factor(tokens: impl Fn(usize) -> Vec<u8>, len: usize) -> u64 {
179+
pub fn find_hash_factor_for_dictionary(iter: impl Iterator<Item = Vec<u8>>) -> u64 {
180180
use std::collections::HashSet;
181181

182182
use rand::Rng;
183183

184+
let all_tokens = iter.collect_vec();
184185
let mut rnd = rand::thread_rng();
185186
loop {
186187
let factor: u64 = rnd.gen();
187-
let mut seen = HashSet::with_capacity(len);
188-
if (0..len).all(|i| seen.insert(hash_bytes(&tokens(i), factor))) {
188+
let mut seen = HashSet::new();
189+
if all_tokens
190+
.iter()
191+
.all(|token| seen.insert(hash_bytes(token, factor)))
192+
{
189193
println!("hash factor: {factor}");
190194
return factor;
191195
}
@@ -219,7 +223,7 @@ impl BytePairEncoding {
219223

220224
/// Construct a BytePairEncoding instance from a tiktoken dictionary.
221225
/// A suitable hash factor may be necessary to prevent hash collisions,
222-
/// which can by found using [`crate::data::find_hash_factor_from_tiktoken`].
226+
/// which can by found using [`find_hash_factor_for_tiktoken`].
223227
///
224228
/// The recommended approach is to store the serialized value and reuse that,
225229
/// to prevent repeating the cost of computing the hash factor and encoding.
@@ -237,7 +241,7 @@ impl BytePairEncoding {
237241

238242
/// Construct a BytePairEncoding instance from an iterator that enumerates all tokens.
239243
/// A suitable hash factor may be necessary to prevent hash collisions, which can be
240-
/// found using [`crate::data::find_hash_factor`].
244+
/// found using [`find_hash_factor_for_dictionary`].
241245
///
242246
/// The recommended approach is to store the serialized value and reuse that,
243247
/// to prevent repeating the cost of computing the hash factor and encoding.

0 commit comments

Comments
 (0)