-
Notifications
You must be signed in to change notification settings - Fork 14
Implement dropout for encode_minimal
#98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
113a2fe
db7dd30
aa49a37
a336729
23c8ec9
08d4200
a10cce2
c215210
8ad286a
5493b12
3d63504
bba0765
65eb519
3f0d4fe
83dd2f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -551,6 +551,45 @@ impl BytePairEncoding { | |
| encoded.reverse(); | ||
| encoded | ||
| } | ||
|
|
||
| /// This function computes the shortest possible encoding sequence which will usually differ from the | ||
| /// tokenization produced by the original BPE algorithm. | ||
| #[cfg(feature = "rand")] | ||
| pub fn encode_minimal_dropout(&self, text: &[u8], dropout: f32) -> Vec<u32> { | ||
| use rand::Rng; | ||
| assert!(0.0 <= dropout); | ||
| assert!(dropout <= 1.0); | ||
| let mut rng = rand::rng(); | ||
|
|
||
| let mut last_token: Vec<(u32, u32)> = Vec::with_capacity(text.len()); | ||
| let mut state = self.overlapping_searcher.start_state(); | ||
| for (pos, c) in text.iter().enumerate() { | ||
| let (s, iter) = self.overlapping_searcher.consume(state, pos + 1, *c); | ||
| state = s; | ||
| let mut best = (0, u32::MAX); | ||
| for m in iter { | ||
| if m.start() == 0 { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there some paper explaining in more detail how the randomization is supposed to work? Also, some documentation would be nice (as part of some readme and/or doc comment). If this is a one-to-one implementation of some paper, then we can probably just link to that paper.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The paper: https://arxiv.org/abs/1910.13267 We're interested in Algorithm 1 (page 3). Improvements rationale can be seen on Figure 6. I don't think it's an one-to-one implementation, since
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ...although I admit I don't really understand Intuition is that dropout roughly equals number of rejected merges in the final encoding, e.g. dropout ~=1 would result in almost single-byte encoding. However, I don't see that with where dictionary is So I'd appreciate any directions if you have any :)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Note: this is very different from how BPE works and cannot produce the same output as the algorithm in the paper. The only implementation in this crate which follows the "standard" BPE algorithm is The problem with the algorithm in the paper is that it is VERY slow. So, maybe it is good enough to pick a different randomization process which follows the idea of the paper in spirit?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @aneubeck thanks for the explanation, that's actually very helpful. I guess that the only thing that matters is just being able to drop some merges before actually building tokenization. Could you have a look at the updated approach? I've changed the approach that I had before (which I think was very wrong), and instead now consider "best" tokens if they are not in "forbidden_tokens", which have been constructed prior to tokenization. My only worry is the single-byte tokens -- I'm not sure how they're handled, and I wouldn't like to discard them from the allowed tokens, but I'm not sure how to handle that properly. I'm talking about this line: ...
& (!(forbidden_tokens_set.contains(&m.value())) | ((m.end() - m.start()) == 1))
...I'm not sure if the second condition should be present or not, basically.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the changes! There was a little bug with how you treated tokens which started at the beginning of the text (you didn't filter larger tokens out there...). I also got rid of the pretty expensive lookup tables which you were computing. Those would slow down the processing drastically! It would be nice if you could extend the comment of this function describing in more detail what it does (i.e. we uniformly drop edges from the graph I described above, but always keep the one-byte tokens such that the graph stays connected). On my Macbook I measured about 30million input characters/sec with dropout and 40 million/sec with the "standard" minimal_encoding impelmentation.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for the changes as well!
Will do!
I'll spend some time playing around with a toy example (with
that's pretty cool :)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @aneubeck I've added some explanation and updated README slightly I'm running benchmarks now -- I guess it's simply Also, I'm running them on m4 -- should I update the description in README accordingly, or would you prefer to run it on your machine?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will try to review the changes tomorrow. |
||
| best = (m.value(), 1); | ||
| break; | ||
| } else if last_token[m.start() - 1].1 + 1 < best.1 { | ||
| best = (m.value(), last_token[m.start() - 1].1 + 1); | ||
| if rng.random_range(0.0..=1.0) < dropout { | ||
| best = (m.value(), 1); | ||
| } | ||
| } | ||
| } | ||
| last_token.push(best); | ||
| } | ||
| let mut encoded = Vec::with_capacity(last_token.last().map(|l| l.1 as usize).unwrap_or(0)); | ||
| let mut pos = text.len(); | ||
| while pos > 0 { | ||
| let token = last_token[pos - 1].0; | ||
| encoded.push(token); | ||
| pos -= self.token_len(token); | ||
| } | ||
| encoded.reverse(); | ||
| encoded | ||
| } | ||
| } | ||
|
|
||
| /// Create a random test string for the given [`BytePairEncoding`]. The string will be at least [`min_bytes`] long. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in order to get reproducible results, the randon number generator should be passed in as argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
seed: Option<u64>argument