Skip to content

Commit 4e1c521

Browse files
cache ase calculator
1 parent ffe658e commit 4e1c521

6 files changed

Lines changed: 28 additions & 34 deletions

File tree

src/atomate2/ase/jobs.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ class AseMaker(Maker, ABC):
5454
class EMTStaticMaker(AseMaker):
5555
name: str = "EMT static maker"
5656
57-
@property
58-
def calculator(self):
57+
def _get_calculator(self):
5958
return EMT()
6059
```
6160
@@ -95,6 +94,10 @@ def calculator(self):
9594
store_trajectory: StoreTrajectoryOption = StoreTrajectoryOption.NO
9695
tags: list[str] | None = None
9796

97+
def __post_init__(self) -> None:
98+
"""Enable caching of the ASE calculator via private attribute."""
99+
self._calculator: Calculator | None = None
100+
98101
@job(data=_ASE_DATA_OBJECTS)
99102
def make(
100103
self,
@@ -148,11 +151,18 @@ def run_ase(
148151
elapsed_time=t_f - t_i,
149152
)
150153

151-
@property
152154
@abstractmethod
155+
def _get_calculator(self) -> Calculator:
156+
"""Load ASE calculator, to be implemented by the user."""
157+
158+
@property
153159
def calculator(self) -> Calculator:
154-
"""ASE calculator, method to be implemented in subclasses."""
155-
raise NotImplementedError
160+
"""Retrieve cached ASE calculator."""
161+
if getattr(self, "_calculator", None) is None:
162+
self._calculator = self._get_calculator()
163+
if self._calculator is None:
164+
raise ValueError("ASE calculator not properly initialized.")
165+
return self._calculator
156166

157167

158168
@dataclass
@@ -208,8 +218,7 @@ class AseRelaxMaker(AseMaker):
208218

209219
def __post_init__(self) -> None:
210220
"""Ensure that physical relaxation settings are used."""
211-
if hasattr(super(), "__post_init__"):
212-
super().__post_init__() # type: ignore[misc]
221+
super().__post_init__()
213222
if self.relax_cell and self.relax_shape:
214223
raise ValueError(
215224
"You have set both `relax_cell` (relaxing the cell shape and volume) "
@@ -299,8 +308,7 @@ class EmtRelaxMaker(AseRelaxMaker):
299308

300309
name: str = "EMT relaxation"
301310

302-
@property
303-
def calculator(self) -> Calculator:
311+
def _get_calculator(self) -> Calculator:
304312
"""EMT calculator."""
305313
from ase.calculators.emt import EMT
306314

@@ -320,8 +328,7 @@ class LennardJonesRelaxMaker(AseRelaxMaker):
320328

321329
name: str = "Lennard-Jones 6-12 relaxation"
322330

323-
@property
324-
def calculator(self) -> Calculator:
331+
def _get_calculator(self) -> None:
325332
"""Lennard-Jones calculator."""
326333
from ase.calculators.lj import LennardJones
327334

@@ -378,8 +385,7 @@ class GFNxTBRelaxMaker(AseRelaxMaker):
378385
}
379386
)
380387

381-
@property
382-
def calculator(self) -> Calculator:
388+
def _get_calculator(self) -> None:
383389
"""GFN-xTB / TBLite calculator."""
384390
try:
385391
from tblite.ase import TBLite

src/atomate2/ase/md.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import sys
1010
import time
11-
from abc import ABC, abstractmethod
11+
from abc import ABC
1212
from collections.abc import Sequence
1313
from dataclasses import dataclass, field
1414
from enum import Enum
@@ -189,6 +189,7 @@ class AseMDMaker(AseMaker, ABC):
189189

190190
def __post_init__(self) -> None:
191191
"""Ensure that ensemble is an enum."""
192+
super().__post_init__()
192193
if isinstance(self.ensemble, str):
193194
self.ensemble = MDEnsemble(self.ensemble.split("MDEnsemble.")[-1])
194195

@@ -444,12 +445,6 @@ def _callback(dyn: MolecularDynamics = md_runner) -> None:
444445
elapsed_time=t_f - t_i,
445446
)
446447

447-
@property
448-
@abstractmethod
449-
def calculator(self) -> Calculator:
450-
"""ASE calculator, to be overwritten by user."""
451-
raise NotImplementedError
452-
453448

454449
@dataclass
455450
class LennardJonesMDMaker(AseMDMaker):
@@ -461,8 +456,7 @@ class LennardJonesMDMaker(AseMDMaker):
461456

462457
name: str = "Lennard-Jones 6-12 MD"
463458

464-
@property
465-
def calculator(self) -> Calculator:
459+
def _get_calculator(self) -> Calculator:
466460
"""Lennard-Jones calculator."""
467461
from ase.calculators.lj import LennardJones
468462

@@ -495,8 +489,7 @@ class GFNxTBMDMaker(AseMDMaker):
495489
}
496490
)
497491

498-
@property
499-
def calculator(self) -> Calculator:
492+
def _get_calculator(self) -> Calculator:
500493
"""GFN-xTB / TBLite calculator."""
501494
try:
502495
from tblite.ase import TBLite

src/atomate2/ase/neb.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,7 @@ class EmtNebFromImagesMaker(AseNebFromImagesMaker):
257257

258258
name: str = "EMT NEB from images maker"
259259

260-
@property
261-
def calculator(self) -> Calculator:
260+
def _get_calculator(self) -> Calculator:
262261
"""EMT calculator."""
263262
from ase.calculators.emt import EMT
264263

src/atomate2/forcefields/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,7 @@ def _run_ase_safe(self, *args, **kwargs) -> AseResult:
210210
with revert_default_dtype():
211211
return self.run_ase(*args, **kwargs)
212212

213-
@property
214-
def calculator(self) -> Calculator:
213+
def _get_calculator(self) -> Calculator:
215214
"""ASE calculator, can be overwritten by user."""
216215
return ase_calculator(
217216
self.calculator_meta,

tests/ase/test_jobs.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,15 @@
3030
class EMTStaticMaker(AseMaker):
3131
name: str = "EMT static maker"
3232

33-
@property
34-
def calculator(self):
33+
def _get_calculator(self):
3534
return EMT()
3635

3736

3837
@dataclass
3938
class EMTRelaxMaker(AseRelaxMaker):
4039
name: str = "EMT relax maker"
4140

42-
@property
43-
def calculator(self):
41+
def _get_calculator(self):
4442
return EMT()
4543

4644

tests/ase/test_neb.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ class EmtNebFromEndpointsMaker(AseNebFromEndpointsMaker):
4747
default_factory=EmtRelaxMaker,
4848
)
4949

50-
@property
51-
def calculator(self):
50+
def _get_calculator(self):
5251
return EMT(**self.calculator_kwargs)
5352

5453

0 commit comments

Comments
 (0)