Skip to content

Commit b1aef62

Browse files
authored
Merge pull request #435 from mdboom/config-handling
REFACTOR: Load config with dataclasses, earlier validation
2 parents 77e6394 + 71d0581 commit b1aef62

16 files changed

Lines changed: 343 additions & 250 deletions

File tree

bench_runner/bases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
def get_bases() -> list[str]:
15-
return config.get_bench_runner_config().get("bases", {}).get("versions", [])
15+
return config.get_config().bases.versions
1616

1717

1818
@functools.cache

bench_runner/config.py

Lines changed: 104 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,123 @@
22
Handles the loading of the bench_runner.toml configuration file.
33
"""
44

5+
import dataclasses
56
import functools
67
from pathlib import Path
78
import tomllib
8-
from typing import Any
99

1010

11-
from . import runners
11+
from . import flags as mflags
12+
from . import plot as mplot
13+
from . import runners as mrunners
1214
from .util import PathLike
1315

1416

17+
@dataclasses.dataclass
18+
class Bases:
19+
# The base versions to compare every benchmark run to.
20+
# Should be a full-specified version, e.g. "3.13.0".
21+
versions: list[str]
22+
# List of configuration flags that are compared against the default build of
23+
# its commit merge base.
24+
compare_to_default: list[str] = dataclasses.field(default_factory=list)
25+
26+
def __post_init__(self):
27+
if len(self.versions) == 0:
28+
raise RuntimeError(
29+
"No `bases.versions` are defined in `bench_runner.toml`. "
30+
)
31+
mflags.normalize_flags(self.compare_to_default)
32+
33+
34+
@dataclasses.dataclass
35+
class Notify:
36+
# The Github issue to use to send notification emails
37+
notification_issue: int = 0
38+
39+
40+
@dataclasses.dataclass
41+
class PublishMirror:
42+
# Whether to skip publishing to the mirror
43+
skip: bool = False
44+
45+
46+
@dataclasses.dataclass
47+
class Benchmarks:
48+
# Benchmarks to exclude from plots.
49+
excluded_benchmarks: set[str] = dataclasses.field(default_factory=set)
50+
51+
def __post_init__(self):
52+
self.excluded_benchmarks = set(self.excluded_benchmarks)
53+
54+
55+
@dataclasses.dataclass
56+
class Weekly:
57+
flags: list[str] = dataclasses.field(default_factory=list)
58+
runners: list[str] = dataclasses.field(default_factory=list)
59+
60+
def __post_init__(self):
61+
self.flags = mflags.normalize_flags(self.flags)
62+
63+
64+
@dataclasses.dataclass
65+
class Config:
66+
bases: Bases
67+
runners: dict[str, mrunners.Runner]
68+
publish_mirror: PublishMirror = dataclasses.field(default_factory=PublishMirror)
69+
benchmarks: Benchmarks = dataclasses.field(default_factory=Benchmarks)
70+
notify: Notify = dataclasses.field(default_factory=Notify)
71+
longitudinal_plot: mplot.LongitudinalPlotConfig | None = None
72+
flag_effect_plot: mplot.FlagEffectPlotConfig | None = None
73+
benchmark_longitudinal_plot: mplot.BenchmarkLongitudinalPlotConfig | None = None
74+
weekly: dict[str, Weekly] = dataclasses.field(default_factory=dict)
75+
76+
def __post_init__(self):
77+
self.bases = Bases(**self.bases) # pyright: ignore[reportCallIssue]
78+
if len(self.runners) == 0:
79+
raise RuntimeError(
80+
"No runners are defined in `bench_runner.toml`. "
81+
"Please set up some runners first."
82+
)
83+
self.runners = {
84+
name: mrunners.Runner(
85+
nickname=name, **runner # pyright: ignore[reportCallIssue]
86+
)
87+
for name, runner in self.runners.items()
88+
}
89+
if isinstance(self.publish_mirror, dict):
90+
self.publish_mirror = PublishMirror(**self.publish_mirror)
91+
if isinstance(self.benchmarks, dict):
92+
self.benchmarks = Benchmarks(**self.benchmarks)
93+
if isinstance(self.notify, dict):
94+
self.notify = Notify(**self.notify)
95+
if isinstance(self.longitudinal_plot, dict):
96+
self.longitudinal_plot = mplot.LongitudinalPlotConfig(
97+
**self.longitudinal_plot
98+
)
99+
if isinstance(self.flag_effect_plot, dict):
100+
self.flag_effect_plot = mplot.FlagEffectPlotConfig(**self.flag_effect_plot)
101+
if isinstance(self.benchmark_longitudinal_plot, dict):
102+
self.benchmark_longitudinal_plot = mplot.BenchmarkLongitudinalPlotConfig(
103+
**self.benchmark_longitudinal_plot
104+
)
105+
if len(self.weekly) == 0:
106+
self.weekly = {"default": Weekly(runners=list(self.runners.keys()))}
107+
else:
108+
self.weekly = {
109+
name: Weekly(**weekly) # pyright: ignore[reportCallIssue]
110+
for name, weekly in self.weekly.items()
111+
}
112+
113+
15114
@functools.cache
16-
def get_bench_runner_config(filepath: PathLike | None = None):
115+
def get_config(filepath: PathLike | None = None) -> Config:
17116
if filepath is None:
18117
filepath = Path("bench_runner.toml")
19118
else:
20119
filepath = Path(filepath)
21120

22121
with filepath.open("rb") as fd:
23-
return tomllib.load(fd)
24-
122+
content = tomllib.load(fd)
25123

26-
def get_config_for_current_runner(filepath: PathLike | None = None) -> dict[str, Any]:
27-
config = get_bench_runner_config(filepath)
28-
runner = runners.get_runner_for_hostname(cfgpath=filepath)
29-
all_runners = config.get("runners", {})
30-
if len(all_runners) >= 1:
31-
return all_runners.get(runner.nickname, {})
32-
return {}
124+
return Config(**content)

bench_runner/flags.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,11 @@ def flags_to_human(flags: list[str]) -> Iterable[str]:
5656
if flag_descr.name == flag:
5757
yield flag_descr.short_name
5858
break
59+
60+
61+
def normalize_flags(flags: list[str]) -> list[str]:
62+
result = sorted(set(flags))
63+
for flag in result:
64+
if flag not in FLAG_MAPPING.values():
65+
raise ValueError(f"Invalid flag {flag}")
66+
return result

bench_runner/gh.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
from . import config
1313
from . import flags as mflags
14-
from . import runners
1514

1615

1716
def get_machines():
18-
return [x.name for x in runners.get_runners() if x.available] + ["all"]
17+
cfg = config.get_config()
18+
return [x.name for x in cfg.runners.values() if x.available] + ["all"]
1919

2020

2121
def _get_flags(d: Mapping[str, Any]) -> list[str]:
@@ -71,8 +71,8 @@ def benchmark(
7171

7272

7373
def send_notification(body):
74-
conf = config.get_bench_runner_config()
75-
notification_issue = conf.get("notify", {}).get("notification_issue", 0)
74+
cfg = config.get_config()
75+
notification_issue = cfg.notify.notification_issue
7676

7777
if notification_issue == 0:
7878
print("Not sending Github notification.")

bench_runner/hpt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from numpy.typing import NDArray
3636

3737

38-
from . import util
38+
from . import config
3939
from .util import PathLike
4040

4141
ACC_MAXSU = 2
@@ -68,8 +68,9 @@ def load_data(data: Mapping[str, Any]) -> dict[str, NDArray[np.float64]]:
6868
def create_matrices(
6969
a: Mapping[str, NDArray[np.float64]], b: Mapping[str, NDArray[np.float64]]
7070
) -> tuple[dict[str, NDArray[np.float64]], dict[str, NDArray[np.float64]]]:
71+
cfg = config.get_config()
7172
benchmarks = sorted(list(set(a.keys()) & set(b.keys())))
72-
excluded = util.get_excluded_benchmarks()
73+
excluded = cfg.benchmarks.excluded_benchmarks
7374
benchmarks = [bm for bm in benchmarks if bm not in excluded]
7475
return {bm: a[bm] for bm in benchmarks}, {bm: b[bm] for bm in benchmarks}
7576

0 commit comments

Comments
 (0)