bayespecon._logdet.make_logdet_jax_fn¶
-
bayespecon._logdet.make_logdet_jax_fn(W, method=
None, rho_min=-1.0, rho_max=1.0, T=1, trace_estimator='hutchpp', trace_k=None)[source]¶ Return a JAX-native function (rho) -> log|I - rho*W|.
Companion to
make_logdet_fn()(pytensor) andmake_logdet_numpy_fn()(numpy) that returns a JAX-native callable suitable for use insidejax.jitandjax.grad.- Parameters:¶
- W : np.ndarray or scipy.sparse matrix¶
Either a 2-D dense
(n, n)spatial weights matrix or a 1-D array of pre-computed real eigenvalues. Passing eigenvalues skips the O(n³) decomposition.- method : str or None¶
Auto-selected when
None("eigenvalue"forn <= 500else"chebyshev"). Supported values:"eigenvalue"— exact evaluation from eigenvalues, O(n) per call."chebyshev"— Chebyshev polynomial via Clenshaw’s algorithm, O(m) per call. Coefficients are built from exact eigenvalues whennis small (oreigsis supplied); otherwise from a stochastic trace estimator selected bytrace_estimator.- rho_min : float, default=-1.0¶
Lower bound for the rho interval.
- rho_max : float, default=1.0¶
Upper bound for the rho interval.
- T : int, default 1¶
Panel time-period count. The returned log-determinant is multiplied by T.
- trace_estimator : {"hutchinson", "hutchpp", "xtrace"}, default "hutchpp"¶
Stochastic trace estimator used to build the Chebyshev coefficients when an eigendecomposition is unavailable. Ignored when
method="eigenvalue"or when eigenvalues are passed in.- trace_k : int, optional¶
Number of probe vectors for the trace estimator. Defaults:
30(hutchinson),50(hutchpp),25(xtrace).
- Returns:¶
Function
(rho) -> jax.numpy.ndarraythat computes log|I - rho*W| (or T * log|I - rho*W| for panel models). Fully compatible withjax.jitandjax.grad.- Return type:¶
callable
- Raises:¶
ValueError – If method is not one of the supported JAX-compatible methods.
Notes
Not all logdet methods have JAX-native implementations. Grid/spline methods (
"grid_dense","grid_sparse","sparse_spline","grid_mc","grid_ilu") and"exact"are not supported because they rely on scipy or pytensor-specific operations that cannot be called insidejax.jit. Use"eigenvalue"or"chebyshev"instead.See also
make_logdet_fnPyTensor symbolic version (for NUTS).
make_logdet_numpy_fnNumPy scalar version (for Python-loop Gibbs).
make_logdet_numpy_vec_fnNumPy vectorized version (for post-processing).