|
2 | 2 | Handles the loading of the bench_runner.toml configuration file. |
3 | 3 | """ |
4 | 4 |
|
| 5 | +import dataclasses |
5 | 6 | import functools |
6 | 7 | from pathlib import Path |
7 | 8 | import tomllib |
8 | | -from typing import Any |
9 | 9 |
|
10 | 10 |
|
11 | | -from . import runners |
| 11 | +from . import flags as mflags |
| 12 | +from . import plot as mplot |
| 13 | +from . import runners as mrunners |
12 | 14 | from .util import PathLike |
13 | 15 |
|
14 | 16 |
|
| 17 | +@dataclasses.dataclass |
| 18 | +class Bases: |
| 19 | + # The base versions to compare every benchmark run to. |
| 20 | + # Should be a full-specified version, e.g. "3.13.0". |
| 21 | + versions: list[str] |
| 22 | + # List of configuration flags that are compared against the default build of |
| 23 | + # its commit merge base. |
| 24 | + compare_to_default: list[str] = dataclasses.field(default_factory=list) |
| 25 | + |
| 26 | + def __post_init__(self): |
| 27 | + if len(self.versions) == 0: |
| 28 | + raise RuntimeError( |
| 29 | + "No `bases.versions` are defined in `bench_runner.toml`. " |
| 30 | + ) |
| 31 | + mflags.normalize_flags(self.compare_to_default) |
| 32 | + |
| 33 | + |
| 34 | +@dataclasses.dataclass |
| 35 | +class Notify: |
| 36 | + # The Github issue to use to send notification emails |
| 37 | + notification_issue: int = 0 |
| 38 | + |
| 39 | + |
| 40 | +@dataclasses.dataclass |
| 41 | +class PublishMirror: |
| 42 | + # Whether to skip publishing to the mirror |
| 43 | + skip: bool = False |
| 44 | + |
| 45 | + |
| 46 | +@dataclasses.dataclass |
| 47 | +class Benchmarks: |
| 48 | + # Benchmarks to exclude from plots. |
| 49 | + excluded_benchmarks: set[str] = dataclasses.field(default_factory=set) |
| 50 | + |
| 51 | + def __post_init__(self): |
| 52 | + self.excluded_benchmarks = set(self.excluded_benchmarks) |
| 53 | + |
| 54 | + |
| 55 | +@dataclasses.dataclass |
| 56 | +class Weekly: |
| 57 | + flags: list[str] = dataclasses.field(default_factory=list) |
| 58 | + runners: list[str] = dataclasses.field(default_factory=list) |
| 59 | + |
| 60 | + def __post_init__(self): |
| 61 | + self.flags = mflags.normalize_flags(self.flags) |
| 62 | + |
| 63 | + |
| 64 | +@dataclasses.dataclass |
| 65 | +class Config: |
| 66 | + bases: Bases |
| 67 | + runners: dict[str, mrunners.Runner] |
| 68 | + publish_mirror: PublishMirror = dataclasses.field(default_factory=PublishMirror) |
| 69 | + benchmarks: Benchmarks = dataclasses.field(default_factory=Benchmarks) |
| 70 | + notify: Notify = dataclasses.field(default_factory=Notify) |
| 71 | + longitudinal_plot: mplot.LongitudinalPlotConfig | None = None |
| 72 | + flag_effect_plot: mplot.FlagEffectPlotConfig | None = None |
| 73 | + benchmark_longitudinal_plot: mplot.BenchmarkLongitudinalPlotConfig | None = None |
| 74 | + weekly: dict[str, Weekly] = dataclasses.field(default_factory=dict) |
| 75 | + |
| 76 | + def __post_init__(self): |
| 77 | + self.bases = Bases(**self.bases) # pyright: ignore[reportCallIssue] |
| 78 | + if len(self.runners) == 0: |
| 79 | + raise RuntimeError( |
| 80 | + "No runners are defined in `bench_runner.toml`. " |
| 81 | + "Please set up some runners first." |
| 82 | + ) |
| 83 | + self.runners = { |
| 84 | + name: mrunners.Runner( |
| 85 | + nickname=name, **runner # pyright: ignore[reportCallIssue] |
| 86 | + ) |
| 87 | + for name, runner in self.runners.items() |
| 88 | + } |
| 89 | + if isinstance(self.publish_mirror, dict): |
| 90 | + self.publish_mirror = PublishMirror(**self.publish_mirror) |
| 91 | + if isinstance(self.benchmarks, dict): |
| 92 | + self.benchmarks = Benchmarks(**self.benchmarks) |
| 93 | + if isinstance(self.notify, dict): |
| 94 | + self.notify = Notify(**self.notify) |
| 95 | + if isinstance(self.longitudinal_plot, dict): |
| 96 | + self.longitudinal_plot = mplot.LongitudinalPlotConfig( |
| 97 | + **self.longitudinal_plot |
| 98 | + ) |
| 99 | + if isinstance(self.flag_effect_plot, dict): |
| 100 | + self.flag_effect_plot = mplot.FlagEffectPlotConfig(**self.flag_effect_plot) |
| 101 | + if isinstance(self.benchmark_longitudinal_plot, dict): |
| 102 | + self.benchmark_longitudinal_plot = mplot.BenchmarkLongitudinalPlotConfig( |
| 103 | + **self.benchmark_longitudinal_plot |
| 104 | + ) |
| 105 | + if len(self.weekly) == 0: |
| 106 | + self.weekly = {"default": Weekly(runners=list(self.runners.keys()))} |
| 107 | + else: |
| 108 | + self.weekly = { |
| 109 | + name: Weekly(**weekly) # pyright: ignore[reportCallIssue] |
| 110 | + for name, weekly in self.weekly.items() |
| 111 | + } |
| 112 | + |
| 113 | + |
15 | 114 | @functools.cache |
16 | | -def get_bench_runner_config(filepath: PathLike | None = None): |
| 115 | +def get_config(filepath: PathLike | None = None) -> Config: |
17 | 116 | if filepath is None: |
18 | 117 | filepath = Path("bench_runner.toml") |
19 | 118 | else: |
20 | 119 | filepath = Path(filepath) |
21 | 120 |
|
22 | 121 | with filepath.open("rb") as fd: |
23 | | - return tomllib.load(fd) |
24 | | - |
| 122 | + content = tomllib.load(fd) |
25 | 123 |
|
26 | | -def get_config_for_current_runner(filepath: PathLike | None = None) -> dict[str, Any]: |
27 | | - config = get_bench_runner_config(filepath) |
28 | | - runner = runners.get_runner_for_hostname(cfgpath=filepath) |
29 | | - all_runners = config.get("runners", {}) |
30 | | - if len(all_runners) >= 1: |
31 | | - return all_runners.get(runner.nickname, {}) |
32 | | - return {} |
| 124 | + return Config(**content) |
0 commit comments