|
1 | 1 | #[cfg(test)] |
2 | 2 | mod tests { |
| 3 | + use std::time; |
| 4 | + |
3 | 5 | use itertools::Itertools; |
4 | 6 | use rand::{rng, Rng}; |
5 | 7 | use tiktoken_rs::cl100k_base_singleton; |
@@ -144,25 +146,41 @@ mod tests { |
144 | 146 |
|
145 | 147 | #[test] |
146 | 148 | fn test_bpe_dropout() { |
| 149 | + use rand::rngs::StdRng; |
| 150 | + use rand::SeedableRng; |
| 151 | + |
| 152 | + fn get_rng(seed: u64) -> StdRng { |
| 153 | + // Expand the u64 seed to 32 bytes |
| 154 | + let mut seed_bytes = [0u8; 32]; |
| 155 | + seed_bytes[..8].copy_from_slice(&seed.to_le_bytes()); |
| 156 | + StdRng::from_seed(seed_bytes) |
| 157 | + } |
| 158 | + |
147 | 159 | let bpe = &cl100k_base().bpe; |
148 | 160 | for bytes in [10000, 20000] { |
149 | 161 | for _ in 0..8 { |
150 | 162 | let input = create_test_bytes(bpe, bytes); |
151 | 163 | let encoded = bpe.encode_minimal(&input); |
152 | | - let encoded_d_min = bpe.encode_minimal_dropout(&input, 0.2, Some(0)); |
153 | | - let encoded_d_max = bpe.encode_minimal_dropout(&input, 0.9, Some(1)); |
154 | | - let encoded_d_max_again = bpe.encode_minimal_dropout(&input, 0.9, Some(2)); |
| 164 | + let encoded_d_min = bpe.encode_minimal_dropout(&input, 0.2, get_rng(0)); |
| 165 | + let encoded_d_max = bpe.encode_minimal_dropout(&input, 0.9, get_rng(1)); |
| 166 | + let encoded_d_1_0 = bpe.encode_minimal_dropout(&input, 1.0, get_rng(2)); |
155 | 167 | let decoded = bpe.decode_tokens(&encoded); |
156 | 168 | let decoded_min = bpe.decode_tokens(&encoded_d_min); |
157 | 169 | let decoded_max = bpe.decode_tokens(&encoded_d_max); |
158 | | - assert_eq!(decoded, decoded_min); |
159 | | - assert_eq!(decoded, decoded_max); |
| 170 | + let decoded_max_again = bpe.decode_tokens(&encoded_d_1_0); |
| 171 | + println!("Input length: {}, Encoded length: {}, Encoded with dropout length: {}-{}, max {}", |
| 172 | + input.len(), encoded.len(), encoded_d_min.len(), encoded_d_max.len(), encoded_d_1_0.len()); |
| 173 | + assert_eq!(input, decoded); |
| 174 | + assert_eq!(input, decoded_min); |
| 175 | + assert_eq!(input, decoded_max); |
| 176 | + assert_eq!(input, decoded_max_again); |
| 177 | + assert_eq!(input.len(), encoded_d_1_0.len()); |
160 | 178 | assert!(encoded_d_min.len() >= encoded.len()); |
161 | 179 | assert!(encoded_d_max.len() > encoded.len()); |
162 | 180 |
|
163 | 181 | assert_ne!(encoded, encoded_d_min); |
164 | 182 | assert_ne!(encoded, encoded_d_max); |
165 | | - assert_ne!(encoded_d_max, encoded_d_max_again); |
| 183 | + assert_ne!(encoded_d_max, encoded_d_1_0); |
166 | 184 | } |
167 | 185 | } |
168 | 186 | } |
|
0 commit comments