API Reference#
Complete reference for all public functions in nufftax.
Overview#
nufftax provides three types of Non-Uniform FFT:
Type |
Direction |
Functions |
|---|---|---|
Type 1 |
Nonuniform → Uniform |
|
Type 2 |
Uniform → Nonuniform |
|
Type 3 |
Nonuniform → Nonuniform |
|
JAX Transformation Support#
All functions support JAX transformations. Use jax.grad, jax.vjp, or jax.jvp directly:
Transform |
|
|
|
|
|---|---|---|---|---|
Type 1 (1D/2D/3D) |
✅ |
✅ |
✅ |
✅ |
Type 2 (1D/2D/3D) |
✅ |
✅ |
✅ |
✅ |
Type 3 (1D/2D/3D) |
✅ |
✅ |
✅ |
✅ |
Differentiable inputs:
Type |
Differentiable w.r.t. |
|---|---|
Type 1 |
|
Type 2 |
|
Type 3 |
|
Example:
import jax
from nufftax import nufft1d1
x = jnp.array([0.1, 0.5, 1.0, 2.0])
c = jnp.array([1+1j, 2-1j, 0.5, 1j])
# Gradient w.r.t. strengths
grad_c = jax.grad(lambda c: jnp.sum(jnp.abs(nufft1d1(x, c, 32))**2))(c)
# Gradient w.r.t. coordinates
grad_x = jax.grad(lambda x: jnp.sum(jnp.abs(nufft1d1(x, c, 32))**2))(x)
# Forward-mode AD (JVP)
primals, tangents = jax.jvp(
lambda c: nufft1d1(x, c, 32),
(c,), (jnp.ones_like(c),)
)
# Reverse-mode AD (VJP)
primals, vjp_fn = jax.vjp(lambda c: nufft1d1(x, c, 32), c)
(grad_c,) = vjp_fn(jnp.ones_like(primals))
Type 1: Nonuniform → Uniform#
Compute Fourier coefficients from scattered data.
Mathematical definition:
where \(k = -N/2, \ldots, N/2-1\).
1D Transform#
Example:
from nufftax import nufft1d1
import jax.numpy as jnp
x = jnp.array([0.1, 0.5, 1.0, 2.0, -1.5])
c = jnp.array([1+1j, 2-1j, 0.5, 1j, -1+0.5j])
f = nufft1d1(x, c, n_modes=64, eps=1e-6, isign=1)
# f.shape = (64,)
2D Transform#
Example:
from nufftax import nufft2d1
x = jnp.array([0.1, 0.5, 1.0])
y = jnp.array([0.2, -0.3, 0.8])
c = jnp.array([1+1j, 2-1j, 0.5])
f = nufft2d1(x, y, c, n_modes=(32, 32), eps=1e-6)
# f.shape = (32, 32)
3D Transform#
Type 2: Uniform → Nonuniform#
Evaluate Fourier series at scattered points.
Mathematical definition:
1D Transform#
Example:
from nufftax import nufft1d2
x = jnp.array([0.1, 0.5, 1.0, 2.0, -1.5])
f = jnp.ones(64, dtype=jnp.complex64)
c = nufft1d2(x, f, eps=1e-6, isign=-1)
# c.shape = (5,)
2D Transform#
3D Transform#
Type 3: Nonuniform → Nonuniform#
Transform between two sets of nonuniform points.
Mathematical definition:
Important
Type 3 transforms require pre-computing the internal grid size for JIT compilation. Use the helper functions below.
1D Transform#
Example:
from nufftax import nufft1d3, compute_type3_grid_size
x = jnp.array([0.1, 0.5, 1.0, 2.0])
c = jnp.array([1+1j, 2-1j, 0.5, 1j])
s = jnp.linspace(-10, 10, 50)
# Pre-compute grid size
n_modes = compute_type3_grid_size(x, s, eps=1e-6)
f = nufft1d3(x, c, s, n_modes=n_modes, eps=1e-6)
# f.shape = (50,)
2D Transform#
3D Transform#
Utility Functions#
Grid Size Computation#
These functions compute the internal grid size needed for Type 3 transforms. Call them before JIT-compiling your code.
- nufftax.compute_type3_grid_size(x_or_x_extent, s_or_s_extent, eps=1e-06, upsampfac=2.0)[source]#
Compute appropriate grid size for 1D Type 3 NUFFT.
This helper function can be used to pre-compute grid sizes for JIT compilation.
- Parameters:
x_or_x_extent – Either source points array (shape M,) OR half-width float. If array, computes extent as (max - min) / 2.
s_or_s_extent – Either target frequencies array (shape N,) OR half-width float. If array, computes extent as (max - min) / 2.
eps (float) – Requested precision
upsampfac (float) – Oversampling factor
- Returns:
Grid size (smooth integer with factors 2, 3, 5)
- Return type:
nf
Example
>>> import jax.numpy as jnp >>> x = jnp.array([...]) # source points >>> s = jnp.array([...]) # target frequencies >>> # Method 1: Pass arrays directly (recommended) >>> nf = compute_type3_grid_size(x, s, eps=1e-6) >>> # Method 2: Pass extents manually >>> nf = compute_type3_grid_size((x.max()-x.min())/2, (s.max()-s.min())/2, eps=1e-6) >>> # Now use nf in JIT-compiled code: >>> f = nufft1d3(x, c, s, n_modes=nf, eps=1e-6)
- nufftax.compute_type3_grid_sizes_2d(x_extent, y_extent, s_extent, t_extent, eps=1e-06, upsampfac=2.0)[source]#
Compute appropriate grid sizes for 2D Type 3 NUFFT.
- Parameters:
- Returns:
Grid sizes for each dimension
- Return type:
(nf1, nf2)
- nufftax.compute_type3_grid_sizes_3d(x_extent, y_extent, z_extent, s_extent, t_extent, u_extent, eps=1e-06, upsampfac=2.0)[source]#
Compute appropriate grid sizes for 3D Type 3 NUFFT.
- Parameters:
x_extent (float) – Half-widths of source point ranges
y_extent (float) – Half-widths of source point ranges
z_extent (float) – Half-widths of source point ranges
s_extent (float) – Half-widths of target frequency ranges
t_extent (float) – Half-widths of target frequency ranges
u_extent (float) – Half-widths of target frequency ranges
eps (float) – Requested precision
upsampfac (float) – Oversampling factor
- Returns:
Grid sizes for each dimension
- Return type:
(nf1, nf2, nf3)
Common Parameters#
All transform functions share these parameters:
Parameter |
Type |
Description |
|---|---|---|
|
|
Nonuniform point coordinates in \([-\pi, \pi)\). Shape |
|
|
Complex values at nonuniform points. Shape |
|
|
Fourier mode coefficients on uniform grid. |
|
|
Target frequencies for Type 3 transforms. |
|
|
Number of output Fourier modes. For 2D/3D, use tuple like |
|
|
Requested relative precision. Range: |
|
|
Sign of the exponent: |
Return Values#
Type 1: Returns complex array of Fourier modes with shape
(n_modes,)for 1D,(n_modes_x, n_modes_y)for 2D, etc.Type 2: Returns complex array of values at query points with shape
(M,).Type 3: Returns complex array at target frequencies with shape
(n_targets,).