Skip to content

Commit 348abf0

Browse files
author
Hendrik van Antwerpen
committed
Move tiktoken data reading to bpe
1 parent 3132551 commit 348abf0

File tree

4 files changed

+53
-34
lines changed

4 files changed

+53
-34
lines changed

crates/bpe-openai/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ tiktoken-rs = "0.6"
2424

2525
[build-dependencies]
2626
base64 = "0.22.1"
27-
bpe = { version = "0.1.0", path = "../bpe" }
27+
bpe = { version = "0.1.0", path = "../bpe", features = ["tiktoken"] }
2828
flate2 = "1.0"
2929
rmp-serde = "1"
3030
serde = { version = "1" }

crates/bpe-openai/build.rs

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,30 @@ use std::fs::File;
33
use std::io::Read;
44
use std::path::PathBuf;
55

6-
use base64::prelude::*;
7-
use bpe::byte_pair_encoding::BytePairEncoding;
6+
use bpe::byte_pair_encoding::{read_tiktoken, BytePairEncoding};
87
use serde::Serialize;
98

109
fn main() {
11-
serialize_tokens(
12-
"r50k_base",
13-
load_tiktoken_gz(include_bytes!("data/r50k_base.tiktoken.gz")),
14-
1,
15-
);
16-
serialize_tokens(
17-
"p50k_base",
18-
load_tiktoken_gz(include_bytes!("data/p50k_base.tiktoken.gz")),
19-
1,
20-
);
21-
serialize_tokens(
10+
serialize_tiktoken_bpe("r50k_base", include_bytes!("data/r50k_base.tiktoken.gz"), 1);
11+
serialize_tiktoken_bpe("p50k_base", include_bytes!("data/p50k_base.tiktoken.gz"), 1);
12+
serialize_tiktoken_bpe(
2213
"cl100k_base",
23-
load_tiktoken_gz(include_bytes!("data/cl100k_base.tiktoken.gz")),
14+
include_bytes!("data/cl100k_base.tiktoken.gz"),
2415
17846336922010275747,
2516
);
26-
serialize_tokens(
17+
serialize_tiktoken_bpe(
2718
"o200k_base",
28-
load_tiktoken_gz(include_bytes!("data/o200k_base.tiktoken.gz")),
19+
include_bytes!("data/o200k_base.tiktoken.gz"),
2920
17846336922010275747,
3021
);
3122
println!("cargo::rerun-if-changed=build.rs");
3223
}
3324

34-
fn serialize_tokens(name: &str, tokens: Vec<Vec<u8>>, hash_factor: u64) {
25+
fn serialize_tiktoken_bpe(name: &str, data: &[u8], hash_factor: u64) {
26+
let mut dec = flate2::read::GzDecoder::new(data);
27+
let mut tiktoken = String::new();
28+
dec.read_to_string(&mut tiktoken).expect("can decode data");
29+
let tokens = read_tiktoken(&tiktoken).expect("can read data");
3530
let mut path = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR is set during build"));
3631
path.push(format!("bpe_{name}.dict"));
3732
let file = File::create(path).expect("can create output file");
@@ -40,19 +35,3 @@ fn serialize_tokens(name: &str, tokens: Vec<Vec<u8>>, hash_factor: u64) {
4035
bpe.serialize(&mut serializer)
4136
.expect("serialization succeeds");
4237
}
43-
44-
fn load_tiktoken_gz(data: &[u8]) -> Vec<Vec<u8>> {
45-
let mut dec = flate2::read::GzDecoder::new(data);
46-
let mut tiktoken = String::new();
47-
dec.read_to_string(&mut tiktoken).expect("can decode data");
48-
let tokens: Vec<_> = tiktoken
49-
.lines()
50-
.filter(|line| !line.is_empty())
51-
.map(|line| {
52-
BASE64_STANDARD
53-
.decode(line.split_whitespace().next().expect("token field on line"))
54-
.expect("base64 token field")
55-
})
56-
.collect();
57-
tokens
58-
}

crates/bpe/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ bench = false
1414

1515
[features]
1616
rand = ["dep:rand"]
17+
tiktoken = ["dep:base64"]
1718

1819
[dependencies]
1920
aneubeck-daachorse = "1.1.1"
21+
base64 = { version = "0.22", optional = true }
2022
fnv = "1.0"
2123
itertools = "0.12"
2224
rand = { version = "0.8", optional = true }

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,13 @@ fn hash_bytes(bytes: &[u8], factor: u64) -> u32 {
155155
((hasher.finish().wrapping_mul(factor)) >> 32) as u32
156156
}
157157

158+
/// Find a suitable hash factor for the given tiktoken data that prevents collisions when
159+
/// constructing a [`BytePairEncoding`] from those tokens.
160+
#[cfg(all(feature = "rand", feature = "tiktoken"))]
161+
pub fn find_hash_factor_for_tiktoken(data: &str) -> Result<u64, base64::DecodeError> {
162+
Ok(find_hash_factor_for_dictionary(read_tiktoken(data)?))
163+
}
164+
158165
/// Find a suitable hash factor for a set of given tokens that prevents collisions when
159166
/// constructing a [`BytePairEncoding`] from those tokens.
160167
#[cfg(feature = "rand")]
@@ -193,7 +200,38 @@ fn find_token_by_bytes(
193200
}
194201
}
195202

203+
/// Read the tokens from a tiktoken data file, which contains base64 encoded tokens at
204+
/// the start of each line, in descending frequency order.
205+
#[cfg(feature = "tiktoken")]
206+
pub fn read_tiktoken(data: &str) -> Result<Vec<Vec<u8>>, base64::DecodeError> {
207+
use base64::prelude::*;
208+
data.lines()
209+
.filter(|line| !line.is_empty())
210+
.map(|line| {
211+
let encoded_token = line
212+
.split_whitespace()
213+
.next()
214+
.expect("non-empty line has first field");
215+
BASE64_STANDARD.decode(encoded_token)
216+
})
217+
.try_collect()
218+
}
219+
196220
impl BytePairEncoding {
221+
/// Construct a BytePairEncoding instance from a tiktoken data file.
222+
/// A suitable hash factor may be necessary to prevent hash collisions, which can be
223+
/// found using [`find_hash_factor_for_tiktoken`].
224+
///
225+
/// The recommended approach is to store the serialized value and reuse that,
226+
/// to prevent repeating the cost of computing the hash factor and encoding.
227+
#[cfg(feature = "tiktoken")]
228+
pub fn from_tiktoken(
229+
data: &str,
230+
hash_factor: Option<u64>,
231+
) -> Result<Self, base64::DecodeError> {
232+
Ok(Self::from_dictionary(read_tiktoken(data)?, hash_factor))
233+
}
234+
197235
/// Construct a BytePairEncoding instance from an iterator that enumerates all tokens.
198236
/// A suitable hash factor may be necessary to prevent hash collisions, which can be
199237
/// found using [`find_hash_factor_for_dictionary`].

0 commit comments

Comments
 (0)