Models

BGM Family

BGM

class bayesgm.models.bgm.BGM(params, timestamp=None, random_seed=None)[source]

Bayesian Generative Model (BGM) for tabular data.

BGM learns a latent-variable generative model \(Z \sim \mathcal{N}(0, I)\), \(X \mid Z \sim \mathcal{N}(\mu(Z), \Sigma(Z))\) using an iterative algorithm that alternates between updating the generative network parameters and the individual latent variables.

Parameters:
  • params (dict) –

    Configuration dictionary. Required keys:

    • 'x_dim' (int): Dimension of the observed variable \(X\).

    • 'z_dim' (int): Dimension of the latent variable \(Z\).

    • 'dataset' (str): Dataset name (used for checkpoint paths).

    • 'output_dir' (str): Root directory for saving results and checkpoints.

    Optional keys (with defaults):

    • 'use_bnn' (bool): Whether to use a Bayesian neural network for the generator. Default False.

    • 'g_units' (list[int]): Hidden-layer sizes for the generator network. Default [64, 64, 64, 64, 64].

    • 'e_units' (list[int]): Hidden-layer sizes for the encoder network. Default [64, 64, 64, 64, 64].

    • 'dz_units' (list[int]): Hidden-layer sizes for the latent discriminator. Default [64, 32, 8].

    • 'dx_units' (list[int]): Hidden-layer sizes for the data discriminator. Default [64, 32, 8].

    • 'lr' (float): Learning rate for EGM initialization. Default 0.001.

    • 'lr_theta' (float): Learning rate for generator parameters. Default 0.005.

    • 'lr_z' (float): Learning rate for latent-variable updates. Default 0.005.

    • 'gamma' (float): Gradient-penalty coefficient. Default 0.

    • 'alpha' (float): Regularisation weight on the variance term. Default 0.0.

    • 'g_d_freq' (int): Discriminator-to-generator update ratio during EGM. Default 1.

    • 'save_model' (bool): Whether to save model checkpoints. Default True.

    • 'save_res' (bool): Whether to save results. Default True.

    • 'kl_weight' (float): KL-divergence weight when use_bnn is True. Default 0.00005.

  • timestamp (str or None, optional) – Timestamp string for the run. If None, the current local time is used.

  • random_seed (int or None, optional) – If provided, sets the global random seed for reproducibility.

Training

fit(data, batch_size=32, epochs=100, epochs_per_eval=5, use_egm_init=True, egm_n_iter=20000, egm_batches_per_eval=500, verbose=1)[source]

Train the BGM model on observed data.

The training procedure consists of two phases:

  1. EGM initialization (optional) — warm-start by jointly training encoder and generator with adversarial losses to obtain a good starting point for the latent variables and model parameters. This phase is optional and can be skipped by setting use_egm_init to False.

  2. Iterative optimization — alternates between updating the generator network parameters \(\theta\) and the per-sample latent variables \(Z\) via SGD.

Parameters:
  • data (np.ndarray) – Observed data matrix with shape (n, x_dim).

  • batch_size (int, default=32) – Mini-batch size.

  • epochs (int, default=100) – Number of training epochs for the iterative phase.

  • epochs_per_eval (int, default=5) – Evaluate and (optionally) save every this many epochs.

  • use_egm_init (bool, default=True) – Whether to run EGM initialization before iterative training.

  • egm_n_iter (int, default=20000) – Number of EGM initialization iterations.

  • egm_batches_per_eval (int, default=500) – Evaluate EGM every this many iterations.

  • verbose (int, default=1) – Verbosity level. Set to 0 to suppress progress messages.

Inference

predict(data, alpha=0.05, return_samples=False, bs=100, n_mcmc=5000, burn_in=5000, step_size=0.01, num_leapfrog_steps=10, seed=42)[source]

Predict the posterior distribution with missing data handling.

