Skip to content

Commit 8c574d5

Browse files
author
Hendrik van Antwerpen
committed
Expose hash factor in API
1 parent f0c9def commit 8c574d5

3 files changed

Lines changed: 49 additions & 28 deletions

File tree

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ pub struct BytePairEncoding {
6363
/// But we don't have efficient access to it and therefore store it here again.
6464
/// If there is none, then the value is set to u32::MAX.
6565
next_prefix_match: Vec<u32>,
66+
/// Hash factor used to prevent hash collisions.
67+
hash_factor: u64,
6668
}
6769

6870
fn serialize_daac<S: Serializer>(
@@ -156,11 +158,7 @@ fn token_bytes<'a>(all_tokens: &'a [u8], token_starts: &[u32], token_id: u32) ->
156158
&all_tokens[token_range(token_starts, token_id)]
157159
}
158160

159-
fn hash_bytes(bytes: &[u8]) -> u32 {
160-
hash_bytes_with_factor(bytes, 17846336922010275747)
161-
}
162-
163-
fn hash_bytes_with_factor(bytes: &[u8], factor: u64) -> u32 {
161+
fn hash_bytes(bytes: &[u8], factor: u64) -> u32 {
164162
let mut hasher = FnvHasher::default();
165163
bytes.hash(&mut hasher);
166164
// Note: we save 1/3 of space for the hashmap by only using the most significant bits of the hash.
@@ -173,8 +171,9 @@ fn find_token_by_bytes(
173171
token_starts: &[u32],
174172
bytes_hash_to_token: &FnvHashMap<u32, u32>,
175173
bytes: &[u8],
174+
hash_factor: u64,
176175
) -> Option<u32> {
177-
let hash = hash_bytes(bytes);
176+
let hash = hash_bytes(bytes, hash_factor);
178177
let token = *bytes_hash_to_token.get(&hash)?;
179178
if token_bytes(all_tokens, token_starts, token) == bytes {
180179
Some(token)
@@ -193,19 +192,31 @@ impl BytePairEncoding {
193192
}
194193

195194
/// Construct a BytePairEncoding instance frmo a tiktoken dictionary.
195+
/// A suitable hash factor may be necessary to prevent hash collisions. You can find on eusing the [`find_hash_factor`] test.
196196
#[cfg(feature = "tiktoken-rs")]
197-
pub fn from_tiktoken(tiktoken_bpe: &tiktoken_rs::CoreBPE, num_tokens: usize) -> Self {
198-
Self::from_dictionary((0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i])))
197+
pub fn from_tiktoken(
198+
tiktoken_bpe: &tiktoken_rs::CoreBPE,
199+
num_tokens: usize,
200+
hash_factor: Option<u64>,
201+
) -> Self {
202+
Self::from_dictionary(
203+
(0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i])),
204+
hash_factor,
205+
)
199206
}
200207

201208
/// Construct a BytePairEncoding instance from an iterator which enumerates all tokens.
202-
pub fn from_dictionary(iter: impl Iterator<Item = Vec<u8>>) -> Self {
209+
/// A suitable hash factor may be necessary to prevent hash collisions. You can find on eusing the [`find_hash_factor`] test.
210+
pub fn from_dictionary(iter: impl Iterator<Item = Vec<u8>>, hash_factor: Option<u64>) -> Self {
211+
let hash_factor = hash_factor
212+
.inspect(|f| assert_ne!(*f, 0, "hash factor must be larger than zero"))
213+
.unwrap_or(1);
203214
let mut all_tokens = Vec::new();
204215
let mut all_tokens_rev = Vec::new();
205216
let mut token_starts = vec![0];
206217
let mut bytes_hash_to_token = FnvHashMap::default();
207218
for (i, token) in iter.enumerate() {
208-
bytes_hash_to_token.insert(hash_bytes(&token), i as u32);
219+
bytes_hash_to_token.insert(hash_bytes(&token, hash_factor), i as u32);
209220
all_tokens_rev.extend(token.iter().copied().rev());
210221
all_tokens.extend(token);
211222
token_starts.push(all_tokens.len() as u32);
@@ -236,9 +247,13 @@ impl BytePairEncoding {
236247
let mut token1 = next_prefix_match[id];
237248
while token1 != u32::MAX {
238249
let rest = &token[token_range(&token_starts, token1).len()..];
239-
if let Some(token2) =
240-
find_token_by_bytes(&all_tokens, &token_starts, &bytes_hash_to_token, rest)
241-
{
250+
if let Some(token2) = find_token_by_bytes(
251+
&all_tokens,
252+
&token_starts,
253+
&bytes_hash_to_token,
254+
rest,
255+
hash_factor,
256+
) {
242257
if token1 < id as u32
243258
&& token2 < id as u32
244259
&& is_valid_token_pair(&pair_lookup, &split_table, token1, token2)
@@ -264,6 +279,7 @@ impl BytePairEncoding {
264279
next_prefix_match,
265280
pair_lookup,
266281
split_table,
282+
hash_factor,
267283
}
268284
}
269285

@@ -308,6 +324,7 @@ impl BytePairEncoding {
308324
&self.token_starts,
309325
&self.bytes_hash_to_token,
310326
bytes,
327+
self.hash_factor,
311328
)
312329
}
313330

@@ -563,32 +580,29 @@ mod data {
563580

564581
use rand::Rng;
565582
use serde::Serialize;
566-
use tiktoken_rs::{cl100k_base, o200k_base};
583+
use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
567584

568585
use super::*;
569586

570587
const BPE_CL100K_LEN: usize = 100256;
571588
const BPE_O200K_LEN: usize = 199998;
572589

573590
/// Use this to find a hashing factor for [`hash_bytes`] that prevents collisions.
574-
/// 1. Ensure all supported tokenizers are in the list.
591+
/// 1. Set the `(bpe, len)` value to the tiktoken tokenizer you want to find a hash factor for.
575592
/// 2. Update the hash factor in [`hash_bytes`].
576593
/// 3. Run [`update_token_dicts`] tests below to update data files.
594+
/// Note: If you forget this, the next test run will update the files, but
595+
/// all other tests might fail because the data was not up-to-date.
577596
#[test]
578597
#[ignore = "run manually to find a suitable hash factor"]
598+
#[allow(unreachable_code, unused_variables)]
579599
fn find_hash_factor() {
580-
let bpes = &mut [
581-
(cl100k_base().unwrap(), BPE_CL100K_LEN),
582-
(o200k_base().unwrap(), BPE_O200K_LEN),
583-
];
600+
let (bpe, len): (CoreBPE, _) = todo!("replace with BPE instance and token count");
584601
let mut rnd = rand::thread_rng();
585602
loop {
586603
let factor: u64 = rnd.gen();
587-
if bpes.iter().all(|(bpe, len)| {
588-
let mut seen = HashSet::with_capacity(*len);
589-
(0..*len)
590-
.all(|i| seen.insert(hash_bytes_with_factor(&bpe._decode_native(&[i]), factor)))
591-
}) {
604+
let mut seen = HashSet::with_capacity(len);
605+
if (0..len).all(|i| seen.insert(hash_bytes(&bpe._decode_native(&[i]), factor))) {
592606
println!("hash factor: {factor}");
593607
return;
594608
}
@@ -598,27 +612,34 @@ mod data {
598612
#[test]
599613
fn update_token_dicts() {
600614
serialize_tokens(
615+
"cl100k",
601616
&cl100k_base().expect("tiktoken initialization must not fail!"),
602617
BPE_CL100K_LEN,
603-
"cl100k",
618+
17846336922010275747,
604619
);
605620
serialize_tokens(
621+
"o200k",
606622
&o200k_base().expect("tiktoken initialization must not fail!"),
607623
BPE_O200K_LEN,
608-
"o200k",
624+
17846336922010275747,
609625
);
610626
}
611627

612628
#[track_caller]
613-
fn serialize_tokens(dict: &tiktoken_rs::CoreBPE, num_tokens: usize, name: &str) {
629+
fn serialize_tokens(
630+
name: &str,
631+
dict: &tiktoken_rs::CoreBPE,
632+
num_tokens: usize,
633+
hash_factor: u64,
634+
) {
614635
let path = PathBuf::from(file!());
615636
let dir = path.parent().unwrap();
616637
let data_file = dir.join(format!("data/bpe_{name}.dict"));
617638
let current_dir = std::env::current_dir().unwrap();
618639
let abs_path = current_dir.parent().unwrap().parent().unwrap();
619640
let file = File::create(abs_path.join(data_file)).unwrap();
620641
let mut serializer = rmp_serde::Serializer::new(file);
621-
BytePairEncoding::from_tiktoken(dict, num_tokens)
642+
BytePairEncoding::from_tiktoken(dict, num_tokens, Some(hash_factor))
622643
.serialize(&mut serializer)
623644
.unwrap();
624645
}
9 Bytes
Binary file not shown.

crates/bpe/src/data/bpe_o200k.dict

9 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)