import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfm = tfp.mcmc
from ..networks import (
BaseFullyConnectedNet,
Discriminator,
BayesianVariationalNet,
BaseVariationalNet,
)
import numpy as np
from bayesgm.datasets import Gaussian_sampler, Base_sampler
import dateutil.tz
import datetime
import os
from tqdm import tqdm
[docs]
class BGM(object):
"""Bayesian Generative Model (BGM) for tabular data.
BGM learns a latent-variable generative model
:math:`Z \\sim \\mathcal{N}(0, I)`, :math:`X \\mid Z \\sim \\mathcal{N}(\\mu(Z), \\Sigma(Z))`
using an iterative algorithm that alternates between updating the generative
network parameters and the individual latent variables.
Parameters
----------
params : dict
Configuration dictionary. Required keys:
- ``'x_dim'`` (int): Dimension of the observed variable :math:`X`.
- ``'z_dim'`` (int): Dimension of the latent variable :math:`Z`.
- ``'dataset'`` (str): Dataset name (used for checkpoint paths).
- ``'output_dir'`` (str): Root directory for saving results and checkpoints.
Optional keys (with defaults):
- ``'use_bnn'`` (bool): Whether to use a Bayesian neural network for the generator. Default ``False``.
- ``'g_units'`` (list[int]): Hidden-layer sizes for the generator network. Default ``[64, 64, 64, 64, 64]``.
- ``'e_units'`` (list[int]): Hidden-layer sizes for the encoder network. Default ``[64, 64, 64, 64, 64]``.
- ``'dz_units'`` (list[int]): Hidden-layer sizes for the latent discriminator. Default ``[64, 32, 8]``.
- ``'dx_units'`` (list[int]): Hidden-layer sizes for the data discriminator. Default ``[64, 32, 8]``.
- ``'lr'`` (float): Learning rate for EGM initialization. Default ``0.001``.
- ``'lr_theta'`` (float): Learning rate for generator parameters. Default ``0.005``.
- ``'lr_z'`` (float): Learning rate for latent-variable updates. Default ``0.005``.
- ``'gamma'`` (float): Gradient-penalty coefficient. Default ``0``.
- ``'alpha'`` (float): Regularisation weight on the variance term. Default ``0.0``.
- ``'g_d_freq'`` (int): Discriminator-to-generator update ratio during EGM. Default ``1``.
- ``'save_model'`` (bool): Whether to save model checkpoints. Default ``True``.
- ``'save_res'`` (bool): Whether to save results. Default ``True``.
- ``'kl_weight'`` (float): KL-divergence weight when ``use_bnn`` is True. Default ``0.00005``.
timestamp : str or None, optional
Timestamp string for the run. If ``None``, the current local time is used.
random_seed : int or None, optional
If provided, sets the global random seed for reproducibility.
"""
def __init__(self, params, timestamp=None, random_seed=None):
super(BGM, self).__init__()
self.params = params
self.timestamp = timestamp
if random_seed is not None:
tf.keras.utils.set_random_seed(random_seed)
os.environ['TF_DETERMINISTIC_OPS'] = '1'
tf.config.experimental.enable_op_determinism()
if self.params['use_bnn']:
self.g_net = BayesianVariationalNet(input_dim=params['z_dim'],output_dim = params['x_dim'],
model_name='g_net', nb_units=params['g_units'])
else:
self.g_net = BaseVariationalNet(input_dim=params['z_dim'],output_dim = params['x_dim'],
model_name='g_net', nb_units=params['g_units'])
self.e_net = BaseFullyConnectedNet(input_dim=params['x_dim'],output_dim = params['z_dim'],
model_name='e_net', nb_units=params['e_units'])
self.dz_net = Discriminator(input_dim=params['z_dim'],model_name='dz_net',
nb_units=params['dz_units'])
self.dx_net = Discriminator(input_dim=params['x_dim'],model_name='dx_net',
nb_units=params['dx_units'])
#self.g_pre_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.9, beta_2=0.99)
self.g_pre_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.5, beta_2=0.9)
#self.d_pre_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.9, beta_2=0.99)
self.d_pre_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.5, beta_2=0.9)
self.z_sampler = Gaussian_sampler(mean=np.zeros(params['z_dim']), sd=1.0)
self.g_optimizer = tf.keras.optimizers.Adam(params['lr_theta'], beta_1=0.9, beta_2=0.99)
self.posterior_optimizer = tf.keras.optimizers.Adam(params['lr_z'], beta_1=0.9, beta_2=0.99)
self.initialize_nets()
if self.timestamp is None:
now = datetime.datetime.now(dateutil.tz.tzlocal())
self.timestamp = now.strftime('%Y%m%d_%H%M%S')
self.checkpoint_path = "{}/checkpoints/{}/{}".format(
params['output_dir'], params['dataset'], self.timestamp)
if self.params['save_model'] and not os.path.exists(self.checkpoint_path):
os.makedirs(self.checkpoint_path)
self.save_dir = "{}/results/{}/{}".format(
params['output_dir'], params['dataset'], self.timestamp)
if self.params['save_res'] and not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.ckpt = tf.train.Checkpoint(g_net = self.g_net,
e_net = self.e_net,
dz_net = self.dz_net,
dx_net = self.dx_net,
g_pre_optimizer = self.g_pre_optimizer,
d_pre_optimizer = self.d_pre_optimizer,
g_optimizer = self.g_optimizer,
posterior_optimizer = self.posterior_optimizer)
self.ckpt_manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_path, max_to_keep=100)
if self.ckpt_manager.latest_checkpoint:
self.ckpt.restore(self.ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
[docs]
def get_config(self):
"""Return the configuration of the BGM model.
Returns
-------
dict
A dictionary with key ``"params"`` containing the full
configuration dictionary passed at construction time.
"""
return {
"params": self.params,
}
def initialize_nets(self, print_summary = False):
"""Initialize all the networks in BGM."""
self.g_net(np.zeros((1, self.params['z_dim'])))
if print_summary:
print(self.g_net.summary())
# Update generative model for X
@tf.function
def update_g_net(self, data_z, data_x):
with tf.GradientTape() as gen_tape:
mu_x, sigma_square_x = self.g_net(data_z)
#loss = -log(p(x|z))
loss_mse = tf.reduce_mean((data_x - mu_x)**2)
loss_x = tf.reduce_sum(((data_x - mu_x) ** 2) / (2 * sigma_square_x) + \
0.5 * tf.math.log(sigma_square_x), axis=1)
loss_x = tf.reduce_mean(loss_x) # Average over batch
if self.params['use_bnn']:
loss_kl = sum(self.g_net.losses)
loss_x += loss_kl * self.params['kl_weight']
# Calculate the gradients for generators and discriminators
g_gradients = gen_tape.gradient(loss_x, self.g_net.trainable_variables)
# Apply the gradients to the optimizer
self.g_optimizer.apply_gradients(zip(g_gradients, self.g_net.trainable_variables))
return loss_x, loss_mse
# Update posterior of latent variables Z
@tf.function
def update_latent_variable_sgd(self, data_z, data_x):
with tf.GradientTape() as tape:
# logp(x|z) for covariate model
mu_x, sigma_square_x = self.g_net(data_z)
loss_px_z = tf.reduce_sum(((data_x - mu_x) ** 2) / (2 * sigma_square_x) + \
0.5 * tf.math.log(sigma_square_x), axis=1)
loss_px_z = tf.reduce_mean(loss_px_z)
loss_prior_z = tf.reduce_sum(data_z**2, axis=1)/2
loss_prior_z = tf.reduce_mean(loss_prior_z)
loss_postrior_z = loss_px_z + loss_prior_z
#loss_postrior_z = loss_postrior_z/self.params['x_dim']
# Calculate the gradients
posterior_gradients = tape.gradient(loss_postrior_z, [data_z])
# Apply the gradients to the optimizer
self.posterior_optimizer.apply_gradients(zip(posterior_gradients, [data_z]))
return loss_postrior_z
#################################### EGM initialization ###########################################
@tf.function
def train_disc_step(self, data_z, data_x):
"""Train discriminators step.
Args:
data_z: latent tensor with shape [batch_size, z_dim].
data_x: data tensor with shape [batch_size, x_dim].
Returns:
Tuple of (dz_loss, dx_loss, d_loss): various discriminator loss functions.
"""
epsilon_z = tf.random.uniform([],minval=0., maxval=1.)
epsilon_x = tf.random.uniform([],minval=0., maxval=1.)
with tf.GradientTape(persistent=True) as disc_tape:
with tf.GradientTape() as gpz_tape:
data_z_ = self.e_net(data_x)
data_z_hat = data_z*epsilon_z + data_z_*(1-epsilon_z)
data_dz_hat = self.dz_net(data_z_hat)
with tf.GradientTape() as gpx_tape:
mu_x_, sigma_square_x_ = self.g_net(data_z)
data_x_ = self.g_net.reparameterize(mu_x_, sigma_square_x_)
data_x_hat = data_x*epsilon_x + data_x_*(1-epsilon_x)
data_dx_hat = self.dx_net(data_x_hat)
data_dx_ = self.dx_net(data_x_)
data_dz_ = self.dz_net(data_z_)
data_dx = self.dx_net(data_x)
data_dz = self.dz_net(data_z)
#dz_loss = -tf.reduce_mean(data_dz) + tf.reduce_mean(data_dz_)
#dx_loss = -tf.reduce_mean(data_dx) + tf.reduce_mean(data_dx_)
dz_loss = (tf.reduce_mean((0.9*tf.ones_like(data_dz) - data_dz)**2) \
+tf.reduce_mean((0.1*tf.ones_like(data_dz_) - data_dz_)**2))/2.0
dx_loss = (tf.reduce_mean((0.9*tf.ones_like(data_dx) - data_dx)**2) \
+tf.reduce_mean((0.1*tf.ones_like(data_dx_) - data_dx_)**2))/2.0
#gradient penalty for z
grad_z = gpz_tape.gradient(data_dz_hat, data_z_hat)
grad_norm_z = tf.sqrt(tf.reduce_sum(tf.square(grad_z), axis=1))#(bs,)
gpz_loss = tf.reduce_mean(tf.square(grad_norm_z - 1.0))
#gradient penalty for x
grad_x = gpx_tape.gradient(data_dx_hat, data_x_hat)
grad_norm_x = tf.sqrt(tf.reduce_sum(tf.square(grad_x), axis=1))#(bs,)
gpx_loss = tf.reduce_mean(tf.square(grad_norm_x - 1.0))
d_loss = dx_loss + dz_loss + \
self.params['gamma']*(gpz_loss + gpx_loss)
# Calculate the gradients for generators and discriminators
d_gradients = disc_tape.gradient(d_loss, self.dz_net.trainable_variables+self.dx_net.trainable_variables)
# Apply the gradients to the optimizer
self.d_pre_optimizer.apply_gradients(zip(d_gradients, self.dz_net.trainable_variables+self.dx_net.trainable_variables))
return dz_loss, dx_loss, d_loss
@tf.function
def train_gen_step(self, data_z, data_x):
"""Train generators step.
Args:
data_z: latent tensor with shape [batch_size, z_dim].
data_x: data tensor with shape [batch_size, x_dim].
Returns:
Tuple of (g_loss_adv, e_loss_adv, l2_loss_z, l2_loss_x, reg_loss, g_e_loss):
various generator loss functions.
"""
with tf.GradientTape(persistent=True) as gen_tape:
mu_x_, sigma_square_x_ = self.g_net(data_z)
data_x_ = self.g_net.reparameterize(mu_x_, sigma_square_x_)
reg_loss = tf.reduce_mean(tf.square(sigma_square_x_))
data_z_ = self.e_net(data_x)
data_z__= self.e_net(data_x_)
mu_x__, sigma_square_x__ = self.g_net(data_z_)
data_x__ = self.g_net.reparameterize(mu_x__, sigma_square_x__)
data_dx_ = self.dx_net(data_x_)
data_dz_ = self.dz_net(data_z_)
l2_loss_x = tf.reduce_mean((data_x - data_x__)**2)
l2_loss_z = tf.reduce_mean((data_z - data_z__)**2)
#g_loss_adv = -tf.reduce_mean(data_dx_)
#e_loss_adv = -tf.reduce_mean(data_dz_)
g_loss_adv = tf.reduce_mean((0.9*tf.ones_like(data_dx_) - data_dx_)**2)
e_loss_adv = tf.reduce_mean((0.9*tf.ones_like(data_dz_) - data_dz_)**2)
g_e_loss = g_loss_adv + e_loss_adv + 10 * (l2_loss_x + l2_loss_z) + self.params['alpha'] * reg_loss
# if self.params['use_bnn']:
# loss_g_kl = sum(self.g_net.losses)
# loss_e_kl = sum(self.e_net.losses)
# g_e_loss += self.params['kl_weight'] * (loss_g_kl+loss_e_kl)
# Calculate the gradients for generators and discriminators
g_e_gradients = gen_tape.gradient(g_e_loss, self.g_net.trainable_variables+self.e_net.trainable_variables)
# Apply the gradients to the optimizer
self.g_pre_optimizer.apply_gradients(zip(g_e_gradients, self.g_net.trainable_variables+self.e_net.trainable_variables))
return g_loss_adv, e_loss_adv, l2_loss_z, l2_loss_x, reg_loss, g_e_loss
def egm_init(self, data, egm_n_iter=10000, batch_size=32, egm_batches_per_eval=500, verbose=1):
self.data_sampler = Base_sampler(x=data,y=data,v=data, batch_size=batch_size, normalize=False)
print('EGM Initialization Starts ...')
for batch_iter in range(egm_n_iter+1):
# Update model parameters of Discriminator
for _ in range(self.params['g_d_freq']):
batch_x,_,_ = self.data_sampler.next_batch()
batch_z = self.z_sampler.get_batch(batch_size)
dz_loss, dx_loss, d_loss = self.train_disc_step(batch_z, batch_x)
# Update model parameters of G,E with SGD
batch_x,_,_ = self.data_sampler.next_batch()
batch_z = self.z_sampler.get_batch(batch_size)
g_loss_adv, e_loss_adv, l2_loss_z, l2_loss_x, sigma_square_loss, g_e_loss = self.train_gen_step(batch_z, batch_x)
if batch_iter % egm_batches_per_eval == 0:
loss_contents = (
'EGM Initialization Iter [%d] : g_loss_adv[%.4f], e_loss_adv [%.4f], l2_loss_z [%.4f], l2_loss_x [%.4f], '
'sd^2_loss[%.4f], g_e_loss [%.4f], dz_loss [%.4f], dx_loss[%.4f], d_loss [%.4f]'
% (batch_iter, g_loss_adv, e_loss_adv, l2_loss_z, l2_loss_x, sigma_square_loss, g_e_loss, dz_loss, dx_loss, d_loss)
)
if verbose:
print(loss_contents)
data_z_ = self.e_net(data)
data_x__, _ = self.g_net(data_z_)
MSE = tf.reduce_mean((data - data_x__)**2)
data_gen_1, sigma_square_x_1 = self.generate(nb_samples=5000)
data_gen_12, sigma_square_x_12 = self.generate(nb_samples=5000,use_x_sd=False)
if self.params['save_res']:
np.savez('%s/init_data_gen_at_%d.npz'%(self.save_dir, batch_iter),
gen1=data_gen_1, gen12=data_gen_12,
z=data_z_, x_rec=data_x__, var1=sigma_square_x_1, var12=sigma_square_x_12
)
print('MSE_x', MSE.numpy())
mse_x = self.evaluate(data = data, use_x_sd = True)
print('iter [%d/%d]: MSE_x: %.4f\n' % (batch_iter, egm_n_iter, mse_x))
mse_x = self.evaluate(data = data, use_x_sd = False)
print('iter [%d/%d]: MSE_x no x_sd: %.4f\n' % (batch_iter, egm_n_iter, mse_x))
if self.params['save_model']:
base_path = self.checkpoint_path + f"/weights_at_egm_init_{batch_iter}"
self.e_net.save_weights(f"{base_path}_encoder.weights.h5")
self.g_net.save_weights(f"{base_path}_generator.weights.h5")
print('Saving checkpoint for egm_init at {}'.format(base_path))
print('EGM Initialization Ends.')
#################################### EGM initialization #############################################
[docs]
def fit(self, data,
batch_size=32, epochs=100, epochs_per_eval=5,
use_egm_init=True, egm_n_iter=20000, egm_batches_per_eval=500, verbose=1):
"""Train the BGM model on observed data.
The training procedure consists of two phases:
1. **EGM initialization** (optional) — warm-start by jointly training encoder and
generator with adversarial losses to obtain a good starting point
for the latent variables and model parameters. This phase is optional and can be skipped by setting `use_egm_init` to False.
2. **Iterative optimization** — alternates between updating the
generator network parameters :math:`\\theta` and the per-sample
latent variables :math:`Z` via SGD.
Parameters
----------
data : np.ndarray
Observed data matrix with shape ``(n, x_dim)``.
batch_size : int, default=32
Mini-batch size.
epochs : int, default=100
Number of training epochs for the iterative phase.
epochs_per_eval : int, default=5
Evaluate and (optionally) save every this many epochs.
use_egm_init : bool, default=True
Whether to run EGM initialization before iterative training.
egm_n_iter : int, default=20000
Number of EGM initialization iterations.
egm_batches_per_eval : int, default=500
Evaluate EGM every this many iterations.
verbose : int, default=1
Verbosity level. Set to 0 to suppress progress messages.
"""
if self.params['save_res']:
f_params = open('{}/params.txt'.format(self.save_dir),'w')
f_params.write(str(self.params))
f_params.close()
if use_egm_init:
self.egm_init(data, egm_n_iter=egm_n_iter, egm_batches_per_eval=egm_batches_per_eval, batch_size=batch_size, verbose=verbose)
print('Initialize latent variables Z with e(V)...')
data_z_init = self.e_net(data)
else:
print('Random initialization of latent variables Z...')
data_z_init = np.random.normal(0, 1, size = (len(data), self.params['z_dim'])).astype('float32')
self.data_z = tf.Variable(data_z_init, name="Latent Variable",trainable=True)
self.history_loss = []
print('Iterative Updating Starts ...')
for epoch in range(epochs+1):
sample_idx = np.random.choice(len(data), len(data), replace=False)
# Create a progress bar for batches
with tqdm(total=len(data) // batch_size, desc=f"Epoch {epoch}/{epochs}", unit="batch") as batch_bar:
for i in range(0,len(data) - batch_size + 1,batch_size): ## Skip the incomplete last batch
batch_idx = sample_idx[i:i+batch_size]
# Update model parameters of G, H, F with SGD
batch_z = tf.Variable(tf.gather(self.data_z, batch_idx, axis = 0), name='batch_z', trainable=True)
batch_x = data[batch_idx,:]
loss_x, loss_mse_x = self.update_g_net(batch_z, batch_x)
# Update Z by maximizing a posterior or posterior mean
loss_postrior_z = self.update_latent_variable_sgd(batch_z, batch_x)
# Update data_z with updated batch_z
self.data_z.scatter_nd_update(
indices=tf.expand_dims(batch_idx, axis=1),
updates=batch_z
)
# Update the progress bar with the current loss information
loss_contents = (
'loss_x: [%.4f], loss_mse_x: [%.4f], loss_postrior_z: [%.4f]'
% (loss_x, loss_mse_x, loss_postrior_z)
)
batch_bar.set_postfix_str(loss_contents)
batch_bar.update(1)
# Evaluate the full training data and print metrics for the epoch
if epoch % epochs_per_eval == 0:
mse_x = self.evaluate(data = data, data_z = self.data_z)
self.history_loss.append(mse_x)
if verbose:
print('Epoch [%d/%d]: MSE_x: %.4f\n' % (epoch, epochs, mse_x))
if self.params['save_model']:
base_path = self.checkpoint_path + f"/weights_at_{epoch}"
self.g_net.save_weights(f"{base_path}_generator.weights.h5")
print('Saving checkpoint for epoch {} at {}'.format(epoch, base_path))
data_gen_1, sigma_square_x_1 = self.generate(nb_samples=5000)
data_gen_12, sigma_square_x_12 = self.generate(nb_samples=5000,use_x_sd=False)
if self.params['save_res']:
np.savez('%s/data_gen_at_%d.npz'%(self.save_dir, epoch),
gen1=data_gen_1, gen12=data_gen_12,
z=self.data_z.numpy(), var1=sigma_square_x_1, var12=sigma_square_x_12
)
@tf.function
def evaluate(self, data, data_z=None, use_x_sd=True):
"""Compute the mean squared error between observed and reconstructed data.
Parameters
----------
data : np.ndarray or tf.Tensor
Observed data with shape ``(n, x_dim)``.
data_z : tf.Tensor or None, optional
Latent variables with shape ``(n, z_dim)``.
If ``None``, the encoder network is used to infer them.
use_x_sd : bool, default=True
If ``True``, sample reconstructions from
:math:`\\mathcal{N}(\\mu, \\sigma^2)`.
If ``False``, use the mean :math:`\\mu` directly.
Returns
-------
mse_x : tf.Tensor
Scalar mean squared error.
"""
if data_z is None:
data_z = self.e_net(data, training=False)
mu_x, sigma_square_x = self.g_net(data_z, training=False)
if use_x_sd:
data_x_pred = self.g_net.reparameterize(mu_x, sigma_square_x)
else:
data_x_pred = mu_x
mse_x = tf.reduce_mean((data-data_x_pred)**2)
return mse_x
@tf.function
def generate(self, nb_samples=1000, use_x_sd=True):
"""Generate synthetic data from the trained model.
Samples latent codes from the standard normal prior and decodes them
through the generator network.
Parameters
----------
nb_samples : int, default=1000
Number of samples to generate.
use_x_sd : bool, default=True
If ``True``, add noise sampled from the learned variance.
If ``False``, return the generator mean directly.
Returns
-------
data_x_gen : tf.Tensor
Generated data with shape ``(nb_samples, x_dim)``.
sigma_square_x : tf.Tensor
Predicted variance with shape ``(nb_samples, x_dim)``.
"""
data_z = tf.random.normal(shape=(nb_samples, self.params['z_dim']), mean=0.0, stddev=1.0)
mu_x, sigma_square_x = self.g_net(data_z, training=False)
if use_x_sd:
data_x_gen = self.g_net.reparameterize(mu_x, sigma_square_x)
else:
data_x_gen = mu_x
return data_x_gen, sigma_square_x
@tf.function
def predict_on_posteriors(self, data_posterior_z):
n_mcmc = tf.shape(data_posterior_z)[0]
n_samples = tf.shape(data_posterior_z)[1]
# Flatten data
data_posterior_z_flat = tf.reshape(data_posterior_z, [-1, self.params['z_dim']]) # Flatten: Shape: (n_IS * n_samples, z_dim)
mu_x_flat, sigma_square_x_flat = self.g_net(data_posterior_z_flat, training=False) # Output shape: (n_MCMC*n_samples, x_dim)
data_x_pred_flat = self.g_net.reparameterize(mu_x_flat, sigma_square_x_flat)
# Correctly reshape mean and variance
#mu_x = tf.reshape(mu_x_flat, [n_mcmc, n_samples, self.params['x_dim']]) # Shape: (n_MCMC, n_samples, x_dim)
data_x_pred = tf.reshape(data_x_pred_flat, [n_mcmc, n_samples, self.params['x_dim']])
return data_x_pred
[docs]
def predict(self, data, alpha=0.05, return_samples=False, bs=100, n_mcmc=5000, burn_in=5000, step_size=0.01, num_leapfrog_steps=10, seed=42):
"""
Predict the posterior distribution with missing data handling.
Parameters
----------
data : np.ndarray or tf.Tensor
Input data with shape ``(n, x_dim)``.
Missing values should be encoded as ``np.nan``.
alpha : float, default=0.05
Significance level for prediction intervals.
return_samples : bool, default=False
If ``False``, return imputed data with shape ``(n, x_dim)``.
If ``True``, return posterior samples with shape
``(n_mcmc, n, x_dim)``.
bs : int, default=100
Batch size for posterior prediction.
n_mcmc : int, default=5000
Number of retained MCMC samples.
burn_in : int, default=5000
Number of burn-in iterations.
step_size : float, default=0.01
HMC step size.
num_leapfrog_steps : int, default=10
Number of leapfrog steps in HMC.
seed : int, default=42
Random seed.
Returns
-------
data_x_pred : np.ndarray
Imputed data if ``return_samples=False`` with shape ``(n, x_dim)``.
Posterior predictive samples if ``return_samples=True`` with shape
``(n_mcmc, n, x_dim)``.
pred_interval : np.ndarray or list[np.ndarray]
Prediction intervals on missing dimensions.
For a shared missing pattern, shape is ``(n, n_missing_dims, 2)``.
Otherwise, this is a per-sample list where element ``i`` has shape
``(n_missing_dims_i, 2)``.
"""
assert 0 < alpha < 1, "The significance level 'alpha' must be greater than 0 and less than 1."
if not isinstance(data, tf.Tensor):
data_tf = tf.convert_to_tensor(data, dtype=tf.float32)
else:
data_tf = tf.cast(data, tf.float32)
# Shape: (n, x_dim)
n_data_samples = data_tf.shape[0]
# Boolean mask of missingness (True where NaN)
is_nan_tf = tf.math.is_nan(data_tf)
# Observed mask (True where not NaN)
is_obs_tf = tf.logical_not(is_nan_tf)
# We'll still feed some numeric value at missing locations; they are ignored via indices.
data_clean_tf = tf.where(is_nan_tf,
tf.zeros_like(data_tf),
data_tf)
# Build ind_x1 as list-of-lists of observed feature indices
is_obs_np = is_obs_tf.numpy()
ind_x1_list = [
np.where(row)[0].tolist()
for row in is_obs_np
]
data_posterior_z = self.tfp_mcmc_sampler(
data=data_clean_tf,
ind_x1=ind_x1_list,
n_mcmc=n_mcmc,
burn_in=burn_in,
step_size=step_size,
num_leapfrog_steps=num_leapfrog_steps,
seed=seed
)
# data_posterior_z: (n_mcmc, n_data_samples, z_dim)
data_x_pred_all = []
# Loop over data samples in batches
for i in range(0, n_data_samples, bs):
batch_posterior_z = data_posterior_z[:, i:i + bs, :] # (n_mcmc, bs_i, z_dim)
data_x_batch_pred = self.predict_on_posteriors(batch_posterior_z)
# Expected shape: (n_mcmc, bs_i, x_dim)
data_x_batch_pred = data_x_batch_pred.numpy()
data_x_pred_all.append(data_x_batch_pred)
# Concatenate along data dimension
data_x_pred_all = np.concatenate(data_x_pred_all, axis=1)
# Shape: (n_mcmc, n_data_samples, x_dim)
data_np = data_tf.numpy()
miss_mask_full = np.isnan(data_np).astype(np.float32)
obs_mask_full = 1.0 - miss_mask_full
data_obs_np = np.nan_to_num(data_np, nan=0.0)
# Compute prediction intervals on missing dimensions only
# Check if all samples share the same missing pattern
miss_mask_flat = miss_mask_full.astype(bool)
same_pattern = np.all(miss_mask_flat == miss_mask_flat[0])
if same_pattern:
# Common missing pattern across all samples
miss_idx = np.where(miss_mask_flat[0])[0] # (N_missed_dims,)
if miss_idx.size == 0:
# No missing dimensions at all
pred_interval = np.zeros((n_data_samples, 0, 2), dtype=np.float32)
else:
# Gather only missing dimension samples
dim_samples = data_x_pred_all[:, :, miss_idx] # (n_mcmc, n, N_missed_dims)
lower = np.quantile(dim_samples, alpha / 2.0, axis=0) # (n, N_missed_dims)
upper = np.quantile(dim_samples, 1.0 - alpha / 2.0, axis=0) # (n, N_missed_dims)
pred_interval = np.stack([lower, upper], axis=-1) # (n, N_missed_dims, 2)
else:
# Different missing patterns; return a list of per-sample intervals
pred_interval = []
for i in range(n_data_samples):
miss_idx_i = np.where(miss_mask_flat[i])[0] # (N_missed_dims_i,)
if miss_idx_i.size == 0:
# This sample has no missing dimensions
pred_interval.append(np.zeros((0, 2), dtype=np.float32))
continue
dim_samples_i = data_x_pred_all[:, i, miss_idx_i] # (n_mcmc, N_missed_dims_i)
lower_i = np.quantile(dim_samples_i, alpha / 2.0, axis=0) # (N_missed_dims_i,)
upper_i = np.quantile(dim_samples_i, 1.0 - alpha / 2.0, axis=0) # (N_missed_dims_i,)
intervals_i = np.stack([lower_i, upper_i], axis=-1) # (N_missed_dims_i, 2)
pred_interval.append(intervals_i)
if return_samples:
return data_x_pred_all, pred_interval
else:
# Return single imputed dataset: posterior mean across MCMC samples
# Shape: (n, x_dim) — observed values intact, missing filled with posterior means
data_imputed = np.mean(data_x_pred_all, axis=0)
# Ensure observed values are exactly the original (avoid floating-point drift)
data_imputed = miss_mask_full * data_imputed + obs_mask_full * data_obs_np
return data_imputed, pred_interval
@tf.function
def get_log_posterior(self, data_z, data_x, ind_x1=None, obs_mask=None):
"""
Calculate log posterior.
data_z: (tf.Tensor): Input data with shape (n, z_dim), where z_dim is the dimension of Z.
data_x: (tf.Tensor): Input data with shape (n, x_dim), where x_dim is the dimension of X.
Missing pixels can be any pattern (we ignore them via indices/mask).
ind_x1: None, or int32 Tensor of shape (n, K_max) with feature indices
for each sample (padded where obs_mask == 0).
obs_mask: None, or float32 Tensor of shape (n, K_max),
1 for real observed indices, 0 for padding.
return (tf.Tensor): Log posterior with shape (n, ).
"""
mu_x, sigma_square_x = self.g_net(data_z, training=False)
# Likelihood term log p(x_obs | z)
if ind_x1 is None:
# Use all features
loss_px_z = tf.reduce_sum(((data_x - mu_x) ** 2) / (2 * sigma_square_x) + \
0.5 * tf.math.log(sigma_square_x), axis=1)
else:
# Gather observed features per-sample
# ind_x1: (n, K_max), obs_mask: (n, K_max)
data_x_cond = tf.gather(data_x, ind_x1, batch_dims=1) # (n, K_max)
mu_x_cond = tf.gather(mu_x, ind_x1, batch_dims=1) # (n, K_max)
sigma_square_x_cond = tf.gather(sigma_square_x, ind_x1, batch_dims=1) # (n, K_max)
# Compute log-likelihood for observed features
ll_term = ((data_x_cond - mu_x_cond) ** 2) / (2 * sigma_square_x_cond) + \
0.5 * tf.math.log(sigma_square_x_cond) # (n, K_max)
if obs_mask is not None:
ll_term = ll_term * obs_mask # zero out padded positions
loss_px_z = tf.reduce_sum(ll_term, axis=1) # (n,)
loss_prior_z = tf.reduce_sum(data_z**2, axis=1) / 2
log_posterior = -(loss_prior_z + loss_px_z)
return log_posterior
def tfp_mcmc_sampler(self, data, ind_x1=None, n_mcmc=3000, burn_in=5000,
step_size=0.01, num_leapfrog_steps=10, seed=42):
"""
Samples from the posterior distribution P(Z|X) using TensorFlow Probability MCMC.
Args:
data: (tf.Tensor): Tensor or np.array, shape (n, x_dim).
Full data; missing features are ignored via ind_x1 / obs_mask.
ind_x1: None, or:
- list of lists: len == n, each sublist is observed feature indices
for that sample, possibly different lengths.
- 2-D int tensor: (n, K_max), same K_max for all.
- 1-D int tensor/list: shared indices for all samples.
n_mcmc: (int): Number of samples retained after burn-in.
burn_in: (int): Number of samples for burn-in.
step_size: (float): Step size for HMC kernel.
num_leapfrog_steps: (int): Number of leapfrog steps for HMC.
seed: (int): Random seed for reproducibility.
Returns:
tf.Tensor: Posterior samples with shape (n_mcmc, n, z_dim).
"""
# Convert data to tensor if not already
if not isinstance(data, tf.Tensor):
data = tf.convert_to_tensor(data, dtype=tf.float32)
n_samples = data.shape[0]
z_dim = self.params['z_dim']
ind_x1_tensor = None
obs_mask = None
if ind_x1 is not None:
# Case 1: list-of-lists (ragged, arbitrary lengths)
if isinstance(ind_x1, (list, tuple)) and len(ind_x1) > 0 and isinstance(ind_x1[0], (list, tuple)):
# Ensure length matches batch
assert len(ind_x1) == n_samples, \
f"len(ind_x1)={len(ind_x1)} != n_samples={n_samples}"
max_len = max(len(row) for row in ind_x1) if n_samples > 0 else 0
assert max_len > 0, f"No observed features"
ind_mat = np.zeros((n_samples, max_len), dtype=np.int32)
mask_mat = np.zeros((n_samples, max_len), dtype=np.float32)
for i, row in enumerate(ind_x1):
L = len(row)
if L > 0:
ind_mat[i, :L] = np.array(row, dtype=np.int32)
mask_mat[i, :L] = 1.0
ind_x1_tensor = tf.constant(ind_mat, dtype=tf.int32) # (n, K_max)
obs_mask = tf.constant(mask_mat, dtype=tf.float32) # (n, K_max)
else:
# Convert anything else (np arrays, tensors) to tf.Tensor
ind_x1_tensor = tf.convert_to_tensor(ind_x1, dtype=tf.int32)
if ind_x1_tensor.shape.rank == 1:
# Shared pattern for all samples; broadcast to (n, K)
K = tf.shape(ind_x1_tensor)[0]
ind_x1_tensor = tf.broadcast_to(ind_x1_tensor[tf.newaxis, :],
[n_samples, K])
elif ind_x1_tensor.shape.rank != 2:
raise ValueError("ind_x1 must be rank 1 or 2 if tensor-like.")
obs_mask = tf.ones_like(ind_x1_tensor, dtype=tf.float32)
# Initialize chains with standard normal distribution
initial_state = tf.random.normal(
shape=(n_samples, z_dim),
seed=seed,
dtype=tf.float32
)
# Define the target log probability function
def target_log_prob_fn(z):
"""
Target log probability function for MCMC.
Args:
z: (tf.Tensor): Latent variables with shape (n_samples, z_dim).
Returns:
tf.Tensor: Log probability with shape (n_samples,).
"""
return self.get_log_posterior(z, data, ind_x1_tensor, obs_mask)
# Create HMC kernel
hmc_kernel = tfm.HamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob_fn,
step_size=step_size,
num_leapfrog_steps=num_leapfrog_steps
)
# Add adaptive step size adjustment
adaptive_kernel = tfm.SimpleStepSizeAdaptation(
inner_kernel=hmc_kernel,
num_adaptation_steps=int(burn_in * 0.8),
target_accept_prob=0.75
)
# Run MCMC
@tf.function
def run_mcmc():
samples, kernel_results = tfm.sample_chain(
num_results=n_mcmc,
num_burnin_steps=burn_in,
current_state=initial_state,
kernel=adaptive_kernel,
trace_fn=lambda _, pkr: pkr.inner_results.is_accepted
)
return samples, kernel_results
# Execute MCMC
samples, is_accepted = run_mcmc()
# Calculate acceptance rate
acceptance_rate = tf.reduce_mean(tf.cast(is_accepted, tf.float32))
print(f"TFP MCMC Acceptance Rate: {acceptance_rate:.4f}")
return samples