Skip to content

Commit 8e1d6f6

Browse files
committed
Bump transformers version to v4.51.0 and syncode version to v0.4.11
1 parent a7b657f commit 8e1d6f6

19 files changed

Lines changed: 171 additions & 69 deletions

.github/workflows/run_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
uses: actions/cache@v3
2828
with:
2929
path: /home/runner/work/syncode/syncode/cache/mask_stores/
30-
key: files-${{ hashFiles('syncode/parsers/grammars/python_grammar.lark', 'syncode/dfa_mask_store.py') }}
30+
key: files-${{ hashFiles('syncode/parsers/grammars/python.lark', 'syncode/dfa_mask_store.py') }}
3131
- name: Run Tests
3232
run: |
3333
python3 -m unittest tests.test_misc

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ SynCode depends on HuggingFace [transformers](https://github.com/huggingface/tra
6969

7070
| SynCode version | Required transformers version | Python version |
7171
| -------------- | ----------------------------- | -------------- |
72-
| `v0.4.10` (latest) | `v4.44.0` | 3.6 - 3.12 |
72+
| `v0.4.11` (latest) | `v4.51.0` | 3.6 - 3.12 |
7373

7474
**Note:** Python 3.13 is not currently supported due to dependency constraints.
7575

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "syncode"
7-
version="0.4.10"
7+
version="0.4.11"
88
requires-python = ">=3.6,<3.13"
99
description = "Grammar-guided code generation tool"
1010
readme = "README.md"
@@ -24,7 +24,7 @@ dependencies = [
2424
"regex==2023.8.8",
2525
"torch",
2626
"tqdm",
27-
"transformers==4.44.0",
27+
"transformers==4.51.0",
2828
"datasets",
2929
"jsonschema",
3030
]

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ interegular
33
regex==2023.8.8
44
torch
55
tqdm
6-
transformers==4.44.0; python_version < "3.13"
6+
transformers==4.51.0; python_version < "3.13"
77
datasets
88
jsonschema

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
"regex==2023.8.8",
1212
"torch",
1313
"tqdm",
14-
"transformers==4.44.0",
14+
"transformers==4.51.0",
1515
"datasets",
1616
"jsonschema"
1717
]
1818

