Source code for nufftax.transforms.nufft3

"""
Type 3 NUFFT transforms: Nonuniform to Nonuniform.

Type 3 computes the Fourier transform from nonuniform source points
to nonuniform target frequencies:
    f[k] = sum_j c[j] * exp(isign * i * s[k] * x[j])

where both x[j] (source points) and s[k] (target frequencies) are nonuniform.

Algorithm (from FINUFFT):
1. Pre-phase: c' = c * exp(isign * i * D * x)
2. Rescale source points: x' = (x - C) / gamma
3. Spread c' at x' to get fw on fine grid (spatial representation)
4. Use inner Type 2: treat fw as "Fourier modes" and evaluate at s'
   (Type 2 does: deconvolve -> IFFT -> interpolate)
5. Deconvolve: f = result * (1/phihat(s')) * exp(isign * i * (s-D) * C)

JIT Compilation:
Type 3 requires the `n_modes` parameter for JIT compilation (grid sizes must
be static). Use `compute_type3_grid_size()` to pre-compute appropriate values.

Reference: FINUFFT finufft_core.cpp
"""

import jax
import jax.numpy as jnp
import numpy as np

from ..core.kernel import compute_kernel_params, es_kernel
from ..core.spread import _spread_1d_dispatch, _spread_2d_dispatch, _spread_3d_dispatch
from ..utils.grid import next_smooth_int
from .nufft2 import nufft1d2, nufft2d2, nufft3d2


def _next_smooth_even(n: int) -> int:
    """Find next even integer >= n that has only factors of 2, 3, 5.

    Type 3 NUFFT requires even grid sizes for proper FFT alignment.
    """
    if n <= 1:
        return 2
    if n % 2 == 1:
        n += 1
    candidate = next_smooth_int(n)
    while candidate % 2 != 0:
        candidate = next_smooth_int(candidate + 1)
    return candidate


