Skip to content

Commit 2e3deac

Browse files
authored
Merge branch 'sunlabuiuc:master' into califorest-abhinav
2 parents c9d23be + 7d95dea commit 2e3deac

File tree

63 files changed

+12702
-954
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+12702
-954
lines changed

README.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,6 @@ Module 1: <pyhealth.datasets>
189189
root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/",
190190
# raw CSV table name
191191
tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
192-
# map all NDC codes to CCS codes in these tables
193-
code_mapping={"NDC": "CCSCM"},
194192
)
195193
196194
.. image:: figure/structured-dataset.png
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
"""
2+
Quick inference test: TFMTokenizer on TUAB using local weightfiles/.
3+
4+
Two weight setups (ask your PI which matches their training):
5+
6+
1) Default (matches conformal example scripts):
7+
- tokenizer: weightfiles/tfm_tokenizer_last.pth (multi-dataset tokenizer)
8+
- classifier: weightfiles/TFM_Tokenizer_multiple_finetuned_on_TUAB/.../best_model.pth
9+
10+
2) PI benchmark TUAB-specific files (place in weightfiles/):
11+
- tokenizer: tfm_tokenizer_tuab.pth
12+
- classifier: tfm_encoder_best_model_tuab.pth
13+
Use: --pi-tuab-weights
14+
15+
Split modes:
16+
- conformal (default): same test set as conformal runs (TUH eval via patient conformal split).
17+
- pi_benchmark: train/val ratio [0.875, 0.125] on train partition; test = TUH eval (same patients as official eval).
18+
19+
Usage:
20+
python examples/conformal_eeg/test_tfm_tuab_inference.py
21+
python examples/conformal_eeg/test_tfm_tuab_inference.py --pi-tuab-weights
22+
python examples/conformal_eeg/test_tfm_tuab_inference.py \\
23+
--tuab-pi-weights-dir /shared/eng/conformal_eeg --split pi_benchmark
24+
python examples/conformal_eeg/test_tfm_tuab_inference.py --tokenizer-weights PATH --classifier-weights PATH
25+
"""
26+
27+
import argparse
28+
import os
29+
import time
30+
31+
import torch
32+
33+
from pyhealth.datasets import (
34+
TUABDataset,
35+
get_dataloader,
36+
split_by_patient_conformal_tuh,
37+
split_by_patient_tuh,
38+
)
39+
from pyhealth.models import TFMTokenizer
40+
from pyhealth.tasks import EEGAbnormalTUAB
41+
from pyhealth.trainer import Trainer
42+
43+
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
44+
WEIGHTFILES = os.path.join(REPO_ROOT, "weightfiles")
45+
DEFAULT_TOKENIZER = os.path.join(WEIGHTFILES, "tfm_tokenizer_last.pth")
46+
CLASSIFIER_WEIGHTS_DIR = os.path.join(
47+
WEIGHTFILES, "TFM_Tokenizer_multiple_finetuned_on_TUAB"
48+
)
49+
PI_TOKENIZER = os.path.join(WEIGHTFILES, "tfm_tokenizer_tuab.pth")
50+
PI_CLASSIFIER = os.path.join(WEIGHTFILES, "tfm_encoder_best_model_tuab.pth")
51+
52+
53+
def main():
54+
parser = argparse.ArgumentParser(description="TFM TUAB inference sanity check")
55+
parser.add_argument(
56+
"--root",
57+
type=str,
58+
default="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf",
59+
help="Path to TUAB edf/ directory.",
60+
)
61+
parser.add_argument("--gpu_id", type=int, default=0)
62+
parser.add_argument(
63+
"--seed",
64+
type=int,
65+
default=1,
66+
choices=[1, 2, 3, 4, 5],
67+
help="Which fine-tuned classifier folder _1.._5 (only if not using --classifier-weights).",
68+
)
69+
parser.add_argument(
70+
"--pi-tuab-weights",
71+
action="store_true",
72+
help="Use PI TUAB-specific files under weightfiles/: tfm_tokenizer_tuab.pth, "
73+
"tfm_encoder_best_model_tuab.pth",
74+
)
75+
parser.add_argument(
76+
"--tuab-pi-weights-dir",
77+
type=str,
78+
default=None,
79+
metavar="DIR",
80+
help="Directory containing PI's TUAB TFM files (e.g. /shared/eng/conformal_eeg). "
81+
"Loads tfm_tokenizer_tuab.pth + tfm_encoder_best_model_tuab.pth from there. "
82+
"Overrides --pi-tuab-weights and default weightfiles paths unless "
83+
"--tokenizer-weights / --classifier-weights are set explicitly.",
84+
)
85+
parser.add_argument(
86+
"--tokenizer-weights",
87+
type=str,
88+
default=None,
89+
help="Override tokenizer checkpoint path.",
90+
)
91+
parser.add_argument(
92+
"--classifier-weights",
93+
type=str,
94+
default=None,
95+
help="Override classifier checkpoint path (single .pth file).",
96+
)
97+
parser.add_argument(
98+
"--split",
99+
type=str,
100+
choices=["conformal", "pi_benchmark"],
101+
default="conformal",
102+
help="conformal: same as EEG conformal scripts; pi_benchmark: 0.875/0.125 train/val on train partition.",
103+
)
104+
parser.add_argument(
105+
"--split-seed",
106+
type=int,
107+
default=42,
108+
help="RNG seed for patient shuffle (pi_benchmark and conformal).",
109+
)
110+
args = parser.parse_args()
111+
device = f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu"
112+
113+
if args.tuab_pi_weights_dir is not None:
114+
d = os.path.expanduser(args.tuab_pi_weights_dir)
115+
tok = os.path.join(d, "tfm_tokenizer_tuab.pth")
116+
cls_path = os.path.join(d, "tfm_encoder_best_model_tuab.pth")
117+
elif args.pi_tuab_weights:
118+
tok = PI_TOKENIZER
119+
cls_path = PI_CLASSIFIER
120+
else:
121+
tok = DEFAULT_TOKENIZER
122+
cls_path = os.path.join(
123+
CLASSIFIER_WEIGHTS_DIR,
124+
f"TFM_Tokenizer_multiple_finetuned_on_TUAB_{args.seed}",
125+
"best_model.pth",
126+
)
127+
128+
if args.tokenizer_weights is not None:
129+
tok = args.tokenizer_weights
130+
if args.classifier_weights is not None:
131+
cls_path = args.classifier_weights
132+
133+
print(f"Device: {device}")
134+
print(f"TUAB root: {args.root}")
135+
print(f"Split mode: {args.split}")
136+
print(f"Tokenizer weights: {tok}")
137+
print(f"Classifier weights: {cls_path}")
138+
139+
t0 = time.time()
140+
base_dataset = TUABDataset(root=args.root, subset="both")
141+
print(f"Dataset loaded in {time.time() - t0:.1f}s")
142+
143+
t0 = time.time()
144+
sample_dataset = base_dataset.set_task(
145+
EEGAbnormalTUAB(
146+
resample_rate=200,
147+
normalization="95th_percentile",
148+
compute_stft=True,
149+
),
150+
num_workers=16,
151+
)
152+
print(f"Task set in {time.time() - t0:.1f}s | total samples: {len(sample_dataset)}")
153+
154+
if args.split == "conformal":
155+
_, _, _, test_ds = split_by_patient_conformal_tuh(
156+
dataset=sample_dataset,
157+
ratios=[0.6, 0.2, 0.2],
158+
seed=args.split_seed,
159+
)
160+
else:
161+
_, _, test_ds = split_by_patient_tuh(
162+
sample_dataset,
163+
[0.875, 0.125],
164+
seed=args.split_seed,
165+
)
166+
167+
test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)
168+
print(f"Test set size: {len(test_ds)}")
169+
170+
model = TFMTokenizer(dataset=sample_dataset).to(device)
171+
model.load_pretrained_weights(
172+
tokenizer_checkpoint_path=tok,
173+
classifier_checkpoint_path=cls_path,
174+
)
175+
176+
trainer = Trainer(
177+
model=model,
178+
device=device,
179+
metrics=[
180+
"accuracy",
181+
"balanced_accuracy",
182+
"f1_weighted",
183+
"f1_macro",
184+
"roc_auc_weighted_ovr",
185+
],
186+
enable_logging=False,
187+
)
188+
t0 = time.time()
189+
results = trainer.evaluate(test_loader)
190+
print(f"\nEval time: {time.time() - t0:.1f}s")
191+
print("\n=== Test Results ===")
192+
for metric, value in results.items():
193+
print(f" {metric}: {value:.4f}")
194+
195+
196+
if __name__ == "__main__":
197+
main()
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""
2+
Quick inference test: TFMTokenizer on TUEV using local weightfiles/.
3+
4+
Mirrors the PI's benchmark script but uses the weightfiles/ paths already
5+
present in this repo. No training — pure inference to verify weights and
6+
normalization are correct.
7+
8+
Usage:
9+
python examples/conformal_eeg/test_tfm_tuev_inference.py
10+
python examples/conformal_eeg/test_tfm_tuev_inference.py --gpu_id 1
11+
python examples/conformal_eeg/test_tfm_tuev_inference.py --seed 2 # use _2/best_model.pth
12+
"""
13+
14+
import argparse
15+
import os
16+
import time
17+
18+
import torch
19+
20+
from pyhealth.datasets import TUEVDataset, get_dataloader, split_by_patient_conformal_tuh
21+
from pyhealth.models import TFMTokenizer
22+
from pyhealth.tasks import EEGEventsTUEV
23+
from pyhealth.trainer import Trainer
24+
25+
TUEV_ROOT = "/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf/"
26+
27+
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
28+
TOKENIZER_WEIGHTS = os.path.join(REPO_ROOT, "weightfiles", "tfm_tokenizer_last.pth")
29+
CLASSIFIER_WEIGHTS_DIR = os.path.join(
30+
REPO_ROOT, "weightfiles", "TFM_Tokenizer_multiple_finetuned_on_TUEV"
31+
)
32+
33+
34+
def main():
35+
parser = argparse.ArgumentParser()
36+
parser.add_argument("--gpu_id", type=int, default=0)
37+
parser.add_argument(
38+
"--seed", type=int, default=1, choices=[1, 2, 3, 4, 5],
39+
help="Which fine-tuned classifier to use (1-5)."
40+
)
41+
args = parser.parse_args()
42+
device = f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu"
43+
44+
classifier_weights = os.path.join(
45+
CLASSIFIER_WEIGHTS_DIR,
46+
f"TFM_Tokenizer_multiple_finetuned_on_TUEV_{args.seed}",
47+
"best_model.pth",
48+
)
49+
50+
print(f"Device: {device}")
51+
print(f"Tokenizer weights: {TOKENIZER_WEIGHTS}")
52+
print(f"Classifier weights: {classifier_weights}")
53+
54+
# ------------------------------------------------------------------ #
55+
# STEP 1: Load dataset
56+
# ------------------------------------------------------------------ #
57+
t0 = time.time()
58+
base_dataset = TUEVDataset(root=TUEV_ROOT, subset="both")
59+
print(f"Dataset loaded in {time.time() - t0:.1f}s")
60+
61+
# ------------------------------------------------------------------ #
62+
# STEP 2: Set task — normalization="95th_percentile" matches training
63+
# ------------------------------------------------------------------ #
64+
t0 = time.time()
65+
sample_dataset = base_dataset.set_task(
66+
EEGEventsTUEV(
67+
resample_rate=200,
68+
normalization="95th_percentile",
69+
compute_stft=True,
70+
)
71+
)
72+
print(f"Task set in {time.time() - t0:.1f}s | total samples: {len(sample_dataset)}")
73+
74+
# ------------------------------------------------------------------ #
75+
# STEP 3: Extract fixed test set (TUH eval partition)
76+
# ------------------------------------------------------------------ #
77+
_, _, _, test_ds = split_by_patient_conformal_tuh(
78+
dataset=sample_dataset,
79+
ratios=[0.6, 0.2, 0.2],
80+
seed=42,
81+
)
82+
test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)
83+
print(f"Test set size: {len(test_ds)}")
84+
85+
# ------------------------------------------------------------------ #
86+
# STEP 4: Load TFMTokenizer with pre-trained weights (no training)
87+
# ------------------------------------------------------------------ #
88+
model = TFMTokenizer(dataset=sample_dataset).to(device)
89+
model.load_pretrained_weights(
90+
tokenizer_checkpoint_path=TOKENIZER_WEIGHTS,
91+
classifier_checkpoint_path=classifier_weights,
92+
)
93+
94+
# ------------------------------------------------------------------ #
95+
# STEP 5: Evaluate
96+
# ------------------------------------------------------------------ #
97+
trainer = Trainer(
98+
model=model,
99+
device=device,
100+
metrics=["accuracy", "f1_weighted", "f1_macro"],
101+
enable_logging=False,
102+
)
103+
t0 = time.time()
104+
results = trainer.evaluate(test_loader)
105+
print(f"\nEval time: {time.time() - t0:.1f}s")
106+
print("\n=== Test Results ===")
107+
for metric, value in results.items():
108+
print(f" {metric}: {value:.4f}")
109+
110+
111+
if __name__ == "__main__":
112+
main()

examples/conformal_eeg/tuab_conventional_conformal.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,16 @@ def parse_args() -> argparse.Namespace:
123123
parser.add_argument(
124124
"--weights-dir",
125125
type=str,
126-
default="weightfiles/TFM_Tokenizer_multiple_finetuned_on_TUAB",
127-
help="Root folder of fine-tuned TFM classifier checkpoints (only with --model tfm).",
126+
default="/shared/eng/conformal_eeg",
127+
help="Root folder of TFM classifier checkpoints (only with --model tfm). "
128+
"If the directory contains tfm_encoder_best_model_tuab.pth directly, "
129+
"that single checkpoint is used for all seeds (PI TUAB setup). "
130+
"Otherwise expects per-seed subdirs {base}_1..N/best_model.pth.",
128131
)
129132
parser.add_argument(
130133
"--tokenizer-weights",
131134
type=str,
132-
default="weightfiles/tfm_tokenizer_last.pth",
135+
default="/shared/eng/conformal_eeg/tfm_tokenizer_tuab.pth",
133136
help="Path to the pre-trained TFM tokenizer weights (only with --model tfm).",
134137
)
135138
parser.add_argument(
@@ -152,9 +155,19 @@ def _do_split(dataset, ratios, seed, split_type):
152155

153156

154157
def _load_tfm_weights(model, args, run_idx: int) -> None:
155-
"""Load pre-trained tokenizer + fine-tuned classifier for run_idx (0-based)."""
156-
base = os.path.basename(args.weights_dir)
157-
classifier_path = os.path.join(args.weights_dir, f"{base}_{run_idx + 1}", "best_model.pth")
158+
"""Load pre-trained tokenizer + fine-tuned classifier for run_idx (0-based).
159+
160+
Supports two layouts:
161+
- Single classifier (PI TUAB setup): weights_dir/tfm_encoder_best_model_tuab.pth
162+
Used for all seeds — only the data split varies across runs.
163+
- Per-seed subdirs: weights_dir/{base}_{run_idx+1}/best_model.pth
164+
"""
165+
single = os.path.join(args.weights_dir, "tfm_encoder_best_model_tuab.pth")
166+
if os.path.isfile(single):
167+
classifier_path = single
168+
else:
169+
base = os.path.basename(args.weights_dir)
170+
classifier_path = os.path.join(args.weights_dir, f"{base}_{run_idx + 1}", "best_model.pth")
158171
print(f" Loading TFM weights (run {run_idx + 1}): {classifier_path}")
159172
model.load_pretrained_weights(
160173
tokenizer_checkpoint_path=args.tokenizer_weights,
@@ -283,7 +296,7 @@ def _print_multi_seed_summary(
283296
n_runs = len(all_metrics)
284297

285298
print("\n" + "=" * 80)
286-
print("Per-run LABEL results (fixed test set = TUH eval partition)")
299+
print(f"Per-run results — alpha={alpha} (LABEL, fixed test set = TUH eval partition)")
287300
print("=" * 80)
288301
print(f" {'Run':<4} {'Seed':<6} {'Accuracy':<10} {'ROC-AUC':<10} {'F1':<8} "
289302
f"{'Coverage':<10} {'Miscoverage':<12} {'Avg set size':<12}")
@@ -295,7 +308,8 @@ def _print_multi_seed_summary(
295308
f"{m['miscoverage']:<12.4f} {m['avg_set_size']:<12.2f}")
296309

297310
print("\n" + "=" * 80)
298-
print(f"LABEL summary (mean \u00b1 std over {n_runs} runs, fixed test set)")
311+
print(f"Summary — alpha={alpha} (mean \u00b1 std over {n_runs} runs, fixed test set)")
312+
print(" Method: LABEL")
299313
print("=" * 80)
300314
print(f" Accuracy: {accs.mean():.4f} \u00b1 {accs.std():.4f}")
301315
print(f" ROC-AUC: {roc_aucs.mean():.4f} \u00b1 {roc_aucs.std():.4f}")
@@ -334,7 +348,7 @@ def _main(args: argparse.Namespace) -> None:
334348
print("STEP 1: Load TUAB + build task dataset (shared across all seeds)")
335349
print("=" * 80)
336350
dataset = TUABDataset(root=str(root), subset=args.subset, dev=args.quick_test)
337-
sample_dataset = dataset.set_task(EEGAbnormalTUAB())
351+
sample_dataset = dataset.set_task(EEGAbnormalTUAB(normalization="95th_percentile"), num_workers=16)
338352
if args.quick_test and len(sample_dataset) > quick_test_max_samples:
339353
sample_dataset = sample_dataset.subset(range(quick_test_max_samples))
340354
print(f"Capped to {quick_test_max_samples} samples for quick-test.")

0 commit comments

Comments
 (0)