Parameters:
  • data (np.ndarray or tf.Tensor) – Input data with shape (n, x_dim). Missing values should be encoded as np.nan.

  • alpha (float, default=0.05) – Significance level for prediction intervals.

  • return_samples (bool, default=False) – If False, return imputed data with shape (n, x_dim). If True, return posterior samples with shape (n_mcmc, n, x_dim).

  • bs (int, default=100) – Batch size for posterior prediction.

  • n_mcmc (int, default=5000) – Number of retained MCMC samples.

  • burn_in (int, default=5000) – Number of burn-in iterations.

  • step_size (float, default=0.01) – HMC step size.

  • num_leapfrog_steps (int, default=10) – Number of leapfrog steps in HMC.

  • seed (int, default=42) – Random seed.

Returns:

  • data_x_pred (np.ndarray) – Imputed data if return_samples=False with shape (n, x_dim). Posterior predictive samples if return_samples=True with shape (n_mcmc, n, x_dim).

  • pred_interval (np.ndarray or list[np.ndarray]) – Prediction intervals on missing dimensions. For a shared missing pattern, shape is (n, n_missing_dims, 2). Otherwise, this is a per-sample list where element i has shape (n_missing_dims_i, 2).

generate(nb_samples=1000, use_x_sd=True)

Generate synthetic data from the trained model.

Samples latent codes from the standard normal prior and decodes them through the generator network.

Parameters:
  • nb_samples (int, default=1000) – Number of samples to generate.

  • use_x_sd (bool, default=True) – If True, add noise sampled from the learned variance. If False, return the generator mean directly.

Returns:

  • data_x_gen (tf.Tensor) – Generated data with shape (nb_samples, x_dim).

  • sigma_square_x (tf.Tensor) – Predicted variance with shape (nb_samples, x_dim).

evaluate(data, data_z=None, use_x_sd=True)

Compute the mean squared error between observed and reconstructed data.

Parameters:
  • data (np.ndarray or tf.Tensor) – Observed data with shape (n, x_dim).

  • data_z (tf.Tensor or None, optional) – Latent variables with shape (n, z_dim). If None, the encoder network is used to infer them.

  • use_x_sd (bool, default=True) – If True, sample reconstructions from \(\mathcal{N}(\mu, \sigma^2)\). If False, use the mean \(\mu\) directly.

Returns:

mse_x – Scalar mean squared error.

Return type:

tf.Tensor

Configuration

get_config()[source]

Return the configuration of the BGM model.

Returns:

A dictionary with key "params" containing the full configuration dictionary passed at construction time.

Return type:

dict

MNISTBGM

class bayesgm.models.bgm.MNISTBGM(params, timestamp=None, random_seed=None)[source]

BGM model for MNIST imaging data.

Inherits from BGM and overrides methods to use convolutional neural networks and a Bernoulli likelihood for binary image data of shape (28, 28, 1).

Parameters:
  • params (dict) – Configuration dictionary. Same keys as BGM, with x_dim corresponding to the flattened image dimensionality (784 for MNIST).

  • timestamp (str or None, optional) – Timestamp string for the run. If None, the current local time is used.

  • random_seed (int or None, optional) – If provided, sets the global random seed for reproducibility.

Training

fit(data, batch_size=32, epochs=100, epochs_per_eval=5, use_egm_init=True, egm_n_iter=10000, egm_batches_per_eval=500, verbose=1)[source]

Train the MNIST BGM model on image data.

Parameters:
  • data (np.ndarray) – MNIST image array with shape (n, 28, 28, 1), values in [0, 1].

  • batch_size (int, default=32) – Mini-batch size.

  • epochs (int, default=100) – Number of training epochs for the iterative phase.

  • epochs_per_eval (int, default=5) – Evaluate and (optionally) save every this many epochs.

  • use_egm_init (bool, default=True) – Whether to run EGM initialization before iterative training.

  • egm_n_iter (int, default=10000) – Number of EGM initialization iterations.

  • egm_batches_per_eval (int, default=500) – Evaluate EGM every this many iterations.

  • verbose (int, default=1) – Verbosity level. Set to 0 to suppress progress messages.

