import datetime
import os
import dateutil.tz
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from bayesgm.datasets import Gaussian_sampler
from ..networks import (
BaseFullyConnectedNet,
BayesianFullyConnectedNet,
Discriminator,
MCMCFullyConnectedNet,
run_mcmc_for_net,
)
from .base import CausalBGM
[docs]
class FullMCMCCausalBGM(CausalBGM):
"""CausalBGM with full MCMC sampling for both individual latent variables
and neural-network parameters.
After calling :meth:`fit` (which uses SGD for both network weights and
latent variables), invoke :meth:`run_mcmc_training` to draw posterior
samples of all network weights via Hamiltonian Monte Carlo. The
:meth:`predict` method then marginalises over *both* latent-variable and
weight uncertainty.
Inherits from :class:`CausalBGM`.
Parameters
----------
params : dict
Same keys as :class:`CausalBGM`.
timestamp : str or None, optional
Timestamp string for the run.
random_seed : int or None, optional
If provided, sets the global random seed for reproducibility.
"""
def __init__(self, params, timestamp=None, random_seed=None):
self.params = params
self.timestamp = timestamp
if random_seed is not None:
tf.keras.utils.set_random_seed(random_seed)
os.environ['TF_DETERMINISTIC_OPS'] = '1'
tf.config.experimental.enable_op_determinism()
if self.params['use_bnn']:
self.g_net = MCMCFullyConnectedNet(input_dim=sum(params['z_dims']),output_dim = params['v_dim']+1,
model_name='g_net', nb_units=params['g_units'])
self.e_net = BayesianFullyConnectedNet(input_dim=params['v_dim'],output_dim = sum(params['z_dims']),
model_name='e_net', nb_units=params['e_units'])
self.f_net = MCMCFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][1]+1,
output_dim = 2, model_name='f_net', nb_units=params['f_units'])
self.h_net = MCMCFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][2],
output_dim = 2, model_name='h_net', nb_units=params['h_units'])
else:
self.g_net = BaseFullyConnectedNet(input_dim=sum(params['z_dims']),output_dim = params['v_dim']+1,
model_name='g_net', nb_units=params['g_units'])
self.e_net = BaseFullyConnectedNet(input_dim=params['v_dim'],output_dim = sum(params['z_dims']),
model_name='e_net', nb_units=params['e_units'])
self.f_net = BaseFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][1]+1,
output_dim = 2, model_name='f_net', nb_units=params['f_units'])
self.h_net = BaseFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][2],
output_dim = 2, model_name='h_net', nb_units=params['h_units'])
self.dz_net = Discriminator(input_dim=sum(params['z_dims']),model_name='dz_net',
nb_units=params['dz_units'])
self.g_pre_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.9, beta_2=0.99)
self.d_pre_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.9, beta_2=0.99)
self.z_sampler = Gaussian_sampler(mean=np.zeros(sum(params['z_dims'])), sd=1.0)
self.g_optimizer = tf.keras.optimizers.Adam(params['lr_theta'], beta_1=0.9, beta_2=0.99)
self.f_optimizer = tf.keras.optimizers.Adam(params['lr_theta'], beta_1=0.9, beta_2=0.99)
self.h_optimizer = tf.keras.optimizers.Adam(params['lr_theta'], beta_1=0.9, beta_2=0.99)
self.posterior_optimizer = tf.keras.optimizers.Adam(params['lr_z'], beta_1=0.9, beta_2=0.99)
self.initialize_nets()
if self.timestamp is None:
now = datetime.datetime.now(dateutil.tz.tzlocal())
self.timestamp = now.strftime('%Y%m%d_%H%M%S')
self.checkpoint_path = "{}/checkpoints/{}/{}".format(
params['output_dir'], params['dataset'], self.timestamp)
if self.params['save_model'] and not os.path.exists(self.checkpoint_path):
os.makedirs(self.checkpoint_path)
self.save_dir = "{}/results/{}/{}".format(
params['output_dir'], params['dataset'], self.timestamp)
if self.params['save_res'] and not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.ckpt = tf.train.Checkpoint(g_net = self.g_net,
e_net = self.e_net,
f_net = self.f_net,
h_net = self.h_net,
dz_net = self.dz_net,
g_pre_optimizer = self.g_pre_optimizer,
d_pre_optimizer = self.d_pre_optimizer,
g_optimizer = self.g_optimizer,
f_optimizer = self.f_optimizer,
h_optimizer = self.h_optimizer,
posterior_optimizer = self.posterior_optimizer)
self.ckpt_manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_path, max_to_keep=5)
if self.ckpt_manager.latest_checkpoint:
self.ckpt.restore(self.ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
[docs]
def run_mcmc_training(self, data, num_samples=2000, num_burnin=1000, eps=1e-6):
"""Draw posterior weight samples via Hamiltonian Monte Carlo.
Runs HMC on the weights of ``g_net``, ``h_net``, and ``f_net``
conditioned on the optimised latent variables from :meth:`fit`.
Must be called **after** :meth:`fit`.
Parameters
----------
data : tuple of np.ndarray
A triplet ``(data_x, data_y, data_v)``.
num_samples : int, default=2000
Number of HMC posterior samples to draw.
num_burnin : int, default=1000
Number of burn-in steps to discard.
eps : float, default=1e-6
Small constant added for numerical stability in the
likelihood computation.
"""
data_x, data_y, data_v = data
data_z = self.data_z.numpy() # Use the optimized latent variables from fit()
data_z0 = data_z[:, :self.params['z_dims'][0]]
data_z1 = data_z[:, self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z[:, sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
# --- MCMC for g_net (predicting V from Z) ---
def g_net_likelihood(v_true, v_pred_out):
mu_v = v_pred_out[:, :self.params['v_dim']]
# By using `[-1:]` we keep the last dimension, changing the shape from (batch,) to (batch, 1)
sigma_square_v = tf.nn.softplus(v_pred_out[:, -1:]) + eps # <--- Fixed
log_prob = tf.reduce_sum(tfp.distributions.Normal(mu_v, tf.sqrt(sigma_square_v)).log_prob(v_true))
return log_prob
self.g_net_samples = run_mcmc_for_net(
self.g_net, data_z, data_v, g_net_likelihood,
self.g_net.get_weights(), num_samples, num_burnin
)
# --- MCMC for h_net (predicting X from Z) ---
def h_net_likelihood(x_true, x_pred_out):
mu_x = x_pred_out[:, :1]
if self.params['binary_treatment']:
dist = tfp.distributions.Bernoulli(logits=mu_x)
else:
sigma_square_x = tf.nn.softplus(x_pred_out[:, -1]) + eps
dist = tfp.distributions.Normal(mu_x, tf.sqrt(sigma_square_x))
return tf.reduce_sum(dist.log_prob(x_true))
h_net_input = tf.concat([data_z0, data_z2], axis=-1)
self.h_net_samples = run_mcmc_for_net(
self.h_net, h_net_input, data_x, h_net_likelihood,
self.h_net.get_weights(), num_samples, num_burnin
)
# --- MCMC for f_net (predicting Y from Z, X) ---
def f_net_likelihood(y_true, y_pred_out):
mu_y = y_pred_out[:, :1]
sigma_square_y = tf.nn.softplus(y_pred_out[:, -1]) + eps
log_prob = tf.reduce_sum(tfp.distributions.Normal(mu_y, tf.sqrt(sigma_square_y)).log_prob(y_true))
return log_prob
f_net_input = tf.concat([data_z0, data_z1, data_x], axis=-1)
self.f_net_samples = run_mcmc_for_net(
self.f_net, f_net_input, data_y, f_net_likelihood,
self.f_net.get_weights(), num_samples, num_burnin
)
# Predict with MCMC sampling
[docs]
def predict(self, data, alpha=0.01, n_mcmc=3000, x_values=None, q_sd=1.0, sample_y=True, bs=100):
"""Predict causal effects with full posterior uncertainty.
Marginalises over **both** latent-variable and network-weight
uncertainty. :meth:`run_mcmc_training` must be called first to
populate weight samples.
Parameters
----------
data : tuple of np.ndarray
A triplet ``(data_x, data_y, data_v)``.
alpha : float, default=0.01
Significance level for posterior intervals.
n_mcmc : int, default=3000
Number of posterior MCMC samples for latent variables.
x_values : array-like or None
Treatment values for dose-response (continuous treatment).
q_sd : float, default=1.0
Proposal standard deviation for Metropolis-Hastings.
sample_y : bool, default=True
Whether to sample from the outcome variance model.
bs : int, default=100
Batch size for processing posterior samples.
Returns
-------
effect : np.ndarray
ITE (binary) or ADRF (continuous) point estimates.
pos_int : np.ndarray
Posterior intervals.
"""
assert 0 < alpha < 1, "The significance level 'alpha' must be greater than 0 and less than 1."
if not self.params['binary_treatment']:
# Validate x_values for binary treatment
if x_values is None:
raise ValueError("For continous treatment, 'x_values' must not be None. Provide a list or a single treatment value.")
if x_values is not None:
if np.isscalar(x_values):
# Convert scalar to 1D array
x_values = np.array([x_values], dtype=float)
else:
# Convert list to NumPy array
x_values = np.array(x_values, dtype=float)
# Initialize list to store causal effect samples
causal_effects = []
print('MCMC Latent Variable Sampling ...')
data_posterior_z = self.metropolis_hastings_sampler(data,
g_net_samples=self.g_net_samples,
h_net_samples=self.h_net_samples,
f_net_samples=self.f_net_samples,
n_keep=n_mcmc,
q_sd=q_sd)
print('Number of x_values:', len(x_values))
print('Shape of NN weights by MCMC:', self.g_net_samples.shape, self.h_net_samples.shape, self.f_net_samples.shape)
print('Shape of Latent Variable Z by MCMC:', data_posterior_z.shape)
f_net_weights = self.f_net_samples
# Randomly select one weight sample for each Z sample to create pairs
num_z_samples = data_posterior_z.shape[0] #MCMC sample size for Z
num_weight_samples = f_net_weights.shape[0] #MCMC sample size for weights
# This creates a paired set of indices for efficient lookup
paired_weight_indices = np.random.randint(0, num_weight_samples, size=num_z_samples)
paired_f_net_weights = tf.gather(f_net_weights, paired_weight_indices)
# Iterate over the data_posterior_z in batches
for i in range(0, data_posterior_z.shape[0], bs):
batch_posterior_z = data_posterior_z[i:i + bs]
batch_weights = paired_f_net_weights[i:i + bs]
causal_effect_batch = self.infer_from_latent_posterior(batch_posterior_z,
f_net_weights=batch_weights,
x_values=x_values,
sample_y=sample_y).numpy()
causal_effects.append(causal_effect_batch)
# Estimate the posterior interval with user-specific significance level alpha
print('Shape of causal effect:', np.array(causal_effects).shape)
if self.params['binary_treatment']:
# For binary treatment: Individual Treatment Effect (ITE)
causal_effects = np.concatenate(causal_effects, axis=0)
ITE = np.mean(causal_effects, axis=0)
posterior_interval_upper = np.quantile(causal_effects, 1-alpha/2, axis=0)
posterior_interval_lower = np.quantile(causal_effects, alpha/2, axis=0)
pos_int = np.stack([posterior_interval_lower, posterior_interval_upper], axis=1)
return ITE, pos_int
else:
# For continuous treatment: Average Dose Response Function (ADRF)
causal_effects = np.concatenate(causal_effects, axis=0)
ADRF = np.mean(causal_effects, axis=0)
posterior_interval_upper = np.quantile(causal_effects, 1-alpha/2, axis=0)
posterior_interval_lower = np.quantile(causal_effects, alpha/2, axis=0)
pos_int = np.stack([posterior_interval_lower, posterior_interval_upper], axis=1)
return ADRF, pos_int
@tf.function
def infer_from_latent_posterior(self, data_posterior_z, f_net_weights=None, x_values=None, sample_y=True, eps=1e-6):
"""Infer causal estimate on the test data and give estimation interval and posterior latent variables. ITE is estimated for binary treatment and ADRF is estimated for continous treatment.
data_posterior_z: (np.ndarray): Posterior latent variables with shape (n_samples, n, p), where p is the dimension of Z.
x_values: (list of floats or np.ndarray): Number of intervals for the dose response function.
sample_y: (bool): consider the variance function in outcome generative model.
return (np.ndarray):
ITE with shape (n_samples, n) containing all the MCMC samples.
ADRF with shape (n_samples, len(x_values)) containing all the MCMC samples for each treatment value.
"""
# Helper function to compute effect for a single paired (z_sample, weight_sample)
def compute_effect(elems):
z_sample, weight_sample = elems
z0 = z_sample[:,:self.params['z_dims'][0]]
z1 = z_sample[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
if self.params['binary_treatment']:
# Predict outcome under treatment (x=1)
input_pos = tf.concat([z0, z1, tf.ones([tf.shape(z_sample)[0], 1])], axis=-1)
out_pos = self.f_net.call_with_weights(input_pos, weight_sample)
mu_y_pos, sigma_y_pos = out_pos[:, :1], tf.nn.softplus(out_pos[:, 1:]) + eps
# Predict outcome under control (x=0)
input_neg = tf.concat([z0, z1, tf.zeros([tf.shape(z_sample)[0], 1])], axis=-1)
out_neg = self.f_net.call_with_weights(input_neg, weight_sample)
mu_y_neg, sigma_y_neg = out_neg[:, :1], tf.nn.softplus(out_neg[:, 1:]) + eps
if sample_y: # Account for Aleatoric uncertainty
y_pred_pos = tf.random.normal(shape=tf.shape(mu_y_pos), mean=mu_y_pos, stddev=tf.sqrt(sigma_y_pos))
y_pred_neg = tf.random.normal(shape=tf.shape(mu_y_neg), mean=mu_y_neg, stddev=tf.sqrt(sigma_y_neg))
else: # Use only the mean (epistemic + latent uncertainty only)
y_pred_pos, y_pred_neg = mu_y_pos, mu_y_neg
# Return one sample of the ITE for each individual
ite_pred = y_pred_pos - y_pred_neg
return np.squeeze(ite_pred)
else:
# ADRF implementation would go here, mapping over x_values
def compute_dose_response(x):
data_x_tile = tf.cast(tf.fill([tf.shape(z_sample)[0], 1], x), tf.float32)
y_out = self.f_net.call_with_weights(tf.concat([z0, z1, data_x_tile], axis=-1), weight_sample)
mu_y, sigma_y = y_out[:, :1], tf.nn.softplus(y_out[:, 1:]) + eps
if sample_y:
y_pred = tf.random.normal(shape=tf.shape(mu_y), mean=mu_y, stddev=tf.sqrt(sigma_y))
else:
y_pred = mu_y
return tf.reduce_mean(y_pred)
return tf.map_fn(compute_dose_response, x_values, fn_output_signature=tf.float32)
causal_effects = tf.map_fn(
compute_effect,
(data_posterior_z, f_net_weights),
fn_output_signature=tf.float32 if self.params['binary_treatment'] else tf.TensorSpec(shape=(len(x_values),), dtype=tf.float32)
)
return causal_effects
@tf.function
def get_log_posterior(self, data_x, data_y, data_v, data_z, g_weights, h_weights, f_weights, eps=1e-6):
"""
Calculate log posterior of Z for a GIVEN set of network weights.
This version is stateless and graph-compatible.
g_weights, h_weights, f_weights: Flattened tensors of weights for each network.
"""
data_z0 = data_z[:, :self.params['z_dims'][0]]
data_z1 = data_z[:, self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z[:, sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
# logp(v|z) for covariate model
g_net_output = self.g_net.call_with_weights(data_z, g_weights)
mu_v = g_net_output[:, :self.params['v_dim']]
sigma_square_v = tf.nn.softplus(g_net_output[:, -1]) + eps
# logp(x|z) for treatment model
h_net_input = tf.concat([data_z0, data_z2], axis=-1)
h_net_output = self.h_net.call_with_weights(h_net_input, h_weights)
mu_x = h_net_output[:, :1]
# logp(y|z,x) for outcome model
f_net_input = tf.concat([data_z0, data_z1, data_x], axis=-1)
f_net_output = self.f_net.call_with_weights(f_net_input, f_weights)
mu_y = f_net_output[:, :1]
# --- Calculate Likelihood Losses (Negative Log-Likelihoods) ---
loss_pv_z = tf.reduce_sum((data_v - mu_v)**2, axis=1) / (2 * sigma_square_v) + \
self.params['v_dim'] * tf.math.log(sigma_square_v) / 2
if self.params['binary_treatment']:
loss_px_z = tf.squeeze(tf.nn.sigmoid_cross_entropy_with_logits(labels=data_x, logits=mu_x))
else:
sigma_square_x = tf.nn.softplus(h_net_output[:, -1]) + eps
loss_px_z = tf.reduce_sum((data_x - mu_x)**2, axis=1) / (2 * sigma_square_x) + \
tf.math.log(sigma_square_x) / 2
sigma_square_y = tf.nn.softplus(f_net_output[:, -1]) + eps
loss_py_zx = tf.reduce_sum((data_y - mu_y)**2, axis=1) / (2 * sigma_square_y) + \
tf.math.log(sigma_square_y) / 2
# --- Calculate Prior Loss ---
loss_prior_z = tf.reduce_sum(data_z**2, axis=1) / 2
# --- Total Negative Log-Posterior ---
loss_posterior_z = loss_pv_z + loss_px_z + loss_py_zx + loss_prior_z
log_posterior = -loss_posterior_z
return log_posterior
def metropolis_hastings_sampler(self, data, g_net_samples, h_net_samples, f_net_samples, initial_q_sd = 1.0, q_sd = None, burn_in = 5000, n_keep = 3000, target_acceptance_rate=0.25, tolerance=0.05, adjustment_interval=50, adaptive_sd=None, window_size=100):
"""
Samples from the posterior distribution P(Z|X,Y,V) using the Metropolis-Hastings algorithm with adaptive proposal adjustment.
Args:
data (tuple): Tuple containing data_x, data_y, data_v.
q_sd (float or None): Fixed standard deviation for the proposal distribution. If None, `q_sd` will adapt.
initial_q_sd (float): Initial standard deviation of the proposal distribution.
burn_in (int): Number of samples for burn-in, set to 1000 as an initial estimate.
n_keep (int): Number of samples retained after burn-in.
target_acceptance_rate (float): Target acceptance rate for the Metropolis-Hastings algorithm.
tolerance (float): Acceptable deviation from the target acceptance rate.
adjustment_interval (int): Number of iterations between each adjustment of `q_sd`.
window_size (int): The size of the sliding window for acceptance rate calculation.
Returns:
np.ndarray: Posterior samples with shape (n_keep, n, q), where q is the dimension of Z.
"""
data_x, data_y, data_v = data
# Initialize the state of n chains
current_state = np.random.normal(0, 1, size = (len(data_x), sum(self.params['z_dims']))).astype('float32')
# Initialize the list to store the samples
samples = []
counter = 0
# Sliding window for acceptance tracking
recent_acceptances = []
num_weight_samples = f_net_samples.shape[0]
# Determine if q_sd should be adaptive
if adaptive_sd is None:
adaptive_sd = (q_sd is None or q_sd <= 0)
# Set the initial q_sd
if adaptive_sd:
q_sd = initial_q_sd
# Run the Metropolis-Hastings algorithm
while len(samples) < n_keep:
# Propose a new state by sampling from a multivariate normal distribution
proposed_state = current_state + np.random.normal(0, q_sd, size = (len(data_x), sum(self.params['z_dims']))).astype('float32')
rand_idx = np.random.randint(0, num_weight_samples)
g_w = g_net_samples[rand_idx]
h_w = h_net_samples[rand_idx]
f_w = f_net_samples[rand_idx]
# Compute the acceptance ratio
proposed_log_posterior = self.get_log_posterior(data_x, data_y, data_v, proposed_state, g_w, h_w, f_w)
current_log_posterior = self.get_log_posterior(data_x, data_y, data_v, current_state, g_w, h_w, f_w)
#acceptance_ratio = np.exp(proposed_log_posterior-current_log_posterior)
acceptance_ratio = np.exp(np.minimum(proposed_log_posterior - current_log_posterior, 0))
# Accept or reject the proposed state
indices = np.random.rand(len(data_x)) < acceptance_ratio
current_state[indices] = proposed_state[indices]
# Update the sliding window
recent_acceptances.append(indices)
if len(recent_acceptances) > window_size:
# Keep only the most recent `window_size` elements
recent_acceptances = recent_acceptances[-window_size:]
# Adjust q_sd periodically during the burn-in phase
if adaptive_sd and counter < burn_in and counter % adjustment_interval == 0 and counter > 0:
# Calculate the current acceptance rate
current_acceptance_rate = np.sum(recent_acceptances) / (len(recent_acceptances)*len(data_x))
print(f"Current MCMC Acceptance Rate: {current_acceptance_rate:.4f}")
# Adjust q_sd based on the acceptance rate
if current_acceptance_rate < target_acceptance_rate - tolerance:
q_sd *= 0.9 # Decrease q_sd to increase acceptance rate
elif current_acceptance_rate > target_acceptance_rate + tolerance:
q_sd *= 1.1 # Increase q_sd to decrease acceptance rate
print(f"MCMC Proposal Standard Deviation (q_sd): {q_sd:.4f}")
# Append the current state to the list of samples
if counter >= burn_in:
samples.append(current_state.copy())
counter += 1
# Calculate the acceptance rate
acceptance_rate = np.sum(recent_acceptances) / (len(recent_acceptances)*len(data_x))
print(f"Final MCMC Acceptance Rate: {acceptance_rate:.4f}")
#print(f"Final Proposal Standard Deviation (q_sd): {q_sd:.4f}")
return np.array(samples)