nufftax#
A pure JAX implementation of the Non-Uniform Fast Fourier Transform.
What is nufftax?#
nufftax lets you compute Fourier transforms on data sampled at arbitrary (non-uniform) locations, while keeping full compatibility with JAX’s automatic differentiation and GPU acceleration.
Traditional FFTs require data on a regular grid. But real-world data is often irregularly sampled:
MRI scans acquire k-space data along spiral or radial trajectories
Astronomical observations happen at irregular time intervals
Sensor networks collect measurements at scattered spatial locations
The NUFFT bridges this gap, and nufftax makes it differentiable.
Why nufftax?#
A JAX package for NUFFT already exists: jax-finufft. However, it wraps the C++ FINUFFT library via Foreign Function Interface (FFI), exposing it through custom XLA calls. This approach can lead to:
Kernel fusion issues on GPU — custom XLA calls act as optimization barriers, preventing XLA from fusing operations
CUDA version matching — GPU support requires matching CUDA versions between JAX and the library
nufftax takes a different approach — pure JAX implementation:
Feature |
Benefit |
|---|---|
Fully differentiable |
Compute gradients through the entire transform - both w.r.t. data values and sampling locations |
JAX native |
Works with |
GPU ready |
Runs on CPU/GPU without code changes, benefits from XLA fusion |
Pallas GPU kernels |
Fused Triton spreading kernels with 5–75× speedups on A100/H100 |
No compilation step |
Pure Python/JAX - no C++ extensions to build |
JAX Transformation Support#
Transform |
|
|
|
|
|---|---|---|---|---|
Type 1 (1D/2D/3D) |
✅ |
✅ |
✅ |
✅ |
Type 2 (1D/2D/3D) |
✅ |
✅ |
✅ |
✅ |
Type 3 (1D/2D/3D) |
✅ |
✅ |
✅ |
✅ |
Differentiable inputs:
Type 1:
gradw.r.t.c(strengths) andx,y,z(coordinates)Type 2:
gradw.r.t.f(Fourier modes) andx,y,z(coordinates)Type 3:
gradw.r.t.c(strengths),x,y,z(source coordinates), ands,t,u(target frequencies)
Quick Example#
import jax
import jax.numpy as jnp
from nufftax import nufft1d1
# Irregular sample locations in [-pi, pi)
x = jnp.array([0.1, 0.7, 1.3, 2.1, -0.5])
# Complex values at those locations
c = jnp.array([1.0+0.5j, 0.3-0.2j, 0.8+0.1j, 0.2+0.4j, 0.5-0.3j])
# Compute 32 Fourier modes
f = nufft1d1(x, c, n_modes=32, eps=1e-6)
# Differentiate through the transform
def loss(c):
return jnp.sum(jnp.abs(nufft1d1(x, c, n_modes=32)) ** 2)
grad_c = jax.grad(loss)(c)
GPU Acceleration#
On GPU, nufftax automatically dispatches spreading and interpolation to fused Pallas (Triton) kernels when the problem is large enough. This avoids materializing O(M × nspreadd) intermediate tensors and uses atomic scatter-add for spreading.
Operation |
Backend |
Speedup vs pure JAX |
|---|---|---|
1D spread |
A100 |
5–67× (M ≥ 100K) |
1D spread |
H100 |
4–75× (M ≥ 100K) |
2D spread |
A100/H100 |
2–3× (M ≥ 100K) |
The dispatch is transparent — no code changes required. On CPU or for small problems, the pure JAX path is used.
Installation#
CPU only:
uv pip install nufftax
With CUDA 12 GPU support:
uv pip install "nufftax[cuda12]"
Development install (from source):
git clone https://github.com/GragasLab/nufftax.git
cd nufftax
uv pip install -e ".[dev]"
This installs test dependencies (pytest, ruff, finufft for
comparison testing, pre-commit).
Development install with CUDA 12:
uv pip install -e ".[dev,cuda12]"
With docs dependencies:
uv pip install -e ".[docs]"
Documentation#
Getting Started
User Guide
License#
MIT License. Algorithm based on FINUFFT by the Flatiron Institute.