Skip to content

Commit d37f8c9

Browse files
committed
Add downstream task evaluation script for llama3_native_te recipe
Adds eval_downstream.py that runs lm-eval benchmarks (arc_challenge, arc_easy, boolq, copa, hellaswag, piqa, winogrande) on trained Lingua 1B checkpoints. Supports both consolidated final_model directories and distributed FSDP2 step checkpoints. Made-with: Cursor
1 parent a047ec3 commit d37f8c9

1 file changed

Lines changed: 366 additions & 0 deletions

File tree

Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
#!/usr/bin/env python
2+
3+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
# SPDX-License-Identifier: LicenseRef-Apache2
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
"""Evaluate a trained Llama checkpoint on downstream NLP benchmarks using lm-eval.
19+
20+
Supports loading from:
21+
1. A consolidated final_model directory (model.safetensors + config.json)
22+
2. A distributed FSDP2 training checkpoint (step_N directory)
23+
24+
Examples:
25+
# From a consolidated final_model (single GPU, no torchrun needed):
26+
python eval_downstream.py \
27+
--checkpoint-path /path/to/ckpt_dir/train_fsdp2/final_model
28+
29+
# From a distributed FSDP2 checkpoint (needs torchrun for weight gathering):
30+
torchrun --nproc_per_node=1 eval_downstream.py \
31+
--checkpoint-path /path/to/ckpt_dir/train_fsdp2/step_60000 \
32+
--from-distributed \
33+
--model-config ./model_configs/lingua-1B
34+
35+
# Custom tasks and batch size:
36+
python eval_downstream.py \
37+
--checkpoint-path /path/to/final_model \
38+
--tasks arc_easy,hellaswag \
39+
--batch-size 16
40+
41+
# Save results to a file:
42+
python eval_downstream.py \
43+
--checkpoint-path /path/to/final_model \
44+
--output-path ./eval_results
45+
"""
46+
47+
from __future__ import annotations
48+
49+
import argparse
50+
import json
51+
import os
52+
import shutil
53+
import subprocess
54+
import sys
55+
import tempfile
56+
import time
57+
from pathlib import Path
58+
59+
60+
DOWNSTREAM_TASKS = "arc_challenge,arc_easy,boolq,copa,hellaswag,piqa,winogrande"
61+
62+
63+
def parse_args() -> argparse.Namespace:
64+
"""Parse command line arguments."""
65+
parser = argparse.ArgumentParser(
66+
description="Evaluate a trained checkpoint on downstream NLP tasks with lm-eval.",
67+
formatter_class=argparse.RawDescriptionHelpFormatter,
68+
epilog=__doc__,
69+
)
70+
parser.add_argument(
71+
"--checkpoint-path",
72+
type=str,
73+
required=True,
74+
help="Path to checkpoint. Either a final_model dir (with model.safetensors) "
75+
"or a step_N distributed checkpoint dir (with --from-distributed).",
76+
)
77+
parser.add_argument(
78+
"--tokenizer",
79+
type=str,
80+
default="meta-llama/Meta-Llama-3-8B",
81+
help="Tokenizer name or path (default: meta-llama/Meta-Llama-3-8B).",
82+
)
83+
parser.add_argument(
84+
"--tasks",
85+
type=str,
86+
default=DOWNSTREAM_TASKS,
87+
help=f"Comma-separated lm-eval task names (default: {DOWNSTREAM_TASKS}).",
88+
)
89+
parser.add_argument(
90+
"--batch-size",
91+
type=str,
92+
default="auto",
93+
help="Batch size for lm-eval. Use 'auto' for automatic selection (default: auto).",
94+
)
95+
parser.add_argument(
96+
"--device",
97+
type=str,
98+
default="cuda:0",
99+
help="Device for lm-eval inference (default: cuda:0).",
100+
)
101+
parser.add_argument(
102+
"--eval-dir",
103+
type=str,
104+
default=None,
105+
help="Directory to store the prepared eval checkpoint. Uses a temp directory if not set.",
106+
)
107+
parser.add_argument(
108+
"--from-distributed",
109+
action="store_true",
110+
help="Treat --checkpoint-path as a distributed FSDP2 checkpoint. Requires torchrun.",
111+
)
112+
parser.add_argument(
113+
"--model-config",
114+
type=str,
115+
default="./model_configs/lingua-1B",
116+
help="Model config path for --from-distributed (default: ./model_configs/lingua-1B).",
117+
)
118+
parser.add_argument(
119+
"--output-path",
120+
type=str,
121+
default=None,
122+
help="Path to save lm-eval results JSON.",
123+
)
124+
parser.add_argument(
125+
"--num-fewshot",
126+
type=int,
127+
default=None,
128+
help="Number of few-shot examples (default: lm-eval task default).",
129+
)
130+
return parser.parse_args()
131+
132+
133+
def export_distributed_checkpoint(checkpoint_path: str, model_config: str, output_path: str) -> bool:
134+
"""Load a distributed FSDP2 checkpoint and export consolidated weights.
135+
136+
Must be called inside a torchrun context. All ranks participate in loading
137+
and gathering, but only rank 0 saves the exported model.
138+
139+
Args:
140+
checkpoint_path: Path to the step_N distributed checkpoint directory.
141+
model_config: Path to model config (e.g. ./model_configs/lingua-1B).
142+
output_path: Directory to save the consolidated model.
143+
144+
Returns:
145+
True if this is rank 0 (should continue to evaluation), False otherwise.
146+
"""
147+
import torch
148+
from safetensors.torch import save_file
149+
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
150+
from torch.distributed.checkpoint.state_dict_loader import load as dcp_load
151+
from torch.distributed.device_mesh import init_device_mesh
152+
from torch.distributed.fsdp import fully_shard
153+
154+
from checkpoint import AppState
155+
from distributed_config import DistributedConfig
156+
from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
157+
from scheduler import get_cosine_annealing_schedule_with_warmup
158+
159+
dist_config = DistributedConfig()
160+
device = torch.device(f"cuda:{dist_config.local_rank}")
161+
torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", device_id=device)
162+
torch.cuda.set_device(dist_config.local_rank)
163+
device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",))
164+
165+
print(f"[Rank {dist_config.rank}] Loading distributed checkpoint from {checkpoint_path}")
166+
167+
config = NVLlamaConfig.from_pretrained(model_config, dtype=torch.bfloat16, attn_input_format="thd")
168+
with torch.device("meta"):
169+
model = NVLlamaForCausalLM(config)
170+
171+
for layer in model.model.layers:
172+
fully_shard(layer, mesh=device_mesh["dp"])
173+
fully_shard(model, mesh=device_mesh["dp"])
174+
175+
model.init_empty_weights()
176+
177+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
178+
scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, num_warmup_steps=1, num_decay_steps=1)
179+
180+
app_state = AppState(model=model, optimizer=optimizer, scheduler=scheduler)
181+
state_dict = {"app": app_state}
182+
dcp_load(state_dict, checkpoint_id=checkpoint_path, process_group=device_mesh.get_group("dp"))
183+
184+
print(f"[Rank {dist_config.rank}] Loaded checkpoint at step {app_state.step}")
185+
186+
model_state_dict = get_model_state_dict(
187+
model=model,
188+
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
189+
)
190+
191+
if dist_config.is_main_process():
192+
os.makedirs(output_path, exist_ok=True)
193+
save_file(model_state_dict, os.path.join(output_path, "model.safetensors"))
194+
config.save_pretrained(output_path)
195+
print(f"Exported consolidated model to {output_path}")
196+
197+
torch.distributed.barrier()
198+
torch.distributed.destroy_process_group()
199+
200+
return dist_config.is_main_process()
201+
202+
203+
def prepare_eval_directory(checkpoint_path: str, output_path: str, tokenizer_name: str) -> str:
204+
"""Prepare a checkpoint directory with all files lm-eval needs.
205+
206+
Copies model files, patches config.json with auto_map and inference-compatible
207+
attention settings, copies modeling_llama_te.py, and saves the tokenizer.
208+
209+
Args:
210+
checkpoint_path: Source directory with model.safetensors + config.json.
211+
output_path: Destination directory for the eval-ready checkpoint.
212+
tokenizer_name: HuggingFace tokenizer name or local path.
213+
214+
Returns:
215+
The output_path string.
216+
"""
217+
from transformers import AutoTokenizer
218+
219+
from modeling_llama_te import AUTO_MAP
220+
221+
checkpoint_path_obj = Path(checkpoint_path)
222+
output_path_obj = Path(output_path)
223+
224+
if output_path_obj.resolve() != checkpoint_path_obj.resolve():
225+
os.makedirs(output_path, exist_ok=True)
226+
for f in checkpoint_path_obj.iterdir():
227+
if f.is_file():
228+
shutil.copy2(f, output_path_obj / f.name)
229+
230+
config_file = output_path_obj / "config.json"
231+
with open(config_file) as f:
232+
config = json.load(f)
233+
234+
config["auto_map"] = AUTO_MAP
235+
config["attn_input_format"] = "bshd"
236+
config["self_attn_mask_type"] = "causal"
237+
238+
with open(config_file, "w") as f:
239+
json.dump(config, f, indent=2, sort_keys=True)
240+
241+
script_dir = Path(__file__).parent
242+
shutil.copy2(script_dir / "modeling_llama_te.py", output_path_obj / "modeling_llama_te.py")
243+
244+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
245+
tokenizer.save_pretrained(str(output_path_obj))
246+
247+
print(f"Prepared eval directory: {output_path}")
248+
return output_path
249+
250+
251+
def run_lm_eval(
252+
eval_dir: str,
253+
tasks: str,
254+
batch_size: str,
255+
device: str,
256+
output_path: str | None = None,
257+
num_fewshot: int | None = None,
258+
) -> float:
259+
"""Run lm-eval on the prepared checkpoint directory.
260+
261+
Args:
262+
eval_dir: Path to the prepared eval checkpoint directory.
263+
tasks: Comma-separated list of lm-eval task names.
264+
batch_size: Batch size string (integer or "auto").
265+
device: Device string (e.g. "cuda:0").
266+
output_path: Optional path to save results JSON.
267+
num_fewshot: Optional number of few-shot examples.
268+
269+
Returns:
270+
Wall-clock time in seconds.
271+
"""
272+
cmd = [
273+
sys.executable,
274+
"-m",
275+
"lm_eval",
276+
"--model",
277+
"hf",
278+
"--model_args",
279+
f"pretrained={eval_dir},tokenizer={eval_dir}",
280+
"--trust_remote_code",
281+
"--tasks",
282+
tasks,
283+
"--device",
284+
device,
285+
"--batch_size",
286+
batch_size,
287+
]
288+
289+
if output_path:
290+
cmd.extend(["--output_path", output_path])
291+
292+
if num_fewshot is not None:
293+
cmd.extend(["--num_fewshot", str(num_fewshot)])
294+
295+
print(f"\nRunning lm-eval:\n {' '.join(cmd)}\n")
296+
print("=" * 80)
297+
298+
start_time = time.time()
299+
result = subprocess.run(cmd, check=False)
300+
elapsed = time.time() - start_time
301+
302+
print("=" * 80)
303+
print(f"\nlm-eval completed in {elapsed:.1f}s ({elapsed / 60:.1f} min)")
304+
305+
if result.returncode != 0:
306+
print(f"lm-eval failed with exit code {result.returncode}", file=sys.stderr)
307+
sys.exit(result.returncode)
308+
309+
return elapsed
310+
311+
312+
def main() -> None:
313+
"""Main entry point."""
314+
args = parse_args()
315+
checkpoint_path = Path(args.checkpoint_path)
316+
317+
use_temp = args.eval_dir is None
318+
eval_dir = args.eval_dir if args.eval_dir else tempfile.mkdtemp(prefix="lm_eval_checkpoint_")
319+
320+
if use_temp:
321+
print(f"Using temporary eval directory: {eval_dir}")
322+
323+
try:
324+
if args.from_distributed:
325+
is_main = export_distributed_checkpoint(
326+
checkpoint_path=str(checkpoint_path),
327+
model_config=args.model_config,
328+
output_path=eval_dir,
329+
)
330+
if not is_main:
331+
return
332+
source_dir = eval_dir
333+
else:
334+
if not (checkpoint_path / "model.safetensors").exists():
335+
print(
336+
f"Error: {checkpoint_path / 'model.safetensors'} not found.\n"
337+
"If this is a distributed FSDP2 checkpoint, use --from-distributed with torchrun.\n"
338+
"If this is a final_model directory, ensure it contains model.safetensors.",
339+
file=sys.stderr,
340+
)
341+
sys.exit(1)
342+
source_dir = str(checkpoint_path)
343+
344+
prepare_eval_directory(
345+
checkpoint_path=source_dir,
346+
output_path=eval_dir,
347+
tokenizer_name=args.tokenizer,
348+
)
349+
350+
run_lm_eval(
351+
eval_dir=eval_dir,
352+
tasks=args.tasks,
353+
batch_size=args.batch_size,
354+
device=args.device,
355+
output_path=args.output_path,
356+
num_fewshot=args.num_fewshot,
357+
)
358+
359+
finally:
360+
if use_temp and os.path.exists(eval_dir):
361+
print(f"\nCleaning up temporary directory: {eval_dir}")
362+
shutil.rmtree(eval_dir)
363+
364+
365+
if __name__ == "__main__":
366+
main()

0 commit comments

Comments
 (0)