Skip to content

Commit ad953d9

Browse files
author
Hendrik van Antwerpen
committed
Use create_test_string everywhere
1 parent 438c54e commit ad953d9

File tree

3 files changed

+23
-36
lines changed

3 files changed

+23
-36
lines changed

crates/bpe/benchmarks/performance.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@ use bpe::appendable_encoder::AppendableEncoder;
44
use bpe::byte_pair_encoding::{create_test_string, select_test_string};
55
use bpe::interval_encoding::IntervalEncoding;
66
use bpe_benchmarks::*;
7-
use bpe_tests::create_test_bytes;
87
use criterion::{
98
criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, PlotConfiguration,
109
};
1110
use rand::{thread_rng, Rng};
1211

1312
fn counting_benchmark(c: &mut Criterion) {
1413
for (name, bpe, _, _) in TOKENIZERS.iter() {
15-
let input = create_test_bytes(&bpe.bpe, 20000);
16-
let fast = IntervalEncoding::new(&bpe.bpe, &input);
14+
let input = create_test_string(&bpe.bpe, 80000);
15+
let fast = IntervalEncoding::new(&bpe.bpe, input.as_bytes());
1716

1817
let mut group = c.benchmark_group(format!("counting-{name}"));
1918
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
@@ -32,7 +31,7 @@ fn counting_benchmark(c: &mut Criterion) {
3231
|b, bytes| {
3332
b.iter_batched(
3433
|| thread_rng().gen_range(0..input.len() - bytes),
35-
|start| bpe.bpe.count(&input[start..start + bytes]),
34+
|start| bpe.bpe.count(&input.as_bytes()[start..start + bytes]),
3635
criterion::BatchSize::SmallInput,
3736
)
3837
},

crates/bpe/tests/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "bpe-tests"
33
edition = "2021"
44

55
[dependencies]
6-
bpe = { path = "../../bpe" }
6+
bpe = { path = "../../bpe", features = ["rand"] }
77
bpe-openai = { path = "../../bpe-openai" }
88
itertools = "0.13"
99
rand = "0.8"

crates/bpe/tests/src/lib.rs

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,3 @@
1-
use bpe::byte_pair_encoding::BytePairEncoding;
2-
use rand::{thread_rng, Rng};
3-
4-
pub fn create_test_bytes(bpe: &BytePairEncoding, tokens: usize) -> Vec<u8> {
5-
let mut text = vec![];
6-
for _ in 0..tokens {
7-
let i = thread_rng().gen_range(0..bpe.num_tokens());
8-
let s = bpe.token_bytes(i as u32);
9-
text.extend_from_slice(s);
10-
}
11-
text
12-
}
13-
141
#[cfg(test)]
152
mod tests {
163
use std::time::Instant;
@@ -20,13 +7,11 @@ mod tests {
207
use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton};
218

229
use bpe::appendable_encoder::AppendableEncoder;
23-
use bpe::byte_pair_encoding::BytePairEncoding;
10+
use bpe::byte_pair_encoding::{create_test_string, BytePairEncoding};
2411
use bpe::interval_encoding::IntervalEncoding;
2512
use bpe::prependable_encoder::PrependableEncoder;
2613
use bpe_openai::{cl100k_base, o200k_base};
2714

28-
use super::*;
29-
3015
/// This test produces the output for the encoding example in the README.
3116
#[test]
3217
fn readme_example() {
@@ -87,10 +72,13 @@ mod tests {
8772
fn test_appendable_encoder() {
8873
let bpe = &cl100k_base().bpe;
8974
let mut enc = AppendableEncoder::new(bpe);
90-
let input_string = create_test_bytes(bpe, 100);
91-
for (i, c) in input_string.iter().enumerate() {
92-
assert_eq!(enc.token_count(), bpe.count(&input_string[0..i]));
93-
enc.push(*c);
75+
let input_string = create_test_string(bpe, 100);
76+
for (i, b) in input_string.as_bytes().iter().enumerate() {
77+
enc.push(*b);
78+
assert_eq!(
79+
enc.token_count(),
80+
bpe.count(&input_string.as_bytes()[0..i + 1])
81+
);
9482
}
9583
}
9684

@@ -149,11 +137,11 @@ mod tests {
149137
#[test]
150138
fn test_bpe_equivalence() {
151139
let bpe = &cl100k_base().bpe;
152-
for tokens in [10, 1000, 10000] {
140+
for bytes in [10, 1000, 10000] {
153141
for _ in 0..5 {
154-
let test_input = create_test_bytes(bpe, tokens);
155-
let encoded1 = bpe.encode_via_backtracking(&test_input);
156-
let encoded2 = bpe.encode_via_bitfield(&test_input);
142+
let test_input = create_test_string(bpe, bytes);
143+
let encoded1 = bpe.encode_via_backtracking(test_input.as_bytes());
144+
let encoded2 = bpe.encode_via_bitfield(test_input.as_bytes());
157145
assert_eq!(encoded1, encoded2, "{} {}", encoded1.len(), encoded2.len());
158146
}
159147
}
@@ -162,15 +150,15 @@ mod tests {
162150
#[test]
163151
fn test_interval_count() {
164152
let bpe = &cl100k_base().bpe;
165-
let text = create_test_bytes(bpe, 10000);
166-
let intervals = IntervalEncoding::new(bpe, &text);
153+
let text = create_test_string(bpe, 10000);
154+
let intervals = IntervalEncoding::new(bpe, text.as_bytes());
167155
for _ in 0..1000 {
168156
let start = thread_rng().gen_range(0..text.len());
169157
let end = thread_rng().gen_range(0..text.len());
170158
let range = start.min(end)..start.max(end);
171159
assert_eq!(
172160
intervals.count(range.clone()),
173-
bpe.encode_via_backtracking(&text[range]).len()
161+
bpe.encode_via_backtracking(&text.as_bytes()[range]).len()
174162
);
175163
}
176164
}
@@ -179,10 +167,10 @@ mod tests {
179167
fn test_prependable_encoder() {
180168
let bpe = &cl100k_base().bpe;
181169
let mut enc = PrependableEncoder::new(bpe);
182-
let input_string = create_test_bytes(bpe, 100);
183-
for (i, c) in input_string.iter().enumerate().rev() {
184-
enc.push(*c);
185-
assert_eq!(enc.token_count(), bpe.count(&input_string[i..]));
170+
let input_string = create_test_string(bpe, 100);
171+
for (i, b) in input_string.as_bytes().iter().enumerate().rev() {
172+
enc.push(*b);
173+
assert_eq!(enc.token_count(), bpe.count(&input_string.as_bytes()[i..]));
186174
}
187175
}
188176
}

0 commit comments

Comments
 (0)