bayespecon._logdet.jax_logdet_chebyshev

bayespecon._logdet.jax_logdet_chebyshev(rho, coeffs, rmin=-1.0, rmax=1.0)[source]

Evaluate Chebyshev approximation of log|I - rho*W| in JAX.

JAX-native version of logdet_chebyshev() using Clenshaw’s algorithm. Fully compatible with jax.jit and jax.grad.

Parameters:
rho : jax.numpy scalar or array

Spatial autoregressive parameter. Can be a scalar or an array of shape (G,) for vectorized evaluation over posterior draws.

coeffs : np.ndarray, shape (m,)

Chebyshev coefficients from chebyshev().

rmin : float, default=-1.0

Lower bound of the rho interval (must match what was used to compute coeffs).

rmax : float, default=1.0

Upper bound of the rho interval (must match what was used to compute coeffs).

Returns:

Chebyshev approximation of the log-determinant. Same shape as rho.

Return type:

jax.numpy.ndarray

Notes

The mapped variable is

\[x = \frac{2\rho - r_{\max} - r_{\min}}{r_{\max} - r_{\min}}\]

and the approximation is evaluated via Clenshaw’s recurrence:

\[ \begin{align}\begin{aligned}b_{m+1} = 0, \quad b_m = c_m\\b_k = 2x \, b_{k+1} - b_{k+2} + c_k\\f(x) = x \, b_1 - b_2 + c_0\end{aligned}\end{align} \]

This is the same algorithm as logdet_chebyshev() but uses jax.numpy instead of pytensor, making it compatible with jax.jit and jax.grad.

See also

chebyshev

Compute Chebyshev coefficients from W.

logdet_chebyshev

PyTensor symbolic version (for NUTS).