@@ -12,15 +12,18 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
1212
1313use crate :: backtrack_encoder:: BacktrackEncoder ;
1414use crate :: bitfield:: BitField ;
15+ use crate :: byte_pair_encoding:: data:: TokenDict ;
1516
1617static BPE_CL100K : LazyLock < BytePairEncoding > = LazyLock :: new ( || {
1718 let bytes = include_bytes ! ( "data/bpe_cl100k.dict" ) ;
18- rmp_serde:: from_slice ( bytes) . expect ( "" )
19+ let dict: TokenDict = rmp_serde:: from_slice ( bytes) . expect ( "" ) ;
20+ dict. into_bpe ( )
1921} ) ;
2022
2123static BPE_O200K : LazyLock < BytePairEncoding > = LazyLock :: new ( || {
2224 let bytes = include_bytes ! ( "data/bpe_o200k.dict" ) ;
23- rmp_serde:: from_slice ( bytes) . expect ( "" )
25+ let dict: TokenDict = rmp_serde:: from_slice ( bytes) . expect ( "" ) ;
26+ dict. into_bpe ( )
2427} ) ;
2528
2629/// Representation of the byte pair dictionary.
@@ -612,15 +615,23 @@ mod tests {
612615 }
613616}
614617
615- #[ cfg( test) ]
616618mod data {
617- use std:: fs:: File ;
618- use std:: path:: PathBuf ;
619-
620- use serde:: Serialize ;
619+ use serde:: { Deserialize , Serialize } ;
621620
622621 use crate :: byte_pair_encoding:: BytePairEncoding ;
623622
623+ #[ derive( Serialize , Deserialize ) ]
624+ pub ( crate ) struct TokenDict {
625+ tokens : Vec < Vec < u8 > > ,
626+ hash_factor : u64 ,
627+ }
628+
629+ impl TokenDict {
630+ pub ( crate ) fn into_bpe ( self ) -> BytePairEncoding {
631+ BytePairEncoding :: from_dictionary ( self . tokens , Some ( self . hash_factor ) )
632+ }
633+ }
634+
624635 #[ test]
625636 fn update_token_dicts ( ) {
626637 serialize_tokens (
@@ -637,22 +648,34 @@ mod data {
637648 ) ;
638649 }
639650
651+ #[ cfg( test) ]
640652 #[ track_caller]
641653 fn serialize_tokens (
642654 name : & str ,
643- dict : & tiktoken_rs:: CoreBPE ,
655+ bpe : & tiktoken_rs:: CoreBPE ,
644656 num_tokens : usize ,
645657 hash_factor : u64 ,
646658 ) {
659+ use std:: fs:: File ;
660+ use std:: path:: PathBuf ;
661+
662+ use itertools:: Itertools ;
663+ use serde:: Serialize ;
664+
647665 let path = PathBuf :: from ( file ! ( ) ) ;
648666 let dir = path. parent ( ) . unwrap ( ) ;
649667 let data_file = dir. join ( format ! ( "data/bpe_{name}.dict" ) ) ;
650668 let current_dir = std:: env:: current_dir ( ) . unwrap ( ) ;
651669 let abs_path = current_dir. parent ( ) . unwrap ( ) . parent ( ) . unwrap ( ) ;
652670 let file = File :: create ( abs_path. join ( data_file) ) . unwrap ( ) ;
653671 let mut serializer = rmp_serde:: Serializer :: new ( file) ;
654- BytePairEncoding :: from_tiktoken ( dict, num_tokens, Some ( hash_factor) )
655- . serialize ( & mut serializer)
656- . unwrap ( ) ;
672+ let tokens = ( 0 ..num_tokens)
673+ . map ( |i| bpe. _decode_native ( & [ i] ) )
674+ . collect_vec ( ) ;
675+ let dict = TokenDict {
676+ tokens,
677+ hash_factor,
678+ } ;
679+ dict. serialize ( & mut serializer) . unwrap ( ) ;
657680 }
658681}
0 commit comments