bayespecon.GibbsEstimation

class bayespecon.GibbsEstimation(y, X, W_sparse, Wy, priors, logdet_fn, logdet_vec_fn, feature_names, model_type, W_eigs=None, logdet_method=None, T=1)[source]

Base class for Gaussian spatial Gibbs sampler configuration and execution.

Encapsulates the data, priors, cache, and chain-running logic for the 3-block Gibbs sampler (β, σ², ρ/λ). Subclasses provide model-specific details (SAR vs SEM, collapsed vs un-collapsed ρ/λ).

Parameters:
y : ndarray of shape (n,)

Response vector.

X : ndarray of shape (n, k)

Design matrix (for SDM/SDEM, this is [X, WX]).

W_sparse : csr_matrix of shape (n, n)

Row-standardised spatial weights matrix.

Wy : ndarray of shape (n,) or None

W @ y (precomputed, for SAR/SDM).

priors : GaussianGibbsPriors

Prior hyperparameters.

logdet_fn : callable

log|I - rho*W| callable (numpy scalar).

logdet_vec_fn : callable

Vectorized logdet callable for arrays of rho values.

feature_names : list of str

Names for the columns of X (for InferenceData coords).

model_type : str

One of “sar”, “sem”, “sdm”, “sdem”.

__init__(y, X, W_sparse, Wy, priors, logdet_fn, logdet_vec_fn, feature_names, model_type, W_eigs=None, logdet_method=None, T=1)[source]

Methods

__init__(y, X, W_sparse, Wy, priors, ...[, ...])

fit([draws, tune, chains, random_seed, ...])

Run Gibbs chains and assemble InferenceData.

fit(draws=2000, tune=1000, chains=4, random_seed=None, thin=1, n_jobs=-1, progressbar=True, gibbs_method='numpy', mala_step_size=0.05, use_mala=True, use_slice=True, slice_width=None, chain_method=None)[source]

Run Gibbs chains and assemble InferenceData.

Parameters:
draws : int, default 2000

Number of post-warmup draws per chain.

tune : int, default 1000

Number of warmup (burn-in) draws per chain.

chains : int, default 4

Number of independent chains.

random_seed : int or None

Seed for reproducibility.

thin : int, default 1

Keep every thin-th draw after warmup.

n_jobs : int, default -1

Number of parallel workers for the NumPy path. -1 uses all CPUs. When n_jobs=1, chains run sequentially with progress bars. When n_jobs>1 (or -1), chains run in parallel via joblib. Ignored for the JAX path (use chain_method instead).

progressbar : bool, default True

Show per-chain progress bars.

gibbs_method : str, default "numpy"

Execution backend: "numpy" for Python-loop Gibbs with adaptive slice sampling, or "jax" for full-JIT Gibbs with MALA for ρ/λ. The JAX path requires JAX and equinox.

mala_step_size : float, default 0.05

Initial MALA step size (or RW-MH proposal sd) for the JAX path. Ignored when gibbs_method="numpy".

use_mala : bool, default True

If True, use MALA (gradient-guided proposals) for the ρ/λ update in the JAX path. If False, use random-walk Metropolis–Hastings. Ignored when gibbs_method="numpy" or use_slice=True.

use_slice : bool, default False

If True, use slice sampling for the ρ/λ update in the JAX path. Slice sampling gives much better ESS per sample than MALA. Takes priority over use_mala. Ignored when gibbs_method="numpy".

slice_width : float or None, default None

Initial step-out width for slice sampling. If None, defaults to (rho_upper - rho_lower) * 0.1. Ignored when use_slice=False or gibbs_method="numpy".

chain_method : str or None, default None

How to run multiple chains for the JAX path. "vectorized" uses jax.vmap for JAX-native parallelism (all chains on one device). "sequential" runs chains one after another with progress bars. "parallel" is not supported for the JAX path. If None, defaults to "vectorized" when gibbs_method="jax". Ignored for the NumPy path (use n_jobs to control parallelism instead).

Returns:

With posterior, log_likelihood, and observed_data groups.

Return type:

az.InferenceData