bayespecon.GaussianSEMGibbs

class bayespecon.GaussianSEMGibbs(y, X, W_sparse, priors, logdet_fn, logdet_vec_fn, feature_names, model_type='sem', W_eigs=None, logdet_method=None, T=1)[source]

Gibbs sampler for SEM/SDEM Gaussian models.

3-block sampler: β (conjugate normal), σ² (conjugate Inv-Γ), λ (conditional slice sampling or MALA).

Parameters:
y : ndarray of shape (n,)

Response vector.

X : ndarray of shape (n, k)

Design matrix (for SDEM, this is [X, WX]).

W_sparse : csr_matrix of shape (n, n)

Row-standardised spatial weights matrix.

priors : GaussianGibbsPriors

Prior hyperparameters.

logdet_fn : callable

log|I - lam*W| callable (numpy scalar).

logdet_vec_fn : callable

Vectorized logdet callable.

feature_names : list of str

Names for the columns of X.

model_type : str

“sem” or “sdem”.

W_eigs : ndarray or None

Real eigenvalues of W (for JAX logdet).

logdet_method : str or None

Logdet method for JAX path (auto-selected when None).

T : int, default 1

Panel time-period count.

__init__(y, X, W_sparse, priors, logdet_fn, logdet_vec_fn, feature_names, model_type='sem', W_eigs=None, logdet_method=None, T=1)[source]

Methods

__init__(y, X, W_sparse, priors, logdet_fn, ...)

fit([draws, tune, chains, random_seed, ...])

Run Gibbs chains and assemble InferenceData.

fit(draws=2000, tune=1000, chains=4, random_seed=None, thin=1, n_jobs=-1, progressbar=True, gibbs_method='numpy', mala_step_size=0.05, use_mala=True, use_slice=True, slice_width=None, chain_method=None)[source]

Run Gibbs chains and assemble InferenceData.

Parameters:
draws : int, default 2000

Number of post-warmup draws per chain.

tune : int, default 1000

Number of warmup (burn-in) draws per chain.

chains : int, default 4

Number of independent chains.

random_seed : int or None

Seed for reproducibility.

thin : int, default 1

Keep every thin-th draw after warmup.

n_jobs : int, default -1

Number of parallel workers for the NumPy path. -1 uses all CPUs. When n_jobs=1, chains run sequentially with progress bars. When n_jobs>1 (or -1), chains run in parallel via joblib. Ignored for the JAX path (use chain_method instead).

progressbar : bool, default True

Show per-chain progress bars.

gibbs_method : str, default "numpy"

Execution backend: "numpy" for Python-loop Gibbs with adaptive slice sampling, or "jax" for full-JIT Gibbs with MALA for ρ/λ. The JAX path requires JAX and equinox.

mala_step_size : float, default 0.05

Initial MALA step size (or RW-MH proposal sd) for the JAX path. Ignored when gibbs_method="numpy".

use_mala : bool, default True

If True, use MALA (gradient-guided proposals) for the ρ/λ update in the JAX path. If False, use random-walk Metropolis–Hastings. Ignored when gibbs_method="numpy" or use_slice=True.

use_slice : bool, default False

If True, use slice sampling for the ρ/λ update in the JAX path. Slice sampling gives much better ESS per sample than MALA. Takes priority over use_mala. Ignored when gibbs_method="numpy".

slice_width : float or None, default None

Initial step-out width for slice sampling. If None, defaults to (rho_upper - rho_lower) * 0.1. Ignored when use_slice=False or gibbs_method="numpy".

chain_method : str or None, default None

How to run multiple chains for the JAX path. "vectorized" uses jax.vmap for JAX-native parallelism (all chains on one device). "sequential" runs chains one after another with progress bars. "parallel" is not supported for the JAX path. If None, defaults to "vectorized" when gibbs_method="jax". Ignored for the NumPy path (use n_jobs to control parallelism instead).

Returns:

With posterior, log_likelihood, and observed_data groups.

Return type:

az.InferenceData