Skip to content

Commit bae9f01

Browse files
author
Hendrik van Antwerpen
committed
Generate serialized data in build script
1 parent 1c2506d commit bae9f01

10 files changed

Lines changed: 144 additions & 110 deletions

File tree

crates/bpe-openai/Cargo.toml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
[package]
2+
name = "bpe-openai"
3+
version = "0.0.1"
4+
edition = "2021"
5+
description = "Prebuilt fast byte-pair encoders for OpenAI."
6+
repository = "https://github.com/github/rust-gems"
7+
license = "MIT"
8+
keywords = ["tokenizer", "algorithm", "encoding", "bpe"]
9+
categories = ["algorithms", "data-structures", "encoding", "science"]
10+
11+
[lib]
12+
crate-type = ["lib", "staticlib"]
13+
bench = false
14+
15+
[dependencies]
16+
bpe = { version = "0.0.1", path = "../bpe" }
17+
rmp-serde = "1"
18+
serde = { version = "1" }
19+
20+
[build-dependencies]
21+
bpe = { version = "0.0.1", path = "../bpe", features = ["tiktoken-rs"] }
22+
rmp-serde = "1"
23+
tiktoken-rs = { version = "0.5" }
24+
serde = { version = "1" }

crates/bpe-openai/build.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
use std::env;
2+
use std::fs::File;
3+
use std::path::PathBuf;
4+
5+
use bpe::byte_pair_encoding::BytePairEncoding;
6+
use serde::Serialize;
7+
use tiktoken_rs::CoreBPE;
8+
9+
fn main() {
10+
serialize_tokens(
11+
"cl100k",
12+
&tiktoken_rs::cl100k_base().expect("tiktoken initialization must not fail!"),
13+
100256,
14+
17846336922010275747,
15+
);
16+
serialize_tokens(
17+
"o200k",
18+
&tiktoken_rs::o200k_base().expect("tiktoken initialization must not fail!"),
19+
199998,
20+
17846336922010275747,
21+
);
22+
println!("cargo::rerun-if-changed=build.rs");
23+
}
24+
25+
fn serialize_tokens(name: &str, bpe: &CoreBPE, num_tokens: usize, hash_factor: u64) {
26+
let mut path = PathBuf::from(env::var("OUT_DIR").unwrap());
27+
path.push(format!("bpe_{name}.dict"));
28+
let file = File::create(path).unwrap();
29+
let mut serializer = rmp_serde::Serializer::new(file);
30+
let bpe = BytePairEncoding::from_tiktoken(bpe, num_tokens, Some(hash_factor));
31+
bpe.serialize(&mut serializer).unwrap();
32+
}

crates/bpe-openai/src/lib.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
use std::sync::LazyLock;
2+
3+
use bpe::byte_pair_encoding::BytePairEncoding;
4+
5+
static BPE_CL100K: LazyLock<BytePairEncoding> = LazyLock::new(|| {
6+
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k.dict"));
7+
rmp_serde::from_slice(bytes).expect("")
8+
});
9+
10+
static BPE_O200K: LazyLock<BytePairEncoding> = LazyLock::new(|| {
11+
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_o200k.dict"));
12+
rmp_serde::from_slice(bytes).expect("")
13+
});
14+
15+
pub use bpe::*;
16+
17+
pub fn cl100k() -> &'static BytePairEncoding {
18+
&BPE_CL100K
19+
}
20+
21+
pub fn o200k() -> &'static BytePairEncoding {
22+
&BPE_O200K
23+
}
24+
25+
#[cfg(test)]
26+
mod tests {
27+
use super::*;
28+
29+
#[test]
30+
fn can_load_cl100k() {
31+
cl100k().count("".as_bytes());
32+
}
33+
34+
#[test]
35+
fn can_load_o200k() {
36+
o200k().count("".as_bytes());
37+
}
38+
}

