import datetime
import os
import dateutil.tz
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from bayesgm.datasets import Base_sampler, Gaussian_sampler
from ..networks import (
Discriminator,
MNISTDiscriminator,
MNISTEncoderConv,
MNISTGenerator,
)
from .base import BGM
[docs]
class MNISTBGM(BGM):
"""BGM model for MNIST imaging data.
Inherits from :class:`BGM` and overrides methods to use convolutional
neural networks and a Bernoulli likelihood for binary image data of
shape ``(28, 28, 1)``.
Parameters
----------
params : dict
Configuration dictionary. Same keys as :class:`BGM`, with
``x_dim`` corresponding to the flattened image dimensionality
(784 for MNIST).
timestamp : str or None, optional
Timestamp string for the run. If ``None``, the current local time
is used.
random_seed : int or None, optional
If provided, sets the global random seed for reproducibility.
"""
def __init__(self, params, timestamp=None, random_seed=None):
self.params = params
self.timestamp = timestamp
if random_seed is not None:
tf.keras.utils.set_random_seed(random_seed)
os.environ['TF_DETERMINISTIC_OPS'] = '1'
tf.config.experimental.enable_op_determinism()
# MNIST-specific networks
self.g_net = MNISTGenerator(z_dim=params['z_dim'], filters=32,
use_bnn=params['use_bnn'], name='g_net')
self.e_net = MNISTEncoderConv(z_dim=params['z_dim'], filters=32, name='e_net')
self.dz_net = Discriminator(input_dim=params['z_dim'], model_name='dz_net',
nb_units=params['dz_units'])
self.dx_net = MNISTDiscriminator(filters=64, name='dx_net')
self.g_pre_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.5, beta_2=0.9)
self.d_pre_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.5, beta_2=0.9)
self.z_sampler = Gaussian_sampler(mean=np.zeros(params['z_dim']), sd=1.0)
self.g_optimizer = tf.keras.optimizers.Adam(params['lr_theta'], beta_1=0.9, beta_2=0.99)
self.posterior_optimizer = tf.keras.optimizers.Adam(params['lr_z'], beta_1=0.9, beta_2=0.99)
self.initialize_nets()
if self.timestamp is None:
now = datetime.datetime.now(dateutil.tz.tzlocal())
self.timestamp = now.strftime('%Y%m%d_%H%M%S')
self.checkpoint_path = "{}/checkpoints/{}/{}".format(
params['output_dir'], params['dataset'], self.timestamp)
if self.params['save_model'] and not os.path.exists(self.checkpoint_path):
os.makedirs(self.checkpoint_path)
self.save_dir = "{}/results/{}/{}".format(
params['output_dir'], params['dataset'], self.timestamp)
if self.params['save_res'] and not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.ckpt = tf.train.Checkpoint(g_net = self.g_net,
e_net = self.e_net,
dz_net = self.dz_net,
dx_net = self.dx_net,
g_pre_optimizer = self.g_pre_optimizer,
d_pre_optimizer = self.d_pre_optimizer,
g_optimizer = self.g_optimizer,
posterior_optimizer = self.posterior_optimizer)
self.ckpt_manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_path, max_to_keep=100)
if self.ckpt_manager.latest_checkpoint:
self.ckpt.restore(self.ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
# Update generative model for X
@tf.function
def update_g_net(self, data_z, data_x):
"""
Updates the generative model g_net using Bernoulli log-likelihood for MNIST.
Args:
data_z: Tensor of shape (batch, z_dim), latent variable.
data_x: Tensor of shape (batch, 28, 28, 1), observed MNIST data.
Returns:
loss_x: Scalar loss value for training g_net.
loss_mse: Mean squared error between observed and predicted x.
"""
with tf.GradientTape() as gen_tape:
mu_x, sigma_square_x = self.g_net(data_z)
x_logits = self.g_net.reparameterize(mu_x, sigma_square_x)
# Convert logits to probabilities for MSE calculation
x_probs = tf.nn.sigmoid(x_logits)
loss_mse = tf.reduce_mean((data_x - x_probs)**2)
# Bernoulli log-likelihood: -log p(x|z)
# For Bernoulli: p(x) = x * p + (1-x) * (1-p) where p = sigmoid(logits)
# log p(x) = x * log(p) + (1-x) * log(1-p)
# log p(x) = x * logits - log(1 + exp(logits)) (using log-sum-exp trick)
x_logits = tf.clip_by_value(x_logits, -10, 10) # Prevent overflow
log_px_z = tf.reduce_sum(
data_x * x_logits - tf.nn.softplus(x_logits),
axis=[1, 2, 3] # Sum over spatial dimensions
)
loss_x = -tf.reduce_mean(log_px_z) # Negative log-likelihood
if self.params['use_bnn']:
loss_kl = sum(self.g_net.losses)
loss_x += loss_kl * self.params['kl_weight']
# Calculate the gradients for generators
g_gradients = gen_tape.gradient(loss_x, self.g_net.trainable_variables)
# Apply the gradients to the optimizer
self.g_optimizer.apply_gradients(zip(g_gradients, self.g_net.trainable_variables))
return loss_x, loss_mse
# Update posterior of latent variables Z
@tf.function
def update_latent_variable_sgd(self, data_z, data_x):
with tf.GradientTape() as tape:
# logp(x|z) for Bernoulli model
mu_x, sigma_square_x = self.g_net(data_z)
x_logits = self.g_net.reparameterize(mu_x, sigma_square_x)
# Bernoulli log-likelihood: -log p(x|z)
x_logits = tf.clip_by_value(x_logits, -10, 10) # Prevent overflow
log_px_z = tf.reduce_sum(
data_x * x_logits - tf.nn.softplus(x_logits),
axis=[1, 2, 3] # Sum over spatial dimensions
)
loss_px_z = -tf.reduce_mean(log_px_z) # Negative log-likelihood
loss_prior_z = tf.reduce_sum(data_z**2, axis=1)/2
loss_prior_z = tf.reduce_mean(loss_prior_z)
loss_postrior_z = loss_px_z + loss_prior_z
# Calculate the gradients
posterior_gradients = tape.gradient(loss_postrior_z, [data_z])
# Apply the gradients to the optimizer
self.posterior_optimizer.apply_gradients(zip(posterior_gradients, [data_z]))
return loss_postrior_z
#################################### EGM initialization ###########################################
@tf.function
def train_disc_step(self, data_z, data_x):
"""train discriminators step for MNIST image data.
Args:
data_z: Latent tensor with shape [batch_size, z_dim].
data_x: Image tensor with shape [batch_size, 28, 28, 1].
Returns:
returns various discriminator loss functions.
"""
epsilon_z = tf.random.uniform([],minval=0., maxval=1.)
epsilon_x = tf.random.uniform([],minval=0., maxval=1.)
with tf.GradientTape(persistent=True) as disc_tape:
with tf.GradientTape() as gpz_tape:
data_z_ = self.e_net(data_x)
data_z_hat = data_z*epsilon_z + data_z_*(1-epsilon_z)
data_dz_hat = self.dz_net(data_z_hat)
with tf.GradientTape() as gpx_tape:
mu_x_, sigma_square_x_ = self.g_net(data_z)
x_logits_ = self.g_net.reparameterize(mu_x_, sigma_square_x_)
data_x_ = tf.nn.sigmoid(x_logits_)
data_x_hat = data_x*epsilon_x + data_x_*(1-epsilon_x)
data_dx_hat = self.dx_net(data_x_hat)
data_dx_ = self.dx_net(data_x_)
data_dz_ = self.dz_net(data_z_)
data_dx = self.dx_net(data_x)
data_dz = self.dz_net(data_z)
dz_loss = (tf.reduce_mean((0.9*tf.ones_like(data_dz) - data_dz)**2) \
+tf.reduce_mean((0.1*tf.ones_like(data_dz_) - data_dz_)**2))/2.0
dx_loss = (tf.reduce_mean((0.9*tf.ones_like(data_dx) - data_dx)**2) \
+tf.reduce_mean((0.1*tf.ones_like(data_dx_) - data_dx_)**2))/2.0
#gradient penalty for z
grad_z = gpz_tape.gradient(data_dz_hat, data_z_hat)
grad_norm_z = tf.sqrt(tf.reduce_sum(tf.square(grad_z), axis=1))#(bs,)
gpz_loss = tf.reduce_mean(tf.square(grad_norm_z - 1.0))
#gradient penalty for x (spatial dimensions for images)
grad_x = gpx_tape.gradient(data_dx_hat, data_x_hat)
grad_norm_x = tf.sqrt(tf.reduce_sum(tf.square(grad_x), axis=[1, 2, 3]))#(bs,)
gpx_loss = tf.reduce_mean(tf.square(grad_norm_x - 1.0))
d_loss = dx_loss + dz_loss + \
self.params['gamma']*(gpz_loss + gpx_loss)
# Calculate the gradients for discriminators
d_gradients = disc_tape.gradient(d_loss, self.dz_net.trainable_variables+self.dx_net.trainable_variables)
# Apply the gradients to the optimizer
self.d_pre_optimizer.apply_gradients(zip(d_gradients, self.dz_net.trainable_variables+self.dx_net.trainable_variables))
return dz_loss, dx_loss, d_loss
@tf.function
def train_gen_step(self, data_z, data_x):
"""train generators step for MNIST image data.
Args:
data_z: Latent tensor with shape [batch_size, z_dim].
data_x: Image tensor with shape [batch_size, 28, 28, 1].
Returns:
returns various generator loss functions.
"""
with tf.GradientTape(persistent=True) as gen_tape:
mu_x_, sigma_square_x_ = self.g_net(data_z)
x_logits_ = self.g_net.reparameterize(mu_x_, sigma_square_x_)
data_x_ = tf.nn.sigmoid(x_logits_)
reg_loss = tf.reduce_mean(tf.square(sigma_square_x_))
data_z_ = self.e_net(data_x)
data_z__= self.e_net(data_x_)
mu_x__, sigma_square_x__ = self.g_net(data_z_)
x_logits__ = self.g_net.reparameterize(mu_x__, sigma_square_x__)
data_x__ = tf.nn.sigmoid(x_logits__)
data_dx_ = self.dx_net(data_x_)
data_dz_ = self.dz_net(data_z_)
l2_loss_x = tf.reduce_mean((data_x - data_x__)**2)
l2_loss_z = tf.reduce_mean((data_z - data_z__)**2)
g_loss_adv = tf.reduce_mean((0.9*tf.ones_like(data_dx_) - data_dx_)**2)
e_loss_adv = tf.reduce_mean((0.9*tf.ones_like(data_dz_) - data_dz_)**2)
g_e_loss = g_loss_adv + e_loss_adv + 10 * (l2_loss_x + l2_loss_z) + self.params['alpha'] * reg_loss
# Calculate the gradients for generators
g_e_gradients = gen_tape.gradient(g_e_loss, self.g_net.trainable_variables+self.e_net.trainable_variables)
# Apply the gradients to the optimizer
self.g_pre_optimizer.apply_gradients(zip(g_e_gradients, self.g_net.trainable_variables+self.e_net.trainable_variables))
return g_loss_adv, e_loss_adv, l2_loss_z, l2_loss_x, reg_loss, g_e_loss
def egm_init(self, data, egm_n_iter=10000, batch_size=32, egm_batches_per_eval=500, verbose=1):
self.data_sampler = Base_sampler(x=data,y=data,v=data, batch_size=batch_size, normalize=False)
print('EGM Initialization Starts ...')
for batch_iter in range(egm_n_iter+1):
# Update model parameters of Discriminator
for _ in range(self.params['g_d_freq']):
batch_x,_,_ = self.data_sampler.next_batch()
batch_z = self.z_sampler.get_batch(batch_size)
dz_loss, dx_loss, d_loss = self.train_disc_step(batch_z, batch_x)
# Update model parameters of G,E with SGD
batch_x,_,_ = self.data_sampler.next_batch()
batch_z = self.z_sampler.get_batch(batch_size)
g_loss_adv, e_loss_adv, l2_loss_z, l2_loss_x, sigma_square_loss, g_e_loss = self.train_gen_step(batch_z, batch_x)
if batch_iter % egm_batches_per_eval == 0:
loss_contents = (
'EGM Initialization Iter [%d] : g_loss_adv[%.4f], e_loss_adv [%.4f], l2_loss_z [%.4f], l2_loss_x [%.4f], '
'sd^2_loss[%.4f], g_e_loss [%.4f], dz_loss [%.4f], dx_loss[%.4f], d_loss [%.4f]'
% (batch_iter, g_loss_adv, e_loss_adv, l2_loss_z, l2_loss_x, sigma_square_loss, g_e_loss, dz_loss, dx_loss, d_loss)
)
if verbose:
print(loss_contents)
data_z_ = self.e_net(data)
mu_x__, sigma_square_x__ = self.g_net(data_z_)
x_logits__ = self.g_net.reparameterize(mu_x__, sigma_square_x__)
data_x__ = tf.nn.sigmoid(x_logits__)
MSE = tf.reduce_mean((data - data_x__)**2)
data_gen = self.generate(nb_samples=5000)
np.savez('%s/init_data_gen_at_%d.npz'%(self.save_dir, batch_iter),
data_gen=data_gen, z=data_z_, x_rec=data_x__)
print('MSE_x', MSE.numpy())
mse_x = self.evaluate(data = data)
print('iter [%d/%d]: MSE_x: %.4f\n' % (batch_iter, egm_n_iter, mse_x))
if self.params['save_model']:
base_path = self.checkpoint_path + f"/weights_at_egm_init_{batch_iter}"
self.e_net.save_weights(f"{base_path}_encoder.weights.h5")
self.g_net.save_weights(f"{base_path}_generator.weights.h5")
print('Saving checkpoint for egm_init at {}'.format(base_path))
print('EGM Initialization Ends.')
#################################### EGM initialization #############################################
[docs]
def fit(self, data,
batch_size=32, epochs=100, epochs_per_eval=5,
use_egm_init=True, egm_n_iter=10000, egm_batches_per_eval=500, verbose=1):
"""Train the MNIST BGM model on image data.
Parameters
----------
data : np.ndarray
MNIST image array with shape ``(n, 28, 28, 1)``, values in [0, 1].
batch_size : int, default=32
Mini-batch size.
epochs : int, default=100
Number of training epochs for the iterative phase.
epochs_per_eval : int, default=5
Evaluate and (optionally) save every this many epochs.
use_egm_init : bool, default=True
Whether to run EGM initialization before iterative training.
egm_n_iter : int, default=10000
Number of EGM initialization iterations.
egm_batches_per_eval : int, default=500
Evaluate EGM every this many iterations.
verbose : int, default=1
Verbosity level. Set to 0 to suppress progress messages.
"""
if self.params['save_res']:
f_params = open('{}/params.txt'.format(self.save_dir),'w')
f_params.write(str(self.params))
f_params.close()
if use_egm_init:
self.egm_init(data, egm_n_iter=egm_n_iter, egm_batches_per_eval=egm_batches_per_eval, batch_size=batch_size, verbose=verbose)
print('Initialize latent variables Z with e(V)...')
data_z_init = self.e_net(data)
else:
print('Random initialization of latent variables Z...')
data_z_init = np.random.normal(0, 1, size = (len(data), self.params['z_dim'])).astype('float32')
self.data_z = tf.Variable(data_z_init, name="Latent Variable",trainable=True)
self.history_loss = []
print('Iterative Updating Starts ...')
for epoch in range(epochs+1):
sample_idx = np.random.choice(len(data), len(data), replace=False)
# Create a progress bar for batches
with tqdm(total=len(data) // batch_size, desc=f"Epoch {epoch}/{epochs}", unit="batch") as batch_bar:
for i in range(0,len(data) - batch_size + 1,batch_size): ## Skip the incomplete last batch
batch_idx = sample_idx[i:i+batch_size]
# Update model parameters of G, H, F with SGD
batch_z = tf.Variable(tf.gather(self.data_z, batch_idx, axis = 0), name='batch_z', trainable=True)
batch_x = data[batch_idx,:]
loss_x, loss_mse_x = self.update_g_net(batch_z, batch_x)
# Update Z by maximizing a posterior or posterior mean
loss_postrior_z = self.update_latent_variable_sgd(batch_z, batch_x)
# Update data_z with updated batch_z
self.data_z.scatter_nd_update(
indices=tf.expand_dims(batch_idx, axis=1),
updates=batch_z
)
# Update the progress bar with the current loss information
loss_contents = (
'loss_x: [%.4f], loss_mse_x: [%.4f], loss_postrior_z: [%.4f]'
% (loss_x, loss_mse_x, loss_postrior_z)
)
batch_bar.set_postfix_str(loss_contents)
batch_bar.update(1)
# Evaluate the full training data and print metrics for the epoch
if epoch % epochs_per_eval == 0:
mse_x = self.evaluate(data = data, data_z = self.data_z)
self.history_loss.append(mse_x)
if verbose:
print('Epoch [%d/%d]: MSE_x: %.4f\n' % (epoch, epochs, mse_x))
if self.params['save_model']:
base_path = self.checkpoint_path + f"/weights_at_{epoch}"
self.g_net.save_weights(f"{base_path}_generator.weights.h5")
print('Saving checkpoint for epoch {} at {}'.format(epoch, base_path))
data_gen = self.generate(nb_samples=5000)
if self.params['save_res']:
np.savez('%s/data_gen_at_%d.npz'%(self.save_dir, epoch),
gen=data_gen,
z=self.data_z.numpy()
)
@tf.function
def evaluate(self, data, data_z=None):
"""Compute the mean squared error between observed and reconstructed MNIST images.
Parameters
----------
data : np.ndarray or tf.Tensor
Observed images with shape ``(n, 28, 28, 1)``.
data_z : tf.Tensor or None, optional
Latent variables with shape ``(n, z_dim)``.
If ``None``, the encoder is used to infer them.
Returns
-------
mse_x : tf.Tensor
Scalar mean squared error.
"""
if data_z is None:
data_z = self.e_net(data, training=False)
mu_x, sigma_square_x = self.g_net(data_z, training=False)
x_logits = self.g_net.reparameterize(mu_x, sigma_square_x)
data_x_pred = tf.nn.sigmoid(x_logits)
mse_x = tf.reduce_mean((data-data_x_pred)**2)
return mse_x
@tf.function
def generate(self, nb_samples=1000):
"""Generate synthetic MNIST images from the trained model.
Samples latent codes from the standard normal prior and decodes them
through the convolutional generator.
Parameters
----------
nb_samples : int, default=1000
Number of images to generate.
Returns
-------
data_x_pred : tf.Tensor
Generated images with shape ``(nb_samples, 28, 28, 1)``,
pixel values in [0, 1].
"""
data_z = tf.random.normal(shape=(nb_samples, self.params['z_dim']), mean=0.0, stddev=1.0)
mu_x, sigma_square_x = self.g_net(data_z, training=False)
x_logits = self.g_net.reparameterize(mu_x, sigma_square_x)
data_x_pred = tf.nn.sigmoid(x_logits)
return data_x_pred
@tf.function
def predict_on_posteriors(self, data_posterior_z):
n_mcmc = tf.shape(data_posterior_z)[0]
n_samples = tf.shape(data_posterior_z)[1]
# Flatten data
data_posterior_z_flat = tf.reshape(data_posterior_z, [-1, self.params['z_dim']]) # (n_mcmc * n_samples, z_dim)
mu_x_flat, sigma_square_x_flat = self.g_net(data_posterior_z_flat) # (n_mcmc*n_samples, 28, 28, 1)
x_logits_flat = self.g_net.reparameterize(mu_x_flat, sigma_square_x_flat)
data_x_pred_flat = tf.nn.sigmoid(x_logits_flat)
data_x_pred = tf.reshape(data_x_pred_flat, [n_mcmc, n_samples, 28, 28, 1])
return data_x_pred
[docs]
def predict(self, data, alpha=0.05, return_samples=False, bs=100, n_mcmc=5000, burn_in=5000, step_size=0.01, num_leapfrog_steps=10, seed=42):
"""
Predict the posterior distribution of P(x2|x1) for MNIST images.
Parameters
----------
data : np.ndarray or tf.Tensor
Observed data with shape ``(n, 28, 28, 1)``.
Missing pixels should be encoded as ``np.nan``.
alpha : float, default=0.05
Significance level for prediction intervals.
return_samples : bool, default=False
If ``False``, return imputed images with shape ``(n, 28, 28, 1)``.
If ``True``, return posterior samples with shape
``(n_mcmc, n, 28, 28, 1)``.
bs : int, default=100
Batch size for posterior prediction.
n_mcmc : int, default=5000
Number of retained MCMC samples.
burn_in : int, default=5000
Number of burn-in iterations.
step_size : float, default=0.01
HMC step size.
num_leapfrog_steps : int, default=10
Number of leapfrog steps in HMC.
seed : int, default=42
Random seed.
Returns
-------
data_x_pred : np.ndarray
Imputed images if ``return_samples=False`` with shape
``(n, 28, 28, 1)``. Posterior predictive samples if
``return_samples=True`` with shape ``(n_mcmc, n, 28, 28, 1)``.
pred_interval : np.ndarray or list[np.ndarray]
Prediction intervals on missing pixels.
For a shared missing pattern, shape is
``(n, n_missing_pixels, 2)``.
Otherwise, this is a per-sample list where element ``i`` has shape
``(n_missing_pixels_i, 2)``.
"""
assert 0 < alpha < 1, "The significance level 'alpha' must be greater than 0 and less than 1."
if not isinstance(data, tf.Tensor):
data_tf = tf.convert_to_tensor(data, dtype=tf.float32)
else:
data_tf = tf.cast(data, tf.float32)
# Shape: (n, 28, 28, 1)
n_data_samples = data_tf.shape[0]
# Boolean mask of missingness (True where NaN)
is_nan_tf = tf.math.is_nan(data_tf)
# Observed mask (True where not NaN)
is_obs_tf = tf.logical_not(is_nan_tf)
# We'll still feed some numeric value at missing locations; they are ignored via indices.
data_clean_tf = tf.where(is_nan_tf,
tf.zeros_like(data_tf),
data_tf)
# Flatten observed mask to build per-sample index lists
is_obs_flat_tf = tf.reshape(is_obs_tf, [n_data_samples, -1])
is_obs_flat_np = is_obs_flat_tf.numpy()
# Build ind_x1 as list-of-lists of observed pixel indices
ind_x1_list = [
np.where(row)[0].tolist()
for row in is_obs_flat_np
]
data_posterior_z = self.tfp_mcmc_sampler(
data=data_clean_tf,
ind_x1=ind_x1_list,
n_mcmc=n_mcmc,
burn_in=burn_in,
step_size=step_size,
num_leapfrog_steps=num_leapfrog_steps,
seed=seed
)
# data_posterior_z: (n_mcmc, n_data_samples, z_dim)
data_x_pred_all = []
# Loop over data dimension in batches
for i in range(0, n_data_samples, bs):
batch_posterior_z = data_posterior_z[:, i:i + bs, :] # (n_mcmc, bs_i, z_dim)
data_x_batch_pred = self.predict_on_posteriors(batch_posterior_z)
# Expected shape: (n_mcmc, bs_i, 28, 28, 1)
data_x_batch_pred = data_x_batch_pred.numpy()
data_x_pred_all.append(data_x_batch_pred)
# Concatenate along data dimension
data_x_pred_all = np.concatenate(data_x_pred_all, axis=1)
# Shape: (n_mcmc, n_data_samples, 28, 28, 1)
data_np = data_tf.numpy()
miss_mask_full = np.isnan(data_np).astype(np.float32)
obs_mask_full = 1.0 - miss_mask_full
data_obs_np = np.nan_to_num(data_np, nan=0.0)
# Compute prediction intervals on missing pixels only
n_mcmc_samples = data_x_pred_all.shape[0]
flat_pred = data_x_pred_all.reshape(n_mcmc_samples,
n_data_samples,
-1) # (n_mcmc, n, 784)
miss_mask_flat = miss_mask_full.reshape(n_data_samples, -1).astype(bool)
# Check if all samples share the same missing pattern
same_pattern = np.all(miss_mask_flat == miss_mask_flat[0])
if same_pattern:
# Common missing pattern across all samples
miss_idx = np.where(miss_mask_flat[0])[0] # (N_missed_pixels,)
if miss_idx.size == 0:
# No missing pixels at all
pred_interval = np.zeros((n_data_samples, 0, 2), dtype=np.float32)
else:
# Gather only missing pixel samples
pix_samples = flat_pred[:, :, miss_idx] # (n_mcmc, n, N_missed_pixels)
lower = np.quantile(pix_samples, alpha / 2.0, axis=0) # (n, N_missed_pixels)
upper = np.quantile(pix_samples, 1.0 - alpha / 2.0, axis=0) # (n, N_missed_pixels)
pred_interval = np.stack([lower, upper], axis=-1) # (n, N_missed_pixels, 2)
else:
# Different missing patterns; return a list of per-sample intervals
pred_interval = []
for i in range(n_data_samples):
miss_idx_i = np.where(miss_mask_flat[i])[0] # (N_missed_pixels_i,)
if miss_idx_i.size == 0:
# This sample has no missing pixels
pred_interval.append(np.zeros((0, 2), dtype=np.float32))
continue
pix_samples_i = flat_pred[:, i, miss_idx_i] # (n_mcmc, N_missed_pixels_i)
lower_i = np.quantile(pix_samples_i, alpha / 2.0, axis=0) # (N_missed_pixels_i,)
upper_i = np.quantile(pix_samples_i, 1.0 - alpha / 2.0, axis=0) # (N_missed_pixels_i,)
intervals_i = np.stack([lower_i, upper_i], axis=-1) # (N_missed_pixels_i, 2)
pred_interval.append(intervals_i)
if return_samples:
return data_x_pred_all, pred_interval
else:
# Return single imputed dataset: posterior mean across MCMC samples
# Shape: (n, 28, 28, 1) — observed pixels intact, missing filled with posterior means
data_imputed = np.mean(data_x_pred_all, axis=0)
# Ensure observed pixels are exactly the original (avoid floating-point drift)
data_imputed = miss_mask_full * data_imputed + obs_mask_full * data_obs_np
return data_imputed, pred_interval
@tf.function
def get_log_posterior(self, data_z, data_x, ind_x1=None, obs_mask=None):
"""
Calculate log posterior using Bernoulli likelihood for MNIST images.
data_z: (tf.Tensor): Input data with shape (n, z_dim).
data_x: (tf.Tensor): (n, 28, 28, 1) or (n, 784) full images;
missing pixels are ignored via indices/mask.
ind_x1: None, or int32 Tensor of shape (n, K_max) with pixel indices
for each sample (padded where obs_mask == 0).
obs_mask: None, or float32 Tensor of shape (n, K_max),
1 for real observed indices, 0 for padding.
return (tf.Tensor): Log posterior with shape (n, ).
"""
mu_x, sigma_square_x = self.g_net(data_z)
x_logits = self.g_net.reparameterize(mu_x, sigma_square_x)
# Clip logits to prevent overflow
x_logits = tf.clip_by_value(x_logits, -10, 10)
# Flatten both tensors for indexing
batch_size = tf.shape(data_x)[0]
data_x_flat = tf.reshape(data_x, [batch_size, -1]) # (n, 784)
x_logits_flat = tf.reshape(x_logits, [batch_size, -1]) # (n, 784)
# Bernoulli likelihood term log p(x_obs | z)
if ind_x1 is None:
ll_term = data_x_flat * x_logits_flat - tf.nn.softplus(x_logits_flat)
log_px_z = tf.reduce_sum(ll_term, axis=1)
else:
# Gather observed pixels per-sample
# ind_x1: (n, K_max), obs_mask: (n, K_max)
data_x_cond = tf.gather(data_x_flat, ind_x1, batch_dims=1) # (n, K_max)
x_logits_cond = tf.gather(x_logits_flat, ind_x1, batch_dims=1) # (n, K_max)
ll_term = data_x_cond * x_logits_cond - tf.nn.softplus(x_logits_cond) # (n, K_max)
if obs_mask is not None:
ll_term = ll_term * obs_mask # zero out padded positions
log_px_z = tf.reduce_sum(ll_term, axis=1) # (n,)
log_prior_z = -0.5 * tf.reduce_sum(data_z**2, axis=1)
return log_prior_z + log_px_z