Tutorial for Python Users
In this tutorial, we will go through CausalBGM workflow under both continous treatment setting and binary treatment setting.
Users can use CausalBGM by Python API or a single command line after installation.
First of all, you need to install python bayesgm toolkit, please refer to the install page.
Notes
[1]:
import numpy as np
import pandas as pd
import yaml
import bayesgm
print("Currently use version %s of bayesgm."%bayesgm.__version__)
Currently use version 1.0.1 of bayesgm.
Use CausalBGM Python API in continous treatment setting
We will use Hirano and Imbens dataset for an example.
CausalBGM Configuration Parameters
Before creating a CausalBGM model, a python dict object should be created for configing a CausalBGM model, which are described as follows.
General Parameters
Config Parameter |
Description |
|---|---|
|
Dataset name to indicate the input data. Default: ‘Sim_Hirano_Imbens’. |
|
Output directory to save the results during the model training. Default: ‘.’. |
|
Whether to save intermediate results. Default: True. |
|
Whether to save the model after training. Default: False. |
|
Whether to use binary treatment settings. Default: False. |
|
Whether to use Bayesian neural networks. Default: True. |
Parameters for Iterative Updating Algorithm
Config Parameter |
Description |
|---|---|
|
Latent dimensions of |
|
Dimension of covariates. Default: 200. |
|
Learning rate for updating model parameters. Default: 0.0001. |
|
Learning rate for updating latent variables. Default: 0.0001. |
|
Number of units for covariates generative model. Default: [64, 64, 64, 64, 64]. |
|
Number of units for outcome generative model. Default: [64, 32, 8]. |
|
Number of units for treatment generative model. Default: [64, 32, 8]. |
Parameters for EGM Initialization
Config Parameter |
Description |
|---|---|
|
Coefficient for KL divergence term in BNNs. Default: 0.0001. |
|
Learning rate for EGM initialization. Default: 0.0002. |
|
Frequency for updating discriminators and generators. Default: 5. |
|
Whether to use reconstruction for latent features. Default: True. |
|
Number of units for the encoder network. Default: [64, 64, 64, 64, 64]. |
|
Number of units for the discriminator network in latent space. Default: [64, 32, 8]. |
Notes
Many example config files are provided in the folder src/configs (link) for datasets used in the paper.
Loading config parameters
[2]:
params = yaml.safe_load(open('src/configs/Sim_Hirano_Imbens.yaml', 'r'))
print(params)
{'dataset': 'Sim_Hirano_Imbens', 'output_dir': '.', 'save_res': True, 'save_model': False, 'binary_treatment': False, 'use_bnn': True, 'z_dims': [1, 1, 1, 7], 'v_dim': 200, 'lr_theta': 0.0001, 'lr_z': 0.0001, 'g_units': [64, 64, 64, 64, 64], 'f_units': [64, 32, 8], 'h_units': [64, 32, 8], 'kl_weight': 0.0001, 'lr': 0.0002, 'g_d_freq': 5, 'use_z_rec': True, 'e_units': [64, 64, 64, 64, 64], 'dz_units': [64, 32, 8]}
Instantiate a CausalBGM model
[3]:
model = bayesgm.models.CausalBGM(params=params, random_seed=None)
2026-03-11 10:40:29.953365: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-03-11 10:40:30.165681: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-03-11 10:40:30.278509: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-03-11 10:40:32.578548: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/ql339/.conda/envs/py3.9/lib/
2026-03-11 10:40:32.579064: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/ql339/.conda/envs/py3.9/lib/
2026-03-11 10:40:32.579068: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2026-03-11 10:40:39.357386: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2027] TensorFlow was not built with CUDA kernel binaries compatible with compute capability 9.0. CUDA kernels will be jit-compiled from PTX, which could take 30 minutes or longer.
2026-03-11 10:40:39.357660: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-03-11 10:40:39.363478: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2027] TensorFlow was not built with CUDA kernel binaries compatible with compute capability 9.0. CUDA kernels will be jit-compiled from PTX, which could take 30 minutes or longer.
2026-03-11 10:40:44.230986: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 140929 MB memory: -> device: 0, name: NVIDIA H200, pci bus id: 0000:bb:00.0, compute capability: 9.0
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:95: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
loc = add_variable_fn(
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:105: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
untransformed_scale = add_variable_fn(
2026-03-11 10:40:47.589822: I tensorflow/stream_executor/cuda/cuda_blas.cc:1614] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
Data preparation
The input data are organized in a triplet, which contains treatment (X), potential outcome (Y), and covariates (V).
[4]:
from bayesgm.datasets import Sim_Hirano_Imbens_sampler
x,y,v = Sim_Hirano_Imbens_sampler(N=20000, v_dim=200).load_all()
print(x.shape,y.shape,v.shape)
(20000, 1) (20000, 1) (20000, 200)
Model training
Train CausalBGM with an optional EGM warm-start.
Config Parameter |
Description |
|---|---|
|
Tuple of data inputs |
|
Batch size for training. Default: 32. |
|
Number of epochs for training. Default: 100. |
|
Frequency of evaluations during training (e.g., every 5 epochs). Default: 5. |
|
Whether to run EGM initialization before iterative training. Default: True. |
|
Frequency of evaluations during training (e.g., every 5 epochs). Default: 5. |
|
Number of EGM initialization iterations. Default: 30000. |
|
Evaluate EGM initialization every this many iterations. Default: 500. |
|
Controls verbosity level, showing progress and evaluation metrics. Default: 1. |
Notes
The training procedure consists of two phases:
EGM initialization (optional) — warm-start to obtain a good starting point for the latent variables and model parameters. This phase is optional and can be skipped by setting use_egm_init to False.
Stochastic iterative updating — alternates between updating the generator network parameters and the per-sample latent variables via stochastic gradient optimization.
[5]:
model.fit(data=(x,y,v), epochs=100, epochs_per_eval=10, use_egm_init=True, egm_n_iter=30000, egm_batches_per_eval=500, verbose=1)
EGM Initialization Starts ...
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:95: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
loc = add_variable_fn(
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:105: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
untransformed_scale = add_variable_fn(
EGM Initialization Iter [0] : e_loss_adv [0.2336], l2_loss_v [1.0228], l2_loss_z [0.9938], l2_loss_x [1.0393], l2_loss_y [4.4029], g_e_loss [7.6924], dz_loss [0.0516], d_loss [0.3271]
EGM Initialization Iter [500] : e_loss_adv [1.7835], l2_loss_v [1.0174], l2_loss_z [0.8902], l2_loss_x [11.1814], l2_loss_y [12.4417], g_e_loss [27.3142], dz_loss [-1.1979], d_loss [-0.7519]
EGM Initialization Iter [1000] : e_loss_adv [0.2142], l2_loss_v [1.0167], l2_loss_z [0.9186], l2_loss_x [0.4651], l2_loss_y [1.7649], g_e_loss [4.3797], dz_loss [-1.5703], d_loss [-1.4926]
EGM Initialization Iter [1500] : e_loss_adv [0.6016], l2_loss_v [1.0229], l2_loss_z [0.8104], l2_loss_x [1.7311], l2_loss_y [1.1285], g_e_loss [5.2944], dz_loss [-1.4662], d_loss [-1.1440]
EGM Initialization Iter [2000] : e_loss_adv [-0.2172], l2_loss_v [0.9818], l2_loss_z [0.7609], l2_loss_x [3.8512], l2_loss_y [1.3351], g_e_loss [6.7121], dz_loss [-0.7049], d_loss [-0.5723]
EGM Initialization Iter [2500] : e_loss_adv [-0.6869], l2_loss_v [0.9955], l2_loss_z [0.7060], l2_loss_x [0.5864], l2_loss_y [1.4249], g_e_loss [3.0260], dz_loss [-0.7184], d_loss [-0.5650]
EGM Initialization Iter [3000] : e_loss_adv [-0.9327], l2_loss_v [0.9780], l2_loss_z [0.7114], l2_loss_x [1.4311], l2_loss_y [1.4923], g_e_loss [3.6802], dz_loss [-0.4438], d_loss [-0.2919]
EGM Initialization Iter [3500] : e_loss_adv [-1.2839], l2_loss_v [1.0583], l2_loss_z [0.7365], l2_loss_x [2.3093], l2_loss_y [1.9179], g_e_loss [4.7381], dz_loss [-0.8263], d_loss [-0.7364]
EGM Initialization Iter [4000] : e_loss_adv [-1.5408], l2_loss_v [1.0294], l2_loss_z [0.6064], l2_loss_x [4.0206], l2_loss_y [1.4722], g_e_loss [5.5880], dz_loss [-0.8158], d_loss [-0.6767]
EGM Initialization Iter [4500] : e_loss_adv [-1.0226], l2_loss_v [0.9935], l2_loss_z [0.6432], l2_loss_x [1.1369], l2_loss_y [1.1873], g_e_loss [2.9384], dz_loss [-0.7635], d_loss [-0.6356]
EGM Initialization Iter [5000] : e_loss_adv [-1.3171], l2_loss_v [1.0850], l2_loss_z [0.5702], l2_loss_x [0.8795], l2_loss_y [1.1124], g_e_loss [2.3301], dz_loss [-0.7465], d_loss [-0.6473]
EGM Initialization Iter [5500] : e_loss_adv [-1.3136], l2_loss_v [1.0697], l2_loss_z [0.6109], l2_loss_x [3.3471], l2_loss_y [1.4369], g_e_loss [5.1511], dz_loss [-0.7818], d_loss [-0.7041]
EGM Initialization Iter [6000] : e_loss_adv [-1.5287], l2_loss_v [1.0135], l2_loss_z [0.5739], l2_loss_x [3.3765], l2_loss_y [2.4388], g_e_loss [5.8741], dz_loss [-0.4657], d_loss [-0.3500]
EGM Initialization Iter [6500] : e_loss_adv [-0.9883], l2_loss_v [1.0253], l2_loss_z [0.4936], l2_loss_x [12.1984], l2_loss_y [1.8050], g_e_loss [14.5340], dz_loss [-1.0290], d_loss [-0.8756]
EGM Initialization Iter [7000] : e_loss_adv [-0.9280], l2_loss_v [1.0170], l2_loss_z [0.5469], l2_loss_x [0.7295], l2_loss_y [1.5921], g_e_loss [2.9574], dz_loss [-1.1066], d_loss [-0.9553]
EGM Initialization Iter [7500] : e_loss_adv [-1.3834], l2_loss_v [0.9478], l2_loss_z [0.5085], l2_loss_x [1.4535], l2_loss_y [0.8933], g_e_loss [2.4197], dz_loss [-1.1365], d_loss [-0.9165]
EGM Initialization Iter [8000] : e_loss_adv [-1.1943], l2_loss_v [1.0483], l2_loss_z [0.4082], l2_loss_x [1.5088], l2_loss_y [1.2450], g_e_loss [3.0161], dz_loss [-0.6851], d_loss [-0.5403]
EGM Initialization Iter [8500] : e_loss_adv [-0.9068], l2_loss_v [0.9925], l2_loss_z [0.5008], l2_loss_x [2.5160], l2_loss_y [1.6454], g_e_loss [4.7481], dz_loss [-0.4941], d_loss [-0.3221]
EGM Initialization Iter [9000] : e_loss_adv [-1.1367], l2_loss_v [1.0227], l2_loss_z [0.4641], l2_loss_x [0.8961], l2_loss_y [1.0888], g_e_loss [2.3350], dz_loss [-1.1502], d_loss [-1.0479]
EGM Initialization Iter [9500] : e_loss_adv [-0.4581], l2_loss_v [1.0066], l2_loss_z [0.4015], l2_loss_x [4.5273], l2_loss_y [1.1357], g_e_loss [6.6130], dz_loss [-0.4729], d_loss [-0.2438]
EGM Initialization Iter [10000] : e_loss_adv [-0.4825], l2_loss_v [1.0119], l2_loss_z [0.3976], l2_loss_x [4.9266], l2_loss_y [1.6185], g_e_loss [7.4722], dz_loss [-0.9087], d_loss [-0.8407]
EGM Initialization Iter [10500] : e_loss_adv [-0.3622], l2_loss_v [0.9236], l2_loss_z [0.3934], l2_loss_x [1.1994], l2_loss_y [0.8895], g_e_loss [3.0437], dz_loss [-0.5223], d_loss [-0.3849]
EGM Initialization Iter [11000] : e_loss_adv [-0.6587], l2_loss_v [1.0022], l2_loss_z [0.3753], l2_loss_x [0.6290], l2_loss_y [2.2966], g_e_loss [3.6443], dz_loss [-0.3311], d_loss [-0.2351]
EGM Initialization Iter [11500] : e_loss_adv [-0.8142], l2_loss_v [0.9219], l2_loss_z [0.3577], l2_loss_x [0.4504], l2_loss_y [0.6233], g_e_loss [1.5392], dz_loss [-0.5330], d_loss [-0.4307]
EGM Initialization Iter [12000] : e_loss_adv [-0.1782], l2_loss_v [0.9856], l2_loss_z [0.4137], l2_loss_x [17.7432], l2_loss_y [1.1676], g_e_loss [20.1319], dz_loss [-0.6628], d_loss [-0.5342]
EGM Initialization Iter [12500] : e_loss_adv [-0.7103], l2_loss_v [1.0942], l2_loss_z [0.3109], l2_loss_x [4.3603], l2_loss_y [0.7972], g_e_loss [5.8523], dz_loss [-0.3978], d_loss [-0.2856]
EGM Initialization Iter [13000] : e_loss_adv [-0.7747], l2_loss_v [1.0407], l2_loss_z [0.3655], l2_loss_x [1.3099], l2_loss_y [0.5965], g_e_loss [2.5379], dz_loss [-0.5982], d_loss [-0.5245]
EGM Initialization Iter [13500] : e_loss_adv [-1.5893], l2_loss_v [0.9822], l2_loss_z [0.3269], l2_loss_x [1.5761], l2_loss_y [1.4505], g_e_loss [2.7463], dz_loss [-0.2089], d_loss [-0.1767]
EGM Initialization Iter [14000] : e_loss_adv [-0.3512], l2_loss_v [0.9724], l2_loss_z [0.3501], l2_loss_x [4.6126], l2_loss_y [1.2828], g_e_loss [6.8669], dz_loss [-0.5639], d_loss [-0.5068]
EGM Initialization Iter [14500] : e_loss_adv [-0.8137], l2_loss_v [1.0103], l2_loss_z [0.4248], l2_loss_x [0.5124], l2_loss_y [1.4730], g_e_loss [2.6068], dz_loss [-0.3226], d_loss [-0.2140]
EGM Initialization Iter [15000] : e_loss_adv [-0.4964], l2_loss_v [0.9213], l2_loss_z [0.2990], l2_loss_x [1.2344], l2_loss_y [0.9027], g_e_loss [2.8611], dz_loss [-0.1730], d_loss [-0.0417]
EGM Initialization Iter [15500] : e_loss_adv [-0.9802], l2_loss_v [1.0312], l2_loss_z [0.3307], l2_loss_x [0.9689], l2_loss_y [1.3733], g_e_loss [2.7239], dz_loss [-0.7252], d_loss [-0.6841]
EGM Initialization Iter [16000] : e_loss_adv [-0.8511], l2_loss_v [0.9647], l2_loss_z [0.3063], l2_loss_x [1.1361], l2_loss_y [1.1887], g_e_loss [2.7447], dz_loss [-0.7088], d_loss [-0.6389]
EGM Initialization Iter [16500] : e_loss_adv [-0.9395], l2_loss_v [1.0061], l2_loss_z [0.3282], l2_loss_x [1.2912], l2_loss_y [1.5184], g_e_loss [3.2045], dz_loss [-0.1869], d_loss [-0.1471]
EGM Initialization Iter [17000] : e_loss_adv [-0.1236], l2_loss_v [1.0263], l2_loss_z [0.3395], l2_loss_x [2.4202], l2_loss_y [0.7661], g_e_loss [4.4285], dz_loss [-0.3217], d_loss [-0.2915]
EGM Initialization Iter [17500] : e_loss_adv [-0.6039], l2_loss_v [1.0110], l2_loss_z [0.2228], l2_loss_x [0.6884], l2_loss_y [1.0420], g_e_loss [2.3603], dz_loss [-0.6893], d_loss [-0.6441]
EGM Initialization Iter [18000] : e_loss_adv [-0.9193], l2_loss_v [1.0120], l2_loss_z [0.2393], l2_loss_x [0.4388], l2_loss_y [0.9639], g_e_loss [1.7347], dz_loss [-0.4117], d_loss [-0.3580]
EGM Initialization Iter [18500] : e_loss_adv [-1.1126], l2_loss_v [0.9904], l2_loss_z [0.2908], l2_loss_x [12.6224], l2_loss_y [1.8989], g_e_loss [14.6900], dz_loss [0.1250], d_loss [0.1782]
EGM Initialization Iter [19000] : e_loss_adv [-0.4944], l2_loss_v [1.0103], l2_loss_z [0.2811], l2_loss_x [0.9594], l2_loss_y [1.4073], g_e_loss [3.1637], dz_loss [-0.7970], d_loss [-0.7325]
EGM Initialization Iter [19500] : e_loss_adv [0.3062], l2_loss_v [0.9363], l2_loss_z [0.2672], l2_loss_x [17.2928], l2_loss_y [1.6232], g_e_loss [20.4258], dz_loss [-0.3881], d_loss [-0.2849]
EGM Initialization Iter [20000] : e_loss_adv [-1.7392], l2_loss_v [1.0347], l2_loss_z [0.2673], l2_loss_x [6.0748], l2_loss_y [1.1972], g_e_loss [6.8348], dz_loss [-0.6097], d_loss [-0.5434]
EGM Initialization Iter [20500] : e_loss_adv [-0.9999], l2_loss_v [0.9217], l2_loss_z [0.2955], l2_loss_x [3.1602], l2_loss_y [1.3958], g_e_loss [4.7734], dz_loss [0.0663], d_loss [0.1290]
EGM Initialization Iter [21000] : e_loss_adv [-0.3025], l2_loss_v [0.9891], l2_loss_z [0.2912], l2_loss_x [3.4434], l2_loss_y [1.3069], g_e_loss [5.7280], dz_loss [-0.2240], d_loss [-0.2015]
EGM Initialization Iter [21500] : e_loss_adv [-0.4693], l2_loss_v [1.0710], l2_loss_z [0.2530], l2_loss_x [0.4649], l2_loss_y [1.3412], g_e_loss [2.6608], dz_loss [-0.6601], d_loss [-0.6297]
EGM Initialization Iter [22000] : e_loss_adv [-1.7933], l2_loss_v [0.9932], l2_loss_z [0.2663], l2_loss_x [1.2690], l2_loss_y [1.4594], g_e_loss [2.1947], dz_loss [-0.0681], d_loss [-0.0063]
EGM Initialization Iter [22500] : e_loss_adv [-0.4883], l2_loss_v [0.9213], l2_loss_z [0.2608], l2_loss_x [0.6639], l2_loss_y [0.6807], g_e_loss [2.0383], dz_loss [-0.1923], d_loss [-0.1554]
EGM Initialization Iter [23000] : e_loss_adv [-0.5152], l2_loss_v [0.9531], l2_loss_z [0.2275], l2_loss_x [0.4341], l2_loss_y [1.0350], g_e_loss [2.1345], dz_loss [-0.2549], d_loss [-0.1995]
EGM Initialization Iter [23500] : e_loss_adv [-0.1655], l2_loss_v [0.9404], l2_loss_z [0.2472], l2_loss_x [0.5480], l2_loss_y [0.6557], g_e_loss [2.2257], dz_loss [-0.1739], d_loss [-0.1279]
EGM Initialization Iter [24000] : e_loss_adv [-0.8649], l2_loss_v [0.9337], l2_loss_z [0.2714], l2_loss_x [1.8735], l2_loss_y [1.4069], g_e_loss [3.6205], dz_loss [-0.2289], d_loss [-0.1504]
EGM Initialization Iter [24500] : e_loss_adv [0.3453], l2_loss_v [0.9769], l2_loss_z [0.2489], l2_loss_x [6.0515], l2_loss_y [0.9008], g_e_loss [8.5234], dz_loss [-0.2344], d_loss [-0.1849]
EGM Initialization Iter [25000] : e_loss_adv [-0.6483], l2_loss_v [0.9621], l2_loss_z [0.2230], l2_loss_x [0.6436], l2_loss_y [1.7994], g_e_loss [2.9798], dz_loss [-0.1736], d_loss [-0.1410]
EGM Initialization Iter [25500] : e_loss_adv [-0.6854], l2_loss_v [1.0132], l2_loss_z [0.2315], l2_loss_x [4.1439], l2_loss_y [2.5138], g_e_loss [7.2171], dz_loss [-0.4864], d_loss [-0.4177]
EGM Initialization Iter [26000] : e_loss_adv [-1.2118], l2_loss_v [1.0288], l2_loss_z [0.2737], l2_loss_x [5.5182], l2_loss_y [1.2290], g_e_loss [6.8380], dz_loss [0.1194], d_loss [0.1722]
EGM Initialization Iter [26500] : e_loss_adv [0.7093], l2_loss_v [1.0771], l2_loss_z [0.2293], l2_loss_x [3.3827], l2_loss_y [1.4104], g_e_loss [6.8087], dz_loss [-0.2914], d_loss [-0.1987]
EGM Initialization Iter [27000] : e_loss_adv [-0.7259], l2_loss_v [0.8823], l2_loss_z [0.2003], l2_loss_x [0.8137], l2_loss_y [0.8930], g_e_loss [2.0634], dz_loss [-0.7017], d_loss [-0.6602]
EGM Initialization Iter [27500] : e_loss_adv [-1.5183], l2_loss_v [0.9569], l2_loss_z [0.1955], l2_loss_x [2.3814], l2_loss_y [1.4743], g_e_loss [3.4897], dz_loss [-0.1165], d_loss [-0.0647]
EGM Initialization Iter [28000] : e_loss_adv [-0.0395], l2_loss_v [1.0021], l2_loss_z [0.2151], l2_loss_x [3.7281], l2_loss_y [0.9516], g_e_loss [5.8575], dz_loss [-0.5957], d_loss [-0.5354]
EGM Initialization Iter [28500] : e_loss_adv [-2.1634], l2_loss_v [1.0258], l2_loss_z [0.1998], l2_loss_x [0.9157], l2_loss_y [0.8383], g_e_loss [0.8162], dz_loss [-0.1053], d_loss [-0.0697]
EGM Initialization Iter [29000] : e_loss_adv [-0.0149], l2_loss_v [0.9536], l2_loss_z [0.2508], l2_loss_x [2.8241], l2_loss_y [1.4126], g_e_loss [5.4262], dz_loss [-0.6389], d_loss [-0.6117]
EGM Initialization Iter [29500] : e_loss_adv [-0.8025], l2_loss_v [0.9960], l2_loss_z [0.2465], l2_loss_x [0.9598], l2_loss_y [1.1993], g_e_loss [2.5991], dz_loss [-0.3781], d_loss [-0.2946]
EGM Initialization Iter [30000] : e_loss_adv [-1.7520], l2_loss_v [0.9474], l2_loss_z [0.2136], l2_loss_x [1.0793], l2_loss_y [0.7894], g_e_loss [1.2777], dz_loss [-0.4842], d_loss [-0.4431]
EGM Initialization Ends.
Initialize latent variables Z with e(V)...
Iterative Updating Starts ...
Epoch 0/100: 100%|██████████| 625/625 [00:17<00:00, 34.93batch/s, loss_px_z: [1.0017], loss_mse_x: [0.7962], loss_py_z: [1.7299], loss_mse_y: [1.8524], loss_pv_z: [104.1091], loss_mse_v: [0.9552], loss_postrior_z: [101.2081]]
Epoch [0/100]: MSE_x: 2.1162, MSE_y: 1.2035, MSE_v: 0.9740
Epoch 1/100: 100%|██████████| 625/625 [00:11<00:00, 55.72batch/s, loss_px_z: [1.0627], loss_mse_x: [0.9919], loss_py_z: [1.3802], loss_mse_y: [1.3957], loss_pv_z: [101.1004], loss_mse_v: [0.9326], loss_postrior_z: [99.0015]]
Epoch 2/100: 100%|██████████| 625/625 [00:11<00:00, 54.91batch/s, loss_px_z: [2.9263], loss_mse_x: [4.9884], loss_py_z: [1.2504], loss_mse_y: [1.2226], loss_pv_z: [104.0898], loss_mse_v: [0.9618], loss_postrior_z: [104.2342]]
Epoch 3/100: 100%|██████████| 625/625 [00:11<00:00, 55.40batch/s, loss_px_z: [1.3514], loss_mse_x: [1.8893], loss_py_z: [1.2817], loss_mse_y: [1.3911], loss_pv_z: [111.9185], loss_mse_v: [1.0400], loss_postrior_z: [111.6363]]
Epoch 4/100: 100%|██████████| 625/625 [00:11<00:00, 54.95batch/s, loss_px_z: [0.8047], loss_mse_x: [0.7948], loss_py_z: [1.1081], loss_mse_y: [1.0806], loss_pv_z: [109.0874], loss_mse_v: [1.0118], loss_postrior_z: [105.9435]]
Epoch 5/100: 100%|██████████| 625/625 [00:11<00:00, 54.52batch/s, loss_px_z: [1.1618], loss_mse_x: [1.1868], loss_py_z: [1.1339], loss_mse_y: [1.0139], loss_pv_z: [110.8662], loss_mse_v: [1.0276], loss_postrior_z: [109.2909]]
Epoch 6/100: 100%|██████████| 625/625 [00:11<00:00, 55.03batch/s, loss_px_z: [0.8462], loss_mse_x: [0.7771], loss_py_z: [1.3909], loss_mse_y: [1.6178], loss_pv_z: [103.7796], loss_mse_v: [0.9596], loss_postrior_z: [101.4011]]
Epoch 7/100: 100%|██████████| 625/625 [00:11<00:00, 54.77batch/s, loss_px_z: [0.9922], loss_mse_x: [1.4759], loss_py_z: [0.9513], loss_mse_y: [0.6665], loss_pv_z: [104.8899], loss_mse_v: [0.9705], loss_postrior_z: [102.6353]]
Epoch 8/100: 100%|██████████| 625/625 [00:11<00:00, 53.82batch/s, loss_px_z: [0.7981], loss_mse_x: [0.6528], loss_py_z: [1.3613], loss_mse_y: [1.5980], loss_pv_z: [109.5272], loss_mse_v: [1.0158], loss_postrior_z: [106.3503]]
Epoch 9/100: 100%|██████████| 625/625 [00:11<00:00, 54.67batch/s, loss_px_z: [0.7962], loss_mse_x: [0.5568], loss_py_z: [1.0703], loss_mse_y: [0.9598], loss_pv_z: [104.8426], loss_mse_v: [0.9710], loss_postrior_z: [101.9698]]
Epoch 10/100: 100%|██████████| 625/625 [00:11<00:00, 54.84batch/s, loss_px_z: [0.5672], loss_mse_x: [0.5001], loss_py_z: [1.3016], loss_mse_y: [1.4453], loss_pv_z: [98.7699], loss_mse_v: [0.9122], loss_postrior_z: [96.4060]]
Epoch [10/100]: MSE_x: 2.2314, MSE_y: 1.2006, MSE_v: 0.9687
Epoch 11/100: 100%|██████████| 625/625 [00:11<00:00, 54.30batch/s, loss_px_z: [0.5795], loss_mse_x: [0.6802], loss_py_z: [1.1793], loss_mse_y: [1.2066], loss_pv_z: [96.5351], loss_mse_v: [0.8871], loss_postrior_z: [93.4869]]
Epoch 12/100: 100%|██████████| 625/625 [00:11<00:00, 54.78batch/s, loss_px_z: [1.1033], loss_mse_x: [2.4904], loss_py_z: [1.1913], loss_mse_y: [1.2767], loss_pv_z: [98.9891], loss_mse_v: [0.9151], loss_postrior_z: [95.6955]]
Epoch 13/100: 100%|██████████| 625/625 [00:11<00:00, 54.29batch/s, loss_px_z: [1.6449], loss_mse_x: [1.9944], loss_py_z: [1.1341], loss_mse_y: [1.1039], loss_pv_z: [104.2422], loss_mse_v: [0.9682], loss_postrior_z: [103.0387]]
Epoch 14/100: 100%|██████████| 625/625 [00:11<00:00, 54.24batch/s, loss_px_z: [0.6592], loss_mse_x: [2.1695], loss_py_z: [1.0025], loss_mse_y: [0.8883], loss_pv_z: [102.1708], loss_mse_v: [0.9453], loss_postrior_z: [99.7335]]
Epoch 15/100: 100%|██████████| 625/625 [00:11<00:00, 55.68batch/s, loss_px_z: [0.7443], loss_mse_x: [1.6535], loss_py_z: [1.1693], loss_mse_y: [1.2531], loss_pv_z: [103.2916], loss_mse_v: [0.9567], loss_postrior_z: [101.1586]]
Epoch 16/100: 100%|██████████| 625/625 [00:11<00:00, 55.40batch/s, loss_px_z: [0.7354], loss_mse_x: [0.8059], loss_py_z: [1.1732], loss_mse_y: [1.2717], loss_pv_z: [106.4319], loss_mse_v: [0.9858], loss_postrior_z: [104.2073]]
Epoch 17/100: 100%|██████████| 625/625 [00:11<00:00, 55.67batch/s, loss_px_z: [2.1600], loss_mse_x: [2.8464], loss_py_z: [1.0883], loss_mse_y: [1.1696], loss_pv_z: [106.7967], loss_mse_v: [0.9891], loss_postrior_z: [106.0688]]
Epoch 18/100: 100%|██████████| 625/625 [00:11<00:00, 54.89batch/s, loss_px_z: [0.7860], loss_mse_x: [2.6597], loss_py_z: [1.0036], loss_mse_y: [0.9227], loss_pv_z: [102.5790], loss_mse_v: [0.9485], loss_postrior_z: [100.5853]]
Epoch 19/100: 100%|██████████| 625/625 [00:11<00:00, 53.88batch/s, loss_px_z: [0.9116], loss_mse_x: [0.9684], loss_py_z: [0.9765], loss_mse_y: [0.9605], loss_pv_z: [105.0637], loss_mse_v: [0.9744], loss_postrior_z: [103.2852]]
Epoch 20/100: 100%|██████████| 625/625 [00:11<00:00, 54.23batch/s, loss_px_z: [1.1986], loss_mse_x: [7.3664], loss_py_z: [1.3686], loss_mse_y: [2.0824], loss_pv_z: [103.5854], loss_mse_v: [0.9616], loss_postrior_z: [101.7267]]
Epoch [20/100]: MSE_x: 2.3551, MSE_y: 1.2030, MSE_v: 0.9676
Epoch 21/100: 100%|██████████| 625/625 [00:11<00:00, 54.97batch/s, loss_px_z: [0.6754], loss_mse_x: [0.8088], loss_py_z: [1.1038], loss_mse_y: [1.2782], loss_pv_z: [100.8194], loss_mse_v: [0.9332], loss_postrior_z: [97.6035]]
Epoch 22/100: 100%|██████████| 625/625 [00:11<00:00, 55.75batch/s, loss_px_z: [1.5267], loss_mse_x: [1.2338], loss_py_z: [0.9382], loss_mse_y: [0.8040], loss_pv_z: [104.1450], loss_mse_v: [0.9673], loss_postrior_z: [102.7854]]
Epoch 23/100: 100%|██████████| 625/625 [00:11<00:00, 54.29batch/s, loss_px_z: [0.7087], loss_mse_x: [1.0746], loss_py_z: [1.1086], loss_mse_y: [1.3181], loss_pv_z: [105.0445], loss_mse_v: [0.9765], loss_postrior_z: [102.6641]]
Epoch 24/100: 100%|██████████| 625/625 [00:11<00:00, 54.38batch/s, loss_px_z: [0.3100], loss_mse_x: [0.4896], loss_py_z: [1.1030], loss_mse_y: [1.2588], loss_pv_z: [103.4497], loss_mse_v: [0.9582], loss_postrior_z: [100.2494]]
Epoch 25/100: 100%|██████████| 625/625 [00:11<00:00, 55.14batch/s, loss_px_z: [0.4496], loss_mse_x: [1.0667], loss_py_z: [0.9804], loss_mse_y: [1.0471], loss_pv_z: [106.0292], loss_mse_v: [0.9863], loss_postrior_z: [103.5820]]
Epoch 26/100: 100%|██████████| 625/625 [00:11<00:00, 54.07batch/s, loss_px_z: [0.9430], loss_mse_x: [1.7313], loss_py_z: [0.9116], loss_mse_y: [0.8233], loss_pv_z: [111.1543], loss_mse_v: [1.0376], loss_postrior_z: [108.5810]]
Epoch 27/100: 100%|██████████| 625/625 [00:11<00:00, 55.23batch/s, loss_px_z: [0.3157], loss_mse_x: [0.3745], loss_py_z: [0.8482], loss_mse_y: [0.7153], loss_pv_z: [101.8660], loss_mse_v: [0.9439], loss_postrior_z: [100.0770]]
Epoch 28/100: 100%|██████████| 625/625 [00:11<00:00, 53.60batch/s, loss_px_z: [0.8641], loss_mse_x: [1.8446], loss_py_z: [0.9266], loss_mse_y: [0.8325], loss_pv_z: [103.8543], loss_mse_v: [0.9660], loss_postrior_z: [102.6205]]
Epoch 29/100: 100%|██████████| 625/625 [00:11<00:00, 55.00batch/s, loss_px_z: [0.7417], loss_mse_x: [1.1123], loss_py_z: [1.0385], loss_mse_y: [1.1367], loss_pv_z: [107.2606], loss_mse_v: [0.9963], loss_postrior_z: [104.4046]]
Epoch 30/100: 100%|██████████| 625/625 [00:11<00:00, 55.33batch/s, loss_px_z: [0.7017], loss_mse_x: [2.0424], loss_py_z: [1.0338], loss_mse_y: [1.0654], loss_pv_z: [108.2052], loss_mse_v: [1.0073], loss_postrior_z: [105.8030]]
Epoch [30/100]: MSE_x: 2.1616, MSE_y: 1.2084, MSE_v: 0.9667
Epoch 31/100: 100%|██████████| 625/625 [00:11<00:00, 55.53batch/s, loss_px_z: [0.5098], loss_mse_x: [0.6142], loss_py_z: [1.1598], loss_mse_y: [1.4640], loss_pv_z: [101.9042], loss_mse_v: [0.9434], loss_postrior_z: [98.7518]]
Epoch 32/100: 100%|██████████| 625/625 [00:11<00:00, 55.46batch/s, loss_px_z: [0.2685], loss_mse_x: [0.3559], loss_py_z: [0.8601], loss_mse_y: [0.7184], loss_pv_z: [99.3381], loss_mse_v: [0.9203], loss_postrior_z: [95.7403]]
Epoch 33/100: 100%|██████████| 625/625 [00:11<00:00, 54.63batch/s, loss_px_z: [0.6054], loss_mse_x: [0.6897], loss_py_z: [1.1666], loss_mse_y: [1.4648], loss_pv_z: [110.0090], loss_mse_v: [1.0229], loss_postrior_z: [107.6170]]
Epoch 34/100: 100%|██████████| 625/625 [00:11<00:00, 55.32batch/s, loss_px_z: [0.6286], loss_mse_x: [1.4439], loss_py_z: [1.1934], loss_mse_y: [1.4473], loss_pv_z: [105.3239], loss_mse_v: [0.9798], loss_postrior_z: [103.2204]]
Epoch 35/100: 100%|██████████| 625/625 [00:11<00:00, 53.97batch/s, loss_px_z: [0.6942], loss_mse_x: [1.5184], loss_py_z: [1.0618], loss_mse_y: [1.2188], loss_pv_z: [108.3755], loss_mse_v: [1.0158], loss_postrior_z: [107.3421]]
Epoch 36/100: 100%|██████████| 625/625 [00:11<00:00, 55.26batch/s, loss_px_z: [0.4048], loss_mse_x: [0.7296], loss_py_z: [1.0736], loss_mse_y: [1.2077], loss_pv_z: [98.7979], loss_mse_v: [0.9153], loss_postrior_z: [96.1371]]
Epoch 37/100: 100%|██████████| 625/625 [00:11<00:00, 56.41batch/s, loss_px_z: [0.6732], loss_mse_x: [2.5477], loss_py_z: [1.1506], loss_mse_y: [1.4312], loss_pv_z: [99.2480], loss_mse_v: [0.9178], loss_postrior_z: [97.5223]]
Epoch 38/100: 100%|██████████| 625/625 [00:11<00:00, 55.61batch/s, loss_px_z: [0.5150], loss_mse_x: [1.2934], loss_py_z: [0.8623], loss_mse_y: [0.8216], loss_pv_z: [106.1062], loss_mse_v: [0.9891], loss_postrior_z: [103.6298]]
Epoch 39/100: 100%|██████████| 625/625 [00:11<00:00, 55.37batch/s, loss_px_z: [0.9153], loss_mse_x: [3.1581], loss_py_z: [0.9936], loss_mse_y: [1.0444], loss_pv_z: [105.5084], loss_mse_v: [0.9839], loss_postrior_z: [103.6738]]
Epoch 40/100: 100%|██████████| 625/625 [00:11<00:00, 54.97batch/s, loss_px_z: [0.6372], loss_mse_x: [1.2377], loss_py_z: [0.8759], loss_mse_y: [0.8238], loss_pv_z: [105.4479], loss_mse_v: [0.9848], loss_postrior_z: [105.0747]]
Epoch [40/100]: MSE_x: 2.2160, MSE_y: 1.2174, MSE_v: 0.9662
Epoch 41/100: 100%|██████████| 625/625 [00:11<00:00, 55.49batch/s, loss_px_z: [0.3933], loss_mse_x: [0.7077], loss_py_z: [0.8098], loss_mse_y: [0.7087], loss_pv_z: [100.8510], loss_mse_v: [0.9372], loss_postrior_z: [98.1798]]
Epoch 42/100: 100%|██████████| 625/625 [00:11<00:00, 55.60batch/s, loss_px_z: [0.6129], loss_mse_x: [1.1368], loss_py_z: [1.1585], loss_mse_y: [1.5237], loss_pv_z: [100.2316], loss_mse_v: [0.9312], loss_postrior_z: [97.8113]]
Epoch 43/100: 100%|██████████| 625/625 [00:11<00:00, 54.80batch/s, loss_px_z: [0.3872], loss_mse_x: [0.6784], loss_py_z: [1.0379], loss_mse_y: [1.2025], loss_pv_z: [102.1871], loss_mse_v: [0.9516], loss_postrior_z: [99.1012]]
Epoch 44/100: 100%|██████████| 625/625 [00:11<00:00, 54.81batch/s, loss_px_z: [0.5947], loss_mse_x: [1.2953], loss_py_z: [0.9499], loss_mse_y: [0.9979], loss_pv_z: [106.5104], loss_mse_v: [0.9883], loss_postrior_z: [104.6743]]
Epoch 45/100: 100%|██████████| 625/625 [00:11<00:00, 54.32batch/s, loss_px_z: [0.6412], loss_mse_x: [3.7783], loss_py_z: [0.9002], loss_mse_y: [0.9011], loss_pv_z: [97.6416], loss_mse_v: [0.9105], loss_postrior_z: [94.2107]]
Epoch 46/100: 100%|██████████| 625/625 [00:11<00:00, 55.17batch/s, loss_px_z: [0.6056], loss_mse_x: [1.3386], loss_py_z: [0.9334], loss_mse_y: [0.9600], loss_pv_z: [104.2481], loss_mse_v: [0.9693], loss_postrior_z: [102.1923]]
Epoch 47/100: 100%|██████████| 625/625 [00:11<00:00, 54.82batch/s, loss_px_z: [0.6723], loss_mse_x: [7.5391], loss_py_z: [0.9108], loss_mse_y: [0.9672], loss_pv_z: [103.6044], loss_mse_v: [0.9630], loss_postrior_z: [102.2846]]
Epoch 48/100: 100%|██████████| 625/625 [00:11<00:00, 54.92batch/s, loss_px_z: [0.2197], loss_mse_x: [0.4434], loss_py_z: [1.0514], loss_mse_y: [1.4217], loss_pv_z: [104.7859], loss_mse_v: [0.9797], loss_postrior_z: [101.7223]]
Epoch 49/100: 100%|██████████| 625/625 [00:11<00:00, 55.71batch/s, loss_px_z: [0.2410], loss_mse_x: [0.4813], loss_py_z: [0.9054], loss_mse_y: [1.0336], loss_pv_z: [103.0146], loss_mse_v: [0.9606], loss_postrior_z: [101.0885]]
Epoch 50/100: 100%|██████████| 625/625 [00:11<00:00, 54.98batch/s, loss_px_z: [0.5846], loss_mse_x: [1.2790], loss_py_z: [0.9430], loss_mse_y: [1.0161], loss_pv_z: [111.7826], loss_mse_v: [1.0438], loss_postrior_z: [110.1667]]
Epoch [50/100]: MSE_x: 2.1777, MSE_y: 1.2043, MSE_v: 0.9657
Epoch 51/100: 100%|██████████| 625/625 [00:11<00:00, 54.96batch/s, loss_px_z: [1.0339], loss_mse_x: [1.5230], loss_py_z: [0.9928], loss_mse_y: [1.1792], loss_pv_z: [104.4264], loss_mse_v: [0.9691], loss_postrior_z: [102.5903]]
Epoch 52/100: 100%|██████████| 625/625 [00:11<00:00, 54.79batch/s, loss_px_z: [0.7486], loss_mse_x: [0.7666], loss_py_z: [1.1842], loss_mse_y: [1.5948], loss_pv_z: [102.5457], loss_mse_v: [0.9553], loss_postrior_z: [101.0967]]
Epoch 53/100: 100%|██████████| 625/625 [00:11<00:00, 54.98batch/s, loss_px_z: [0.3077], loss_mse_x: [1.0589], loss_py_z: [1.0254], loss_mse_y: [1.2029], loss_pv_z: [104.0903], loss_mse_v: [0.9721], loss_postrior_z: [101.5168]]
Epoch 54/100: 100%|██████████| 625/625 [00:11<00:00, 54.34batch/s, loss_px_z: [0.4946], loss_mse_x: [0.9493], loss_py_z: [0.9811], loss_mse_y: [1.1906], loss_pv_z: [103.0164], loss_mse_v: [0.9565], loss_postrior_z: [100.3452]]
Epoch 55/100: 100%|██████████| 625/625 [00:11<00:00, 55.08batch/s, loss_px_z: [0.3406], loss_mse_x: [5.0589], loss_py_z: [0.8992], loss_mse_y: [1.0087], loss_pv_z: [98.9583], loss_mse_v: [0.9173], loss_postrior_z: [97.5815]]
Epoch 56/100: 100%|██████████| 625/625 [00:11<00:00, 54.77batch/s, loss_px_z: [0.3262], loss_mse_x: [0.3477], loss_py_z: [1.0027], loss_mse_y: [1.1875], loss_pv_z: [107.3604], loss_mse_v: [1.0007], loss_postrior_z: [104.1735]]
Epoch 57/100: 100%|██████████| 625/625 [00:11<00:00, 55.67batch/s, loss_px_z: [0.2425], loss_mse_x: [0.5210], loss_py_z: [0.8429], loss_mse_y: [0.8359], loss_pv_z: [105.6930], loss_mse_v: [0.9835], loss_postrior_z: [101.8994]]
Epoch 58/100: 100%|██████████| 625/625 [00:11<00:00, 55.48batch/s, loss_px_z: [0.4562], loss_mse_x: [1.4668], loss_py_z: [0.9041], loss_mse_y: [1.0505], loss_pv_z: [106.8178], loss_mse_v: [0.9959], loss_postrior_z: [105.5796]]
Epoch 59/100: 100%|██████████| 625/625 [00:11<00:00, 54.85batch/s, loss_px_z: [0.1841], loss_mse_x: [0.9394], loss_py_z: [1.0556], loss_mse_y: [1.6084], loss_pv_z: [95.1684], loss_mse_v: [0.8880], loss_postrior_z: [93.4943]]
Epoch 60/100: 100%|██████████| 625/625 [00:11<00:00, 54.79batch/s, loss_px_z: [0.4262], loss_mse_x: [0.3860], loss_py_z: [0.7917], loss_mse_y: [0.8005], loss_pv_z: [104.1415], loss_mse_v: [0.9713], loss_postrior_z: [102.0446]]
Epoch [60/100]: MSE_x: 2.1817, MSE_y: 1.1895, MSE_v: 0.9652
Epoch 61/100: 100%|██████████| 625/625 [00:11<00:00, 55.19batch/s, loss_px_z: [0.3869], loss_mse_x: [0.7310], loss_py_z: [1.0391], loss_mse_y: [1.3771], loss_pv_z: [104.6561], loss_mse_v: [0.9732], loss_postrior_z: [101.3780]]
Epoch 62/100: 100%|██████████| 625/625 [00:11<00:00, 54.76batch/s, loss_px_z: [0.5479], loss_mse_x: [0.6498], loss_py_z: [0.8399], loss_mse_y: [0.8538], loss_pv_z: [101.8596], loss_mse_v: [0.9493], loss_postrior_z: [99.5369]]
Epoch 63/100: 100%|██████████| 625/625 [00:11<00:00, 55.01batch/s, loss_px_z: [0.5373], loss_mse_x: [0.9175], loss_py_z: [0.9829], loss_mse_y: [1.1854], loss_pv_z: [102.4637], loss_mse_v: [0.9533], loss_postrior_z: [100.1760]]
Epoch 64/100: 100%|██████████| 625/625 [00:11<00:00, 55.09batch/s, loss_px_z: [0.2867], loss_mse_x: [1.5419], loss_py_z: [0.9317], loss_mse_y: [1.2228], loss_pv_z: [96.9979], loss_mse_v: [0.9007], loss_postrior_z: [95.0470]]
Epoch 65/100: 100%|██████████| 625/625 [00:11<00:00, 55.44batch/s, loss_px_z: [0.6300], loss_mse_x: [1.4146], loss_py_z: [0.8759], loss_mse_y: [0.9186], loss_pv_z: [107.2320], loss_mse_v: [1.0003], loss_postrior_z: [105.5436]]
Epoch 66/100: 100%|██████████| 625/625 [00:11<00:00, 53.91batch/s, loss_px_z: [0.3859], loss_mse_x: [2.8775], loss_py_z: [0.9805], loss_mse_y: [1.2691], loss_pv_z: [104.1365], loss_mse_v: [0.9740], loss_postrior_z: [104.6624]]
Epoch 67/100: 100%|██████████| 625/625 [00:11<00:00, 55.30batch/s, loss_px_z: [0.5155], loss_mse_x: [2.5749], loss_py_z: [0.8094], loss_mse_y: [0.8627], loss_pv_z: [104.7769], loss_mse_v: [0.9788], loss_postrior_z: [102.0515]]
Epoch 68/100: 100%|██████████| 625/625 [00:11<00:00, 55.20batch/s, loss_px_z: [0.2483], loss_mse_x: [0.5678], loss_py_z: [1.0396], loss_mse_y: [1.2972], loss_pv_z: [104.7215], loss_mse_v: [0.9814], loss_postrior_z: [102.0759]]
Epoch 69/100: 100%|██████████| 625/625 [00:11<00:00, 55.26batch/s, loss_px_z: [1.4218], loss_mse_x: [1.1111], loss_py_z: [0.8525], loss_mse_y: [0.9750], loss_pv_z: [103.8017], loss_mse_v: [0.9686], loss_postrior_z: [102.3827]]
Epoch 70/100: 100%|██████████| 625/625 [00:11<00:00, 54.61batch/s, loss_px_z: [0.2590], loss_mse_x: [0.5932], loss_py_z: [0.8803], loss_mse_y: [1.1457], loss_pv_z: [104.2802], loss_mse_v: [0.9745], loss_postrior_z: [101.8869]]
Epoch [70/100]: MSE_x: 2.0605, MSE_y: 1.2068, MSE_v: 0.9650
Epoch 71/100: 100%|██████████| 625/625 [00:11<00:00, 54.63batch/s, loss_px_z: [0.9122], loss_mse_x: [4.1757], loss_py_z: [0.8866], loss_mse_y: [1.1041], loss_pv_z: [99.0652], loss_mse_v: [0.9213], loss_postrior_z: [98.8926]]
Epoch 72/100: 100%|██████████| 625/625 [00:11<00:00, 55.25batch/s, loss_px_z: [0.7746], loss_mse_x: [1.6122], loss_py_z: [0.7496], loss_mse_y: [0.7539], loss_pv_z: [104.4348], loss_mse_v: [0.9758], loss_postrior_z: [101.8911]]
Epoch 73/100: 100%|██████████| 625/625 [00:11<00:00, 55.42batch/s, loss_px_z: [0.5040], loss_mse_x: [1.4536], loss_py_z: [0.9477], loss_mse_y: [1.1859], loss_pv_z: [103.3225], loss_mse_v: [0.9635], loss_postrior_z: [101.3829]]
Epoch 74/100: 100%|██████████| 625/625 [00:11<00:00, 55.29batch/s, loss_px_z: [0.5007], loss_mse_x: [0.7652], loss_py_z: [0.9774], loss_mse_y: [1.3177], loss_pv_z: [102.7477], loss_mse_v: [0.9562], loss_postrior_z: [100.9115]]
Epoch 75/100: 100%|██████████| 625/625 [00:11<00:00, 55.45batch/s, loss_px_z: [0.6312], loss_mse_x: [0.5428], loss_py_z: [1.0763], loss_mse_y: [1.4158], loss_pv_z: [108.3980], loss_mse_v: [1.0175], loss_postrior_z: [106.3569]]
Epoch 76/100: 100%|██████████| 625/625 [00:11<00:00, 55.03batch/s, loss_px_z: [1.1616], loss_mse_x: [3.6411], loss_py_z: [0.9174], loss_mse_y: [1.1184], loss_pv_z: [101.4809], loss_mse_v: [0.9504], loss_postrior_z: [100.3156]]
Epoch 77/100: 100%|██████████| 625/625 [00:11<00:00, 55.45batch/s, loss_px_z: [0.2270], loss_mse_x: [1.4151], loss_py_z: [0.7666], loss_mse_y: [0.9107], loss_pv_z: [99.8907], loss_mse_v: [0.9310], loss_postrior_z: [97.7918]]
Epoch 78/100: 100%|██████████| 625/625 [00:11<00:00, 55.30batch/s, loss_px_z: [0.5104], loss_mse_x: [2.5735], loss_py_z: [1.0417], loss_mse_y: [1.3955], loss_pv_z: [102.6700], loss_mse_v: [0.9608], loss_postrior_z: [100.8437]]
Epoch 79/100: 100%|██████████| 625/625 [00:11<00:00, 55.60batch/s, loss_px_z: [0.2810], loss_mse_x: [1.1481], loss_py_z: [0.7788], loss_mse_y: [0.8689], loss_pv_z: [99.6064], loss_mse_v: [0.9245], loss_postrior_z: [99.1888]]
Epoch 80/100: 100%|██████████| 625/625 [00:11<00:00, 56.25batch/s, loss_px_z: [0.4661], loss_mse_x: [0.6606], loss_py_z: [0.9857], loss_mse_y: [1.2555], loss_pv_z: [97.2984], loss_mse_v: [0.9066], loss_postrior_z: [95.6442]]
Epoch [80/100]: MSE_x: 2.0278, MSE_y: 1.1949, MSE_v: 0.9645
Epoch 81/100: 100%|██████████| 625/625 [00:11<00:00, 54.68batch/s, loss_px_z: [0.3389], loss_mse_x: [1.4112], loss_py_z: [0.7984], loss_mse_y: [0.9120], loss_pv_z: [106.6279], loss_mse_v: [0.9946], loss_postrior_z: [105.3045]]
Epoch 82/100: 100%|██████████| 625/625 [00:11<00:00, 53.96batch/s, loss_px_z: [0.3469], loss_mse_x: [0.7758], loss_py_z: [0.7624], loss_mse_y: [0.9110], loss_pv_z: [96.9946], loss_mse_v: [0.9044], loss_postrior_z: [94.0614]]
Epoch 83/100: 100%|██████████| 625/625 [00:11<00:00, 53.81batch/s, loss_px_z: [0.2615], loss_mse_x: [0.3087], loss_py_z: [0.9122], loss_mse_y: [1.1561], loss_pv_z: [110.0109], loss_mse_v: [1.0239], loss_postrior_z: [107.2893]]
Epoch 84/100: 100%|██████████| 625/625 [00:11<00:00, 54.81batch/s, loss_px_z: [0.0938], loss_mse_x: [0.2996], loss_py_z: [0.8764], loss_mse_y: [1.0378], loss_pv_z: [104.0820], loss_mse_v: [0.9701], loss_postrior_z: [101.2576]]
Epoch 85/100: 100%|██████████| 625/625 [00:11<00:00, 54.97batch/s, loss_px_z: [0.7172], loss_mse_x: [1.4699], loss_py_z: [1.0360], loss_mse_y: [1.5079], loss_pv_z: [106.3758], loss_mse_v: [0.9988], loss_postrior_z: [104.7337]]
Epoch 86/100: 100%|██████████| 625/625 [00:11<00:00, 55.05batch/s, loss_px_z: [0.1134], loss_mse_x: [3.9313], loss_py_z: [0.8578], loss_mse_y: [1.3222], loss_pv_z: [99.4824], loss_mse_v: [0.9270], loss_postrior_z: [97.2417]]
Epoch 87/100: 100%|██████████| 625/625 [00:11<00:00, 54.46batch/s, loss_px_z: [0.1477], loss_mse_x: [0.4027], loss_py_z: [0.9585], loss_mse_y: [1.6592], loss_pv_z: [105.9451], loss_mse_v: [0.9952], loss_postrior_z: [104.9820]]
Epoch 88/100: 100%|██████████| 625/625 [00:11<00:00, 54.60batch/s, loss_px_z: [0.7088], loss_mse_x: [0.8604], loss_py_z: [0.9679], loss_mse_y: [1.2659], loss_pv_z: [109.0925], loss_mse_v: [1.0223], loss_postrior_z: [106.7311]]
Epoch 89/100: 100%|██████████| 625/625 [00:11<00:00, 54.62batch/s, loss_px_z: [0.3217], loss_mse_x: [0.6296], loss_py_z: [0.7441], loss_mse_y: [0.8663], loss_pv_z: [103.1253], loss_mse_v: [0.9651], loss_postrior_z: [101.4039]]
Epoch 90/100: 100%|██████████| 625/625 [00:11<00:00, 54.76batch/s, loss_px_z: [0.3576], loss_mse_x: [2.2540], loss_py_z: [0.9933], loss_mse_y: [1.3456], loss_pv_z: [103.8477], loss_mse_v: [0.9681], loss_postrior_z: [101.4713]]
Epoch [90/100]: MSE_x: 2.0580, MSE_y: 1.1633, MSE_v: 0.9644
Epoch 91/100: 100%|██████████| 625/625 [00:11<00:00, 54.66batch/s, loss_px_z: [0.1603], loss_mse_x: [0.4676], loss_py_z: [0.7444], loss_mse_y: [0.8241], loss_pv_z: [108.7655], loss_mse_v: [1.0234], loss_postrior_z: [106.3370]]
Epoch 92/100: 100%|██████████| 625/625 [00:11<00:00, 53.44batch/s, loss_px_z: [0.1892], loss_mse_x: [0.7174], loss_py_z: [1.1678], loss_mse_y: [1.8073], loss_pv_z: [113.5816], loss_mse_v: [1.0619], loss_postrior_z: [111.2462]]
Epoch 93/100: 100%|██████████| 625/625 [00:11<00:00, 54.54batch/s, loss_px_z: [0.1294], loss_mse_x: [0.7718], loss_py_z: [0.8597], loss_mse_y: [1.1468], loss_pv_z: [99.1875], loss_mse_v: [0.9276], loss_postrior_z: [97.8237]]
Epoch 94/100: 100%|██████████| 625/625 [00:11<00:00, 54.27batch/s, loss_px_z: [0.8639], loss_mse_x: [1.9431], loss_py_z: [1.0255], loss_mse_y: [1.4076], loss_pv_z: [102.7940], loss_mse_v: [0.9620], loss_postrior_z: [101.0259]]
Epoch 95/100: 100%|██████████| 625/625 [00:11<00:00, 55.64batch/s, loss_px_z: [0.4005], loss_mse_x: [1.7976], loss_py_z: [0.8490], loss_mse_y: [1.1049], loss_pv_z: [100.7964], loss_mse_v: [0.9436], loss_postrior_z: [98.7635]]
Epoch 96/100: 100%|██████████| 625/625 [00:11<00:00, 55.98batch/s, loss_px_z: [0.0117], loss_mse_x: [0.2462], loss_py_z: [0.7910], loss_mse_y: [0.9281], loss_pv_z: [109.1485], loss_mse_v: [1.0273], loss_postrior_z: [106.6896]]
Epoch 97/100: 100%|██████████| 625/625 [00:11<00:00, 52.94batch/s, loss_px_z: [0.2551], loss_mse_x: [0.8334], loss_py_z: [0.9307], loss_mse_y: [1.2324], loss_pv_z: [102.6908], loss_mse_v: [0.9649], loss_postrior_z: [100.4549]]
Epoch 98/100: 100%|██████████| 625/625 [00:11<00:00, 53.89batch/s, loss_px_z: [1.5356], loss_mse_x: [6.9516], loss_py_z: [0.8813], loss_mse_y: [1.1969], loss_pv_z: [103.2901], loss_mse_v: [0.9661], loss_postrior_z: [103.3188]]
Epoch 99/100: 100%|██████████| 625/625 [00:11<00:00, 55.43batch/s, loss_px_z: [0.3029], loss_mse_x: [2.1966], loss_py_z: [0.6606], loss_mse_y: [0.6747], loss_pv_z: [103.4311], loss_mse_v: [0.9697], loss_postrior_z: [101.9436]]
Epoch 100/100: 100%|██████████| 625/625 [00:11<00:00, 55.07batch/s, loss_px_z: [0.1279], loss_mse_x: [0.5865], loss_py_z: [0.7868], loss_mse_y: [0.9442], loss_pv_z: [105.4194], loss_mse_v: [0.9945], loss_postrior_z: [102.6229]]
Epoch [100/100]: MSE_x: 2.0460, MSE_y: 1.1746, MSE_v: 0.9638
Make predictions using the trained CausalBGM model
Estimate causal effects with posterior intervals from latent MCMC samples.
Config Parameter |
Description |
|---|---|
|
Tuple of data inputs |
|
Significance level for the posterior interval. Default: 0.01. |
|
Number of posterior MCMC samples to draw. Default: 3000. |
|
Number of burn-in MCMC samples before drawing. Default: 5000. |
|
Treatment value(s) for dose-response function to be predicted. Examples: 1.0 or [1.0,2.0] |
|
Standard deviation for the proposal distribution used in Metropolis-Hastings (MH) sampling. Default: 1.0. |
|
Whether to consider the variance function in the outcome generative model. Default: True. |
|
Batch size in inference stage, denoting number of test subjects processed per batch prediction. Default: 10000. |
Return |
Type |
Description |
Shape |
|---|---|---|---|
|
|
Point estimates of the Average Dose-Response Function. |
|
|
|
Posterior intervals for the ADRF, representing |
|
[6]:
pre_adrf_mean, pre_adrf_PI = model.predict(data=(x,y,v), alpha=0.01, n_mcmc=3000, burn_in=5000, x_values=np.linspace(0,3,20), q_sd=1.0, bs=20000)
MCMC Latent Variable Sampling ...
Final MCMC Acceptance Rate: 0.0948
Visualizating and evaluating the results
Show the true average dose-response function (ADRF) and the predicted ADRF.
[7]:
import matplotlib.pyplot as plt
from bayesgm.utils import get_ADRF
x_values = np.linspace(0,3,20) # treatment values
true_adrf = get_ADRF(x_values = x_values, dataset='Imbens') # true ADRF for Hirano and Imbens dataset
# Evaluate
rmse = np.sqrt(np.mean((true_adrf-pre_adrf_mean)**2))
mape = np.mean([abs((item[0]-item[1])/item[0]) for item in zip(true_adrf, pre_adrf_mean)])
print(f"RMSE (Root Mean Squared Error): {rmse:.4f}")
print(f"MAPE (Mean Absolute Percentage Error): {mape:.4f}")
# Create the plot
plt.figure(figsize=(10, 6))
# Plot the ground truth curve
plt.plot(x_values, true_adrf, label='True ADRF', color='blue', linestyle='--', linewidth=2)
# Plot the predicted mean curve
plt.plot(x_values, pre_adrf_mean, label='Predicted ADRF', color='red', linewidth=2)
# Plot the posterior intervals
plt.fill_between(x_values, pre_adrf_PI[:, 0], pre_adrf_PI[:, 1], color='red', alpha=0.4, label='Posterior Interval')
# Add labels, legend, and title
plt.xlabel('Treatment Value', fontsize=12)
plt.ylabel('ADRF', fontsize=12)
plt.title('True vs Predicted ADRF (RMSE=%.4f,MAPE=%.4f)'%(rmse, mape), fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, linestyle='--', alpha=0.6)
# Show the plot
plt.show()
RMSE (Root Mean Squared Error): 0.0188
MAPE (Mean Absolute Percentage Error): 0.0103
Use CausalBGM Python API in binary treatment setting
We will use ACIC 2018 dataset for an example.
Reminder
Set binary_treatment to True in the config file.
Make sure the v_dim matches the number of covariates in the dateset.
[8]:
params = yaml.safe_load(open('src/configs/Semi_acic.yaml', 'r'))
print(params)
{'dataset': 'Semi_acic', 'output_dir': '.', 'save_res': True, 'save_model': True, 'binary_treatment': True, 'use_bnn': True, 'z_dims': [3, 6, 3, 6], 'v_dim': 177, 'lr_theta': 0.0001, 'lr_z': 0.0001, 'g_units': [64, 64, 64, 64, 64], 'f_units': [64, 32, 8], 'h_units': [64, 32, 8], 'kl_weight': 0.0001, 'lr': 0.0002, 'g_d_freq': 5, 'use_z_rec': True, 'e_units': [64, 64, 64, 64, 64], 'dz_units': [64, 32, 8]}
Instantiate a CausalBGM model
[9]:
model = bayesgm.models.CausalBGM(params=params,random_seed=None)
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:95: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
loc = add_variable_fn(
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:105: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
untransformed_scale = add_variable_fn(
Data preparation
The input data are organized in a triplet, which contains treatment (X), potential outcome (Y), and covariates (V).
[10]:
#get the data from the ACIC 2018 competition dataset with a specified ufid.
x,y,v = bayesgm.datasets.Semi_acic_sampler(path='data/ACIC_2018',ufid='629e3d2c63914e45b227cc913c09cebe').load_all()
print(x.shape,y.shape,v.shape)
(1000, 1) (1000, 1) (1000, 177)
Model training
Train CausalBGM with an optional EGM warm-start.
Config Parameter |
Description |
|---|---|
|
Tuple of data inputs |
|
Batch size for training. Default: 32. |
|
Number of epochs for training. Default: 100. |
|
Frequency of evaluations during training (e.g., every 5 epochs). Default: 5. |
|
Whether to run EGM initialization before iterative training. Default: True. |
|
Frequency of evaluations during training (e.g., every 5 epochs). Default: 5. |
|
Number of EGM initialization iterations. Default: 30000. |
|
Evaluate EGM initialization every this many iterations. Default: 500. |
|
Controls verbosity level, showing progress and evaluation metrics. Default: 1. |
[11]:
model.fit(data=(x,y,v), epochs=100, epochs_per_eval=10, use_egm_init=True, egm_n_iter=30000, egm_batches_per_eval=500, verbose=1)
EGM Initialization Starts ...
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:95: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
loc = add_variable_fn(
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:105: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
untransformed_scale = add_variable_fn(
EGM Initialization Iter [0] : e_loss_adv [0.1102], l2_loss_v [0.8648], l2_loss_z [1.0305], l2_loss_x [0.6933], l2_loss_y [0.5305], g_e_loss [3.2293], dz_loss [0.0733], d_loss [0.3746]
EGM Initialization Iter [500] : e_loss_adv [2.2379], l2_loss_v [1.0616], l2_loss_z [0.9627], l2_loss_x [0.6788], l2_loss_y [0.1069], g_e_loss [5.0480], dz_loss [-1.7182], d_loss [-1.2625]
EGM Initialization Iter [1000] : e_loss_adv [-0.0245], l2_loss_v [1.4840], l2_loss_z [1.0821], l2_loss_x [0.6829], l2_loss_y [0.0910], g_e_loss [3.3154], dz_loss [-1.3309], d_loss [-1.2576]
EGM Initialization Iter [1500] : e_loss_adv [-1.6068], l2_loss_v [0.4169], l2_loss_z [1.0244], l2_loss_x [0.6923], l2_loss_y [0.0486], g_e_loss [0.5754], dz_loss [-1.0655], d_loss [-0.9715]
EGM Initialization Iter [2000] : e_loss_adv [-1.6149], l2_loss_v [0.6514], l2_loss_z [0.8222], l2_loss_x [0.6996], l2_loss_y [0.1217], g_e_loss [0.6800], dz_loss [-0.8808], d_loss [-0.6829]
EGM Initialization Iter [2500] : e_loss_adv [-1.7912], l2_loss_v [1.6250], l2_loss_z [1.0236], l2_loss_x [0.6911], l2_loss_y [0.1043], g_e_loss [1.6528], dz_loss [-0.5761], d_loss [-0.5062]
EGM Initialization Iter [3000] : e_loss_adv [-2.6811], l2_loss_v [0.6036], l2_loss_z [0.8955], l2_loss_x [0.6970], l2_loss_y [0.1193], g_e_loss [-0.3656], dz_loss [-1.3826], d_loss [-1.1942]
EGM Initialization Iter [3500] : e_loss_adv [-2.8498], l2_loss_v [0.6821], l2_loss_z [0.8155], l2_loss_x [0.6820], l2_loss_y [0.0911], g_e_loss [-0.5792], dz_loss [-0.2488], d_loss [-0.1469]
EGM Initialization Iter [4000] : e_loss_adv [-2.9439], l2_loss_v [0.4767], l2_loss_z [0.9369], l2_loss_x [0.6887], l2_loss_y [0.0918], g_e_loss [-0.7499], dz_loss [-0.6104], d_loss [-0.4594]
EGM Initialization Iter [4500] : e_loss_adv [-2.6417], l2_loss_v [0.7067], l2_loss_z [0.9116], l2_loss_x [0.7060], l2_loss_y [0.1264], g_e_loss [-0.1910], dz_loss [-1.3566], d_loss [-1.0913]
EGM Initialization Iter [5000] : e_loss_adv [-1.0904], l2_loss_v [0.9057], l2_loss_z [0.9335], l2_loss_x [0.7084], l2_loss_y [0.1196], g_e_loss [1.5769], dz_loss [-0.4903], d_loss [-0.3944]
EGM Initialization Iter [5500] : e_loss_adv [-0.8474], l2_loss_v [0.6684], l2_loss_z [0.8527], l2_loss_x [0.6900], l2_loss_y [0.0964], g_e_loss [1.4600], dz_loss [-0.6281], d_loss [-0.5299]
EGM Initialization Iter [6000] : e_loss_adv [-0.8948], l2_loss_v [0.7359], l2_loss_z [0.9014], l2_loss_x [0.6886], l2_loss_y [0.0662], g_e_loss [1.4973], dz_loss [-0.3875], d_loss [-0.2597]
EGM Initialization Iter [6500] : e_loss_adv [-1.1519], l2_loss_v [0.4320], l2_loss_z [0.9488], l2_loss_x [0.6827], l2_loss_y [0.1137], g_e_loss [1.0252], dz_loss [-0.4316], d_loss [-0.3879]
EGM Initialization Iter [7000] : e_loss_adv [-0.1190], l2_loss_v [0.5975], l2_loss_z [0.9308], l2_loss_x [0.6818], l2_loss_y [0.1010], g_e_loss [2.1922], dz_loss [-0.4173], d_loss [-0.3681]
EGM Initialization Iter [7500] : e_loss_adv [-0.0798], l2_loss_v [0.6064], l2_loss_z [0.9438], l2_loss_x [0.6962], l2_loss_y [0.0829], g_e_loss [2.2496], dz_loss [-0.8070], d_loss [-0.7731]
EGM Initialization Iter [8000] : e_loss_adv [2.0215], l2_loss_v [0.4414], l2_loss_z [1.0132], l2_loss_x [0.6978], l2_loss_y [0.0823], g_e_loss [4.2563], dz_loss [-0.0485], d_loss [0.0026]
EGM Initialization Iter [8500] : e_loss_adv [1.7494], l2_loss_v [0.5732], l2_loss_z [0.9103], l2_loss_x [0.6829], l2_loss_y [0.0619], g_e_loss [3.9777], dz_loss [-0.5275], d_loss [-0.4650]
EGM Initialization Iter [9000] : e_loss_adv [-0.1547], l2_loss_v [0.5875], l2_loss_z [0.9173], l2_loss_x [0.6891], l2_loss_y [0.0647], g_e_loss [2.1039], dz_loss [-0.6923], d_loss [-0.6058]
EGM Initialization Iter [9500] : e_loss_adv [-1.8302], l2_loss_v [0.5040], l2_loss_z [0.9089], l2_loss_x [0.6838], l2_loss_y [0.0822], g_e_loss [0.3487], dz_loss [-0.4017], d_loss [-0.3120]
EGM Initialization Iter [10000] : e_loss_adv [-1.1754], l2_loss_v [0.4594], l2_loss_z [0.8607], l2_loss_x [0.6830], l2_loss_y [0.0497], g_e_loss [0.8774], dz_loss [-0.7414], d_loss [-0.6600]
EGM Initialization Iter [10500] : e_loss_adv [0.0125], l2_loss_v [0.4014], l2_loss_z [0.8827], l2_loss_x [0.6823], l2_loss_y [0.0821], g_e_loss [2.0610], dz_loss [-0.4633], d_loss [-0.4381]
EGM Initialization Iter [11000] : e_loss_adv [1.2771], l2_loss_v [0.8815], l2_loss_z [0.7977], l2_loss_x [0.6933], l2_loss_y [0.1165], g_e_loss [3.7661], dz_loss [-0.5535], d_loss [-0.4989]
EGM Initialization Iter [11500] : e_loss_adv [0.8197], l2_loss_v [0.6996], l2_loss_z [0.9713], l2_loss_x [0.6715], l2_loss_y [0.0668], g_e_loss [3.2289], dz_loss [-0.5401], d_loss [-0.4833]
EGM Initialization Iter [12000] : e_loss_adv [1.1846], l2_loss_v [0.4105], l2_loss_z [0.8513], l2_loss_x [0.6902], l2_loss_y [0.0928], g_e_loss [3.2294], dz_loss [-0.5215], d_loss [-0.4562]
EGM Initialization Iter [12500] : e_loss_adv [0.5591], l2_loss_v [0.5461], l2_loss_z [0.8962], l2_loss_x [0.6943], l2_loss_y [0.1157], g_e_loss [2.8115], dz_loss [-0.2863], d_loss [-0.2228]
EGM Initialization Iter [13000] : e_loss_adv [0.5567], l2_loss_v [0.5537], l2_loss_z [0.9396], l2_loss_x [0.6648], l2_loss_y [0.0530], g_e_loss [2.7678], dz_loss [-0.8052], d_loss [-0.7681]
EGM Initialization Iter [13500] : e_loss_adv [-0.4126], l2_loss_v [0.6628], l2_loss_z [0.8893], l2_loss_x [0.7001], l2_loss_y [0.0652], g_e_loss [1.9049], dz_loss [0.1546], d_loss [0.2192]
EGM Initialization Iter [14000] : e_loss_adv [1.3701], l2_loss_v [0.4984], l2_loss_z [0.7432], l2_loss_x [0.6523], l2_loss_y [0.0843], g_e_loss [3.3483], dz_loss [-0.1833], d_loss [-0.1407]
EGM Initialization Iter [14500] : e_loss_adv [0.3630], l2_loss_v [0.5708], l2_loss_z [0.8929], l2_loss_x [0.7033], l2_loss_y [0.0717], g_e_loss [2.6017], dz_loss [-0.1090], d_loss [-0.0825]
EGM Initialization Iter [15000] : e_loss_adv [1.1808], l2_loss_v [0.4881], l2_loss_z [0.7314], l2_loss_x [0.6782], l2_loss_y [0.1120], g_e_loss [3.1905], dz_loss [-0.4781], d_loss [-0.4294]
EGM Initialization Iter [15500] : e_loss_adv [0.7721], l2_loss_v [0.5489], l2_loss_z [0.7299], l2_loss_x [0.7268], l2_loss_y [0.1136], g_e_loss [2.8913], dz_loss [-0.4289], d_loss [-0.3480]
EGM Initialization Iter [16000] : e_loss_adv [0.5866], l2_loss_v [0.3066], l2_loss_z [0.8738], l2_loss_x [0.6267], l2_loss_y [0.0903], g_e_loss [2.4841], dz_loss [-0.5187], d_loss [-0.4643]
EGM Initialization Iter [16500] : e_loss_adv [2.6886], l2_loss_v [0.6556], l2_loss_z [0.9029], l2_loss_x [0.6635], l2_loss_y [0.0617], g_e_loss [4.9723], dz_loss [-0.0725], d_loss [-0.0014]
EGM Initialization Iter [17000] : e_loss_adv [0.6549], l2_loss_v [0.7038], l2_loss_z [0.8277], l2_loss_x [0.6232], l2_loss_y [0.0827], g_e_loss [2.8924], dz_loss [-0.2632], d_loss [-0.1660]
EGM Initialization Iter [17500] : e_loss_adv [-0.3580], l2_loss_v [0.5062], l2_loss_z [0.8819], l2_loss_x [0.6791], l2_loss_y [0.0736], g_e_loss [1.7828], dz_loss [-0.1837], d_loss [-0.0903]
EGM Initialization Iter [18000] : e_loss_adv [-0.6044], l2_loss_v [0.4451], l2_loss_z [0.8348], l2_loss_x [0.7131], l2_loss_y [0.0517], g_e_loss [1.4402], dz_loss [-0.3551], d_loss [-0.2829]
EGM Initialization Iter [18500] : e_loss_adv [-0.0970], l2_loss_v [0.3223], l2_loss_z [0.7733], l2_loss_x [0.6027], l2_loss_y [0.0677], g_e_loss [1.6691], dz_loss [-0.4478], d_loss [-0.3706]
EGM Initialization Iter [19000] : e_loss_adv [-0.3036], l2_loss_v [0.4483], l2_loss_z [0.9154], l2_loss_x [0.7860], l2_loss_y [0.0749], g_e_loss [1.9210], dz_loss [-0.0080], d_loss [0.0634]
EGM Initialization Iter [19500] : e_loss_adv [1.4144], l2_loss_v [0.4755], l2_loss_z [0.8437], l2_loss_x [0.5943], l2_loss_y [0.1192], g_e_loss [3.4471], dz_loss [-0.3931], d_loss [-0.3437]
EGM Initialization Iter [20000] : e_loss_adv [1.5499], l2_loss_v [0.5318], l2_loss_z [0.7672], l2_loss_x [0.6411], l2_loss_y [0.0971], g_e_loss [3.5870], dz_loss [-0.3918], d_loss [-0.3298]
EGM Initialization Iter [20500] : e_loss_adv [0.5439], l2_loss_v [0.4782], l2_loss_z [0.8471], l2_loss_x [0.6008], l2_loss_y [0.0701], g_e_loss [2.5401], dz_loss [-0.2296], d_loss [-0.1909]
EGM Initialization Iter [21000] : e_loss_adv [0.5634], l2_loss_v [0.3514], l2_loss_z [0.8784], l2_loss_x [0.6237], l2_loss_y [0.0578], g_e_loss [2.4749], dz_loss [-0.0582], d_loss [-0.0211]
EGM Initialization Iter [21500] : e_loss_adv [0.1586], l2_loss_v [1.0522], l2_loss_z [0.8501], l2_loss_x [0.5285], l2_loss_y [0.0775], g_e_loss [2.6669], dz_loss [-0.4601], d_loss [-0.3838]
EGM Initialization Iter [22000] : e_loss_adv [0.4936], l2_loss_v [0.4876], l2_loss_z [0.8209], l2_loss_x [0.5598], l2_loss_y [0.0686], g_e_loss [2.4305], dz_loss [-0.5799], d_loss [-0.5045]
EGM Initialization Iter [22500] : e_loss_adv [0.5633], l2_loss_v [0.8955], l2_loss_z [0.9021], l2_loss_x [0.4840], l2_loss_y [0.0793], g_e_loss [2.9243], dz_loss [0.0632], d_loss [0.1201]
EGM Initialization Iter [23000] : e_loss_adv [-0.6131], l2_loss_v [1.0310], l2_loss_z [0.8269], l2_loss_x [0.5019], l2_loss_y [0.0767], g_e_loss [1.8233], dz_loss [-0.4884], d_loss [-0.4198]
EGM Initialization Iter [23500] : e_loss_adv [0.2808], l2_loss_v [0.3826], l2_loss_z [0.7574], l2_loss_x [0.6134], l2_loss_y [0.0854], g_e_loss [2.1196], dz_loss [0.0413], d_loss [0.1513]
EGM Initialization Iter [24000] : e_loss_adv [0.7267], l2_loss_v [0.4576], l2_loss_z [0.7689], l2_loss_x [0.5064], l2_loss_y [0.1214], g_e_loss [2.5811], dz_loss [-0.6373], d_loss [-0.5943]
EGM Initialization Iter [24500] : e_loss_adv [1.3909], l2_loss_v [0.5250], l2_loss_z [0.8024], l2_loss_x [0.4561], l2_loss_y [0.0664], g_e_loss [3.2408], dz_loss [-0.1226], d_loss [-0.0769]
EGM Initialization Iter [25000] : e_loss_adv [-0.9248], l2_loss_v [0.4964], l2_loss_z [0.8490], l2_loss_x [0.5210], l2_loss_y [0.1215], g_e_loss [1.0632], dz_loss [0.2103], d_loss [0.2916]
EGM Initialization Iter [25500] : e_loss_adv [-0.7704], l2_loss_v [0.5056], l2_loss_z [0.7879], l2_loss_x [0.3088], l2_loss_y [0.0530], g_e_loss [0.8849], dz_loss [0.2377], d_loss [0.3105]
EGM Initialization Iter [26000] : e_loss_adv [-0.3391], l2_loss_v [0.4905], l2_loss_z [0.8404], l2_loss_x [0.4700], l2_loss_y [0.0738], g_e_loss [1.5356], dz_loss [-0.0343], d_loss [0.0082]
EGM Initialization Iter [26500] : e_loss_adv [-0.4712], l2_loss_v [0.3574], l2_loss_z [0.7607], l2_loss_x [0.4296], l2_loss_y [0.0830], g_e_loss [1.1594], dz_loss [-0.2168], d_loss [-0.1930]
EGM Initialization Iter [27000] : e_loss_adv [0.7430], l2_loss_v [0.5136], l2_loss_z [0.8916], l2_loss_x [0.4302], l2_loss_y [0.0700], g_e_loss [2.6484], dz_loss [-0.6683], d_loss [-0.5880]
EGM Initialization Iter [27500] : e_loss_adv [0.3668], l2_loss_v [0.4964], l2_loss_z [0.9325], l2_loss_x [0.3674], l2_loss_y [0.0685], g_e_loss [2.2315], dz_loss [0.1753], d_loss [0.2039]
EGM Initialization Iter [28000] : e_loss_adv [-0.9352], l2_loss_v [0.5420], l2_loss_z [0.8413], l2_loss_x [0.4752], l2_loss_y [0.0727], g_e_loss [0.9959], dz_loss [-0.3604], d_loss [-0.2736]
EGM Initialization Iter [28500] : e_loss_adv [-0.2398], l2_loss_v [0.4868], l2_loss_z [0.7556], l2_loss_x [0.4528], l2_loss_y [0.0851], g_e_loss [1.5405], dz_loss [-0.3888], d_loss [-0.2976]
EGM Initialization Iter [29000] : e_loss_adv [-0.6478], l2_loss_v [0.5740], l2_loss_z [0.8240], l2_loss_x [0.3217], l2_loss_y [0.0874], g_e_loss [1.1594], dz_loss [-0.1161], d_loss [-0.0649]
EGM Initialization Iter [29500] : e_loss_adv [-0.2740], l2_loss_v [0.5155], l2_loss_z [0.6885], l2_loss_x [0.2661], l2_loss_y [0.0785], g_e_loss [1.2745], dz_loss [-0.3944], d_loss [-0.3385]
EGM Initialization Iter [30000] : e_loss_adv [-0.7436], l2_loss_v [0.5136], l2_loss_z [0.8334], l2_loss_x [0.1789], l2_loss_y [0.0681], g_e_loss [0.8504], dz_loss [-0.0883], d_loss [-0.0273]
EGM Initialization Ends.
Initialize latent variables Z with e(V)...
Iterative Updating Starts ...
Epoch 0/100: 100%|██████████| 32/32 [00:10<00:00, 3.09batch/s, loss_px_z: [0.9413], loss_mse_x: [0.2514], loss_py_z: [0.6371], loss_mse_y: [0.0882], loss_pv_z: [22.3034], loss_mse_v: [0.3698], loss_postrior_z: [30.3454]]
Epoch [0/100]: MSE_x: 0.0967, MSE_y: 0.0816, MSE_v: 0.5076
Saving checkpoint for epoch 0 at ./checkpoints/Semi_acic/20260311_112445/ckpt-0
Epoch 1/100: 100%|██████████| 32/32 [00:00<00:00, 52.06batch/s, loss_px_z: [1.0814], loss_mse_x: [0.3921], loss_py_z: [0.6501], loss_mse_y: [0.1115], loss_pv_z: [59.3399], loss_mse_v: [0.6598], loss_postrior_z: [65.1993]]
Epoch 2/100: 100%|██████████| 32/32 [00:00<00:00, 52.19batch/s, loss_px_z: [1.1222], loss_mse_x: [0.4334], loss_py_z: [0.5929], loss_mse_y: [0.0376], loss_pv_z: [51.3154], loss_mse_v: [0.6012], loss_postrior_z: [56.4026]]
Epoch 3/100: 100%|██████████| 32/32 [00:00<00:00, 52.17batch/s, loss_px_z: [1.0116], loss_mse_x: [0.3234], loss_py_z: [0.5930], loss_mse_y: [0.0414], loss_pv_z: [24.1223], loss_mse_v: [0.3872], loss_postrior_z: [28.4659]]
Epoch 4/100: 100%|██████████| 32/32 [00:00<00:00, 52.53batch/s, loss_px_z: [0.8006], loss_mse_x: [0.1130], loss_py_z: [0.5958], loss_mse_y: [0.0510], loss_pv_z: [20.5702], loss_mse_v: [0.3585], loss_postrior_z: [34.4404]]
Epoch 5/100: 100%|██████████| 32/32 [00:00<00:00, 51.60batch/s, loss_px_z: [1.1752], loss_mse_x: [0.4883], loss_py_z: [0.6078], loss_mse_y: [0.0713], loss_pv_z: [18.7218], loss_mse_v: [0.3485], loss_postrior_z: [24.4937]]
Epoch 6/100: 100%|██████████| 32/32 [00:00<00:00, 50.58batch/s, loss_px_z: [0.8292], loss_mse_x: [0.1428], loss_py_z: [0.7203], loss_mse_y: [0.2263], loss_pv_z: [15.9238], loss_mse_v: [0.3311], loss_postrior_z: [22.7411]]
Epoch 7/100: 100%|██████████| 32/32 [00:00<00:00, 50.07batch/s, loss_px_z: [1.1182], loss_mse_x: [0.4325], loss_py_z: [0.6544], loss_mse_y: [0.1425], loss_pv_z: [49.3744], loss_mse_v: [0.5830], loss_postrior_z: [58.7593]]
Epoch 8/100: 100%|██████████| 32/32 [00:00<00:00, 53.07batch/s, loss_px_z: [0.9144], loss_mse_x: [0.2292], loss_py_z: [0.5907], loss_mse_y: [0.0616], loss_pv_z: [48.1219], loss_mse_v: [0.5634], loss_postrior_z: [59.8836]]
Epoch 9/100: 100%|██████████| 32/32 [00:00<00:00, 54.61batch/s, loss_px_z: [1.2536], loss_mse_x: [0.5691], loss_py_z: [0.5661], loss_mse_y: [0.0364], loss_pv_z: [10.2657], loss_mse_v: [0.2887], loss_postrior_z: [17.2223]]
Epoch 10/100: 100%|██████████| 32/32 [00:00<00:00, 55.06batch/s, loss_px_z: [0.9793], loss_mse_x: [0.2953], loss_py_z: [0.5569], loss_mse_y: [0.0274], loss_pv_z: [37.2485], loss_mse_v: [0.4839], loss_postrior_z: [44.8024]]
Epoch [10/100]: MSE_x: 0.0956, MSE_y: 0.0811, MSE_v: 0.4902
Saving checkpoint for epoch 10 at ./checkpoints/Semi_acic/20260311_112445/ckpt-10
Epoch 11/100: 100%|██████████| 32/32 [00:00<00:00, 54.95batch/s, loss_px_z: [1.1522], loss_mse_x: [0.4688], loss_py_z: [0.6065], loss_mse_y: [0.0960], loss_pv_z: [2.2391], loss_mse_v: [0.2359], loss_postrior_z: [11.8110]]
Epoch 12/100: 100%|██████████| 32/32 [00:00<00:00, 55.30batch/s, loss_px_z: [0.9200], loss_mse_x: [0.2372], loss_py_z: [0.5623], loss_mse_y: [0.0459], loss_pv_z: [20.5715], loss_mse_v: [0.3698], loss_postrior_z: [22.2090]]
Epoch 13/100: 100%|██████████| 32/32 [00:00<00:00, 59.66batch/s, loss_px_z: [1.0428], loss_mse_x: [0.3606], loss_py_z: [0.5959], loss_mse_y: [0.0940], loss_pv_z: [27.0938], loss_mse_v: [0.4216], loss_postrior_z: [33.3546]]
Epoch 14/100: 100%|██████████| 32/32 [00:00<00:00, 59.17batch/s, loss_px_z: [1.3462], loss_mse_x: [0.6646], loss_py_z: [0.5644], loss_mse_y: [0.0568], loss_pv_z: [34.1856], loss_mse_v: [0.5718], loss_postrior_z: [40.3281]]
Epoch 15/100: 100%|██████████| 32/32 [00:00<00:00, 60.14batch/s, loss_px_z: [0.9129], loss_mse_x: [0.2318], loss_py_z: [0.6260], loss_mse_y: [0.1425], loss_pv_z: [57.6632], loss_mse_v: [0.6218], loss_postrior_z: [65.6941]]
Epoch 16/100: 100%|██████████| 32/32 [00:00<00:00, 60.19batch/s, loss_px_z: [1.4620], loss_mse_x: [0.7816], loss_py_z: [0.5453], loss_mse_y: [0.0444], loss_pv_z: [3.3863], loss_mse_v: [0.2562], loss_postrior_z: [10.9483]]
Epoch 17/100: 100%|██████████| 32/32 [00:00<00:00, 60.11batch/s, loss_px_z: [0.8324], loss_mse_x: [0.1526], loss_py_z: [0.5713], loss_mse_y: [0.0809], loss_pv_z: [31.0689], loss_mse_v: [0.4438], loss_postrior_z: [36.7993]]
Epoch 18/100: 100%|██████████| 32/32 [00:00<00:00, 60.18batch/s, loss_px_z: [0.9134], loss_mse_x: [0.2342], loss_py_z: [0.5573], loss_mse_y: [0.0685], loss_pv_z: [29.7325], loss_mse_v: [0.4428], loss_postrior_z: [37.8426]]
Epoch 19/100: 100%|██████████| 32/32 [00:00<00:00, 59.97batch/s, loss_px_z: [1.0431], loss_mse_x: [0.3645], loss_py_z: [0.6281], loss_mse_y: [0.1604], loss_pv_z: [21.2909], loss_mse_v: [0.3899], loss_postrior_z: [28.2614]]
Epoch 20/100: 100%|██████████| 32/32 [00:00<00:00, 60.11batch/s, loss_px_z: [0.9431], loss_mse_x: [0.2650], loss_py_z: [0.5509], loss_mse_y: [0.0707], loss_pv_z: [20.4603], loss_mse_v: [0.3859], loss_postrior_z: [23.0501]]
Epoch [20/100]: MSE_x: 0.0929, MSE_y: 0.0824, MSE_v: 0.4880
Epoch 21/100: 100%|██████████| 32/32 [00:00<00:00, 60.17batch/s, loss_px_z: [1.0347], loss_mse_x: [0.3573], loss_py_z: [0.6146], loss_mse_y: [0.1555], loss_pv_z: [29.8462], loss_mse_v: [0.4482], loss_postrior_z: [36.4209]]
Epoch 22/100: 100%|██████████| 32/32 [00:00<00:00, 59.91batch/s, loss_px_z: [0.9561], loss_mse_x: [0.2792], loss_py_z: [0.5554], loss_mse_y: [0.0882], loss_pv_z: [1.5015], loss_mse_v: [0.2573], loss_postrior_z: [7.7514]]
Epoch 23/100: 100%|██████████| 32/32 [00:00<00:00, 60.14batch/s, loss_px_z: [0.8235], loss_mse_x: [0.1472], loss_py_z: [0.5874], loss_mse_y: [0.1351], loss_pv_z: [0.5594], loss_mse_v: [0.2621], loss_postrior_z: [6.8897]]
Epoch 24/100: 100%|██████████| 32/32 [00:00<00:00, 60.10batch/s, loss_px_z: [0.7600], loss_mse_x: [0.0843], loss_py_z: [0.5146], loss_mse_y: [0.0505], loss_pv_z: [15.1470], loss_mse_v: [0.3561], loss_postrior_z: [22.3120]]
Epoch 25/100: 100%|██████████| 32/32 [00:00<00:00, 60.14batch/s, loss_px_z: [0.8172], loss_mse_x: [0.1421], loss_py_z: [0.4946], loss_mse_y: [0.0423], loss_pv_z: [35.5047], loss_mse_v: [0.5002], loss_postrior_z: [39.7496]]
Epoch 26/100: 100%|██████████| 32/32 [00:00<00:00, 59.76batch/s, loss_px_z: [0.8243], loss_mse_x: [0.1498], loss_py_z: [0.5031], loss_mse_y: [0.0492], loss_pv_z: [21.2283], loss_mse_v: [0.4085], loss_postrior_z: [29.4189]]
Epoch 27/100: 100%|██████████| 32/32 [00:00<00:00, 60.15batch/s, loss_px_z: [1.2291], loss_mse_x: [0.5552], loss_py_z: [0.5260], loss_mse_y: [0.0858], loss_pv_z: [32.0752], loss_mse_v: [0.4673], loss_postrior_z: [42.3491]]
Epoch 28/100: 100%|██████████| 32/32 [00:00<00:00, 60.16batch/s, loss_px_z: [0.9256], loss_mse_x: [0.2523], loss_py_z: [0.6113], loss_mse_y: [0.1872], loss_pv_z: [11.5990], loss_mse_v: [0.3499], loss_postrior_z: [15.3578]]
Epoch 29/100: 100%|██████████| 32/32 [00:00<00:00, 60.13batch/s, loss_px_z: [0.9034], loss_mse_x: [0.2306], loss_py_z: [0.5699], loss_mse_y: [0.1462], loss_pv_z: [14.3208], loss_mse_v: [0.3664], loss_postrior_z: [21.1504]]
Epoch 30/100: 100%|██████████| 32/32 [00:00<00:00, 59.69batch/s, loss_px_z: [1.3914], loss_mse_x: [0.7193], loss_py_z: [0.4828], loss_mse_y: [0.0625], loss_pv_z: [-7.3091], loss_mse_v: [0.2280], loss_postrior_z: [-2.4749]]
Epoch [30/100]: MSE_x: 0.0939, MSE_y: 0.0844, MSE_v: 0.4857
Epoch 31/100: 100%|██████████| 32/32 [00:00<00:00, 60.18batch/s, loss_px_z: [0.7766], loss_mse_x: [0.1051], loss_py_z: [0.4541], loss_mse_y: [0.0316], loss_pv_z: [65.6081], loss_mse_v: [1.0058], loss_postrior_z: [62.1184]]
Epoch 32/100: 100%|██████████| 32/32 [00:00<00:00, 58.97batch/s, loss_px_z: [1.2586], loss_mse_x: [0.5876], loss_py_z: [0.5416], loss_mse_y: [0.1410], loss_pv_z: [25.5976], loss_mse_v: [0.6107], loss_postrior_z: [31.8609]]
Epoch 33/100: 100%|██████████| 32/32 [00:00<00:00, 60.15batch/s, loss_px_z: [1.1108], loss_mse_x: [0.4405], loss_py_z: [0.4404], loss_mse_y: [0.0340], loss_pv_z: [9.5764], loss_mse_v: [0.3409], loss_postrior_z: [13.5380]]
Epoch 34/100: 100%|██████████| 32/32 [00:00<00:00, 59.43batch/s, loss_px_z: [1.0514], loss_mse_x: [0.3816], loss_py_z: [0.4840], loss_mse_y: [0.0906], loss_pv_z: [-7.3509], loss_mse_v: [0.2674], loss_postrior_z: [-4.4814]]
Epoch 35/100: 100%|██████████| 32/32 [00:00<00:00, 57.56batch/s, loss_px_z: [0.9433], loss_mse_x: [0.2741], loss_py_z: [0.4575], loss_mse_y: [0.0741], loss_pv_z: [29.3482], loss_mse_v: [0.4471], loss_postrior_z: [27.1960]]
Epoch 36/100: 100%|██████████| 32/32 [00:00<00:00, 52.69batch/s, loss_px_z: [0.9184], loss_mse_x: [0.2498], loss_py_z: [0.4759], loss_mse_y: [0.1049], loss_pv_z: [10.3026], loss_mse_v: [0.3483], loss_postrior_z: [17.2433]]
Epoch 37/100: 100%|██████████| 32/32 [00:00<00:00, 52.65batch/s, loss_px_z: [0.9823], loss_mse_x: [0.3142], loss_py_z: [0.3996], loss_mse_y: [0.0382], loss_pv_z: [11.1830], loss_mse_v: [0.3703], loss_postrior_z: [19.4070]]
Epoch 38/100: 100%|██████████| 32/32 [00:00<00:00, 52.67batch/s, loss_px_z: [0.9945], loss_mse_x: [0.3270], loss_py_z: [0.4861], loss_mse_y: [0.1335], loss_pv_z: [-9.3832], loss_mse_v: [0.2701], loss_postrior_z: [-4.6423]]
Epoch 39/100: 100%|██████████| 32/32 [00:00<00:00, 52.46batch/s, loss_px_z: [0.7293], loss_mse_x: [0.0624], loss_py_z: [0.4619], loss_mse_y: [0.1198], loss_pv_z: [22.1069], loss_mse_v: [0.3943], loss_postrior_z: [26.1371]]
Epoch 40/100: 100%|██████████| 32/32 [00:00<00:00, 52.65batch/s, loss_px_z: [0.9726], loss_mse_x: [0.3064], loss_py_z: [0.4586], loss_mse_y: [0.1268], loss_pv_z: [19.7465], loss_mse_v: [0.4129], loss_postrior_z: [31.5037]]
Epoch [40/100]: MSE_x: 0.0933, MSE_y: 0.0893, MSE_v: 0.4900
Epoch 41/100: 100%|██████████| 32/32 [00:00<00:00, 52.70batch/s, loss_px_z: [0.7790], loss_mse_x: [0.1133], loss_py_z: [0.3834], loss_mse_y: [0.0446], loss_pv_z: [1.9458], loss_mse_v: [0.3133], loss_postrior_z: [8.3677]]
Epoch 42/100: 100%|██████████| 32/32 [00:00<00:00, 52.73batch/s, loss_px_z: [0.7895], loss_mse_x: [0.1244], loss_py_z: [0.3713], loss_mse_y: [0.0586], loss_pv_z: [31.2267], loss_mse_v: [0.4730], loss_postrior_z: [35.4569]]
Epoch 43/100: 100%|██████████| 32/32 [00:00<00:00, 52.60batch/s, loss_px_z: [0.7912], loss_mse_x: [0.1266], loss_py_z: [0.3231], loss_mse_y: [0.0287], loss_pv_z: [41.9379], loss_mse_v: [0.5249], loss_postrior_z: [43.0948]]
Epoch 44/100: 100%|██████████| 32/32 [00:00<00:00, 52.66batch/s, loss_px_z: [0.7445], loss_mse_x: [0.0806], loss_py_z: [0.3750], loss_mse_y: [0.0760], loss_pv_z: [21.5167], loss_mse_v: [0.4183], loss_postrior_z: [30.6867]]
Epoch 45/100: 100%|██████████| 32/32 [00:00<00:00, 52.53batch/s, loss_px_z: [0.7780], loss_mse_x: [0.1146], loss_py_z: [0.3586], loss_mse_y: [0.0796], loss_pv_z: [19.9246], loss_mse_v: [0.4149], loss_postrior_z: [21.5443]]
Epoch 46/100: 100%|██████████| 32/32 [00:00<00:00, 52.72batch/s, loss_px_z: [1.0830], loss_mse_x: [0.4202], loss_py_z: [0.2969], loss_mse_y: [0.0634], loss_pv_z: [-9.4088], loss_mse_v: [0.2761], loss_postrior_z: [-3.7582]]
Epoch 47/100: 100%|██████████| 32/32 [00:00<00:00, 52.78batch/s, loss_px_z: [1.0419], loss_mse_x: [0.3797], loss_py_z: [0.3354], loss_mse_y: [0.0947], loss_pv_z: [37.1565], loss_mse_v: [2.0222], loss_postrior_z: [40.7764]]
Epoch 48/100: 100%|██████████| 32/32 [00:00<00:00, 54.16batch/s, loss_px_z: [0.7746], loss_mse_x: [0.1130], loss_py_z: [0.2930], loss_mse_y: [0.0726], loss_pv_z: [7.5576], loss_mse_v: [0.3752], loss_postrior_z: [13.7704]]
Epoch 49/100: 100%|██████████| 32/32 [00:00<00:00, 52.62batch/s, loss_px_z: [0.8423], loss_mse_x: [0.1813], loss_py_z: [0.3029], loss_mse_y: [0.0860], loss_pv_z: [50.3740], loss_mse_v: [0.5826], loss_postrior_z: [42.9421]]
Epoch 50/100: 100%|██████████| 32/32 [00:00<00:00, 53.91batch/s, loss_px_z: [0.9005], loss_mse_x: [0.2401], loss_py_z: [0.2378], loss_mse_y: [0.0572], loss_pv_z: [23.7203], loss_mse_v: [0.4206], loss_postrior_z: [21.5316]]
Epoch [50/100]: MSE_x: 0.0924, MSE_y: 0.0922, MSE_v: 0.4850
Epoch 51/100: 100%|██████████| 32/32 [00:00<00:00, 56.99batch/s, loss_px_z: [0.8310], loss_mse_x: [0.1712], loss_py_z: [0.2651], loss_mse_y: [0.0840], loss_pv_z: [21.3173], loss_mse_v: [1.1722], loss_postrior_z: [32.6034]]
Epoch 52/100: 100%|██████████| 32/32 [00:00<00:00, 56.32batch/s, loss_px_z: [0.9389], loss_mse_x: [0.2797], loss_py_z: [0.2724], loss_mse_y: [0.0895], loss_pv_z: [41.2433], loss_mse_v: [1.2558], loss_postrior_z: [37.6628]]
Epoch 53/100: 100%|██████████| 32/32 [00:00<00:00, 55.43batch/s, loss_px_z: [1.0191], loss_mse_x: [0.3605], loss_py_z: [0.3110], loss_mse_y: [0.1181], loss_pv_z: [-5.5875], loss_mse_v: [0.2971], loss_postrior_z: [-2.6677]]
Epoch 54/100: 100%|██████████| 32/32 [00:00<00:00, 55.05batch/s, loss_px_z: [0.7439], loss_mse_x: [0.0858], loss_py_z: [0.1105], loss_mse_y: [0.0163], loss_pv_z: [20.1353], loss_mse_v: [0.5451], loss_postrior_z: [20.6309]]
Epoch 55/100: 100%|██████████| 32/32 [00:00<00:00, 53.03batch/s, loss_px_z: [0.9376], loss_mse_x: [0.2801], loss_py_z: [0.1811], loss_mse_y: [0.0681], loss_pv_z: [-17.3250], loss_mse_v: [0.2448], loss_postrior_z: [-14.3498]]
Epoch 56/100: 100%|██████████| 32/32 [00:00<00:00, 57.25batch/s, loss_px_z: [1.3045], loss_mse_x: [0.6476], loss_py_z: [0.3238], loss_mse_y: [0.1491], loss_pv_z: [46.6397], loss_mse_v: [0.5431], loss_postrior_z: [48.7620]]
Epoch 57/100: 100%|██████████| 32/32 [00:00<00:00, 57.07batch/s, loss_px_z: [0.8093], loss_mse_x: [0.1530], loss_py_z: [0.2234], loss_mse_y: [0.1067], loss_pv_z: [41.7925], loss_mse_v: [0.5469], loss_postrior_z: [48.3183]]
Epoch 58/100: 100%|██████████| 32/32 [00:00<00:00, 56.74batch/s, loss_px_z: [1.4082], loss_mse_x: [0.7524], loss_py_z: [0.0447], loss_mse_y: [0.0551], loss_pv_z: [-2.0897], loss_mse_v: [0.3089], loss_postrior_z: [1.1457]]
Epoch 59/100: 100%|██████████| 32/32 [00:00<00:00, 55.34batch/s, loss_px_z: [0.8823], loss_mse_x: [0.2271], loss_py_z: [0.2955], loss_mse_y: [0.1463], loss_pv_z: [-17.1023], loss_mse_v: [0.2474], loss_postrior_z: [-8.4271]]
Epoch 60/100: 100%|██████████| 32/32 [00:00<00:00, 57.76batch/s, loss_px_z: [1.3100], loss_mse_x: [0.6554], loss_py_z: [0.0349], loss_mse_y: [0.0532], loss_pv_z: [-9.9317], loss_mse_v: [0.2931], loss_postrior_z: [-3.7395]]
Epoch [60/100]: MSE_x: 0.0928, MSE_y: 0.0886, MSE_v: 0.4847
Epoch 61/100: 100%|██████████| 32/32 [00:00<00:00, 58.63batch/s, loss_px_z: [0.8233], loss_mse_x: [0.1693], loss_py_z: [-0.0091], loss_mse_y: [0.0491], loss_pv_z: [38.3667], loss_mse_v: [0.4997], loss_postrior_z: [45.0812]]
Epoch 62/100: 100%|██████████| 32/32 [00:00<00:00, 56.32batch/s, loss_px_z: [0.7631], loss_mse_x: [0.1096], loss_py_z: [0.2690], loss_mse_y: [0.1389], loss_pv_z: [19.4034], loss_mse_v: [0.4365], loss_postrior_z: [26.6290]]
Epoch 63/100: 100%|██████████| 32/32 [00:00<00:00, 57.21batch/s, loss_px_z: [1.1080], loss_mse_x: [0.4551], loss_py_z: [0.2811], loss_mse_y: [0.1415], loss_pv_z: [80.7834], loss_mse_v: [0.6600], loss_postrior_z: [72.9104]]
Epoch 64/100: 100%|██████████| 32/32 [00:00<00:00, 56.45batch/s, loss_px_z: [0.7215], loss_mse_x: [0.0692], loss_py_z: [0.2883], loss_mse_y: [0.1612], loss_pv_z: [53.9828], loss_mse_v: [0.6072], loss_postrior_z: [59.1240]]
Epoch 65/100: 100%|██████████| 32/32 [00:00<00:00, 53.46batch/s, loss_px_z: [1.5178], loss_mse_x: [0.8660], loss_py_z: [0.2269], loss_mse_y: [0.1412], loss_pv_z: [7.9063], loss_mse_v: [0.4066], loss_postrior_z: [8.8537]]
Epoch 66/100: 100%|██████████| 32/32 [00:00<00:00, 57.96batch/s, loss_px_z: [1.0116], loss_mse_x: [0.3604], loss_py_z: [0.5221], loss_mse_y: [0.2095], loss_pv_z: [12.0556], loss_mse_v: [0.4083], loss_postrior_z: [19.8110]]
Epoch 67/100: 100%|██████████| 32/32 [00:00<00:00, 52.53batch/s, loss_px_z: [0.9474], loss_mse_x: [0.2968], loss_py_z: [0.0703], loss_mse_y: [0.0881], loss_pv_z: [-6.6935], loss_mse_v: [0.2905], loss_postrior_z: [-1.6798]]
Epoch 68/100: 100%|██████████| 32/32 [00:00<00:00, 54.06batch/s, loss_px_z: [0.7904], loss_mse_x: [0.1404], loss_py_z: [-0.2489], loss_mse_y: [0.0387], loss_pv_z: [-25.1200], loss_mse_v: [0.2314], loss_postrior_z: [-16.3646]]
Epoch 69/100: 100%|██████████| 32/32 [00:00<00:00, 52.61batch/s, loss_px_z: [0.7263], loss_mse_x: [0.0768], loss_py_z: [0.1700], loss_mse_y: [0.1200], loss_pv_z: [-3.5599], loss_mse_v: [0.3203], loss_postrior_z: [9.7701]]
Epoch 70/100: 100%|██████████| 32/32 [00:00<00:00, 56.04batch/s, loss_px_z: [0.8032], loss_mse_x: [0.1543], loss_py_z: [-0.1722], loss_mse_y: [0.0521], loss_pv_z: [21.8734], loss_mse_v: [0.4367], loss_postrior_z: [29.7064]]
Epoch [70/100]: MSE_x: 0.0927, MSE_y: 0.0864, MSE_v: 0.4795
Epoch 71/100: 100%|██████████| 32/32 [00:00<00:00, 56.08batch/s, loss_px_z: [0.8890], loss_mse_x: [0.2406], loss_py_z: [-0.1999], loss_mse_y: [0.0420], loss_pv_z: [1.5693], loss_mse_v: [0.4193], loss_postrior_z: [6.5249]]
Epoch 72/100: 100%|██████████| 32/32 [00:00<00:00, 55.01batch/s, loss_px_z: [0.9643], loss_mse_x: [0.3165], loss_py_z: [0.2295], loss_mse_y: [0.1324], loss_pv_z: [25.2323], loss_mse_v: [0.4755], loss_postrior_z: [35.1813]]
Epoch 73/100: 100%|██████████| 32/32 [00:00<00:00, 56.47batch/s, loss_px_z: [0.9932], loss_mse_x: [0.3460], loss_py_z: [0.0042], loss_mse_y: [0.0815], loss_pv_z: [51.5540], loss_mse_v: [0.9088], loss_postrior_z: [57.8552]]
Epoch 74/100: 100%|██████████| 32/32 [00:00<00:00, 54.08batch/s, loss_px_z: [0.8933], loss_mse_x: [0.2466], loss_py_z: [-0.1914], loss_mse_y: [0.0467], loss_pv_z: [6.5270], loss_mse_v: [0.5062], loss_postrior_z: [8.9503]]
Epoch 75/100: 100%|██████████| 32/32 [00:00<00:00, 55.97batch/s, loss_px_z: [1.0264], loss_mse_x: [0.3803], loss_py_z: [-0.2124], loss_mse_y: [0.0410], loss_pv_z: [51.8404], loss_mse_v: [0.7711], loss_postrior_z: [53.4973]]
Epoch 76/100: 100%|██████████| 32/32 [00:00<00:00, 54.47batch/s, loss_px_z: [0.8433], loss_mse_x: [0.1978], loss_py_z: [-0.1416], loss_mse_y: [0.0737], loss_pv_z: [23.4251], loss_mse_v: [0.4216], loss_postrior_z: [25.4289]]
Epoch 77/100: 100%|██████████| 32/32 [00:00<00:00, 56.41batch/s, loss_px_z: [1.0254], loss_mse_x: [0.3805], loss_py_z: [-0.0419], loss_mse_y: [0.0844], loss_pv_z: [89.5222], loss_mse_v: [1.0466], loss_postrior_z: [94.9647]]
Epoch 78/100: 100%|██████████| 32/32 [00:00<00:00, 56.31batch/s, loss_px_z: [0.9360], loss_mse_x: [0.2916], loss_py_z: [0.3582], loss_mse_y: [0.1516], loss_pv_z: [-0.7475], loss_mse_v: [0.3250], loss_postrior_z: [9.3354]]
Epoch 79/100: 100%|██████████| 32/32 [00:00<00:00, 56.92batch/s, loss_px_z: [0.8097], loss_mse_x: [0.1659], loss_py_z: [0.0190], loss_mse_y: [0.1032], loss_pv_z: [23.7452], loss_mse_v: [0.5672], loss_postrior_z: [30.1029]]
Epoch 80/100: 100%|██████████| 32/32 [00:00<00:00, 56.77batch/s, loss_px_z: [1.5119], loss_mse_x: [0.8687], loss_py_z: [-0.2420], loss_mse_y: [0.0538], loss_pv_z: [52.9697], loss_mse_v: [0.7831], loss_postrior_z: [67.9704]]
Epoch [80/100]: MSE_x: 0.0921, MSE_y: 0.0892, MSE_v: 0.4806
Epoch 81/100: 100%|██████████| 32/32 [00:00<00:00, 58.74batch/s, loss_px_z: [1.0592], loss_mse_x: [0.4165], loss_py_z: [0.3977], loss_mse_y: [0.1552], loss_pv_z: [76.5605], loss_mse_v: [0.7894], loss_postrior_z: [88.1919]]
Epoch 82/100: 100%|██████████| 32/32 [00:00<00:00, 55.39batch/s, loss_px_z: [0.8014], loss_mse_x: [0.1592], loss_py_z: [-0.3580], loss_mse_y: [0.0247], loss_pv_z: [-10.8457], loss_mse_v: [0.3003], loss_postrior_z: [-1.7214]]
Epoch 83/100: 100%|██████████| 32/32 [00:00<00:00, 55.04batch/s, loss_px_z: [0.7829], loss_mse_x: [0.1414], loss_py_z: [0.1903], loss_mse_y: [0.1368], loss_pv_z: [5.1346], loss_mse_v: [0.3579], loss_postrior_z: [13.7117]]
Epoch 84/100: 100%|██████████| 32/32 [00:00<00:00, 53.91batch/s, loss_px_z: [0.7963], loss_mse_x: [0.1553], loss_py_z: [-0.1579], loss_mse_y: [0.0651], loss_pv_z: [-43.4583], loss_mse_v: [0.1511], loss_postrior_z: [-41.7073]]
Epoch 85/100: 100%|██████████| 32/32 [00:00<00:00, 55.40batch/s, loss_px_z: [1.7594], loss_mse_x: [1.1190], loss_py_z: [-0.3379], loss_mse_y: [0.0233], loss_pv_z: [7.3098], loss_mse_v: [0.3463], loss_postrior_z: [7.8066]]
Epoch 86/100: 100%|██████████| 32/32 [00:00<00:00, 55.51batch/s, loss_px_z: [0.9384], loss_mse_x: [0.2985], loss_py_z: [-0.0485], loss_mse_y: [0.0784], loss_pv_z: [10.2092], loss_mse_v: [0.3946], loss_postrior_z: [15.5025]]
Epoch 87/100: 100%|██████████| 32/32 [00:00<00:00, 56.19batch/s, loss_px_z: [0.9014], loss_mse_x: [0.2620], loss_py_z: [-0.0807], loss_mse_y: [0.0751], loss_pv_z: [4.4597], loss_mse_v: [0.3666], loss_postrior_z: [16.1913]]
Epoch 88/100: 100%|██████████| 32/32 [00:00<00:00, 54.59batch/s, loss_px_z: [0.8896], loss_mse_x: [0.2508], loss_py_z: [-0.4017], loss_mse_y: [0.0261], loss_pv_z: [-1.6607], loss_mse_v: [0.3429], loss_postrior_z: [9.9890]]
Epoch 89/100: 100%|██████████| 32/32 [00:00<00:00, 55.01batch/s, loss_px_z: [0.9987], loss_mse_x: [0.3604], loss_py_z: [-0.0197], loss_mse_y: [0.0888], loss_pv_z: [2.3344], loss_mse_v: [0.3574], loss_postrior_z: [8.8961]]
Epoch 90/100: 100%|██████████| 32/32 [00:00<00:00, 55.18batch/s, loss_px_z: [0.9082], loss_mse_x: [0.2706], loss_py_z: [0.3003], loss_mse_y: [0.1572], loss_pv_z: [41.4758], loss_mse_v: [0.6483], loss_postrior_z: [42.9172]]
Epoch [90/100]: MSE_x: 0.0929, MSE_y: 0.0851, MSE_v: 0.4804
Epoch 91/100: 100%|██████████| 32/32 [00:00<00:00, 55.70batch/s, loss_px_z: [1.2535], loss_mse_x: [0.6164], loss_py_z: [0.3955], loss_mse_y: [0.1559], loss_pv_z: [27.8524], loss_mse_v: [0.4503], loss_postrior_z: [26.3598]]
Epoch 92/100: 100%|██████████| 32/32 [00:00<00:00, 54.39batch/s, loss_px_z: [0.9255], loss_mse_x: [0.2889], loss_py_z: [-0.2278], loss_mse_y: [0.0506], loss_pv_z: [-30.0789], loss_mse_v: [0.2110], loss_postrior_z: [-27.2653]]
Epoch 93/100: 100%|██████████| 32/32 [00:00<00:00, 55.26batch/s, loss_px_z: [0.8545], loss_mse_x: [0.2185], loss_py_z: [-0.1203], loss_mse_y: [0.0712], loss_pv_z: [-3.8885], loss_mse_v: [0.3201], loss_postrior_z: [-0.4682]]
Epoch 94/100: 100%|██████████| 32/32 [00:00<00:00, 55.40batch/s, loss_px_z: [0.8185], loss_mse_x: [0.1830], loss_py_z: [-0.2654], loss_mse_y: [0.0451], loss_pv_z: [-1.6399], loss_mse_v: [0.3369], loss_postrior_z: [1.0570]]
Epoch 95/100: 100%|██████████| 32/32 [00:00<00:00, 58.24batch/s, loss_px_z: [0.6933], loss_mse_x: [0.0585], loss_py_z: [-0.0570], loss_mse_y: [0.0867], loss_pv_z: [0.5686], loss_mse_v: [0.3402], loss_postrior_z: [5.5315]]
Epoch 96/100: 100%|██████████| 32/32 [00:00<00:00, 60.05batch/s, loss_px_z: [0.8206], loss_mse_x: [0.1863], loss_py_z: [0.3453], loss_mse_y: [0.1475], loss_pv_z: [30.9373], loss_mse_v: [0.5577], loss_postrior_z: [33.8833]]
Epoch 97/100: 100%|██████████| 32/32 [00:00<00:00, 60.05batch/s, loss_px_z: [0.7989], loss_mse_x: [0.1652], loss_py_z: [-0.2076], loss_mse_y: [0.0503], loss_pv_z: [5.4328], loss_mse_v: [0.4984], loss_postrior_z: [11.1171]]
Epoch 98/100: 100%|██████████| 32/32 [00:00<00:00, 59.85batch/s, loss_px_z: [0.6983], loss_mse_x: [0.0651], loss_py_z: [-0.0814], loss_mse_y: [0.0811], loss_pv_z: [24.3846], loss_mse_v: [0.5054], loss_postrior_z: [31.1920]]
Epoch 99/100: 100%|██████████| 32/32 [00:00<00:00, 60.03batch/s, loss_px_z: [1.1036], loss_mse_x: [0.4710], loss_py_z: [-0.2236], loss_mse_y: [0.0453], loss_pv_z: [2.1287], loss_mse_v: [0.3433], loss_postrior_z: [5.4971]]
Epoch 100/100: 100%|██████████| 32/32 [00:00<00:00, 59.90batch/s, loss_px_z: [0.8939], loss_mse_x: [0.2618], loss_py_z: [-0.1434], loss_mse_y: [0.0728], loss_pv_z: [-24.3166], loss_mse_v: [0.2560], loss_postrior_z: [-25.0117]]
Epoch [100/100]: MSE_x: 0.0916, MSE_y: 0.0877, MSE_v: 0.4752
Make predictions using the trained CausalBGM model
Estimate causal effects with posterior intervals from latent MCMC samples.
Config Parameter |
Description |
|---|---|
|
Tuple of data inputs |
|
Significance level for the posterior interval. Default: 0.01. |
|
Number of posterior MCMC samples to draw. Default: 3000. |
|
Number of burn-in MCMC samples before drawing. Default: 5000. |
|
Treatment value(s) for dose-response function to be predicted. Examples: 1.0 or [1.0,2.0] |
|
Standard deviation for the proposal distribution used in Metropolis-Hastings (MH) sampling. Default: 1.0. |
|
Whether to consider the variance function in the outcome generative model. Default: True. |
|
Batch size in inference stage, denoting number of test subjects processed per batch prediction. Default: 10000. |
Return |
Type |
Description |
Shape |
|---|---|---|---|
|
|
Point estimates of the Individual Treatment Effect (ITE). |
|
|
|
Posterior intervals for the ITEs, representing |
|
[12]:
pre_ite_mean, pre_ite_PI = model.predict(data=(x,y,v), alpha=0.01, n_mcmc=3000, burn_in=5000, q_sd=1.0, bs=1000)
MCMC Latent Variable Sampling ...
Final MCMC Acceptance Rate: 0.2116
Evaluating the results
Calculate the error of average treatment effect (\(\epsilon_{ATE}\)) and precision in estimation of heterogeneous effect (\(\epsilon_{PEHE}\)).
[13]:
# Get the ground truth ITE
ufid = '629e3d2c63914e45b227cc913c09cebe'
covariants_file = 'data/ACIC_2018/x.csv'
df = pd.read_csv(covariants_file, index_col='sample_id',header=0, sep=',')
df_sim = pd.read_csv('data/ACIC_2018/scaling/factuals/%s.csv'%ufid,index_col='sample_id',header=0, sep=',')
dataset = df.join(df_sim, how='inner')
data_x = dataset['z'].values
data_y = dataset['y'].values
cf_id = ufid + '_cf'
y_true = pd.read_csv('data/ACIC_2018/scaling/counterfactuals/%s.csv'%cf_id,index_col='sample_id',header=0, sep=',')
y_0 = y_true.values[:,0]
y_1 = y_true.values[:,1]
ite_true = y_true.values[:,1]-y_true.values[:,0]
# Evaluate
delta_ate = abs(np.mean(pre_ite_mean) - np.mean(ite_true))
delta_pehe = np.mean((pre_ite_mean - ite_true)**2)
print(f"Delta ATE (Absolute Error in Average Treatment Effect): {delta_ate:.4f}")
print(f"Delta PEHE (Precision in Estimation of Heterogeneous Effect): {delta_pehe:.4f}")
Delta ATE (Absolute Error in Average Treatment Effect): 0.0069
Delta PEHE (Precision in Estimation of Heterogeneous Effect): 0.0001
Use CausalBGM by a command-line interface (CLI)
When installing the CausalBGM by pip install bayesgm, an indepedent console program will be available for general use. This has advantage of being generalizeable to non-python scripts!
Reminder
Since the current CLI directly takes data from a single file with txt/csv/npy format as input, the training and inference are on the same dataset.
Python or R APIs are recommended for more flexible usage.
[1]:
!bayesgm causalbgm -h
2026-03-12 22:10:07.495988: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-03-12 22:10:07.584365: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-03-12 22:10:07.607316: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-03-12 22:10:08.326003: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/ql339/.conda/envs/py3.9/lib/
2026-03-12 22:10:08.326092: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/ql339/.conda/envs/py3.9/lib/
2026-03-12 22:10:08.326098: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
usage: bayesgm causalbgm [-h] -o OUTPUT_DIR -i INPUT [-t DELIMITER]
[-d DATASET] [-F SAVE_FORMAT] [-save_model]
[-save_res] [--use_bnn | --no-use_bnn]
[--use_egm_init | --no-use_egm_init] [--seed SEED]
[-B | --binary_treatment | --no-binary_treatment]
[-Z Z_DIMS [Z_DIMS ...]] [--lr_theta LR_THETA]
[--lr_z LR_Z] [--x_min X_MIN] [--x_max X_MAX]
[--x_values X_VALUES [X_VALUES ...]]
[--g_units G_UNITS [G_UNITS ...]]
[--f_units F_UNITS [F_UNITS ...]]
[--h_units H_UNITS [H_UNITS ...]]
[--kl_weight KL_WEIGHT] [--lr LR]
[--g_d_freq G_D_FREQ]
[--e_units E_UNITS [E_UNITS ...]]
[--dz_units DZ_UNITS [DZ_UNITS ...]]
[--use-z-rec | --no-use-z-rec] [-N N_ITER]
[--startoff STARTOFF]
[--batches_per_eval BATCHES_PER_EVAL] [-E EPOCHS]
[-M N_MCMC] [-q Q_SD]
[--epochs_per_eval EPOCHS_PER_EVAL] [--alpha ALPHA]
CausalBGM: An AI-powered Bayesian generative modeling approach for causal
inference in observational studies
optional arguments:
-h, --help show this help message and exit
-o OUTPUT_DIR, --output_dir OUTPUT_DIR
Output directory
-i INPUT, --input INPUT
Input data file must be in csv or txt or npz format
-t DELIMITER, --delimiter DELIMITER
Delimiter for txt or csv files (default: tab '\t').
-d DATASET, --dataset DATASET
Dataset name
-F SAVE_FORMAT, --save_format SAVE_FORMAT
Saving format (default: txt)
-save_model Whether to save model. (default: False)
-save_res Whether to save intermediate results. (default: True)
--use_bnn, --no-use_bnn
Whether use Bayesian neural nets. (default: True)
--use_egm_init, --no-use_egm_init
Whether use EGM initialization. (default: True)
--seed SEED Random seed for reproduction (default: 123).
-B, --binary_treatment, --no-binary_treatment
Whether use binary treatment setting. (default: True)
-Z Z_DIMS [Z_DIMS ...], --z_dims Z_DIMS [Z_DIMS ...]
Latent dimensions of Z (default: [3, 3, 6, 6]).
--lr_theta LR_THETA Learning rate for updating model parameters (default:
0.0001).
--lr_z LR_Z Learning rate for updating latent variables (default:
0.0001).
--x_min X_MIN Lower bound for treatment interval (default: 0.0).
--x_max X_MAX Upper bound for treatment interval (default: 3.0).
--x_values X_VALUES [X_VALUES ...]
List of treatment values to be predicted. Provide
space-separated values. Example: --x_values 0.5 1.0
1.5
--g_units G_UNITS [G_UNITS ...]
Number of units for covariates generative model
(default: [64,64,64,64,64]).
--f_units F_UNITS [F_UNITS ...]
Number of units for outcome generative model (default:
[64,32,8]).
--h_units H_UNITS [H_UNITS ...]
Number of units for treatment generative model
(default: [64,32,8]).
--kl_weight KL_WEIGHT
Coefficient for KL divergence term in BNNs (default:
0.0001).
--lr LR Learning rate for EGM initialization (default:
0.0001).
--g_d_freq G_D_FREQ Frequency for updating discriminators and generators
(default: 5).
--e_units E_UNITS [E_UNITS ...]
Number of units for encoder network (default:
[64,64,64,64,64]).
--dz_units DZ_UNITS [DZ_UNITS ...]
Number of units for discriminator network in latent
space (default: [64,32,8]).
--use-z-rec, --no-use-z-rec
Use the reconstruction for latent features (default:
True). (default: True)
-N N_ITER, --n_iter N_ITER
Number of iterations (default: 30000).
--startoff STARTOFF Iteration for starting evaluation (default: 0).
--batches_per_eval BATCHES_PER_EVAL
Number of iterations per evaluation (default: 500).
-E EPOCHS, --epochs EPOCHS
Number of epochs in iterative updating algorithm
(default: 100).
-M N_MCMC, --n_mcmc N_MCMC
MCMC sample size (default: 3000).
-q Q_SD, --q_sd Q_SD Standard deviation for proposal distribution in MCMC,
a negative q_sd denotes adaptive MCMC (default: 1.0).
--epochs_per_eval EPOCHS_PER_EVAL
Number of epochs per evaluation (default: 10).
--alpha ALPHA Significance level (default: 0.01).
The config parameters are consistent with the Python APIs in the previous sections. Here, we use a demo data (continous treatment setting) for an example!
[2]:
!bayesgm causalbgm -i demo.csv -o ./ -d Demo -N 1000 -E 10 -M 500 -Z 1 1 1 7 --no-binary_treatment --x_values 0 1 2
2025-01-20 11:39:17.090463: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-01-20 11:39:17.213772: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-20 11:39:17.787946: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /share/software/user/open/cudnn/8.1.1.33/lib64:/share/software/user/open/nccl/2.8.4/lib:/usr/lib64/nvidia:/share/software/user/open/cuda/11.2.0/targets/x86_64-linux/lib:/share/software/user/open/cuda/11.2.0/lib64:/share/software/user/open/cuda/11.2.0/nvvm/lib64:/share/software/user/open/cuda/11.2.0/extras/Debugger/lib64:/share/software/user/open/cuda/11.2.0/extras/CUPTI/lib64:/share/software/user/open/openblas/0.3.10/lib:/share/software/user/open/gcc/10.1.0/lib64:/share/software/user/open/gcc/10.1.0/lib/gcc/x86_64-pc-linux-gnu:/share/software/user/open/gcc/10.1.0/lib:/share/software/user/open/tensorrt/8.5.1.7/lib:/share/software/user/open/python/3.9.0/lib:/share/software/user/open/libffi/3.2.1/lib64:/share/software/user/open/sqlite/3.44.2/lib:/share/software/user/open/readline/8.2/lib:/share/software/user/open/ncurses/6.4/lib:/share/software/user/open/tcltk/8.6.6/lib:/share/software/user/open/libressl/3.2.1/lib:/share/software/user/open/zlib/1.2.11/lib
2025-01-20 11:39:17.788041: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /share/software/user/open/cudnn/8.1.1.33/lib64:/share/software/user/open/nccl/2.8.4/lib:/usr/lib64/nvidia:/share/software/user/open/cuda/11.2.0/targets/x86_64-linux/lib:/share/software/user/open/cuda/11.2.0/lib64:/share/software/user/open/cuda/11.2.0/nvvm/lib64:/share/software/user/open/cuda/11.2.0/extras/Debugger/lib64:/share/software/user/open/cuda/11.2.0/extras/CUPTI/lib64:/share/software/user/open/openblas/0.3.10/lib:/share/software/user/open/gcc/10.1.0/lib64:/share/software/user/open/gcc/10.1.0/lib/gcc/x86_64-pc-linux-gnu:/share/software/user/open/gcc/10.1.0/lib:/share/software/user/open/tensorrt/8.5.1.7/lib:/share/software/user/open/python/3.9.0/lib:/share/software/user/open/libffi/3.2.1/lib64:/share/software/user/open/sqlite/3.44.2/lib:/share/software/user/open/readline/8.2/lib:/share/software/user/open/ncurses/6.4/lib:/share/software/user/open/tcltk/8.6.6/lib:/share/software/user/open/libressl/3.2.1/lib:/share/software/user/open/zlib/1.2.11/lib
2025-01-20 11:39:17.788058: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2025-01-20 11:39:19.651249: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-01-20 11:39:20.144803: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 43430 MB memory: -> device: 0, name: NVIDIA A40, pci bus id: 0000:06:00.0, compute capability: 8.6
/home/users/liuqiao/.local/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:95: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
loc = add_variable_fn(
/home/users/liuqiao/.local/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:105: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
untransformed_scale = add_variable_fn(
2025-01-20 11:39:21.124243: I tensorflow/stream_executor/cuda/cuda_blas.cc:1614] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
EGM Initialization Starts ...
EGM Initialization Iter [0] : e_loss_adv [0.0104], l2_loss_v [1.0768], l2_loss_z [0.9892], l2_loss_x [0.9474], l2_loss_y [4.9964], g_e_loss [8.0202], dz_loss [0.0822], d_loss [1.8063]
EGM Initialization Iter [500] : e_loss_adv [1.9232], l2_loss_v [1.1323], l2_loss_z [1.2015], l2_loss_x [0.4507], l2_loss_y [2.8286], g_e_loss [7.5364], dz_loss [-0.9738], d_loss [-0.5502]
EGM Initialization Iter [1000] : e_loss_adv [2.0145], l2_loss_v [0.9970], l2_loss_z [1.1143], l2_loss_x [1.5727], l2_loss_y [0.9798], g_e_loss [6.6785], dz_loss [-1.6597], d_loss [-1.3993]
EGM Initialization Ends.
Initialize latent variables Z with e(V)...
Iterative Updating Starts ...
Epoch 0/10: 100%|█| 3/3 [00:09<00:00, 3.25s/batch, loss_px_z: [1.7446], loss_ms
Epoch [0/10]: MSE_x: 1.5652, MSE_y: 1.5891, MSE_v: 0.9300
Epoch 1/10: 100%|█| 3/3 [00:00<00:00, 36.85batch/s, loss_px_z: [1.0154], loss_ms
Epoch 2/10: 100%|█| 3/3 [00:00<00:00, 36.99batch/s, loss_px_z: [1.3450], loss_ms
Epoch 3/10: 100%|█| 3/3 [00:00<00:00, 36.98batch/s, loss_px_z: [1.6738], loss_ms
Epoch 4/10: 100%|█| 3/3 [00:00<00:00, 38.60batch/s, loss_px_z: [1.9682], loss_ms
Epoch 5/10: 100%|█| 3/3 [00:00<00:00, 32.90batch/s, loss_px_z: [2.1426], loss_ms
Epoch 6/10: 100%|█| 3/3 [00:00<00:00, 40.06batch/s, loss_px_z: [0.8936], loss_ms
Epoch 7/10: 100%|█| 3/3 [00:00<00:00, 39.97batch/s, loss_px_z: [1.1379], loss_ms
Epoch 8/10: 100%|█| 3/3 [00:00<00:00, 38.26batch/s, loss_px_z: [0.9878], loss_ms
Epoch 9/10: 100%|█| 3/3 [00:00<00:00, 37.52batch/s, loss_px_z: [0.6265], loss_ms
Epoch 10/10: 100%|█| 3/3 [00:00<00:00, 36.10batch/s, loss_px_z: [1.0492], loss_m
Epoch [10/10]: MSE_x: 1.5488, MSE_y: 1.5971, MSE_v: 0.9248
MCMC Latent Variable Sampling ...
Final MCMC Acceptance Rate: 0.1590