|
| 1 | +""" |
| 2 | +Quick inference test: TFMTokenizer on TUAB using local weightfiles/. |
| 3 | +
|
| 4 | +Two weight setups (ask your PI which matches their training): |
| 5 | +
|
| 6 | + 1) Default (matches conformal example scripts): |
| 7 | + - tokenizer: weightfiles/tfm_tokenizer_last.pth (multi-dataset tokenizer) |
| 8 | + - classifier: weightfiles/TFM_Tokenizer_multiple_finetuned_on_TUAB/.../best_model.pth |
| 9 | +
|
| 10 | + 2) PI benchmark TUAB-specific files (place in weightfiles/): |
| 11 | + - tokenizer: tfm_tokenizer_tuab.pth |
| 12 | + - classifier: tfm_encoder_best_model_tuab.pth |
| 13 | + Use: --pi-tuab-weights |
| 14 | +
|
| 15 | +Split modes: |
| 16 | + - conformal (default): same test set as conformal runs (TUH eval via patient conformal split). |
| 17 | + - pi_benchmark: train/val ratio [0.875, 0.125] on train partition; test = TUH eval (same patients as official eval). |
| 18 | +
|
| 19 | +Usage: |
| 20 | + python examples/conformal_eeg/test_tfm_tuab_inference.py |
| 21 | + python examples/conformal_eeg/test_tfm_tuab_inference.py --pi-tuab-weights |
| 22 | + python examples/conformal_eeg/test_tfm_tuab_inference.py \\ |
| 23 | + --tuab-pi-weights-dir /shared/eng/conformal_eeg --split pi_benchmark |
| 24 | + python examples/conformal_eeg/test_tfm_tuab_inference.py --tokenizer-weights PATH --classifier-weights PATH |
| 25 | +""" |
| 26 | + |
| 27 | +import argparse |
| 28 | +import os |
| 29 | +import time |
| 30 | + |
| 31 | +import torch |
| 32 | + |
| 33 | +from pyhealth.datasets import ( |
| 34 | + TUABDataset, |
| 35 | + get_dataloader, |
| 36 | + split_by_patient_conformal_tuh, |
| 37 | + split_by_patient_tuh, |
| 38 | +) |
| 39 | +from pyhealth.models import TFMTokenizer |
| 40 | +from pyhealth.tasks import EEGAbnormalTUAB |
| 41 | +from pyhealth.trainer import Trainer |
| 42 | + |
| 43 | +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| 44 | +WEIGHTFILES = os.path.join(REPO_ROOT, "weightfiles") |
| 45 | +DEFAULT_TOKENIZER = os.path.join(WEIGHTFILES, "tfm_tokenizer_last.pth") |
| 46 | +CLASSIFIER_WEIGHTS_DIR = os.path.join( |
| 47 | + WEIGHTFILES, "TFM_Tokenizer_multiple_finetuned_on_TUAB" |
| 48 | +) |
| 49 | +PI_TOKENIZER = os.path.join(WEIGHTFILES, "tfm_tokenizer_tuab.pth") |
| 50 | +PI_CLASSIFIER = os.path.join(WEIGHTFILES, "tfm_encoder_best_model_tuab.pth") |
| 51 | + |
| 52 | + |
| 53 | +def main(): |
| 54 | + parser = argparse.ArgumentParser(description="TFM TUAB inference sanity check") |
| 55 | + parser.add_argument( |
| 56 | + "--root", |
| 57 | + type=str, |
| 58 | + default="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf", |
| 59 | + help="Path to TUAB edf/ directory.", |
| 60 | + ) |
| 61 | + parser.add_argument("--gpu_id", type=int, default=0) |
| 62 | + parser.add_argument( |
| 63 | + "--seed", |
| 64 | + type=int, |
| 65 | + default=1, |
| 66 | + choices=[1, 2, 3, 4, 5], |
| 67 | + help="Which fine-tuned classifier folder _1.._5 (only if not using --classifier-weights).", |
| 68 | + ) |
| 69 | + parser.add_argument( |
| 70 | + "--pi-tuab-weights", |
| 71 | + action="store_true", |
| 72 | + help="Use PI TUAB-specific files under weightfiles/: tfm_tokenizer_tuab.pth, " |
| 73 | + "tfm_encoder_best_model_tuab.pth", |
| 74 | + ) |
| 75 | + parser.add_argument( |
| 76 | + "--tuab-pi-weights-dir", |
| 77 | + type=str, |
| 78 | + default=None, |
| 79 | + metavar="DIR", |
| 80 | + help="Directory containing PI's TUAB TFM files (e.g. /shared/eng/conformal_eeg). " |
| 81 | + "Loads tfm_tokenizer_tuab.pth + tfm_encoder_best_model_tuab.pth from there. " |
| 82 | + "Overrides --pi-tuab-weights and default weightfiles paths unless " |
| 83 | + "--tokenizer-weights / --classifier-weights are set explicitly.", |
| 84 | + ) |
| 85 | + parser.add_argument( |
| 86 | + "--tokenizer-weights", |
| 87 | + type=str, |
| 88 | + default=None, |
| 89 | + help="Override tokenizer checkpoint path.", |
| 90 | + ) |
| 91 | + parser.add_argument( |
| 92 | + "--classifier-weights", |
| 93 | + type=str, |
| 94 | + default=None, |
| 95 | + help="Override classifier checkpoint path (single .pth file).", |
| 96 | + ) |
| 97 | + parser.add_argument( |
| 98 | + "--split", |
| 99 | + type=str, |
| 100 | + choices=["conformal", "pi_benchmark"], |
| 101 | + default="conformal", |
| 102 | + help="conformal: same as EEG conformal scripts; pi_benchmark: 0.875/0.125 train/val on train partition.", |
| 103 | + ) |
| 104 | + parser.add_argument( |
| 105 | + "--split-seed", |
| 106 | + type=int, |
| 107 | + default=42, |
| 108 | + help="RNG seed for patient shuffle (pi_benchmark and conformal).", |
| 109 | + ) |
| 110 | + args = parser.parse_args() |
| 111 | + device = f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu" |
| 112 | + |
| 113 | + if args.tuab_pi_weights_dir is not None: |
| 114 | + d = os.path.expanduser(args.tuab_pi_weights_dir) |
| 115 | + tok = os.path.join(d, "tfm_tokenizer_tuab.pth") |
| 116 | + cls_path = os.path.join(d, "tfm_encoder_best_model_tuab.pth") |
| 117 | + elif args.pi_tuab_weights: |
| 118 | + tok = PI_TOKENIZER |
| 119 | + cls_path = PI_CLASSIFIER |
| 120 | + else: |
| 121 | + tok = DEFAULT_TOKENIZER |
| 122 | + cls_path = os.path.join( |
| 123 | + CLASSIFIER_WEIGHTS_DIR, |
| 124 | + f"TFM_Tokenizer_multiple_finetuned_on_TUAB_{args.seed}", |
| 125 | + "best_model.pth", |
| 126 | + ) |
| 127 | + |
| 128 | + if args.tokenizer_weights is not None: |
| 129 | + tok = args.tokenizer_weights |
| 130 | + if args.classifier_weights is not None: |
| 131 | + cls_path = args.classifier_weights |
| 132 | + |
| 133 | + print(f"Device: {device}") |
| 134 | + print(f"TUAB root: {args.root}") |
| 135 | + print(f"Split mode: {args.split}") |
| 136 | + print(f"Tokenizer weights: {tok}") |
| 137 | + print(f"Classifier weights: {cls_path}") |
| 138 | + |
| 139 | + t0 = time.time() |
| 140 | + base_dataset = TUABDataset(root=args.root, subset="both") |
| 141 | + print(f"Dataset loaded in {time.time() - t0:.1f}s") |
| 142 | + |
| 143 | + t0 = time.time() |
| 144 | + sample_dataset = base_dataset.set_task( |
| 145 | + EEGAbnormalTUAB( |
| 146 | + resample_rate=200, |
| 147 | + normalization="95th_percentile", |
| 148 | + compute_stft=True, |
| 149 | + ), |
| 150 | + num_workers=16, |
| 151 | + ) |
| 152 | + print(f"Task set in {time.time() - t0:.1f}s | total samples: {len(sample_dataset)}") |
| 153 | + |
| 154 | + if args.split == "conformal": |
| 155 | + _, _, _, test_ds = split_by_patient_conformal_tuh( |
| 156 | + dataset=sample_dataset, |
| 157 | + ratios=[0.6, 0.2, 0.2], |
| 158 | + seed=args.split_seed, |
| 159 | + ) |
| 160 | + else: |
| 161 | + _, _, test_ds = split_by_patient_tuh( |
| 162 | + sample_dataset, |
| 163 | + [0.875, 0.125], |
| 164 | + seed=args.split_seed, |
| 165 | + ) |
| 166 | + |
| 167 | + test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False) |
| 168 | + print(f"Test set size: {len(test_ds)}") |
| 169 | + |
| 170 | + model = TFMTokenizer(dataset=sample_dataset).to(device) |
| 171 | + model.load_pretrained_weights( |
| 172 | + tokenizer_checkpoint_path=tok, |
| 173 | + classifier_checkpoint_path=cls_path, |
| 174 | + ) |
| 175 | + |
| 176 | + trainer = Trainer( |
| 177 | + model=model, |
| 178 | + device=device, |
| 179 | + metrics=[ |
| 180 | + "accuracy", |
| 181 | + "balanced_accuracy", |
| 182 | + "f1_weighted", |
| 183 | + "f1_macro", |
| 184 | + "roc_auc_weighted_ovr", |
| 185 | + ], |
| 186 | + enable_logging=False, |
| 187 | + ) |
| 188 | + t0 = time.time() |
| 189 | + results = trainer.evaluate(test_loader) |
| 190 | + print(f"\nEval time: {time.time() - t0:.1f}s") |
| 191 | + print("\n=== Test Results ===") |
| 192 | + for metric, value in results.items(): |
| 193 | + print(f" {metric}: {value:.4f}") |
| 194 | + |
| 195 | + |
| 196 | +if __name__ == "__main__": |
| 197 | + main() |
0 commit comments