Skip to content

Commit 3c36ba0

Browse files
authored
Merge pull request #183 from structuredllm/bump
Bump transformers version to v4.51.0 and syncode version to v0.4.11
2 parents a7b657f + d343567 commit 3c36ba0

20 files changed

Lines changed: 204 additions & 85 deletions

.github/workflows/run_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
uses: actions/cache@v3
2828
with:
2929
path: /home/runner/work/syncode/syncode/cache/mask_stores/
30-
key: files-${{ hashFiles('syncode/parsers/grammars/python_grammar.lark', 'syncode/dfa_mask_store.py') }}
30+
key: files-${{ hashFiles('syncode/parsers/grammars/python.lark', 'syncode/dfa_mask_store.py') }}
3131
- name: Run Tests
3232
run: |
3333
python3 -m unittest tests.test_misc

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ SynCode depends on HuggingFace [transformers](https://github.com/huggingface/tra
6969

7070
| SynCode version | Required transformers version | Python version |
7171
| -------------- | ----------------------------- | -------------- |
72-
| `v0.4.10` (latest) | `v4.44.0` | 3.6 - 3.12 |
72+
| `v0.4.11` (latest) | `v4.51.0` | 3.6 - 3.12 |
7373

7474
**Note:** Python 3.13 is not currently supported due to dependency constraints.
7575

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "syncode"
7-
version="0.4.10"
7+
version="0.4.11"
88
requires-python = ">=3.6,<3.13"
99
description = "Grammar-guided code generation tool"
1010
readme = "README.md"
@@ -24,7 +24,7 @@ dependencies = [
2424
"regex==2023.8.8",
2525
"torch",
2626
"tqdm",
27-
"transformers==4.44.0",
27+
"transformers==4.51.0",
2828
"datasets",
2929
"jsonschema",
3030
]

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
accelerate
12
fire
23
interegular
34
regex==2023.8.8
45
torch
56
tqdm
6-
transformers==4.44.0; python_version < "3.13"
7+
transformers==4.51.0; python_version < "3.13"
78
datasets
89
jsonschema

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
"regex==2023.8.8",
1212
"torch",
1313
"tqdm",
14-
"transformers==4.44.0",
14+
"transformers==4.51.0",
1515
"datasets",
1616
"jsonschema"
1717
]
1818

1919
setuptools.setup(
2020
name="syncode",
21-
version="0.4.10",
21+
version="0.4.11",
2222
author="Shubham Ugare",
2323
author_email="shubhamugare@gmail.com",
2424
description="This package provides the tool for grammar augmented LLM generation.",

syncode/common.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,32 @@
1212

1313

1414
def load_model(model_name, device, quantize, device_map = None):
15+
torch_dtype = torch.bfloat16 if quantize else "auto"
16+
device_map = device_map if device_map is not None else "auto"
17+
18+
attn_implementation = None
19+
if "gemma-3" in model_name:
20+
# This is due to the gemma-3 issue with SDPA implementation
21+
# https://github.com/google-deepmind/gemma/issues/169
22+
attn_implementation = "eager"
23+
logging.info("Using slower \"eager\" attention implementation for gemma-3 due to issue with SDPA implementation")
24+
1525
if model_name == 'test':
1626
model = AutoModelForCausalLM.from_pretrained('bigcode/tiny_starcoder_py').to(device)
1727
elif model_name == 'test-instruct':
1828
model = AutoModelForCausalLM.from_pretrained("rahuldshetty/tiny-starcoder-instruct")
1929
else:
2030
if device_map is not None:
21-
if (quantize):
22-
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True, device_map = device_map).eval()
23-
else:
24-
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True, device_map = device_map).eval()
25-
else:
26-
if (quantize):
27-
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)
28-
else:
29-
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)
31+
logging.info(f"Loading model {model_name} with device:{device}, device_map:{device_map}, torch_dtype:{torch_dtype}")
32+
model = AutoModelForCausalLM.from_pretrained(
33+
model_name,
34+
torch_dtype=torch_dtype,
35+
cache_dir=HF_CACHE,
36+
token=HF_ACCESS_TOKEN,
37+
trust_remote_code=True,
38+
device_map = device_map,
39+
attn_implementation=attn_implementation
40+
).eval()
3041
return model
3142

3243
def load_tokenizer(model_name):
@@ -35,7 +46,12 @@ def load_tokenizer(model_name):
3546
elif model_name == 'test-instruct':
3647
tokenizer = AutoTokenizer.from_pretrained("rahuldshetty/tiny-starcoder-instruct")
3748
else:
38-
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True)
49+
tokenizer = AutoTokenizer.from_pretrained(
50+
model_name,
51+
cache_dir=HF_CACHE,
52+
token=HF_ACCESS_TOKEN,
53+
trust_remote_code=True
54+
)
3955
return tokenizer
4056

4157
def get_output_path(model_name, grammar, dataset, num_samples, mode):

syncode/evaluation/json_eval.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ def run_eval_for_task(syncode, num_samples_per_task, problem, samples, pbar, tas
7272
else:
7373
problem["prompt"][0]['content'] = f"{problem['prompt'][0]['content']}\nOnly output JSON.\nJSON:\n"
7474

75-
prompt = syncode.model.tokenizer.apply_chat_template(problem["prompt"], tokenize = False)
75+
if syncode.model.tokenizer.chat_template is not None:
76+
prompt = syncode.model.tokenizer.apply_chat_template(problem["prompt"], tokenize = False)
77+
else:
78+
prompt = problem["prompt"][0]['content']
7679

7780
batch_completions = syncode.model.generate_grammar_constrained_completion(prompt, num_samples_per_task)
7881
for completion_id, completion in enumerate(batch_completions):

0 commit comments

Comments
 (0)