Skip to content

Commit 75399ce

Browse files
batch mode extension in phonons + tests adapted from #1196
1 parent ae10e2c commit 75399ce

7 files changed

Lines changed: 190 additions & 76 deletions

File tree

src/atomate2/ase/jobs.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -101,24 +101,37 @@ def __post_init__(self) -> None:
101101
@job(data=_ASE_DATA_OBJECTS)
102102
def make(
103103
self,
104-
mol_or_struct: Molecule | Structure,
104+
mol_or_struct: Molecule | Structure | list[Molecule | Structure],
105105
prev_dir: str | Path | None = None,
106-
) -> AseStructureTaskDoc | AseMoleculeTaskDoc:
106+
) -> (
107+
AseStructureTaskDoc
108+
| AseMoleculeTaskDoc
109+
| list[AseStructureTaskDoc | AseMoleculeTaskDoc]
110+
):
107111
"""
108112
Run ASE as job, can be re-implemented in subclasses.
109113
110114
Parameters
111115
----------
112-
mol_or_struct: .Molecule or .Structure
113-
pymatgen molecule or structure
116+
mol_or_struct: .Molecule, .Structure, or a list thereof
117+
pymatgen molecule(s) or structure(s)
114118
prev_dir : str or Path or None
115119
A previous calculation directory to copy output files from. Unused, just
116120
added to match the method signature of other makers.
121+
122+
Returns
123+
-------
124+
AseStructureTaskDoc, AseMoleculeTaskDoc, or list thereof.
117125
"""
118-
return AseTaskDoc.to_mol_or_struct_metadata_doc(
119-
getattr(self.calculator, "name", type(self.calculator).__name__),
120-
self.run_ase(mol_or_struct, prev_dir=prev_dir),
121-
)
126+
batch_mode = isinstance(mol_or_struct, list)
127+
results = [
128+
AseTaskDoc.to_mol_or_struct_metadata_doc(
129+
getattr(self.calculator, "name", type(self.calculator).__name__),
130+
self.run_ase(atoms, prev_dir=prev_dir),
131+
)
132+
for atoms in (mol_or_struct if batch_mode else [mol_or_struct])
133+
]
134+
return results if batch_mode else results[0]
122135

123136
def run_ase(
124137
self,
@@ -229,38 +242,48 @@ def __post_init__(self) -> None:
229242
@job(data=_ASE_DATA_OBJECTS)
230243
def make(
231244
self,
232-
mol_or_struct: Molecule | Structure,
245+
mol_or_struct: Molecule | Structure | list[Molecule | Structure],
233246
prev_dir: str | Path | None = None,
234-
) -> AseStructureTaskDoc | AseMoleculeTaskDoc:
247+
) -> (
248+
AseStructureTaskDoc
249+
| AseMoleculeTaskDoc
250+
| list[AseStructureTaskDoc | AseMoleculeTaskDoc]
251+
):
235252
"""
236253
Relax a structure or molecule using ASE as a job.
237254
238255
Parameters
239256
----------
240-
mol_or_struct: .Molecule or .Structure
241-
pymatgen molecule or structure
257+
mol_or_struct: .Molecule or .Structure, or list thereof
258+
pymatgen molecule(s) or structure(s)
242259
prev_dir : str or Path or None
243260
A previous calculation directory to copy output files from. Unused, just
244261
added to match the method signature of other makers.
245262
246263
Returns
247264
-------
248-
AseStructureTaskDoc or AseMoleculeTaskDoc
265+
AseStructureTaskDoc or AseMoleculeTaskDoc, or list thereof
249266
"""
250-
return AseTaskDoc.to_mol_or_struct_metadata_doc(
251-
getattr(self.calculator, "name", type(self.calculator).__name__),
252-
self.run_ase(mol_or_struct, prev_dir=prev_dir),
253-
self.steps,
254-
relax_kwargs=self.relax_kwargs,
255-
optimizer_kwargs=self.optimizer_kwargs,
256-
relax_cell=self.relax_cell,
257-
relax_shape=self.relax_shape,
258-
fix_symmetry=self.fix_symmetry,
259-
symprec=self.symprec if self.fix_symmetry else None,
260-
ionic_step_data=self.ionic_step_data,
261-
store_trajectory=self.store_trajectory,
262-
tags=self.tags,
263-
)
267+
batch_mode = isinstance(mol_or_struct, list)
268+
269+
results = [
270+
AseTaskDoc.to_mol_or_struct_metadata_doc(
271+
getattr(self.calculator, "name", type(self.calculator).__name__),
272+
self.run_ase(atoms, prev_dir=prev_dir),
273+
self.steps,
274+
relax_kwargs=self.relax_kwargs,
275+
optimizer_kwargs=self.optimizer_kwargs,
276+
relax_cell=self.relax_cell,
277+
relax_shape=self.relax_shape,
278+
fix_symmetry=self.fix_symmetry,
279+
symprec=self.symprec if self.fix_symmetry else None,
280+
ionic_step_data=self.ionic_step_data,
281+
store_trajectory=self.store_trajectory,
282+
tags=self.tags,
283+
)
284+
for atoms in (mol_or_struct if batch_mode else [mol_or_struct])
285+
]
286+
return results if batch_mode else results[0]
264287

