jaxbind.contrib.jaxducc0 module#

c2c(a: numpy.ndarray, axes: list[int] | None = None, forward: bool = True, inorm: int = 0, out: numpy.ndarray | None = None, nthreads: int = 1) numpy.ndarray#

Performs a complex FFT.

Parameters:
  • a (numpy.ndarray (any complex or real type)) – The input data. If its type is real, a more efficient real-to-complex transform will be used.

  • axes (list of integers) – The axes along which the FFT is carried out (first axis has number 0). If not set, all axes will be transformed.

  • forward (bool) – If True, a negative sign is used in the exponent, else a positive one.

  • inorm (int) –

    Normalization type
    0 : no normalization
    1 : divide by sqrt(N)
    2 : divide by N

    where N is the product of the lengths of the transformed axes.

  • out (numpy.ndarray (same shape as a, complex type with same accuracy as a)) – May be identical to a, but if it isn’t, it must not overlap with a. If None, a new array is allocated to store the output.

  • nthreads (int) – Number of threads to use. If 0, use the system default (typically the number of hardware threads on the compute node).

Returns:

The transformed data.

Return type:

numpy.ndarray (same shape as a, complex type with same accuracy as a)

Notes

For one-dimensional arrays of length N, this function computes: \forall\ k = 0 \dots n-1

Y_k = \frac{1}{\sqrt{n}^{\textrm{inorm}}} \sum_{j=0}^{n-1} X_j e^{s 2\pi i \frac{j k}{N}}

where

s = \left\{ \begin{align} -1 & \quad \text{if forward} \\ +1 & \quad \text{else} \end{align} \right.

For multi-dimensional arrays, the function computes one-dimensional transforms on each of the specified axes sequentially. For instance, for a two-dimensional array X of shape (N,M) (with axes=(0,1)), this function computes the two-dimensional array of the same shape Z as:

Y_{k,p} = \frac{1}{\sqrt{N}^{\textrm{inorm}}} \sum_{j=0}^{N-1} X_{j,p} e^{s 2\pi i \frac{j k}{N}} \\ Z_{k,q} = \frac{1}{\sqrt{M}^{\textrm{inorm}}} \sum_{p=0}^{M-1} Y_{k,p} e^{s 2\pi i \frac{p q}{M}}

genuine_fht(a: numpy.ndarray, axes: list[int] | None = None, inorm: int = 0, out: numpy.ndarray | None = None, nthreads: int = 1) numpy.ndarray#

Performs a full Hartley transform. A full forward Fourier transform is carried out over the requested axes, and the real part minus the imaginary part of the result is stored in the output array. For a single transformed axis, this is identical to separable_fht, but when transforming multiple axes, the results are different.

Parameters:
  • a (numpy.ndarray (any real type)) – The input data

  • axes (list of integers) – The axes along which the transform is carried out. If not set, all axes will be transformed.

  • inorm (int) –

    Normalization type
    0 : no normalization
    1 : divide by sqrt(N)
    2 : divide by N

    where N is the product of the lengths of the transformed axes.

  • out (numpy.ndarray (same shape and data type as a)) – May be identical to a, but if it isn’t, it must not overlap with a. If None, a new array is allocated to store the output.

  • nthreads (int) – Number of threads to use. If 0, use the system default (typically the number of hardware threads on the compute node).

Returns:

The transformed data

Return type:

numpy.ndarray (same shape and data type as a)

Notes

Mathematically this function performs exactly the same operations as c2c() (with forward=True), but returns a real-valued array containing \Re(a)-\Im(a), where a is the c2c() output.

get_healpix_sht(nside, lmax, mmax, spin, nthreads=1)[source]#

Create a JAX primitive for the ducc0 SHT synthesis for HEALPix

Parameters:
  • nside (int) – Parameter of the HEALPix sphere.

  • lmax (int) – Maximum l respectively m moment of the transformation (inclusive).

  • mmax (int) – Maximum l respectively m moment of the transformation (inclusive).

  • spin (int) – Spin to use for the transfomration.

  • nthreads (int) – Number of threads to use for the computation. If 0, use as many threads as there are hardware threads available on the system.

Returns:

op – The Jax primitive of the SHT synthesis for HEALPix.

Return type:

JAX primitive

get_wgridder(*, pixsize_x, pixsize_y, npix_x, npix_y, epsilon, do_wgridding, nthreads=1, flip_v=False, verbosity=0, **kwargs)[source]#

Create a JAX primitive for the ducc0 wgridder

Parameters:
  • pixsize_x (float) – Size of the pixels in radian.

  • pixsize_y (float) – Size of the pixels in radian.

  • npix_x (int) – Number of pixels.

  • npix_y (int) – Number of pixels.

  • epsilon (float) – Sets the required accuracy of the wgridder evaluation.

  • nthreads (int) – Sets the number of threads used for evaluation. Default 1.

  • flip_v (bool) – Whether or not to flip the v coordinate of the visibilities. Default False.

  • verbosity (int) – Sets the verbosity of the wgridder. For 0 no print out, for >0 verbose output. Default 0.

  • **kwargs (dict) – Additional forwarded to ducc wgridder.

Returns:

op – The Jax primitive has the signature (uvw, freq, image) with uvw being an (N, 3) array the uvw coordinates of the visibilities in meter, freq a 1D array with the frequencies in Herz, and image a 2D arrays of shape (npix_x, npix_y) with the sky brightness in Jansky per Steradian.

Return type:

JAX primitive evaluating the ducc wgridder.

nalm(lmax, mmax)[source]#

Compute the number of a_lm for a given maximum l and m moment of the SHT