[docs] def compute_type3_grid_size( x_or_x_extent, s_or_s_extent, eps: float = 1e-6, upsampfac: float = 2.0, ) -> int: """Compute appropriate grid size for 1D Type 3 NUFFT. This helper function can be used to pre-compute grid sizes for JIT compilation. Args: x_or_x_extent: Either source points array (shape M,) OR half-width float. If array, computes extent as (max - min) / 2. s_or_s_extent: Either target frequencies array (shape N,) OR half-width float. If array, computes extent as (max - min) / 2. eps: Requested precision upsampfac: Oversampling factor Returns: nf: Grid size (smooth integer with factors 2, 3, 5) Example: >>> import jax.numpy as jnp >>> x = jnp.array([...]) # source points >>> s = jnp.array([...]) # target frequencies >>> # Method 1: Pass arrays directly (recommended) >>> nf = compute_type3_grid_size(x, s, eps=1e-6) >>> # Method 2: Pass extents manually >>> nf = compute_type3_grid_size((x.max()-x.min())/2, (s.max()-s.min())/2, eps=1e-6) >>> # Now use nf in JIT-compiled code: >>> f = nufft1d3(x, c, s, n_modes=nf, eps=1e-6) """ # Handle array inputs - convert to extents x_arr = np.asarray(x_or_x_extent) s_arr = np.asarray(s_or_s_extent) if x_arr.ndim > 0: x_extent = float((np.max(x_arr) - np.min(x_arr)) / 2) else: x_extent = float(x_arr) if s_arr.ndim > 0: s_extent = float((np.max(s_arr) - np.min(s_arr)) / 2) else: s_extent = float(s_arr) kernel_params = compute_kernel_params(eps, upsampfac) nspread = kernel_params.nspread # Ensure non-zero extents x_extent = max(x_extent, 1e-10) s_extent = max(s_extent, 1e-10) nf_float = 2.0 * upsampfac * s_extent * x_extent / np.pi + nspread return _next_smooth_even(int(np.ceil(nf_float)))
[docs] def compute_type3_grid_sizes_2d( x_extent: float, y_extent: float, s_extent: float, t_extent: float, eps: float = 1e-6, upsampfac: float = 2.0, ) -> tuple[int, int]: """Compute appropriate grid sizes for 2D Type 3 NUFFT. Args: x_extent, y_extent: Half-widths of source point ranges s_extent, t_extent: Half-widths of target frequency ranges eps: Requested precision upsampfac: Oversampling factor Returns: (nf1, nf2): Grid sizes for each dimension """ kernel_params = compute_kernel_params(eps, upsampfac) nspread = kernel_params.nspread x_extent = max(x_extent, 1e-10) y_extent = max(y_extent, 1e-10) s_extent = max(s_extent, 1e-10) t_extent = max(t_extent, 1e-10) nf1 = _next_smooth_even(int(np.ceil(2.0 * upsampfac * s_extent * x_extent / np.pi + nspread))) nf2 = _next_smooth_even(int(np.ceil(2.0 * upsampfac * t_extent * y_extent / np.pi + nspread))) return nf1, nf2
[docs] def compute_type3_grid_sizes_3d( x_extent: float, y_extent: float, z_extent: float, s_extent: float, t_extent: float, u_extent: float, eps: float = 1e-6, upsampfac: float = 2.0, ) -> tuple[int, int, int]: """Compute appropriate grid sizes for 3D Type 3 NUFFT. Args: x_extent, y_extent, z_extent: Half-widths of source point ranges s_extent, t_extent, u_extent: Half-widths of target frequency ranges eps: Requested precision upsampfac: Oversampling factor Returns: (nf1, nf2, nf3): Grid sizes for each dimension """ kernel_params = compute_kernel_params(eps, upsampfac) nspread = kernel_params.nspread x_extent = max(x_extent, 1e-10) y_extent = max(y_extent, 1e-10) z_extent = max(z_extent, 1e-10) s_extent = max(s_extent, 1e-10) t_extent = max(t_extent, 1e-10) u_extent = max(u_extent, 1e-10) nf1 = _next_smooth_even(int(np.ceil(2.0 * upsampfac * s_extent * x_extent / np.pi + nspread))) nf2 = _next_smooth_even(int(np.ceil(2.0 * upsampfac * t_extent * y_extent / np.pi + nspread))) nf3 = _next_smooth_even(int(np.ceil(2.0 * upsampfac * u_extent * z_extent / np.pi + nspread))) return nf1, nf2, nf3
def _kernel_ft_at_point(k: jax.Array, nspread: int, beta: float, c: float, dtype=None) -> jax.Array: """Evaluate Fourier transform of ES kernel at arbitrary frequency k. Uses quadrature to compute: phi_hat(k) = integral_{-J/2}^{J/2} phi(z) * cos(k*z) dz Args: k: Frequency points (can be array) nspread: Kernel width J beta: Kernel shape parameter c: Kernel normalization (1/J^2) dtype: Output dtype (default: inferred from k) Returns: phi_hat(k): Kernel Fourier transform at frequencies k """ if dtype is None: dtype = k.dtype J2 = nspread / 2.0 n_quad = max(int(4 + 3.0 * J2), 20) z = jnp.linspace(-J2, J2, n_quad, dtype=dtype) dz = z[1] - z[0] phi_vals = es_kernel(z, beta, c) # Trapezoidal weights w = jnp.ones(n_quad, dtype=dtype) * dz w = w.at[0].set(dz / 2) w = w.at[-1].set(dz / 2) k = jnp.atleast_1d(k).astype(dtype) phase = jnp.outer(k, z) phi_hat = jnp.sum(phi_vals[None, :] * w[None, :] * jnp.cos(phase), axis=1) return phi_hat def nufft1d3( x: jax.Array, c: jax.Array, s: jax.Array, n_modes: int, eps: float = 1e-6, isign: int = 1, upsampfac: float = 2.0, ) -> jax.Array: """ 1D Type-3 NUFFT: nonuniform to nonuniform. Computes: f[k] = sum_{j=0}^{M-1} c[j] * exp(isign * i * s[k] * x[j]) for k = 0, ..., N-1 where s[k] are the target frequencies. Use `compute_type3_grid_size()` to compute `n_modes` from data bounds. Args: x: Source points (nonuniform), shape (M,) c: Complex strengths at source points, shape (M,) or (n_trans, M) s: Target frequencies (nonuniform), shape (N,) n_modes: Grid size. Use compute_type3_grid_size() to compute. eps: Requested precision (tolerance) isign: Sign in the exponential (+1 or -1) upsampfac: Oversampling factor (default 2.0) Returns: f: Values at target frequencies, shape (N,) or (n_trans, N) """ # Handle batched input batched = c.ndim == 2 if not batched: c = c[None, :] kernel_params = compute_kernel_params(eps, upsampfac) nspread = kernel_params.nspread beta = kernel_params.beta kc = kernel_params.c # Compute interval centers and half-widths (traceable operations) x_min, x_max = jnp.min(x), jnp.max(x) s_min, s_max = jnp.min(s), jnp.max(s) C = (x_min + x_max) / 2.0 D = (s_min + s_max) / 2.0 S = jnp.maximum((s_max - s_min) / 2.0, 1e-10) # Fine grid parameters (n_modes is required for JIT) nf = n_modes h = 2.0 * jnp.pi / nf gamma = nf / (2.0 * upsampfac * S) # Rescale source points x_rescaled = (x - C) / gamma x_normalized = jnp.mod(x_rescaled + jnp.pi, 2.0 * jnp.pi) - jnp.pi # Get dtype-appropriate imaginary unit imag_unit = jnp.array(1j, dtype=c.dtype) # Pre-phase: c' = c * exp(isign * i * D * x) prephase = jnp.exp(imag_unit * isign * D * x) c_phased = c * prephase[None, :] # Spread to fine grid (spatial representation) fw = _spread_1d_dispatch(x_normalized, c_phased, nf, kernel_params) # Rescale target frequencies: s' = h * gamma * (s - D) s_rescaled = h * gamma * (s - D) s_normalized = jnp.mod(s_rescaled + jnp.pi, 2.0 * jnp.pi) - jnp.pi # Inner Type 2: treats fw as "Fourier modes" and evaluates at s' # Type 2 computes: result[j] = sum_k fw[k] * exp(isign * k * s'[j]) # It does: deconvolve fw -> IFFT -> interpolate at s' f = nufft1d2(s_normalized, fw, eps=eps, isign=isign, upsampfac=upsampfac, modeord=0) # Deconvolution: multiply by 1/phihat(s') * phase_correction real_dtype = jnp.finfo(c.dtype).dtype phi_hat = _kernel_ft_at_point(s_rescaled, nspread, beta, kc, dtype=real_dtype) phase_correction = jnp.exp(imag_unit * isign * (s - D) * C) deconv = phase_correction / phi_hat f = f * deconv[None, :] if not batched: f = f[0] return f def nufft2d3( x: jax.Array, y: jax.Array, c: jax.Array, s: jax.Array, t: jax.Array, n_modes: tuple[int, int], eps: float = 1e-6, isign: int = 1, upsampfac: float = 2.0, ) -> jax.Array: """ 2D Type-3 NUFFT: nonuniform to nonuniform. Computes: f[k] = sum_j c[j] * exp(isign * i * (s[k]*x[j] + t[k]*y[j])) Use `compute_type3_grid_sizes_2d()` to compute `n_modes` from data bounds. Args: x, y: Source points (nonuniform), shape (M,) c: Complex strengths, shape (M,) or (n_trans, M) s, t: Target frequencies (nonuniform), shape (N,) n_modes: Grid sizes (nf1, nf2). Use compute_type3_grid_sizes_2d(). eps: Requested precision isign: Sign in the exponential upsampfac: Oversampling factor Returns: f: Values at target frequencies, shape (N,) or (n_trans, N) """ batched = c.ndim == 2 if not batched: c = c[None, :] kernel_params = compute_kernel_params(eps, upsampfac) nspread = kernel_params.nspread beta = kernel_params.beta kc = kernel_params.c # Compute intervals (traceable operations) x_min, x_max = jnp.min(x), jnp.max(x) y_min, y_max = jnp.min(y), jnp.max(y) s_min, s_max = jnp.min(s), jnp.max(s) t_min, t_max = jnp.min(t), jnp.max(t) Cx = (x_min + x_max) / 2.0 Cy = (y_min + y_max) / 2.0 Ds = (s_min + s_max) / 2.0 Ss = jnp.maximum((s_max - s_min) / 2.0, 1e-10) Dt = (t_min + t_max) / 2.0 St = jnp.maximum((t_max - t_min) / 2.0, 1e-10) # Fine grid parameters (n_modes is required for JIT) nf1, nf2 = n_modes h1, h2 = 2.0 * jnp.pi / nf1, 2.0 * jnp.pi / nf2 gamma1, gamma2 = nf1 / (2.0 * upsampfac * Ss), nf2 / (2.0 * upsampfac * St) # Rescale source points x_normalized = jnp.mod((x - Cx) / gamma1 + jnp.pi, 2.0 * jnp.pi) - jnp.pi y_normalized = jnp.mod((y - Cy) / gamma2 + jnp.pi, 2.0 * jnp.pi) - jnp.pi # Get dtype-appropriate imaginary unit imag_unit = jnp.array(1j, dtype=c.dtype) # Pre-phase prephase = jnp.exp(imag_unit * isign * (Ds * x + Dt * y)) c_phased = c * prephase[None, :] # Spread fw = _spread_2d_dispatch(x_normalized, y_normalized, c_phased, nf1, nf2, kernel_params) # Rescale target frequencies s_rescaled = h1 * gamma1 * (s - Ds) t_rescaled = h2 * gamma2 * (t - Dt) s_normalized = jnp.mod(s_rescaled + jnp.pi, 2.0 * jnp.pi) - jnp.pi t_normalized = jnp.mod(t_rescaled + jnp.pi, 2.0 * jnp.pi) - jnp.pi # Inner Type 2 f = nufft2d2(s_normalized, t_normalized, fw, eps=eps, isign=isign, upsampfac=upsampfac) # Deconvolution real_dtype = jnp.finfo(c.dtype).dtype phi_hat1 = _kernel_ft_at_point(s_rescaled, nspread, beta, kc, dtype=real_dtype) phi_hat2 = _kernel_ft_at_point(t_rescaled, nspread, beta, kc, dtype=real_dtype) phase_correction = jnp.exp(imag_unit * isign * ((s - Ds) * Cx + (t - Dt) * Cy)) f = f * phase_correction[None, :] / (phi_hat1 * phi_hat2)[None, :] if not batched: f = f[0] return f def nufft3d3( x: jax.Array, y: jax.Array, z: jax.Array, c: jax.Array, s: jax.Array, t: jax.Array, u: jax.Array, n_modes: tuple[int, int, int], eps: float = 1e-6, isign: int = 1, upsampfac: float = 2.0, ) -> jax.Array: """ 3D Type-3 NUFFT: nonuniform to nonuniform. Computes: f[k] = sum_j c[j] * exp(isign * i * (s[k]*x[j] + t[k]*y[j] + u[k]*z[j])) Use `compute_type3_grid_sizes_3d()` to compute `n_modes` from data bounds. Args: x, y, z: Source points (nonuniform), shape (M,) c: Complex strengths, shape (M,) or (n_trans, M) s, t, u: Target frequencies (nonuniform), shape (N,) n_modes: Grid sizes (nf1, nf2, nf3). Use compute_type3_grid_sizes_3d(). eps: Requested precision isign: Sign in the exponential upsampfac: Oversampling factor Returns: f: Values at target frequencies, shape (N,) or (n_trans, N) """ batched = c.ndim == 2 if not batched: c = c[None, :] kernel_params = compute_kernel_params(eps, upsampfac) nspread = kernel_params.nspread beta = kernel_params.beta kc = kernel_params.c # Compute intervals (traceable operations) x_min, x_max = jnp.min(x), jnp.max(x) y_min, y_max = jnp.min(y), jnp.max(y) z_min, z_max = jnp.min(z), jnp.max(z) s_min, s_max = jnp.min(s), jnp.max(s) t_min, t_max = jnp.min(t), jnp.max(t) u_min, u_max = jnp.min(u), jnp.max(u) Cx = (x_min + x_max) / 2.0 Cy = (y_min + y_max) / 2.0 Cz = (z_min + z_max) / 2.0 Ds = (s_min + s_max) / 2.0 Ss = jnp.maximum((s_max - s_min) / 2.0, 1e-10) Dt = (t_min + t_max) / 2.0 St = jnp.maximum((t_max - t_min) / 2.0, 1e-10) Du = (u_min + u_max) / 2.0 Su = jnp.maximum((u_max - u_min) / 2.0, 1e-10) # Fine grid parameters (n_modes is required for JIT) nf1, nf2, nf3 = n_modes h1, h2, h3 = 2.0 * jnp.pi / nf1, 2.0 * jnp.pi / nf2, 2.0 * jnp.pi / nf3 gamma1 = nf1 / (2.0 * upsampfac * Ss) gamma2 = nf2 / (2.0 * upsampfac * St) gamma3 = nf3 / (2.0 * upsampfac * Su) # Rescale source points x_normalized = jnp.mod((x - Cx) / gamma1 + jnp.pi, 2.0 * jnp.pi) - jnp.pi y_normalized = jnp.mod((y - Cy) / gamma2 + jnp.pi, 2.0 * jnp.pi) - jnp.pi z_normalized = jnp.mod((z - Cz) / gamma3 + jnp.pi, 2.0 * jnp.pi) - jnp.pi # Get dtype-appropriate imaginary unit imag_unit = jnp.array(1j, dtype=c.dtype) # Pre-phase prephase = jnp.exp(imag_unit * isign * (Ds * x + Dt * y + Du * z)) c_phased = c * prephase[None, :] # Spread fw = _spread_3d_dispatch(x_normalized, y_normalized, z_normalized, c_phased, nf1, nf2, nf3, kernel_params) # Rescale target frequencies s_rescaled = h1 * gamma1 * (s - Ds) t_rescaled = h2 * gamma2 * (t - Dt) u_rescaled = h3 * gamma3 * (u - Du) s_normalized = jnp.mod(s_rescaled + jnp.pi, 2.0 * jnp.pi) - jnp.pi t_normalized = jnp.mod(t_rescaled + jnp.pi, 2.0 * jnp.pi) - jnp.pi u_normalized = jnp.mod(u_rescaled + jnp.pi, 2.0 * jnp.pi) - jnp.pi # Inner Type 2 f = nufft3d2( s_normalized, t_normalized, u_normalized, fw, eps=eps, isign=isign, upsampfac=upsampfac, ) # Deconvolution real_dtype = jnp.finfo(c.dtype).dtype phi_hat1 = _kernel_ft_at_point(s_rescaled, nspread, beta, kc, dtype=real_dtype) phi_hat2 = _kernel_ft_at_point(t_rescaled, nspread, beta, kc, dtype=real_dtype) phi_hat3 = _kernel_ft_at_point(u_rescaled, nspread, beta, kc, dtype=real_dtype) phase_correction = jnp.exp(imag_unit * isign * ((s - Ds) * Cx + (t - Dt) * Cy + (u - Du) * Cz)) f = f * phase_correction[None, :] / (phi_hat1 * phi_hat2 * phi_hat3)[None, :] if not batched: f = f[0] return f