jaxbind.contrib.jaxducc0 module#
- c2c(a: numpy.ndarray, axes: list[int] | None = None, forward: bool = True, inorm: int = 0, out: numpy.ndarray | None = None, nthreads: int = 1) numpy.ndarray #
Performs a complex FFT.
- Parameters:
a (numpy.ndarray (any complex or real type)) – The input data. If its type is real, a more efficient real-to-complex transform will be used.
axes (list of integers) – The axes along which the FFT is carried out (first axis has number 0). If not set, all axes will be transformed.
forward (bool) – If True, a negative sign is used in the exponent, else a positive one.
inorm (int) –
- Normalization type
- 0 : no normalization1 : divide by sqrt(N)2 : divide by N
where N is the product of the lengths of the transformed axes.
out (numpy.ndarray (same shape as a, complex type with same accuracy as a)) – May be identical to a, but if it isn’t, it must not overlap with a. If None, a new array is allocated to store the output.
nthreads (int) – Number of threads to use. If 0, use the system default (typically the number of hardware threads on the compute node).
- Returns:
The transformed data.
- Return type:
numpy.ndarray (same shape as a, complex type with same accuracy as a)
Notes
For one-dimensional arrays of length N, this function computes: \forall\ k = 0 \dots n-1
Y_k = \frac{1}{\sqrt{n}^{\textrm{inorm}}} \sum_{j=0}^{n-1} X_j e^{s 2\pi i \frac{j k}{N}}
where
s = \left\{ \begin{align} -1 & \quad \text{if forward} \\ +1 & \quad \text{else} \end{align} \right.
For multi-dimensional arrays, the function computes one-dimensional transforms on each of the specified axes sequentially. For instance, for a two-dimensional array X of shape (N,M) (with
axes=(0,1)
), this function computes the two-dimensional array of the same shape Z as:Y_{k,p} = \frac{1}{\sqrt{N}^{\textrm{inorm}}} \sum_{j=0}^{N-1} X_{j,p} e^{s 2\pi i \frac{j k}{N}} \\ Z_{k,q} = \frac{1}{\sqrt{M}^{\textrm{inorm}}} \sum_{p=0}^{M-1} Y_{k,p} e^{s 2\pi i \frac{p q}{M}}
- genuine_fht(a: numpy.ndarray, axes: list[int] | None = None, inorm: int = 0, out: numpy.ndarray | None = None, nthreads: int = 1) numpy.ndarray #
Performs a full Hartley transform. A full forward Fourier transform is carried out over the requested axes, and the real part minus the imaginary part of the result is stored in the output array. For a single transformed axis, this is identical to separable_fht, but when transforming multiple axes, the results are different.
- Parameters:
a (numpy.ndarray (any real type)) – The input data
axes (list of integers) – The axes along which the transform is carried out. If not set, all axes will be transformed.
inorm (int) –
- Normalization type
- 0 : no normalization1 : divide by sqrt(N)2 : divide by N
where N is the product of the lengths of the transformed axes.
out (numpy.ndarray (same shape and data type as a)) – May be identical to a, but if it isn’t, it must not overlap with a. If None, a new array is allocated to store the output.
nthreads (int) – Number of threads to use. If 0, use the system default (typically the number of hardware threads on the compute node).
- Returns:
The transformed data
- Return type:
numpy.ndarray (same shape and data type as a)
- get_healpix_sht(nside, lmax, mmax, spin, nthreads=1)[source]#
Create a JAX primitive for the ducc0 SHT synthesis for HEALPix
- Parameters:
nside (int) – Parameter of the HEALPix sphere.
lmax (int) – Maximum l respectively m moment of the transformation (inclusive).
mmax (int) – Maximum l respectively m moment of the transformation (inclusive).
spin (int) – Spin to use for the transfomration.
nthreads (int) – Number of threads to use for the computation. If 0, use as many threads as there are hardware threads available on the system.
- Returns:
op – The Jax primitive of the SHT synthesis for HEALPix.
- Return type:
JAX primitive
- get_wgridder(*, pixsize_x, pixsize_y, npix_x, npix_y, epsilon, do_wgridding, nthreads=1, flip_v=False, verbosity=0, **kwargs)[source]#
Create a JAX primitive for the ducc0 wgridder
- Parameters:
pixsize_x (float) – Size of the pixels in radian.
pixsize_y (float) – Size of the pixels in radian.
npix_x (int) – Number of pixels.
npix_y (int) – Number of pixels.
epsilon (float) – Sets the required accuracy of the wgridder evaluation.
nthreads (int) – Sets the number of threads used for evaluation. Default 1.
flip_v (bool) – Whether or not to flip the v coordinate of the visibilities. Default False.
verbosity (int) – Sets the verbosity of the wgridder. For 0 no print out, for >0 verbose output. Default 0.
**kwargs (dict) – Additional forwarded to ducc wgridder.
- Returns:
op – The Jax primitive has the signature (uvw, freq, image) with uvw being an (N, 3) array the uvw coordinates of the visibilities in meter, freq a 1D array with the frequencies in Herz, and image a 2D arrays of shape (npix_x, npix_y) with the sky brightness in Jansky per Steradian.
- Return type:
JAX primitive evaluating the ducc wgridder.