bayespecon._logdet.make_logdet_jax_fn

bayespecon._logdet.make_logdet_jax_fn(W, method=None, rho_min=-1.0, rho_max=1.0, T=1, trace_estimator='hutchpp', trace_k=None)[source]

Return a JAX-native function (rho) -> log|I - rho*W|.

Companion to make_logdet_fn() (pytensor) and make_logdet_numpy_fn() (numpy) that returns a JAX-native callable suitable for use inside jax.jit and jax.grad.

Parameters:
W : np.ndarray or scipy.sparse matrix

Either a 2-D dense (n, n) spatial weights matrix or a 1-D array of pre-computed real eigenvalues. Passing eigenvalues skips the O(n³) decomposition.

method : str or None

Auto-selected when None ("eigenvalue" for n <= 500 else "chebyshev"). Supported values:

"eigenvalue" — exact evaluation from eigenvalues, O(n) per call. "chebyshev" — Chebyshev polynomial via Clenshaw’s algorithm, O(m) per call. Coefficients are built from exact eigenvalues when n is small (or eigs is supplied); otherwise from a stochastic trace estimator selected by trace_estimator.

rho_min : float, default=-1.0

Lower bound for the rho interval.

rho_max : float, default=1.0

Upper bound for the rho interval.

T : int, default 1

Panel time-period count. The returned log-determinant is multiplied by T.

trace_estimator : {"hutchinson", "hutchpp", "xtrace"}, default "hutchpp"

Stochastic trace estimator used to build the Chebyshev coefficients when an eigendecomposition is unavailable. Ignored when method="eigenvalue" or when eigenvalues are passed in.

trace_k : int, optional

Number of probe vectors for the trace estimator. Defaults: 30 (hutchinson), 50 (hutchpp), 25 (xtrace).

Returns:

Function (rho) -> jax.numpy.ndarray that computes log|I - rho*W| (or T * log|I - rho*W| for panel models). Fully compatible with jax.jit and jax.grad.

Return type:

callable

Raises:

ValueError – If method is not one of the supported JAX-compatible methods.

Notes

Not all logdet methods have JAX-native implementations. Grid/spline methods ("grid_dense", "grid_sparse", "sparse_spline", "grid_mc", "grid_ilu") and "exact" are not supported because they rely on scipy or pytensor-specific operations that cannot be called inside jax.jit. Use "eigenvalue" or "chebyshev" instead.

See also

make_logdet_fn

PyTensor symbolic version (for NUTS).

make_logdet_numpy_fn

NumPy scalar version (for Python-loop Gibbs).

make_logdet_numpy_vec_fn

NumPy vectorized version (for post-processing).