# SPDX-License-Identifier: BSD-2-Clause
# Copyright(C) 2024 Max-Planck-Society
Binding the DUCC package to JAX#
This file binds the fast fourier transformation, hartly transformation, spherical harmonics transformation, and the wgridder of the DUCC package to JAX.
This provides real world examples on the usage of JAXbind. Together with the demos this could be a good starting point on the usage of JAXbind.
Please note: This file is JAXbind internal. When executing the blow code
blocks outside of JAXbind you need to import get_linear_call
and
load_kwargs
from JAXbind.
from functools import partial
import ducc0
import numpy as np
from .. import get_linear_call, load_kwargs
__all__ = ["c2c", "genuine_fht", "get_healpix_sht", "nalm", "get_wgridder"]
_r2cdict = {
np.dtype(np.float32): np.dtype(np.complex64),
np.dtype(np.float64): np.dtype(np.complex128),
}
_c2rdict = {
np.dtype(np.complex64): np.dtype(np.float32),
np.dtype(np.complex128): np.dtype(np.float64),
}
def _complextype(dtype):
return _r2cdict[np.dtype(dtype)]
def _realtype(dtype):
return _c2rdict[np.dtype(dtype)]
Binding the DUCC hartly transform to JAX#
In the following we provide a Python _fht
calling C++ hartly transformation
of DUCC ducc0.fft.genuine_fht
. The C++ function natively supports
batching along multiple axes. To make use of that the the python function also
translates potential batching axis of JAX to axis of DUCC. Detailed
explanations on custom batching along axis can be found in the
01_linear_function.py
demo.
def _fht(out, args, kwargs_dump):
(x,) = args
kwargs = load_kwargs(kwargs_dump)
batch_axes = kwargs.pop("batch_axes", None)
axes = list(range(x.ndim))
if batch_axes is not None:
axes = [i for i in range(x.ndim) if i not in batch_axes[0]]
orig_axis = kwargs.pop("axes", None)
orig_axis = [orig_axis] if isinstance(orig_axis, int) else orig_axis
if orig_axis is not None:
axes = [i for idx, i in enumerate(axes) if idx in orig_axis]
ducc0.fft.genuine_fht(x, out=out[0], axes=axes, **kwargs)
Additionally we provide the abstract evaluation function.
def _fht_abstract(*args, **kwargs):
(x,) = args
batch_axes = kwargs.pop("batch_axes", None)
out_ax = ()
if batch_axes is not None and len(batch_axes[0]) > 0:
out_ax = batch_axes[0][-1]
return ((x.shape, x.dtype, out_ax),)
Now we register the hartly transformation JAX primitive via the
get_linear_call
functionality of JAXbind.
genuine_fht = get_linear_call(
_fht, _fht, _fht_abstract, _fht_abstract, func_can_batch=True
)
genuine_fht.__doc__ = ducc0.fft.genuine_fht.__doc__
Binding the DUCC fast fourier transformation to JAX#
In analogy to the hartly transformation we bind the DUCC fast fourier transformation to JAX.
def _c2c(out, args, kwargs_dump):
(x,) = args
kwargs = load_kwargs(kwargs_dump)
batch_axes = kwargs.pop("batch_axes", None)
axes = list(range(x.ndim))
if batch_axes is not None:
axes = [i for i in range(x.ndim) if i not in batch_axes[0]]
orig_axis = kwargs.pop("axes", None)
orig_axis = [orig_axis] if isinstance(orig_axis, int) else orig_axis
if orig_axis is not None:
axes = [i for idx, i in enumerate(axes) if idx in orig_axis]
ducc0.fft.c2c(x, out=out[0], axes=axes, **kwargs)
def _c2c_abstract(*args, **kwargs):
(x,) = args
batch_axes = kwargs.pop("batch_axes", None)
out_ax = ()
if batch_axes is not None and len(batch_axes[0]) > 0:
out_ax = batch_axes[0][-1]
return ((x.shape, x.dtype, out_ax),)
c2c = get_linear_call(_c2c, _c2c, _c2c_abstract, _c2c_abstract, func_can_batch=True)
c2c.__doc__ = ducc0.fft.c2c.__doc__
Binding the DUCC healpix spherical harmonic transformation to JAX#
def _alm2realalm(alm, lmax, dtype, out=None):
if out is None:
out = np.empty((alm.shape[0], alm.shape[1] * 2 - lmax - 1), dtype=dtype)
out[:, 0 : lmax + 1] = alm[:, 0 : lmax + 1].real
out[:, lmax + 1 :] = alm[:, lmax + 1 :].view(dtype)
out[:, lmax + 1 :] *= np.sqrt(2.0)
return out
def _realalm2alm(alm, lmax, dtype, out=None):
if out is None:
out = np.empty((alm.shape[0], (alm.shape[1] + lmax + 1) // 2), dtype=dtype)
out[:, 0 : lmax + 1] = alm[:, 0 : lmax + 1]
out[:, lmax + 1 :] = alm[:, lmax + 1 :].view(dtype)
out[:, lmax + 1 :] *= np.sqrt(2.0) / 2
return out
We wrap the spherical harmonic transformation and it’s transposed of DUCC as a python function with the signature required by JAXbind. Additionally we provide the required abstract evaluation functions.
def _healpix_sht(out, args, kwargs_dump):
theta, phi0, nphi, ringstart, x = args
kwargs = load_kwargs(kwargs_dump).copy()
tmp = _realalm2alm(x, kwargs["lmax"], _complextype(x.dtype))
ducc0.sht.synthesis(
map=out[0],
alm=tmp,
theta=theta,
phi0=phi0,
nphi=nphi,
ringstart=ringstart,
spin=kwargs["spin"],
lmax=kwargs["lmax"],
mmax=kwargs["mmax"],
nthreads=kwargs["nthreads"],
)
def _healpix_sht_T(out, args, kwargs_dump):
theta, phi0, nphi, ringstart, x = args
kwargs = load_kwargs(kwargs_dump).copy()
tmp = ducc0.sht.adjoint_synthesis(
map=x,
theta=theta,
phi0=phi0,
nphi=nphi,
ringstart=ringstart,
spin=kwargs["spin"],
lmax=kwargs["lmax"],
mmax=kwargs["mmax"],
nthreads=kwargs["nthreads"],
)
_alm2realalm(tmp, kwargs["lmax"], x.dtype, out[0])
def _healpix_sht_abstract(*args, **kwargs):
_, _, _, _, x = args
spin = kwargs["spin"]
ncomp = 1 if spin == 0 else 2
shape_out = (ncomp, 12 * kwargs["nside"] ** 2)
return ((shape_out, x.dtype),)
def _healpix_sht_abstract_T(*args, **kwargs):
_, _, _, _, x = args
spin = kwargs["spin"]
ncomp = 1 if spin == 0 else 2
lmax, mmax = kwargs["lmax"], kwargs["mmax"]
nalm = ((mmax + 1) * (mmax + 2)) // 2 + (mmax + 1) * (lmax - mmax)
nalm = nalm * 2 - lmax - 1
shape_out = (ncomp, nalm)
return ((shape_out, x.dtype),)
Now we register the JAX primitive. The spherical harmonics transformation is
not linear/ differentiable in the arguments theta, phi0, nphi, and ringstart.
We communicate this to JAXbind via first_n_args_fixed=4
. A more detailed
example on the usage of first_n_args_fixed
is in demo
03_nonlinear_function.py
.
_hp_sht = get_linear_call(
_healpix_sht,
_healpix_sht_T,
_healpix_sht_abstract,
_healpix_sht_abstract_T,
first_n_args_fixed=4,
)
To the user we expose a JAX function in which the non differentiable arguments are already inserted.
def get_healpix_sht(nside, lmax, mmax, spin, nthreads=1):
"""Create a JAX primitive for the ducc0 SHT synthesis for HEALPix
Parameters
----------
nside : int
Parameter of the HEALPix sphere.
lmax, mmax : int
Maximum l respectively m moment of the transformation (inclusive).
spin : int
Spin to use for the transfomration.
nthreads : int
Number of threads to use for the computation. If 0, use as many threads
as there are hardware threads available on the system.
Returns
-------
op : JAX primitive
The Jax primitive of the SHT synthesis for HEALPix.
"""
hpxparam = ducc0.healpix.Healpix_Base(nside, "RING").sht_info()
hpp = partial(
_hp_sht,
hpxparam["theta"],
hpxparam["phi0"],
hpxparam["nphi"],
hpxparam["ringstart"],
lmax=lmax,
mmax=mmax,
spin=spin,
nthreads=nthreads,
nside=nside,
)
return hpp
def nalm(lmax, mmax):
"""Compute the number of a_lm for a given maximum l and m moment of the SHT"""
return ((mmax + 1) * (mmax + 2)) // 2 + (mmax + 1) * (lmax - mmax)
Binding the DUCC wgridder to JAX#
Again we define python functions calling the DUCC wgridder and the transposed alongside with abstract evaluation functions for the forward and transposed directions.
def _dirty2vis(out, args, kwargs_dump):
uvw, freq, dirty = args
kwargs = load_kwargs(kwargs_dump)
kwargs.pop("npix_x")
kwargs.pop("npix_y")
ducc0.wgridder.experimental.dirty2vis(
uvw=uvw, freq=freq, dirty=dirty, vis=out[0], **kwargs
)
def _dirty2vis_abstract(*args, **kwargs):
uvw, freq, dirty = args
shape_out = (uvw.shape[0], freq.shape[0])
dtype_out = _complextype(dirty.dtype)
return ((shape_out, dtype_out),)
def _vis2dirty(out, args, kwargs_dump):
uvw, freq, vis = args
kwargs = load_kwargs(kwargs_dump)
ducc0.wgridder.experimental.vis2dirty(
uvw=uvw, freq=freq, vis=vis.conj(), dirty=out[0], **kwargs
)
def _vis2dirty_abstract(*args, **kwargs):
_, _, vis = args
shape_out = (kwargs["npix_x"], kwargs["npix_y"])
dtype_out = _realtype(vis.dtype)
return ((shape_out, dtype_out),)
Similar to the spherical harmonic transformation not all arguments of the
wgridder are differentiable. Again we specify the non-differentiable arguments
as fixed via first_n_args_fixed=2
.
_wgridder = get_linear_call(
_dirty2vis,
_vis2dirty,
_dirty2vis_abstract,
_vis2dirty_abstract,
first_n_args_fixed=2,
)
def get_wgridder(
*,
pixsize_x,
pixsize_y,
npix_x,
npix_y,
epsilon,
do_wgridding,
nthreads=1,
flip_v=False,
verbosity=0,
**kwargs,
):
"""Create a JAX primitive for the ducc0 wgridder
Parameters
----------
pixsize_x, pixsize_y : float
Size of the pixels in radian.
npix_x, npix_y : int
Number of pixels.
epsilon : float
Sets the required accuracy of the wgridder evaluation.
nthreads : int
Sets the number of threads used for evaluation. Default 1.
flip_v : bool
Whether or not to flip the v coordinate of the visibilities. Default
`False`.
verbosity : int
Sets the verbosity of the wgridder. For 0 no print out, for >0 verbose
output. Default 0.
**kwargs : dict
Additional forwarded to ducc wgridder.
Returns
-------
op : JAX primitive evaluating the ducc wgridder.
The Jax primitive has the
signature `(uvw, freq, image)` with `uvw` being an (N, 3) array the uvw
coordinates of the visibilities in meter, `freq` a 1D array with the
frequencies in Herz, and `image` a 2D arrays of shape `(npix_x, npix_y)`
with the sky brightness in Jansky per Steradian.
"""
wgridder = partial(
_wgridder,
pixsize_x=pixsize_x,
pixsize_y=pixsize_y,
npix_x=npix_x,
npix_y=npix_y,
epsilon=epsilon,
do_wgridding=do_wgridding,
nthreads=nthreads,
flip_v=flip_v,
verbosity=verbosity,
**kwargs,
)
return wgridder