-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathmain.py
More file actions
113 lines (88 loc) · 4.24 KB
/
main.py
File metadata and controls
113 lines (88 loc) · 4.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from marl_models.base_model import MARLModel
from environment.env import Env
from marl_models.utils import get_model, load_step_count
from train import train_on_policy, train_off_policy, train_random
from test import test_model
from utils.logger import Logger, load_configs
from utils.plot_logs import generate_plots
import config
import torch
import numpy as np
import argparse
import warnings
import os
from datetime import datetime
torch.set_float32_matmul_precision("high")
def start_training(args: argparse.Namespace):
timestamp: str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
print(f"\n🚀 Training started at {timestamp} for {args.num_episodes} episodes")
resume_training: bool = args.resume_path is not None
if resume_training:
if args.config_path is None:
raise ValueError("If --resume_path is provided, --config_path must also be provided.")
load_configs(args.config_path) # Resume training with old config
else: # Fresh training
if args.config_path is not None:
warnings.warn("--config_path is ignored during training unless --resume_path is also provided.")
np.random.seed(config.SEED)
torch.manual_seed(config.SEED)
env: Env = Env()
model_name: str = config.MODEL.lower()
model: MARLModel = get_model(model_name)
model_log_dir: str = f"train_logs/{model_name}"
if not os.path.exists(model_log_dir):
os.makedirs(model_log_dir)
logger: Logger = Logger(model_log_dir, timestamp)
if not resume_training:
logger.log_configs() # Save config for fresh training
total_step_count: int = 0 # for off policy models
if resume_training:
model.load(args.resume_path)
total_step_count = load_step_count(args.resume_path)
print(f"📥 Models loaded successfully from {args.resume_path}")
print(f"📂 Resumed training from: {args.resume_path}\n")
if model_name in ["maddpg", "attention_maddpg", "matd3", "attention_matd3", "masac", "attention_masac"]:
train_off_policy(env, model, logger, args.num_episodes, total_step_count)
elif model_name in ["mappo", "attention_mappo"]:
train_on_policy(env, model, logger, args.num_episodes)
else: # "random"
train_random(env, model, logger, args.num_episodes)
print("✅ Training Completed!\n")
print("📊 Generating plots...")
generate_plots(f"{model_log_dir}/log_data_{timestamp}.json", f"train_plots/{model_name}/", "train", timestamp)
def start_testing(args: argparse.Namespace):
timestamp: str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
print(f"\n🚀 Testing started at {timestamp} for {args.num_episodes} episodes")
load_configs(args.config_path)
np.random.seed(config.SEED)
torch.manual_seed(config.SEED)
env: Env = Env()
model_name: str = config.MODEL.lower()
model: MARLModel = get_model(model_name)
model_log_dir: str = f"test_logs/{model_name}"
if not os.path.exists(model_log_dir):
os.makedirs(model_log_dir)
logger: Logger = Logger(model_log_dir, timestamp)
model.load(args.model_path)
print(f"📥 Models loaded successfully from {args.model_path}")
test_model(env, model, logger, args.num_episodes)
print("✅ Testing Completed!\n")
print("📊 Generating plots...")
generate_plots(f"{model_log_dir}/log_data_{timestamp}.json", f"test_plots/{model_name}/", "test", timestamp, smoothing_window=2)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="mode", required=True)
parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser.add_argument("--num_episodes", type=int, required=True)
train_parser = subparsers.add_parser("train", parents=[parent_parser])
train_parser.add_argument("--resume_path", type=str, default=None)
train_parser.add_argument("--config_path", type=str, default=None)
test_parser = subparsers.add_parser("test", parents=[parent_parser])
test_parser.add_argument("--model_path", type=str, required=True)
test_parser.add_argument("--config_path", type=str, required=True)
args = parser.parse_args()
if args.mode == "train":
start_training(args)
elif args.mode == "test":
start_testing(args)
print("🎉 All done!")