"""Public NUFFT API.
Thin wrappers around the JAX core primitives defined in `primitives.py`.
The primitives carry explicit JVP and transpose rules, so all of `jit`,
`grad`, `vjp`, `jvp`, `linear_transpose`, and `lax.custom_linear_solve`
(hence `jax.scipy.sparse.linalg.cg`/`gmres`/`bicgstab`) work through them.
Type 1 (nufftXd1) and Type 2 (nufftXd2) are mutual adjoints; Type 3 is
self-adjoint with source/target points swapped.
The `nufftXdY_jvp` aliases exist for backward compatibility with code that
imported the previous `custom_jvp`-only variants. They are now identical to
their non-`_jvp` counterparts.
"""
from jax import Array
from . import primitives as P
# ============================================================================
# Type 1 (nonuniform -> uniform)
# ============================================================================
[docs]
def nufft1d1(x: Array, c: Array, n_modes: int, eps: float = 1e-6, isign: int = 1) -> Array:
return P.nufft1d1_p.bind(x, c, n_modes=n_modes, eps=eps, isign=isign)
[docs]
def nufft2d1(x: Array, y: Array, c: Array, n_modes: tuple[int, int], eps: float = 1e-6, isign: int = 1) -> Array:
return P.nufft2d1_p.bind(x, y, c, n_modes=tuple(n_modes), eps=eps, isign=isign)
[docs]
def nufft3d1(
x: Array,
y: Array,
z: Array,
c: Array,
n_modes: tuple[int, int, int],
eps: float = 1e-6,
isign: int = 1,
) -> Array:
return P.nufft3d1_p.bind(x, y, z, c, n_modes=tuple(n_modes), eps=eps, isign=isign)
# ============================================================================
# Type 2 (uniform -> nonuniform)
# ============================================================================
[docs]
def nufft1d2(x: Array, f: Array, eps: float = 1e-6, isign: int = -1) -> Array:
return P.nufft1d2_p.bind(x, f, eps=eps, isign=isign)
[docs]
def nufft2d2(x: Array, y: Array, f: Array, eps: float = 1e-6, isign: int = -1) -> Array:
return P.nufft2d2_p.bind(x, y, f, eps=eps, isign=isign)
[docs]
def nufft3d2(x: Array, y: Array, z: Array, f: Array, eps: float = 1e-6, isign: int = -1) -> Array:
return P.nufft3d2_p.bind(x, y, z, f, eps=eps, isign=isign)
# ============================================================================
# Type 3 (nonuniform -> nonuniform)
# ============================================================================
[docs]
def nufft1d3(
x: Array,
c: Array,
s: Array,
n_modes: int,
eps: float = 1e-6,
isign: int = 1,
upsampfac: float = 2.0,
) -> Array:
return P.nufft1d3_p.bind(x, c, s, n_modes=int(n_modes), eps=eps, isign=isign, upsampfac=upsampfac)
[docs]
def nufft2d3(
x: Array,
y: Array,
c: Array,
s: Array,
t: Array,
n_modes: tuple[int, int],
eps: float = 1e-6,
isign: int = 1,
upsampfac: float = 2.0,
) -> Array:
return P.nufft2d3_p.bind(x, y, c, s, t, n_modes=tuple(n_modes), eps=eps, isign=isign, upsampfac=upsampfac)
[docs]
def nufft3d3(
x: Array,
y: Array,
z: Array,
c: Array,
s: Array,
t: Array,
u: Array,
n_modes: tuple[int, int, int],
eps: float = 1e-6,
isign: int = 1,
upsampfac: float = 2.0,
) -> Array:
return P.nufft3d3_p.bind(x, y, z, c, s, t, u, n_modes=tuple(n_modes), eps=eps, isign=isign, upsampfac=upsampfac)
# Backward-compatibility aliases. Identical to the non-_jvp variants now that
# both forward and reverse mode AD flow through the same primitive.
nufft1d1_jvp = nufft1d1
nufft1d2_jvp = nufft1d2
nufft2d1_jvp = nufft2d1
nufft2d2_jvp = nufft2d2
nufft3d1_jvp = nufft3d1
nufft3d2_jvp = nufft3d2
nufft1d3_jvp = nufft1d3
nufft2d3_jvp = nufft2d3
nufft3d3_jvp = nufft3d3
__all__ = [
"nufft1d1",
"nufft1d2",
"nufft1d3",
"nufft2d1",
"nufft2d2",
"nufft2d3",
"nufft3d1",
"nufft3d2",
"nufft3d3",
"nufft1d1_jvp",
"nufft1d2_jvp",
"nufft1d3_jvp",
"nufft2d1_jvp",
"nufft2d2_jvp",
"nufft2d3_jvp",
"nufft3d1_jvp",
"nufft3d2_jvp",
"nufft3d3_jvp",
]