Inference

predict(data, alpha=0.05, return_samples=False, bs=100, n_mcmc=5000, burn_in=5000, step_size=0.01, num_leapfrog_steps=10, seed=42)[source]

Predict the posterior distribution of P(x2|x1) for MNIST images.

Parameters:
  • data (np.ndarray or tf.Tensor) – Observed data with shape (n, 28, 28, 1). Missing pixels should be encoded as np.nan.

  • alpha (float, default=0.05) – Significance level for prediction intervals.

  • return_samples (bool, default=False) – If False, return imputed images with shape (n, 28, 28, 1). If True, return posterior samples with shape (n_mcmc, n, 28, 28, 1).

  • bs (int, default=100) – Batch size for posterior prediction.

  • n_mcmc (int, default=5000) – Number of retained MCMC samples.

  • burn_in (int, default=5000) – Number of burn-in iterations.

  • step_size (float, default=0.01) – HMC step size.

  • num_leapfrog_steps (int, default=10) – Number of leapfrog steps in HMC.

  • seed (int, default=42) – Random seed.

Returns:

  • data_x_pred (np.ndarray) – Imputed images if return_samples=False with shape (n, 28, 28, 1). Posterior predictive samples if return_samples=True with shape (n_mcmc, n, 28, 28, 1).

  • pred_interval (np.ndarray or list[np.ndarray]) – Prediction intervals on missing pixels. For a shared missing pattern, shape is (n, n_missing_pixels, 2). Otherwise, this is a per-sample list where element i has shape (n_missing_pixels_i, 2).

generate(nb_samples=1000)

Generate synthetic MNIST images from the trained model.

Samples latent codes from the standard normal prior and decodes them through the convolutional generator.

Parameters:

nb_samples (int, default=1000) – Number of images to generate.

Returns:

data_x_pred – Generated images with shape (nb_samples, 28, 28, 1), pixel values in [0, 1].

Return type:

tf.Tensor

evaluate(data, data_z=None)

Compute the mean squared error between observed and reconstructed MNIST images.

Parameters:
  • data (np.ndarray or tf.Tensor) – Observed images with shape (n, 28, 28, 1).

  • data_z (tf.Tensor or None, optional) – Latent variables with shape (n, z_dim). If None, the encoder is used to infer them.

Returns:

mse_x – Scalar mean squared error.

Return type:

tf.Tensor

Configuration

get_config()

Return the configuration of the BGM model.

Returns:

A dictionary with key "params" containing the full configuration dictionary passed at construction time.

Return type:

dict

CausalBGM Family

CausalBGM

class bayesgm.models.causalbgm.CausalBGM(params, timestamp=None, random_seed=None)[source]

Causal Bayesian Generative Model (CausalBGM) for causal inference.

CausalBGM learns a latent-variable generative model for causal inference with treatment \(X\), outcome \(Y\), and high-dimensional covariates \(V\). The latent variable \(Z\) is partitioned into \((Z_0, Z_1, Z_2, Z_3)\) to disentangle confounding, outcome-specific, treatment-specific, and residual variation.

