bayespecon.GaussianSEMGibbs¶
-
class bayespecon.GaussianSEMGibbs(y, X, W_sparse, priors, logdet_fn, logdet_vec_fn, feature_names, model_type=
'sem', W_eigs=None, logdet_method=None, T=1)[source]¶ Gibbs sampler for SEM/SDEM Gaussian models.
3-block sampler: β (conjugate normal), σ² (conjugate Inv-Γ), λ (conditional slice sampling or MALA).
- Parameters:¶
- y : ndarray of shape (n,)¶
Response vector.
- X : ndarray of shape (n, k)¶
Design matrix (for SDEM, this is [X, WX]).
- W_sparse : csr_matrix of shape (n, n)¶
Row-standardised spatial weights matrix.
- priors : GaussianGibbsPriors¶
Prior hyperparameters.
- logdet_fn : callable¶
log|I - lam*W| callable (numpy scalar).
- logdet_vec_fn : callable¶
Vectorized logdet callable.
- feature_names : list of str¶
Names for the columns of X.
- model_type : str¶
“sem” or “sdem”.
- W_eigs : ndarray or None¶
Real eigenvalues of W (for JAX logdet).
- logdet_method : str or None¶
Logdet method for JAX path (auto-selected when None).
- T : int, default 1¶
Panel time-period count.
-
__init__(y, X, W_sparse, priors, logdet_fn, logdet_vec_fn, feature_names, model_type=
'sem', W_eigs=None, logdet_method=None, T=1)[source]¶
Methods
__init__(y, X, W_sparse, priors, logdet_fn, ...)fit([draws, tune, chains, random_seed, ...])Run Gibbs chains and assemble InferenceData.
-
fit(draws=
2000, tune=1000, chains=4, random_seed=None, thin=1, n_jobs=-1, progressbar=True, gibbs_method='numpy', mala_step_size=0.05, use_mala=True, use_slice=True, slice_width=None, chain_method=None)[source]¶ Run Gibbs chains and assemble InferenceData.
- Parameters:¶
- draws : int, default 2000¶
Number of post-warmup draws per chain.
- tune : int, default 1000¶
Number of warmup (burn-in) draws per chain.
- chains : int, default 4¶
Number of independent chains.
- random_seed : int or None¶
Seed for reproducibility.
- thin : int, default 1¶
Keep every
thin-th draw after warmup.- n_jobs : int, default -1¶
Number of parallel workers for the NumPy path.
-1uses all CPUs. Whenn_jobs=1, chains run sequentially with progress bars. Whenn_jobs>1(or-1), chains run in parallel viajoblib. Ignored for the JAX path (usechain_methodinstead).- progressbar : bool, default True¶
Show per-chain progress bars.
- gibbs_method : str, default "numpy"¶
Execution backend:
"numpy"for Python-loop Gibbs with adaptive slice sampling, or"jax"for full-JIT Gibbs with MALA for ρ/λ. The JAX path requires JAX and equinox.- mala_step_size : float, default 0.05¶
Initial MALA step size (or RW-MH proposal sd) for the JAX path. Ignored when
gibbs_method="numpy".- use_mala : bool, default True¶
If True, use MALA (gradient-guided proposals) for the ρ/λ update in the JAX path. If False, use random-walk Metropolis–Hastings. Ignored when
gibbs_method="numpy"oruse_slice=True.- use_slice : bool, default False¶
If True, use slice sampling for the ρ/λ update in the JAX path. Slice sampling gives much better ESS per sample than MALA. Takes priority over
use_mala. Ignored whengibbs_method="numpy".- slice_width : float or None, default None¶
Initial step-out width for slice sampling. If None, defaults to
(rho_upper - rho_lower) * 0.1. Ignored whenuse_slice=Falseorgibbs_method="numpy".- chain_method : str or None, default None¶
How to run multiple chains for the JAX path.
"vectorized"usesjax.vmapfor JAX-native parallelism (all chains on one device)."sequential"runs chains one after another with progress bars."parallel"is not supported for the JAX path. If None, defaults to"vectorized"whengibbs_method="jax". Ignored for the NumPy path (usen_jobsto control parallelism instead).
- Returns:¶
With
posterior,log_likelihood, andobserved_datagroups.- Return type:¶
az.InferenceData