265288
def run_ase(
266289
self,

src/atomate2/common/flows/phonons.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class BasePhononMaker(Maker, ABC):
132132
store_force_constants: bool
133133
if True, force constants will be stored
134134
socket: bool
135-
If True, use the socket for the calculation
135+
If True, use the socket/batch for the calculation
136136
"""
137137

138138
name: str = "phonon"

src/atomate2/common/jobs/phonons.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@
2121
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine
2222
from pymatgen.phonon.dos import PhononDos
2323

24+
from atomate2.ase.jobs import AseRelaxMaker
2425
from atomate2.common.schemas.phonons import ForceConstants, PhononBSDOSDoc, get_factor
2526
from atomate2.common.utils import get_supercell_matrix
27+
from atomate2.forcefields.jobs import ForceFieldRelaxMaker
28+
from atomate2.vasp.jobs.base import BaseVaspMaker
2629

2730
if TYPE_CHECKING:
2831
from pathlib import Path
2932

3033
from emmet.core.math import Matrix3D
3134

3235
from atomate2.aims.jobs.base import BaseAimsMaker
33-
from atomate2.forcefields.jobs import ForceFieldStaticMaker
34-
from atomate2.vasp.jobs.base import BaseVaspMaker
35-
3636

3737
logger = logging.getLogger(__name__)
3838

@@ -253,7 +253,10 @@ def run_phonon_displacements(
253253
displacements: list[Structure],
254254
structure: Structure,
255255
supercell_matrix: Matrix3D,
256-
phonon_maker: BaseVaspMaker | ForceFieldStaticMaker | BaseAimsMaker = None,
256+
phonon_maker: BaseVaspMaker
257+
| AseRelaxMaker
258+
| ForceFieldRelaxMaker
259+
| BaseAimsMaker = None,
257260
prev_dir: str | Path = None,
258261
prev_dir_argname: str = None,
259262
socket: bool = False,
@@ -272,14 +275,16 @@ def run_phonon_displacements(
272275
Fully optimized structure used for phonon computations.
273276
supercell_matrix: Matrix3D
274277
supercell matrix for meta data
275-
phonon_maker : .BaseVaspMaker or .ForceFieldStaticMaker or .BaseAimsMaker
276-
A maker to use to generate dispacement calculations
278+
phonon_maker : .BaseVaspMaker, .AseRelaxMaker,
279+
.ForceFieldRelaxMaker, or .BaseAimsMaker
280+
A maker to use to generate dispacement calculations.
281+
NB: this should be a static maker.
277282
prev_dir: str or Path
278283
The previous working directory
279284
prev_dir_argname: str
280285
argument name for the prev_dir variable
281286
socket: bool
282-
If True use the socket-io interface to increase performance
287+
If True use the socket-io (batch-mode) interface to increase performance
283288
"""
284289
phonon_jobs = []
285290
outputs: dict[str, list] = {
@@ -292,28 +297,39 @@ def run_phonon_displacements(
292297
if prev_dir is not None and prev_dir_argname is not None:
293298
phonon_job_kwargs[prev_dir_argname] = prev_dir
294299

300+
num_disp = len(displacements)
295301
if socket:
302+
if isinstance(phonon_maker, BaseVaspMaker):
303+
raise ValueError("VASP makers do not currently support socket/batch mode.")
304+
296305
phonon_job = phonon_maker.make(displacements, **phonon_job_kwargs)
297306
info = {
298307
"original_structure": structure,
299308
"supercell_matrix": supercell_matrix,
300309
"displaced_structures": displacements,
301310
}
302-
phonon_job.update_maker_kwargs(
303-
{"_set": {"write_additional_data->phonon_info:json": info}}, dict_mod=True
304-
)
311+
if not isinstance(phonon_maker, AseRelaxMaker | ForceFieldRelaxMaker):
312+
phonon_job.update_maker_kwargs(
313+
{"_set": {"write_additional_data->phonon_info:json": info}},
314+
dict_mod=True,
315+
)
316+
305317
phonon_jobs.append(phonon_job)
306-
outputs["displacement_number"] = list(range(len(displacements)))
307-
outputs["uuids"] = [phonon_job.output.uuid] * len(displacements)
308-
outputs["dirs"] = [phonon_job.output.dir_name] * len(displacements)
309-
outputs["forces"] = phonon_job.output.output.all_forces
318+
outputs["displacement_number"] = list(range(num_disp))
319+
if isinstance(phonon_maker, AseRelaxMaker | ForceFieldRelaxMaker):
320+
outputs["uuids"] = [phonon_job.output[0].uuid] * num_disp
321+
outputs["dirs"] = [phonon_job.output[0].dir_name] * num_disp
322+
outputs["forces"] = [
323+
phonon_job.output[idx].output.forces for idx in range(num_disp)
324+
]
325+
else:
326+
outputs["uuids"] = [phonon_job.output.uuid] * num_disp
327+
outputs["dirs"] = [phonon_job.output.dir_name] * num_disp
328+
outputs["forces"] = phonon_job.output.output.all_forces
310329
else:
311330
for idx, displacement in enumerate(displacements):
312-
if prev_dir is not None:
313-
phonon_job = phonon_maker.make(displacement, prev_dir=prev_dir)
314-
else:
315-
phonon_job = phonon_maker.make(displacement)
316-
phonon_job.append_name(f" {idx + 1}/{len(displacements)}")
331+
phonon_job = phonon_maker.make(displacement, prev_dir=prev_dir)
332+
phonon_job.append_name(f" {idx + 1}/{num_disp}")
317333

318334
# we will add some meta data
319335
info = {
@@ -323,10 +339,11 @@ def run_phonon_displacements(
323339
"displaced_structure": displacement,
324340
}
325341
with contextlib.suppress(Exception):
326-
phonon_job.update_maker_kwargs(
327-
{"_set": {"write_additional_data->phonon_info:json": info}},
328-
dict_mod=True,
329-
)
342+
if not isinstance(phonon_maker, AseRelaxMaker | ForceFieldRelaxMaker):
343+
phonon_job.update_maker_kwargs(
344+
{"_set": {"write_additional_data->phonon_info:json": info}},
345+
dict_mod=True,
346+
)
330347
phonon_jobs.append(phonon_job)
331348
outputs["displacement_number"].append(idx)
332349
outputs["uuids"].append(phonon_job.output.uuid)

src/atomate2/forcefields/jobs.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -122,21 +122,29 @@ class ForceFieldRelaxMaker(ForceFieldMixin, AseRelaxMaker):
122122

123123
@forcefield_job
124124
def make(
125-
self, structure: Molecule | Structure, prev_dir: str | Path | None = None
126-
) -> ForceFieldTaskDocument | ForceFieldMoleculeTaskDocument:
125+
self,
126+
structure: Molecule | Structure | list[Molecule | Structure],
127+
prev_dir: str | Path | None = None,
128+
) -> (
129+
ForceFieldTaskDocument
130+
| ForceFieldMoleculeTaskDocument
131+
| list[ForceFieldTaskDocument | ForceFieldMoleculeTaskDocument]
132+
):
127133
"""
128134
Perform a relaxation of a structure using a force field.
129135
130136
Parameters
131137
----------
132-
structure: .Structure or Molecule
133-
pymatgen structure or molecule.
138+
structure: .Molecule or .Structure, or a list thereof
139+
pymatgen molecule(s) or structure(s)
134140
prev_dir : str or Path or None
135141
A previous calculation directory to copy output files from. Unused, just
136142
added to match the method signature of other makers.
137-
"""
138-
ase_result = self._run_ase_safe(structure, prev_dir=prev_dir)
139143
144+
Returns
145+
-------
146+
ForceFieldTaskDocument, ForceFieldMoleculeTaskDocument, or a list thereof
147+
"""
140148
if len(self.task_document_kwargs) > 0:
141149
warnings.warn(
142150
"`task_document_kwargs` is now deprecated, please use the top-level "
@@ -145,22 +153,33 @@ def make(
145153
stacklevel=1,
146154
)
147155

148-
return ForceFieldTaskDocument.from_ase_compatible_result(
149-
self.ase_calculator_name,
150-
ase_result,
151-
self.steps,
152-
calculator_meta=self.calculator_meta,
153-
relax_kwargs=self.relax_kwargs,
154-
optimizer_kwargs=self.optimizer_kwargs,
155-
relax_cell=self.relax_cell,
156-
relax_shape=self.relax_shape,
157-
fix_symmetry=self.fix_symmetry,
158-
symprec=self.symprec if self.fix_symmetry else None,
159-
ionic_step_data=self.ionic_step_data,
160-
store_trajectory=self.store_trajectory,
161-
tags=self.tags,
162-
**self.task_document_kwargs,
163-
)
156+
batch_mode = isinstance(structure, list)
157+
158+
ase_results = [
159+
self._run_ase_safe(atoms, prev_dir=prev_dir)
160+
for atoms in (structure if batch_mode else [structure])
161+
]
162+
163+
task_docs = [
164+
ForceFieldTaskDocument.from_ase_compatible_result(
165+
self.ase_calculator_name,
166+
ase_result,
167+
self.steps,
168+
calculator_meta=self.calculator_meta,
169+
relax_kwargs=self.relax_kwargs,
170+
optimizer_kwargs=self.optimizer_kwargs,
171+
relax_cell=self.relax_cell,
172+
relax_shape=self.relax_shape,
173+
fix_symmetry=self.fix_symmetry,
174+
symprec=self.symprec if self.fix_symmetry else None,
175+
ionic_step_data=self.ionic_step_data,
176+
store_trajectory=self.store_trajectory,
177+
tags=self.tags,
178+
**self.task_document_kwargs,
179+
)
180+
for ase_result in ase_results
181+
]
182+
return task_docs if batch_mode else task_docs[0]
164183

165184

166185
@dataclass

tests/ase/test_jobs.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,27 @@ def test_lennard_jones_relax_maker(lj_fcc_ne_pars, fcc_ne_structure):
8989
)
9090

9191

92+
def test_lennard_jones_batch_relax_maker(
93+
lj_fcc_ne_pars, fcc_ne_structure, memory_jobstore
94+
):
95+
job = LennardJonesRelaxMaker(
96+
calculator_kwargs=lj_fcc_ne_pars, relax_kwargs={"fmax": 0.001}
97+
).make([fcc_ne_structure, fcc_ne_structure])
98+
99+
response = run_locally(job, store=memory_jobstore)
100+
101+
output = response[job.uuid][1].output
102+
103+
assert [calc.output.structure.volume for calc in output] == pytest.approx(
104+
[22.304245, 22.304245]
105+
)
106+
assert [calc.output.energy for calc in output] == pytest.approx(
107+
[-0.018494767, -0.018494767]
108+
)
109+
assert all(isinstance(calc, AseStructureTaskDoc) for calc in output)
110+
assert fcc_ne_structure.matches(output[0].output.structure)
111+
112+
92113
def test_lennard_jones_static_maker(lj_fcc_ne_pars, fcc_ne_structure):
93114
job = LennardJonesStaticMaker(calculator_kwargs=lj_fcc_ne_pars).make(
94115
fcc_ne_structure

tests/forcefields/flows/test_phonon.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from importlib.util import find_spec
3+
from itertools import product
34
from pathlib import Path
45
from tempfile import TemporaryDirectory
56

@@ -109,10 +110,12 @@ def test_phonon_maker_initialization_with_all_mlff(
109110
) from exc
110111

111112

112-
@pytest.mark.skipif(not mlff_is_installed("CHGNet"), reason="matgl is not installed")
113-
@pytest.mark.parametrize("from_name", [False, True])
113+
@pytest.mark.skipif(
114+
not mlff_is_installed("CHGNet"), reason="matgl/chgnet is not installed"
115+
)
116+
@pytest.mark.parametrize("from_name, socket", list(product(*[[True, False]] * 2)))
114117
def test_phonon_wf_force_field(
115-
clean_dir, si_structure: Structure, tmp_path: Path, from_name: bool
118+
clean_dir, si_structure: Structure, tmp_path: Path, from_name: bool, socket: bool
116119
):
117120
# TODO brittle due to inability to adjust dtypes in CHGNetRelaxMaker
118121

@@ -148,6 +151,7 @@ def test_phonon_wf_force_field(
148151
"filename_bs": (filename_bs := f"{tmp_path}/phonon_bs_test.png"),
149152
"filename_dos": (filename_dos := f"{tmp_path}/phonon_dos_test.pdf"),
150153
},
154+
socket=socket,
151155
)
152156

153157
if from_name:

0 commit comments

Comments
 (0)