Skip to content

Commit e2feb94

Browse files
committed
Add dask gestion to grid_to_sh and sh_to_grid
1 parent befcf21 commit e2feb94

1 file changed

Lines changed: 83 additions & 38 deletions

File tree

lenapy/utils/harmo.py

Lines changed: 83 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def _init_degrees(
238238
mmax: int = None,
239239
used_l: np.ndarray = None,
240240
used_m: np.ndarray = None,
241-
) -> tuple[np.ndarray, np.ndarray, int, int, int, int]:
241+
) -> tuple[xr.Dataset, np.ndarray, np.ndarray, int, int, int, int]:
242242
"""
243243
Initialize spherical harmonic degrees and orders to be used.
244244
@@ -257,6 +257,8 @@ def _init_degrees(
257257
258258
Returns
259259
-------
260+
sub_data : xr.Dataset
261+
Reduced dataset with selected degrees and orders.
260262
used_l: np.ndarray
261263
Degrees to use.
262264
used_m: np.ndarray
@@ -276,7 +278,8 @@ def _init_degrees(
276278
mmin = int(data.m.min()) if mmin is None else mmin
277279
used_l = np.arange(lmin, lmax + 1) if used_l is None else used_l
278280
used_m = np.arange(mmin, mmax + 1) if used_m is None else used_m
279-
return used_l, used_m, lmin, lmax, mmin, mmax
281+
sub_data = data.sel(l=used_l, m=used_m)
282+
return sub_data, used_l, used_m, lmin, lmax, mmin, mmax
280283

281284

282285
def sh_to_grid(
@@ -304,6 +307,8 @@ def sh_to_grid(
304307
include_elastic: bool = True,
305308
plm: xr.DataArray = None,
306309
normalization_plm: Literal["4pi", "ortho", "schmidt"] = "4pi",
310+
use_dask: bool = False,
311+
chunks_plm: dict | None = None,
307312
**kwargs,
308313
) -> xr.DataArray:
309314
"""
@@ -374,6 +379,11 @@ def sh_to_grid(
374379
Either '4pi', 'ortho', or 'schmidt' for 4pi normalized, orthonormalized, or Schmidt semi-normalized SH
375380
functions, respectively. Default is '4pi'.
376381
382+
use_dask : bool, optional
383+
If True, use dask to chunk plm for memory optimization. Default is False.
384+
chunks_plm : dict, optional
385+
Define the chunking of plm when use_dask is True. Default is None, which set the chunking to {'latitude': 1}.
386+
377387
**kwargs :
378388
Supplementary parameters used by the function l_factor_conv to modify defaults constants used in the computation
379389
for the unit conversion. These parameters include (see :func:`l_factor_conv` documentation for more details) :
@@ -386,7 +396,7 @@ def sh_to_grid(
386396
"""
387397
# add mask in output variable
388398

389-
used_l, used_m, lmin, lmax, mmin, mmax = _init_degrees(
399+
sub_data, used_l, used_m, lmin, lmax, mmin, mmax = _init_degrees(
390400
data, lmin, lmax, mmin, mmax, used_l, used_m
391401
)
392402
used_l, use_czero_coef, force_mass_conservation = _handle_mass_conservation(
@@ -415,22 +425,22 @@ def sh_to_grid(
415425
plm = compute_plm(
416426
lmax,
417427
np.cos(geocentric_colat),
428+
latitude=latitude,
418429
mmax=mmax,
419430
normalization=normalization_plm,
431+
use_dask=use_dask,
432+
chunks=chunks_plm,
420433
)
421434
else:
422435
plm = compute_plm(
423-
lmax, sin_latitude, mmax=mmax, normalization=normalization_plm
436+
lmax,
437+
sin_latitude,
438+
latitude=latitude,
439+
mmax=mmax,
440+
normalization=normalization_plm,
441+
use_dask=use_dask,
442+
chunks=chunks_plm,
424443
)
425-
plm = xr.DataArray(
426-
plm,
427-
dims=["l", "m", "latitude"],
428-
coords={
429-
"l": np.arange(lmax + 1),
430-
"m": np.arange(mmax + 1),
431-
"latitude": latitude,
432-
},
433-
)
434444

435445
else:
436446
# Verify plm integrity
@@ -463,7 +473,7 @@ def sh_to_grid(
463473
include_elastic=include_elastic,
464474
ellipsoidal_earth=ellipsoidal_earth,
465475
geocentric_colat=geocentric_colat,
466-
attrs=data.attrs,
476+
attrs=sub_data.attrs,
467477
**kwargs,
468478
)
469479

@@ -484,14 +494,14 @@ def sh_to_grid(
484494

485495
# summation over all spherical harmonic degrees
486496
if not errors:
487-
d_clm = (plm_lfactor * data.sel(l=used_l, m=used_m).clm).sum(dim="l")
488-
d_slm = (plm_lfactor * data.sel(l=used_l, m=used_m).slm).sum(dim="l")
497+
d_clm = (plm_lfactor * sub_data.clm).sum(dim="l")
498+
d_slm = (plm_lfactor * sub_data.slm).sum(dim="l")
489499

490500
# Final calcul on the grid
491501
xgrid = c_cos.dot(d_clm) + s_sin.dot(d_slm)
492502
else:
493-
d_clm = (plm_lfactor**2 * data.sel(l=used_l, m=used_m).clm ** 2).sum(dim="l")
494-
d_slm = (plm_lfactor**2 * data.sel(l=used_l, m=used_m).slm ** 2).sum(dim="l")
503+
d_clm = (plm_lfactor**2 * sub_data.clm ** 2).sum(dim="l")
504+
d_slm = (plm_lfactor**2 * sub_data.slm ** 2).sum(dim="l")
495505

496506
# Final calcul of sigma on the grid
497507
xgrid = np.sqrt((c_cos**2).dot(d_clm) + (s_sin**2).dot(d_slm))
@@ -518,17 +528,17 @@ def sh_to_grid(
518528
# restore C0 mass
519529
if use_czero_coef:
520530
lfactor_zero = l_factor_conv(
521-
np.array([0]), unit=unit, attrs=data.attrs, **kwargs
531+
np.array([0]), unit=unit, attrs=sub_data.attrs, **kwargs
522532
)[0]
523-
xgrid = xgrid + (lfactor_zero * data.clm.sel(l=0, m=0)).values
533+
xgrid = xgrid + (lfactor_zero * sub_data.clm.sel(l=0, m=0)).values
524534

525535
xgrid = xgrid.transpose("latitude", "longitude", ...)
526536

527537
xgrid.attrs = {"units": unit, "max_degree": int(lmax)}
528-
if "radius" in data.attrs:
529-
xgrid.attrs["radius"] = data.attrs["radius"]
530-
if "earth_gravity_constant" in data.attrs:
531-
xgrid.attrs["earth_gravity_constant"] = data.attrs["earth_gravity_constant"]
538+
if "radius" in sub_data.attrs:
539+
xgrid.attrs["radius"] = sub_data.attrs["radius"]
540+
if "earth_gravity_constant" in sub_data.attrs:
541+
xgrid.attrs["earth_gravity_constant"] = sub_data.attrs["earth_gravity_constant"]
532542

533543
return xgrid
534544

@@ -546,6 +556,8 @@ def grid_to_sh(
546556
include_elastic: bool = True,
547557
plm: xr.DataArray | None = None,
548558
normalization_plm: Literal["4pi", "ortho", "schmidt"] = "4pi",
559+
use_dask: bool = False,
560+
chunks_plm: dict | None = None,
549561
**kwargs,
550562
) -> xr.Dataset:
551563
"""
@@ -591,6 +603,11 @@ def grid_to_sh(
591603
4pi normalized, orthonormalized, or Schmidt semi-normalized SH functions, respectively. Default is '4pi'.
592604
Output SH coefficient will be normalized according to this parameter.
593605
606+
use_dask : bool, optional
607+
If True, use dask to chunk plm for memory optimization. Default is False.
608+
chunks_plm : dict, optional
609+
Define the chunking of plm when use_dask is True. Default is None, which set the chunking to {'latitude': 1}.
610+
594611
**kwargs :
595612
Supplementary parameters used by the function l_factor_conv to modify defaults constants used in the computation
596613
for the unit conversion. These parameters include (see :func:`l_factor_conv` documentation for more details) :
@@ -651,22 +668,21 @@ def grid_to_sh(
651668
plm = compute_plm(
652669
lmax,
653670
np.cos(geocentric_colat),
671+
latitude=grid.cf["latitude"],
654672
mmax=mmax,
655673
normalization=normalization_plm,
674+
use_dask=use_dask,
675+
chunks=chunks_plm,
656676
)
657677
else:
658678
plm = compute_plm(
659-
lmax, sin_latitude, mmax=mmax, normalization=normalization_plm
679+
lmax, sin_latitude,
680+
latitude=grid.cf["latitude"],
681+
mmax=mmax,
682+
normalization=normalization_plm,
683+
use_dask = use_dask,
684+
chunks = chunks_plm,
660685
)
661-
plm = xr.DataArray(
662-
plm,
663-
dims=["l", "m", "latitude"],
664-
coords={
665-
"l": np.arange(lmax + 1),
666-
"m": np.arange(lmax + 1),
667-
"latitude": grid.cf["latitude"],
668-
},
669-
)
670686

671687
else:
672688
# Verify plm integrity
@@ -736,10 +752,13 @@ def grid_to_sh(
736752
def compute_plm(
737753
lmax: int,
738754
z: np.ndarray,
755+
latitude: np.ndarray = None,
739756
mmax: int = None,
740757
normalization: Literal["4pi", "ortho", "schmidt"] = "4pi",
741758
dtype: complex | float | type[complex] | type[float] = np.float128,
742-
) -> np.ndarray:
759+
use_dask: bool = False,
760+
chunks: dict | None = None,
761+
) -> xr.DataArray:
743762
"""
744763
Compute all the associated Legendre functions up to a maximum degree and
745764
order using the recursion relation from [Holmes2002]_
@@ -751,6 +770,8 @@ def compute_plm(
751770
Maximum degree of legrendre functions.
752771
z : np.ndarray
753772
Argument of the associated Legendre functions.
773+
latitude : np.ndarray, optional
774+
Latitude values in degrees. Default is None and latitude is made from z.
754775
mmax : int or NoneType, optional
755776
Maximum order of associated legrendre functions.
756777
normalization : {'4pi', 'ortho', 'schmidt'}, optional
@@ -759,10 +780,15 @@ def compute_plm(
759780
dtype : dtype, optional
760781
Data type of the output array. Default is np.float128.
761782
783+
use_dask : bool, optional
784+
If True, use dask to chunk plm for memory optimization. Default is False.
785+
chunks : dict, optional
786+
Define the chunking of plm when use_dask is True. Default is None, which set the chunking to {'latitude': 1}.
787+
762788
Returns
763789
-------
764-
plm : np.ndarray
765-
Fully-normalized Legendre functions as a 3D array with "l", "m" and z dimensions.
790+
plm : xr.DataArray
791+
Fully-normalized Legendre functions as a DataArray with "l", "m" and "latitude" dimensions.
766792
767793
References
768794
----------
@@ -789,6 +815,9 @@ def compute_plm(
789815
# if default mmax, set mmax to be maximal degree
790816
mmax = lmax if mmax is None else mmax
791817

818+
# if default latitude, set it from z
819+
latitude = z if latitude is None else latitude
820+
792821
f1, f2, norm_p10, norm_4pi = _compute_factors(lmax, normalization)
793822

794823
# scale factor based on Holmes2002
@@ -851,8 +880,24 @@ def compute_plm(
851880
ind = np.tril_indices(lmax + 1)
852881
plm[ind] = p
853882

883+
plm = xr.DataArray(
884+
plm[:, : mmax + 1, :],
885+
dims=["l", "m", "latitude"],
886+
coords={
887+
"l": np.arange(lmax + 1),
888+
"m": np.arange(mmax + 1),
889+
"latitude": latitude, #grid.cf["latitude"]
890+
},
891+
)
892+
893+
# Chunking plm for dask usage and memory optimization
894+
if use_dask:
895+
if chunks is None:
896+
chunks = {"latitude": 1}
897+
plm = plm.chunk(chunks)
898+
854899
# return the legendre polynomials and truncating orders to mmax
855-
return plm[:, : mmax + 1, :]
900+
return plm
856901

857902

858903
def mid_month_grace_estimate(

0 commit comments

Comments
 (0)