bayespecon._logdet.jax_logdet_trace_poly

bayespecon._logdet.jax_logdet_trace_poly(rho, traces)[source]

Evaluate trace-polynomial approximation of log|I - rho*W| in JAX.

JAX-native version of logdet_mc_poly_pytensor() using Horner’s method. Fully compatible with jax.jit and jax.grad.

Computes the truncated power-series approximation

\[\log|I_n - \rho W| \approx -\sum_{k=1}^{m} \frac{\rho^k}{k}\,\hat{\tau}_k\]

where \(\hat{\tau}_k \approx \text{tr}(W^k)\) are stochastic trace estimates from traceax_traces() or compute_flow_traces().

Parameters:
rho : jax.numpy scalar or array

Spatial autoregressive parameter. Can be a scalar or an array of shape (G,) for vectorized evaluation over posterior draws.

traces : np.ndarray, shape (m,)

Trace estimates traces[k-1] tr(W^k) for k=1..m.

Returns:

Polynomial approximation of the log-determinant. Same shape as rho.

Return type:

jax.numpy.ndarray

Notes

Horner evaluation of \(-\sum_{k=1}^m w_k \rho^k\) where \(w_k = \hat{\tau}_k / k\):

\[-\rho \bigl(w_1 + \rho(w_2 + \rho(\cdots + \rho\, w_m)\cdots)\bigr)\]

This is the same algorithm as logdet_mc_poly_pytensor() but uses jax.numpy instead of pytensor, making it compatible with jax.jit and jax.grad.

See also

logdet_mc_poly_pytensor

PyTensor symbolic version (for NUTS).

traceax_traces

Compute trace estimates via variance-reduced estimators.

compute_flow_traces

Compute trace estimates via Barry-Pace Hutchinson.