# SPDX-License-Identifier: BSD-2-Clause
# Copyright(C) 2024 Max-Planck-Society

Binding the DUCC package to JAX#

This file binds the fast fourier transformation, hartly transformation, spherical harmonics transformation, and the wgridder of the DUCC package to JAX.

This provides real world examples on the usage of JAXbind. Together with the demos this could be a good starting point on the usage of JAXbind.

Please note: This file is JAXbind internal. When executing the blow code blocks outside of JAXbind you need to import get_linear_call and load_kwargs from JAXbind.


from functools import partial

import ducc0
import numpy as np

from .. import get_linear_call, load_kwargs

__all__ = ["c2c", "genuine_fht", "get_healpix_sht", "nalm", "get_wgridder"]


_r2cdict = {
    np.dtype(np.float32): np.dtype(np.complex64),
    np.dtype(np.float64): np.dtype(np.complex128),
}

_c2rdict = {
    np.dtype(np.complex64): np.dtype(np.float32),
    np.dtype(np.complex128): np.dtype(np.float64),
}


def _complextype(dtype):
    return _r2cdict[np.dtype(dtype)]


def _realtype(dtype):
    return _c2rdict[np.dtype(dtype)]

Binding the DUCC hartly transform to JAX#

In the following we provide a Python _fht calling C++ hartly transformation of DUCC ducc0.fft.genuine_fht. The C++ function natively supports batching along multiple axes. To make use of that the the python function also translates potential batching axis of JAX to axis of DUCC. Detailed explanations on custom batching along axis can be found in the 01_linear_function.py demo.



def _fht(out, args, kwargs_dump):
    (x,) = args
    kwargs = load_kwargs(kwargs_dump)
    batch_axes = kwargs.pop("batch_axes", None)
    axes = list(range(x.ndim))
    if batch_axes is not None:
        axes = [i for i in range(x.ndim) if i not in batch_axes[0]]
    orig_axis = kwargs.pop("axes", None)
    orig_axis = [orig_axis] if isinstance(orig_axis, int) else orig_axis
    if orig_axis is not None:
        axes = [i for idx, i in enumerate(axes) if idx in orig_axis]
    ducc0.fft.genuine_fht(x, out=out[0], axes=axes, **kwargs)

Additionally we provide the abstract evaluation function.



def _fht_abstract(*args, **kwargs):
    (x,) = args
    batch_axes = kwargs.pop("batch_axes", None)
    out_ax = ()
    if batch_axes is not None and len(batch_axes[0]) > 0:
        out_ax = batch_axes[0][-1]
    return ((x.shape, x.dtype, out_ax),)

Now we register the hartly transformation JAX primitive via the get_linear_call functionality of JAXbind.

genuine_fht = get_linear_call(
    _fht, _fht, _fht_abstract, _fht_abstract, func_can_batch=True
)
genuine_fht.__doc__ = ducc0.fft.genuine_fht.__doc__

Binding the DUCC fast fourier transformation to JAX#

In analogy to the hartly transformation we bind the DUCC fast fourier transformation to JAX.



def _c2c(out, args, kwargs_dump):
    (x,) = args
    kwargs = load_kwargs(kwargs_dump)
    batch_axes = kwargs.pop("batch_axes", None)
    axes = list(range(x.ndim))
    if batch_axes is not None:
        axes = [i for i in range(x.ndim) if i not in batch_axes[0]]
    orig_axis = kwargs.pop("axes", None)
    orig_axis = [orig_axis] if isinstance(orig_axis, int) else orig_axis
    if orig_axis is not None:
        axes = [i for idx, i in enumerate(axes) if idx in orig_axis]
    ducc0.fft.c2c(x, out=out[0], axes=axes, **kwargs)


def _c2c_abstract(*args, **kwargs):
    (x,) = args
    batch_axes = kwargs.pop("batch_axes", None)
    out_ax = ()
    if batch_axes is not None and len(batch_axes[0]) > 0:
        out_ax = batch_axes[0][-1]
    return ((x.shape, x.dtype, out_ax),)


c2c = get_linear_call(_c2c, _c2c, _c2c_abstract, _c2c_abstract, func_can_batch=True)
c2c.__doc__ = ducc0.fft.c2c.__doc__

