bayespecon._logdet.jax_logdet_trace_poly¶
- bayespecon._logdet.jax_logdet_trace_poly(rho, traces)[source]¶
Evaluate trace-polynomial approximation of log|I - rho*W| in JAX.
JAX-native version of
logdet_mc_poly_pytensor()using Horner’s method. Fully compatible withjax.jitandjax.grad.Computes the truncated power-series approximation
\[\log|I_n - \rho W| \approx -\sum_{k=1}^{m} \frac{\rho^k}{k}\,\hat{\tau}_k\]where \(\hat{\tau}_k \approx \text{tr}(W^k)\) are stochastic trace estimates from
traceax_traces()orcompute_flow_traces().- Parameters:¶
- Returns:¶
Polynomial approximation of the log-determinant. Same shape as rho.
- Return type:¶
jax.numpy.ndarray
Notes
Horner evaluation of \(-\sum_{k=1}^m w_k \rho^k\) where \(w_k = \hat{\tau}_k / k\):
\[-\rho \bigl(w_1 + \rho(w_2 + \rho(\cdots + \rho\, w_m)\cdots)\bigr)\]This is the same algorithm as
logdet_mc_poly_pytensor()but usesjax.numpyinstead of pytensor, making it compatible withjax.jitandjax.grad.See also
logdet_mc_poly_pytensorPyTensor symbolic version (for NUTS).
traceax_tracesCompute trace estimates via variance-reduced estimators.
compute_flow_tracesCompute trace estimates via Barry-Pace Hutchinson.