Skip to content

Commit 646deeb

Browse files
author
Hendrik van Antwerpen
committed
Make hash factor functions available to users
1 parent 8c574d5 commit 646deeb

1 file changed

Lines changed: 43 additions & 37 deletions

File tree

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,32 @@ fn hash_bytes(bytes: &[u8], factor: u64) -> u32 {
166166
((hasher.finish().wrapping_mul(factor)) >> 32) as u32
167167
}
168168

169+
/// Find a suitable hash factor for the given tiktoken dictionary that prevents collisions
170+
/// when constructing a [`BytePairEncoding`] from those tokens.
171+
#[cfg(all(feature = "tiktoken-rs", feature = "rand"))]
172+
pub fn find_hash_factor_from_tiktoken(bpe: &tiktoken_rs::CoreBPE, len: usize) -> u64 {
173+
find_hash_factor(|i| bpe._decode_native(&[i]), len)
174+
}
175+
176+
/// Find a suitable hash factor for a set of given tokens that prevents collisions when
177+
/// constructing a [`BytePairEncoding`] from those tokens.
178+
#[cfg(feature = "rand")]
179+
pub fn find_hash_factor(tokens: impl Fn(usize) -> Vec<u8>, len: usize) -> u64 {
180+
use std::collections::HashSet;
181+
182+
use rand::Rng;
183+
184+
let mut rnd = rand::thread_rng();
185+
loop {
186+
let factor: u64 = rnd.gen();
187+
let mut seen = HashSet::with_capacity(len);
188+
if (0..len).all(|i| seen.insert(hash_bytes(&tokens(i), factor))) {
189+
println!("hash factor: {factor}");
190+
return factor;
191+
}
192+
}
193+
}
194+
169195
fn find_token_by_bytes(
170196
all_tokens: &[u8],
171197
token_starts: &[u32],
@@ -191,8 +217,12 @@ impl BytePairEncoding {
191217
&BPE_O200K
192218
}
193219

194-
/// Construct a BytePairEncoding instance frmo a tiktoken dictionary.
195-
/// A suitable hash factor may be necessary to prevent hash collisions. You can find on eusing the [`find_hash_factor`] test.
220+
/// Construct a BytePairEncoding instance from a tiktoken dictionary.
221+
/// A suitable hash factor may be necessary to prevent hash collisions,
222+
/// which can by found using [`crate::data::find_hash_factor_from_tiktoken`].
223+
///
224+
/// The recommended approach is to store the serialized value and reuse that,
225+
/// to prevent repeating the cost of computing the hash factor and encoding.
196226
#[cfg(feature = "tiktoken-rs")]
197227
pub fn from_tiktoken(
198228
tiktoken_bpe: &tiktoken_rs::CoreBPE,
@@ -205,8 +235,12 @@ impl BytePairEncoding {
205235
)
206236
}
207237

208-
/// Construct a BytePairEncoding instance from an iterator which enumerates all tokens.
209-
/// A suitable hash factor may be necessary to prevent hash collisions. You can find on eusing the [`find_hash_factor`] test.
238+
/// Construct a BytePairEncoding instance from an iterator that enumerates all tokens.
239+
/// A suitable hash factor may be necessary to prevent hash collisions, which can be
240+
/// found using [`crate::data::find_hash_factor`].
241+
///
242+
/// The recommended approach is to store the serialized value and reuse that,
243+
/// to prevent repeating the cost of computing the hash factor and encoding.
210244
pub fn from_dictionary(iter: impl Iterator<Item = Vec<u8>>, hash_factor: Option<u64>) -> Self {
211245
let hash_factor = hash_factor
212246
.inspect(|f| assert_ne!(*f, 0, "hash factor must be larger than zero"))
@@ -574,53 +608,25 @@ mod tests {
574608

575609
#[cfg(test)]
576610
mod data {
577-
use std::collections::HashSet;
578611
use std::fs::File;
579612
use std::path::PathBuf;
580613

581-
use rand::Rng;
582614
use serde::Serialize;
583-
use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
584-
585-
use super::*;
586615

587-
const BPE_CL100K_LEN: usize = 100256;
588-
const BPE_O200K_LEN: usize = 199998;
589-
590-
/// Use this to find a hashing factor for [`hash_bytes`] that prevents collisions.
591-
/// 1. Set the `(bpe, len)` value to the tiktoken tokenizer you want to find a hash factor for.
592-
/// 2. Update the hash factor in [`hash_bytes`].
593-
/// 3. Run [`update_token_dicts`] tests below to update data files.
594-
/// Note: If you forget this, the next test run will update the files, but
595-
/// all other tests might fail because the data was not up-to-date.
596-
#[test]
597-
#[ignore = "run manually to find a suitable hash factor"]
598-
#[allow(unreachable_code, unused_variables)]
599-
fn find_hash_factor() {
600-
let (bpe, len): (CoreBPE, _) = todo!("replace with BPE instance and token count");
601-
let mut rnd = rand::thread_rng();
602-
loop {
603-
let factor: u64 = rnd.gen();
604-
let mut seen = HashSet::with_capacity(len);
605-
if (0..len).all(|i| seen.insert(hash_bytes(&bpe._decode_native(&[i]), factor))) {
606-
println!("hash factor: {factor}");
607-
return;
608-
}
609-
}
610-
}
616+
use crate::byte_pair_encoding::BytePairEncoding;
611617

612618
#[test]
613619
fn update_token_dicts() {
614620
serialize_tokens(
615621
"cl100k",
616-
&cl100k_base().expect("tiktoken initialization must not fail!"),
617-
BPE_CL100K_LEN,
622+
&tiktoken_rs::cl100k_base().expect("tiktoken initialization must not fail!"),
623+
100256,
618624
17846336922010275747,
619625
);
620626
serialize_tokens(
621627
"o200k",
622-
&o200k_base().expect("tiktoken initialization must not fail!"),
623-
BPE_O200K_LEN,
628+
&tiktoken_rs::o200k_base().expect("tiktoken initialization must not fail!"),
629+
199998,
624630
17846336922010275747,
625631
);
626632
}

0 commit comments

Comments
 (0)