Parameters:
  • params (dict) –

    Configuration dictionary. Required keys:

    • 'v_dim' (int): Dimension of covariates \(V\).

    • 'z_dims' (list[int]): Dimensions [z0, z1, z2, z3] of the four latent sub-vectors.

    • 'binary_treatment' (bool): True for binary treatment, False for continuous.

    • 'dataset' (str): Dataset name (used for checkpoint paths).

    • 'output_dir' (str): Root directory for outputs.

    Optional keys (with defaults):

    • 'use_bnn' (bool): Whether to use Bayesian neural networks. Default True.

    • 'g_units' (list[int]): Hidden-layer sizes for the generator network. Default [64, 64, 64, 64, 64].

    • 'e_units' (list[int]): Hidden-layer sizes for the encoder network. Default [64, 64, 64, 64, 64].

    • 'f_units' (list[int]): Hidden-layer sizes for the outcome network. Default [64, 32, 8].

    • 'h_units' (list[int]): Hidden-layer sizes for the treatment network. Default [64, 32, 8].

    • 'dz_units' (list[int]): Hidden-layer sizes for the latent discriminator. Default [64, 32, 8].

    • 'lr' (float): Learning rate for EGM pre-training. Default 0.0002.

    • 'lr_theta' (float): Learning rate for network parameters. Default 0.0001.

    • 'lr_z' (float): Learning rate for latent-variable updates. Default 0.0001.

    • 'g_d_freq' (int): Discriminator-to-generator update ratio. Default 5.

    • 'save_model' (bool): Whether to save model checkpoints. Default False.

    • 'save_res' (bool): Whether to save results. Default True.

    • 'kl_weight' (float): KL-divergence weight when use_bnn is True. Default 0.0001.

  • timestamp (str or None, optional) – Timestamp string for the run. If None, the current local time is used.

  • random_seed (int or None, optional) – If provided, sets the global random seed for reproducibility.

Training

fit(data, epochs=100, epochs_per_eval=5, batch_size=32, startoff=0, use_egm_init=True, egm_n_iter=30000, egm_batches_per_eval=500, save_format='txt', verbose=1)[source]

Train CausalBGM with an optional EGM warm-start.

Parameters:
  • data (tuple of np.ndarray) – Training data (data_x, data_y, data_v).

  • epochs (int, default=100) – Number of training epochs.

  • epochs_per_eval (int, default=5) – Evaluate the full training set every this many epochs.

  • batch_size (int, default=32) – Mini-batch size used for both EGM initialization and iterative updates.

  • startoff (int, default=0) – Start tracking the best model only after this epoch.

  • use_egm_init (bool, default=True) – If True, run EGM initialization before iterative training.

  • egm_n_iter (int, default=30000) – Number of EGM mini-batch iterations when use_egm_init=True.

  • egm_batches_per_eval (int, default=500) – Logging interval for EGM initialization.

  • save_format (str, default='txt') – File format used when saving causal estimates.

  • verbose (int, default=1) – Verbosity level. Set to 0 to suppress progress logging.

Notes

After the optional EGM warm-start, latent variables are initialized from e(V). If EGM is skipped, they are initialized from a standard normal distribution.

Inference

predict(data, alpha=0.01, n_mcmc=3000, burn_in=5000, x_values=None, q_sd=1.0, sample_y=True, bs=10000)[source]

Estimate causal effects with posterior intervals from latent MCMC samples.

Parameters:
  • data (tuple of np.ndarray) – Test data (data_x, data_y, data_v).

  • alpha (float, default=0.01) – Significance level used for posterior intervals.

  • n_mcmc (int, default=3000) – Number of retained MCMC samples.

  • burn_in (int, default=5000) – Number of burn-in iterations for the Metropolis-Hastings sampler.

  • x_values (float or array-like, optional) – Treatment values used to evaluate the dose-response curve for continuous-treatment settings.

  • q_sd (float, default=1.0) – Proposal standard deviation for the Metropolis-Hastings sampler.

  • sample_y (bool, default=True) – If True, sample from the outcome model using the variance head. If False, use the posterior mean of the outcome model.

  • bs (int, default=10000) – Number of test subjects processed per batch prediction.

Returns:

  • effect (np.ndarray) – Binary treatment: ITE estimates with shape (n,). Continuous treatment: ADRF estimates with shape (len(x_values),).

  • pos_int (np.ndarray) – Posterior intervals with shape (n, 2) for binary treatment or (len(x_values), 2) for continuous treatment.

