import datetime
import os
import dateutil.tz
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from bayesgm.datasets import Gaussian_sampler
from bayesgm.utils.data_io import save_data
from ..networks import BaseFullyConnectedNet, BayesianFullyConnectedNet, Discriminator
from .base import CausalBGM
[docs]
class IdentifiableCausalBGM(CausalBGM):
"""Identifiable CausalBGM using nonlinear ICA theory (iVAE).
Achieves identifiability under mild conditions by introducing an auxiliary
variable :math:`U` and conditioning the latent prior on it:
:math:`Z \\mid U \\sim \\mathcal{N}(\\mu(U), \\sigma^2(U) I)`.
Inherits from :class:`CausalBGM`.
Parameters
----------
params : dict
Same keys as :class:`CausalBGM`, plus optionally:
- ``'n_segments'`` (int): Number of auxiliary-variable segments
(default 10).
- ``'prior_units'`` (list[int]): Hidden-layer sizes for the prior
network (default ``[64]``).
timestamp : str or None, optional
Timestamp string for the run.
random_seed : int or None, optional
If provided, sets the global random seed for reproducibility.
"""
def __init__(self, params, timestamp=None, random_seed=None):
self.params = params
self.timestamp = timestamp
# Set random seed for reproducibility
if random_seed is not None:
tf.keras.utils.set_random_seed(random_seed)
os.environ['TF_DETERMINISTIC_OPS'] = '1'
tf.config.experimental.enable_op_determinism()
# iVAE modification: Add default number of segments if not provided
if 'n_segments' not in self.params:
self.params['n_segments'] = 10 # Default value for auxiliary variable segments
z_dim = sum(params['z_dims'])
# Initialize networks (g, e, f, h)
if self.params['use_bnn']:
self.g_net = BayesianFullyConnectedNet(input_dim=z_dim, output_dim=params['v_dim'] + 1,
model_name='g_net', nb_units=params['g_units'])
self.e_net = BayesianFullyConnectedNet(input_dim=params['v_dim'], output_dim=z_dim,
model_name='e_net', nb_units=params['e_units'])
self.f_net = BayesianFullyConnectedNet(input_dim=params['z_dims'][0] + params['z_dims'][1] + 1,
output_dim=2, model_name='f_net', nb_units=params['f_units'])
self.h_net = BayesianFullyConnectedNet(input_dim=params['z_dims'][0] + params['z_dims'][2],
output_dim=2, model_name='h_net', nb_units=params['h_units'])
# iVAE modification: Define prior network p(z|u) using BNN
self.prior_net = BayesianFullyConnectedNet(input_dim=self.params['n_segments'], output_dim=z_dim + 1,
model_name='prior_net', nb_units=params.get('prior_units', [64])) # Smaller net for prior typically sufficient
else:
self.g_net = BaseFullyConnectedNet(input_dim=z_dim, output_dim=params['v_dim'] + 1,
model_name='g_net', nb_units=params['g_units'])
self.e_net = BaseFullyConnectedNet(input_dim=params['v_dim'], output_dim=z_dim,
model_name='e_net', nb_units=params['e_units'])
self.f_net = BaseFullyConnectedNet(input_dim=params['z_dims'][0] + params['z_dims'][1] + 1,
output_dim=2, model_name='f_net', nb_units=params['f_units'])
self.h_net = BaseFullyConnectedNet(input_dim=params['z_dims'][0] + params['z_dims'][2],
output_dim=2, model_name='h_net', nb_units=params['h_units'])
# iVAE modification: Define prior network p(z|u) using standard NN
self.prior_net = BaseFullyConnectedNet(input_dim=self.params['n_segments'], output_dim=z_dim + 1,
model_name='prior_net', nb_units=params.get('prior_units', [64]))
self.dz_net = Discriminator(input_dim=z_dim, model_name='dz_net',
nb_units=params['dz_units'])
# Optimizers for pre-training and main training phase
self.g_pre_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.9, beta_2=0.99)
self.d_pre_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.9, beta_2=0.99)
self.z_sampler = Gaussian_sampler(mean=np.zeros(z_dim), sd=1.0)
self.g_optimizer = tf.keras.optimizers.Adam(params['lr_theta'], beta_1=0.9, beta_2=0.99)
self.f_optimizer = tf.keras.optimizers.Adam(params['lr_theta'], beta_1=0.9, beta_2=0.99)
self.h_optimizer = tf.keras.optimizers.Adam(params['lr_theta'], beta_1=0.9, beta_2=0.99)
self.posterior_optimizer = tf.keras.optimizers.Adam(params['lr_z'], beta_1=0.9, beta_2=0.99)
# iVAE modification: Add optimizer for the prior network parameters
self.prior_optimizer = tf.keras.optimizers.Adam(params['lr_theta'], beta_1=0.9, beta_2=0.99)
self.initialize_nets()
# Checkpoint and results saving setup
if self.timestamp is None:
now = datetime.datetime.now(dateutil.tz.tzlocal())
self.timestamp = now.strftime('%Y%m%d_%H%M%S')
self.checkpoint_path = "{}/checkpoints/{}/{}".format(
params['output_dir'], params['dataset'], self.timestamp)
if self.params['save_model'] and not os.path.exists(self.checkpoint_path):
os.makedirs(self.checkpoint_path)
self.save_dir = "{}/results/{}/{}".format(
params['output_dir'], params['dataset'], self.timestamp)
if self.params['save_res'] and not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.ckpt = tf.train.Checkpoint(g_net=self.g_net,
e_net=self.e_net,
f_net=self.f_net,
h_net=self.h_net,
dz_net=self.dz_net,
prior_net=self.prior_net, # iVAE modification: Add prior_net to checkpoint
g_pre_optimizer=self.g_pre_optimizer,
d_pre_optimizer=self.d_pre_optimizer,
g_optimizer=self.g_optimizer,
f_optimizer=self.f_optimizer,
h_optimizer=self.h_optimizer,
posterior_optimizer=self.posterior_optimizer,
prior_optimizer=self.prior_optimizer) # iVAE modification: Add prior_optimizer to checkpoint
self.ckpt_manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_path, max_to_keep=5)
if self.ckpt_manager.latest_checkpoint:
self.ckpt.restore(self.ckpt_manager.latest_checkpoint)
print('Latest checkpoint restored!!')
def initialize_nets(self, print_summary=False):
"""Initialize all the networks in IdentifiableCausalBGM."""
self.g_net(np.zeros((1, sum(self.params['z_dims']))))
self.f_net(np.zeros((1, self.params['z_dims'][0] + self.params['z_dims'][1] + 1)))
self.h_net(np.zeros((1, self.params['z_dims'][0] + self.params['z_dims'][2])))
self.prior_net(np.zeros((1, self.params['n_segments'])))
if print_summary:
print(self.g_net.summary())
print(self.f_net.summary())
print(self.h_net.summary())
print(self.prior_net.summary()) # iVAE modification
# iVAE modification: Update posterior of latent variables Z and prior network parameters
@tf.function
def update_latent_variable_sgd(self, data_x, data_y, data_v, data_z, data_u, eps=1e-6):
with tf.GradientTape(persistent=True) as tape: # persistent=True to calculate multiple gradients
data_z0 = data_z[:,:self.params['z_dims'][0]]
data_z1 = data_z[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
# logp(v|z) for covariate model
mu_v = self.g_net(data_z)[:,:self.params['v_dim']]
if 'sigma_v' in self.params:
sigma_square_v = self.params['sigma_v']**2
else:
sigma_square_v = tf.nn.softplus(self.g_net(data_z)[:,-1]) + eps
loss_pv_z = tf.reduce_sum((data_v - mu_v)**2, axis=1)/(2*sigma_square_v) + \
self.params['v_dim'] * tf.math.log(sigma_square_v)/2
loss_pv_z = tf.reduce_mean(loss_pv_z)
# log(x|z) for treatment model
mu_x = self.h_net(tf.concat([data_z0, data_z2], axis=-1))[:,:1]
if 'sigma_x' in self.params:
sigma_square_x = self.params['sigma_x']**2
else:
sigma_square_x = tf.nn.softplus(self.h_net(tf.concat([data_z0, data_z2], axis=-1))[:,-1]) + eps
if self.params['binary_treatment']:
loss_px_z = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=data_x,
logits=mu_x))
else:
loss_px_z = tf.reduce_sum((data_x - mu_x)**2, axis=1)/(2*sigma_square_x) + \
tf.math.log(sigma_square_x)/2
loss_px_z = tf.reduce_mean(loss_px_z)
# log(y|z,x) for outcome model
mu_y = self.f_net(tf.concat([data_z0, data_z1, data_x], axis=-1))[:,:1]
if 'sigma_y' in self.params:
sigma_square_y = self.params['sigma_y']**2
else:
sigma_square_y = tf.nn.softplus(self.f_net(tf.concat([data_z0, data_z1, data_x], axis=-1))[:,-1]) + eps
loss_py_zx = tf.reduce_sum((data_y - mu_y)**2, axis=1)/(2*sigma_square_y) + \
tf.math.log(sigma_square_y)/2
loss_py_zx = tf.reduce_mean(loss_py_zx)
# iVAE modification: Replace standard prior loss with conditional prior loss -log p(z|u)
# Original prior loss: loss_prior_z = tf.reduce_sum(data_z**2, axis=1)/2
# loss_prior_z = tf.reduce_mean(loss_prior_z)
# Calculate prior parameters P(Z|U) = N(mu(U), sigma^2(U)I)
prior_output = self.prior_net(data_u)
mu_z_prior = prior_output[:, :sum(self.params['z_dims'])]
# Use scalar variance for all dimensions of z for simplicity
sigma_square_z_prior = tf.nn.softplus(prior_output[:, -1:]) + eps # Shape (batch_size, 1)
# Calculate negative log-likelihood for the conditional prior: -log P(Z|U)
dim_z = sum(self.params['z_dims'])
# Note: tf.squeeze converts shape (batch_size, 1) to (batch_size,) for element-wise division.
loss_term1 = tf.reduce_sum((data_z - mu_z_prior)**2, axis=1) / (2.0 * tf.squeeze(sigma_square_z_prior))
loss_term2 = dim_z * tf.math.log(tf.squeeze(sigma_square_z_prior)) / 2.0
loss_prior_z = tf.reduce_mean(loss_term1 + loss_term2)
if self.params['use_bnn']:
loss_kl_prior = sum(self.prior_net.losses)
loss_prior_z += loss_kl_prior * self.params.get('kl_weight', 1.0) # Add KL divergence for BNN prior network
loss_postrior_z = loss_pv_z + loss_px_z + loss_py_zx + loss_prior_z
# Calculate gradients for Z (E-step)
posterior_gradients = tape.gradient(loss_postrior_z, [data_z])
# Apply gradients to update Z
self.posterior_optimizer.apply_gradients(zip(posterior_gradients, [data_z]))
# Calculate gradients for prior network parameters (M-step for prior)
prior_net_gradients = tape.gradient(loss_postrior_z, self.prior_net.trainable_variables)
# Apply gradients to update prior network
self.prior_optimizer.apply_gradients(zip(prior_net_gradients, self.prior_net.trainable_variables))
del tape # release tape resources
return loss_postrior_z
[docs]
def fit(self, data,
batch_size=32, epochs=100, epochs_per_eval=5, startoff=0,
use_egm_init=True, egm_n_iter=30000, egm_batches_per_eval=500,
verbose=1, save_format='txt'):
"""Train the IdentifiableCausalBGM model on observed data.
The training procedure consists of two phases:
1. **EGM initialization** (optional) — warm-start by jointly training
encoder and generator with adversarial losses to obtain a good
starting point for the latent variables and model parameters.
This phase is optional and can be skipped by setting
``use_egm_init`` to ``False``.
2. **Iterative optimization** — generates an auxiliary variable
:math:`U` internally and jointly optimizes latent variables,
network parameters, and the conditional prior network.
Parameters
----------
data : tuple of np.ndarray
A triplet ``(data_x, data_y, data_v)``.
batch_size : int, default=32
Mini-batch size.
epochs : int, default=100
Number of training epochs.
epochs_per_eval : int, default=5
Evaluate every this many epochs.
startoff : int, default=0
Only start tracking the best model after this many epochs.
use_egm_init : bool, default=True
Whether to run EGM initialization before iterative training.
egm_n_iter : int, default=30000
Number of EGM initialization iterations.
egm_batches_per_eval : int, default=500
Evaluate EGM every this many iterations.
verbose : int, default=1
Verbosity level.
save_format : str, default='txt'
File format for saving causal estimates.
"""
data_x, data_y, data_v = data
n_samples = len(data_x)
# iVAE modification: Generate auxiliary variable U
print(f"Generating auxiliary variable U for {self.params['n_segments']} segments.")
n_segments = self.params['n_segments']
segment_indices = np.random.randint(0, n_segments, size=n_samples)
data_u = tf.keras.utils.to_categorical(segment_indices, num_classes=n_segments).astype('float32')
if self.params['save_res']:
f_params = open('{}/params.txt'.format(self.save_dir),'w')
f_params.write(str(self.params))
f_params.close()
if use_egm_init:
self.egm_init(data, egm_n_iter=egm_n_iter, egm_batches_per_eval=egm_batches_per_eval, batch_size=batch_size, verbose=verbose)
print('Initialize latent variables Z with e(V)...')
data_z_init = self.e_net(data_v)
else:
print('Random initialization of latent variables Z...')
data_z_init = np.random.normal(0, 1, size=(n_samples, sum(self.params['z_dims']))).astype('float32')
self.data_z = tf.Variable(data_z_init, name="Latent Variable", trainable=True)
best_loss = np.inf
print('Iterative Updating Starts ...')
for epoch in range(epochs + 1):
sample_idx = np.random.choice(n_samples, n_samples, replace=False)
with tqdm(total=n_samples // batch_size, desc=f"Epoch {epoch}/{epochs}", unit="batch") as batch_bar:
for i in range(0, n_samples - batch_size + 1, batch_size): # Skip the incomplete last batch
batch_idx = sample_idx[i:i + batch_size]
# Update model parameters of G, H, F with SGD
batch_z = tf.Variable(tf.gather(self.data_z, batch_idx, axis=0), name='batch_z', trainable=True)
batch_x = data_x[batch_idx, :]
batch_y = data_y[batch_idx, :]
batch_v = data_v[batch_idx, :]
batch_u = data_u[batch_idx, :] # iVAE modification: get batch for U
loss_v, loss_mse_v = self.update_g_net(batch_z, batch_v)
loss_x, loss_mse_x = self.update_h_net(batch_z, batch_x)
loss_y, loss_mse_y = self.update_f_net(batch_z, batch_x, batch_y)
# Update Z by maximizing a posterior or posterior mean, and update prior network parameters
loss_postrior_z = self.update_latent_variable_sgd(batch_x, batch_y, batch_v, batch_z, batch_u)
# Update data_z with updated batch_z
self.data_z.scatter_nd_update(
indices=tf.expand_dims(batch_idx, axis=1),
updates=batch_z
)
loss_contents = (
'loss_px_z: [%.4f], loss_mse_x: [%.4f], loss_py_z: [%.4f], '
'loss_mse_y: [%.4f], loss_pv_z: [%.4f], loss_mse_v: [%.4f], loss_postrior_z: [%.4f]'
% (loss_x, loss_mse_x, loss_y, loss_mse_y, loss_v, loss_mse_v, loss_postrior_z)
)
batch_bar.set_postfix_str(loss_contents)
batch_bar.update(1)
if epoch % epochs_per_eval == 0:
causal_pre, mse_x, mse_y, mse_v, data_x_pred, data_y_pred, data_v_pred = self.evaluate(data = data, data_z = self.data_z)
causal_pre = causal_pre.numpy()
if verbose:
print('Epoch [%d/%d]: MSE_x: %.4f, MSE_y: %.4f, MSE_v: %.4f\n' % (epoch, epochs, mse_x, mse_y, mse_v))
if epoch >= startoff and mse_y < best_loss:
best_loss = mse_y
self.best_causal_pre = causal_pre
self.best_epoch = epoch
if self.params['save_model']:
ckpt_save_path = self.ckpt_manager.save(epoch)
print('Saving checkpoint for epoch {} at {}'.format(epoch, ckpt_save_path))
if self.params['save_res']:
save_data('{}/causal_pre_at_{}.{}'.format(self.save_dir, epoch, save_format), causal_pre)
[docs]
def predict(self, data, alpha=0.01, n_mcmc=3000, x_values=None, q_sd=1.0, sample_y=True, bs=100):
"""Predict causal effects with posterior uncertainty via MCMC.
Same interface as :meth:`CausalBGM.predict`.
Internally generates a fresh auxiliary variable :math:`U` for the
conditional prior during MCMC sampling.
Parameters
----------
data : tuple of np.ndarray
A triplet ``(data_x, data_y, data_v)``.
alpha : float, default=0.01
Significance level for posterior intervals.
n_mcmc : int, default=3000
Number of posterior MCMC samples.
x_values : array-like or None
Treatment values for dose-response (continuous treatment).
q_sd : float, default=1.0
Proposal standard deviation for Metropolis-Hastings.
sample_y : bool, default=True
Whether to sample from the outcome variance model.
bs : int, default=100
Batch size for processing posterior samples.
Returns
-------
effect : np.ndarray
ITE (binary) or ADRF (continuous) point estimates.
pos_int : np.ndarray
Posterior intervals with shape ``(n, 2)`` or
``(len(x_values), 2)``.
"""
assert 0 < alpha < 1, "The significance level 'alpha' must be greater than 0 and less than 1."
if not self.params['binary_treatment']:
if x_values is None:
raise ValueError("For continuous treatment, 'x_values' must not be None.")
if x_values is not None:
if np.isscalar(x_values):
x_values = np.array([x_values], dtype=float)
else:
x_values = np.array(x_values, dtype=float)
causal_effects = []
print('MCMC Latent Variable Sampling ...')
# iVAE modification: Pass data to MCMC sampler to generate internal data_u
data_posterior_z, data_u_mcmc = self.metropolis_hastings_sampler(data, n_keep=n_mcmc, q_sd=q_sd)
# Iterate over the data_posterior_z in batches
for i in range(0, data_posterior_z.shape[0], bs):
batch_posterior_z = data_posterior_z[i:i + bs]
# No need to pass data_u here as infer_from_latent_posterior only depends on Z, not U directly.
# The influence of U is already captured in the sampled posterior of Z.
causal_effect_batch = self.infer_from_latent_posterior(batch_posterior_z, x_values=x_values, sample_y=sample_y).numpy()
causal_effects.append(causal_effect_batch)
if self.params['binary_treatment']:
causal_effects = np.concatenate(causal_effects, axis=0)
ITE = np.mean(causal_effects, axis=0)
posterior_interval_upper = np.quantile(causal_effects, 1-alpha/2, axis=0)
posterior_interval_lower = np.quantile(causal_effects, alpha/2, axis=0)
pos_int = np.stack([posterior_interval_lower, posterior_interval_upper], axis=1)
return ITE, pos_int
else:
causal_effects = np.concatenate(causal_effects, axis=1)
ADRF = np.mean(causal_effects, axis=1)
posterior_interval_upper = np.quantile(causal_effects, 1-alpha/2, axis=1)
posterior_interval_lower = np.quantile(causal_effects, alpha/2, axis=1)
pos_int = np.stack([posterior_interval_lower, posterior_interval_upper], axis=1)
return ADRF, pos_int
# infer_from_latent_posterior function remains unchanged.
# It calculates E[Y|do(x), z] = f(z0, z1, x). It doesn't need U because Z already contains all necessary information from U.
@tf.function
def infer_from_latent_posterior(self, data_posterior_z, x_values=None, sample_y=True, eps=1e-6):
# ... function body as in original code ...
data_z0 = data_posterior_z[:,:,:self.params['z_dims'][0]]
data_z1 = data_posterior_z[:,:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_posterior_z[:,:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
if self.params['binary_treatment']:
y_out_pos_all = tf.map_fn(
lambda z: self.f_net(tf.concat([z[:, :self.params['z_dims'][0]],
z[:, self.params['z_dims'][0]:sum(self.params['z_dims'][:2])],
tf.ones([tf.shape(z)[0], 1])], axis=-1)),
data_posterior_z,
fn_output_signature=tf.float32
)
mu_y_pos_all = y_out_pos_all[:,:,0]
if 'sigma_y' in self.params:
sigma_square_y_pos = self.params['sigma_y']**2
else:
sigma_square_y_pos = tf.nn.softplus(y_out_pos_all[:,:,1]) + eps
if sample_y:
y_pred_pos_all = tf.random.normal(
shape=tf.shape(mu_y_pos_all), mean=mu_y_pos_all, stddev=tf.sqrt(sigma_square_y_pos)
)
else:
y_pred_pos_all = mu_y_pos_all
y_out_neg_all = tf.map_fn(
lambda z: self.f_net(tf.concat([z[:, :self.params['z_dims'][0]],
z[:, self.params['z_dims'][0]:sum(self.params['z_dims'][:2])],
tf.zeros([tf.shape(z)[0], 1])], axis=-1)),
data_posterior_z,
fn_output_signature=tf.float32
)
mu_y_neg_all = y_out_neg_all[:,:,0]
if 'sigma_y' in self.params:
sigma_square_y_neg = self.params['sigma_y']**2
else:
sigma_square_y_neg = tf.nn.softplus(y_out_neg_all[:,:,1]) + eps
if sample_y:
y_pred_neg_all = tf.random.normal(
shape=tf.shape(mu_y_neg_all), mean=mu_y_neg_all, stddev=tf.sqrt(sigma_square_y_neg)
)
else:
y_pred_neg_all = mu_y_neg_all
ite_pred_all = y_pred_pos_all-y_pred_neg_all
return ite_pred_all
else:
def compute_dose_response(x):
data_x = tf.fill([tf.shape(data_posterior_z)[1], 1], x)
data_x = tf.cast(data_x, tf.float32)
y_out_all = tf.map_fn(
lambda z: self.f_net(tf.concat([z[:, :self.params['z_dims'][0]],
z[:, self.params['z_dims'][0]:sum(self.params['z_dims'][:2])],
data_x],axis=-1)),
data_posterior_z,
fn_output_signature=tf.float32
)
mu_y_all = y_out_all[:,:,0]
if 'sigma_y' in self.params:
sigma_square_y = self.params['sigma_y']**2
else:
sigma_square_y = tf.nn.softplus(y_out_all[:,:,1]) + eps
if sample_y:
y_pred_all = tf.random.normal(
shape=tf.shape(mu_y_all), mean=mu_y_all, stddev=tf.sqrt(sigma_square_y)
)
else:
y_pred_all = mu_y_all
return tf.reduce_mean(y_pred_all, axis=1)
dose_response = tf.map_fn(compute_dose_response, x_values, fn_output_signature=tf.float32)
return dose_response
# iVAE modification: Update get_log_posterior to accept data_u and calculate conditional prior likelihood
@tf.function
def get_log_posterior(self, data_x, data_y, data_v, data_z, data_u, eps=1e-6):
""" Calculate log posterior log p(z|x,y,v,u) ~ log p(x,y,v|z) + log p(z|u) """
data_z0 = data_z[:,:self.params['z_dims'][0]]
data_z1 = data_z[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
# Likelihood term: log p(v|z) calculation (as negative loss)
mu_v = self.g_net(data_z)[:,:self.params['v_dim']]
if 'sigma_v' in self.params:
sigma_square_v = self.params['sigma_v']**2
else:
sigma_square_v = tf.nn.softplus(self.g_net(data_z)[:,-1]) + eps
loss_pv_z = tf.reduce_sum((data_v - mu_v)**2, axis=1)/(2*sigma_square_v) + \
self.params['v_dim'] * tf.math.log(sigma_square_v)/2
# Likelihood term: log p(x|z) calculation (as negative loss)
mu_x = self.h_net(tf.concat([data_z0, data_z2], axis=-1))[:,:1]
if 'sigma_x' in self.params:
sigma_square_x = self.params['sigma_x']**2
else:
sigma_square_x = tf.nn.softplus(self.h_net(tf.concat([data_z0, data_z2], axis=-1))[:,-1]) + eps
if self.params['binary_treatment']:
loss_px_z = tf.squeeze(tf.nn.sigmoid_cross_entropy_with_logits(labels=data_x,logits=mu_x))
else:
loss_px_z = tf.reduce_sum((data_x - mu_x)**2, axis=1)/(2*sigma_square_x) + \
tf.math.log(sigma_square_x)/2
# Likelihood term: log p(y|z,x) calculation (as negative loss)
mu_y = self.f_net(tf.concat([data_z0, data_z1, data_x], axis=-1))[:,:1]
if 'sigma_y' in self.params:
sigma_square_y = self.params['sigma_y']**2
else:
sigma_square_y = tf.nn.softplus(self.f_net(tf.concat([data_z0, data_z1, data_x], axis=-1))[:,-1]) + eps
loss_py_zx = tf.reduce_sum((data_y - mu_y)**2, axis=1)/(2*sigma_square_y) + \
tf.math.log(sigma_square_y)/2
# iVAE modification: Conditional prior term log p(z|u) calculation (as negative loss)
# Original: loss_prior_z = tf.reduce_sum(data_z**2, axis=1)/2
prior_output = self.prior_net(data_u)
mu_z_prior = prior_output[:, :sum(self.params['z_dims'])]
sigma_square_z_prior = tf.nn.softplus(prior_output[:, -1:]) + eps # Shape (batch_size, 1)
dim_z = sum(self.params['z_dims'])
loss_term1 = tf.reduce_sum((data_z - mu_z_prior)**2, axis=1) / (2.0 * tf.squeeze(sigma_square_z_prior))
loss_term2 = dim_z * tf.math.log(tf.squeeze(sigma_square_z_prior)) / 2.0
loss_prior_z = loss_term1 + loss_term2
# Note: We ignore BNN KL loss here as MCMC samples network parameters implicitly if BNN is used.
loss_postrior_z = loss_pv_z + loss_px_z + loss_py_zx + loss_prior_z
log_posterior = -loss_postrior_z
return log_posterior
# iVAE modification: Update MCMC sampler to generate and use data_u
def metropolis_hastings_sampler(self, data, initial_q_sd=1.0, q_sd=None, burn_in=5000, n_keep=3000, target_acceptance_rate=0.25, tolerance=0.05, adjustment_interval=50, adaptive_sd=None, window_size=100):
data_x, data_y, data_v = data
n_samples = len(data_x)
# iVAE modification: Generate auxiliary variable U for MCMC sampling.
# Use the same logic as in fit() to ensure consistency.
n_segments = self.params['n_segments']
# Note: For test set prediction, ideally we should re-use segment assignments if known,
# or randomly assign again. Random assignment here follows the spirit of treating U as noise.
segment_indices = np.random.randint(0, n_segments, size=n_samples)
data_u = tf.keras.utils.to_categorical(segment_indices, num_classes=n_segments).astype('float32')
# Initialize the state of n chains
current_state = np.random.normal(0, 1, size=(n_samples, sum(self.params['z_dims']))).astype('float32')
samples = []
counter = 0
recent_acceptances = []
if adaptive_sd is None:
adaptive_sd = (q_sd is None or q_sd <= 0)
if adaptive_sd:
q_sd = initial_q_sd
while len(samples) < n_keep:
proposed_state = current_state + np.random.normal(0, q_sd, size=(n_samples, sum(self.params['z_dims']))).astype('float32')
# iVAE modification: Pass data_u to get_log_posterior
proposed_log_posterior = self.get_log_posterior(data_x, data_y, data_v, proposed_state, data_u)
current_log_posterior = self.get_log_posterior(data_x, data_y, data_v, current_state, data_u)
acceptance_ratio = np.exp(np.minimum(proposed_log_posterior - current_log_posterior, 0))
indices = np.random.rand(n_samples) < acceptance_ratio
current_state[indices] = proposed_state[indices]
# Acceptance rate tracking and adaptation logic...
recent_acceptances.append(indices)
if len(recent_acceptances) > window_size:
recent_acceptances = recent_acceptances[-window_size:]
if adaptive_sd and counter < burn_in and counter % adjustment_interval == 0 and counter > 0:
current_acceptance_rate = np.sum(recent_acceptances) / (len(recent_acceptances) * n_samples)
print(f"Current MCMC Acceptance Rate: {current_acceptance_rate:.4f}")
if current_acceptance_rate < target_acceptance_rate - tolerance:
q_sd *= 0.9
elif current_acceptance_rate > target_acceptance_rate + tolerance:
q_sd *= 1.1
# print(f"MCMC Proposal Standard Deviation (q_sd): {q_sd:.4f}") # Optional: for debugging
if counter >= burn_in:
samples.append(current_state.copy())
counter += 1
acceptance_rate = np.sum(recent_acceptances) / (len(recent_acceptances) * n_samples)
print(f"Final MCMC Acceptance Rate: {acceptance_rate:.4f}")
# Return samples and the corresponding data_u used for sampling (though data_u might not be needed by caller)
return np.array(samples), data_u