1919
setuptools.setup(
2020
name="syncode",
21-
version="0.4.10",
21+
version="0.4.11",
2222
author="Shubham Ugare",
2323
author_email="shubhamugare@gmail.com",
2424
description="This package provides the tool for grammar augmented LLM generation.",

syncode/evaluation/json_eval.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ def run_eval_for_task(syncode, num_samples_per_task, problem, samples, pbar, tas
7272
else:
7373
problem["prompt"][0]['content'] = f"{problem['prompt'][0]['content']}\nOnly output JSON.\nJSON:\n"
7474

75-
prompt = syncode.model.tokenizer.apply_chat_template(problem["prompt"], tokenize = False)
75+
if syncode.model.tokenizer.chat_template is not None:
76+
prompt = syncode.model.tokenizer.apply_chat_template(problem["prompt"], tokenize = False)
77+
else:
78+
prompt = problem["prompt"][0]['content']
7679

7780
batch_completions = syncode.model.generate_grammar_constrained_completion(prompt, num_samples_per_task)
7881
for completion_id, completion in enumerate(batch_completions):

syncode/language_model.py

Lines changed: 156 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1-
from ast import Tuple
1+
from ast import Dict, Tuple
22
import time
33
import torch
44
import syncode.common as common
55
from syncode.grammar_mask.logits_processor import SyncodeLogitsProcessor
66
from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria, PreTrainedModel
77
from syncode.parsers.grammars import Grammar
8-
from syncode.utils.generation import filter_code, fix_indents
9-
from typing import Callable, Iterable, Union
8+
from typing import Any, Callable, Iterable, Union
109
from transformers.generation.utils import GenerationMode
1110
from transformers.generation.configuration_utils import GenerationConfig
12-
11+
from transformers.generation.logits_process import (
12+
TemperatureLogitsWarper,
13+
TopKLogitsWarper,
14+
TopPLogitsWarper,
15+
)
16+
from transformers.cache_utils import Cache
1317

1418
class KeywordsStoppingCriteria(StoppingCriteria):
1519
'''
@@ -189,34 +193,50 @@ def _generate(
189193
"""
190194
We support greedy search and sampling for batch size 1 otherwise we use the generate function from transformers library.
191195
"""
192-
token_ids, attention_mask, past_key_values = inputs['input_ids'], inputs['attention_mask'], None
193-
196+
197+
# Get the input ids and attention mask
198+
token_ids = inputs['input_ids']
199+
model_kwargs = {}
200+
model_kwargs['attention_mask'] = inputs['attention_mask']
201+
model_kwargs['use_cache'] = True
202+
model_kwargs = self._get_initial_cache_position(token_ids, model_kwargs)
203+
194204
# This does not include grammar decoder
195-
self.model._prepare_special_tokens(gen_config, False, device=self.device)
205+
self.model._prepare_special_tokens(gen_config, True, device=self.device)
196206

197207
# Add logits processor for generation parameters such as top_k, top_p, temperature, etc.
198-
logits_processor = self.model._get_logits_warper(gen_config, self.device)
208+
logits_processor = self._get_logits_processors(gen_config)
199209

200210
max_tokens = self.gen_args['max_new_tokens']+token_ids.size(1)
201-
self.model.config.pad_token_id = pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
211+
self.model.config.pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
212+
213+
# Prepare the cache. (This is copied from the transformers generation_utils.py)
214+
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `gen_config`.
215+
# - different models have a different cache name expected by the model (default = "past_key_values")
216+
# - `max_length`, prepared above, is used to determine the maximum cache length
217+
max_cache_length = max_tokens-1
218+
self.model._prepare_cache_for_generation(
219+
gen_config,
220+
model_kwargs,
221+
assistant_model=None,
222+
batch_size=token_ids.shape[0],
223+
max_cache_length=max_cache_length,
224+
device=self.device
225+
)
202226

203227
while True:
228+
model_inputs = self.model.prepare_inputs_for_generation(token_ids, **model_kwargs)
204229
try:
205-
if past_key_values: # Get the last token if kv is cached for all previous tokens
206-
input_ids = token_ids[..., -1].unsqueeze(-1)
207-
else:
208-
input_ids = token_ids
209-
210-
outputs = self.model(
211-
input_ids,
212-
attention_mask=attention_mask,
213-
past_key_values=past_key_values
214-
)
230+
outputs = self.model(**model_inputs, return_dict=True)
215231
except IndexError as e:
216232
raise ValueError(f"The input length exceeds the context length of the model. {e}")
217233

218-
next_token_scores, past_key_values = outputs.logits[:, -1, :], outputs.past_key_values
234+
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
219235

236+
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
237+
# (the clone itself is always small)
238+
next_token_scores = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=token_ids.device)
239+
220240
if grammar_decoder is not None:
221241
next_token = self._get_next_token(gen_mode, token_ids, logits_processor, next_token_scores)
222242
is_valid = grammar_decoder.is_valid(token_ids, next_token)
@@ -240,12 +260,6 @@ def _generate(
240260
if finish_generation or next_token == self.tokenizer.eos_token_id or token_ids.size(1) >= max_tokens:
241261
break
242262

243-
# Update attention mask
244-
attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype).to(self.device)], dim=-1)
245-
246-
if debug:
247-
grammar_decoder.print_debug()
248-
249263
return token_ids
250264

251265
def _get_next_token(self, gen_mode, token_ids, logits_processor, next_token_scores):
@@ -258,34 +272,140 @@ def _get_next_token(self, gen_mode, token_ids, logits_processor, next_token_scor
258272
return next_token
259273

260274
def _get_generation_mode(
261-
self, generation_config: GenerationConfig
275+
self, gen_config: GenerationConfig
262276
) -> GenerationMode:
263277
"""
264278
Returns the generation mode triggered by a [`GenerationConfig`] instance.
265279
"""
266-
if generation_config.constraints is not None or generation_config.force_words_ids is not None:
280+
if gen_config.constraints is not None or gen_config.force_words_ids is not None:
267281
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
268-
elif generation_config.num_beams == 1:
269-
if generation_config.do_sample is False:
282+
elif gen_config.num_beams == 1:
283+
if gen_config.do_sample is False:
270284
if (
271-
generation_config.top_k is not None
272-
and generation_config.top_k > 1
273-
and generation_config.penalty_alpha is not None
274-
and generation_config.penalty_alpha > 0
285+
gen_config.top_k is not None
286+
and gen_config.top_k > 1
287+
and gen_config.penalty_alpha is not None
288+
and gen_config.penalty_alpha > 0
275289
):
276290
generation_mode = GenerationMode.CONTRASTIVE_SEARCH
277291
else:
278292
generation_mode = GenerationMode.GREEDY_SEARCH
279293
else:
280294
generation_mode = GenerationMode.SAMPLE
281295
else:
282-
if generation_config.num_beam_groups > 1:
296+
if gen_config.num_beam_groups > 1:
283297
generation_mode = GenerationMode.GROUP_BEAM_SEARCH
284-
elif generation_config.do_sample is True:
298+
elif gen_config.do_sample is True:
285299
generation_mode = GenerationMode.BEAM_SAMPLE
286300
else:
287301
generation_mode = GenerationMode.BEAM_SEARCH
288302
return generation_mode
289303

290304
def tokenize(self, s: str) -> 'Iterable[int]':
291305
return self.tokenizer.encode(s, add_special_tokens=False)
306+
307+
def _get_logits_processors(self, gen_config: GenerationConfig) -> LogitsProcessorList:
308+
"""
309+
Returns a [`~transformers.generation.LogitsProcessorList`] with the appropriate [`LogitsProcessor`]s to use for
310+
generation.
311+
"""
312+
processors = LogitsProcessorList()
313+
if gen_config.do_sample:
314+
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
315+
# better score (i.e. keep len(list(gen_config._eos_token_tensor)) + 1)
316+
if gen_config.num_beams > 1:
317+
if isinstance(gen_config._eos_token_tensor, list):
318+
min_tokens_to_keep = len(gen_config._eos_token_tensor) + 1
319+
elif isinstance(gen_config._eos_token_tensor, torch.Tensor):
320+
min_tokens_to_keep = gen_config._eos_token_tensor.shape[0] + 1
321+
else:
322+
min_tokens_to_keep = 2
323+
else:
324+
min_tokens_to_keep = 1
325+
326+
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
327+
# all samplers can be found in `generation_utils_samplers.py`
328+
if gen_config.temperature is not None and gen_config.temperature != 1.0:
329+
processors.append(TemperatureLogitsWarper(gen_config.temperature))
330+
if gen_config.top_k is not None and gen_config.top_k != 0:
331+
processors.append(
332+
TopKLogitsWarper(top_k=gen_config.top_k, min_tokens_to_keep=min_tokens_to_keep)
333+
)
334+
if gen_config.top_p is not None and gen_config.top_p < 1.0:
335+
processors.append(
336+
TopPLogitsWarper(top_p=gen_config.top_p, min_tokens_to_keep=min_tokens_to_keep)
337+
)
338+
return processors
339+
340+
def _update_model_kwargs_for_generation(
341+
self,
342+
outputs,
343+
model_kwargs: dict[str, Any],
344+
) -> dict[str, Any]:
345+
# Variable names used to hold the cache at generation time
346+
ALL_CACHE_NAMES = [
347+
"past_key_values", # default
348+
"cache_params", # mamba-based models
349+
"state", # rwkv
350+
"mems", # xlnet
351+
"past_buckets_states", # reformer
352+
]
353+
354+
# update past_key_values keeping its naming used in model code
355+
for possible_cache_name in ALL_CACHE_NAMES:
356+
if possible_cache_name in outputs:
357+
if possible_cache_name in ("past_buckets_states", "mems"):
358+
cache_name = "past_key_values"
359+
else:
360+
cache_name = possible_cache_name
361+
model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
362+
break
363+
364+
# update token_type_ids with last value
365+
if "token_type_ids" in model_kwargs:
366+
token_type_ids = model_kwargs["token_type_ids"]
367+
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
368+
369+
# assuming is_encoder_decoder = False
370+
if "attention_mask" in model_kwargs:
371+
attention_mask = model_kwargs["attention_mask"]
372+
model_kwargs["attention_mask"] = torch.cat(
373+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
374+
)
375+
376+
if model_kwargs.get("use_cache", True):
377+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 # num_new_tokens = 1
378+
else:
379+
past_positions = model_kwargs.pop("cache_position")
380+
new_positions = torch.arange(
381+
past_positions[-1] + 1, past_positions[-1] + 2, dtype=past_positions.dtype
382+
).to(past_positions.device)
383+
model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
384+
385+
return model_kwargs
386+
387+
def _get_initial_cache_position(self, input_ids, model_kwargs):
388+
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
389+
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
390+
if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
391+
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
392+
elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder:
393+
cache_position = (
394+
torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
395+
)
396+
else:
397+
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
398+
399+
past_length = 0
400+
if model_kwargs.get("past_key_values") is not None:
401+
cache = model_kwargs["past_key_values"]
402+
past_length = 0
403+
if not isinstance(cache, Cache):
404+
past_length = cache[0][0].shape[2]
405+
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
406+
past_length = cache.get_seq_length()
407+
408+
cache_position = cache_position[past_length:]
409+
410+
model_kwargs["cache_position"] = cache_position
411+
return model_kwargs
File renamed without changes.

0 commit comments

Comments
 (0)