Tutorial for Python Users

In this tutorial, we will go through CausalBGM workflow under both continous treatment setting and binary treatment setting.

Users can use CausalBGM by Python API or a single command line after installation.

First of all, you need to install python bayesgm toolkit, please refer to the install page.

Notes

  • We develop a Python toolkit bayesgm for AI-powered Bayesian generative modeling approaches. CausalBGM is one of the models in this package.

  • The use of CausalBGM APIs are detailed in online document API page.

[1]:
import numpy as np
import pandas as pd
import yaml
import bayesgm
print("Currently use version %s of bayesgm."%bayesgm.__version__)
Currently use version 1.0.1 of bayesgm.

Use CausalBGM Python API in continous treatment setting

We will use Hirano and Imbens dataset for an example.

CausalBGM Configuration Parameters

Before creating a CausalBGM model, a python dict object should be created for configing a CausalBGM model, which are described as follows.

General Parameters

Config Parameter

Description

dataset

Dataset name to indicate the input data. Default: ‘Sim_Hirano_Imbens’.

output_dir

Output directory to save the results during the model training. Default: ‘.’.

save_res

Whether to save intermediate results. Default: True.

save_model

Whether to save the model after training. Default: False.

binary_treatment

Whether to use binary treatment settings. Default: False.

use_bnn

Whether to use Bayesian neural networks. Default: True.

Parameters for Iterative Updating Algorithm

Config Parameter

Description

z_dims

Latent dimensions of Z. Default: [1, 1, 1, 7].

v_dim

Dimension of covariates. Default: 200.

lr_theta

Learning rate for updating model parameters. Default: 0.0001.

lr_z

Learning rate for updating latent variables. Default: 0.0001.

g_units

Number of units for covariates generative model. Default: [64, 64, 64, 64, 64].

f_units

Number of units for outcome generative model. Default: [64, 32, 8].

h_units

Number of units for treatment generative model. Default: [64, 32, 8].

Parameters for EGM Initialization

Config Parameter

Description

kl_weight

Coefficient for KL divergence term in BNNs. Default: 0.0001.

lr

Learning rate for EGM initialization. Default: 0.0002.

g_d_freq

Frequency for updating discriminators and generators. Default: 5.

use_z_rec

Whether to use reconstruction for latent features. Default: True.

e_units

Number of units for the encoder network. Default: [64, 64, 64, 64, 64].

dz_units

Number of units for the discriminator network in latent space. Default: [64, 32, 8].

Notes

  • Many example config files are provided in the folder src/configs (link) for datasets used in the paper.

Loading config parameters

[2]:
params = yaml.safe_load(open('src/configs/Sim_Hirano_Imbens.yaml', 'r'))
print(params)
{'dataset': 'Sim_Hirano_Imbens', 'output_dir': '.', 'save_res': True, 'save_model': False, 'binary_treatment': False, 'use_bnn': True, 'z_dims': [1, 1, 1, 7], 'v_dim': 200, 'lr_theta': 0.0001, 'lr_z': 0.0001, 'g_units': [64, 64, 64, 64, 64], 'f_units': [64, 32, 8], 'h_units': [64, 32, 8], 'kl_weight': 0.0001, 'lr': 0.0002, 'g_d_freq': 5, 'use_z_rec': True, 'e_units': [64, 64, 64, 64, 64], 'dz_units': [64, 32, 8]}

Instantiate a CausalBGM model