Binding the DUCC healpix spherical harmonic transformation to JAX#



def _alm2realalm(alm, lmax, dtype, out=None):
    if out is None:
        out = np.empty((alm.shape[0], alm.shape[1] * 2 - lmax - 1), dtype=dtype)
    out[:, 0 : lmax + 1] = alm[:, 0 : lmax + 1].real
    out[:, lmax + 1 :] = alm[:, lmax + 1 :].view(dtype)
    out[:, lmax + 1 :] *= np.sqrt(2.0)
    return out


def _realalm2alm(alm, lmax, dtype, out=None):
    if out is None:
        out = np.empty((alm.shape[0], (alm.shape[1] + lmax + 1) // 2), dtype=dtype)
    out[:, 0 : lmax + 1] = alm[:, 0 : lmax + 1]
    out[:, lmax + 1 :] = alm[:, lmax + 1 :].view(dtype)
    out[:, lmax + 1 :] *= np.sqrt(2.0) / 2
    return out

We wrap the spherical harmonic transformation and it’s transposed of DUCC as a python function with the signature required by JAXbind. Additionally we provide the required abstract evaluation functions.



def _healpix_sht(out, args, kwargs_dump):
    theta, phi0, nphi, ringstart, x = args
    kwargs = load_kwargs(kwargs_dump).copy()
    tmp = _realalm2alm(x, kwargs["lmax"], _complextype(x.dtype))
    ducc0.sht.synthesis(
        map=out[0],
        alm=tmp,
        theta=theta,
        phi0=phi0,
        nphi=nphi,
        ringstart=ringstart,
        spin=kwargs["spin"],
        lmax=kwargs["lmax"],
        mmax=kwargs["mmax"],
        nthreads=kwargs["nthreads"],
    )


def _healpix_sht_T(out, args, kwargs_dump):
    theta, phi0, nphi, ringstart, x = args
    kwargs = load_kwargs(kwargs_dump).copy()
    tmp = ducc0.sht.adjoint_synthesis(
        map=x,
        theta=theta,
        phi0=phi0,
        nphi=nphi,
        ringstart=ringstart,
        spin=kwargs["spin"],
        lmax=kwargs["lmax"],
        mmax=kwargs["mmax"],
        nthreads=kwargs["nthreads"],
    )
    _alm2realalm(tmp, kwargs["lmax"], x.dtype, out[0])


def _healpix_sht_abstract(*args, **kwargs):
    _, _, _, _, x = args
    spin = kwargs["spin"]
    ncomp = 1 if spin == 0 else 2
    shape_out = (ncomp, 12 * kwargs["nside"] ** 2)
    return ((shape_out, x.dtype),)


def _healpix_sht_abstract_T(*args, **kwargs):
    _, _, _, _, x = args
    spin = kwargs["spin"]
    ncomp = 1 if spin == 0 else 2
    lmax, mmax = kwargs["lmax"], kwargs["mmax"]
    nalm = ((mmax + 1) * (mmax + 2)) // 2 + (mmax + 1) * (lmax - mmax)
    nalm = nalm * 2 - lmax - 1
    shape_out = (ncomp, nalm)
    return ((shape_out, x.dtype),)

Now we register the JAX primitive. The spherical harmonics transformation is not linear/ differentiable in the arguments theta, phi0, nphi, and ringstart. We communicate this to JAXbind via first_n_args_fixed=4. A more detailed example on the usage of first_n_args_fixed is in demo 03_nonlinear_function.py.


_hp_sht = get_linear_call(
    _healpix_sht,
    _healpix_sht_T,
    _healpix_sht_abstract,
    _healpix_sht_abstract_T,
    first_n_args_fixed=4,
)

To the user we expose a JAX function in which the non differentiable arguments are already inserted.



def get_healpix_sht(nside, lmax, mmax, spin, nthreads=1):
    """Create a JAX primitive for the ducc0 SHT synthesis for HEALPix

    Parameters
    ----------
    nside : int
        Parameter of the HEALPix sphere.
    lmax, mmax : int
        Maximum l respectively m moment of the transformation (inclusive).
    spin : int
        Spin to use for the transfomration.
    nthreads : int
        Number of threads to use for the computation. If 0, use as many threads
        as there are hardware threads available on the system.

    Returns
    -------
    op : JAX primitive
        The Jax primitive of the SHT synthesis for HEALPix.
    """
    hpxparam = ducc0.healpix.Healpix_Base(nside, "RING").sht_info()

    hpp = partial(
        _hp_sht,
        hpxparam["theta"],
        hpxparam["phi0"],
        hpxparam["nphi"],
        hpxparam["ringstart"],
        lmax=lmax,
        mmax=mmax,
        spin=spin,
        nthreads=nthreads,
        nside=nside,
    )
    return hpp


def nalm(lmax, mmax):
    """Compute the number of a_lm for a given maximum l and m moment of the SHT"""
    return ((mmax + 1) * (mmax + 2)) // 2 + (mmax + 1) * (lmax - mmax)

Binding the DUCC wgridder to JAX#

Again we define python functions calling the DUCC wgridder and the transposed alongside with abstract evaluation functions for the forward and transposed directions.



def _dirty2vis(out, args, kwargs_dump):
    uvw, freq, dirty = args
    kwargs = load_kwargs(kwargs_dump)
    kwargs.pop("npix_x")
    kwargs.pop("npix_y")
    ducc0.wgridder.experimental.dirty2vis(
        uvw=uvw, freq=freq, dirty=dirty, vis=out[0], **kwargs
    )


def _dirty2vis_abstract(*args, **kwargs):
    uvw, freq, dirty = args
    shape_out = (uvw.shape[0], freq.shape[0])
    dtype_out = _complextype(dirty.dtype)
    return ((shape_out, dtype_out),)


def _vis2dirty(out, args, kwargs_dump):
    uvw, freq, vis = args
    kwargs = load_kwargs(kwargs_dump)
    ducc0.wgridder.experimental.vis2dirty(
        uvw=uvw, freq=freq, vis=vis.conj(), dirty=out[0], **kwargs
    )


def _vis2dirty_abstract(*args, **kwargs):
    _, _, vis = args
    shape_out = (kwargs["npix_x"], kwargs["npix_y"])
    dtype_out = _realtype(vis.dtype)
    return ((shape_out, dtype_out),)

Similar to the spherical harmonic transformation not all arguments of the wgridder are differentiable. Again we specify the non-differentiable arguments as fixed via first_n_args_fixed=2.


_wgridder = get_linear_call(
    _dirty2vis,
    _vis2dirty,
    _dirty2vis_abstract,
    _vis2dirty_abstract,
    first_n_args_fixed=2,
)


def get_wgridder(
    *,
    pixsize_x,
    pixsize_y,
    npix_x,
    npix_y,
    epsilon,
    do_wgridding,
    nthreads=1,
    flip_v=False,
    verbosity=0,
    **kwargs,
):
    """Create a JAX primitive for the ducc0 wgridder

    Parameters
    ----------
    pixsize_x, pixsize_y : float
        Size of the pixels in radian.
    npix_x, npix_y : int
        Number of pixels.
    epsilon : float
        Sets the required accuracy of the wgridder evaluation.
    nthreads : int
        Sets the number of threads used for evaluation. Default 1.
    flip_v : bool
        Whether or not to flip the v coordinate of the visibilities. Default
        `False`.
    verbosity : int
        Sets the verbosity of the wgridder. For 0 no print out, for >0 verbose
        output. Default 0.
    **kwargs : dict
        Additional forwarded to ducc wgridder.

    Returns
    -------
    op : JAX primitive evaluating the ducc wgridder.
        The Jax primitive has the
        signature `(uvw, freq, image)` with `uvw` being an (N, 3) array the uvw
        coordinates of the visibilities in meter, `freq`  a 1D array with the
        frequencies in Herz, and `image` a 2D arrays of shape `(npix_x, npix_y)`
        with the sky brightness in Jansky per Steradian.
    """
    wgridder = partial(
        _wgridder,
        pixsize_x=pixsize_x,
        pixsize_y=pixsize_y,
        npix_x=npix_x,
        npix_y=npix_y,
        epsilon=epsilon,
        do_wgridding=do_wgridding,
        nthreads=nthreads,
        flip_v=flip_v,
        verbosity=verbosity,
        **kwargs,
    )
    return wgridder