evaluate(data, data_z=None, nb_intervals=200)

Configuration

get_config()[source]

Return the configuration of the CausalBGM model.

Returns:

A dictionary with key "params" containing the full configuration dictionary passed at construction time.

Return type:

dict

IdentifiableCausalBGM

class bayesgm.models.causalbgm.IdentifiableCausalBGM(params, timestamp=None, random_seed=None)[source]

Identifiable CausalBGM using nonlinear ICA theory (iVAE).

Achieves identifiability under mild conditions by introducing an auxiliary variable \(U\) and conditioning the latent prior on it: \(Z \mid U \sim \mathcal{N}(\mu(U), \sigma^2(U) I)\).

Inherits from CausalBGM.

Parameters:
  • params (dict) –

    Same keys as CausalBGM, plus optionally:

    • 'n_segments' (int): Number of auxiliary-variable segments (default 10).

    • 'prior_units' (list[int]): Hidden-layer sizes for the prior network (default [64]).

  • timestamp (str or None, optional) – Timestamp string for the run.

  • random_seed (int or None, optional) – If provided, sets the global random seed for reproducibility.

Training

fit(data, batch_size=32, epochs=100, epochs_per_eval=5, startoff=0, use_egm_init=True, egm_n_iter=30000, egm_batches_per_eval=500, verbose=1, save_format='txt')[source]

Train the IdentifiableCausalBGM model on observed data.

The training procedure consists of two phases:

  1. EGM initialization (optional) — warm-start by jointly training encoder and generator with adversarial losses to obtain a good starting point for the latent variables and model parameters. This phase is optional and can be skipped by setting use_egm_init to False.

  2. Iterative optimization — generates an auxiliary variable \(U\) internally and jointly optimizes latent variables, network parameters, and the conditional prior network.

Parameters:
  • data (tuple of np.ndarray) – A triplet (data_x, data_y, data_v).

  • batch_size (int, default=32) – Mini-batch size.

  • epochs (int, default=100) – Number of training epochs.

  • epochs_per_eval (int, default=5) – Evaluate every this many epochs.

  • startoff (int, default=0) – Only start tracking the best model after this many epochs.

  • use_egm_init (bool, default=True) – Whether to run EGM initialization before iterative training.

  • egm_n_iter (int, default=30000) – Number of EGM initialization iterations.

  • egm_batches_per_eval (int, default=500) – Evaluate EGM every this many iterations.

  • verbose (int, default=1) – Verbosity level.

  • save_format (str, default='txt') – File format for saving causal estimates.

Inference

predict(data, alpha=0.01, n_mcmc=3000, x_values=None, q_sd=1.0, sample_y=True, bs=100)[source]

Predict causal effects with posterior uncertainty via MCMC.

Same interface as CausalBGM.predict(). Internally generates a fresh auxiliary variable \(U\) for the conditional prior during MCMC sampling.

Parameters:
  • data (tuple of np.ndarray) – A triplet (data_x, data_y, data_v).

  • alpha (float, default=0.01) – Significance level for posterior intervals.

  • n_mcmc (int, default=3000) – Number of posterior MCMC samples.

  • x_values (array-like or None) – Treatment values for dose-response (continuous treatment).

  • q_sd (float, default=1.0) – Proposal standard deviation for Metropolis-Hastings.

  • sample_y (bool, default=True) – Whether to sample from the outcome variance model.

  • bs (int, default=100) – Batch size for processing posterior samples.

Returns:

  • effect (np.ndarray) – ITE (binary) or ADRF (continuous) point estimates.

  • pos_int (np.ndarray) – Posterior intervals with shape (n, 2) or (len(x_values), 2).

evaluate(data, data_z=None, nb_intervals=200)

Configuration

get_config()

Return the configuration of the CausalBGM model.

Returns:

A dictionary with key "params" containing the full configuration dictionary passed at construction time.

Return type:

dict