[3]:
model = bayesgm.models.CausalBGM(params=params, random_seed=None)
2026-03-11 10:40:29.953365: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-03-11 10:40:30.165681: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-03-11 10:40:30.278509: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-03-11 10:40:32.578548: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/ql339/.conda/envs/py3.9/lib/
2026-03-11 10:40:32.579064: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/ql339/.conda/envs/py3.9/lib/
2026-03-11 10:40:32.579068: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2026-03-11 10:40:39.357386: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2027] TensorFlow was not built with CUDA kernel binaries compatible with compute capability 9.0. CUDA kernels will be jit-compiled from PTX, which could take 30 minutes or longer.
2026-03-11 10:40:39.357660: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-03-11 10:40:39.363478: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2027] TensorFlow was not built with CUDA kernel binaries compatible with compute capability 9.0. CUDA kernels will be jit-compiled from PTX, which could take 30 minutes or longer.
2026-03-11 10:40:44.230986: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 140929 MB memory:  -> device: 0, name: NVIDIA H200, pci bus id: 0000:bb:00.0, compute capability: 9.0
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:95: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
  loc = add_variable_fn(
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:105: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
  untransformed_scale = add_variable_fn(
2026-03-11 10:40:47.589822: I tensorflow/stream_executor/cuda/cuda_blas.cc:1614] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.

Data preparation

The input data are organized in a triplet, which contains treatment (X), potential outcome (Y), and covariates (V).

[4]:
from bayesgm.datasets import Sim_Hirano_Imbens_sampler
x,y,v = Sim_Hirano_Imbens_sampler(N=20000, v_dim=200).load_all()
print(x.shape,y.shape,v.shape)
(20000, 1) (20000, 1) (20000, 200)

Model training

Train CausalBGM with an optional EGM warm-start.

Config Parameter

Description

data

Tuple of data inputs (x, y, v), Required.

batch_size

Batch size for training. Default: 32.

epochs

Number of epochs for training. Default: 100.

epochs_per_eval

Frequency of evaluations during training (e.g., every 5 epochs). Default: 5.

use_egm_init

Whether to run EGM initialization before iterative training. Default: True.

epochs_per_eval

Frequency of evaluations during training (e.g., every 5 epochs). Default: 5.

egm_n_iter

Number of EGM initialization iterations. Default: 30000.

egm_batches_per_eval

Evaluate EGM initialization every this many iterations. Default: 500.

verbose

Controls verbosity level, showing progress and evaluation metrics. Default: 1.

Notes

The training procedure consists of two phases:

  • EGM initialization (optional) — warm-start to obtain a good starting point for the latent variables and model parameters. This phase is optional and can be skipped by setting use_egm_init to False.

  • Stochastic iterative updating — alternates between updating the generator network parameters and the per-sample latent variables via stochastic gradient optimization.

[5]:
model.fit(data=(x,y,v), epochs=100, epochs_per_eval=10, use_egm_init=True, egm_n_iter=30000, egm_batches_per_eval=500, verbose=1)
EGM Initialization Starts ...
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:95: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
  loc = add_variable_fn(
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:105: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
  untransformed_scale = add_variable_fn(
EGM Initialization Iter [0] : e_loss_adv [0.2336], l2_loss_v [1.0228], l2_loss_z [0.9938], l2_loss_x [1.0393], l2_loss_y [4.4029], g_e_loss [7.6924], dz_loss [0.0516], d_loss [0.3271]
EGM Initialization Iter [500] : e_loss_adv [1.7835], l2_loss_v [1.0174], l2_loss_z [0.8902], l2_loss_x [11.1814], l2_loss_y [12.4417], g_e_loss [27.3142], dz_loss [-1.1979], d_loss [-0.7519]
EGM Initialization Iter [1000] : e_loss_adv [0.2142], l2_loss_v [1.0167], l2_loss_z [0.9186], l2_loss_x [0.4651], l2_loss_y [1.7649], g_e_loss [4.3797], dz_loss [-1.5703], d_loss [-1.4926]
EGM Initialization Iter [1500] : e_loss_adv [0.6016], l2_loss_v [1.0229], l2_loss_z [0.8104], l2_loss_x [1.7311], l2_loss_y [1.1285], g_e_loss [5.2944], dz_loss [-1.4662], d_loss [-1.1440]
EGM Initialization Iter [2000] : e_loss_adv [-0.2172], l2_loss_v [0.9818], l2_loss_z [0.7609], l2_loss_x [3.8512], l2_loss_y [1.3351], g_e_loss [6.7121], dz_loss [-0.7049], d_loss [-0.5723]
EGM Initialization Iter [2500] : e_loss_adv [-0.6869], l2_loss_v [0.9955], l2_loss_z [0.7060], l2_loss_x [0.5864], l2_loss_y [1.4249], g_e_loss [3.0260], dz_loss [-0.7184], d_loss [-0.5650]
EGM Initialization Iter [3000] : e_loss_adv [-0.9327], l2_loss_v [0.9780], l2_loss_z [0.7114], l2_loss_x [1.4311], l2_loss_y [1.4923], g_e_loss [3.6802], dz_loss [-0.4438], d_loss [-0.2919]
EGM Initialization Iter [3500] : e_loss_adv [-1.2839], l2_loss_v [1.0583], l2_loss_z [0.7365], l2_loss_x [2.3093], l2_loss_y [1.9179], g_e_loss [4.7381], dz_loss [-0.8263], d_loss [-0.7364]
EGM Initialization Iter [4000] : e_loss_adv [-1.5408], l2_loss_v [1.0294], l2_loss_z [0.6064], l2_loss_x [4.0206], l2_loss_y [1.4722], g_e_loss [5.5880], dz_loss [-0.8158], d_loss [-0.6767]
EGM Initialization Iter [4500] : e_loss_adv [-1.0226], l2_loss_v [0.9935], l2_loss_z [0.6432], l2_loss_x [1.1369], l2_loss_y [1.1873], g_e_loss [2.9384], dz_loss [-0.7635], d_loss [-0.6356]
EGM Initialization Iter [5000] : e_loss_adv [-1.3171], l2_loss_v [1.0850], l2_loss_z [0.5702], l2_loss_x [0.8795], l2_loss_y [1.1124], g_e_loss [2.3301], dz_loss [-0.7465], d_loss [-0.6473]
EGM Initialization Iter [5500] : e_loss_adv [-1.3136], l2_loss_v [1.0697], l2_loss_z [0.6109], l2_loss_x [3.3471], l2_loss_y [1.4369], g_e_loss [5.1511], dz_loss [-0.7818], d_loss [-0.7041]
EGM Initialization Iter [6000] : e_loss_adv [-1.5287], l2_loss_v [1.0135], l2_loss_z [0.5739], l2_loss_x [3.3765], l2_loss_y [2.4388], g_e_loss [5.8741], dz_loss [-0.4657], d_loss [-0.3500]
EGM Initialization Iter [6500] : e_loss_adv [-0.9883], l2_loss_v [1.0253], l2_loss_z [0.4936], l2_loss_x [12.1984], l2_loss_y [1.8050], g_e_loss [14.5340], dz_loss [-1.0290], d_loss [-0.8756]
EGM Initialization Iter [7000] : e_loss_adv [-0.9280], l2_loss_v [1.0170], l2_loss_z [0.5469], l2_loss_x [0.7295], l2_loss_y [1.5921], g_e_loss [2.9574], dz_loss [-1.1066], d_loss [-0.9553]
EGM Initialization Iter [7500] : e_loss_adv [-1.3834], l2_loss_v [0.9478], l2_loss_z [0.5085], l2_loss_x [1.4535], l2_loss_y [0.8933], g_e_loss [2.4197], dz_loss [-1.1365], d_loss [-0.9165]
EGM Initialization Iter [8000] : e_loss_adv [-1.1943], l2_loss_v [1.0483], l2_loss_z [0.4082], l2_loss_x [1.5088], l2_loss_y [1.2450], g_e_loss [3.0161], dz_loss [-0.6851], d_loss [-0.5403]
EGM Initialization Iter [8500] : e_loss_adv [-0.9068], l2_loss_v [0.9925], l2_loss_z [0.5008], l2_loss_x [2.5160], l2_loss_y [1.6454], g_e_loss [4.7481], dz_loss [-0.4941], d_loss [-0.3221]
EGM Initialization Iter [9000] : e_loss_adv [-1.1367], l2_loss_v [1.0227], l2_loss_z [0.4641], l2_loss_x [0.8961], l2_loss_y [1.0888], g_e_loss [2.3350], dz_loss [-1.1502], d_loss [-1.0479]
EGM Initialization Iter [9500] : e_loss_adv [-0.4581], l2_loss_v [1.0066], l2_loss_z [0.4015], l2_loss_x [4.5273], l2_loss_y [1.1357], g_e_loss [6.6130], dz_loss [-0.4729], d_loss [-0.2438]
EGM Initialization Iter [10000] : e_loss_adv [-0.4825], l2_loss_v [1.0119], l2_loss_z [0.3976], l2_loss_x [4.9266], l2_loss_y [1.6185], g_e_loss [7.4722], dz_loss [-0.9087], d_loss [-0.8407]
EGM Initialization Iter [10500] : e_loss_adv [-0.3622], l2_loss_v [0.9236], l2_loss_z [0.3934], l2_loss_x [1.1994], l2_loss_y [0.8895], g_e_loss [3.0437], dz_loss [-0.5223], d_loss [-0.3849]
EGM Initialization Iter [11000] : e_loss_adv [-0.6587], l2_loss_v [1.0022], l2_loss_z [0.3753], l2_loss_x [0.6290], l2_loss_y [2.2966], g_e_loss [3.6443], dz_loss [-0.3311], d_loss [-0.2351]
EGM Initialization Iter [11500] : e_loss_adv [-0.8142], l2_loss_v [0.9219], l2_loss_z [0.3577], l2_loss_x [0.4504], l2_loss_y [0.6233], g_e_loss [1.5392], dz_loss [-0.5330], d_loss [-0.4307]
EGM Initialization Iter [12000] : e_loss_adv [-0.1782], l2_loss_v [0.9856], l2_loss_z [0.4137], l2_loss_x [17.7432], l2_loss_y [1.1676], g_e_loss [20.1319], dz_loss [-0.6628], d_loss [-0.5342]
EGM Initialization Iter [12500] : e_loss_adv [-0.7103], l2_loss_v [1.0942], l2_loss_z [0.3109], l2_loss_x [4.3603], l2_loss_y [0.7972], g_e_loss [5.8523], dz_loss [-0.3978], d_loss [-0.2856]
EGM Initialization Iter [13000] : e_loss_adv [-0.7747], l2_loss_v [1.0407], l2_loss_z [0.3655], l2_loss_x [1.3099], l2_loss_y [0.5965], g_e_loss [2.5379], dz_loss [-0.5982], d_loss [-0.5245]
EGM Initialization Iter [13500] : e_loss_adv [-1.5893], l2_loss_v [0.9822], l2_loss_z [0.3269], l2_loss_x [1.5761], l2_loss_y [1.4505], g_e_loss [2.7463], dz_loss [-0.2089], d_loss [-0.1767]
EGM Initialization Iter [14000] : e_loss_adv [-0.3512], l2_loss_v [0.9724], l2_loss_z [0.3501], l2_loss_x [4.6126], l2_loss_y [1.2828], g_e_loss [6.8669], dz_loss [-0.5639], d_loss [-0.5068]
EGM Initialization Iter [14500] : e_loss_adv [-0.8137], l2_loss_v [1.0103], l2_loss_z [0.4248], l2_loss_x [0.5124], l2_loss_y [1.4730], g_e_loss [2.6068], dz_loss [-0.3226], d_loss [-0.2140]
EGM Initialization Iter [15000] : e_loss_adv [-0.4964], l2_loss_v [0.9213], l2_loss_z [0.2990], l2_loss_x [1.2344], l2_loss_y [0.9027], g_e_loss [2.8611], dz_loss [-0.1730], d_loss [-0.0417]
EGM Initialization Iter [15500] : e_loss_adv [-0.9802], l2_loss_v [1.0312], l2_loss_z [0.3307], l2_loss_x [0.9689], l2_loss_y [1.3733], g_e_loss [2.7239], dz_loss [-0.7252], d_loss [-0.6841]
EGM Initialization Iter [16000] : e_loss_adv [-0.8511], l2_loss_v [0.9647], l2_loss_z [0.3063], l2_loss_x [1.1361], l2_loss_y [1.1887], g_e_loss [2.7447], dz_loss [-0.7088], d_loss [-0.6389]
EGM Initialization Iter [16500] : e_loss_adv [-0.9395], l2_loss_v [1.0061], l2_loss_z [0.3282], l2_loss_x [1.2912], l2_loss_y [1.5184], g_e_loss [3.2045], dz_loss [-0.1869], d_loss [-0.1471]
EGM Initialization Iter [17000] : e_loss_adv [-0.1236], l2_loss_v [1.0263], l2_loss_z [0.3395], l2_loss_x [2.4202], l2_loss_y [0.7661], g_e_loss [4.4285], dz_loss [-0.3217], d_loss [-0.2915]
EGM Initialization Iter [17500] : e_loss_adv [-0.6039], l2_loss_v [1.0110], l2_loss_z [0.2228], l2_loss_x [0.6884], l2_loss_y [1.0420], g_e_loss [2.3603], dz_loss [-0.6893], d_loss [-0.6441]
EGM Initialization Iter [18000] : e_loss_adv [-0.9193], l2_loss_v [1.0120], l2_loss_z [0.2393], l2_loss_x [0.4388], l2_loss_y [0.9639], g_e_loss [1.7347], dz_loss [-0.4117], d_loss [-0.3580]
EGM Initialization Iter [18500] : e_loss_adv [-1.1126], l2_loss_v [0.9904], l2_loss_z [0.2908], l2_loss_x [12.6224], l2_loss_y [1.8989], g_e_loss [14.6900], dz_loss [0.1250], d_loss [0.1782]
EGM Initialization Iter [19000] : e_loss_adv [-0.4944], l2_loss_v [1.0103], l2_loss_z [0.2811], l2_loss_x [0.9594], l2_loss_y [1.4073], g_e_loss [3.1637], dz_loss [-0.7970], d_loss [-0.7325]
EGM Initialization Iter [19500] : e_loss_adv [0.3062], l2_loss_v [0.9363], l2_loss_z [0.2672], l2_loss_x [17.2928], l2_loss_y [1.6232], g_e_loss [20.4258], dz_loss [-0.3881], d_loss [-0.2849]
EGM Initialization Iter [20000] : e_loss_adv [-1.7392], l2_loss_v [1.0347], l2_loss_z [0.2673], l2_loss_x [6.0748], l2_loss_y [1.1972], g_e_loss [6.8348], dz_loss [-0.6097], d_loss [-0.5434]
EGM Initialization Iter [20500] : e_loss_adv [-0.9999], l2_loss_v [0.9217], l2_loss_z [0.2955], l2_loss_x [3.1602], l2_loss_y [1.3958], g_e_loss [4.7734], dz_loss [0.0663], d_loss [0.1290]
EGM Initialization Iter [21000] : e_loss_adv [-0.3025], l2_loss_v [0.9891], l2_loss_z [0.2912], l2_loss_x [3.4434], l2_loss_y [1.3069], g_e_loss [5.7280], dz_loss [-0.2240], d_loss [-0.2015]
EGM Initialization Iter [21500] : e_loss_adv [-0.4693], l2_loss_v [1.0710], l2_loss_z [0.2530], l2_loss_x [0.4649], l2_loss_y [1.3412], g_e_loss [2.6608], dz_loss [-0.6601], d_loss [-0.6297]
EGM Initialization Iter [22000] : e_loss_adv [-1.7933], l2_loss_v [0.9932], l2_loss_z [0.2663], l2_loss_x [1.2690], l2_loss_y [1.4594], g_e_loss [2.1947], dz_loss [-0.0681], d_loss [-0.0063]
EGM Initialization Iter [22500] : e_loss_adv [-0.4883], l2_loss_v [0.9213], l2_loss_z [0.2608], l2_loss_x [0.6639], l2_loss_y [0.6807], g_e_loss [2.0383], dz_loss [-0.1923], d_loss [-0.1554]
EGM Initialization Iter [23000] : e_loss_adv [-0.5152], l2_loss_v [0.9531], l2_loss_z [0.2275], l2_loss_x [0.4341], l2_loss_y [1.0350], g_e_loss [2.1345], dz_loss [-0.2549], d_loss [-0.1995]
EGM Initialization Iter [23500] : e_loss_adv [-0.1655], l2_loss_v [0.9404], l2_loss_z [0.2472], l2_loss_x [0.5480], l2_loss_y [0.6557], g_e_loss [2.2257], dz_loss [-0.1739], d_loss [-0.1279]
EGM Initialization Iter [24000] : e_loss_adv [-0.8649], l2_loss_v [0.9337], l2_loss_z [0.2714], l2_loss_x [1.8735], l2_loss_y [1.4069], g_e_loss [3.6205], dz_loss [-0.2289], d_loss [-0.1504]
EGM Initialization Iter [24500] : e_loss_adv [0.3453], l2_loss_v [0.9769], l2_loss_z [0.2489], l2_loss_x [6.0515], l2_loss_y [0.9008], g_e_loss [8.5234], dz_loss [-0.2344], d_loss [-0.1849]
EGM Initialization Iter [25000] : e_loss_adv [-0.6483], l2_loss_v [0.9621], l2_loss_z [0.2230], l2_loss_x [0.6436], l2_loss_y [1.7994], g_e_loss [2.9798], dz_loss [-0.1736], d_loss [-0.1410]
EGM Initialization Iter [25500] : e_loss_adv [-0.6854], l2_loss_v [1.0132], l2_loss_z [0.2315], l2_loss_x [4.1439], l2_loss_y [2.5138], g_e_loss [7.2171], dz_loss [-0.4864], d_loss [-0.4177]
EGM Initialization Iter [26000] : e_loss_adv [-1.2118], l2_loss_v [1.0288], l2_loss_z [0.2737], l2_loss_x [5.5182], l2_loss_y [1.2290], g_e_loss [6.8380], dz_loss [0.1194], d_loss [0.1722]
EGM Initialization Iter [26500] : e_loss_adv [0.7093], l2_loss_v [1.0771], l2_loss_z [0.2293], l2_loss_x [3.3827], l2_loss_y [1.4104], g_e_loss [6.8087], dz_loss [-0.2914], d_loss [-0.1987]
EGM Initialization Iter [27000] : e_loss_adv [-0.7259], l2_loss_v [0.8823], l2_loss_z [0.2003], l2_loss_x [0.8137], l2_loss_y [0.8930], g_e_loss [2.0634], dz_loss [-0.7017], d_loss [-0.6602]
EGM Initialization Iter [27500] : e_loss_adv [-1.5183], l2_loss_v [0.9569], l2_loss_z [0.1955], l2_loss_x [2.3814], l2_loss_y [1.4743], g_e_loss [3.4897], dz_loss [-0.1165], d_loss [-0.0647]
EGM Initialization Iter [28000] : e_loss_adv [-0.0395], l2_loss_v [1.0021], l2_loss_z [0.2151], l2_loss_x [3.7281], l2_loss_y [0.9516], g_e_loss [5.8575], dz_loss [-0.5957], d_loss [-0.5354]
EGM Initialization Iter [28500] : e_loss_adv [-2.1634], l2_loss_v [1.0258], l2_loss_z [0.1998], l2_loss_x [0.9157], l2_loss_y [0.8383], g_e_loss [0.8162], dz_loss [-0.1053], d_loss [-0.0697]
EGM Initialization Iter [29000] : e_loss_adv [-0.0149], l2_loss_v [0.9536], l2_loss_z [0.2508], l2_loss_x [2.8241], l2_loss_y [1.4126], g_e_loss [5.4262], dz_loss [-0.6389], d_loss [-0.6117]
EGM Initialization Iter [29500] : e_loss_adv [-0.8025], l2_loss_v [0.9960], l2_loss_z [0.2465], l2_loss_x [0.9598], l2_loss_y [1.1993], g_e_loss [2.5991], dz_loss [-0.3781], d_loss [-0.2946]
EGM Initialization Iter [30000] : e_loss_adv [-1.7520], l2_loss_v [0.9474], l2_loss_z [0.2136], l2_loss_x [1.0793], l2_loss_y [0.7894], g_e_loss [1.2777], dz_loss [-0.4842], d_loss [-0.4431]
EGM Initialization Ends.
Initialize latent variables Z with e(V)...
Iterative Updating Starts ...
Epoch 0/100: 100%|██████████| 625/625 [00:17<00:00, 34.93batch/s, loss_px_z: [1.0017], loss_mse_x: [0.7962], loss_py_z: [1.7299], loss_mse_y: [1.8524], loss_pv_z: [104.1091], loss_mse_v: [0.9552], loss_postrior_z: [101.2081]]
Epoch [0/100]: MSE_x: 2.1162, MSE_y: 1.2035, MSE_v: 0.9740

Epoch 1/100: 100%|██████████| 625/625 [00:11<00:00, 55.72batch/s, loss_px_z: [1.0627], loss_mse_x: [0.9919], loss_py_z: [1.3802], loss_mse_y: [1.3957], loss_pv_z: [101.1004], loss_mse_v: [0.9326], loss_postrior_z: [99.0015]]
Epoch 2/100: 100%|██████████| 625/625 [00:11<00:00, 54.91batch/s, loss_px_z: [2.9263], loss_mse_x: [4.9884], loss_py_z: [1.2504], loss_mse_y: [1.2226], loss_pv_z: [104.0898], loss_mse_v: [0.9618], loss_postrior_z: [104.2342]]
Epoch 3/100: 100%|██████████| 625/625 [00:11<00:00, 55.40batch/s, loss_px_z: [1.3514], loss_mse_x: [1.8893], loss_py_z: [1.2817], loss_mse_y: [1.3911], loss_pv_z: [111.9185], loss_mse_v: [1.0400], loss_postrior_z: [111.6363]]
Epoch 4/100: 100%|██████████| 625/625 [00:11<00:00, 54.95batch/s, loss_px_z: [0.8047], loss_mse_x: [0.7948], loss_py_z: [1.1081], loss_mse_y: [1.0806], loss_pv_z: [109.0874], loss_mse_v: [1.0118], loss_postrior_z: [105.9435]]
Epoch 5/100: 100%|██████████| 625/625 [00:11<00:00, 54.52batch/s, loss_px_z: [1.1618], loss_mse_x: [1.1868], loss_py_z: [1.1339], loss_mse_y: [1.0139], loss_pv_z: [110.8662], loss_mse_v: [1.0276], loss_postrior_z: [109.2909]]
Epoch 6/100: 100%|██████████| 625/625 [00:11<00:00, 55.03batch/s, loss_px_z: [0.8462], loss_mse_x: [0.7771], loss_py_z: [1.3909], loss_mse_y: [1.6178], loss_pv_z: [103.7796], loss_mse_v: [0.9596], loss_postrior_z: [101.4011]]
Epoch 7/100: 100%|██████████| 625/625 [00:11<00:00, 54.77batch/s, loss_px_z: [0.9922], loss_mse_x: [1.4759], loss_py_z: [0.9513], loss_mse_y: [0.6665], loss_pv_z: [104.8899], loss_mse_v: [0.9705], loss_postrior_z: [102.6353]]
Epoch 8/100: 100%|██████████| 625/625 [00:11<00:00, 53.82batch/s, loss_px_z: [0.7981], loss_mse_x: [0.6528], loss_py_z: [1.3613], loss_mse_y: [1.5980], loss_pv_z: [109.5272], loss_mse_v: [1.0158], loss_postrior_z: [106.3503]]
Epoch 9/100: 100%|██████████| 625/625 [00:11<00:00, 54.67batch/s, loss_px_z: [0.7962], loss_mse_x: [0.5568], loss_py_z: [1.0703], loss_mse_y: [0.9598], loss_pv_z: [104.8426], loss_mse_v: [0.9710], loss_postrior_z: [101.9698]]
Epoch 10/100: 100%|██████████| 625/625 [00:11<00:00, 54.84batch/s, loss_px_z: [0.5672], loss_mse_x: [0.5001], loss_py_z: [1.3016], loss_mse_y: [1.4453], loss_pv_z: [98.7699], loss_mse_v: [0.9122], loss_postrior_z: [96.4060]]
Epoch [10/100]: MSE_x: 2.2314, MSE_y: 1.2006, MSE_v: 0.9687

Epoch 11/100: 100%|██████████| 625/625 [00:11<00:00, 54.30batch/s, loss_px_z: [0.5795], loss_mse_x: [0.6802], loss_py_z: [1.1793], loss_mse_y: [1.2066], loss_pv_z: [96.5351], loss_mse_v: [0.8871], loss_postrior_z: [93.4869]]
Epoch 12/100: 100%|██████████| 625/625 [00:11<00:00, 54.78batch/s, loss_px_z: [1.1033], loss_mse_x: [2.4904], loss_py_z: [1.1913], loss_mse_y: [1.2767], loss_pv_z: [98.9891], loss_mse_v: [0.9151], loss_postrior_z: [95.6955]]
Epoch 13/100: 100%|██████████| 625/625 [00:11<00:00, 54.29batch/s, loss_px_z: [1.6449], loss_mse_x: [1.9944], loss_py_z: [1.1341], loss_mse_y: [1.1039], loss_pv_z: [104.2422], loss_mse_v: [0.9682], loss_postrior_z: [103.0387]]
Epoch 14/100: 100%|██████████| 625/625 [00:11<00:00, 54.24batch/s, loss_px_z: [0.6592], loss_mse_x: [2.1695], loss_py_z: [1.0025], loss_mse_y: [0.8883], loss_pv_z: [102.1708], loss_mse_v: [0.9453], loss_postrior_z: [99.7335]]
Epoch 15/100: 100%|██████████| 625/625 [00:11<00:00, 55.68batch/s, loss_px_z: [0.7443], loss_mse_x: [1.6535], loss_py_z: [1.1693], loss_mse_y: [1.2531], loss_pv_z: [103.2916], loss_mse_v: [0.9567], loss_postrior_z: [101.1586]]
Epoch 16/100: 100%|██████████| 625/625 [00:11<00:00, 55.40batch/s, loss_px_z: [0.7354], loss_mse_x: [0.8059], loss_py_z: [1.1732], loss_mse_y: [1.2717], loss_pv_z: [106.4319], loss_mse_v: [0.9858], loss_postrior_z: [104.2073]]
Epoch 17/100: 100%|██████████| 625/625 [00:11<00:00, 55.67batch/s, loss_px_z: [2.1600], loss_mse_x: [2.8464], loss_py_z: [1.0883], loss_mse_y: [1.1696], loss_pv_z: [106.7967], loss_mse_v: [0.9891], loss_postrior_z: [106.0688]]
Epoch 18/100: 100%|██████████| 625/625 [00:11<00:00, 54.89batch/s, loss_px_z: [0.7860], loss_mse_x: [2.6597], loss_py_z: [1.0036], loss_mse_y: [0.9227], loss_pv_z: [102.5790], loss_mse_v: [0.9485], loss_postrior_z: [100.5853]]
Epoch 19/100: 100%|██████████| 625/625 [00:11<00:00, 53.88batch/s, loss_px_z: [0.9116], loss_mse_x: [0.9684], loss_py_z: [0.9765], loss_mse_y: [0.9605], loss_pv_z: [105.0637], loss_mse_v: [0.9744], loss_postrior_z: [103.2852]]
Epoch 20/100: 100%|██████████| 625/625 [00:11<00:00, 54.23batch/s, loss_px_z: [1.1986], loss_mse_x: [7.3664], loss_py_z: [1.3686], loss_mse_y: [2.0824], loss_pv_z: [103.5854], loss_mse_v: [0.9616], loss_postrior_z: [101.7267]]
Epoch [20/100]: MSE_x: 2.3551, MSE_y: 1.2030, MSE_v: 0.9676

Epoch 21/100: 100%|██████████| 625/625 [00:11<00:00, 54.97batch/s, loss_px_z: [0.6754], loss_mse_x: [0.8088], loss_py_z: [1.1038], loss_mse_y: [1.2782], loss_pv_z: [100.8194], loss_mse_v: [0.9332], loss_postrior_z: [97.6035]]
Epoch 22/100: 100%|██████████| 625/625 [00:11<00:00, 55.75batch/s, loss_px_z: [1.5267], loss_mse_x: [1.2338], loss_py_z: [0.9382], loss_mse_y: [0.8040], loss_pv_z: [104.1450], loss_mse_v: [0.9673], loss_postrior_z: [102.7854]]
Epoch 23/100: 100%|██████████| 625/625 [00:11<00:00, 54.29batch/s, loss_px_z: [0.7087], loss_mse_x: [1.0746], loss_py_z: [1.1086], loss_mse_y: [1.3181], loss_pv_z: [105.0445], loss_mse_v: [0.9765], loss_postrior_z: [102.6641]]
Epoch 24/100: 100%|██████████| 625/625 [00:11<00:00, 54.38batch/s, loss_px_z: [0.3100], loss_mse_x: [0.4896], loss_py_z: [1.1030], loss_mse_y: [1.2588], loss_pv_z: [103.4497], loss_mse_v: [0.9582], loss_postrior_z: [100.2494]]
Epoch 25/100: 100%|██████████| 625/625 [00:11<00:00, 55.14batch/s, loss_px_z: [0.4496], loss_mse_x: [1.0667], loss_py_z: [0.9804], loss_mse_y: [1.0471], loss_pv_z: [106.0292], loss_mse_v: [0.9863], loss_postrior_z: [103.5820]]
Epoch 26/100: 100%|██████████| 625/625 [00:11<00:00, 54.07batch/s, loss_px_z: [0.9430], loss_mse_x: [1.7313], loss_py_z: [0.9116], loss_mse_y: [0.8233], loss_pv_z: [111.1543], loss_mse_v: [1.0376], loss_postrior_z: [108.5810]]
Epoch 27/100: 100%|██████████| 625/625 [00:11<00:00, 55.23batch/s, loss_px_z: [0.3157], loss_mse_x: [0.3745], loss_py_z: [0.8482], loss_mse_y: [0.7153], loss_pv_z: [101.8660], loss_mse_v: [0.9439], loss_postrior_z: [100.0770]]
Epoch 28/100: 100%|██████████| 625/625 [00:11<00:00, 53.60batch/s, loss_px_z: [0.8641], loss_mse_x: [1.8446], loss_py_z: [0.9266], loss_mse_y: [0.8325], loss_pv_z: [103.8543], loss_mse_v: [0.9660], loss_postrior_z: [102.6205]]
Epoch 29/100: 100%|██████████| 625/625 [00:11<00:00, 55.00batch/s, loss_px_z: [0.7417], loss_mse_x: [1.1123], loss_py_z: [1.0385], loss_mse_y: [1.1367], loss_pv_z: [107.2606], loss_mse_v: [0.9963], loss_postrior_z: [104.4046]]
Epoch 30/100: 100%|██████████| 625/625 [00:11<00:00, 55.33batch/s, loss_px_z: [0.7017], loss_mse_x: [2.0424], loss_py_z: [1.0338], loss_mse_y: [1.0654], loss_pv_z: [108.2052], loss_mse_v: [1.0073], loss_postrior_z: [105.8030]]
Epoch [30/100]: MSE_x: 2.1616, MSE_y: 1.2084, MSE_v: 0.9667

Epoch 31/100: 100%|██████████| 625/625 [00:11<00:00, 55.53batch/s, loss_px_z: [0.5098], loss_mse_x: [0.6142], loss_py_z: [1.1598], loss_mse_y: [1.4640], loss_pv_z: [101.9042], loss_mse_v: [0.9434], loss_postrior_z: [98.7518]]
Epoch 32/100: 100%|██████████| 625/625 [00:11<00:00, 55.46batch/s, loss_px_z: [0.2685], loss_mse_x: [0.3559], loss_py_z: [0.8601], loss_mse_y: [0.7184], loss_pv_z: [99.3381], loss_mse_v: [0.9203], loss_postrior_z: [95.7403]]
Epoch 33/100: 100%|██████████| 625/625 [00:11<00:00, 54.63batch/s, loss_px_z: [0.6054], loss_mse_x: [0.6897], loss_py_z: [1.1666], loss_mse_y: [1.4648], loss_pv_z: [110.0090], loss_mse_v: [1.0229], loss_postrior_z: [107.6170]]
Epoch 34/100: 100%|██████████| 625/625 [00:11<00:00, 55.32batch/s, loss_px_z: [0.6286], loss_mse_x: [1.4439], loss_py_z: [1.1934], loss_mse_y: [1.4473], loss_pv_z: [105.3239], loss_mse_v: [0.9798], loss_postrior_z: [103.2204]]
Epoch 35/100: 100%|██████████| 625/625 [00:11<00:00, 53.97batch/s, loss_px_z: [0.6942], loss_mse_x: [1.5184], loss_py_z: [1.0618], loss_mse_y: [1.2188], loss_pv_z: [108.3755], loss_mse_v: [1.0158], loss_postrior_z: [107.3421]]
Epoch 36/100: 100%|██████████| 625/625 [00:11<00:00, 55.26batch/s, loss_px_z: [0.4048], loss_mse_x: [0.7296], loss_py_z: [1.0736], loss_mse_y: [1.2077], loss_pv_z: [98.7979], loss_mse_v: [0.9153], loss_postrior_z: [96.1371]]
Epoch 37/100: 100%|██████████| 625/625 [00:11<00:00, 56.41batch/s, loss_px_z: [0.6732], loss_mse_x: [2.5477], loss_py_z: [1.1506], loss_mse_y: [1.4312], loss_pv_z: [99.2480], loss_mse_v: [0.9178], loss_postrior_z: [97.5223]]
Epoch 38/100: 100%|██████████| 625/625 [00:11<00:00, 55.61batch/s, loss_px_z: [0.5150], loss_mse_x: [1.2934], loss_py_z: [0.8623], loss_mse_y: [0.8216], loss_pv_z: [106.1062], loss_mse_v: [0.9891], loss_postrior_z: [103.6298]]
Epoch 39/100: 100%|██████████| 625/625 [00:11<00:00, 55.37batch/s, loss_px_z: [0.9153], loss_mse_x: [3.1581], loss_py_z: [0.9936], loss_mse_y: [1.0444], loss_pv_z: [105.5084], loss_mse_v: [0.9839], loss_postrior_z: [103.6738]]
Epoch 40/100: 100%|██████████| 625/625 [00:11<00:00, 54.97batch/s, loss_px_z: [0.6372], loss_mse_x: [1.2377], loss_py_z: [0.8759], loss_mse_y: [0.8238], loss_pv_z: [105.4479], loss_mse_v: [0.9848], loss_postrior_z: [105.0747]]
Epoch [40/100]: MSE_x: 2.2160, MSE_y: 1.2174, MSE_v: 0.9662

Epoch 41/100: 100%|██████████| 625/625 [00:11<00:00, 55.49batch/s, loss_px_z: [0.3933], loss_mse_x: [0.7077], loss_py_z: [0.8098], loss_mse_y: [0.7087], loss_pv_z: [100.8510], loss_mse_v: [0.9372], loss_postrior_z: [98.1798]]
Epoch 42/100: 100%|██████████| 625/625 [00:11<00:00, 55.60batch/s, loss_px_z: [0.6129], loss_mse_x: [1.1368], loss_py_z: [1.1585], loss_mse_y: [1.5237], loss_pv_z: [100.2316], loss_mse_v: [0.9312], loss_postrior_z: [97.8113]]
Epoch 43/100: 100%|██████████| 625/625 [00:11<00:00, 54.80batch/s, loss_px_z: [0.3872], loss_mse_x: [0.6784], loss_py_z: [1.0379], loss_mse_y: [1.2025], loss_pv_z: [102.1871], loss_mse_v: [0.9516], loss_postrior_z: [99.1012]]
Epoch 44/100: 100%|██████████| 625/625 [00:11<00:00, 54.81batch/s, loss_px_z: [0.5947], loss_mse_x: [1.2953], loss_py_z: [0.9499], loss_mse_y: [0.9979], loss_pv_z: [106.5104], loss_mse_v: [0.9883], loss_postrior_z: [104.6743]]
Epoch 45/100: 100%|██████████| 625/625 [00:11<00:00, 54.32batch/s, loss_px_z: [0.6412], loss_mse_x: [3.7783], loss_py_z: [0.9002], loss_mse_y: [0.9011], loss_pv_z: [97.6416], loss_mse_v: [0.9105], loss_postrior_z: [94.2107]]
Epoch 46/100: 100%|██████████| 625/625 [00:11<00:00, 55.17batch/s, loss_px_z: [0.6056], loss_mse_x: [1.3386], loss_py_z: [0.9334], loss_mse_y: [0.9600], loss_pv_z: [104.2481], loss_mse_v: [0.9693], loss_postrior_z: [102.1923]]
Epoch 47/100: 100%|██████████| 625/625 [00:11<00:00, 54.82batch/s, loss_px_z: [0.6723], loss_mse_x: [7.5391], loss_py_z: [0.9108], loss_mse_y: [0.9672], loss_pv_z: [103.6044], loss_mse_v: [0.9630], loss_postrior_z: [102.2846]]
Epoch 48/100: 100%|██████████| 625/625 [00:11<00:00, 54.92batch/s, loss_px_z: [0.2197], loss_mse_x: [0.4434], loss_py_z: [1.0514], loss_mse_y: [1.4217], loss_pv_z: [104.7859], loss_mse_v: [0.9797], loss_postrior_z: [101.7223]]
Epoch 49/100: 100%|██████████| 625/625 [00:11<00:00, 55.71batch/s, loss_px_z: [0.2410], loss_mse_x: [0.4813], loss_py_z: [0.9054], loss_mse_y: [1.0336], loss_pv_z: [103.0146], loss_mse_v: [0.9606], loss_postrior_z: [101.0885]]
Epoch 50/100: 100%|██████████| 625/625 [00:11<00:00, 54.98batch/s, loss_px_z: [0.5846], loss_mse_x: [1.2790], loss_py_z: [0.9430], loss_mse_y: [1.0161], loss_pv_z: [111.7826], loss_mse_v: [1.0438], loss_postrior_z: [110.1667]]
Epoch [50/100]: MSE_x: 2.1777, MSE_y: 1.2043, MSE_v: 0.9657

Epoch 51/100: 100%|██████████| 625/625 [00:11<00:00, 54.96batch/s, loss_px_z: [1.0339], loss_mse_x: [1.5230], loss_py_z: [0.9928], loss_mse_y: [1.1792], loss_pv_z: [104.4264], loss_mse_v: [0.9691], loss_postrior_z: [102.5903]]
Epoch 52/100: 100%|██████████| 625/625 [00:11<00:00, 54.79batch/s, loss_px_z: [0.7486], loss_mse_x: [0.7666], loss_py_z: [1.1842], loss_mse_y: [1.5948], loss_pv_z: [102.5457], loss_mse_v: [0.9553], loss_postrior_z: [101.0967]]
Epoch 53/100: 100%|██████████| 625/625 [00:11<00:00, 54.98batch/s, loss_px_z: [0.3077], loss_mse_x: [1.0589], loss_py_z: [1.0254], loss_mse_y: [1.2029], loss_pv_z: [104.0903], loss_mse_v: [0.9721], loss_postrior_z: [101.5168]]
Epoch 54/100: 100%|██████████| 625/625 [00:11<00:00, 54.34batch/s, loss_px_z: [0.4946], loss_mse_x: [0.9493], loss_py_z: [0.9811], loss_mse_y: [1.1906], loss_pv_z: [103.0164], loss_mse_v: [0.9565], loss_postrior_z: [100.3452]]
Epoch 55/100: 100%|██████████| 625/625 [00:11<00:00, 55.08batch/s, loss_px_z: [0.3406], loss_mse_x: [5.0589], loss_py_z: [0.8992], loss_mse_y: [1.0087], loss_pv_z: [98.9583], loss_mse_v: [0.9173], loss_postrior_z: [97.5815]]
Epoch 56/100: 100%|██████████| 625/625 [00:11<00:00, 54.77batch/s, loss_px_z: [0.3262], loss_mse_x: [0.3477], loss_py_z: [1.0027], loss_mse_y: [1.1875], loss_pv_z: [107.3604], loss_mse_v: [1.0007], loss_postrior_z: [104.1735]]
Epoch 57/100: 100%|██████████| 625/625 [00:11<00:00, 55.67batch/s, loss_px_z: [0.2425], loss_mse_x: [0.5210], loss_py_z: [0.8429], loss_mse_y: [0.8359], loss_pv_z: [105.6930], loss_mse_v: [0.9835], loss_postrior_z: [101.8994]]
Epoch 58/100: 100%|██████████| 625/625 [00:11<00:00, 55.48batch/s, loss_px_z: [0.4562], loss_mse_x: [1.4668], loss_py_z: [0.9041], loss_mse_y: [1.0505], loss_pv_z: [106.8178], loss_mse_v: [0.9959], loss_postrior_z: [105.5796]]
Epoch 59/100: 100%|██████████| 625/625 [00:11<00:00, 54.85batch/s, loss_px_z: [0.1841], loss_mse_x: [0.9394], loss_py_z: [1.0556], loss_mse_y: [1.6084], loss_pv_z: [95.1684], loss_mse_v: [0.8880], loss_postrior_z: [93.4943]]
Epoch 60/100: 100%|██████████| 625/625 [00:11<00:00, 54.79batch/s, loss_px_z: [0.4262], loss_mse_x: [0.3860], loss_py_z: [0.7917], loss_mse_y: [0.8005], loss_pv_z: [104.1415], loss_mse_v: [0.9713], loss_postrior_z: [102.0446]]
Epoch [60/100]: MSE_x: 2.1817, MSE_y: 1.1895, MSE_v: 0.9652

Epoch 61/100: 100%|██████████| 625/625 [00:11<00:00, 55.19batch/s, loss_px_z: [0.3869], loss_mse_x: [0.7310], loss_py_z: [1.0391], loss_mse_y: [1.3771], loss_pv_z: [104.6561], loss_mse_v: [0.9732], loss_postrior_z: [101.3780]]
Epoch 62/100: 100%|██████████| 625/625 [00:11<00:00, 54.76batch/s, loss_px_z: [0.5479], loss_mse_x: [0.6498], loss_py_z: [0.8399], loss_mse_y: [0.8538], loss_pv_z: [101.8596], loss_mse_v: [0.9493], loss_postrior_z: [99.5369]]
Epoch 63/100: 100%|██████████| 625/625 [00:11<00:00, 55.01batch/s, loss_px_z: [0.5373], loss_mse_x: [0.9175], loss_py_z: [0.9829], loss_mse_y: [1.1854], loss_pv_z: [102.4637], loss_mse_v: [0.9533], loss_postrior_z: [100.1760]]
Epoch 64/100: 100%|██████████| 625/625 [00:11<00:00, 55.09batch/s, loss_px_z: [0.2867], loss_mse_x: [1.5419], loss_py_z: [0.9317], loss_mse_y: [1.2228], loss_pv_z: [96.9979], loss_mse_v: [0.9007], loss_postrior_z: [95.0470]]
Epoch 65/100: 100%|██████████| 625/625 [00:11<00:00, 55.44batch/s, loss_px_z: [0.6300], loss_mse_x: [1.4146], loss_py_z: [0.8759], loss_mse_y: [0.9186], loss_pv_z: [107.2320], loss_mse_v: [1.0003], loss_postrior_z: [105.5436]]
Epoch 66/100: 100%|██████████| 625/625 [00:11<00:00, 53.91batch/s, loss_px_z: [0.3859], loss_mse_x: [2.8775], loss_py_z: [0.9805], loss_mse_y: [1.2691], loss_pv_z: [104.1365], loss_mse_v: [0.9740], loss_postrior_z: [104.6624]]
Epoch 67/100: 100%|██████████| 625/625 [00:11<00:00, 55.30batch/s, loss_px_z: [0.5155], loss_mse_x: [2.5749], loss_py_z: [0.8094], loss_mse_y: [0.8627], loss_pv_z: [104.7769], loss_mse_v: [0.9788], loss_postrior_z: [102.0515]]
Epoch 68/100: 100%|██████████| 625/625 [00:11<00:00, 55.20batch/s, loss_px_z: [0.2483], loss_mse_x: [0.5678], loss_py_z: [1.0396], loss_mse_y: [1.2972], loss_pv_z: [104.7215], loss_mse_v: [0.9814], loss_postrior_z: [102.0759]]
Epoch 69/100: 100%|██████████| 625/625 [00:11<00:00, 55.26batch/s, loss_px_z: [1.4218], loss_mse_x: [1.1111], loss_py_z: [0.8525], loss_mse_y: [0.9750], loss_pv_z: [103.8017], loss_mse_v: [0.9686], loss_postrior_z: [102.3827]]
Epoch 70/100: 100%|██████████| 625/625 [00:11<00:00, 54.61batch/s, loss_px_z: [0.2590], loss_mse_x: [0.5932], loss_py_z: [0.8803], loss_mse_y: [1.1457], loss_pv_z: [104.2802], loss_mse_v: [0.9745], loss_postrior_z: [101.8869]]
Epoch [70/100]: MSE_x: 2.0605, MSE_y: 1.2068, MSE_v: 0.9650

Epoch 71/100: 100%|██████████| 625/625 [00:11<00:00, 54.63batch/s, loss_px_z: [0.9122], loss_mse_x: [4.1757], loss_py_z: [0.8866], loss_mse_y: [1.1041], loss_pv_z: [99.0652], loss_mse_v: [0.9213], loss_postrior_z: [98.8926]]
Epoch 72/100: 100%|██████████| 625/625 [00:11<00:00, 55.25batch/s, loss_px_z: [0.7746], loss_mse_x: [1.6122], loss_py_z: [0.7496], loss_mse_y: [0.7539], loss_pv_z: [104.4348], loss_mse_v: [0.9758], loss_postrior_z: [101.8911]]
Epoch 73/100: 100%|██████████| 625/625 [00:11<00:00, 55.42batch/s, loss_px_z: [0.5040], loss_mse_x: [1.4536], loss_py_z: [0.9477], loss_mse_y: [1.1859], loss_pv_z: [103.3225], loss_mse_v: [0.9635], loss_postrior_z: [101.3829]]
Epoch 74/100: 100%|██████████| 625/625 [00:11<00:00, 55.29batch/s, loss_px_z: [0.5007], loss_mse_x: [0.7652], loss_py_z: [0.9774], loss_mse_y: [1.3177], loss_pv_z: [102.7477], loss_mse_v: [0.9562], loss_postrior_z: [100.9115]]
Epoch 75/100: 100%|██████████| 625/625 [00:11<00:00, 55.45batch/s, loss_px_z: [0.6312], loss_mse_x: [0.5428], loss_py_z: [1.0763], loss_mse_y: [1.4158], loss_pv_z: [108.3980], loss_mse_v: [1.0175], loss_postrior_z: [106.3569]]
Epoch 76/100: 100%|██████████| 625/625 [00:11<00:00, 55.03batch/s, loss_px_z: [1.1616], loss_mse_x: [3.6411], loss_py_z: [0.9174], loss_mse_y: [1.1184], loss_pv_z: [101.4809], loss_mse_v: [0.9504], loss_postrior_z: [100.3156]]
Epoch 77/100: 100%|██████████| 625/625 [00:11<00:00, 55.45batch/s, loss_px_z: [0.2270], loss_mse_x: [1.4151], loss_py_z: [0.7666], loss_mse_y: [0.9107], loss_pv_z: [99.8907], loss_mse_v: [0.9310], loss_postrior_z: [97.7918]]
Epoch 78/100: 100%|██████████| 625/625 [00:11<00:00, 55.30batch/s, loss_px_z: [0.5104], loss_mse_x: [2.5735], loss_py_z: [1.0417], loss_mse_y: [1.3955], loss_pv_z: [102.6700], loss_mse_v: [0.9608], loss_postrior_z: [100.8437]]
Epoch 79/100: 100%|██████████| 625/625 [00:11<00:00, 55.60batch/s, loss_px_z: [0.2810], loss_mse_x: [1.1481], loss_py_z: [0.7788], loss_mse_y: [0.8689], loss_pv_z: [99.6064], loss_mse_v: [0.9245], loss_postrior_z: [99.1888]]
Epoch 80/100: 100%|██████████| 625/625 [00:11<00:00, 56.25batch/s, loss_px_z: [0.4661], loss_mse_x: [0.6606], loss_py_z: [0.9857], loss_mse_y: [1.2555], loss_pv_z: [97.2984], loss_mse_v: [0.9066], loss_postrior_z: [95.6442]]
Epoch [80/100]: MSE_x: 2.0278, MSE_y: 1.1949, MSE_v: 0.9645

Epoch 81/100: 100%|██████████| 625/625 [00:11<00:00, 54.68batch/s, loss_px_z: [0.3389], loss_mse_x: [1.4112], loss_py_z: [0.7984], loss_mse_y: [0.9120], loss_pv_z: [106.6279], loss_mse_v: [0.9946], loss_postrior_z: [105.3045]]
Epoch 82/100: 100%|██████████| 625/625 [00:11<00:00, 53.96batch/s, loss_px_z: [0.3469], loss_mse_x: [0.7758], loss_py_z: [0.7624], loss_mse_y: [0.9110], loss_pv_z: [96.9946], loss_mse_v: [0.9044], loss_postrior_z: [94.0614]]
Epoch 83/100: 100%|██████████| 625/625 [00:11<00:00, 53.81batch/s, loss_px_z: [0.2615], loss_mse_x: [0.3087], loss_py_z: [0.9122], loss_mse_y: [1.1561], loss_pv_z: [110.0109], loss_mse_v: [1.0239], loss_postrior_z: [107.2893]]
Epoch 84/100: 100%|██████████| 625/625 [00:11<00:00, 54.81batch/s, loss_px_z: [0.0938], loss_mse_x: [0.2996], loss_py_z: [0.8764], loss_mse_y: [1.0378], loss_pv_z: [104.0820], loss_mse_v: [0.9701], loss_postrior_z: [101.2576]]
Epoch 85/100: 100%|██████████| 625/625 [00:11<00:00, 54.97batch/s, loss_px_z: [0.7172], loss_mse_x: [1.4699], loss_py_z: [1.0360], loss_mse_y: [1.5079], loss_pv_z: [106.3758], loss_mse_v: [0.9988], loss_postrior_z: [104.7337]]
Epoch 86/100: 100%|██████████| 625/625 [00:11<00:00, 55.05batch/s, loss_px_z: [0.1134], loss_mse_x: [3.9313], loss_py_z: [0.8578], loss_mse_y: [1.3222], loss_pv_z: [99.4824], loss_mse_v: [0.9270], loss_postrior_z: [97.2417]]
Epoch 87/100: 100%|██████████| 625/625 [00:11<00:00, 54.46batch/s, loss_px_z: [0.1477], loss_mse_x: [0.4027], loss_py_z: [0.9585], loss_mse_y: [1.6592], loss_pv_z: [105.9451], loss_mse_v: [0.9952], loss_postrior_z: [104.9820]]
Epoch 88/100: 100%|██████████| 625/625 [00:11<00:00, 54.60batch/s, loss_px_z: [0.7088], loss_mse_x: [0.8604], loss_py_z: [0.9679], loss_mse_y: [1.2659], loss_pv_z: [109.0925], loss_mse_v: [1.0223], loss_postrior_z: [106.7311]]
Epoch 89/100: 100%|██████████| 625/625 [00:11<00:00, 54.62batch/s, loss_px_z: [0.3217], loss_mse_x: [0.6296], loss_py_z: [0.7441], loss_mse_y: [0.8663], loss_pv_z: [103.1253], loss_mse_v: [0.9651], loss_postrior_z: [101.4039]]
Epoch 90/100: 100%|██████████| 625/625 [00:11<00:00, 54.76batch/s, loss_px_z: [0.3576], loss_mse_x: [2.2540], loss_py_z: [0.9933], loss_mse_y: [1.3456], loss_pv_z: [103.8477], loss_mse_v: [0.9681], loss_postrior_z: [101.4713]]
Epoch [90/100]: MSE_x: 2.0580, MSE_y: 1.1633, MSE_v: 0.9644

Epoch 91/100: 100%|██████████| 625/625 [00:11<00:00, 54.66batch/s, loss_px_z: [0.1603], loss_mse_x: [0.4676], loss_py_z: [0.7444], loss_mse_y: [0.8241], loss_pv_z: [108.7655], loss_mse_v: [1.0234], loss_postrior_z: [106.3370]]
Epoch 92/100: 100%|██████████| 625/625 [00:11<00:00, 53.44batch/s, loss_px_z: [0.1892], loss_mse_x: [0.7174], loss_py_z: [1.1678], loss_mse_y: [1.8073], loss_pv_z: [113.5816], loss_mse_v: [1.0619], loss_postrior_z: [111.2462]]
Epoch 93/100: 100%|██████████| 625/625 [00:11<00:00, 54.54batch/s, loss_px_z: [0.1294], loss_mse_x: [0.7718], loss_py_z: [0.8597], loss_mse_y: [1.1468], loss_pv_z: [99.1875], loss_mse_v: [0.9276], loss_postrior_z: [97.8237]]
Epoch 94/100: 100%|██████████| 625/625 [00:11<00:00, 54.27batch/s, loss_px_z: [0.8639], loss_mse_x: [1.9431], loss_py_z: [1.0255], loss_mse_y: [1.4076], loss_pv_z: [102.7940], loss_mse_v: [0.9620], loss_postrior_z: [101.0259]]
Epoch 95/100: 100%|██████████| 625/625 [00:11<00:00, 55.64batch/s, loss_px_z: [0.4005], loss_mse_x: [1.7976], loss_py_z: [0.8490], loss_mse_y: [1.1049], loss_pv_z: [100.7964], loss_mse_v: [0.9436], loss_postrior_z: [98.7635]]
Epoch 96/100: 100%|██████████| 625/625 [00:11<00:00, 55.98batch/s, loss_px_z: [0.0117], loss_mse_x: [0.2462], loss_py_z: [0.7910], loss_mse_y: [0.9281], loss_pv_z: [109.1485], loss_mse_v: [1.0273], loss_postrior_z: [106.6896]]
Epoch 97/100: 100%|██████████| 625/625 [00:11<00:00, 52.94batch/s, loss_px_z: [0.2551], loss_mse_x: [0.8334], loss_py_z: [0.9307], loss_mse_y: [1.2324], loss_pv_z: [102.6908], loss_mse_v: [0.9649], loss_postrior_z: [100.4549]]
Epoch 98/100: 100%|██████████| 625/625 [00:11<00:00, 53.89batch/s, loss_px_z: [1.5356], loss_mse_x: [6.9516], loss_py_z: [0.8813], loss_mse_y: [1.1969], loss_pv_z: [103.2901], loss_mse_v: [0.9661], loss_postrior_z: [103.3188]]
Epoch 99/100: 100%|██████████| 625/625 [00:11<00:00, 55.43batch/s, loss_px_z: [0.3029], loss_mse_x: [2.1966], loss_py_z: [0.6606], loss_mse_y: [0.6747], loss_pv_z: [103.4311], loss_mse_v: [0.9697], loss_postrior_z: [101.9436]]
Epoch 100/100: 100%|██████████| 625/625 [00:11<00:00, 55.07batch/s, loss_px_z: [0.1279], loss_mse_x: [0.5865], loss_py_z: [0.7868], loss_mse_y: [0.9442], loss_pv_z: [105.4194], loss_mse_v: [0.9945], loss_postrior_z: [102.6229]]
Epoch [100/100]: MSE_x: 2.0460, MSE_y: 1.1746, MSE_v: 0.9638

Make predictions using the trained CausalBGM model

Estimate causal effects with posterior intervals from latent MCMC samples.

Config Parameter

Description

data

Tuple of data inputs (x, y, v), Required.

alpha

Significance level for the posterior interval. Default: 0.01.

n_mcmc

Number of posterior MCMC samples to draw. Default: 3000.

burn_in

Number of burn-in MCMC samples before drawing. Default: 5000.

x_values

Treatment value(s) for dose-response function to be predicted. Examples: 1.0 or [1.0,2.0]

q_sd

Standard deviation for the proposal distribution used in Metropolis-Hastings (MH) sampling. Default: 1.0.

sample_y

Whether to consider the variance function in the outcome generative model. Default: True.

bs

Batch size in inference stage, denoting number of test subjects processed per batch prediction. Default: 10000.

Return

Type

Description

Shape

pre_adrf_mean

np.ndarray

Point estimates of the Average Dose-Response Function.

(len(x_values),)

pre_adrf_PI

np.ndarray

Posterior intervals for the ADRF, representing [lower bound, upper bound]

(len(x_values), 2)

[6]:
pre_adrf_mean, pre_adrf_PI = model.predict(data=(x,y,v), alpha=0.01, n_mcmc=3000, burn_in=5000, x_values=np.linspace(0,3,20), q_sd=1.0, bs=20000)
MCMC Latent Variable Sampling ...
Final MCMC Acceptance Rate: 0.0948

Visualizating and evaluating the results

Show the true average dose-response function (ADRF) and the predicted ADRF.

[7]:
import matplotlib.pyplot as plt
from bayesgm.utils import get_ADRF

x_values = np.linspace(0,3,20)  # treatment values
true_adrf = get_ADRF(x_values = x_values,  dataset='Imbens') # true ADRF for Hirano and Imbens dataset

# Evaluate
rmse = np.sqrt(np.mean((true_adrf-pre_adrf_mean)**2))
mape = np.mean([abs((item[0]-item[1])/item[0]) for item in zip(true_adrf, pre_adrf_mean)])

print(f"RMSE (Root Mean Squared Error): {rmse:.4f}")
print(f"MAPE (Mean Absolute Percentage Error): {mape:.4f}")

# Create the plot
plt.figure(figsize=(10, 6))
# Plot the ground truth curve
plt.plot(x_values, true_adrf, label='True ADRF', color='blue', linestyle='--', linewidth=2)
# Plot the predicted mean curve
plt.plot(x_values, pre_adrf_mean, label='Predicted ADRF', color='red', linewidth=2)
# Plot the posterior intervals
plt.fill_between(x_values, pre_adrf_PI[:, 0], pre_adrf_PI[:, 1], color='red', alpha=0.4, label='Posterior Interval')

# Add labels, legend, and title
plt.xlabel('Treatment Value', fontsize=12)
plt.ylabel('ADRF', fontsize=12)
plt.title('True vs Predicted ADRF (RMSE=%.4f,MAPE=%.4f)'%(rmse, mape), fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, linestyle='--', alpha=0.6)

# Show the plot
plt.show()
RMSE (Root Mean Squared Error): 0.0188
MAPE (Mean Absolute Percentage Error): 0.0103
../_images/causalbgm_tutorial_py_18_1.png

Use CausalBGM Python API in binary treatment setting

We will use ACIC 2018 dataset for an example.

Reminder

  • Set binary_treatment to True in the config file.

  • Make sure the v_dim matches the number of covariates in the dateset.

[8]:
params = yaml.safe_load(open('src/configs/Semi_acic.yaml', 'r'))
print(params)
{'dataset': 'Semi_acic', 'output_dir': '.', 'save_res': True, 'save_model': True, 'binary_treatment': True, 'use_bnn': True, 'z_dims': [3, 6, 3, 6], 'v_dim': 177, 'lr_theta': 0.0001, 'lr_z': 0.0001, 'g_units': [64, 64, 64, 64, 64], 'f_units': [64, 32, 8], 'h_units': [64, 32, 8], 'kl_weight': 0.0001, 'lr': 0.0002, 'g_d_freq': 5, 'use_z_rec': True, 'e_units': [64, 64, 64, 64, 64], 'dz_units': [64, 32, 8]}

Instantiate a CausalBGM model

[9]:
model = bayesgm.models.CausalBGM(params=params,random_seed=None)
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:95: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
  loc = add_variable_fn(
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:105: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
  untransformed_scale = add_variable_fn(

Data preparation

The input data are organized in a triplet, which contains treatment (X), potential outcome (Y), and covariates (V).

[10]:
#get the data from the ACIC 2018 competition dataset with a specified ufid.
x,y,v = bayesgm.datasets.Semi_acic_sampler(path='data/ACIC_2018',ufid='629e3d2c63914e45b227cc913c09cebe').load_all()
print(x.shape,y.shape,v.shape)
(1000, 1) (1000, 1) (1000, 177)

Model training

Train CausalBGM with an optional EGM warm-start.

Config Parameter

Description

data

Tuple of data inputs (x, y, v), Required.

batch_size

Batch size for training. Default: 32.

epochs

Number of epochs for training. Default: 100.

epochs_per_eval

Frequency of evaluations during training (e.g., every 5 epochs). Default: 5.

use_egm_init

Whether to run EGM initialization before iterative training. Default: True.

epochs_per_eval

Frequency of evaluations during training (e.g., every 5 epochs). Default: 5.

egm_n_iter

Number of EGM initialization iterations. Default: 30000.

egm_batches_per_eval

Evaluate EGM initialization every this many iterations. Default: 500.

verbose

Controls verbosity level, showing progress and evaluation metrics. Default: 1.

[11]:
model.fit(data=(x,y,v), epochs=100, epochs_per_eval=10, use_egm_init=True, egm_n_iter=30000, egm_batches_per_eval=500, verbose=1)
EGM Initialization Starts ...
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:95: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
  loc = add_variable_fn(
/home/ql339/.conda/envs/py3.9/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:105: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
  untransformed_scale = add_variable_fn(
EGM Initialization Iter [0] : e_loss_adv [0.1102], l2_loss_v [0.8648], l2_loss_z [1.0305], l2_loss_x [0.6933], l2_loss_y [0.5305], g_e_loss [3.2293], dz_loss [0.0733], d_loss [0.3746]
EGM Initialization Iter [500] : e_loss_adv [2.2379], l2_loss_v [1.0616], l2_loss_z [0.9627], l2_loss_x [0.6788], l2_loss_y [0.1069], g_e_loss [5.0480], dz_loss [-1.7182], d_loss [-1.2625]
EGM Initialization Iter [1000] : e_loss_adv [-0.0245], l2_loss_v [1.4840], l2_loss_z [1.0821], l2_loss_x [0.6829], l2_loss_y [0.0910], g_e_loss [3.3154], dz_loss [-1.3309], d_loss [-1.2576]
EGM Initialization Iter [1500] : e_loss_adv [-1.6068], l2_loss_v [0.4169], l2_loss_z [1.0244], l2_loss_x [0.6923], l2_loss_y [0.0486], g_e_loss [0.5754], dz_loss [-1.0655], d_loss [-0.9715]
EGM Initialization Iter [2000] : e_loss_adv [-1.6149], l2_loss_v [0.6514], l2_loss_z [0.8222], l2_loss_x [0.6996], l2_loss_y [0.1217], g_e_loss [0.6800], dz_loss [-0.8808], d_loss [-0.6829]
EGM Initialization Iter [2500] : e_loss_adv [-1.7912], l2_loss_v [1.6250], l2_loss_z [1.0236], l2_loss_x [0.6911], l2_loss_y [0.1043], g_e_loss [1.6528], dz_loss [-0.5761], d_loss [-0.5062]
EGM Initialization Iter [3000] : e_loss_adv [-2.6811], l2_loss_v [0.6036], l2_loss_z [0.8955], l2_loss_x [0.6970], l2_loss_y [0.1193], g_e_loss [-0.3656], dz_loss [-1.3826], d_loss [-1.1942]
EGM Initialization Iter [3500] : e_loss_adv [-2.8498], l2_loss_v [0.6821], l2_loss_z [0.8155], l2_loss_x [0.6820], l2_loss_y [0.0911], g_e_loss [-0.5792], dz_loss [-0.2488], d_loss [-0.1469]
EGM Initialization Iter [4000] : e_loss_adv [-2.9439], l2_loss_v [0.4767], l2_loss_z [0.9369], l2_loss_x [0.6887], l2_loss_y [0.0918], g_e_loss [-0.7499], dz_loss [-0.6104], d_loss [-0.4594]
EGM Initialization Iter [4500] : e_loss_adv [-2.6417], l2_loss_v [0.7067], l2_loss_z [0.9116], l2_loss_x [0.7060], l2_loss_y [0.1264], g_e_loss [-0.1910], dz_loss [-1.3566], d_loss [-1.0913]
EGM Initialization Iter [5000] : e_loss_adv [-1.0904], l2_loss_v [0.9057], l2_loss_z [0.9335], l2_loss_x [0.7084], l2_loss_y [0.1196], g_e_loss [1.5769], dz_loss [-0.4903], d_loss [-0.3944]
EGM Initialization Iter [5500] : e_loss_adv [-0.8474], l2_loss_v [0.6684], l2_loss_z [0.8527], l2_loss_x [0.6900], l2_loss_y [0.0964], g_e_loss [1.4600], dz_loss [-0.6281], d_loss [-0.5299]
EGM Initialization Iter [6000] : e_loss_adv [-0.8948], l2_loss_v [0.7359], l2_loss_z [0.9014], l2_loss_x [0.6886], l2_loss_y [0.0662], g_e_loss [1.4973], dz_loss [-0.3875], d_loss [-0.2597]
EGM Initialization Iter [6500] : e_loss_adv [-1.1519], l2_loss_v [0.4320], l2_loss_z [0.9488], l2_loss_x [0.6827], l2_loss_y [0.1137], g_e_loss [1.0252], dz_loss [-0.4316], d_loss [-0.3879]
EGM Initialization Iter [7000] : e_loss_adv [-0.1190], l2_loss_v [0.5975], l2_loss_z [0.9308], l2_loss_x [0.6818], l2_loss_y [0.1010], g_e_loss [2.1922], dz_loss [-0.4173], d_loss [-0.3681]
EGM Initialization Iter [7500] : e_loss_adv [-0.0798], l2_loss_v [0.6064], l2_loss_z [0.9438], l2_loss_x [0.6962], l2_loss_y [0.0829], g_e_loss [2.2496], dz_loss [-0.8070], d_loss [-0.7731]
EGM Initialization Iter [8000] : e_loss_adv [2.0215], l2_loss_v [0.4414], l2_loss_z [1.0132], l2_loss_x [0.6978], l2_loss_y [0.0823], g_e_loss [4.2563], dz_loss [-0.0485], d_loss [0.0026]
EGM Initialization Iter [8500] : e_loss_adv [1.7494], l2_loss_v [0.5732], l2_loss_z [0.9103], l2_loss_x [0.6829], l2_loss_y [0.0619], g_e_loss [3.9777], dz_loss [-0.5275], d_loss [-0.4650]
EGM Initialization Iter [9000] : e_loss_adv [-0.1547], l2_loss_v [0.5875], l2_loss_z [0.9173], l2_loss_x [0.6891], l2_loss_y [0.0647], g_e_loss [2.1039], dz_loss [-0.6923], d_loss [-0.6058]
EGM Initialization Iter [9500] : e_loss_adv [-1.8302], l2_loss_v [0.5040], l2_loss_z [0.9089], l2_loss_x [0.6838], l2_loss_y [0.0822], g_e_loss [0.3487], dz_loss [-0.4017], d_loss [-0.3120]
EGM Initialization Iter [10000] : e_loss_adv [-1.1754], l2_loss_v [0.4594], l2_loss_z [0.8607], l2_loss_x [0.6830], l2_loss_y [0.0497], g_e_loss [0.8774], dz_loss [-0.7414], d_loss [-0.6600]
EGM Initialization Iter [10500] : e_loss_adv [0.0125], l2_loss_v [0.4014], l2_loss_z [0.8827], l2_loss_x [0.6823], l2_loss_y [0.0821], g_e_loss [2.0610], dz_loss [-0.4633], d_loss [-0.4381]
EGM Initialization Iter [11000] : e_loss_adv [1.2771], l2_loss_v [0.8815], l2_loss_z [0.7977], l2_loss_x [0.6933], l2_loss_y [0.1165], g_e_loss [3.7661], dz_loss [-0.5535], d_loss [-0.4989]
EGM Initialization Iter [11500] : e_loss_adv [0.8197], l2_loss_v [0.6996], l2_loss_z [0.9713], l2_loss_x [0.6715], l2_loss_y [0.0668], g_e_loss [3.2289], dz_loss [-0.5401], d_loss [-0.4833]
EGM Initialization Iter [12000] : e_loss_adv [1.1846], l2_loss_v [0.4105], l2_loss_z [0.8513], l2_loss_x [0.6902], l2_loss_y [0.0928], g_e_loss [3.2294], dz_loss [-0.5215], d_loss [-0.4562]
EGM Initialization Iter [12500] : e_loss_adv [0.5591], l2_loss_v [0.5461], l2_loss_z [0.8962], l2_loss_x [0.6943], l2_loss_y [0.1157], g_e_loss [2.8115], dz_loss [-0.2863], d_loss [-0.2228]
EGM Initialization Iter [13000] : e_loss_adv [0.5567], l2_loss_v [0.5537], l2_loss_z [0.9396], l2_loss_x [0.6648], l2_loss_y [0.0530], g_e_loss [2.7678], dz_loss [-0.8052], d_loss [-0.7681]
EGM Initialization Iter [13500] : e_loss_adv [-0.4126], l2_loss_v [0.6628], l2_loss_z [0.8893], l2_loss_x [0.7001], l2_loss_y [0.0652], g_e_loss [1.9049], dz_loss [0.1546], d_loss [0.2192]
EGM Initialization Iter [14000] : e_loss_adv [1.3701], l2_loss_v [0.4984], l2_loss_z [0.7432], l2_loss_x [0.6523], l2_loss_y [0.0843], g_e_loss [3.3483], dz_loss [-0.1833], d_loss [-0.1407]
EGM Initialization Iter [14500] : e_loss_adv [0.3630], l2_loss_v [0.5708], l2_loss_z [0.8929], l2_loss_x [0.7033], l2_loss_y [0.0717], g_e_loss [2.6017], dz_loss [-0.1090], d_loss [-0.0825]
EGM Initialization Iter [15000] : e_loss_adv [1.1808], l2_loss_v [0.4881], l2_loss_z [0.7314], l2_loss_x [0.6782], l2_loss_y [0.1120], g_e_loss [3.1905], dz_loss [-0.4781], d_loss [-0.4294]
EGM Initialization Iter [15500] : e_loss_adv [0.7721], l2_loss_v [0.5489], l2_loss_z [0.7299], l2_loss_x [0.7268], l2_loss_y [0.1136], g_e_loss [2.8913], dz_loss [-0.4289], d_loss [-0.3480]
EGM Initialization Iter [16000] : e_loss_adv [0.5866], l2_loss_v [0.3066], l2_loss_z [0.8738], l2_loss_x [0.6267], l2_loss_y [0.0903], g_e_loss [2.4841], dz_loss [-0.5187], d_loss [-0.4643]
EGM Initialization Iter [16500] : e_loss_adv [2.6886], l2_loss_v [0.6556], l2_loss_z [0.9029], l2_loss_x [0.6635], l2_loss_y [0.0617], g_e_loss [4.9723], dz_loss [-0.0725], d_loss [-0.0014]
EGM Initialization Iter [17000] : e_loss_adv [0.6549], l2_loss_v [0.7038], l2_loss_z [0.8277], l2_loss_x [0.6232], l2_loss_y [0.0827], g_e_loss [2.8924], dz_loss [-0.2632], d_loss [-0.1660]
EGM Initialization Iter [17500] : e_loss_adv [-0.3580], l2_loss_v [0.5062], l2_loss_z [0.8819], l2_loss_x [0.6791], l2_loss_y [0.0736], g_e_loss [1.7828], dz_loss [-0.1837], d_loss [-0.0903]
EGM Initialization Iter [18000] : e_loss_adv [-0.6044], l2_loss_v [0.4451], l2_loss_z [0.8348], l2_loss_x [0.7131], l2_loss_y [0.0517], g_e_loss [1.4402], dz_loss [-0.3551], d_loss [-0.2829]
EGM Initialization Iter [18500] : e_loss_adv [-0.0970], l2_loss_v [0.3223], l2_loss_z [0.7733], l2_loss_x [0.6027], l2_loss_y [0.0677], g_e_loss [1.6691], dz_loss [-0.4478], d_loss [-0.3706]
EGM Initialization Iter [19000] : e_loss_adv [-0.3036], l2_loss_v [0.4483], l2_loss_z [0.9154], l2_loss_x [0.7860], l2_loss_y [0.0749], g_e_loss [1.9210], dz_loss [-0.0080], d_loss [0.0634]
EGM Initialization Iter [19500] : e_loss_adv [1.4144], l2_loss_v [0.4755], l2_loss_z [0.8437], l2_loss_x [0.5943], l2_loss_y [0.1192], g_e_loss [3.4471], dz_loss [-0.3931], d_loss [-0.3437]
EGM Initialization Iter [20000] : e_loss_adv [1.5499], l2_loss_v [0.5318], l2_loss_z [0.7672], l2_loss_x [0.6411], l2_loss_y [0.0971], g_e_loss [3.5870], dz_loss [-0.3918], d_loss [-0.3298]
EGM Initialization Iter [20500] : e_loss_adv [0.5439], l2_loss_v [0.4782], l2_loss_z [0.8471], l2_loss_x [0.6008], l2_loss_y [0.0701], g_e_loss [2.5401], dz_loss [-0.2296], d_loss [-0.1909]
EGM Initialization Iter [21000] : e_loss_adv [0.5634], l2_loss_v [0.3514], l2_loss_z [0.8784], l2_loss_x [0.6237], l2_loss_y [0.0578], g_e_loss [2.4749], dz_loss [-0.0582], d_loss [-0.0211]
EGM Initialization Iter [21500] : e_loss_adv [0.1586], l2_loss_v [1.0522], l2_loss_z [0.8501], l2_loss_x [0.5285], l2_loss_y [0.0775], g_e_loss [2.6669], dz_loss [-0.4601], d_loss [-0.3838]
EGM Initialization Iter [22000] : e_loss_adv [0.4936], l2_loss_v [0.4876], l2_loss_z [0.8209], l2_loss_x [0.5598], l2_loss_y [0.0686], g_e_loss [2.4305], dz_loss [-0.5799], d_loss [-0.5045]
EGM Initialization Iter [22500] : e_loss_adv [0.5633], l2_loss_v [0.8955], l2_loss_z [0.9021], l2_loss_x [0.4840], l2_loss_y [0.0793], g_e_loss [2.9243], dz_loss [0.0632], d_loss [0.1201]
EGM Initialization Iter [23000] : e_loss_adv [-0.6131], l2_loss_v [1.0310], l2_loss_z [0.8269], l2_loss_x [0.5019], l2_loss_y [0.0767], g_e_loss [1.8233], dz_loss [-0.4884], d_loss [-0.4198]
EGM Initialization Iter [23500] : e_loss_adv [0.2808], l2_loss_v [0.3826], l2_loss_z [0.7574], l2_loss_x [0.6134], l2_loss_y [0.0854], g_e_loss [2.1196], dz_loss [0.0413], d_loss [0.1513]
EGM Initialization Iter [24000] : e_loss_adv [0.7267], l2_loss_v [0.4576], l2_loss_z [0.7689], l2_loss_x [0.5064], l2_loss_y [0.1214], g_e_loss [2.5811], dz_loss [-0.6373], d_loss [-0.5943]
EGM Initialization Iter [24500] : e_loss_adv [1.3909], l2_loss_v [0.5250], l2_loss_z [0.8024], l2_loss_x [0.4561], l2_loss_y [0.0664], g_e_loss [3.2408], dz_loss [-0.1226], d_loss [-0.0769]
EGM Initialization Iter [25000] : e_loss_adv [-0.9248], l2_loss_v [0.4964], l2_loss_z [0.8490], l2_loss_x [0.5210], l2_loss_y [0.1215], g_e_loss [1.0632], dz_loss [0.2103], d_loss [0.2916]
EGM Initialization Iter [25500] : e_loss_adv [-0.7704], l2_loss_v [0.5056], l2_loss_z [0.7879], l2_loss_x [0.3088], l2_loss_y [0.0530], g_e_loss [0.8849], dz_loss [0.2377], d_loss [0.3105]
EGM Initialization Iter [26000] : e_loss_adv [-0.3391], l2_loss_v [0.4905], l2_loss_z [0.8404], l2_loss_x [0.4700], l2_loss_y [0.0738], g_e_loss [1.5356], dz_loss [-0.0343], d_loss [0.0082]
EGM Initialization Iter [26500] : e_loss_adv [-0.4712], l2_loss_v [0.3574], l2_loss_z [0.7607], l2_loss_x [0.4296], l2_loss_y [0.0830], g_e_loss [1.1594], dz_loss [-0.2168], d_loss [-0.1930]
EGM Initialization Iter [27000] : e_loss_adv [0.7430], l2_loss_v [0.5136], l2_loss_z [0.8916], l2_loss_x [0.4302], l2_loss_y [0.0700], g_e_loss [2.6484], dz_loss [-0.6683], d_loss [-0.5880]
EGM Initialization Iter [27500] : e_loss_adv [0.3668], l2_loss_v [0.4964], l2_loss_z [0.9325], l2_loss_x [0.3674], l2_loss_y [0.0685], g_e_loss [2.2315], dz_loss [0.1753], d_loss [0.2039]
EGM Initialization Iter [28000] : e_loss_adv [-0.9352], l2_loss_v [0.5420], l2_loss_z [0.8413], l2_loss_x [0.4752], l2_loss_y [0.0727], g_e_loss [0.9959], dz_loss [-0.3604], d_loss [-0.2736]
EGM Initialization Iter [28500] : e_loss_adv [-0.2398], l2_loss_v [0.4868], l2_loss_z [0.7556], l2_loss_x [0.4528], l2_loss_y [0.0851], g_e_loss [1.5405], dz_loss [-0.3888], d_loss [-0.2976]
EGM Initialization Iter [29000] : e_loss_adv [-0.6478], l2_loss_v [0.5740], l2_loss_z [0.8240], l2_loss_x [0.3217], l2_loss_y [0.0874], g_e_loss [1.1594], dz_loss [-0.1161], d_loss [-0.0649]
EGM Initialization Iter [29500] : e_loss_adv [-0.2740], l2_loss_v [0.5155], l2_loss_z [0.6885], l2_loss_x [0.2661], l2_loss_y [0.0785], g_e_loss [1.2745], dz_loss [-0.3944], d_loss [-0.3385]
EGM Initialization Iter [30000] : e_loss_adv [-0.7436], l2_loss_v [0.5136], l2_loss_z [0.8334], l2_loss_x [0.1789], l2_loss_y [0.0681], g_e_loss [0.8504], dz_loss [-0.0883], d_loss [-0.0273]
EGM Initialization Ends.
Initialize latent variables Z with e(V)...
Iterative Updating Starts ...
Epoch 0/100: 100%|██████████| 32/32 [00:10<00:00,  3.09batch/s, loss_px_z: [0.9413], loss_mse_x: [0.2514], loss_py_z: [0.6371], loss_mse_y: [0.0882], loss_pv_z: [22.3034], loss_mse_v: [0.3698], loss_postrior_z: [30.3454]]
Epoch [0/100]: MSE_x: 0.0967, MSE_y: 0.0816, MSE_v: 0.5076

Saving checkpoint for epoch 0 at ./checkpoints/Semi_acic/20260311_112445/ckpt-0
Epoch 1/100: 100%|██████████| 32/32 [00:00<00:00, 52.06batch/s, loss_px_z: [1.0814], loss_mse_x: [0.3921], loss_py_z: [0.6501], loss_mse_y: [0.1115], loss_pv_z: [59.3399], loss_mse_v: [0.6598], loss_postrior_z: [65.1993]]
Epoch 2/100: 100%|██████████| 32/32 [00:00<00:00, 52.19batch/s, loss_px_z: [1.1222], loss_mse_x: [0.4334], loss_py_z: [0.5929], loss_mse_y: [0.0376], loss_pv_z: [51.3154], loss_mse_v: [0.6012], loss_postrior_z: [56.4026]]
Epoch 3/100: 100%|██████████| 32/32 [00:00<00:00, 52.17batch/s, loss_px_z: [1.0116], loss_mse_x: [0.3234], loss_py_z: [0.5930], loss_mse_y: [0.0414], loss_pv_z: [24.1223], loss_mse_v: [0.3872], loss_postrior_z: [28.4659]]
Epoch 4/100: 100%|██████████| 32/32 [00:00<00:00, 52.53batch/s, loss_px_z: [0.8006], loss_mse_x: [0.1130], loss_py_z: [0.5958], loss_mse_y: [0.0510], loss_pv_z: [20.5702], loss_mse_v: [0.3585], loss_postrior_z: [34.4404]]
Epoch 5/100: 100%|██████████| 32/32 [00:00<00:00, 51.60batch/s, loss_px_z: [1.1752], loss_mse_x: [0.4883], loss_py_z: [0.6078], loss_mse_y: [0.0713], loss_pv_z: [18.7218], loss_mse_v: [0.3485], loss_postrior_z: [24.4937]]
Epoch 6/100: 100%|██████████| 32/32 [00:00<00:00, 50.58batch/s, loss_px_z: [0.8292], loss_mse_x: [0.1428], loss_py_z: [0.7203], loss_mse_y: [0.2263], loss_pv_z: [15.9238], loss_mse_v: [0.3311], loss_postrior_z: [22.7411]]
Epoch 7/100: 100%|██████████| 32/32 [00:00<00:00, 50.07batch/s, loss_px_z: [1.1182], loss_mse_x: [0.4325], loss_py_z: [0.6544], loss_mse_y: [0.1425], loss_pv_z: [49.3744], loss_mse_v: [0.5830], loss_postrior_z: [58.7593]]
Epoch 8/100: 100%|██████████| 32/32 [00:00<00:00, 53.07batch/s, loss_px_z: [0.9144], loss_mse_x: [0.2292], loss_py_z: [0.5907], loss_mse_y: [0.0616], loss_pv_z: [48.1219], loss_mse_v: [0.5634], loss_postrior_z: [59.8836]]
Epoch 9/100: 100%|██████████| 32/32 [00:00<00:00, 54.61batch/s, loss_px_z: [1.2536], loss_mse_x: [0.5691], loss_py_z: [0.5661], loss_mse_y: [0.0364], loss_pv_z: [10.2657], loss_mse_v: [0.2887], loss_postrior_z: [17.2223]]
Epoch 10/100: 100%|██████████| 32/32 [00:00<00:00, 55.06batch/s, loss_px_z: [0.9793], loss_mse_x: [0.2953], loss_py_z: [0.5569], loss_mse_y: [0.0274], loss_pv_z: [37.2485], loss_mse_v: [0.4839], loss_postrior_z: [44.8024]]
Epoch [10/100]: MSE_x: 0.0956, MSE_y: 0.0811, MSE_v: 0.4902

Saving checkpoint for epoch 10 at ./checkpoints/Semi_acic/20260311_112445/ckpt-10
Epoch 11/100: 100%|██████████| 32/32 [00:00<00:00, 54.95batch/s, loss_px_z: [1.1522], loss_mse_x: [0.4688], loss_py_z: [0.6065], loss_mse_y: [0.0960], loss_pv_z: [2.2391], loss_mse_v: [0.2359], loss_postrior_z: [11.8110]]
Epoch 12/100: 100%|██████████| 32/32 [00:00<00:00, 55.30batch/s, loss_px_z: [0.9200], loss_mse_x: [0.2372], loss_py_z: [0.5623], loss_mse_y: [0.0459], loss_pv_z: [20.5715], loss_mse_v: [0.3698], loss_postrior_z: [22.2090]]
Epoch 13/100: 100%|██████████| 32/32 [00:00<00:00, 59.66batch/s, loss_px_z: [1.0428], loss_mse_x: [0.3606], loss_py_z: [0.5959], loss_mse_y: [0.0940], loss_pv_z: [27.0938], loss_mse_v: [0.4216], loss_postrior_z: [33.3546]]
Epoch 14/100: 100%|██████████| 32/32 [00:00<00:00, 59.17batch/s, loss_px_z: [1.3462], loss_mse_x: [0.6646], loss_py_z: [0.5644], loss_mse_y: [0.0568], loss_pv_z: [34.1856], loss_mse_v: [0.5718], loss_postrior_z: [40.3281]]
Epoch 15/100: 100%|██████████| 32/32 [00:00<00:00, 60.14batch/s, loss_px_z: [0.9129], loss_mse_x: [0.2318], loss_py_z: [0.6260], loss_mse_y: [0.1425], loss_pv_z: [57.6632], loss_mse_v: [0.6218], loss_postrior_z: [65.6941]]
Epoch 16/100: 100%|██████████| 32/32 [00:00<00:00, 60.19batch/s, loss_px_z: [1.4620], loss_mse_x: [0.7816], loss_py_z: [0.5453], loss_mse_y: [0.0444], loss_pv_z: [3.3863], loss_mse_v: [0.2562], loss_postrior_z: [10.9483]]
Epoch 17/100: 100%|██████████| 32/32 [00:00<00:00, 60.11batch/s, loss_px_z: [0.8324], loss_mse_x: [0.1526], loss_py_z: [0.5713], loss_mse_y: [0.0809], loss_pv_z: [31.0689], loss_mse_v: [0.4438], loss_postrior_z: [36.7993]]
Epoch 18/100: 100%|██████████| 32/32 [00:00<00:00, 60.18batch/s, loss_px_z: [0.9134], loss_mse_x: [0.2342], loss_py_z: [0.5573], loss_mse_y: [0.0685], loss_pv_z: [29.7325], loss_mse_v: [0.4428], loss_postrior_z: [37.8426]]
Epoch 19/100: 100%|██████████| 32/32 [00:00<00:00, 59.97batch/s, loss_px_z: [1.0431], loss_mse_x: [0.3645], loss_py_z: [0.6281], loss_mse_y: [0.1604], loss_pv_z: [21.2909], loss_mse_v: [0.3899], loss_postrior_z: [28.2614]]
Epoch 20/100: 100%|██████████| 32/32 [00:00<00:00, 60.11batch/s, loss_px_z: [0.9431], loss_mse_x: [0.2650], loss_py_z: [0.5509], loss_mse_y: [0.0707], loss_pv_z: [20.4603], loss_mse_v: [0.3859], loss_postrior_z: [23.0501]]
Epoch [20/100]: MSE_x: 0.0929, MSE_y: 0.0824, MSE_v: 0.4880

Epoch 21/100: 100%|██████████| 32/32 [00:00<00:00, 60.17batch/s, loss_px_z: [1.0347], loss_mse_x: [0.3573], loss_py_z: [0.6146], loss_mse_y: [0.1555], loss_pv_z: [29.8462], loss_mse_v: [0.4482], loss_postrior_z: [36.4209]]
Epoch 22/100: 100%|██████████| 32/32 [00:00<00:00, 59.91batch/s, loss_px_z: [0.9561], loss_mse_x: [0.2792], loss_py_z: [0.5554], loss_mse_y: [0.0882], loss_pv_z: [1.5015], loss_mse_v: [0.2573], loss_postrior_z: [7.7514]]
Epoch 23/100: 100%|██████████| 32/32 [00:00<00:00, 60.14batch/s, loss_px_z: [0.8235], loss_mse_x: [0.1472], loss_py_z: [0.5874], loss_mse_y: [0.1351], loss_pv_z: [0.5594], loss_mse_v: [0.2621], loss_postrior_z: [6.8897]]
Epoch 24/100: 100%|██████████| 32/32 [00:00<00:00, 60.10batch/s, loss_px_z: [0.7600], loss_mse_x: [0.0843], loss_py_z: [0.5146], loss_mse_y: [0.0505], loss_pv_z: [15.1470], loss_mse_v: [0.3561], loss_postrior_z: [22.3120]]
Epoch 25/100: 100%|██████████| 32/32 [00:00<00:00, 60.14batch/s, loss_px_z: [0.8172], loss_mse_x: [0.1421], loss_py_z: [0.4946], loss_mse_y: [0.0423], loss_pv_z: [35.5047], loss_mse_v: [0.5002], loss_postrior_z: [39.7496]]
Epoch 26/100: 100%|██████████| 32/32 [00:00<00:00, 59.76batch/s, loss_px_z: [0.8243], loss_mse_x: [0.1498], loss_py_z: [0.5031], loss_mse_y: [0.0492], loss_pv_z: [21.2283], loss_mse_v: [0.4085], loss_postrior_z: [29.4189]]
Epoch 27/100: 100%|██████████| 32/32 [00:00<00:00, 60.15batch/s, loss_px_z: [1.2291], loss_mse_x: [0.5552], loss_py_z: [0.5260], loss_mse_y: [0.0858], loss_pv_z: [32.0752], loss_mse_v: [0.4673], loss_postrior_z: [42.3491]]
Epoch 28/100: 100%|██████████| 32/32 [00:00<00:00, 60.16batch/s, loss_px_z: [0.9256], loss_mse_x: [0.2523], loss_py_z: [0.6113], loss_mse_y: [0.1872], loss_pv_z: [11.5990], loss_mse_v: [0.3499], loss_postrior_z: [15.3578]]
Epoch 29/100: 100%|██████████| 32/32 [00:00<00:00, 60.13batch/s, loss_px_z: [0.9034], loss_mse_x: [0.2306], loss_py_z: [0.5699], loss_mse_y: [0.1462], loss_pv_z: [14.3208], loss_mse_v: [0.3664], loss_postrior_z: [21.1504]]
Epoch 30/100: 100%|██████████| 32/32 [00:00<00:00, 59.69batch/s, loss_px_z: [1.3914], loss_mse_x: [0.7193], loss_py_z: [0.4828], loss_mse_y: [0.0625], loss_pv_z: [-7.3091], loss_mse_v: [0.2280], loss_postrior_z: [-2.4749]]
Epoch [30/100]: MSE_x: 0.0939, MSE_y: 0.0844, MSE_v: 0.4857

Epoch 31/100: 100%|██████████| 32/32 [00:00<00:00, 60.18batch/s, loss_px_z: [0.7766], loss_mse_x: [0.1051], loss_py_z: [0.4541], loss_mse_y: [0.0316], loss_pv_z: [65.6081], loss_mse_v: [1.0058], loss_postrior_z: [62.1184]]
Epoch 32/100: 100%|██████████| 32/32 [00:00<00:00, 58.97batch/s, loss_px_z: [1.2586], loss_mse_x: [0.5876], loss_py_z: [0.5416], loss_mse_y: [0.1410], loss_pv_z: [25.5976], loss_mse_v: [0.6107], loss_postrior_z: [31.8609]]
Epoch 33/100: 100%|██████████| 32/32 [00:00<00:00, 60.15batch/s, loss_px_z: [1.1108], loss_mse_x: [0.4405], loss_py_z: [0.4404], loss_mse_y: [0.0340], loss_pv_z: [9.5764], loss_mse_v: [0.3409], loss_postrior_z: [13.5380]]
Epoch 34/100: 100%|██████████| 32/32 [00:00<00:00, 59.43batch/s, loss_px_z: [1.0514], loss_mse_x: [0.3816], loss_py_z: [0.4840], loss_mse_y: [0.0906], loss_pv_z: [-7.3509], loss_mse_v: [0.2674], loss_postrior_z: [-4.4814]]
Epoch 35/100: 100%|██████████| 32/32 [00:00<00:00, 57.56batch/s, loss_px_z: [0.9433], loss_mse_x: [0.2741], loss_py_z: [0.4575], loss_mse_y: [0.0741], loss_pv_z: [29.3482], loss_mse_v: [0.4471], loss_postrior_z: [27.1960]]
Epoch 36/100: 100%|██████████| 32/32 [00:00<00:00, 52.69batch/s, loss_px_z: [0.9184], loss_mse_x: [0.2498], loss_py_z: [0.4759], loss_mse_y: [0.1049], loss_pv_z: [10.3026], loss_mse_v: [0.3483], loss_postrior_z: [17.2433]]
Epoch 37/100: 100%|██████████| 32/32 [00:00<00:00, 52.65batch/s, loss_px_z: [0.9823], loss_mse_x: [0.3142], loss_py_z: [0.3996], loss_mse_y: [0.0382], loss_pv_z: [11.1830], loss_mse_v: [0.3703], loss_postrior_z: [19.4070]]
Epoch 38/100: 100%|██████████| 32/32 [00:00<00:00, 52.67batch/s, loss_px_z: [0.9945], loss_mse_x: [0.3270], loss_py_z: [0.4861], loss_mse_y: [0.1335], loss_pv_z: [-9.3832], loss_mse_v: [0.2701], loss_postrior_z: [-4.6423]]
Epoch 39/100: 100%|██████████| 32/32 [00:00<00:00, 52.46batch/s, loss_px_z: [0.7293], loss_mse_x: [0.0624], loss_py_z: [0.4619], loss_mse_y: [0.1198], loss_pv_z: [22.1069], loss_mse_v: [0.3943], loss_postrior_z: [26.1371]]
Epoch 40/100: 100%|██████████| 32/32 [00:00<00:00, 52.65batch/s, loss_px_z: [0.9726], loss_mse_x: [0.3064], loss_py_z: [0.4586], loss_mse_y: [0.1268], loss_pv_z: [19.7465], loss_mse_v: [0.4129], loss_postrior_z: [31.5037]]
Epoch [40/100]: MSE_x: 0.0933, MSE_y: 0.0893, MSE_v: 0.4900

Epoch 41/100: 100%|██████████| 32/32 [00:00<00:00, 52.70batch/s, loss_px_z: [0.7790], loss_mse_x: [0.1133], loss_py_z: [0.3834], loss_mse_y: [0.0446], loss_pv_z: [1.9458], loss_mse_v: [0.3133], loss_postrior_z: [8.3677]]
Epoch 42/100: 100%|██████████| 32/32 [00:00<00:00, 52.73batch/s, loss_px_z: [0.7895], loss_mse_x: [0.1244], loss_py_z: [0.3713], loss_mse_y: [0.0586], loss_pv_z: [31.2267], loss_mse_v: [0.4730], loss_postrior_z: [35.4569]]
Epoch 43/100: 100%|██████████| 32/32 [00:00<00:00, 52.60batch/s, loss_px_z: [0.7912], loss_mse_x: [0.1266], loss_py_z: [0.3231], loss_mse_y: [0.0287], loss_pv_z: [41.9379], loss_mse_v: [0.5249], loss_postrior_z: [43.0948]]
Epoch 44/100: 100%|██████████| 32/32 [00:00<00:00, 52.66batch/s, loss_px_z: [0.7445], loss_mse_x: [0.0806], loss_py_z: [0.3750], loss_mse_y: [0.0760], loss_pv_z: [21.5167], loss_mse_v: [0.4183], loss_postrior_z: [30.6867]]
Epoch 45/100: 100%|██████████| 32/32 [00:00<00:00, 52.53batch/s, loss_px_z: [0.7780], loss_mse_x: [0.1146], loss_py_z: [0.3586], loss_mse_y: [0.0796], loss_pv_z: [19.9246], loss_mse_v: [0.4149], loss_postrior_z: [21.5443]]
Epoch 46/100: 100%|██████████| 32/32 [00:00<00:00, 52.72batch/s, loss_px_z: [1.0830], loss_mse_x: [0.4202], loss_py_z: [0.2969], loss_mse_y: [0.0634], loss_pv_z: [-9.4088], loss_mse_v: [0.2761], loss_postrior_z: [-3.7582]]
Epoch 47/100: 100%|██████████| 32/32 [00:00<00:00, 52.78batch/s, loss_px_z: [1.0419], loss_mse_x: [0.3797], loss_py_z: [0.3354], loss_mse_y: [0.0947], loss_pv_z: [37.1565], loss_mse_v: [2.0222], loss_postrior_z: [40.7764]]
Epoch 48/100: 100%|██████████| 32/32 [00:00<00:00, 54.16batch/s, loss_px_z: [0.7746], loss_mse_x: [0.1130], loss_py_z: [0.2930], loss_mse_y: [0.0726], loss_pv_z: [7.5576], loss_mse_v: [0.3752], loss_postrior_z: [13.7704]]
Epoch 49/100: 100%|██████████| 32/32 [00:00<00:00, 52.62batch/s, loss_px_z: [0.8423], loss_mse_x: [0.1813], loss_py_z: [0.3029], loss_mse_y: [0.0860], loss_pv_z: [50.3740], loss_mse_v: [0.5826], loss_postrior_z: [42.9421]]
Epoch 50/100: 100%|██████████| 32/32 [00:00<00:00, 53.91batch/s, loss_px_z: [0.9005], loss_mse_x: [0.2401], loss_py_z: [0.2378], loss_mse_y: [0.0572], loss_pv_z: [23.7203], loss_mse_v: [0.4206], loss_postrior_z: [21.5316]]
Epoch [50/100]: MSE_x: 0.0924, MSE_y: 0.0922, MSE_v: 0.4850

Epoch 51/100: 100%|██████████| 32/32 [00:00<00:00, 56.99batch/s, loss_px_z: [0.8310], loss_mse_x: [0.1712], loss_py_z: [0.2651], loss_mse_y: [0.0840], loss_pv_z: [21.3173], loss_mse_v: [1.1722], loss_postrior_z: [32.6034]]
Epoch 52/100: 100%|██████████| 32/32 [00:00<00:00, 56.32batch/s, loss_px_z: [0.9389], loss_mse_x: [0.2797], loss_py_z: [0.2724], loss_mse_y: [0.0895], loss_pv_z: [41.2433], loss_mse_v: [1.2558], loss_postrior_z: [37.6628]]
Epoch 53/100: 100%|██████████| 32/32 [00:00<00:00, 55.43batch/s, loss_px_z: [1.0191], loss_mse_x: [0.3605], loss_py_z: [0.3110], loss_mse_y: [0.1181], loss_pv_z: [-5.5875], loss_mse_v: [0.2971], loss_postrior_z: [-2.6677]]
Epoch 54/100: 100%|██████████| 32/32 [00:00<00:00, 55.05batch/s, loss_px_z: [0.7439], loss_mse_x: [0.0858], loss_py_z: [0.1105], loss_mse_y: [0.0163], loss_pv_z: [20.1353], loss_mse_v: [0.5451], loss_postrior_z: [20.6309]]
Epoch 55/100: 100%|██████████| 32/32 [00:00<00:00, 53.03batch/s, loss_px_z: [0.9376], loss_mse_x: [0.2801], loss_py_z: [0.1811], loss_mse_y: [0.0681], loss_pv_z: [-17.3250], loss_mse_v: [0.2448], loss_postrior_z: [-14.3498]]
Epoch 56/100: 100%|██████████| 32/32 [00:00<00:00, 57.25batch/s, loss_px_z: [1.3045], loss_mse_x: [0.6476], loss_py_z: [0.3238], loss_mse_y: [0.1491], loss_pv_z: [46.6397], loss_mse_v: [0.5431], loss_postrior_z: [48.7620]]
Epoch 57/100: 100%|██████████| 32/32 [00:00<00:00, 57.07batch/s, loss_px_z: [0.8093], loss_mse_x: [0.1530], loss_py_z: [0.2234], loss_mse_y: [0.1067], loss_pv_z: [41.7925], loss_mse_v: [0.5469], loss_postrior_z: [48.3183]]
Epoch 58/100: 100%|██████████| 32/32 [00:00<00:00, 56.74batch/s, loss_px_z: [1.4082], loss_mse_x: [0.7524], loss_py_z: [0.0447], loss_mse_y: [0.0551], loss_pv_z: [-2.0897], loss_mse_v: [0.3089], loss_postrior_z: [1.1457]]
Epoch 59/100: 100%|██████████| 32/32 [00:00<00:00, 55.34batch/s, loss_px_z: [0.8823], loss_mse_x: [0.2271], loss_py_z: [0.2955], loss_mse_y: [0.1463], loss_pv_z: [-17.1023], loss_mse_v: [0.2474], loss_postrior_z: [-8.4271]]
Epoch 60/100: 100%|██████████| 32/32 [00:00<00:00, 57.76batch/s, loss_px_z: [1.3100], loss_mse_x: [0.6554], loss_py_z: [0.0349], loss_mse_y: [0.0532], loss_pv_z: [-9.9317], loss_mse_v: [0.2931], loss_postrior_z: [-3.7395]]
Epoch [60/100]: MSE_x: 0.0928, MSE_y: 0.0886, MSE_v: 0.4847

Epoch 61/100: 100%|██████████| 32/32 [00:00<00:00, 58.63batch/s, loss_px_z: [0.8233], loss_mse_x: [0.1693], loss_py_z: [-0.0091], loss_mse_y: [0.0491], loss_pv_z: [38.3667], loss_mse_v: [0.4997], loss_postrior_z: [45.0812]]
Epoch 62/100: 100%|██████████| 32/32 [00:00<00:00, 56.32batch/s, loss_px_z: [0.7631], loss_mse_x: [0.1096], loss_py_z: [0.2690], loss_mse_y: [0.1389], loss_pv_z: [19.4034], loss_mse_v: [0.4365], loss_postrior_z: [26.6290]]
Epoch 63/100: 100%|██████████| 32/32 [00:00<00:00, 57.21batch/s, loss_px_z: [1.1080], loss_mse_x: [0.4551], loss_py_z: [0.2811], loss_mse_y: [0.1415], loss_pv_z: [80.7834], loss_mse_v: [0.6600], loss_postrior_z: [72.9104]]
Epoch 64/100: 100%|██████████| 32/32 [00:00<00:00, 56.45batch/s, loss_px_z: [0.7215], loss_mse_x: [0.0692], loss_py_z: [0.2883], loss_mse_y: [0.1612], loss_pv_z: [53.9828], loss_mse_v: [0.6072], loss_postrior_z: [59.1240]]
Epoch 65/100: 100%|██████████| 32/32 [00:00<00:00, 53.46batch/s, loss_px_z: [1.5178], loss_mse_x: [0.8660], loss_py_z: [0.2269], loss_mse_y: [0.1412], loss_pv_z: [7.9063], loss_mse_v: [0.4066], loss_postrior_z: [8.8537]]
Epoch 66/100: 100%|██████████| 32/32 [00:00<00:00, 57.96batch/s, loss_px_z: [1.0116], loss_mse_x: [0.3604], loss_py_z: [0.5221], loss_mse_y: [0.2095], loss_pv_z: [12.0556], loss_mse_v: [0.4083], loss_postrior_z: [19.8110]]
Epoch 67/100: 100%|██████████| 32/32 [00:00<00:00, 52.53batch/s, loss_px_z: [0.9474], loss_mse_x: [0.2968], loss_py_z: [0.0703], loss_mse_y: [0.0881], loss_pv_z: [-6.6935], loss_mse_v: [0.2905], loss_postrior_z: [-1.6798]]
Epoch 68/100: 100%|██████████| 32/32 [00:00<00:00, 54.06batch/s, loss_px_z: [0.7904], loss_mse_x: [0.1404], loss_py_z: [-0.2489], loss_mse_y: [0.0387], loss_pv_z: [-25.1200], loss_mse_v: [0.2314], loss_postrior_z: [-16.3646]]
Epoch 69/100: 100%|██████████| 32/32 [00:00<00:00, 52.61batch/s, loss_px_z: [0.7263], loss_mse_x: [0.0768], loss_py_z: [0.1700], loss_mse_y: [0.1200], loss_pv_z: [-3.5599], loss_mse_v: [0.3203], loss_postrior_z: [9.7701]]
Epoch 70/100: 100%|██████████| 32/32 [00:00<00:00, 56.04batch/s, loss_px_z: [0.8032], loss_mse_x: [0.1543], loss_py_z: [-0.1722], loss_mse_y: [0.0521], loss_pv_z: [21.8734], loss_mse_v: [0.4367], loss_postrior_z: [29.7064]]
Epoch [70/100]: MSE_x: 0.0927, MSE_y: 0.0864, MSE_v: 0.4795

Epoch 71/100: 100%|██████████| 32/32 [00:00<00:00, 56.08batch/s, loss_px_z: [0.8890], loss_mse_x: [0.2406], loss_py_z: [-0.1999], loss_mse_y: [0.0420], loss_pv_z: [1.5693], loss_mse_v: [0.4193], loss_postrior_z: [6.5249]]
Epoch 72/100: 100%|██████████| 32/32 [00:00<00:00, 55.01batch/s, loss_px_z: [0.9643], loss_mse_x: [0.3165], loss_py_z: [0.2295], loss_mse_y: [0.1324], loss_pv_z: [25.2323], loss_mse_v: [0.4755], loss_postrior_z: [35.1813]]
Epoch 73/100: 100%|██████████| 32/32 [00:00<00:00, 56.47batch/s, loss_px_z: [0.9932], loss_mse_x: [0.3460], loss_py_z: [0.0042], loss_mse_y: [0.0815], loss_pv_z: [51.5540], loss_mse_v: [0.9088], loss_postrior_z: [57.8552]]
Epoch 74/100: 100%|██████████| 32/32 [00:00<00:00, 54.08batch/s, loss_px_z: [0.8933], loss_mse_x: [0.2466], loss_py_z: [-0.1914], loss_mse_y: [0.0467], loss_pv_z: [6.5270], loss_mse_v: [0.5062], loss_postrior_z: [8.9503]]
Epoch 75/100: 100%|██████████| 32/32 [00:00<00:00, 55.97batch/s, loss_px_z: [1.0264], loss_mse_x: [0.3803], loss_py_z: [-0.2124], loss_mse_y: [0.0410], loss_pv_z: [51.8404], loss_mse_v: [0.7711], loss_postrior_z: [53.4973]]
Epoch 76/100: 100%|██████████| 32/32 [00:00<00:00, 54.47batch/s, loss_px_z: [0.8433], loss_mse_x: [0.1978], loss_py_z: [-0.1416], loss_mse_y: [0.0737], loss_pv_z: [23.4251], loss_mse_v: [0.4216], loss_postrior_z: [25.4289]]
Epoch 77/100: 100%|██████████| 32/32 [00:00<00:00, 56.41batch/s, loss_px_z: [1.0254], loss_mse_x: [0.3805], loss_py_z: [-0.0419], loss_mse_y: [0.0844], loss_pv_z: [89.5222], loss_mse_v: [1.0466], loss_postrior_z: [94.9647]]
Epoch 78/100: 100%|██████████| 32/32 [00:00<00:00, 56.31batch/s, loss_px_z: [0.9360], loss_mse_x: [0.2916], loss_py_z: [0.3582], loss_mse_y: [0.1516], loss_pv_z: [-0.7475], loss_mse_v: [0.3250], loss_postrior_z: [9.3354]]
Epoch 79/100: 100%|██████████| 32/32 [00:00<00:00, 56.92batch/s, loss_px_z: [0.8097], loss_mse_x: [0.1659], loss_py_z: [0.0190], loss_mse_y: [0.1032], loss_pv_z: [23.7452], loss_mse_v: [0.5672], loss_postrior_z: [30.1029]]
Epoch 80/100: 100%|██████████| 32/32 [00:00<00:00, 56.77batch/s, loss_px_z: [1.5119], loss_mse_x: [0.8687], loss_py_z: [-0.2420], loss_mse_y: [0.0538], loss_pv_z: [52.9697], loss_mse_v: [0.7831], loss_postrior_z: [67.9704]]
Epoch [80/100]: MSE_x: 0.0921, MSE_y: 0.0892, MSE_v: 0.4806

Epoch 81/100: 100%|██████████| 32/32 [00:00<00:00, 58.74batch/s, loss_px_z: [1.0592], loss_mse_x: [0.4165], loss_py_z: [0.3977], loss_mse_y: [0.1552], loss_pv_z: [76.5605], loss_mse_v: [0.7894], loss_postrior_z: [88.1919]]
Epoch 82/100: 100%|██████████| 32/32 [00:00<00:00, 55.39batch/s, loss_px_z: [0.8014], loss_mse_x: [0.1592], loss_py_z: [-0.3580], loss_mse_y: [0.0247], loss_pv_z: [-10.8457], loss_mse_v: [0.3003], loss_postrior_z: [-1.7214]]
Epoch 83/100: 100%|██████████| 32/32 [00:00<00:00, 55.04batch/s, loss_px_z: [0.7829], loss_mse_x: [0.1414], loss_py_z: [0.1903], loss_mse_y: [0.1368], loss_pv_z: [5.1346], loss_mse_v: [0.3579], loss_postrior_z: [13.7117]]
Epoch 84/100: 100%|██████████| 32/32 [00:00<00:00, 53.91batch/s, loss_px_z: [0.7963], loss_mse_x: [0.1553], loss_py_z: [-0.1579], loss_mse_y: [0.0651], loss_pv_z: [-43.4583], loss_mse_v: [0.1511], loss_postrior_z: [-41.7073]]
Epoch 85/100: 100%|██████████| 32/32 [00:00<00:00, 55.40batch/s, loss_px_z: [1.7594], loss_mse_x: [1.1190], loss_py_z: [-0.3379], loss_mse_y: [0.0233], loss_pv_z: [7.3098], loss_mse_v: [0.3463], loss_postrior_z: [7.8066]]
Epoch 86/100: 100%|██████████| 32/32 [00:00<00:00, 55.51batch/s, loss_px_z: [0.9384], loss_mse_x: [0.2985], loss_py_z: [-0.0485], loss_mse_y: [0.0784], loss_pv_z: [10.2092], loss_mse_v: [0.3946], loss_postrior_z: [15.5025]]
Epoch 87/100: 100%|██████████| 32/32 [00:00<00:00, 56.19batch/s, loss_px_z: [0.9014], loss_mse_x: [0.2620], loss_py_z: [-0.0807], loss_mse_y: [0.0751], loss_pv_z: [4.4597], loss_mse_v: [0.3666], loss_postrior_z: [16.1913]]
Epoch 88/100: 100%|██████████| 32/32 [00:00<00:00, 54.59batch/s, loss_px_z: [0.8896], loss_mse_x: [0.2508], loss_py_z: [-0.4017], loss_mse_y: [0.0261], loss_pv_z: [-1.6607], loss_mse_v: [0.3429], loss_postrior_z: [9.9890]]
Epoch 89/100: 100%|██████████| 32/32 [00:00<00:00, 55.01batch/s, loss_px_z: [0.9987], loss_mse_x: [0.3604], loss_py_z: [-0.0197], loss_mse_y: [0.0888], loss_pv_z: [2.3344], loss_mse_v: [0.3574], loss_postrior_z: [8.8961]]
Epoch 90/100: 100%|██████████| 32/32 [00:00<00:00, 55.18batch/s, loss_px_z: [0.9082], loss_mse_x: [0.2706], loss_py_z: [0.3003], loss_mse_y: [0.1572], loss_pv_z: [41.4758], loss_mse_v: [0.6483], loss_postrior_z: [42.9172]]
Epoch [90/100]: MSE_x: 0.0929, MSE_y: 0.0851, MSE_v: 0.4804

Epoch 91/100: 100%|██████████| 32/32 [00:00<00:00, 55.70batch/s, loss_px_z: [1.2535], loss_mse_x: [0.6164], loss_py_z: [0.3955], loss_mse_y: [0.1559], loss_pv_z: [27.8524], loss_mse_v: [0.4503], loss_postrior_z: [26.3598]]
Epoch 92/100: 100%|██████████| 32/32 [00:00<00:00, 54.39batch/s, loss_px_z: [0.9255], loss_mse_x: [0.2889], loss_py_z: [-0.2278], loss_mse_y: [0.0506], loss_pv_z: [-30.0789], loss_mse_v: [0.2110], loss_postrior_z: [-27.2653]]
Epoch 93/100: 100%|██████████| 32/32 [00:00<00:00, 55.26batch/s, loss_px_z: [0.8545], loss_mse_x: [0.2185], loss_py_z: [-0.1203], loss_mse_y: [0.0712], loss_pv_z: [-3.8885], loss_mse_v: [0.3201], loss_postrior_z: [-0.4682]]
Epoch 94/100: 100%|██████████| 32/32 [00:00<00:00, 55.40batch/s, loss_px_z: [0.8185], loss_mse_x: [0.1830], loss_py_z: [-0.2654], loss_mse_y: [0.0451], loss_pv_z: [-1.6399], loss_mse_v: [0.3369], loss_postrior_z: [1.0570]]
Epoch 95/100: 100%|██████████| 32/32 [00:00<00:00, 58.24batch/s, loss_px_z: [0.6933], loss_mse_x: [0.0585], loss_py_z: [-0.0570], loss_mse_y: [0.0867], loss_pv_z: [0.5686], loss_mse_v: [0.3402], loss_postrior_z: [5.5315]]
Epoch 96/100: 100%|██████████| 32/32 [00:00<00:00, 60.05batch/s, loss_px_z: [0.8206], loss_mse_x: [0.1863], loss_py_z: [0.3453], loss_mse_y: [0.1475], loss_pv_z: [30.9373], loss_mse_v: [0.5577], loss_postrior_z: [33.8833]]
Epoch 97/100: 100%|██████████| 32/32 [00:00<00:00, 60.05batch/s, loss_px_z: [0.7989], loss_mse_x: [0.1652], loss_py_z: [-0.2076], loss_mse_y: [0.0503], loss_pv_z: [5.4328], loss_mse_v: [0.4984], loss_postrior_z: [11.1171]]
Epoch 98/100: 100%|██████████| 32/32 [00:00<00:00, 59.85batch/s, loss_px_z: [0.6983], loss_mse_x: [0.0651], loss_py_z: [-0.0814], loss_mse_y: [0.0811], loss_pv_z: [24.3846], loss_mse_v: [0.5054], loss_postrior_z: [31.1920]]
Epoch 99/100: 100%|██████████| 32/32 [00:00<00:00, 60.03batch/s, loss_px_z: [1.1036], loss_mse_x: [0.4710], loss_py_z: [-0.2236], loss_mse_y: [0.0453], loss_pv_z: [2.1287], loss_mse_v: [0.3433], loss_postrior_z: [5.4971]]
Epoch 100/100: 100%|██████████| 32/32 [00:00<00:00, 59.90batch/s, loss_px_z: [0.8939], loss_mse_x: [0.2618], loss_py_z: [-0.1434], loss_mse_y: [0.0728], loss_pv_z: [-24.3166], loss_mse_v: [0.2560], loss_postrior_z: [-25.0117]]
Epoch [100/100]: MSE_x: 0.0916, MSE_y: 0.0877, MSE_v: 0.4752


Make predictions using the trained CausalBGM model

Estimate causal effects with posterior intervals from latent MCMC samples.

Config Parameter

Description

data

Tuple of data inputs (x, y, v), Required.

alpha

Significance level for the posterior interval. Default: 0.01.

n_mcmc

Number of posterior MCMC samples to draw. Default: 3000.

burn_in

Number of burn-in MCMC samples before drawing. Default: 5000.

x_values

Treatment value(s) for dose-response function to be predicted. Examples: 1.0 or [1.0,2.0]

q_sd

Standard deviation for the proposal distribution used in Metropolis-Hastings (MH) sampling. Default: 1.0.

sample_y

Whether to consider the variance function in the outcome generative model. Default: True.

bs

Batch size in inference stage, denoting number of test subjects processed per batch prediction. Default: 10000.

Return

Type

Description

Shape

pre_ite_mean

np.ndarray

Point estimates of the Individual Treatment Effect (ITE).

(len(x),)

pre_ite_PI

np.ndarray

Posterior intervals for the ITEs, representing [lower bound, upper bound]

(len(x), 2)

[12]:
pre_ite_mean, pre_ite_PI = model.predict(data=(x,y,v), alpha=0.01, n_mcmc=3000, burn_in=5000, q_sd=1.0, bs=1000)
MCMC Latent Variable Sampling ...
Final MCMC Acceptance Rate: 0.2116

Evaluating the results

Calculate the error of average treatment effect (\(\epsilon_{ATE}\)) and precision in estimation of heterogeneous effect (\(\epsilon_{PEHE}\)).

[13]:
# Get the ground truth ITE
ufid = '629e3d2c63914e45b227cc913c09cebe'
covariants_file = 'data/ACIC_2018/x.csv'
df = pd.read_csv(covariants_file, index_col='sample_id',header=0, sep=',')
df_sim = pd.read_csv('data/ACIC_2018/scaling/factuals/%s.csv'%ufid,index_col='sample_id',header=0, sep=',')
dataset = df.join(df_sim, how='inner')
data_x = dataset['z'].values
data_y = dataset['y'].values
cf_id = ufid + '_cf'
y_true = pd.read_csv('data/ACIC_2018/scaling/counterfactuals/%s.csv'%cf_id,index_col='sample_id',header=0, sep=',')
y_0 = y_true.values[:,0]
y_1 = y_true.values[:,1]
ite_true = y_true.values[:,1]-y_true.values[:,0]

# Evaluate
delta_ate = abs(np.mean(pre_ite_mean) - np.mean(ite_true))
delta_pehe = np.mean((pre_ite_mean - ite_true)**2)

print(f"Delta ATE (Absolute Error in Average Treatment Effect): {delta_ate:.4f}")
print(f"Delta PEHE (Precision in Estimation of Heterogeneous Effect): {delta_pehe:.4f}")
Delta ATE (Absolute Error in Average Treatment Effect): 0.0069
Delta PEHE (Precision in Estimation of Heterogeneous Effect): 0.0001

Use CausalBGM by a command-line interface (CLI)

When installing the CausalBGM by pip install bayesgm, an indepedent console program will be available for general use. This has advantage of being generalizeable to non-python scripts!

Reminder

  • Since the current CLI directly takes data from a single file with txt/csv/npy format as input, the training and inference are on the same dataset.

  • Python or R APIs are recommended for more flexible usage.

[1]:
!bayesgm causalbgm -h
2026-03-12 22:10:07.495988: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-03-12 22:10:07.584365: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-03-12 22:10:07.607316: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-03-12 22:10:08.326003: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/ql339/.conda/envs/py3.9/lib/
2026-03-12 22:10:08.326092: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/ql339/.conda/envs/py3.9/lib/
2026-03-12 22:10:08.326098: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
usage: bayesgm causalbgm [-h] -o OUTPUT_DIR -i INPUT [-t DELIMITER]
                         [-d DATASET] [-F SAVE_FORMAT] [-save_model]
                         [-save_res] [--use_bnn | --no-use_bnn]
                         [--use_egm_init | --no-use_egm_init] [--seed SEED]
                         [-B | --binary_treatment | --no-binary_treatment]
                         [-Z Z_DIMS [Z_DIMS ...]] [--lr_theta LR_THETA]
                         [--lr_z LR_Z] [--x_min X_MIN] [--x_max X_MAX]
                         [--x_values X_VALUES [X_VALUES ...]]
                         [--g_units G_UNITS [G_UNITS ...]]
                         [--f_units F_UNITS [F_UNITS ...]]
                         [--h_units H_UNITS [H_UNITS ...]]
                         [--kl_weight KL_WEIGHT] [--lr LR]
                         [--g_d_freq G_D_FREQ]
                         [--e_units E_UNITS [E_UNITS ...]]
                         [--dz_units DZ_UNITS [DZ_UNITS ...]]
                         [--use-z-rec | --no-use-z-rec] [-N N_ITER]
                         [--startoff STARTOFF]
                         [--batches_per_eval BATCHES_PER_EVAL] [-E EPOCHS]
                         [-M N_MCMC] [-q Q_SD]
                         [--epochs_per_eval EPOCHS_PER_EVAL] [--alpha ALPHA]

CausalBGM: An AI-powered Bayesian generative modeling approach for causal
inference in observational studies

optional arguments:
  -h, --help            show this help message and exit
  -o OUTPUT_DIR, --output_dir OUTPUT_DIR
                        Output directory
  -i INPUT, --input INPUT
                        Input data file must be in csv or txt or npz format
  -t DELIMITER, --delimiter DELIMITER
                        Delimiter for txt or csv files (default: tab '\t').
  -d DATASET, --dataset DATASET
                        Dataset name
  -F SAVE_FORMAT, --save_format SAVE_FORMAT
                        Saving format (default: txt)
  -save_model           Whether to save model. (default: False)
  -save_res             Whether to save intermediate results. (default: True)
  --use_bnn, --no-use_bnn
                        Whether use Bayesian neural nets. (default: True)
  --use_egm_init, --no-use_egm_init
                        Whether use EGM initialization. (default: True)
  --seed SEED           Random seed for reproduction (default: 123).
  -B, --binary_treatment, --no-binary_treatment
                        Whether use binary treatment setting. (default: True)
  -Z Z_DIMS [Z_DIMS ...], --z_dims Z_DIMS [Z_DIMS ...]
                        Latent dimensions of Z (default: [3, 3, 6, 6]).
  --lr_theta LR_THETA   Learning rate for updating model parameters (default:
                        0.0001).
  --lr_z LR_Z           Learning rate for updating latent variables (default:
                        0.0001).
  --x_min X_MIN         Lower bound for treatment interval (default: 0.0).
  --x_max X_MAX         Upper bound for treatment interval (default: 3.0).
  --x_values X_VALUES [X_VALUES ...]
                        List of treatment values to be predicted. Provide
                        space-separated values. Example: --x_values 0.5 1.0
                        1.5
  --g_units G_UNITS [G_UNITS ...]
                        Number of units for covariates generative model
                        (default: [64,64,64,64,64]).
  --f_units F_UNITS [F_UNITS ...]
                        Number of units for outcome generative model (default:
                        [64,32,8]).
  --h_units H_UNITS [H_UNITS ...]
                        Number of units for treatment generative model
                        (default: [64,32,8]).
  --kl_weight KL_WEIGHT
                        Coefficient for KL divergence term in BNNs (default:
                        0.0001).
  --lr LR               Learning rate for EGM initialization (default:
                        0.0001).
  --g_d_freq G_D_FREQ   Frequency for updating discriminators and generators
                        (default: 5).
  --e_units E_UNITS [E_UNITS ...]
                        Number of units for encoder network (default:
                        [64,64,64,64,64]).
  --dz_units DZ_UNITS [DZ_UNITS ...]
                        Number of units for discriminator network in latent
                        space (default: [64,32,8]).
  --use-z-rec, --no-use-z-rec
                        Use the reconstruction for latent features (default:
                        True). (default: True)
  -N N_ITER, --n_iter N_ITER
                        Number of iterations (default: 30000).
  --startoff STARTOFF   Iteration for starting evaluation (default: 0).
  --batches_per_eval BATCHES_PER_EVAL
                        Number of iterations per evaluation (default: 500).
  -E EPOCHS, --epochs EPOCHS
                        Number of epochs in iterative updating algorithm
                        (default: 100).
  -M N_MCMC, --n_mcmc N_MCMC
                        MCMC sample size (default: 3000).
  -q Q_SD, --q_sd Q_SD  Standard deviation for proposal distribution in MCMC,
                        a negative q_sd denotes adaptive MCMC (default: 1.0).
  --epochs_per_eval EPOCHS_PER_EVAL
                        Number of epochs per evaluation (default: 10).
  --alpha ALPHA         Significance level (default: 0.01).

The config parameters are consistent with the Python APIs in the previous sections. Here, we use a demo data (continous treatment setting) for an example!

[2]:
!bayesgm causalbgm  -i demo.csv -o ./ -d Demo -N 1000 -E 10 -M 500 -Z 1 1 1 7 --no-binary_treatment --x_values 0 1 2
2025-01-20 11:39:17.090463: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-01-20 11:39:17.213772: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-20 11:39:17.787946: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /share/software/user/open/cudnn/8.1.1.33/lib64:/share/software/user/open/nccl/2.8.4/lib:/usr/lib64/nvidia:/share/software/user/open/cuda/11.2.0/targets/x86_64-linux/lib:/share/software/user/open/cuda/11.2.0/lib64:/share/software/user/open/cuda/11.2.0/nvvm/lib64:/share/software/user/open/cuda/11.2.0/extras/Debugger/lib64:/share/software/user/open/cuda/11.2.0/extras/CUPTI/lib64:/share/software/user/open/openblas/0.3.10/lib:/share/software/user/open/gcc/10.1.0/lib64:/share/software/user/open/gcc/10.1.0/lib/gcc/x86_64-pc-linux-gnu:/share/software/user/open/gcc/10.1.0/lib:/share/software/user/open/tensorrt/8.5.1.7/lib:/share/software/user/open/python/3.9.0/lib:/share/software/user/open/libffi/3.2.1/lib64:/share/software/user/open/sqlite/3.44.2/lib:/share/software/user/open/readline/8.2/lib:/share/software/user/open/ncurses/6.4/lib:/share/software/user/open/tcltk/8.6.6/lib:/share/software/user/open/libressl/3.2.1/lib:/share/software/user/open/zlib/1.2.11/lib
2025-01-20 11:39:17.788041: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /share/software/user/open/cudnn/8.1.1.33/lib64:/share/software/user/open/nccl/2.8.4/lib:/usr/lib64/nvidia:/share/software/user/open/cuda/11.2.0/targets/x86_64-linux/lib:/share/software/user/open/cuda/11.2.0/lib64:/share/software/user/open/cuda/11.2.0/nvvm/lib64:/share/software/user/open/cuda/11.2.0/extras/Debugger/lib64:/share/software/user/open/cuda/11.2.0/extras/CUPTI/lib64:/share/software/user/open/openblas/0.3.10/lib:/share/software/user/open/gcc/10.1.0/lib64:/share/software/user/open/gcc/10.1.0/lib/gcc/x86_64-pc-linux-gnu:/share/software/user/open/gcc/10.1.0/lib:/share/software/user/open/tensorrt/8.5.1.7/lib:/share/software/user/open/python/3.9.0/lib:/share/software/user/open/libffi/3.2.1/lib64:/share/software/user/open/sqlite/3.44.2/lib:/share/software/user/open/readline/8.2/lib:/share/software/user/open/ncurses/6.4/lib:/share/software/user/open/tcltk/8.6.6/lib:/share/software/user/open/libressl/3.2.1/lib:/share/software/user/open/zlib/1.2.11/lib
2025-01-20 11:39:17.788058: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2025-01-20 11:39:19.651249: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-01-20 11:39:20.144803: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 43430 MB memory:  -> device: 0, name: NVIDIA A40, pci bus id: 0000:06:00.0, compute capability: 8.6
/home/users/liuqiao/.local/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:95: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
  loc = add_variable_fn(
/home/users/liuqiao/.local/lib/python3.9/site-packages/tensorflow_probability/python/layers/util.py:105: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use the `layer.add_weight()` method instead.
  untransformed_scale = add_variable_fn(
2025-01-20 11:39:21.124243: I tensorflow/stream_executor/cuda/cuda_blas.cc:1614] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
EGM Initialization Starts ...
EGM Initialization Iter [0] : e_loss_adv [0.0104], l2_loss_v [1.0768], l2_loss_z [0.9892], l2_loss_x [0.9474], l2_loss_y [4.9964], g_e_loss [8.0202], dz_loss [0.0822], d_loss [1.8063]
EGM Initialization Iter [500] : e_loss_adv [1.9232], l2_loss_v [1.1323], l2_loss_z [1.2015], l2_loss_x [0.4507], l2_loss_y [2.8286], g_e_loss [7.5364], dz_loss [-0.9738], d_loss [-0.5502]
EGM Initialization Iter [1000] : e_loss_adv [2.0145], l2_loss_v [0.9970], l2_loss_z [1.1143], l2_loss_x [1.5727], l2_loss_y [0.9798], g_e_loss [6.6785], dz_loss [-1.6597], d_loss [-1.3993]
EGM Initialization Ends.
Initialize latent variables Z with e(V)...
Iterative Updating Starts ...
Epoch 0/10: 100%|█| 3/3 [00:09<00:00,  3.25s/batch, loss_px_z: [1.7446], loss_ms
Epoch [0/10]: MSE_x: 1.5652, MSE_y: 1.5891, MSE_v: 0.9300

Epoch 1/10: 100%|█| 3/3 [00:00<00:00, 36.85batch/s, loss_px_z: [1.0154], loss_ms
Epoch 2/10: 100%|█| 3/3 [00:00<00:00, 36.99batch/s, loss_px_z: [1.3450], loss_ms
Epoch 3/10: 100%|█| 3/3 [00:00<00:00, 36.98batch/s, loss_px_z: [1.6738], loss_ms
Epoch 4/10: 100%|█| 3/3 [00:00<00:00, 38.60batch/s, loss_px_z: [1.9682], loss_ms
Epoch 5/10: 100%|█| 3/3 [00:00<00:00, 32.90batch/s, loss_px_z: [2.1426], loss_ms
Epoch 6/10: 100%|█| 3/3 [00:00<00:00, 40.06batch/s, loss_px_z: [0.8936], loss_ms
Epoch 7/10: 100%|█| 3/3 [00:00<00:00, 39.97batch/s, loss_px_z: [1.1379], loss_ms
Epoch 8/10: 100%|█| 3/3 [00:00<00:00, 38.26batch/s, loss_px_z: [0.9878], loss_ms
Epoch 9/10: 100%|█| 3/3 [00:00<00:00, 37.52batch/s, loss_px_z: [0.6265], loss_ms
Epoch 10/10: 100%|█| 3/3 [00:00<00:00, 36.10batch/s, loss_px_z: [1.0492], loss_m
Epoch [10/10]: MSE_x: 1.5488, MSE_y: 1.5971, MSE_v: 0.9248

MCMC Latent Variable Sampling ...
Final MCMC Acceptance Rate: 0.1590