1- from ast import Tuple
1+ from ast import Dict , Tuple
22import time
33import torch
44import syncode .common as common
55from syncode .grammar_mask .logits_processor import SyncodeLogitsProcessor
66from transformers import LogitsProcessorList , StoppingCriteriaList , StoppingCriteria , PreTrainedModel
77from syncode .parsers .grammars import Grammar
8- from syncode .utils .generation import filter_code , fix_indents
9- from typing import Callable , Iterable , Union
8+ from typing import Any , Callable , Iterable , Union
109from transformers .generation .utils import GenerationMode
1110from transformers .generation .configuration_utils import GenerationConfig
12-
11+ from transformers .generation .logits_process import (
12+ TemperatureLogitsWarper ,
13+ TopKLogitsWarper ,
14+ TopPLogitsWarper ,
15+ )
16+ from transformers .cache_utils import Cache
1317
1418class KeywordsStoppingCriteria (StoppingCriteria ):
1519 '''
@@ -189,34 +193,50 @@ def _generate(
189193 """
190194 We support greedy search and sampling for batch size 1 otherwise we use the generate function from transformers library.
191195 """
192- token_ids , attention_mask , past_key_values = inputs ['input_ids' ], inputs ['attention_mask' ], None
193-
196+
197+ # Get the input ids and attention mask
198+ token_ids = inputs ['input_ids' ]
199+ model_kwargs = {}
200+ model_kwargs ['attention_mask' ] = inputs ['attention_mask' ]
201+ model_kwargs ['use_cache' ] = True
202+ model_kwargs = self ._get_initial_cache_position (token_ids , model_kwargs )
203+
194204 # This does not include grammar decoder
195- self .model ._prepare_special_tokens (gen_config , False , device = self .device )
205+ self .model ._prepare_special_tokens (gen_config , True , device = self .device )
196206
197207 # Add logits processor for generation parameters such as top_k, top_p, temperature, etc.
198- logits_processor = self .model . _get_logits_warper (gen_config , self . device )
208+ logits_processor = self ._get_logits_processors (gen_config )
199209
200210 max_tokens = self .gen_args ['max_new_tokens' ]+ token_ids .size (1 )
201- self .model .config .pad_token_id = pad_token_id = self .tokenizer .pad_token_id if self .tokenizer .pad_token_id is not None else self .tokenizer .eos_token_id
211+ self .model .config .pad_token_id = self .tokenizer .pad_token_id if self .tokenizer .pad_token_id is not None else self .tokenizer .eos_token_id
212+
213+ # Prepare the cache. (This is copied from the transformers generation_utils.py)
214+ # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `gen_config`.
215+ # - different models have a different cache name expected by the model (default = "past_key_values")
216+ # - `max_length`, prepared above, is used to determine the maximum cache length
217+ max_cache_length = max_tokens - 1
218+ self .model ._prepare_cache_for_generation (
219+ gen_config ,
220+ model_kwargs ,
221+ assistant_model = None ,
222+ batch_size = token_ids .shape [0 ],
223+ max_cache_length = max_cache_length ,
224+ device = self .device
225+ )
202226
203227 while True :
228+ model_inputs = self .model .prepare_inputs_for_generation (token_ids , ** model_kwargs )
204229 try :
205- if past_key_values : # Get the last token if kv is cached for all previous tokens
206- input_ids = token_ids [..., - 1 ].unsqueeze (- 1 )
207- else :
208- input_ids = token_ids
209-
210- outputs = self .model (
211- input_ids ,
212- attention_mask = attention_mask ,
213- past_key_values = past_key_values
214- )
230+ outputs = self .model (** model_inputs , return_dict = True )
215231 except IndexError as e :
216232 raise ValueError (f"The input length exceeds the context length of the model. { e } " )
217233
218- next_token_scores , past_key_values = outputs . logits [:, - 1 , :], outputs . past_key_values
234+ model_kwargs = self . _update_model_kwargs_for_generation ( outputs , model_kwargs )
219235
236+ # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
237+ # (the clone itself is always small)
238+ next_token_scores = outputs .logits [:, - 1 , :].to (copy = True , dtype = torch .float32 , device = token_ids .device )
239+
220240 if grammar_decoder is not None :
221241 next_token = self ._get_next_token (gen_mode , token_ids , logits_processor , next_token_scores )
222242 is_valid = grammar_decoder .is_valid (token_ids , next_token )
@@ -240,12 +260,6 @@ def _generate(
240260 if finish_generation or next_token == self .tokenizer .eos_token_id or token_ids .size (1 ) >= max_tokens :
241261 break
242262
243- # Update attention mask
244- attention_mask = torch .cat ([attention_mask , torch .ones ((attention_mask .size (0 ), 1 ), dtype = attention_mask .dtype ).to (self .device )], dim = - 1 )
245-
246- if debug :
247- grammar_decoder .print_debug ()
248-
249263 return token_ids
250264
251265 def _get_next_token (self , gen_mode , token_ids , logits_processor , next_token_scores ):
@@ -258,34 +272,140 @@ def _get_next_token(self, gen_mode, token_ids, logits_processor, next_token_scor
258272 return next_token
259273
260274 def _get_generation_mode (
261- self , generation_config : GenerationConfig
275+ self , gen_config : GenerationConfig
262276 ) -> GenerationMode :
263277 """
264278 Returns the generation mode triggered by a [`GenerationConfig`] instance.
265279 """
266- if generation_config .constraints is not None or generation_config .force_words_ids is not None :
280+ if gen_config .constraints is not None or gen_config .force_words_ids is not None :
267281 generation_mode = GenerationMode .CONSTRAINED_BEAM_SEARCH
268- elif generation_config .num_beams == 1 :
269- if generation_config .do_sample is False :
282+ elif gen_config .num_beams == 1 :
283+ if gen_config .do_sample is False :
270284 if (
271- generation_config .top_k is not None
272- and generation_config .top_k > 1
273- and generation_config .penalty_alpha is not None
274- and generation_config .penalty_alpha > 0
285+ gen_config .top_k is not None
286+ and gen_config .top_k > 1
287+ and gen_config .penalty_alpha is not None
288+ and gen_config .penalty_alpha > 0
275289 ):
276290 generation_mode = GenerationMode .CONTRASTIVE_SEARCH
277291 else :
278292 generation_mode = GenerationMode .GREEDY_SEARCH
279293 else :
280294 generation_mode = GenerationMode .SAMPLE
281295 else :
282- if generation_config .num_beam_groups > 1 :
296+ if gen_config .num_beam_groups > 1 :
283297 generation_mode = GenerationMode .GROUP_BEAM_SEARCH
284- elif generation_config .do_sample is True :
298+ elif gen_config .do_sample is True :
285299 generation_mode = GenerationMode .BEAM_SAMPLE
286300 else :
287301 generation_mode = GenerationMode .BEAM_SEARCH
288302 return generation_mode
289303
290304 def tokenize (self , s : str ) -> 'Iterable[int]' :
291305 return self .tokenizer .encode (s , add_special_tokens = False )
306+
307+ def _get_logits_processors (self , gen_config : GenerationConfig ) -> LogitsProcessorList :
308+ """
309+ Returns a [`~transformers.generation.LogitsProcessorList`] with the appropriate [`LogitsProcessor`]s to use for
310+ generation.
311+ """
312+ processors = LogitsProcessorList ()
313+ if gen_config .do_sample :
314+ # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
315+ # better score (i.e. keep len(list(gen_config._eos_token_tensor)) + 1)
316+ if gen_config .num_beams > 1 :
317+ if isinstance (gen_config ._eos_token_tensor , list ):
318+ min_tokens_to_keep = len (gen_config ._eos_token_tensor ) + 1
319+ elif isinstance (gen_config ._eos_token_tensor , torch .Tensor ):
320+ min_tokens_to_keep = gen_config ._eos_token_tensor .shape [0 ] + 1
321+ else :
322+ min_tokens_to_keep = 2
323+ else :
324+ min_tokens_to_keep = 1
325+
326+ # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
327+ # all samplers can be found in `generation_utils_samplers.py`
328+ if gen_config .temperature is not None and gen_config .temperature != 1.0 :
329+ processors .append (TemperatureLogitsWarper (gen_config .temperature ))
330+ if gen_config .top_k is not None and gen_config .top_k != 0 :
331+ processors .append (
332+ TopKLogitsWarper (top_k = gen_config .top_k , min_tokens_to_keep = min_tokens_to_keep )
333+ )
334+ if gen_config .top_p is not None and gen_config .top_p < 1.0 :
335+ processors .append (
336+ TopPLogitsWarper (top_p = gen_config .top_p , min_tokens_to_keep = min_tokens_to_keep )
337+ )
338+ return processors
339+
340+ def _update_model_kwargs_for_generation (
341+ self ,
342+ outputs ,
343+ model_kwargs : dict [str , Any ],
344+ ) -> dict [str , Any ]:
345+ # Variable names used to hold the cache at generation time
346+ ALL_CACHE_NAMES = [
347+ "past_key_values" , # default
348+ "cache_params" , # mamba-based models
349+ "state" , # rwkv
350+ "mems" , # xlnet
351+ "past_buckets_states" , # reformer
352+ ]
353+
354+ # update past_key_values keeping its naming used in model code
355+ for possible_cache_name in ALL_CACHE_NAMES :
356+ if possible_cache_name in outputs :
357+ if possible_cache_name in ("past_buckets_states" , "mems" ):
358+ cache_name = "past_key_values"
359+ else :
360+ cache_name = possible_cache_name
361+ model_kwargs [cache_name ] = getattr (outputs , possible_cache_name )
362+ break
363+
364+ # update token_type_ids with last value
365+ if "token_type_ids" in model_kwargs :
366+ token_type_ids = model_kwargs ["token_type_ids" ]
367+ model_kwargs ["token_type_ids" ] = torch .cat ([token_type_ids , token_type_ids [:, - 1 ].unsqueeze (- 1 )], dim = - 1 )
368+
369+ # assuming is_encoder_decoder = False
370+ if "attention_mask" in model_kwargs :
371+ attention_mask = model_kwargs ["attention_mask" ]
372+ model_kwargs ["attention_mask" ] = torch .cat (
373+ [attention_mask , attention_mask .new_ones ((attention_mask .shape [0 ], 1 ))], dim = - 1
374+ )
375+
376+ if model_kwargs .get ("use_cache" , True ):
377+ model_kwargs ["cache_position" ] = model_kwargs ["cache_position" ][- 1 :] + 1 # num_new_tokens = 1
378+ else :
379+ past_positions = model_kwargs .pop ("cache_position" )
380+ new_positions = torch .arange (
381+ past_positions [- 1 ] + 1 , past_positions [- 1 ] + 2 , dtype = past_positions .dtype
382+ ).to (past_positions .device )
383+ model_kwargs ["cache_position" ] = torch .cat ((past_positions , new_positions ))
384+
385+ return model_kwargs
386+
387+ def _get_initial_cache_position (self , input_ids , model_kwargs ):
388+ """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
389+ # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
390+ if "inputs_embeds" in model_kwargs and not self .config .is_encoder_decoder :
391+ cache_position = torch .ones_like (model_kwargs ["inputs_embeds" ][0 , :, 0 ], dtype = torch .int64 ).cumsum (0 ) - 1
392+ elif "decoder_inputs_embeds" in model_kwargs and self .config .is_encoder_decoder :
393+ cache_position = (
394+ torch .ones_like (model_kwargs ["decoder_inputs_embeds" ][0 , :, 0 ], dtype = torch .int64 ).cumsum (0 ) - 1
395+ )
396+ else :
397+ cache_position = torch .ones_like (input_ids [0 , :], dtype = torch .int64 ).cumsum (0 ) - 1
398+
399+ past_length = 0
400+ if model_kwargs .get ("past_key_values" ) is not None :
401+ cache = model_kwargs ["past_key_values" ]
402+ past_length = 0
403+ if not isinstance (cache , Cache ):
404+ past_length = cache [0 ][0 ].shape [2 ]
405+ elif hasattr (cache , "get_seq_length" ) and cache .get_seq_length () is not None :
406+ past_length = cache .get_seq_length ()
407+
408+ cache_position = cache_position [past_length :]
409+
410+ model_kwargs ["cache_position" ] = cache_position
411+ return model_kwargs
0 commit comments