crates/bpe/benches/performance.rs

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,28 @@ use criterion::{
1010
use rand::{thread_rng, Rng};
1111
use tiktoken_rs::CoreBPE;
1212

13-
static TOKENIZERS: LazyLock<[(&'static str, &'static BytePairEncoding, CoreBPE); 2]> =
14-
LazyLock::new(|| {
15-
[
16-
(
17-
"cl100k",
18-
BytePairEncoding::cl100k(),
19-
tiktoken_rs::cl100k_base().unwrap(),
13+
static TOKENIZERS: LazyLock<[(&'static str, BytePairEncoding, CoreBPE); 2]> = LazyLock::new(|| {
14+
[
15+
(
16+
"cl100k",
17+
BytePairEncoding::from_tiktoken(
18+
&tiktoken_rs::cl100k_base_singleton().lock(),
19+
100256,
20+
Some(17846336922010275747),
2021
),
21-
(
22-
"o200k",
23-
BytePairEncoding::o200k(),
24-
tiktoken_rs::o200k_base().unwrap(),
22+
tiktoken_rs::cl100k_base().unwrap(),
23+
),
24+
(
25+
"o200k",
26+
BytePairEncoding::from_tiktoken(
27+
&tiktoken_rs::o200k_base_singleton().lock(),
28+
199998,
29+
Some(17846336922010275747),
2530
),
26-
]
27-
});
31+
tiktoken_rs::o200k_base().unwrap(),
32+
),
33+
]
34+
});
2835

2936
fn counting_benchmark(c: &mut Criterion) {
3037
for (name, bpe, _) in TOKENIZERS.iter() {

crates/bpe/src/appendable_encoder.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,13 @@ impl<'a> AppendableEncoder<'a> {
9090

9191
#[cfg(test)]
9292
mod tests {
93-
use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};
93+
use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K};
9494

9595
use super::AppendableEncoder;
9696

9797
#[test]
9898
fn test_appendable_encoder() {
99-
let bpe = BytePairEncoding::cl100k();
99+
let bpe = &BPE_CL100K;
100100
let mut enc = AppendableEncoder::new(bpe);
101101
let input_string = create_test_bytes(bpe, 100);
102102
for (i, c) in input_string.iter().enumerate() {

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 24 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ use std::cmp::Reverse;
22
use std::collections::BinaryHeap;
33
use std::hash::{Hash, Hasher};
44
use std::ops::Range;
5-
use std::sync::LazyLock;
65

76
use aneubeck_daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder};
87
use fnv::{FnvHashMap, FnvHasher};
@@ -12,19 +11,26 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
1211

1312
use crate::backtrack_encoder::BacktrackEncoder;
1413
use crate::bitfield::BitField;
15-
use crate::byte_pair_encoding::data::TokenDict;
1614

17-
static BPE_CL100K: LazyLock<BytePairEncoding> = LazyLock::new(|| {
18-
let bytes = include_bytes!("data/bpe_cl100k.dict");
19-
let dict: TokenDict = rmp_serde::from_slice(bytes).expect("");
20-
dict.into_bpe()
21-
});
15+
#[cfg(test)]
16+
pub(crate) static BPE_CL100K: std::sync::LazyLock<BytePairEncoding> =
17+
std::sync::LazyLock::new(|| {
18+
BytePairEncoding::from_tiktoken(
19+
&tiktoken_rs::cl100k_base_singleton().lock(),
20+
100256,
21+
Some(17846336922010275747),
22+
)
23+
});
2224

23-
static BPE_O200K: LazyLock<BytePairEncoding> = LazyLock::new(|| {
24-
let bytes = include_bytes!("data/bpe_o200k.dict");
25-
let dict: TokenDict = rmp_serde::from_slice(bytes).expect("");
26-
dict.into_bpe()
27-
});
25+
#[cfg(test)]
26+
pub(crate) static BPE_O200K: std::sync::LazyLock<BytePairEncoding> =
27+
std::sync::LazyLock::new(|| {
28+
BytePairEncoding::from_tiktoken(
29+
&tiktoken_rs::o200k_base_singleton().lock(),
30+
199998,
31+
Some(17846336922010275747),
32+
)
33+
});
2834

2935
/// Representation of the byte pair dictionary.
3036
/// This struct provides various conversions.
@@ -215,14 +221,6 @@ fn find_token_by_bytes(
215221
}
216222

217223
impl BytePairEncoding {
218-
pub fn cl100k() -> &'static Self {
219-
&BPE_CL100K
220-
}
221-
222-
pub fn o200k() -> &'static Self {
223-
&BPE_O200K
224-
}
225-
226224
/// Construct a BytePairEncoding instance from a tiktoken dictionary.
227225
/// A suitable hash factor may be necessary to prevent hash collisions,
228226
/// which can by found using [`find_hash_factor_for_tiktoken`].
@@ -572,7 +570,7 @@ mod tests {
572570
use itertools::Itertools;
573571
use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton};
574572

575-
use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};
573+
use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K, BPE_O200K};
576574

577575
#[test]
578576
fn test_correctness_cl100k() {
@@ -585,9 +583,9 @@ mod tests {
585583
])
586584
.unwrap();
587585
let time = Instant::now();
588-
let bpe = BytePairEncoding::o200k();
586+
let bpe = &BPE_CL100K;
589587
println!("{:?}", time.elapsed());
590-
let encoded1 = o200k_base_singleton()
588+
let encoded1 = cl100k_base_singleton()
591589
.lock()
592590
.encode_ordinary(test_string)
593591
.into_iter()
@@ -612,9 +610,9 @@ mod tests {
612610
])
613611
.unwrap();
614612
let time = Instant::now();
615-
let bpe = BytePairEncoding::cl100k();
613+
let bpe = &BPE_O200K;
616614
println!("{:?}", time.elapsed());
617-
let encoded1 = cl100k_base_singleton()
615+
let encoded1 = o200k_base_singleton()
618616
.lock()
619617
.encode_ordinary(test_string)
620618
.into_iter()
@@ -630,7 +628,7 @@ mod tests {
630628

631629
#[test]
632630
fn test_bpe_equivalence() {
633-
let bpe = BytePairEncoding::cl100k();
631+
let bpe = &BPE_CL100K;
634632
for tokens in [10, 1000, 10000] {
635633
for _ in 0..5 {
636634
let test_input = create_test_bytes(bpe, tokens);
@@ -641,68 +639,3 @@ mod tests {
641639
}
642640
}
643641
}
644-
645-
mod data {
646-
use serde::{Deserialize, Serialize};
647-
648-
use crate::byte_pair_encoding::BytePairEncoding;
649-
650-
#[derive(Serialize, Deserialize)]
651-
pub(crate) struct TokenDict {
652-
tokens: Vec<Vec<u8>>,
653-
hash_factor: u64,
654-
}
655-
656-
impl TokenDict {
657-
pub(crate) fn into_bpe(self) -> BytePairEncoding {
658-
BytePairEncoding::from_dictionary(self.tokens, Some(self.hash_factor))
659-
}
660-
}
661-
662-
#[test]
663-
fn update_token_dicts() {
664-
serialize_tokens(
665-
"cl100k",
666-
&tiktoken_rs::cl100k_base().expect("tiktoken initialization must not fail!"),
667-
100256,
668-
17846336922010275747,
669-
);
670-
serialize_tokens(
671-
"o200k",
672-
&tiktoken_rs::o200k_base().expect("tiktoken initialization must not fail!"),
673-
199998,
674-
17846336922010275747,
675-
);
676-
}
677-
678-
#[cfg(test)]
679-
#[track_caller]
680-
fn serialize_tokens(
681-
name: &str,
682-
bpe: &tiktoken_rs::CoreBPE,
683-
num_tokens: usize,
684-
hash_factor: u64,
685-
) {
686-
use std::fs::File;
687-
use std::path::PathBuf;
688-
689-
use itertools::Itertools;
690-
use serde::Serialize;
691-
692-
let path = PathBuf::from(file!());
693-
let dir = path.parent().unwrap();
694-
let data_file = dir.join(format!("data/bpe_{name}.dict"));
695-
let current_dir = std::env::current_dir().unwrap();
696-
let abs_path = current_dir.parent().unwrap().parent().unwrap();
697-
let file = File::create(abs_path.join(data_file)).unwrap();
698-
let mut serializer = rmp_serde::Serializer::new(file);
699-
let tokens = (0..num_tokens)
700-
.map(|i| bpe._decode_native(&[i]))
701-
.collect_vec();
702-
let dict = TokenDict {
703-
tokens,
704-
hash_factor,
705-
};
706-
dict.serialize(&mut serializer).unwrap();
707-
}
708-
}
-745 KB
Binary file not shown.

crates/bpe/src/data/bpe_o200k.dict

-2.02 MB
Binary file not shown.

crates/bpe/src/interval_encoding.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,13 @@ impl<'a> IntervalEncoding<'a> {
8686
mod tests {
8787
use rand::{thread_rng, Rng};
8888

89-
use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};
89+
use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K};
9090

9191
use super::IntervalEncoding;
9292

9393
#[test]
9494
fn test_interval_count() {
95-
let bpe = BytePairEncoding::cl100k();
95+
let bpe = &BPE_CL100K;
9696
let text = create_test_bytes(bpe, 10000);
9797
let intervals = IntervalEncoding::new(bpe, &text);
9898
for _ in 0..1000 {

crates/bpe/src/prependable_encoder.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,13 @@ impl<'a> PrependableEncoder<'a> {
9090

9191
#[cfg(test)]
9292
mod tests {
93-
use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};
93+
use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K};
9494

9595
use super::PrependableEncoder;
9696

9797
#[test]
9898
fn test_prependable_encoder() {
99-
let bpe = BytePairEncoding::cl100k();
99+
let bpe = &BPE_CL100K;
100100
let mut enc = PrependableEncoder::new(bpe);
101101
let input_string = create_test_bytes(bpe, 100);
102102
for (i, c) in input_string.iter().enumerate().rev() {

0 commit comments

Comments
 (0)