import tensorflow as tf
import tensorflow_probability as tfp
from ..networks import BaseFullyConnectedNet, Discriminator, BayesianFullyConnectedNet
import numpy as np
from bayesgm.datasets import Gaussian_sampler
from bayesgm.utils.data_io import save_data
import dateutil.tz
import datetime
import os
from tqdm import tqdm
[docs]
class CausalBGM(object):
"""Causal Bayesian Generative Model (CausalBGM) for causal inference.
CausalBGM learns a latent-variable generative model for causal inference
with treatment :math:`X`, outcome :math:`Y`, and high-dimensional covariates
:math:`V`. The latent variable :math:`Z` is partitioned into
:math:`(Z_0, Z_1, Z_2, Z_3)` to disentangle confounding, outcome-specific,
treatment-specific, and residual variation.
Parameters
----------
params : dict
Configuration dictionary. Required keys:
- ``'v_dim'`` (int): Dimension of covariates :math:`V`.
- ``'z_dims'`` (list[int]): Dimensions ``[z0, z1, z2, z3]`` of the
four latent sub-vectors.
- ``'binary_treatment'`` (bool): ``True`` for binary treatment,
``False`` for continuous.
- ``'dataset'`` (str): Dataset name (used for checkpoint paths).
- ``'output_dir'`` (str): Root directory for outputs.
Optional keys (with defaults):
- ``'use_bnn'`` (bool): Whether to use Bayesian neural networks. Default ``True``.
- ``'g_units'`` (list[int]): Hidden-layer sizes for the generator network. Default ``[64, 64, 64, 64, 64]``.
- ``'e_units'`` (list[int]): Hidden-layer sizes for the encoder network. Default ``[64, 64, 64, 64, 64]``.
- ``'f_units'`` (list[int]): Hidden-layer sizes for the outcome network. Default ``[64, 32, 8]``.
- ``'h_units'`` (list[int]): Hidden-layer sizes for the treatment network. Default ``[64, 32, 8]``.
- ``'dz_units'`` (list[int]): Hidden-layer sizes for the latent discriminator. Default ``[64, 32, 8]``.
- ``'lr'`` (float): Learning rate for EGM pre-training. Default ``0.0002``.
- ``'lr_theta'`` (float): Learning rate for network parameters. Default ``0.0001``.
- ``'lr_z'`` (float): Learning rate for latent-variable updates. Default ``0.0001``.
- ``'g_d_freq'`` (int): Discriminator-to-generator update ratio. Default ``5``.
- ``'save_model'`` (bool): Whether to save model checkpoints. Default ``False``.
- ``'save_res'`` (bool): Whether to save results. Default ``True``.
- ``'kl_weight'`` (float): KL-divergence weight when ``use_bnn`` is True. Default ``0.0001``.
timestamp : str or None, optional
Timestamp string for the run. If ``None``, the current local time
is used.
random_seed : int or None, optional
If provided, sets the global random seed for reproducibility.
"""
def __init__(self, params, timestamp=None, random_seed=None):
super(CausalBGM, self).__init__()
self.params = params
self.timestamp = timestamp
if random_seed is not None:
tf.keras.utils.set_random_seed(random_seed)
os.environ['TF_DETERMINISTIC_OPS'] = '1'
tf.config.experimental.enable_op_determinism()
if self.params['use_bnn']:
self.g_net = BayesianFullyConnectedNet(input_dim=sum(params['z_dims']),output_dim = params['v_dim']+1,
model_name='g_net', nb_units=params['g_units'])
self.e_net = BayesianFullyConnectedNet(input_dim=params['v_dim'],output_dim = sum(params['z_dims']),
model_name='e_net', nb_units=params['e_units'])
self.f_net = BayesianFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][1]+1,
output_dim = 2, model_name='f_net', nb_units=params['f_units'])
self.h_net = BayesianFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][2],
output_dim = 2, model_name='h_net', nb_units=params['h_units'])
else:
self.g_net = BaseFullyConnectedNet(input_dim=sum(params['z_dims']),output_dim = params['v_dim']+1,
model_name='g_net', nb_units=params['g_units'])
self.e_net = BaseFullyConnectedNet(input_dim=params['v_dim'],output_dim = sum(params['z_dims']),
model_name='e_net', nb_units=params['e_units'])
self.f_net = BaseFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][1]+1,
output_dim = 2, model_name='f_net', nb_units=params['f_units'])
self.h_net = BaseFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][2],
output_dim = 2, model_name='h_net', nb_units=params['h_units'])
self.dz_net = Discriminator(input_dim=sum(params['z_dims']),model_name='dz_net',
nb_units=params['dz_units'])
self.g_pre_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.9, beta_2=0.99)
self.d_pre_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.9, beta_2=0.99)
self.z_sampler = Gaussian_sampler(mean=np.zeros(sum(params['z_dims'])), sd=1.0)
self.g_optimizer = tf.keras.optimizers.Adam(params['lr_theta'], beta_1=0.9, beta_2=0.99)
self.f_optimizer = tf.keras.optimizers.Adam(params['lr_theta'], beta_1=0.9, beta_2=0.99)
self.h_optimizer = tf.keras.optimizers.Adam(params['lr_theta'], beta_1=0.9, beta_2=0.99)
self.posterior_optimizer = tf.keras.optimizers.Adam(params['lr_z'], beta_1=0.9, beta_2=0.99)
self.initialize_nets()
if self.timestamp is None:
now = datetime.datetime.now(dateutil.tz.tzlocal())
self.timestamp = now.strftime('%Y%m%d_%H%M%S')
self.checkpoint_path = "{}/checkpoints/{}/{}".format(
params['output_dir'], params['dataset'], self.timestamp)
if self.params['save_model'] and not os.path.exists(self.checkpoint_path):
os.makedirs(self.checkpoint_path)
self.save_dir = "{}/results/{}/{}".format(
params['output_dir'], params['dataset'], self.timestamp)
if self.params['save_res'] and not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.ckpt = tf.train.Checkpoint(g_net = self.g_net,
e_net = self.e_net,
f_net = self.f_net,
h_net = self.h_net,
dz_net = self.dz_net,
g_pre_optimizer = self.g_pre_optimizer,
d_pre_optimizer = self.d_pre_optimizer,
g_optimizer = self.g_optimizer,
f_optimizer = self.f_optimizer,
h_optimizer = self.h_optimizer,
posterior_optimizer = self.posterior_optimizer)
self.ckpt_manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_path, max_to_keep=5)
if self.ckpt_manager.latest_checkpoint:
self.ckpt.restore(self.ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
[docs]
def get_config(self):
"""Return the configuration of the CausalBGM model.
Returns
-------
dict
A dictionary with key ``"params"`` containing the full
configuration dictionary passed at construction time.
"""
return {
"params": self.params,
}
def initialize_nets(self, print_summary = False):
"""Initialize all the networks in CausalBGM."""
self.g_net(np.zeros((1, sum(self.params['z_dims']))))
self.f_net(np.zeros((1, self.params['z_dims'][0]+self.params['z_dims'][1]+1)))
self.h_net(np.zeros((1, self.params['z_dims'][0]+self.params['z_dims'][2])))
if print_summary:
print(self.g_net.summary())
print(self.f_net.summary())
print(self.h_net.summary())
# Update generative model for covariates V
@tf.function
def update_g_net(self, data_z, data_v, eps=1e-6):
with tf.GradientTape() as gen_tape:
g_net_output = self.g_net(data_z)
mu_v = g_net_output[:,:self.params['v_dim']]
if 'sigma_v' in self.params:
sigma_square_v = self.params['sigma_v']**2
else:
sigma_square_v = tf.nn.softplus(g_net_output[:,-1]) + eps
#loss = -log(p(x|z))
loss_mse = tf.reduce_mean((data_v - mu_v)**2)
loss_v = tf.reduce_sum((data_v - mu_v)**2, axis=1)/(2*sigma_square_v) + \
self.params['v_dim'] * tf.math.log(sigma_square_v)/2
loss_v = tf.reduce_mean(loss_v)
if self.params['use_bnn']:
loss_kl = sum(self.g_net.losses)
loss_v += loss_kl * self.params['kl_weight']
# Calculate the gradients for generators and discriminators
g_gradients = gen_tape.gradient(loss_v, self.g_net.trainable_variables)
# Apply the gradients to the optimizer
self.g_optimizer.apply_gradients(zip(g_gradients, self.g_net.trainable_variables))
return loss_v, loss_mse
# Update generative model for treatment X
@tf.function
def update_h_net(self, data_z, data_x, eps=1e-6):
with tf.GradientTape() as gen_tape:
data_z0 = data_z[:,:self.params['z_dims'][0]]
data_z2 = data_z[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
h_net_output = self.h_net(tf.concat([data_z0, data_z2], axis=-1))
mu_x = h_net_output[:,:1]
if self.params['binary_treatment']:
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=data_x,
logits=mu_x))
loss_x = loss
else:
if 'sigma_x' in self.params:
sigma_square_x = self.params['sigma_x']**2
else:
sigma_square_x = tf.nn.softplus(h_net_output[:,-1]) + eps
#loss = -log(p(x|z))
loss = tf.reduce_mean((data_x - mu_x)**2)
loss_x = tf.reduce_sum((data_x - mu_x)**2, axis=1)/(2*sigma_square_x) + \
tf.math.log(sigma_square_x)/2
loss_x = tf.reduce_mean(loss_x)
if self.params['use_bnn']:
loss_kl = sum(self.h_net.losses)
loss_x += loss_kl * self.params['kl_weight']
# Calculate the gradients for generators and discriminators
h_gradients = gen_tape.gradient(loss_x, self.h_net.trainable_variables)
# Apply the gradients to the optimizer
self.h_optimizer.apply_gradients(zip(h_gradients, self.h_net.trainable_variables))
return loss_x, loss
# Update generative model for outcome Y
@tf.function
def update_f_net(self, data_z, data_x, data_y, eps=1e-6):
with tf.GradientTape() as gen_tape:
data_z0 = data_z[:,:self.params['z_dims'][0]]
data_z1 = data_z[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
f_net_output = self.f_net(tf.concat([data_z0, data_z1, data_x], axis=-1))
mu_y = f_net_output[:,:1]
if 'sigma_y' in self.params:
sigma_square_y = self.params['sigma_y']**2
else:
sigma_square_y = tf.nn.softplus(f_net_output[:,-1]) + eps
#loss = -log(p(y|z,x))
loss_mse = tf.reduce_mean((data_y - mu_y)**2)
loss_y = tf.reduce_sum((data_y - mu_y)**2, axis=1)/(2*sigma_square_y) + \
tf.math.log(sigma_square_y)/2
loss_y = tf.reduce_mean(loss_y)
if self.params['use_bnn']:
loss_kl = sum(self.f_net.losses)
loss_y += loss_kl * self.params['kl_weight']
# Calculate the gradients for generators and discriminators
f_gradients = gen_tape.gradient(loss_y, self.f_net.trainable_variables)
# Apply the gradients to the optimizer
self.f_optimizer.apply_gradients(zip(f_gradients, self.f_net.trainable_variables))
return loss_y, loss_mse
# Update posterior of latent variables Z
@tf.function
def update_latent_variable_sgd(self, data_x, data_y, data_v, batch_idx, eps=1e-6):
with tf.GradientTape() as tape:
data_z = tf.gather(self.data_z, batch_idx, axis=0)
data_z0 = data_z[:,:self.params['z_dims'][0]]
data_z1 = data_z[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
# logp(v|z) for covariate model
mu_v = self.g_net(data_z)[:,:self.params['v_dim']]
if 'sigma_v' in self.params:
sigma_square_v = self.params['sigma_v']**2
else:
sigma_square_v = tf.nn.softplus(self.g_net(data_z)[:,-1]) + eps
loss_pv_z = tf.reduce_sum((data_v - mu_v)**2, axis=1)/(2*sigma_square_v) + \
self.params['v_dim'] * tf.math.log(sigma_square_v)/2
loss_pv_z = tf.reduce_mean(loss_pv_z)
# log(x|z) for treatment model
mu_x = self.h_net(tf.concat([data_z0, data_z2], axis=-1))[:,:1]
if 'sigma_x' in self.params:
sigma_square_x = self.params['sigma_x']**2
else:
sigma_square_x = tf.nn.softplus(self.h_net(tf.concat([data_z0, data_z2], axis=-1))[:,-1]) + eps
if self.params['binary_treatment']:
loss_px_z = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=data_x,
logits=mu_x))
else:
loss_px_z = tf.reduce_sum((data_x - mu_x)**2, axis=1)/(2*sigma_square_x) + \
tf.math.log(sigma_square_x)/2
loss_px_z = tf.reduce_mean(loss_px_z)
# log(y|z,x) for outcome model
mu_y = self.f_net(tf.concat([data_z0, data_z1, data_x], axis=-1))[:,:1]
if 'sigma_y' in self.params:
sigma_square_y = self.params['sigma_y']**2
else:
sigma_square_y = tf.nn.softplus(self.f_net(tf.concat([data_z0, data_z1, data_x], axis=-1))[:,-1]) + eps
loss_py_zx = tf.reduce_sum((data_y - mu_y)**2, axis=1)/(2*sigma_square_y) + \
tf.math.log(sigma_square_y)/2
loss_py_zx = tf.reduce_mean(loss_py_zx)
loss_prior_z = tf.reduce_sum(data_z**2, axis=1)/2
loss_prior_z = tf.reduce_mean(loss_prior_z)
loss_postrior_z = loss_pv_z + loss_px_z + loss_py_zx + loss_prior_z
#loss_postrior_z = loss_postrior_z/self.params['v_dim']
# Calculate the gradients
posterior_gradients = tape.gradient(loss_postrior_z, [self.data_z])
# Apply the gradients to the optimizer
self.posterior_optimizer.apply_gradients(zip(posterior_gradients, [self.data_z]))
return loss_postrior_z
#################################### EGM initialization ###########################################
@tf.function
def train_disc_step(self, data_z, data_v):
epsilon_z = tf.random.uniform([],minval=0., maxval=1.)
with tf.GradientTape(persistent=True) as disc_tape:
with tf.GradientTape() as gp_tape:
data_z_ = self.e_net(data_v)
data_z_hat = data_z*epsilon_z + data_z_*(1-epsilon_z)
data_dz_hat = self.dz_net(data_z_hat)
data_dz_ = self.dz_net(data_z_)
data_dz = self.dz_net(data_z)
dz_loss = -tf.reduce_mean(data_dz) + tf.reduce_mean(data_dz_)
# Calculate gradient penalty
grad_z = gp_tape.gradient(data_dz_hat, data_z_hat)
grad_norm_z = tf.sqrt(tf.reduce_sum(tf.square(grad_z), axis=1))
gpz_loss = tf.reduce_mean(tf.square(grad_norm_z - 1.0))
d_loss = dz_loss + 10 * gpz_loss
# Calculate the gradients for generators and discriminators
d_gradients = disc_tape.gradient(d_loss, self.dz_net.trainable_variables)
# Apply the gradients to the optimizer
self.d_pre_optimizer.apply_gradients(zip(d_gradients, self.dz_net.trainable_variables))
return dz_loss, d_loss
@tf.function
def train_gen_step(self, data_z, data_v, data_x, data_y):
with tf.GradientTape(persistent=True) as gen_tape:
sigma_square_loss = 0
data_v_ = self.g_net(data_z)[:,:self.params['v_dim']]
sigma_square_loss += tf.reduce_mean(tf.square(self.g_net(data_z)[:,-1]))
data_z_ = self.e_net(data_v)
data_z0 = data_z_[:,:self.params['z_dims'][0]]
data_z1 = data_z_[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z_[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
data_z__= self.e_net(data_v_)
data_v__ = self.g_net(data_z_)[:,:self.params['v_dim']]
data_dz_ = self.dz_net(data_z_)
l2_loss_v = tf.reduce_mean((data_v - data_v__)**2)
l2_loss_z = tf.reduce_mean((data_z - data_z__)**2)
e_loss_adv = -tf.reduce_mean(data_dz_)
data_y_ = self.f_net(tf.concat([data_z0, data_z1, data_x], axis=-1))[:,:1]
sigma_square_loss += tf.reduce_mean(
tf.square(self.f_net(tf.concat([data_z0, data_z1, data_x], axis=-1))[:,-1]))
data_x_ = self.h_net(tf.concat([data_z0, data_z2], axis=-1))[:,:1]
sigma_square_loss += tf.reduce_mean(
tf.square(self.h_net(tf.concat([data_z0, data_z2], axis=-1))[:,-1]))
if self.params['binary_treatment']:
l2_loss_x = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=data_x,
logits=data_x_))
else:
l2_loss_x = tf.reduce_mean((data_x_ - data_x)**2)
l2_loss_y = tf.reduce_mean((data_y_ - data_y)**2)
g_e_loss = e_loss_adv+(l2_loss_v + self.params['use_z_rec']*l2_loss_z) \
+ (l2_loss_x+l2_loss_y) + 0.001 * sigma_square_loss
# Calculate the gradients for generators and discriminators
g_e_gradients = gen_tape.gradient(g_e_loss, self.g_net.trainable_variables+self.e_net.trainable_variables+\
self.f_net.trainable_variables+self.h_net.trainable_variables)
# Apply the gradients to the optimizer
self.g_pre_optimizer.apply_gradients(zip(g_e_gradients, self.g_net.trainable_variables+self.e_net.trainable_variables+\
self.f_net.trainable_variables+self.h_net.trainable_variables))
return e_loss_adv, l2_loss_v, l2_loss_z, l2_loss_x, l2_loss_y, g_e_loss
def egm_init(self, data, egm_n_iter=30000, batch_size=32, egm_batches_per_eval=500, verbose=1):
"""Run the EGM warm-start used by :meth:`fit`.
This helper performs the Encoding Generative Modeling (EGM)
initialization. In the current workflow it is typically called from :meth:`fit` when
``use_egm_init=True``.
Parameters
----------
data : tuple of np.ndarray
A triplet ``(data_x, data_y, data_v)``.
egm_n_iter : int, default=30000
Number of EGM mini-batch iterations.
batch_size : int, default=32
Mini-batch size.
egm_batches_per_eval : int, default=500
Evaluate and log every this many iterations.
verbose : int, default=1
Verbosity level (0 = silent).
"""
data_x, data_y, data_v = data
print('EGM Initialization Starts ...')
for batch_iter in range(egm_n_iter+1):
# Update model parameters of Discriminator
for _ in range(self.params['g_d_freq']):
batch_idx = np.random.choice(len(data_x), batch_size, replace=False)
batch_z = self.z_sampler.get_batch(batch_size)
batch_v = data_v[batch_idx,:]
dz_loss, d_loss = self.train_disc_step(batch_z, batch_v)
# Update model parameters of G, H, F with SGD
batch_z = self.z_sampler.get_batch(batch_size)
batch_idx = np.random.choice(len(data_x), batch_size, replace=False)
batch_x = data_x[batch_idx,:]
batch_y = data_y[batch_idx,:]
batch_v = data_v[batch_idx,:]
e_loss_adv, l2_loss_v, l2_loss_z, l2_loss_x, l2_loss_y, g_e_loss = self.train_gen_step(batch_z, batch_v, batch_x, batch_y)
if batch_iter % egm_batches_per_eval == 0:
loss_contents = (
'EGM Initialization Iter [%d] : e_loss_adv [%.4f], l2_loss_v [%.4f], l2_loss_z [%.4f], '
'l2_loss_x [%.4f], l2_loss_y [%.4f], g_e_loss [%.4f], dz_loss [%.4f], d_loss [%.4f]'
% (batch_iter, e_loss_adv, l2_loss_v, l2_loss_z, l2_loss_x, l2_loss_y, g_e_loss, dz_loss, d_loss)
)
if verbose:
print(loss_contents)
causal_pre, mse_x, mse_y, mse_v = self.evaluate(data = data)
causal_pre = causal_pre.numpy()
if self.params['save_res']:
save_data('{}/causal_pre_egm_init_iter-{}.txt'.format(self.save_dir, batch_iter), causal_pre)
print('EGM Initialization Ends.')
#################################### EGM initialization #############################################
[docs]
def fit(self, data, epochs=100, epochs_per_eval=5, batch_size=32, startoff=0, use_egm_init=True,
egm_n_iter=30000, egm_batches_per_eval=500, save_format='txt', verbose=1):
"""Train CausalBGM with an optional EGM warm-start.
Parameters
----------
data : tuple of np.ndarray
Training data ``(data_x, data_y, data_v)``.
epochs : int, default=100
Number of training epochs.
epochs_per_eval : int, default=5
Evaluate the full training set every this many epochs.
batch_size : int, default=32
Mini-batch size used for both EGM initialization and iterative
updates.
startoff : int, default=0
Start tracking the best model only after this epoch.
use_egm_init : bool, default=True
If ``True``, run EGM initialization before iterative training.
egm_n_iter : int, default=30000
Number of EGM mini-batch iterations when ``use_egm_init=True``.
egm_batches_per_eval : int, default=500
Logging interval for EGM initialization.
save_format : str, default='txt'
File format used when saving causal estimates.
verbose : int, default=1
Verbosity level. Set to ``0`` to suppress progress logging.
Notes
-----
After the optional EGM warm-start, latent variables are initialized
from ``e(V)``. If EGM is skipped, they are initialized from a standard
normal distribution.
"""
data_x, data_y, data_v = data
if self.params['save_res']:
f_params = open('{}/params.txt'.format(self.save_dir),'w')
f_params.write(str(self.params))
f_params.close()
if use_egm_init:
self.egm_init(data, egm_n_iter=egm_n_iter, egm_batches_per_eval=egm_batches_per_eval, batch_size=batch_size, verbose=verbose)
print('Initialize latent variables Z with e(V)...')
data_z_init = self.e_net(data_v)
else:
print('Random initialization of latent variables Z...')
data_z_init = np.random.normal(0, 1, size = (len(data_x), sum(self.params['z_dims']))).astype('float32')
self.data_z = tf.Variable(data_z_init, name="Latent Variable",trainable=True)
best_loss = np.inf
print('Iterative Updating Starts ...')
for epoch in range(epochs+1):
sample_idx = np.random.choice(len(data_x), len(data_x), replace=False)
# Create a progress bar for batches
with tqdm(total=int(np.ceil(len(data_x) / batch_size)), desc=f"Epoch {epoch}/{epochs}", unit="batch") as batch_bar:
for i in range(0, len(data_x), batch_size):
batch_idx = sample_idx[i:i+batch_size]
# Update model parameters of G, H, F with SGD
batch_z = tf.gather(self.data_z, batch_idx, axis = 0)
batch_x = data_x[batch_idx,:]
batch_y = data_y[batch_idx,:]
batch_v = data_v[batch_idx,:]
loss_v, loss_mse_v = self.update_g_net(batch_z, batch_v)
loss_x, loss_mse_x = self.update_h_net(batch_z, batch_x)
loss_y, loss_mse_y = self.update_f_net(batch_z, batch_x, batch_y)
# Update Z by maximizing a posterior or posterior mean
loss_postrior_z = self.update_latent_variable_sgd(batch_x, batch_y, batch_v, batch_idx)
# Update the progress bar with the current loss information
loss_contents = (
'loss_px_z: [%.4f], loss_mse_x: [%.4f], loss_py_z: [%.4f], '
'loss_mse_y: [%.4f], loss_pv_z: [%.4f], loss_mse_v: [%.4f], loss_postrior_z: [%.4f]'
% (loss_x, loss_mse_x, loss_y, loss_mse_y, loss_v, loss_mse_v, loss_postrior_z)
)
batch_bar.set_postfix_str(loss_contents)
batch_bar.update(1)
# Evaluate the full training data and print metrics for the epoch
if epoch % epochs_per_eval == 0:
causal_pre, mse_x, mse_y, mse_v = self.evaluate(data = data, data_z = self.data_z)
causal_pre = causal_pre.numpy()
if verbose:
print('Epoch [%d/%d]: MSE_x: %.4f, MSE_y: %.4f, MSE_v: %.4f\n' % (epoch, epochs, mse_x, mse_y, mse_v))
if epoch >= startoff and mse_y < best_loss:
best_loss = mse_y
self.best_causal_pre = causal_pre
self.best_epoch = epoch
if self.params['save_model']:
ckpt_save_path = self.ckpt_manager.save(epoch)
print('Saving checkpoint for epoch {} at {}'.format(epoch, ckpt_save_path))
if self.params['save_res']:
save_data('{}/causal_pre_at_{}.{}'.format(self.save_dir, epoch, save_format), causal_pre)
@tf.function
def evaluate(self, data, data_z=None, nb_intervals=200):
data_x, data_y, data_v = data
if data_z is None:
data_z = self.e_net(data_v)
data_z0 = data_z[:,:self.params['z_dims'][0]]
data_z1 = data_z[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
data_v_pred = self.g_net(data_z)[:,:self.params['v_dim']]
data_y_pred = self.f_net(tf.concat([data_z0, data_z1, data_x], axis=-1))[:,:1]
data_x_pred = self.h_net(tf.concat([data_z0, data_z2], axis=-1))[:,:1]
if self.params['binary_treatment']:
data_x_pred = tf.sigmoid(data_x_pred)
mse_v = tf.reduce_mean((data_v-data_v_pred)**2)
mse_x = tf.reduce_mean((data_x-data_x_pred)**2)
mse_y = tf.reduce_mean((data_y-data_y_pred)**2)
if self.params['binary_treatment']:
# Individual treatment effect (ITE) && average treatment effect (ATE)
y_pred_pos = self.f_net(tf.concat([data_z0, data_z1, np.ones((len(data_x),1))], axis=-1))[:,:1]
y_pred_neg = self.f_net(tf.concat([data_z0, data_z1, np.zeros((len(data_x),1))], axis=-1))[:,:1]
ite_pre = y_pred_pos-y_pred_neg
return ite_pre, mse_x, mse_y, mse_v
else:
# Average dose response function (ADRF)
x_min = tfp.stats.percentile(data_x, 5.0)
x_max = tfp.stats.percentile(data_x, 95.0)
x_values = tf.linspace(x_min, x_max, nb_intervals)
def compute_dose_response(x):
data_x_tile = tf.fill([tf.shape(data_x)[0], 1], x)
data_x_tile = tf.cast(data_x_tile, tf.float32)
y_pred = self.f_net(tf.concat([data_z0, data_z1, data_x_tile], axis=-1))[:, :1]
return tf.reduce_mean(y_pred)
dose_response = tf.map_fn(compute_dose_response, x_values, fn_output_signature=tf.float32)
return dose_response, mse_x, mse_y, mse_v
# Predict with MCMC sampling
[docs]
def predict(self, data, alpha=0.01, n_mcmc=3000, burn_in=5000, x_values=None, q_sd=1.0, sample_y=True, bs=10000):
"""Estimate causal effects with posterior intervals from latent MCMC samples.
Parameters
----------
data : tuple of np.ndarray
Test data ``(data_x, data_y, data_v)``.
alpha : float, default=0.01
Significance level used for posterior intervals.
n_mcmc : int, default=3000
Number of retained MCMC samples.
burn_in : int, default=5000
Number of burn-in iterations for the Metropolis-Hastings sampler.
x_values : float or array-like, optional
Treatment values used to evaluate the dose-response curve for
continuous-treatment settings.
q_sd : float, default=1.0
Proposal standard deviation for the Metropolis-Hastings sampler.
sample_y : bool, default=True
If ``True``, sample from the outcome model using the variance head.
If ``False``, use the posterior mean of the outcome model.
bs : int, default=10000
Number of test subjects processed per batch prediction.
Returns
-------
effect : np.ndarray
Binary treatment: ITE estimates with shape ``(n,)``.
Continuous treatment: ADRF estimates with shape ``(len(x_values),)``.
pos_int : np.ndarray
Posterior intervals with shape ``(n, 2)`` for binary treatment or
``(len(x_values), 2)`` for continuous treatment.
"""
assert 0 < alpha < 1, "The significance level 'alpha' must be greater than 0 and less than 1."
if not self.params['binary_treatment']:
# Continuous treatment requires an evaluation grid.
if x_values is None:
raise ValueError("For continuous treatment, 'x_values' must not be None. Provide a list or a single treatment value.")
if x_values is not None:
if np.isscalar(x_values):
x_values = np.array([x_values], dtype=float)
else:
x_values = np.array(x_values, dtype=float)
data_x, data_y, data_v = data
n_test = len(data_x)
bs = max(1, int(bs))
print('MCMC Latent Variable Sampling ...')
if self.params['binary_treatment']:
ite_mean = np.zeros(n_test, dtype=np.float32)
posterior_interval_upper = np.zeros(n_test, dtype=np.float32)
posterior_interval_lower = np.zeros(n_test, dtype=np.float32)
for start in range(0, n_test, bs):
end = min(start + bs, n_test)
batch_data = (data_x[start:end], data_y[start:end], data_v[start:end])
batch_posterior_z = self.metropolis_hastings_sampler(
batch_data, burn_in=burn_in, n_keep=n_mcmc, q_sd=q_sd
)
causal_effects = self.infer_from_latent_posterior(
batch_posterior_z, x_values=x_values, sample_y=sample_y
).numpy()
ite_mean[start:end] = np.mean(causal_effects, axis=0)
posterior_interval_upper[start:end] = np.quantile(causal_effects, 1 - alpha / 2, axis=0)
posterior_interval_lower[start:end] = np.quantile(causal_effects, alpha / 2, axis=0)
pos_int = np.stack([posterior_interval_lower, posterior_interval_upper], axis=1)
return ite_mean, pos_int
else:
adrf_draw_sums = np.zeros((len(x_values), n_mcmc), dtype=np.float32)
n_seen = 0
for start in range(0, n_test, bs):
end = min(start + bs, n_test)
batch_data = (data_x[start:end], data_y[start:end], data_v[start:end])
batch_posterior_z = self.metropolis_hastings_sampler(
batch_data, burn_in=burn_in, n_keep=n_mcmc, q_sd=q_sd
)
batch_effects = self.infer_from_latent_posterior(
batch_posterior_z, x_values=x_values, sample_y=sample_y
).numpy()
batch_n = end - start
adrf_draw_sums += batch_effects * batch_n
n_seen += batch_n
causal_effects = adrf_draw_sums / float(n_seen)
ADRF = np.mean(causal_effects, axis=1)
posterior_interval_upper = np.quantile(causal_effects, 1 - alpha / 2, axis=1)
posterior_interval_lower = np.quantile(causal_effects, alpha / 2, axis=1)
pos_int = np.stack([posterior_interval_lower, posterior_interval_upper], axis=1)
return ADRF, pos_int
@tf.function
def infer_from_latent_posterior(self, data_posterior_z, x_values=None, sample_y=True, eps=1e-6):
"""Infer causal estimate on the test data and give estimation interval and posterior latent variables. ITE is estimated for binary treatment and ADRF is estimated for continous treatment.
data_posterior_z: (np.ndarray): Posterior latent variables with shape (n_samples, n, p), where p is the dimension of Z.
x_values: (list of floats or np.ndarray): Number of intervals for the dose response function.
sample_y: (bool): consider the variance function in outcome generative model.
return (np.ndarray):
ITE with shape (n_samples, n) containing all the MCMC samples.
ADRF with shape (len(x_values), n_samples) containing all the MCMC samples for each treatment value.
"""
# Extract the components of Z for X,Y
data_z0 = data_posterior_z[:,:,:self.params['z_dims'][0]]
data_z1 = data_posterior_z[:,:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_posterior_z[:,:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
if self.params['binary_treatment']:
# Extract mean and sigma^2 of positive samples both with shape (n_keep, n_test)
y_out_pos_all = tf.map_fn(
lambda z: self.f_net(tf.concat([z[:, :self.params['z_dims'][0]],
z[:, self.params['z_dims'][0]:sum(self.params['z_dims'][:2])],
tf.ones([tf.shape(z)[0], 1])], axis=-1)),
data_posterior_z,
fn_output_signature=tf.float32
)
mu_y_pos_all = y_out_pos_all[:,:,0]
if 'sigma_y' in self.params:
sigma_square_y_pos = self.params['sigma_y']**2
else:
sigma_square_y_pos = tf.nn.softplus(y_out_pos_all[:,:,1]) + eps
if sample_y:
y_pred_pos_all = tf.random.normal(
shape=tf.shape(mu_y_pos_all), mean=mu_y_pos_all, stddev=tf.sqrt(sigma_square_y_pos)
)
else:
y_pred_pos_all = mu_y_pos_all
# Extract mean and sigma^2 of negative samples both with shape (n_keep, n_test)
y_out_neg_all = tf.map_fn(
lambda z: self.f_net(tf.concat([z[:, :self.params['z_dims'][0]],
z[:, self.params['z_dims'][0]:sum(self.params['z_dims'][:2])],
tf.zeros([tf.shape(z)[0], 1])], axis=-1)),
data_posterior_z,
fn_output_signature=tf.float32
)
mu_y_neg_all = y_out_neg_all[:,:,0]
if 'sigma_y' in self.params:
sigma_square_y_neg = self.params['sigma_y']**2
else:
sigma_square_y_neg = tf.nn.softplus(y_out_neg_all[:,:,1]) + eps
if sample_y:
y_pred_neg_all = tf.random.normal(
shape=tf.shape(mu_y_neg_all), mean=mu_y_neg_all, stddev=tf.sqrt(sigma_square_y_neg)
)
else:
y_pred_neg_all = mu_y_neg_all
ite_pred_all = y_pred_pos_all-y_pred_neg_all
return ite_pred_all
else:
def compute_dose_response(x):
data_x = tf.fill([tf.shape(data_posterior_z)[1], 1], x)
data_x = tf.cast(data_x, tf.float32)
y_out_all = tf.map_fn(
lambda z: self.f_net(tf.concat([z[:, :self.params['z_dims'][0]],
z[:, self.params['z_dims'][0]:sum(self.params['z_dims'][:2])],
data_x],axis=-1)),
data_posterior_z,
fn_output_signature=tf.float32
)
mu_y_all = y_out_all[:,:,0]
if 'sigma_y' in self.params:
sigma_square_y = self.params['sigma_y']**2
else:
sigma_square_y = tf.nn.softplus(y_out_all[:,:,1]) + eps
if sample_y:
y_pred_all = tf.random.normal(
shape=tf.shape(mu_y_all), mean=mu_y_all, stddev=tf.sqrt(sigma_square_y)
)
else:
y_pred_all = mu_y_all
return tf.reduce_mean(y_pred_all, axis=1)
dose_response = tf.map_fn(compute_dose_response, x_values, fn_output_signature=tf.float32)
return dose_response
@tf.function
def get_log_posterior(self, data_x, data_y, data_v, data_z, eps=1e-6):
"""
Calculate log posterior.
data_x: (np.ndarray): Input data with shape (n, 1), where p is the dimension of X.
data_y: (np.ndarray): Input data with shape (n, 1), where q is the dimension of Y.
data_v: (np.ndarray): Input data with shape (n, p), where r is the dimension of V.
data_z: (np.ndarray): Input data with shape (n, q), where q is the dimension of Z.
return (np.ndarray): Log posterior with shape (n, ).
"""
data_z0 = data_z[:,:self.params['z_dims'][0]]
data_z1 = data_z[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
g_net_output = self.g_net(data_z)
mu_v = g_net_output[:,:self.params['v_dim']]
if 'sigma_v' in self.params:
sigma_square_v = self.params['sigma_v']**2
else:
sigma_square_v = tf.nn.softplus(g_net_output[:,-1]) + eps
h_net_output = self.h_net(tf.concat([data_z0, data_z2], axis=-1))
mu_x = h_net_output[:,:1]
if 'sigma_x' in self.params:
sigma_square_x = self.params['sigma_x']**2
else:
sigma_square_x = tf.nn.softplus(h_net_output[:,-1]) + eps
f_net_output = self.f_net(tf.concat([data_z0, data_z1, data_x], axis=-1))
mu_y = f_net_output[:,:1]
if 'sigma_y' in self.params:
sigma_square_y = self.params['sigma_y']**2
else:
sigma_square_y = tf.nn.softplus(f_net_output[:,-1]) + eps
loss_pv_z = tf.reduce_sum((data_v - mu_v)**2, axis=1)/(2*sigma_square_v) + \
self.params['v_dim'] * tf.math.log(sigma_square_v)/2
if self.params['binary_treatment']:
loss_px_z = tf.squeeze(tf.nn.sigmoid_cross_entropy_with_logits(labels=data_x,logits=mu_x))
else:
loss_px_z = tf.reduce_sum((data_x - mu_x)**2, axis=1)/(2*sigma_square_x) + \
tf.math.log(sigma_square_x)/2
loss_py_zx = tf.reduce_sum((data_y - mu_y)**2, axis=1)/(2*sigma_square_y) + \
tf.math.log(sigma_square_y)/2
loss_prior_z = tf.reduce_sum(data_z**2, axis=1)/2
loss_postrior_z = loss_pv_z + loss_px_z + loss_py_zx + loss_prior_z
log_posterior = -loss_postrior_z
return log_posterior
def metropolis_hastings_sampler(self, data, initial_q_sd = 1.0, q_sd = None, burn_in = 5000, n_keep = 3000, target_acceptance_rate=0.25, tolerance=0.05, adjustment_interval=50, adaptive_sd=None, window_size=100):
"""
Samples from the posterior distribution P(Z|X,Y,V) using the Metropolis-Hastings algorithm with adaptive proposal adjustment.
Args:
data (tuple): Tuple containing data_x, data_y, data_v.
q_sd (float or None): Fixed standard deviation for the proposal distribution. If None, `q_sd` will adapt.
initial_q_sd (float): Initial standard deviation of the proposal distribution.
burn_in (int): Number of samples for burn-in, set to 1000 as an initial estimate.
n_keep (int): Number of samples retained after burn-in.
target_acceptance_rate (float): Target acceptance rate for the Metropolis-Hastings algorithm.
tolerance (float): Acceptable deviation from the target acceptance rate.
adjustment_interval (int): Number of iterations between each adjustment of `q_sd`.
window_size (int): The size of the sliding window for acceptance rate calculation.
Returns:
np.ndarray: Posterior samples with shape (n_keep, n, q), where q is the dimension of Z.
"""
data_x, data_y, data_v = data
# Initialize the state of n chains
current_state = np.random.normal(0, 1, size = (len(data_x), sum(self.params['z_dims']))).astype('float32')
# Initialize the list to store the samples
samples = []
counter = 0
# Sliding window for acceptance tracking
recent_acceptances = []
# Determine if q_sd should be adaptive
if adaptive_sd is None:
adaptive_sd = (q_sd is None or q_sd <= 0)
# Set the initial q_sd
if adaptive_sd:
q_sd = initial_q_sd
# Run the Metropolis-Hastings algorithm
while len(samples) < n_keep:
# Propose a new state by sampling from a multivariate normal distribution
proposed_state = current_state + np.random.normal(0, q_sd, size = (len(data_x), sum(self.params['z_dims']))).astype('float32')
# Compute the acceptance ratio
proposed_log_posterior = self.get_log_posterior(data_x, data_y, data_v, proposed_state)
current_log_posterior = self.get_log_posterior(data_x, data_y, data_v, current_state)
#acceptance_ratio = np.exp(proposed_log_posterior-current_log_posterior)
acceptance_ratio = np.exp(np.minimum(proposed_log_posterior - current_log_posterior, 0))
# Accept or reject the proposed state
indices = np.random.rand(len(data_x)) < acceptance_ratio
current_state[indices] = proposed_state[indices]
# Update the sliding window
recent_acceptances.append(indices)
if len(recent_acceptances) > window_size:
# Keep only the most recent `window_size` elements
recent_acceptances = recent_acceptances[-window_size:]
# Adjust q_sd periodically during the burn-in phase
if adaptive_sd and counter < burn_in and counter % adjustment_interval == 0 and counter > 0:
# Calculate the current acceptance rate
current_acceptance_rate = np.sum(recent_acceptances) / (len(recent_acceptances)*len(data_x))
print(f"Current MCMC Acceptance Rate: {current_acceptance_rate:.4f}")
# Adjust q_sd based on the acceptance rate
if current_acceptance_rate < target_acceptance_rate - tolerance:
q_sd *= 0.9 # Decrease q_sd to increase acceptance rate
elif current_acceptance_rate > target_acceptance_rate + tolerance:
q_sd *= 1.1 # Increase q_sd to decrease acceptance rate
print(f"MCMC Proposal Standard Deviation (q_sd): {q_sd:.4f}")
# Append the current state to the list of samples
if counter >= burn_in:
samples.append(current_state.copy())
counter += 1
# Calculate the acceptance rate
acceptance_rate = np.sum(recent_acceptances) / (len(recent_acceptances)*len(data_x))
print(f"Final MCMC Acceptance Rate: {acceptance_rate:.4f}")
#print(f"Final Proposal Standard Deviation (q_sd): {q_sd:.4f}")
return np.array(samples)