FullMCMCCausalBGM

class bayesgm.models.causalbgm.FullMCMCCausalBGM(params, timestamp=None, random_seed=None)[source]

CausalBGM with full MCMC sampling for both individual latent variables and neural-network parameters.

After calling fit() (which uses SGD for both network weights and latent variables), invoke run_mcmc_training() to draw posterior samples of all network weights via Hamiltonian Monte Carlo. The predict() method then marginalises over both latent-variable and weight uncertainty.

Inherits from CausalBGM.

Parameters:
  • params (dict) – Same keys as CausalBGM.

  • timestamp (str or None, optional) – Timestamp string for the run.

  • random_seed (int or None, optional) – If provided, sets the global random seed for reproducibility.

Training

fit(data, epochs=100, epochs_per_eval=5, batch_size=32, startoff=0, use_egm_init=True, egm_n_iter=30000, egm_batches_per_eval=500, save_format='txt', verbose=1)

Train CausalBGM with an optional EGM warm-start.

Parameters:
  • data (tuple of np.ndarray) – Training data (data_x, data_y, data_v).

  • epochs (int, default=100) – Number of training epochs.

  • epochs_per_eval (int, default=5) – Evaluate the full training set every this many epochs.

  • batch_size (int, default=32) – Mini-batch size used for both EGM initialization and iterative updates.

  • startoff (int, default=0) – Start tracking the best model only after this epoch.

  • use_egm_init (bool, default=True) – If True, run EGM initialization before iterative training.

  • egm_n_iter (int, default=30000) – Number of EGM mini-batch iterations when use_egm_init=True.

  • egm_batches_per_eval (int, default=500) – Logging interval for EGM initialization.

  • save_format (str, default='txt') – File format used when saving causal estimates.

  • verbose (int, default=1) – Verbosity level. Set to 0 to suppress progress logging.

Notes

After the optional EGM warm-start, latent variables are initialized from e(V). If EGM is skipped, they are initialized from a standard normal distribution.

run_mcmc_training(data, num_samples=2000, num_burnin=1000, eps=1e-06)[source]

Draw posterior weight samples via Hamiltonian Monte Carlo.

Runs HMC on the weights of g_net, h_net, and f_net conditioned on the optimised latent variables from fit(). Must be called after fit().

Parameters:
  • data (tuple of np.ndarray) – A triplet (data_x, data_y, data_v).

  • num_samples (int, default=2000) – Number of HMC posterior samples to draw.

  • num_burnin (int, default=1000) – Number of burn-in steps to discard.

  • eps (float, default=1e-6) – Small constant added for numerical stability in the likelihood computation.

Inference

predict(data, alpha=0.01, n_mcmc=3000, x_values=None, q_sd=1.0, sample_y=True, bs=100)[source]

Predict causal effects with full posterior uncertainty.

Marginalises over both latent-variable and network-weight uncertainty. run_mcmc_training() must be called first to populate weight samples.

Parameters:
  • data (tuple of np.ndarray) – A triplet (data_x, data_y, data_v).

  • alpha (float, default=0.01) – Significance level for posterior intervals.

  • n_mcmc (int, default=3000) – Number of posterior MCMC samples for latent variables.

  • x_values (array-like or None) – Treatment values for dose-response (continuous treatment).

  • q_sd (float, default=1.0) – Proposal standard deviation for Metropolis-Hastings.

  • sample_y (bool, default=True) – Whether to sample from the outcome variance model.

  • bs (int, default=100) – Batch size for processing posterior samples.

Returns:

  • effect (np.ndarray) – ITE (binary) or ADRF (continuous) point estimates.

  • pos_int (np.ndarray) – Posterior intervals.

evaluate(data, data_z=None, nb_intervals=200)

Configuration

get_config()

Return the configuration of the CausalBGM model.

Returns:

A dictionary with key "params" containing the full configuration dictionary passed at construction time.

Return type:

dict